【问题标题】:Loading trained model to make prediction of single image加载经过训练的模型以预测单个图像
【发布时间】:2021-09-25 22:04:24
【问题描述】:

我已经在英特尔图像多类分类任务上训练了一个 ResNet50 模型。任务是尝试预测图像,无论是街道还是冰川等。该模型已成功训练并能够进行预测。我已保存模型并尝试在新图像上使用保存的模型。 这是训练的代码

import os
import torch
import tarfile
import torchvision
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import random_split
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.utils import download_url
import PIL
import PIL.Image
import numpy as np
transform_train=transforms.Compose([
                                    transforms.Resize((150,150)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomVerticalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
transform_test=transforms.Compose([
                                   transforms.Resize((150,150)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
...
torch.save(model2.state_dict(),'/content/drive/MyDrive/saved_model/model_resnet.pth')

当我在其他文件中调用模型时,我使用了类似的图像转换,但是它给了我一个错误,这是代码和错误

model = torch.load('/content/drive/MyDrive/saved_model/model_resnet.pth')
image=Image.open(Path('/content/drive/MyDrive/images/seg_pred/seg_pred/10004.jpg'))
transform_train=transforms.Compose([
                                    transforms.Resize((150,150)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomVerticalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
input = transform_train(image)

#input = input.view(1, 3, 150,150)

output = model(input)

prediction = int(torch.max(output.data, 1)[1].numpy())
print(prediction)

给我的错误是

TypeError: 'collections.OrderedDict' object is not callable

我的 pytorch 版本是 1.9.0+cu102

【问题讨论】:

    标签: python-3.x machine-learning pytorch multiclass-classification


    【解决方案1】:

    你需要先创建模型的结构,类似于在你的训练代码上创建model2,可以是这样的:

    model = resnet()
    

    然后加载保存的状态字典:

    model.load_state_dict(torch.load('/content/drive/MyDrive/saved_model/model_resnet.pth'))
    model.eval()
    

    参考:

    https://pytorch.org/tutorials/beginner/saving_loading_models.html

    【讨论】:

    • 谢谢,我做了 model=models.resnet50(), model.load_state_dict(torch.load('/content/drive/MyDrive/saved_model/model_resnet.pth'))。它给了我另一个错误 RuntimeError: Error(s) in loading state_dict for ResNet:
    • 我正在尝试这个教程,kaggle.com/awadhi123/… 并保存模型,我想预测一个新的图像数据。感谢您的帮助
    • @rickyfajrin93 你已经修改了 resnet 模型,首先加载模型例如mymodel = IntelCnnModelresnet(),然后加载 state_dict:mymodel.load_state_dict(...)
    【解决方案2】:

    根据您的问题,您显然希望在新图像上进行预测。但是您正在尝试使用transform 来增强和转换图像,这不是获得预测的正确方法。

    因此,由于您提供的代码链接包含大量代码,您可以像在代码中一样使用它们。

    我正在分享 fast.ai 和简单的 `TensorFlow 代码,您可以通过它们预测新图像,然后能够看到结果。

    img = open_image('any_image.jpg')
    print(learn.predict(img)[0])
    

    或者你可以试试这个功能:

    import matplotlib.pyplot as plt # visualization
    import matplotlib.image as mpimg
    import tensorflow as tf # Deep Learning Framework
    import pathlib
    def pred_plot(file, model, class_names=class_names, image_size=(150, 150)):
        img = tf.io.read_file(file)
        img = tf.io.decode_image(img, channels=3)
        img = tf.image.resize(img, size=image_size)
        
        pred_probs = model.predict(tf.expand_dims(img, axis=0))
        pred_class = class_names[pred_probs.argmax()]
        
        plt.imshow(img/225.)
        plt.title(f'Pred: {pred_class}')
        plt.axis(False);
    

    传递任何图像,您将获得可视化的预测。

    url ='dummy.jpg'
    pred_plot(url, model=model_2, class_names=class_names)
    

    【讨论】:

      猜你喜欢
      • 2018-10-12
      • 1970-01-01
      • 2021-09-12
      • 1970-01-01
      • 2017-07-28
      • 2020-07-04
      • 1970-01-01
      • 1970-01-01
      • 2018-04-21
      相关资源
      最近更新 更多