【问题标题】:PyTorch: _thnn_nll_loss_forward is not implemented for type torch.LongTensorPyTorch:_thnn_nll_loss_forward 未针对类型 torch.LongTensor 实现
【发布时间】:2019-09-18 17:11:25
【问题描述】:

当我尝试使用 PyTorch 创建模型时,当我尝试实现损失函数 nll_loss 时,它会抛出以下错误

RuntimeError: _thnn_nll_loss_forward is not implemented for type torch.LongTensor 

我创建的拟合函数是:

for epoch in tqdm_notebook(range(1, epochs+1)):
    for batch_idx, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        net.float()
        output = net(data)
        output_x = output.argmax(dim=2) #to convert (64,50,43) -> (64, 50)
        loss = F.nll_loss(output_x, targets)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train epochs: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx*len(data), len(ds.data),
                100.*batch_idx / len(ds), loss.item()
            ))

输出和目标的形状是 (64, 50) 并且两者的 dtype 都是 torch.int64

【问题讨论】:

    标签: python machine-learning computer-vision pytorch loss-function


    【解决方案1】:

    查看descriptionF.nll_loss。它期望获得的输入不是预测的argmax(类型torch.long),而是完整的64x50x43 预测向量(类型torch.float)。请注意,您提供给F.nll_loss 的预测确实比您提供的地面实况目标具有额外的维度。

    在您的情况下,只需删除 argmax:

    loss = F.nll_loss(output, targets)
    

    【讨论】:

    • 我已经尝试删除 argmax,然后我收到此错误 RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'
    • @thanatoz 似乎您将net 转换为float,但由于某种原因,您的output 类型为torch.double
    • 你有什么建议。我应该如何解决这个问题?我发现nn.GRU(input_size=100, hidden_size=50, dropout=0.5, bidirectional=True, num_layers=2, batch_first=True)对此负责。
    • @thanatoz 确保您的data 和隐藏状态都是torch.float 类型。请注意,x.to(torch.float)x.float() 不是就地操作,您需要 x = x.to(torch.float) 才能将 x 设为浮点数。
    【解决方案2】:

    看起来您正在处理具有43 类的分类任务,使用的批量大小为64,“序列长度”为50

    如果是这样,我相信您对使用argmax()F.log_softmax 有点困惑。正如 Shai 给出的参考,鉴于 output 是 logit 值,您可以使用:

    output_x = F.log_softmax(output, dim=2)
    loss = F.nll_loss(output_x, targets)
    

    这是使用nll_loss的正确方法,或者如果你不想做log_softmax 你自己,你可以改用nn.CrossEntropyLoss

    【讨论】:

    • 是的,你没看错。我确实尝试了您的建议,但遇到了以下错误RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1' 。我做错了什么?
    • 好的,谢谢大卫。
    猜你喜欢
    • 2019-12-26
    • 2019-03-26
    • 2021-06-09
    • 1970-01-01
    • 2019-01-11
    • 2018-12-29
    • 2020-02-29
    • 2021-06-05
    • 1970-01-01
    相关资源
    最近更新 更多