【问题标题】:Setting learning rate for Stochastic Weight Averaging in PyTorch在 PyTorch 中设置随机权重平均的学习率
【发布时间】:2021-10-13 23:41:27
【问题描述】:

以下是 Pytorch 中随机权重平均的小工作代码,取自 here

loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(300):
    for input, target in loader:
        optimizer.zero_grad()
        loss_fn(model(input), target).backward()
        optimizer.step()
        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

    # Update bn statistics for the swa_model at the end
    torch.optim.swa_utils.update_bn(loader, swa_model)
    # Use swa_model to make predictions on test data
    preds = swa_model(test_input)

在第 160 个 epoch 之后的代码中,swa_scheduler 被使用,而不是通常的 schedulerswa_lr 是什么意思? documentation 说,

通常,在 SWA 中,学习率设置为较高的常数值。 SWALR 是一个学习率调度器,它将学习率退火到一个固定值,然后保持不变。

  1. 那么在第 160 个 epoch 之后,optimizer 的学习率会发生什么变化?
  2. swa_lr 会影响optimizer 的学习率吗?

假设在代码的开头optimizerADAM 初始化为1e-4 的学习率。那么上面的代码是否暗示对于前 160 个 epoch,训练的学习率将是 1e-4,然后对于剩余的 epoch 数,它将是 swa_lr=0.05?如果是,将swa_lr 也定义为1e-4 是否是个好主意?

【问题讨论】:

    标签: python machine-learning optimization pytorch


    【解决方案1】:
    • 上面的代码是否暗示对于前 160 个 epoch 的训练学习率将是 1e-4

      不,它不等于 1e-4,在前 160 个 epoch 期间,学习率由第一个调度程序 scheduler 管理。这是一个初始化为torch.optim.lr_scheduler.CosineAnnealingLR。学习率将遵循这条曲线:


    • 对于剩余的 epoch 数,它将是 swa_lr=0.05

      这是部分正确的,在第二部分 - 从 epoch 160 - 优化器的学习率将由第二个调度器 swa_scheduler 处理。这个被初始化为torch.optim.swa_utils.SWALR。您可以在文档页面上阅读:

      SWALR 是一个学习率调度程序,它将学习率退火到一个固定值 [swa_lr],然后保持不变

      默认情况下(参见source code),退火前的 epoch 数等于 10。因此,从 epoch 170 到 epoch 300 的学习率将等于 swa_lr 并将保持这种状态。第二部分将是:

      这个完整的个人资料,两部分:


    • 如果是,将swa_lr 也定义为1e-4 是否是个好主意

      文档中有提到:

      通常,在 SWA 中,学习率设置为较高的常数值。

      swa_lr 设置为1e-4 将产生以下学习率配置文件:

    【讨论】:

    • 谢谢。还有一个问题。保存 SWA 模型是否可行:torch.save({'model': swa_model.state_dict()},os.path.join('weights_{}'.format(epoch)))?
    猜你喜欢
    • 2020-09-13
    • 1970-01-01
    • 2021-05-09
    • 2020-11-16
    • 2017-10-04
    • 1970-01-01
    • 2019-09-03
    • 2020-03-19
    • 1970-01-01
    相关资源
    最近更新 更多