sys.exc_info() 返回一个包含三个值 (type, value, traceback) 的元组。
- 这里
type获取正在处理的异常的异常类型
-
value 是传递给异常类的构造函数的参数。
-
traceback 包含堆栈信息,例如发生异常的位置等。
在 GradientTape 上下文中,当异常发生时 sys.exc_info() 详细信息将传递给 exit() 函数,该函数将 Exits the recording context, no further operations are traced。
以下是说明相同的示例。
让我们考虑一个简单的函数。
def f(w1, w2):
return 3 * w1 ** 2 + 2 * w1 * w2
不使用with 声明:
w1, w2 = tf.Variable(5.), tf.Variable(3.)
tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
print(ex)
exec_tup = sys.exc_info()
tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])
打印:
GradientTape.gradient 只能在非持久性磁带上调用一次。
即使你没有通过传值显式退出,程序也会通过这些值来退出 GradientTaoe 录制,下面是示例。
w1, w2 = tf.Variable(5.), tf.Variable(3.)
tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
print(ex)
打印相同的异常消息。
通过使用with 声明。
with tf.GradientTape() as tape:
z = f(w1, w2)
dz_dw1 = tape.gradient(z, w1)
try:
dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
print(ex)
exec_tup = sys.exc_info()
tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])
下面是上述异常的sys.exc_info() 响应。
(RuntimeError,
RuntimeError('GradientTape.gradient can only be called once on non-persistent tapes.'),
<traceback at 0x7fcd42dd4208>)
编辑 1:
正如评论中提到的user2357112 supports Monica。提供非异常情况的解决方案。
在非异常情况下,规范要求传递给__exit__ 的值都应为None。
示例 1:
x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z = x*x
dy_dx = g.gradient(y, x)
# dz_dx = g.gradient(z, x)
print(dy_dx)
# print(dz_dx)
打印:
tf.Tensor(6.0, shape=(), dtype=float32)
由于y 在__exit__ 之前被捕获,它返回渐变值。
示例 2:
x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z = x*x
# dy_dx = g.gradient(y, x)
dz_dx = g.gradient(z, x)
# print(dy_dx)
print(dz_dx)
打印:
None
这是因为z 在__exit__ 之后被捕获,因此渐变停止记录。