【问题标题】:What are saved in optimizer's state_dict? what "state","param_groups" stands for?优化器的 state_dict 中保存了什么? "state","param_groups" 代表什么?
【发布时间】:2020-06-08 11:25:46
【问题描述】:

当我们使用 Adam 优化器时,如果我们想从预训练模型继续训练网络,我们不仅应该加载“model.state_dict”,还应该加载“optimizer.state_dict”。而且,如果我们修改了网络的结构,我们还应该修改保存的优化器的 state_dict 以使我们的加载成功。

但我不明白保存的“optimizer.state_dict”中的一些参数。像 optim_dict["state"] (dict_keys(['step', 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'])) 和 optim_dict['param_groups'][0]['params']。有很多这样的数字:

 b['optimizer_state_dict']['state'].keys()
Out[71]: dict_keys([140623218628000, 140623218628072, 140623218628216, 140623218628360, 140623218628720, 140623218628792, 140623218628936, 140623218629080, 140623218629656, 140623218629728, 140623218629872, 140623218630016, 140623218630376, 140623218630448, 140623218716744, 140623218716816, 140623218717392, 140623218717464, 140623218717608, 140623218717752, 140623218718112, 140623218718184, 140623218718328, 140623218718472, 140623218719048, 140623218719120, 140623218719264, 140623218719408, 140623218719768, 140623218719840, 140623218719984, 140623218720128, 140623218720704, 140623209943112, 140623209943256, 140623209943400, 140623209943760, 140623209943832, 140623209943976, 140623209944120, 140623209944696, 140623209944768, 140623209944912, 140623209945056, 140623209945416, 140623209945488, 140623209945632, 140623209945776, 140623209946352, 140623209946424, 140623209946568, 140623209946712, 140623209947072, 140623210041416, 140623210041560, 140623210041704, 140623244033768, 140623244033840, 140623244033696, 140623244033912, 140623244033984, 140623244070984, 140623244071056, 140623244071128, 140623429501576, 140623244071200, 140623244071272, 140623244071344, 140623244071416, 140623244071488, 140623244071560, 140623244071632, 140623244071848, 140623244071920, 140623244072064, 140623244072208, 140623244072424, 140623244072496, 140623244072640, 140623244072784, 140623244073216, 140623244073288, 140623244073432, 140623244073576, 140623244073792, 140623244073864, 140623244074008, 140623244074152, 140623244074584, 140623244074656, 140623244074800, 140623244074944, 140623218540760, 140623218540832, 140623218540976, 140623218541120, 140623218541552, 140623218541624, 140623218541768, 140623218541912, 140623218542128, 140623218542200, 140623218542344, 140623218542488, 140623218542920, 140623218542992, 140623218543136, 140623218543280, 140623218543496, 140623218543568, 140623218543712, 140623218543856, 140623218544288, 140623218544360, 140623218544504, 140623218626632, 140623218626992, 140623218627064, 140623218627208, 140623218627352, 140623218627784, 140623218629440, 140623218717176, 140623218718832, 140623218720488, 140623209944480, 140623209946136, 140623210043000])

In [44]: b['optimizer_state_dict']['state'][140623218628072].keys()
Out[44]: dict_keys(['step', 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'])

In [45]: b['optimizer_state_dict']['state'][140623218628072]['exp_avg'].shape
Out[45]: torch.Size([480])

【问题讨论】:

    标签: deep-learning pytorch


    【解决方案1】:

    与保存可学习参数的模型state_dict 相比,优化器的state_dict 包含有关优化器状态(要优化的参数)以及使用的超参数的信息。


    PyTorch 中的所有优化器都需要从基类torch.optim.Optimizer 继承。它需要两个条目:

    • params (iterable)——torch.Tensors 或 dicts 的可迭代对象。指定应该优化哪些张量。
    • defaults (dict)dict 包含优化选项的默认值(在参数组未指定它们时使用)。

    除此之外,优化器还支持指定每个参数的选项。

    为此,不要传递Tensors 的迭代,而是传递dicts 的迭代。它们中的每一个都将定义一个单独的参数组,并且应该包含一个 params 键,其中包含属于它的参数列表。

    举个例子,

    optim.SGD([
                    {'params': model.base.parameters()},
                    {'params': model.classifier.parameters(), 'lr': 1e-3}
                ], lr=1e-2, momentum=0.9)
    

    在这里,我们提供了 a) params、b) 默认超参数:lrmomentum 和 c) 参数组。在这种情况下,model.base 的参数将使用默认的 1e-2 学习率,model.classifier 的参数将使用 1e-3 的学习率,所有参数将使用 0.9 的动量。


    step (optimizer.step()) 执行单个优化步骤(参数更新),这会改变优化器的状态。


    现在,来到优化器的state_dict,它将优化器的状态返回为dict。它包含两个条目:

    • state - 一个 dict 保持当前优化状态。
    • param_groups - 一个包含所有参数组的dict(如上所述)

    某些超参数特定于使用的优化器或模型,例如(用于亚当)

    • exp_avg:梯度值的指数移动平均值
    • exp_avg_sq:平方梯度值的指数移动平均值

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-03-06
      • 1970-01-01
      • 2015-09-15
      • 2011-05-03
      相关资源
      最近更新 更多