【问题标题】:Finding required placeholders in Tensorflow graph在 TensorFlow 图中查找所需的占位符
【发布时间】:2018-01-30 09:30:14
【问题描述】:

在 Tensorflow 中,有没有办法找到评估某个输出张量所需的所有占位符张量?也就是说,当调用sess.run(output_tensor) 时,是否有一个函数将返回所有必须输入feed_dict 的(占位符)张量?

这是我想做的一个例子,用伪代码:

import tensorflow as tf

a = tf.placeholder(dtype=tf.float32,shape=())
b = tf.placeholder(dtype=tf.float32,shape=())
c = tf.placeholder(dtype=tf.float32,shape=())
d = a + b
f = b + c

# This should return [a,b] or [a.name,b.name]
d_input_tensors = get_dependencies(d)

# This should return [b,c] or [b.name,c.name]
f_input_tensors = get_dependencies(f)

编辑:为了澄清,我不是(必然)寻找图中的所有占位符,只是定义特定输出张量所需的占位符。所需的占位符可能只是图中所有占位符的一个子集。

【问题讨论】:

  • 为了获取图中的所有占位符,有一个答案:stackoverflow.com/a/44371483/4834515。至于获取依赖项......不知道。
  • @Seven 我只想获取依赖项,而不是所有占位符。我将编辑我的问题以澄清。

标签: python tensorflow


【解决方案1】:

经过一番修补和发现 this 几乎相同的 SO 问题后,我想出了以下解决方案:

def get_tensor_dependencies(tensor):

    # If a tensor is passed in, get its op
    try:
        tensor_op = tensor.op
    except:
        tensor_op = tensor

    # Recursively analyze inputs
    dependencies = []
    for inp in tensor_op.inputs:
        new_d = get_tensor_dependencies(inp)
        non_repeated = [d for d in new_d if d not in dependencies]
        dependencies = [*dependencies, *non_repeated]

    # If we've reached the "end", return the op's name
    if len(tensor_op.inputs) == 0:
        dependencies = [tensor_op.name]

    # Return a list of tensor op names
    return dependencies

注意:这不仅会返回占位符,还会返回变量和常量。如果dependencies = [tensor_op.name]dependencies = [tensor_op.name] if tensor_op.type == 'Placeholder' else [] 替换,那么只会返回占位符。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-10-30
    • 1970-01-01
    • 2017-10-12
    • 2018-04-23
    相关资源
    最近更新 更多