【问题标题】:How to print the values of tensors inside a while loop?如何在while循环中打印张量的值?
【发布时间】:2017-02-24 16:21:27
【问题描述】:

我对 tensorflow 很陌生,我无法弄清楚这一点。

我有这个while循环:

def process_tree_tf(n_child, reprs, weights, bias, embed_dim, activation = tf.nn.relu):
    n_child, reprs = n_child, reprs
    parent_idxs = generate_parents_numpy(n_child)
    loop_idx = reprs.shape[0] - 1
    loop_vars = loop_idx, reprs, parent_idxs, weights, embed_dim

    def  loop_condition(loop_ind, *_):
        return tf.greater(0, loop_idx)

    def loop_body(loop_ind, reprs, parent_idxs, weights, embed_dim):
        x = reprs[loop_ind]
        x_expanded = tf.expand_dims(x, axis=-1)
        w = weights
        out = tf.squeeze(tf.add(tf.matmul(x_expanded,w,transpose_a=True), bias))
        activated = activation(out)
        par_idx = parent_idxs[loop_ind]
        reprs = update_parent(reprs, par_idx, embed_dim, activated)
        reprs = tf.Print(reprs, [reprs]) #This doesn't work
        loop_ind = loop_ind-1
        return loop_ind, reprs, parent_idxs, weights, embed_dim

    return tf.while_loop(loop_condition, loop_body, loop_vars)

我是这样评价的:

embed_dim = 2
hidden_dim = 2
n_nodes = 4
batch = 2
reprs = np.ones((n_nodes, embed_dim+hidden_dim))
n_child = np.array([1, 1, 1, 0])
weights = np.ones((embed_dim+hidden_dim, hidden_dim))
bias = np.ones(hidden_dim)
with tf.Session() as sess:
    _, r, *_ = process_tree_tf(n_child, reprs,  weights, bias, embed_dim, activation=tf.nn.relu)
    print(r.eval())

我想在 while 循环中检查 reprs 的值,但 tf.Print 似乎不起作用,print 只是告诉我这是一个张量并给了我它的形状。 我该怎么做?

非常感谢!

【问题讨论】:

    标签: python-3.x tensorflow


    【解决方案1】:

    看看这个网页:https://www.tensorflow.org/api_docs/python/tf/Print

    您可以看到 tf.Print 是一个身份运算符,在评估时具有打印数据的副作用。因此,您应该使用这一行来打印:

    reprs = tf.Print(reprs, [reprs])

    希望这会有所帮助,祝你好运!

    【讨论】:

    • 感谢您的回答!我试过你说的,但也没有用。我尝试与 tf.control_dependencies 一起使用,但也失败了。
    • 你能说一下输出是什么吗: print(reprs) reprs_out = tf.Print(reprs, [reprs]) print(reprs_out)
    • 这是我得到的:Tensor("while/update_parent:0", shape=(4, 4), dtype=float64) Tensor("while/Print:0", shape=(4, 4), dtype=float64)
    • 你可以使用 tensorboard 来调试你的图表吗?如果不打印,可能是接线错误。
    • 我不熟悉tensorboard,但我会尝试一下。如果我提到我在尝试评估循环内的 reprs 时遇到此错误会有所帮助吗? ValueError: Operation 'while/update_parent' has been marked as not fetchable.?
    【解决方案2】:

    rmeertens 建议的方法是我认为正确的方法。我只想添加(作为对您的 cmets 的响应),如果某些内容正在打印“Tensor("while/update_parent:0, ......”,那么这意味着图中的该值没有被评估。

    您可能会将其视为“print(r.eval())”语句的输出,而不是 tf.Print() 语句的输出。

    请注意,tf.Print() 的输出在 PyCharm(我正在使用的 IDE)中显示为红色,而普通 python 打印操作的输出显示为黑色。所以 tf.Print() 输出看起来像一条警告消息。可能它确实正在打印出来,但您只是忽略了它。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2019-04-28
      • 2018-06-05
      • 1970-01-01
      • 1970-01-01
      • 2011-03-17
      • 2021-12-16
      • 2023-01-11
      相关资源
      最近更新 更多