【问题标题】:What are the parameters to tf.GradientTape()'s __exit__ function?tf.GradientTape() 的 __exit__ 函数的参数是什么?
【发布时间】:2020-06-09 13:43:05
【问题描述】:

根据tf.GradientTapedocumentation,其__exit__() 方法采用三个位置参数:typ, value, traceback

这些参数究竟是什么?

with 语句如何推断它们?

我应该在下面的代码中给它们什么值(我不是使用with语句):

x = tf.Variable(5)

gt = tf.GradientTape()
gt.__enter__()
y = x ** 2
gt.__exit__(typ = __, value = __, traceback = __)

【问题讨论】:

标签: python tensorflow oop with-statement automatic-differentiation


【解决方案1】:

sys.exc_info() 返回一个包含三个值 (type, value, traceback) 的元组。

  1. 这里type获取正在处理的异常的异常类型
  2. value 是传递给异常类的构造函数的参数。
  3. traceback 包含堆栈信息,例如发生异常的位置等。

在 GradientTape 上下文中,当异常发生时 sys.exc_info() 详细信息将传递给 exit() 函数,该函数将 Exits the recording context, no further operations are traced

以下是说明相同的示例。

让我们考虑一个简单的函数。

def f(w1, w2):
    return 3 * w1 ** 2 + 2 * w1 * w2

不使用with 声明:

w1, w2 = tf.Variable(5.), tf.Variable(3.)

tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
    dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
    print(ex)
    exec_tup = sys.exc_info()
    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

打印:

GradientTape.gradient 只能在非持久性磁带上调用一次。

即使你没有通过传值显式退出,程序也会通过这些值来退出 GradientTaoe 录制,下面是示例。

w1, w2 = tf.Variable(5.), tf.Variable(3.)

tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
    dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
    print(ex)

打印相同的异常消息。

通过使用with 声明。

with tf.GradientTape() as tape:
    z = f(w1, w2)

dz_dw1 = tape.gradient(z, w1)
try:
    dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
    print(ex)
    exec_tup = sys.exc_info()
    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

下面是上述异常的sys.exc_info() 响应。

(RuntimeError,
 RuntimeError('GradientTape.gradient can only be called once on non-persistent tapes.'),
 <traceback at 0x7fcd42dd4208>)

编辑 1:

正如评论中提到的user2357112 supports Monica。提供非异常情况的解决方案。

在非异常情况下,规范要求传递给__exit__ 的值都应为None

示例 1:

x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z  = x*x
dy_dx = g.gradient(y, x) 
# dz_dx = g.gradient(z, x) 
print(dy_dx)
# print(dz_dx)

打印:

tf.Tensor(6.0, shape=(), dtype=float32) 

由于y__exit__ 之前被捕获,它返回渐变值。

示例 2:

x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z  = x*x
# dy_dx = g.gradient(y, x) 
dz_dx = g.gradient(z, x) 
# print(dy_dx)
print(dz_dx)

打印:

None 

这是因为z__exit__ 之后被捕获,因此渐变停止记录。

【讨论】:

  • @Abhimanyu Pallavi Sudhir - 如果您认为我已经回答了您的问题,请接受并投票。
  • 您的答案省略了极其重要的非异常情况。
  • @user2357112 支持莫妮卡 我已经修改了非异常情况的答案,谢谢。
  • 在非异常情况下,规范要求传递给__exit__的值都应该是None
  • @user2357112 支持莫妮卡,感谢您的指出,我已根据您的评论通过None进行了更改。
猜你喜欢
  • 2020-03-07
  • 2014-11-04
  • 2018-12-25
  • 2014-08-18
  • 2017-10-19
  • 2018-10-20
  • 1970-01-01
  • 2022-11-11
  • 1970-01-01
相关资源
最近更新 更多