【问题标题】:Pytorch running_mean, running_var and num_batches_tracked are updated during training, but I want to fix themPytorch running_mean、running_var 和 num_batches_tracked 在训练期间更新,但我想修复它们
【发布时间】:2022-01-12 12:18:18
【问题描述】:

在pytorch中,我想使用一个预训练的模型,训练我的模型给模型结果添加一个delta,即:

        ╭----- (pretrained model) ------ result ---╮
 input------------- (my model) --------- Δresult --+-- final_result

这是我所做的:

  1. 使用load_state_dict 加载预训练模型的参数
  2. 设置所有预训练模型的参数requires_grad = False
  3. 创建我的模型并开始训练

但是在训练过程之后,当我检查result(预训练模型的输出)时,我发现它与原始预训练模型输出不匹配。我仔细比较了预训练模型的参数,唯一的变化是BatchNorm2drunning_meanrunning_varnum_batches_tracked(因为我设置了所有预训练模型的参数requires_grad = False),当我把这三个参数改回原始的,result 匹配原始预训练模型输出。

我不希望对预训练模型进行任何更改。那么有没有办法修复running_meanrunning_varnum_batches_tracked

【问题讨论】:

    标签: python pytorch pre-trained-model batch-normalization


    【解决方案1】:

    我偶然发现了同样的问题,因此我将this repo 中的上下文管理器修改如下:

    @contextlib.contextmanager
    def _disable_tracking_bn_stats(self):
        def switch_attr():
            if not hasattr(self, 'running_stats_modules'):
                self.running_stats_modules = \
                    [mod for n, mod in self.model.named_modules() if
                     hasattr(mod, 'track_running_stats')]
    
            for mod in self.running_stats_modules:
                mod.track_running_stats ^= True
    
        switch_attr()
        yield
        switch_attr()
    

    作为替代方案,我认为您可以通过在 BatchNorm 模块上调用 eval 来获得类似的结果:

    for layer in net.modules():
        if isinstance(layer, BatchNorm2d):
            layer.eval()
    

    虽然第一种方法更有原则。

    【讨论】:

    • 非常感谢您的回答!我解决了我的问题!
    • 那么请把答案标记为正确:)
    猜你喜欢
    • 2020-08-02
    • 2018-06-22
    • 2021-03-28
    • 2019-07-15
    • 1970-01-01
    • 2021-05-09
    • 2019-12-07
    • 2019-10-24
    • 2019-03-10
    相关资源
    最近更新 更多