【问题标题】:TensorFlow: slow performance when getting gradients at inputsTensorFlow:在输入处获取梯度时性能缓慢
【发布时间】:2016-03-27 08:44:20
【问题描述】:

我正在使用 TensorFlow 构建一个简单的多层感知器,我还需要获取神经网络输入处损失的梯度(或误差信号)。

这是我的代码,它有效:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
...
for i in range(epochs):
    ....
    for batch in batches:
        ...
        sess.run(optimizer, feed_dict=feed_dict)
        grads_wrt_input = sess.run(tf.gradients(cost, self.x), feed_dict=feed_dict)[0]

(已编辑以包括训练循环)

没有最后一行 (grads_wrt_input...),这在 CUDA 机器上运行得非常快。但是,tf.gradients() 将性能大大降低了十倍或更多。

我记得节点处的错误信号在反向传播算法中被计算为中间值,我已经使用 Java 库 DeepLearning4j 成功地完成了这项工作。我也有这样的印象,这将是对optimizer 已经构建的计算图的轻微修改。

如何使这更快,或者有没有其他方法可以计算损失 w.r.t 的梯度。输入?

【问题讨论】:

  • 你真的在训练循环中调用tf.gradients()吗?如果是这样,我怀疑开销来自每次调用它时构建反向传播图?
  • 为了清楚起见,我已经包含了训练循环代码;是的,我在训练循环中调用tf.gradients()。程序逐渐变慢。我应该怎么做才能防止这种建筑开销?
  • 在循环外调用 tf.gradients 来构建一次梯度计算图。您还可以使用 compute_gradients 重用为优化器制作的梯度图

标签: tensorflow


【解决方案1】:

tf.gradients() 函数每次调用都会构建一个新的反向传播图,因此速度变慢的原因是 TensorFlow 必须在循环的每次迭代中解析一个新图。 (这可能会非常昂贵:当前版本的 TensorFlow 已针对多次执行 same 图进行了优化。)

幸运的是,解决方案很简单:只需在循环外计算一次梯度。您可以按如下方式重组您的代码:

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.network, self.y))
optimizer = tf.train.AdagradOptimizer(learning_rate=nn_learning_rate).minimize(cost)
grads_wrt_input_tensor = tf.gradients(cost, self.x)[0]
# ...
for i in range(epochs):
    # ...
    for batch in batches:
        # ...
        _, grads_wrt_input = sess.run([optimizer, grads_wrt_input_tensor],
                                      feed_dict=feed_dict)

请注意,为了提高性能,我还合并了两个 sess.run() 调用。这确保了前向传播和大部分反向传播将被重用。


顺便说一句,发现此类性能错误的一个技巧是在开始训练循环之前致电tf.get_default_graph().finalize()。如果您无意中将任何节点添加到图表中,这将引发异常,从而更容易追踪这些错误的原因。

【讨论】:

  • 成功了,谢谢!我的程序现在很快。顺便说一句,我认为sess.run() 调用列表中的grads_wrt_input 应该是grads_wrt_input_tensor
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2015-11-09
  • 1970-01-01
  • 2019-06-18
  • 1970-01-01
  • 2017-08-14
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多