【问题标题】:PyTorch - RuntimeError: Error(s) in loading state_dict for VGG:PyTorch - RuntimeError:为 VGG 加载 state_dict 时出错:
【发布时间】:2020-11-10 01:33:27
【问题描述】:

我已经使用 PyTorch 训练了一个模型并保存了一个状态字典文件。我已经使用下面的代码加载了预训练模型。我收到一条关于 RuntimeError: Error(s) in loading state_dict for VGG: 的错误消息:

RuntimeError: Error(s) in loading state_dict for VGG:
    Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 
    Unexpected key(s) in state_dict: "state_dict", "optimizer_state_dict", "globalStep", "train_paths", "test_paths". 

我正在遵循此站点上的说明:https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices

非常感谢

import argparse
import datetime
import glob
import os
import random
import shutil
import time
from os.path import join

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch.optim as optim

from convnet3 import Convnet
from dataset2 import CellsDataset

from convnet3 import Convnet
from VGG import VGG
from dataset2 import CellsDataset
from torchvision import models
from Conv import Conv2d

parser = argparse.ArgumentParser('Predicting hits from pixels')
parser.add_argument('name',type=str,help='Name of experiment')
parser.add_argument('data_dir',type=str,help='Path to data directory containing images and gt.csv')
parser.add_argument('--weight_decay',type=float,default=0.0,help='Weight decay coefficient (something like 10^-5)')
parser.add_argument('--lr',type=float,default=0.0001,help='Learning rate')
args = parser.parse_args()

metadata = pd.read_csv(join(args.data_dir,'gt.csv'))
metadata.set_index('filename', inplace=True)

# create datasets:

dataset = CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
dataset = DataLoader(dataset,num_workers=4,pin_memory=True)
model_path = '/Users/nubstech/Documents/GitHub/CellCountingDirectCount/VGG_model_V1/checkpoints/checkpoint.pth'

class VGG(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG, self).__init__()
        vgg = models.vgg16(pretrained=pretrained)
        # if pretrained:
        vgg.load_state_dict(torch.load(model_path))
        features = list(vgg.features.children())
        self.features4 = nn.Sequential(*features[0:23])


        self.de_pred = nn.Sequential(Conv2d(512, 128, 1, same_padding=True, NL='relu'),
                                     Conv2d(128, 1, 1, same_padding=True, NL='relu'))


    def forward(self, x):
        x = self.features4(x)       
        x = self.de_pred(x)

        return x

model=VGG()
#model.load_state_dict(torch.load(model_path),strict=False)
model.eval()        

#optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

for images, paths in tqdm(dataset):

    targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
    targets = targets.float()

    # code to print training data to a csv file
    #filename=CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
    output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
    preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
    print(preds)
    paths_test = np.array([paths])
    names_preds = np.hstack(paths)
    print(names_preds)                
    df=pd.DataFrame({'Image_Name':names_preds, 'Target':targets.detach(), 'Prediction':preds.detach()})
    print(df) 
    # save image name, targets, and predictions
    df.to_csv(r'model.csv', index=False, mode='a')

保存状态字典的代码

        torch.save({'state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),
                    'globalStep':global_step,
                    'train_paths':dataset_train.files,
                    'test_paths':dataset_test.files},checkpoint_path)

【问题讨论】:

  • 请将代码发布到您保存的位置并加载状态字典。您可以创建一个新的(未经训练的)网络,将其保存并重新加载 - 如果可行,则问题将得到解决。其余代码与问题无关。请发布其他人可以运行并显示问题的最少代码 - 请参阅stackoverflow.com/help/minimal-reproducible-example
  • 我已经用保存状态字典的部分更新了代码。
  • 非常感谢。这足以发现问题。
  • 错误的意思是:首先你保存了你的模型,然后修改了你的网络结构。您需要与您保存的相同的网络结构。

标签: machine-learning deep-learning pytorch


【解决方案1】:

问题是正在保存的内容与预期加载的内容不同。代码试图加载 only state_dict;它节省的远不止这些 - 看起来像另一个带有附加信息的字典中的 state_dict。 load 方法没有任何逻辑可以查看字典。

这应该可行:

import torch, torchvision.models
model = torchvision.models.vgg16()
path = 'test.pth'
torch.save(model.state_dict(), path) # nothing else here
model.load_state_dict(torch.load(path))

【讨论】:

    猜你喜欢
    • 2019-06-01
    • 2023-03-12
    • 2020-09-18
    • 1970-01-01
    • 2021-02-21
    • 2021-11-27
    • 2022-10-19
    • 2020-04-10
    • 1970-01-01
    相关资源
    最近更新 更多