【发布时间】:2021-05-06 03:05:09
【问题描述】:
我对如何实现 tfa 的 SWA optimizer 感到困惑。这里有两点:
- 当您查看文档时,它会将您指向 [this] 模型平均教程。该教程使用 tfa.callbacks.AverageModelCheckpoint,它允许您
- 将移动平均权重分配给模型并保存。
- (或)保留旧的非平均权重,但保存的模型使用平均权重。
拥有一个独特的 ModelCheckpoint 可以让您保存移动平均权重(而不是当前权重)是有意义的。但是 - 似乎 SWA 应该管理权重平均。这让我想设置update_weights=False。
这是正确的吗?本教程使用update_weights=True。
- documentation 中有一条关于 SWA 未更新 BN 层的说明。按照here 的建议,我这样做了,
# original training
model.fit(...)
# updating weights from final run
optimizer.assign_average_vars(model.variables)
# batch-norm-hack: lr=0 as suggested https://stackoverflow.com/a/64376062/607528
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0),
loss=loss,
metrics=metrics)
model.fit(
data,
validation_data=None,
epochs=1,
callbacks=final_callbacks)
在保存我的模型之前。
这对吗?
谢谢!
【问题讨论】:
标签: python tensorflow machine-learning deep-learning tensorflow2.0