【问题标题】:What is running loss in PyTorch and how is it calculatedPyTorch 中的运行损失是什么以及如何计算
【发布时间】:2020-07-20 09:11:15
【问题描述】:

我查看了 PyTorch 文档中的 this 教程以了解迁移学习。有一条线我没听懂。

使用loss = criterion(outputs, labels)计算loss后,使用running_loss += loss.item() * inputs.size(0)计算running loss,最后使用running_loss / dataset_sizes[phase]计算epoch loss。

loss.item() 不应该用于整个小批量(如果我错了,请纠正我)。即,如果 batch_size 是 4,loss.item() 将给出整个 4 图像集的损失。如果这是真的,为什么loss.item() 在计算running_loss 时要与inputs.size(0) 相乘?在这种情况下,这一步是不是像一个额外的乘法?

任何帮助将不胜感激。谢谢!

【问题讨论】:

    标签: python deep-learning pytorch torch torchvision


    【解决方案1】:

    这是因为CrossEntropy或其他损失函数给出的损失除以元素个数即reduction参数默认为mean

    torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

    因此,loss.item() 包含整个小批量的损失,但除以批量大小。这就是为什么loss.item() 乘以由inputs.size(0) 给出的批量大小,同时计算running_loss

    【讨论】:

      【解决方案2】:

      如果 batch_size 为 4,loss.item() 将给出整个 4 张图像集的损失

      这取决于loss 的计算方式。请记住,loss 和其他张量一样是张量。一般来说,PyTorch API 默认返回 avg loss

      “损失是在每个小批量的观察中平均的。”

      t.item() 用于张量 t 只是将其转换为 python 的默认 float32。

      更重要的是,如果您是 PyTorch 的新手,知道我们使用 t.item() 来维持运行损失而不是 t 可能会对您有所帮助,因为 PyTorch 张量存储其值的历史记录,这可能会使您的 GPU 非常过载很快。

      【讨论】:

        猜你喜欢
        • 2020-12-23
        • 2020-04-30
        • 2022-01-09
        • 1970-01-01
        • 1970-01-01
        • 2021-01-23
        • 2021-08-08
        • 1970-01-01
        • 2011-03-21
        相关资源
        最近更新 更多