【发布时间】:2018-01-30 09:30:14
【问题描述】:
在 Tensorflow 中,有没有办法找到评估某个输出张量所需的所有占位符张量?也就是说,当调用sess.run(output_tensor) 时,是否有一个函数将返回所有必须输入feed_dict 的(占位符)张量?
这是我想做的一个例子,用伪代码:
import tensorflow as tf
a = tf.placeholder(dtype=tf.float32,shape=())
b = tf.placeholder(dtype=tf.float32,shape=())
c = tf.placeholder(dtype=tf.float32,shape=())
d = a + b
f = b + c
# This should return [a,b] or [a.name,b.name]
d_input_tensors = get_dependencies(d)
# This should return [b,c] or [b.name,c.name]
f_input_tensors = get_dependencies(f)
编辑:为了澄清,我不是(必然)寻找图中的所有占位符,只是定义特定输出张量所需的占位符。所需的占位符可能只是图中所有占位符的一个子集。
【问题讨论】:
-
为了获取图中的所有占位符,有一个答案:stackoverflow.com/a/44371483/4834515。至于获取依赖项......不知道。
-
@Seven 我只想获取依赖项,而不是所有占位符。我将编辑我的问题以澄清。
标签: python tensorflow