【发布时间】:2021-09-25 22:04:24
【问题描述】:
我已经在英特尔图像多类分类任务上训练了一个 ResNet50 模型。任务是尝试预测图像,无论是街道还是冰川等。该模型已成功训练并能够进行预测。我已保存模型并尝试在新图像上使用保存的模型。 这是训练的代码
import os
import torch
import tarfile
import torchvision
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import random_split
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.utils import download_url
import PIL
import PIL.Image
import numpy as np
transform_train=transforms.Compose([
transforms.Resize((150,150)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
transform_test=transforms.Compose([
transforms.Resize((150,150)),
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
...
torch.save(model2.state_dict(),'/content/drive/MyDrive/saved_model/model_resnet.pth')
当我在其他文件中调用模型时,我使用了类似的图像转换,但是它给了我一个错误,这是代码和错误
model = torch.load('/content/drive/MyDrive/saved_model/model_resnet.pth')
image=Image.open(Path('/content/drive/MyDrive/images/seg_pred/seg_pred/10004.jpg'))
transform_train=transforms.Compose([
transforms.Resize((150,150)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((.5,.5,.5),(.5,.5,.5))
])
input = transform_train(image)
#input = input.view(1, 3, 150,150)
output = model(input)
prediction = int(torch.max(output.data, 1)[1].numpy())
print(prediction)
给我的错误是
TypeError: 'collections.OrderedDict' object is not callable
我的 pytorch 版本是 1.9.0+cu102
【问题讨论】:
标签: python-3.x machine-learning pytorch multiclass-classification