【发布时间】:2022-01-22 13:03:42
【问题描述】:
我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的浮点数。
Pytorch 中的torch.nn.CrossEntropy 不支持软标签,所以我正在尝试自己编写一个交叉熵函数。
我的函数是这样的
def cross_entropy(self, pred, target):
loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
return loss
def step(self, batch: Any):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = logits
# torch.argmax(logits, dim=1)
return loss, preds, y
但它根本不起作用。
谁能给我一个建议,我的损失函数有什么错误吗?
【问题讨论】:
-
请勿发布代码截图。将代码粘贴到您的问题中并正确格式化
-
当然,我已经添加了
标签: pytorch cross-entropy