【问题标题】:Pytorch crossentropy loss with 3d input带有 3d 输入的 Pytorch 交叉熵损失
【发布时间】:2020-12-18 06:29:16
【问题描述】:

我有一个输出大小为(batch_size, max_len, num_classes) 的 3D 张量的网络。我的真实情况是(batch_size, max_len)。如果我确实对标签执行 one-hot 编码,它的形状将是(batch_size, max_len, num_classes),即max_len 中的值是[0, num_classes] 范围内的整数。由于原代码太长,我写了一个更简单的版本,重现原错误。

criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)

pred和label的形状分别是torch.Size([32, 350, 1000])torch.Size([32, 350])

遇到的错误是

ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])

如果我用 one-hot 编码标签来计算损失

x = nn.functional.one_hot(label)
criterion(pred, x)

它会抛出以下错误

ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])

【问题讨论】:

    标签: python neural-network pytorch cross-entropy


    【解决方案1】:

    Pytorch documentationCrossEntropyLoss 期望其输入的形状为(N, C, ...),因此第二个维度始终是类的数量。如果您将 preds 重塑为大小为 (batch_size, num_classes, max_len),您的代码应该可以工作。

    【讨论】:

    • 即使对标签进行one-hot编码,传递给CrossEntropyLoss时也会抛出错误
    • 对不起,我想我找到了问题并更新了我的答案。
    猜你喜欢
    • 2021-08-25
    • 2020-08-13
    • 2018-04-14
    • 2019-11-02
    • 2020-07-16
    • 2021-01-21
    • 2021-01-02
    • 2022-01-09
    • 2020-03-14
    相关资源
    最近更新 更多