大多数浮点运算都会有梯度,因此第一次通过的答案就是检查图中是否没有 int32/int64 dtype 张量。这很容易做到,但可能没有用(即任何重要的模型都将执行不可微分的索引操作)。
您可以进行某种类型的自省,循环遍历 GraphDef 中的操作并检查它们是否注册了渐变。我认为这也不是非常有用。如果我们不相信梯度首先被注册,为什么要相信它们在注册后是正确的?
相反,我会在您的模型的几个点上进行数值梯度检查。例如,假设我们注册了一个没有渐变的 PyFunc:
import tensorflow as tf
import numpy
def my_func(x):
return numpy.sinh(x)
with tf.Graph().as_default():
inp = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [inp], tf.float32) + inp
grad, = tf.gradients(y, inp)
with tf.Session() as session:
print(session.run([y, grad], feed_dict={inp: 3}))
print("Gradient error:", tf.test.compute_gradient_error(inp, [], y, []))
这让我得到如下输出:
[13.017875, 1.0]
Gradient error: 1.10916996002
数值梯度可能有点棘手,但通常任何比机器 epsilon(float32 约为 1e-7)多几个数量级的梯度误差都会对我提出所谓的平滑函数的危险信号。