【发布时间】:2020-12-18 06:29:16
【问题描述】:
我有一个输出大小为(batch_size, max_len, num_classes) 的 3D 张量的网络。我的真实情况是(batch_size, max_len)。如果我确实对标签执行 one-hot 编码,它的形状将是(batch_size, max_len, num_classes),即max_len 中的值是[0, num_classes] 范围内的整数。由于原代码太长,我写了一个更简单的版本,重现原错误。
criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)
pred和label的形状分别是torch.Size([32, 350, 1000])和torch.Size([32, 350])
遇到的错误是
ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])
如果我用 one-hot 编码标签来计算损失
x = nn.functional.one_hot(label)
criterion(pred, x)
它会抛出以下错误
ValueError: Expected target size (32, 1000), got torch.Size([32, 350, 1000])
【问题讨论】:
标签: python neural-network pytorch cross-entropy