【发布时间】:2021-06-11 16:46:40
【问题描述】:
我需要在类方法中计算tf.Variable 梯度,但稍后使用这些梯度以不同的方法更新变量。我可以在不使用 @tf.function 装饰器时执行此操作,但在使用 @tf.function 时出现 TypeError: An op outside of the function building code is being passed a "Graph" tensor 错误。我一直在寻找对这个错误的理解以及如何解决它,但还不够。
仅供参考,如果您好奇的话,我想这样做,因为我有许多不同方程中的变量。与其尝试创建一个关联所有变量的单个方程,不如将它们分开更容易(计算成本更低),为每个方程及时计算梯度,然后逐步应用更新。我认识到这两种方法在数学上并不相同。
这是我的代码(一个最小的示例),后面是结果和错误消息。请注意,当梯度被计算并用于在单个方法中更新变量时,.iterate(),没有错误。
import tensorflow as tf
class Example():
def __init__(self, x, y, target, lr=0.01):
self.x = x
self.y = y
self.target = target
self.lr = lr
self.variables = [self.x, self.y]
@tf.function
def iterate(self):
with tf.GradientTape(persistent=False) as tape:
loss = (self.target - self.x * self.y)**2
self.gradients = tape.gradient(loss, self.variables)
for g, v in zip(self.gradients, self.variables):
v.assign_add(-self.lr * g)
@tf.function
def compute_update(self):
with tf.GradientTape(persistent=False) as tape:
loss = (self.target - self.x * self.y)**2
self.gradients = tape.gradient(loss, self.variables)
@tf.function
def apply_update(self):
for g, v in zip(self.gradients, self.variables):
v.assign_add(-self.lr * g)
x = tf.Variable(1.)
y = tf.Variable(3.)
target = tf.Variable(5.)
example = Example(x, y, target)
# Compute and apply updates in a single tf.function method
example.iterate()
print('')
print(example.variables)
print('')
# Compute and apply updates in separate tf.function methods
example.compute_update()
example.apply_update()
print('')
print(example.variables)
print('')
输出:
$ python temp_bug.py
[<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.12>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=3.04>]
Traceback (most recent call last):
File "temp_bug.py", line 47, in <module>
example.apply_update()
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 650, in _call
return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1665, in _filtered_call
self.captured_inputs)
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 598, in call
ctx=ctx)
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 75, in quick_execute
raise e
File "/home/mroos/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
inputs, attrs, num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: gradient_tape/mul/Mul:0
【问题讨论】:
标签: python tensorflow gradient tensor eager-execution