【发布时间】:2020-05-21 08:47:05
【问题描述】:
我遇到了以下问题:通过使用@tf.function,我想沿定义的组件取消堆叠张量。
@tf.function
def f1(x):
y = tf.unstack(x)
return y
@tf.function
def f2(x):
y = tf.unstack(x, axis=0)
return y
@tf.function
def f3(x):
y = tf.unstack(x, axis=1)
return y
x = tf.random.uniform((4,2))
y1 = tf.unstack(x, axis=0) #f2
y2 = tf.unstack(x, axis=1) #f3
y = f1(x) # No problem! (output equal to y1)
z = f2(x) #Problem!
zz = f3(x) #Problem
TypeError:在用户代码中:
<ipython-input-339-c5b8c0b032bb>:8 f2 *
y = tf.unstack(x, axis=0)
TypeError: 'set' object is not callable
不确定是由于我对 AutoGraph 和 @tf.function 的无知还是其他原因造成的。如果有人能让我了解发生了什么,将不胜感激:-)
【问题讨论】:
-
您的代码在 TensorFlow 2.2.0 上运行良好。
-
天哪!你说的对;我从 jupyter-notebook 运行。非常感谢!
标签: python tensorflow machine-learning tensorflow2.0