【问题标题】:Cross Entropy for Soft Labeling in PytorchPytorch 中软标签的交叉熵
【发布时间】:2022-01-22 13:03:42
【问题描述】:

我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的浮点数。

Pytorch 中的torch.nn.CrossEntropy 不支持软标签,所以我正在尝试自己编写一个交叉熵函数。

我的函数是这样的

def cross_entropy(self, pred, target):
    loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
    return loss

def step(self, batch: Any):
    x, y = batch
    logits = self.forward(x)
    loss = self.criterion(logits, y)
    preds = logits
    # torch.argmax(logits, dim=1)
    return loss, preds, y

但它根本不起作用。

谁能给我一个建议,我的损失函数有什么错误吗?

【问题讨论】:

  • 请勿发布代码截图。将代码粘贴到您的问题中并正确格式化
  • 当然,我已经添加了

标签: pytorch cross-entropy


【解决方案1】:

似乎BCELoss 和强大的版本BCEWithLogitsLoss 正在“开箱即用”地处理模糊目标。他们不希望 target 是二进制的”任何介于 0 和 1 之间的数字都可以。
请阅读文档。

【讨论】:

    猜你喜欢
    • 2021-10-24
    • 2021-08-25
    • 2018-04-14
    • 2020-07-16
    • 2019-11-02
    • 2020-06-04
    • 1970-01-01
    • 2018-09-03
    • 2021-01-21
    相关资源
    最近更新 更多