【问题标题】:Load/test previously trained and saved neural network Python加载/测试先前训练和保存的神经网络 Python
【发布时间】:2020-06-26 04:30:55
【问题描述】:

我定义我的神经网络


class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        # make sure input tensor is flattened
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)
        
        return x

model = Classifier()

我训练神经网络

我保存训练好的神经网络:

checkpoint = {'input_size': 784,
              'output_size': 10,
              'hidden_layers': [256, 128, 64],
              'state_dict': model.state_dict()}

torch.save(checkpoint, 'checkpoint.pth')
state_dict = torch.load('checkpoint.pth')

当我尝试加载保存的神经网络时,出现错误

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    # I believe the error is in the line directly below
    model_b = model(checkpoint['input_size'], checkpoint['output_size'], checkpoint['hidden_layers'])
    model_b.load_state_dict(checkpoint['state_dict'])
    return model_b

model_b = load_checkpoint('checkpoint.pth')

我收到以下错误:

TypeError: forward() takes 2 positional arguments but 4 were given

【问题讨论】:

  • 总是将完整的错误消息(从单词“Traceback”开始)作为文本(不是屏幕截图)放在有问题的(不是评论)中。还有其他有用的信息。
  • 显示完整的错误信息 - 它应该显示出问题的确切位置。也许普通的nn.Model 有函数forward(),它有4 个参数,它在某些地方使用它,但你用forward() 替换它,它只有2 个参数。

标签: python neural-network


【解决方案1】:

我认为你遗漏了几点:

  • 您的 __init__ 类函数没有参数,您的神经网络具有固定参数,因此您不能使用 dict 对象的其他键来创建具有相同参数的模型。
  • nn.Module 函数有一个名为__call__ 的方法,该方法重定向到forward 方法。每当您运行 Object(something) 时都会运行此函数,其中某些内容将是函数参数。在load_checkpoint 中,您运行了model_b = model(checkpoint['input_size'], checkpoint['output_size'], checkpoint['hidden_layers'])。您尝试使用字典中的一些元素进行前向传递。因此错误(4个参数是modelcheckpoint['input_size']checkpoint['output_size']checkpoint['hidden_layers'])。

要解决加载模型的问题,只需删除这行model_b = model(checkpoint['input_size'], checkpoint['output_size'], checkpoint['hidden_layers']),我认为它应该可以工作。

如果您希望使用检查点 input_size、output_size 和隐藏层创建模型,您应该在构造函数中使用这些参数: model = Classifier(checkpoint['input_size'], checkpoint['output_size'], checkpoint['hidden_layers'])。您的代码需要进行一些更改才能使其正常工作。

【讨论】:

    猜你喜欢
    • 2018-10-30
    • 2020-05-20
    • 1970-01-01
    • 2019-07-23
    • 1970-01-01
    • 2020-09-10
    • 2013-05-25
    • 2014-02-24
    • 2015-06-04
    相关资源
    最近更新 更多