【发布时间】:2020-12-29 23:03:05
【问题描述】:
我正在尝试查找使用 PyTorch 创建的模型的准确性,但出现错误。最初我有一个不同的错误,已修复,但现在我得到了这个错误。
我用它来获取我的测试集:
testset = torchvision.datasets.FashionMNIST(MNIST_DIR, train=False,
download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
torchvision.transforms.ToTensor(), # image to Tensor
torchvision.transforms.Normalize((0.1307,), (0.3081,)) # image, label
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False)
当我尝试访问我创建的测试集时,它会出于某种原因尝试重新训练模型,然后继续出错。 这是获取准确率并调用测试集的代码
correct = 0
total = 0
with torch.no_grad():
print("entered here")
for (x, y_gt) in testloader:
x = x.to(device)
y_gt = y_gt.to(device)
outputs = teacher_model(x)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
这是我得到的错误:
Traceback (most recent call last):
File "[path]/train_teacher_1.py", line 134, in <module>
outputs = teacher_model(x)
File "[path]\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "[path]\models.py", line 17, in forward
x = F.relu(self.layer1(x))
File "[path]\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "[path]\anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "[path]\anaconda3\lib\site-packages\torch\nn\functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
如果您想要训练模型的其余代码,请告诉我。因为帖子太长,我把它省略了。
我是 PyTorch 的新手,感谢任何帮助。提前致谢。
【问题讨论】:
-
据我所知,该错误与您发布的代码无关。
-
@hkchengrex 你知道如果不是代码可能会导致问题吗?
标签: python machine-learning pytorch