【发布时间】:2020-06-26 04:30:55
【问题描述】:
我定义我的神经网络
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Linear(64, 10)
def forward(self, x):
# make sure input tensor is flattened
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.log_softmax(self.fc4(x), dim=1)
return x
model = Classifier()
我训练神经网络
我保存训练好的神经网络:
checkpoint = {'input_size': 784,
'output_size': 10,
'hidden_layers': [256, 128, 64],
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')
state_dict = torch.load('checkpoint.pth')
当我尝试加载保存的神经网络时,出现错误
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
# I believe the error is in the line directly below
model_b = model(checkpoint['input_size'], checkpoint['output_size'], checkpoint['hidden_layers'])
model_b.load_state_dict(checkpoint['state_dict'])
return model_b
model_b = load_checkpoint('checkpoint.pth')
我收到以下错误:
TypeError: forward() takes 2 positional arguments but 4 were given
【问题讨论】:
-
总是将完整的错误消息(从单词“Traceback”开始)作为文本(不是屏幕截图)放在有问题的(不是评论)中。还有其他有用的信息。
-
显示完整的错误信息 - 它应该显示出问题的确切位置。也许普通的
nn.Model有函数forward(),它有4 个参数,它在某些地方使用它,但你用forward()替换它,它只有2 个参数。
标签: python neural-network