【问题标题】:How to load and use a pretained PyTorch InceptionV3 model to classify an image如何加载和使用预训练的 PyTorch Inception V3 模型对图像进行分类
【发布时间】:2019-05-19 13:37:00
【问题描述】:

我和How can I load and use a PyTorch (.pth.tar) model 有同样的问题,它没有一个可以接受的答案,或者我可以弄清楚如何遵循给出的建议。

我是 PyTorch 的新手。我正在尝试加载此处引用的预训练 PyTorch 模型:https://github.com/macaodha/inat_comp_2018

我很确定我缺少一些胶水。

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')

# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  
    return image.cpu()  #assumes that you're using CPU

image = image_loader("test-image.jpg")

产生错误:

在 () ----> 1 个模型.预测(图像)

AttributeError: 'dict' 对象没有属性 'predict

【问题讨论】:

    标签: python pytorch torch


    【解决方案1】:

    问题

    您的model 实际上不是模特。保存时不仅包含参数,还包含模型的其他信息,形式有点类似于dict。

    因此,torch.load("iNat_2018_InceptionV3.pth.tar") 只是简单地返回dict,它当然没有名为predict 的属性。

    model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
    type(model)
    # dict
    

    解决方案

    在这种情况下,通常情况下,您首先需要做的是实例化您想要的模型类,按照官方指南"Load models"

    # First try
    from torchvision.models import Inception3
    v3 = Inception3()
    v3.load_state_dict(model['state_dict']) # model that was imported in your code.
    

    但是,直接输入model['state_dict'] 会引发一些关于Inception3 参数形状不匹配的错误。

    了解Inception3 在实例化后发生了什么变化很重要。幸运的是,你可以在原作者的train_inat.py找到。

    # What the author has done
    model = inception_v3(pretrained=True)
    model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
    model.aux_logits = False
    

    现在我们知道要更改什么,让我们对我们的第一次尝试进行一些修改。

    # Second try
    from torchvision.models import Inception3
    v3 = Inception3()
    v3.fc = nn.Linear(2048, 8142)
    v3.aux_logits = False
    v3.load_state_dict(model['state_dict']) # model that was imported in your code.
    

    然后你就可以成功加载模型了!

    【讨论】:

      猜你喜欢
      • 2018-03-10
      • 2017-08-17
      • 1970-01-01
      • 2017-03-28
      • 2016-12-18
      • 2020-07-27
      • 2021-07-16
      • 2021-04-05
      • 2018-09-25
      相关资源
      最近更新 更多