【发布时间】:2017-12-08 13:48:55
【问题描述】:
我有一个网络在 5D 输入张量上执行 3D 卷积。我的网络的输出大小(1、12、60、36、60)对应于(BatchSize、NumClasses、x-dim、y-dim、z-dim)。我需要计算体素交叉熵损失。但是我不断收到错误。
当尝试使用 torch.nn.CrossEntropyLoss() 计算交叉熵损失时,我不断收到以下错误消息:
RuntimeError: multi-target not supported at .../src/THCUNN/generic/ClassNLLCriterion.cu:16
这是我的代码摘录:
import torch
import torch.nn as nn
from torch.autograd import Variable
criterion = torch.nn.CrossEntropyLoss()
images = Variable(torch.randn(1, 12, 60, 36, 60)).cuda()
labels = Variable(torch.zeros(1, 12, 60, 36, 60).random_(2)).long().cuda()
loss = criterion(images.view(1,-1), labels.view(1,-1))
当我为标签创建一个单热张量时也会发生同样的情况:
nclasses = 12
labels = (np.random.randint(0,12,(1,60,36,60))) # Random labels with values between [0..11]
labels = (np.arange(nclasses) == labels[..., None] - 1).astype(int) # Converts labels to one_hot_tensor
a = np.transpose(labels,(0,4,3,2,1)) # Reorder dimensions to match shape of "images" ([1, 12, 60, 36, 60])
b = Variable(torch.from_numpy(a)).cuda()
loss = criterion(images.view(1,-1), b.view(1,-1))
知道我做错了什么吗? 有人可以提供一个在 5D 输出张量上计算交叉熵的示例吗?
【问题讨论】:
标签: pytorch