【问题标题】:Pytorch creating model from load_state_dictPytorch 从 load_state_dict 创建模型
【发布时间】:2021-10-31 22:43:12
【问题描述】:

我正在尝试学习如何在 Pytorch 中 save and load 训练模型,但到目前为止,我只遇到错误。让我们考虑以下自包含代码:

import torch
lin=torch.nn.Linear; act=torch.nn.ReLU(); fnc=torch.nn.functional;
class Ann(torch.nn.Module): 
   def __init__(self):
      super(Ann, self).__init__() 
      self.conv1 = torch.nn.Conv2d( 1, 10, kernel_size=5) 
      self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=4) 
      self.drop = torch.nn.Dropout2d(p=0.5) 
      self.fc1 = torch.nn.Linear(320,128)  
      self.fc2 = torch.nn.Linear(128,10)   
   def forward(self, x): 
      x = self.conv1(x[:,None,:,:]);        
      x = fnc.relu(fnc.max_pool2d(x,2));    
      x = self.drop(self.conv2(x));         
      x = fnc.relu(fnc.max_pool2d(x,2));    
      x = torch.flatten(x,1);               
      x = fnc.relu(self.fc1(x));            
      x = fnc.dropout(self.fc2(x),training=self.training);   
      return fnc.log_softmax(x,dim=0) 
x,y=torch.rand((5,28,28)),torch.randint(0,9,(5,)); 
f=fnc.nll_loss;   
ann1 = torch.nn.Sequential( torch.nn.Flatten(start_dim=1), 
   lin(784,256), act, lin(256,128), act, lin(128,10), torch.nn.LogSoftmax(dim=1)) 
ann2=Ann()
F1 = torch.optim.SGD(ann1.parameters(),lr=0.01,momentum=0.5)
F2 = torch.optim.SGD(ann2.parameters(),lr=0.01,momentum=0.5)

F1.zero_grad(); y_=ann1(x); loss=f(y_,y); loss.backward(); F1.step()
print(x.dtype,y.dtype,x.shape,y.shape,y_.shape,loss); 
F2.zero_grad(); y_=ann2(x); loss=f(y_,y); loss.backward(); F2.step()
print(x.dtype,y.dtype,x.shape,y.shape,y_.shape,loss); 
name='/home/leon/'

#ann3 = ann1.__class__().load_state_dict(ann1.state_dict());  print(ann3(x)) #outputs errors
#ann4 = ann2.__class__().load_state_dict(ann2.state_dict());  print(ann4(x)) #outputs errors
torch.save( [ann1.state_dict(),F1.state_dict()], name+'annF1.pth'); 
torch.save( [ann2.state_dict(),F2.state_dict()], name+'annF2.pth'); 
a1,d1=torch.load(name+'annF1.pth')
a2,d2=torch.load(name+'annF2.pth') #so far, works as expected
ann3, F3 = ann1.__class__().load_state_dict(a1), F1.__class__().load_state_dict(d1) #outputs errors
ann4, F4 = ann2.__class__().load_state_dict(a2), F2.__class__().load_state_dict(d2) #outputs errors

如您所见,ann1ann2 工作,因为它们产生有效的输出。但是,从给定的 state_dict() (重新)构建模型 ann3ann4 总是会(分别)给出两个错误:

Unexpected key(s) in state_dict: "1.weight", "1.bias", "3.weight", "3.bias", "5.weight", "5.bias". 
TypeError: '_IncompatibleKeys' object is not callable

谁能告诉我如何根据给定的参数正确构建模型,以便我以后可以导出和导入我训练过的模型?

【问题讨论】:

    标签: python import neural-network pytorch export


    【解决方案1】:

    嘿,你有两个问题:

    1. 删除.__class__()
    2. 将 ann3 和 ann4 的定义分开。
    ann1.load_state_dict(ann1.state_dict())
    ann3 = ann1
    print(ann3(x))
    ann2.load_state_dict(ann2.state_dict())
    ann4 = ann2
    print(ann4(x))
    

    但是,thisann1.__class__().load_state_dict(ann1.state_dict())的提议是什么?

    也许你想这样做?

    ann3 = torch.nn.Sequential( torch.nn.Flatten(start_dim=1), 
       lin(784,256), act, lin(256,128), act, lin(128,10), torch.nn.LogSoftmax(dim=1))
    ann3.load_state_dict(ann1.state_dict())
    print(ann3(x))
    
    ann4 = Ann()
    ann4.load_state_dict(ann2.state_dict())
    print(ann4(x))
    

    其工作方式与此处的指南相同,创建具有相同架构的新模型,然后加载已保存/存在的 state_dict。 Saving & Loading Model for Inference

    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    

    【讨论】:

    • 不,我想确保load_state_dict 有效,所以下一步将使用torch.savetorch.load。我有一个脚本,我在其中构造ann,将其训练几个时期,然后将其导出。稍后,我希望运行相同的脚本,导入我的模型并继续训练。我读过建议仅导入/导出ann.state_dict()F.state_dict() 并从这些数据中重构annF。我遇到问题的正是这种重建。我会相应地编辑我的问题,希望你能回答这个问题。
    • ann4 = Ann(); ann4.load_state_dict(ann2.state_dict()) 不适合你?
    • 好的,谢谢。有没有办法只使用ann2 从与ann2 相同的类中创建一个新对象? ann4 = copy.deepcopy(ann2); ann4.load_state_dict(ann2.state_dict()) ?
    • 您的最后一条评论是我一直在寻找的答案。谢谢!
    猜你喜欢
    • 1970-01-01
    • 2021-06-23
    • 1970-01-01
    • 1970-01-01
    • 2019-04-20
    • 1970-01-01
    • 2016-03-05
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多