【问题标题】:How to override gradient vector calculation method for optimization algos in Keras, Tensorflow?如何覆盖 Keras、Tensorflow 中优化算法的梯度向量计算方法?
【发布时间】:2020-11-08 19:52:18
【问题描述】:

所以我正在尝试修改 Keras 中的几个优化算法,即 Adam 或只是 SGD。因此,默认情况下,我很确定参数更新的工作方式是对批处理中的数据点进行平均损失,然后根据该损失值计算梯度向量。另一种思考方式是对批次中每个数据点的损失值的梯度进行平均。这是我想要改变的计算,它会很昂贵,所以我试图在使用 GPU 的优化框架内进行。

因此,对于每个批次,我需要针对批次中每个数据点的损失计算梯度,然后我不会取梯度的平均值,而是做一些其他的平均值或计算。有谁知道我将如何访问以覆盖 Adam 或 SGD 的此功能?

在发表了很棒的评论后,我发现应该有一种方法可以使用GradientTape 中的jacobian 方法来做我想做的事情。但是文档不是那么彻底,我无法弄清楚它如何适应整体情况。在这里我希望有人可以帮助我调整代码以使用jacobian 而不是gradient

作为一个 hello world 示例,我试图用一些使用 jacobian 的代码简单地替换 gradient 行并产生相同的输出。这将说明如何使用jacobian 方法以及与gradient 方法的输出的连接。

工作代码

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars) # <-- line to change
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

【问题讨论】:

  • 不确定这是否是你想要的,所以发表评论——假设你使用GradientTape,你可以使用jacobian方法(而不是gradient)来获得单独的渐变每个批处理元素,然后对它们做任何你想做的事情。
  • @xdurch0 非常感谢,我正在深入研究内部方法,并且已经看到带有gradient 的 GradientTape 对象。好的,我将研究雅可比行列式。现在查看文档中的该方法,它确实返回了渐变 w.r.t。每个数据点。太好了,谢谢!
  • @xdurch0 我想知道gradient 方法和jacobian 方法之间的确切联系?我在问题中发布了一些工作代码,作为第一个“hello world”,我试图简单地删除带有gradient 的行并使用jacobian 将其替换为一些代码。你觉得你能帮我解决这个问题吗?我没有任何运气,我也很难找到一份好的文档来准确说明我应该从jacobian 的输出中得到什么。
  • 是的,我稍后会在答案中发布一些代码。
  • @xdurch0 更多地查看文档,似乎我看到的行为可能是仅从损失函数返回平均损失的结果(这是默认值),我正在阅读损失函数可以实际上从批次中的数据点返回所有损失。使用梯度函数的损失数组可能会返回所有梯度。

标签: python tensorflow keras deep-learning


【解决方案1】:

您应该能够做到以下几点:

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.jacobian(loss, trainable_vars)

        new_gradients = []
        for grad in gradients:
            new_grad = do_something_to(grad)
            new_gradients.append(new_grad)

        # Update weights
        self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

一些重要注意事项:compiled_loss 函数返回的 loss 不能在批处理轴上平均,即我假设它是形状为 (batch_size, ) 的张量,不是标量.
这将导致 jacobian 返回形状为(batch_size, ) + variable_shape 的渐变,也就是说,您现在拥有每批元素的渐变。您现在可以随心所欲地操纵这些梯度,并且当然应该在某些时候摆脱额外的批处理轴(例如平均)。也就是说,new_grad 应该与对应的变量具有相同的形状。

关于您的最后一条评论:正如我所提到的,损失函数确实需要为每个数据点返回一个损失,即不能对批次进行平均。但是,这还不够,因为如果您将此向量提供给tape.gradient,梯度函数将简单地对损失值求和(因为它仅适用于标量)。这就是为什么需要jacobian

最后,jacobian 可能会很慢。在最坏的情况下,运行时间可能会乘以批量大小,因为它需要计算那么多单独的梯度。但是,这在某种程度上是并行完成的,因此减速可能没有那么严重。

【讨论】:

  • 谢谢你,这太好了,这就是我的想法。我试过了,现在我在fit 行上收到了这个错误。我昨天遇到了这个并正在查找它,但无法弄清楚问题是什么。我是TF 2.2.0版。错误是:UnrecognizedFlagError: Unknown command line flag 'f'.
  • 似乎是一个可以通过更改版本来修复的错误。我要试试。
  • 好的,知道了,谢谢!是的,我升级到 nightly ( 2.5.~ ),并且能够在批处理轴上使用 K.mean 重现我在第一个工作代码版本中看到的内容。
猜你喜欢
  • 1970-01-01
  • 2014-01-11
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-05-22
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多