【发布时间】:2018-10-23 18:52:06
【问题描述】:
我正在查看一些运行 tf.while_loop() 的 Tensorflow 代码,我有一个问题。该代码计算 4 次多项式的根,作为学习实验室的一部分 tensorflow. 我的问题是为什么特定的打印语句不会输出所有中间值。
代码如下:
import tensorflow as tf
def f(x, w):
return (w[0] + w[1] * x + w[2] * tf.pow(x,2) + w[3] * tf.pow(x,3) + w[4] * tf.pow(x,4) )
def f1(x, w):
return (w[1] + 2. * w[2] * x + 3. * w[3] * tf.pow(x,2) + 4. * w[4] * tf.pow(x,3) )
def f2(x, w):
return (2. * w[2] + 6. * w[3] * x + 12. * w[4] * tf.pow(x,2) )
def fxn_plus_1(xn, w):
return (xn - (2. * f(xn, w) * f1(xn, w) / (2. * tf.square(f1(xn, w)) - f(xn, w) * f2(xn, w))))
def c(x, weights):
return tf.abs(x - fxn_plus_1(x, weights)) > 0.001
def b(x, weights):
x = fxn_plus_1(x, weights)
return x, weights
weights = tf.constant( [-1000., 1., 1. , 5. , 0.1])
x = fxn_plus_1(-10., weights)
out = tf.while_loop(c, b, [x, weights])
with tf.Session() as sess:
x, weights = sess.run(out)
print(x)
输出正确,值为5.575055。现在,我想看看随着算法的进行,循环体b() 的中间值是什么。我将函数b() 更改为以下内容:
def b(x, weights):
x = fxn_plus_1(x, weights)
print(x) ## ADDED PRINT STATEMENT
print(weights)
return x, weights
我得到的回报是:
Tensor("while_4/sub_4:0", shape=(), dtype=float32)
Tensor("while_4/Identity_1:0", shape=(5,), dtype=float32)
5.575055
这似乎给出了x,weights 值的调试输出或图形信息,而不是实际值。我不确定如何让循环在每个步骤中实际打印值。
有什么建议吗? 谢谢。
关于尝试和建议的事情的更新:
其中一位评论者@user49593 建议我尝试使用tf.Print()。
这是代码和输出。我仍然只是获取图形信息而不是实际的值向量。
def b(x, weights):
x = fxn_plus_1(x, weights)
x = tf.Print(x, [x], message="here: ") #CHANGED TO tf.Print STATEMENT
return x, weights
输出仍然只是5.575055。没有中间值向量。
【问题讨论】:
-
@user49593 感谢您的检查,但我认为这是一个不同的问题。我阅读了链接,它建议使用
eval()方法。但是,如果您正在评估with tf.Session() as sess:内的函数,则使用它。在 OP 中,关键是b()函数中的 print 语句是在会话上下文之前设置的,但该 print 语句是从会话上下文中调用的。我尝试使用eval()并得到一个错误。 -
您尝试过使用
tf.Print()吗? “要打印张量的值而不将其返回到 Python 程序,可以使用 tf.Print() 运算符” -
我确实尝试过
tf.Print(),但它仍然只是给我图形输出而不是实际值。我可以使用tf.Print()时得到的输出来更新 OP。是的,这只是令人困惑,为什么这个问题看起来如此令人困惑。我习惯了你在尝试在sess.run()中调试时提到的技巧,但这些是在sess.run()中调用的函数。 -
这不是你使用
tf.Print()的方式。做x = tf.Print(x, [x], message="here: ")Using tf.Print blog post
标签: python tensorflow