【问题标题】:Batchnormalization over which dimension?在哪个维度上进行批量标准化?
【发布时间】:2021-08-17 14:53:46
【问题描述】:

我们在哪个维度上计算均值和标准差?是在NN层的隐藏维度上,还是在批次中的所有样本上分别针对每个隐藏维度?

在论文中它说我们对批次进行标准化。

torch.nn.BatchNorm1d 但输入参数是num_features,这对我来说没有意义。

为什么 pytorch 不遵循关于 Batchnormalization 的原始论文?

【问题讨论】:

    标签: pytorch batch-normalization


    【解决方案1】:

    我们在哪个维度上计算平均值和标准差?

    超过0th 维度,对于1D 形状(batch, num_features) 的输入,它将是:

    batch = 64
    features = 12
    data = torch.randn(batch, features)
    
    mean = torch.mean(data, dim=0)
    var = torch.var(data, dim=0)
    

    在 torch.nn.BatchNorm1d 中,输入参数是“num_features”, 这对我来说毫无意义。

    它与归一化无关,而是通过gammabeta 可学习参数对meanvar 进行重新参数化。来自论文:

    scale 和 shift phase 中使用的参数都是num_features 的形状,因此我们必须传递这个值才能将它们初始化为特定的形状。

    下面是一个从头实现的示例供参考:

    class BatchNorm1d(torch.nn.Module):
        def __init__(self, num_features, momentum: float = 0.9, eps: float = 1e-7):
            super().__init__()
            self.num_features = num_features
    
            self.gamma = torch.nn.Parameter(torch.ones(1, self.num_features))
            self.beta = torch.nn.Parameter(torch.zeros(1, self.num_features))
            
            self.register_buffer("running_mean", torch.ones(1, self.num_features))
            self.register_buffer("running_var", torch.ones(1, self.num_features))
    
            self.momentum = momentum
            self.eps = eps
    
        def forward(self, X):
            if not self.training:
                X_hat = X - self.running_mean / torch.sqrt(self.running_var + self.eps)
            else:
                mean = X.mean(dim=0).unsqueeze(dim=0)
                var = ((X - mean) ** 2).mean(dim=0).unsqueeze(dim=0)
    
                # Update running mean and variance
                self.running_mean *= self.momentum
                self.running_mean += (1 - self.momentum) * mean
    
                self.running_var *= self.momentum
                self.running_var += (1 - self.momentum) * var
    
                X_hat = X - mean / torch.sqrt(var + self.eps)
    
            return X_hat * self.gamma + self.beta
    

    为什么 pytorch 不遵循关于 Batchnormalization 的原始论文?

    一目了然

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-05-23
      • 2017-10-21
      • 2017-03-03
      • 1970-01-01
      • 1970-01-01
      • 2018-04-19
      • 2018-04-09
      • 2019-08-21
      相关资源
      最近更新 更多