【问题标题】:Resume training with Adam optimizer in Keras在 Keras 中使用 Adam 优化器恢复训练
【发布时间】:2020-05-21 10:09:47
【问题描述】:

我的问题很简单,但我在网上找不到明确的答案(到目前为止)。

我保存了一个使用 adam 优化器训练的 keras 模型的权重,该模型在经过定义数量的训练后使用:

callback = tf.keras.callbacks.ModelCheckpoint(filepath=path, save_weights_only=True)
model.fit(X,y,callbacks=[callback])

当我关闭jupyter后恢复训练时,我可以简单地使用:

model.load_weights(path)

继续训练。

由于 Adam 依赖于 epoch 数(例如在学习率衰减的情况下),我想知道在与以前相同的条件下恢复训练的最简单方法。

根据 ibarrond 的回答,我编写了一个小的自定义回调。

optim = tf.keras.optimizers.Adam()
model.compile(optimizer=optim, loss='categorical_crossentropy',metrics=['accuracy'])

weight_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1, save_best_only=False)

class optim_callback(tf.keras.callbacks.Callback):
    '''Custom callback to save optimiser state'''

          def on_epoch_end(self,epoch,logs=None):
                optim_state = tf.keras.optimizers.Adam.get_config(optim)
                with open(optim_state_pkl,'wb') as f_out:                  
                       pickle.dump(optim_state,f_out)

model.fit(X,y,callbacks=[weight_callback,optim_callback()])

当我恢复训练时:

model.load_weights(checkpoint_path)
with open(optim_state_pkl,'rb') as f_out:                  
                    optim_state = pickle.load(f_out)
tf.keras.optimizers.Adam.from_config(optim_state)

我只是想检查一下这是否正确。再次感谢!

附录:进一步阅读Adam的默认Keras implementationoriginal Adam paper,我认为默认的Adam不依赖于epoch数,而只依赖于迭代次数。因此,这是不必要的。但是,对于希望跟踪其他优化器的任何人来说,该代码可能仍然有用。

【问题讨论】:

  • 为什么这不是 Adam 不依赖于时代编号的必要条件?纪元数和迭代数的相关性是否不够,以至于您在恢复训练时仍想跟踪它们? Epoch = iteration * batch_size ,因为batch_size(大多数时候)是不变的,我猜它们都同样重要。
  • @ibarrond 我之所以相信 Adam 不会在 epochs 中发生变化是因为从 get_config 打印出 dict 显示了 Adam 在 Keras 中的原始默认配置。因此,我认为它保持不变。然而,你确实让我觉得也许还有更多。仅在每个 epoch 内更改 lr 并在下一个 epoch 重新启动回默认 lr 是没有意义的。但是我在 keras 中找不到任何关于如何在 epoch 之间跟踪迭代的文档。如果是这样,get_config 似乎并没有给我关于优化器状态的所有信息。

标签: python tensorflow keras adam


【解决方案1】:

为了完美捕获优化器的状态,您应该使用函数get_config() 存储其配置。此函数返回一个字典(包含选项),可以使用pickle 将其序列化并存储在一个文件中。

要重新启动该过程,只需d = pickle.load('my_saved_tfconf.txt') 使用配置检索字典,然后使用Keras Adam Optimizer. 的函数from_config(d) 生成您的亚当优化器

【讨论】:

  • 非常感谢这个提示,听起来像是一个更合理的方法:)
  • 我已经添加了小的自定义代码,如果可能的话,想知道它是否正确实现,抱歉给您带来麻烦!
猜你喜欢
  • 2019-10-18
  • 1970-01-01
  • 2018-07-04
  • 1970-01-01
  • 2019-08-27
  • 1970-01-01
  • 2017-07-08
  • 2020-05-18
  • 2020-08-25
相关资源
最近更新 更多