【问题标题】:How can I create a custom keras optimizer?如何创建自定义 keras 优化器?
【发布时间】:2021-03-15 22:44:50
【问题描述】:

我正在比较 SVRG、SAG 和其他优化器在深度学习最小化方面的性能。

如何使用 keras 实现自定义优化器,我尝试在 source code 处查看 SGD keras 实现,但找不到 tf.raw_ops.ResourceApplyGradientDescent 的源代码,这使得很难为另一个优化器重现。

【问题讨论】:

  • 我认为你的意思是自定义优化器,装扮的优化器会是别的东西,就像 cosplay 中的优化器 :)

标签: tensorflow keras optimization tf.keras


【解决方案1】:

自定义优化器:

  • 扩展tf.keras.optimizers.Optimizer
  • 覆盖_create_slots:用于为每个可训练变量创建优化器变量。如果您需要为优化器添加动力,这将非常有用。
  • 覆盖 _resource_apply_dense_resource_apply_sparse 以执行优化器的实际更新和方程。
  • get_config(可选):存储您传递给优化器的参数,以便您可以在之后克隆或保存您的模型。

这里是一个简单的 SGD 示例,动量取自 here

class MyMomentumOptimizer(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.001, momentum=0.9, name="MyMomentumOptimizer", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        self._set_hyper("decay", self._initial_decay) # 
        self._set_hyper("momentum", momentum)
    
    def _create_slots(self, var_list):
        """For each model variable, create the optimizer variable associated with it.
        TensorFlow calls these optimizer variables "slots".
        For momentum optimization, we need one momentum slot per model variable.
        """
        for var in var_list:
            self.add_slot(var, "momentum")

    @tf.function
    def _resource_apply_dense(self, grad, var):
        """Update the slots and perform one optimization step for one model variable
        """
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype) # handle learning rate decay
        momentum_var = self.get_slot(var, "momentum")
        momentum_hyper = self._get_hyper("momentum", var_dtype)
        momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)
        var.assign_add(momentum_var * lr_t)

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        }

【讨论】:

  • 谢谢!但是我仍然对 _resource_apply_dense 方法有一些问题。以前,当我必须在与神经网络不同的框架中编写优化算法时,我必须循环一定次数的迭代来更新参数,我觉得我们在这里只进行了 1 次更新。
  • 很高兴它有帮助,_resource_apply_dense 在每个批次上运行,您所说的循环是在 keras 的 fit 函数中实现的,而不是在优化器中。
猜你喜欢
  • 2020-03-05
  • 1970-01-01
  • 1970-01-01
  • 2018-11-26
  • 1970-01-01
  • 2019-07-17
  • 2015-09-07
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多