【问题标题】:RMSE loss for multi output regression problem in PyTorchPyTorch 中多输出回归问题的 RMSE 损失
【发布时间】:2020-09-11 09:07:39
【问题描述】:

我正在训练一个 CNN 架构来使用 PyTorch 解决回归问题,其中我的输出是 20 个值的张量。我计划使用 RMSE 作为模型的损失函数,并尝试使用 PyTorch 的 nn.MSELoss() 并使用 torch.sqrt() 为其取平方根,但在获得结果后感到困惑。我会尽力解释原因.很明显,对于批量大小bs,我的输出张量的尺寸将是[bs , 20]。我尝试实现自己的 RMSE 函数:

   def loss_function (predicted_x , target ):
        loss = torch.sum(torch.square(predicted_x - target) , axis= 1)/(predicted_x.size()[1]) #Taking the mean of all the squares by dividing it with the number of outputs i.e 20 in my case
        loss = torch.sqrt(loss)
        loss = torch.sum(loss)/predicted_x.size()[0]  #averaging out by batch-size
        return loss

但是我的 loss_function() 的输出和 PyTorch 如何使用 nn.MSELoss() 实现它不同。我不确定我的实现是错误的还是我以错误的方式使用了nn.MSELoss()

【问题讨论】:

    标签: python deep-learning pytorch artificial-intelligence loss-function


    【解决方案1】:

    MSE 损失是误差平方均值。您在计算 MSE 后取平方根,因此无法将损失函数的输出与 PyTorch nn.MSELoss() 函数的输出进行比较——它们正在计算不同的值。

    但是,您可以使用 nn.MSELoss() 创建自己的 RMSE 损失函数:

    loss_fn = nn.MSELoss()
    RMSE_loss = torch.sqrt(loss_fn(prediction, target))
    RMSE_loss.backward()
    

    希望对您有所帮助。

    【讨论】:

      【解决方案2】:

      要复制默认 PyTorch 的 MSE(均方误差)损失函数,您需要将 loss_function 方法更改为以下内容:

      def loss_function (predicted_x , target ):
          loss = torch.sum(torch.square(predicted_x - target) , axis= 1)/(predicted_x.size()[1])
          loss = torch.sum(loss)/loss.shape[0]
          return loss
      

      这就是上述方法有效的原因 - MSE 损失意味着均方误差损失。因此,您不必在代码中实现平方根 (torch.sqrt)。默认情况下,PyTorch 中的损失对批处理中的所有示例进行平均以计算损失。因此方法中的第二行。

      要实施 RMSELoss 并将其集成到您的训练中,您可以这样做:

      class RMSELoss(torch.nn.Module):
          def __init__(self):
              super(RMSELoss,self).__init__()
      
          def forward(self,x,y):
              criterion = nn.MSELoss()
              loss = torch.sqrt(criterion(x, y))
              return loss
      

      你可以调用这个类,类似于 PyTorch 中的任何损失函数。

      【讨论】:

      • 我确实使用torch.sqrt()nn.MSELoss() 来获得RMSE。我已经在我的问题中更新了这一点。但是问题是,在我的原始损失函数中,如果我在对批量大小进行平均后取平方根,它符合 PyTorch 的版本,但是如果我之前应用它,就像我在我的帖子中所做的那样,它会出现不同
      猜你喜欢
      • 2020-05-13
      • 2019-06-23
      • 2021-06-21
      • 2019-07-24
      • 2020-11-01
      • 1970-01-01
      • 2018-12-14
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多