【问题标题】:Unable to load model weights while predicting (using pytorch)预测时无法加载模型权重(使用 pytorch)
【发布时间】:2020-03-22 02:30:01
【问题描述】:

我已经使用 PyTorch 训练了一个 Mask RCNN 网络,并尝试使用获得的权重来预测图像中苹果的位置..

我正在使用来自 paper 的数据集,这里是 github link to code being used

我只是按照自述文件中提供的说明进行操作。

这是我在提示符下写的命令(删除了个人信息)

python predict_rcnn.py --data_path "my_directory\datasets\apples-minneapple\detection" --output_file "my_directory\samples\apples\predictions" --weight_file "my_directory\samples\apples\weights\model_19.pth" - -mrcnn

model_19.pth 是第 19 个 epoch 之后生成的所有权重的文件

错误如下:

Loading model Traceback (most recent call last): File "predict_rcnn.py", line 122, in <module> main(args) File "predict_rcnn.py", line 77, in main model.load_state_dict(checkpoint['model'], strict=False) KeyError: 'model'

为方便起见,我将粘贴 predict_rcnn.py:

import os
import torch
import torch.utils.data
import torchvision
import numpy as np

from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import utility.utils as utils
import utility.transforms as T


######################################################
# Predict with either a Faster-RCNN or Mask-RCNN predictor
# using the MinneApple dataset
######################################################
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)


def get_maskrcnn_model_instance(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model


def get_frcnn_model_instance(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model


def main(args):
    num_classes = 2
    device = args.device

    # Load the model from
    print("Loading model")
    # Create the correct model type
    if args.mrcnn:
        model = get_maskrcnn_model_instance(num_classes)
    else:
        model = get_frcnn_model_instance(num_classes)

    # Load model parameters and keep on CPU

    checkpoint = torch.load(args.weight_file, map_location=device)
    #checkpoint = torch.load(args.weight_file, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['model'], strict=False)
    model.eval()

    print("Creating data loaders")
    dataset_test = AppleDataset(args.data_path, get_transform(train=False))
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1,
                                                   shuffle=False, num_workers=1,
                                                   collate_fn=utils.collate_fn)

    # Create output directory
    base_path = os.path.dirname(args.output_file)
    if not os.path.exists(base_path):
        os.makedirs(base_path)

    # Predict on bboxes on each image
    f = open(args.output_file, 'a')
    for image, targets in data_loader_test:
        image = list(img.to(device) for img in image)
        outputs = model(image)
        for ii, output in enumerate(outputs):
            img_id = targets[ii]['image_id']
            img_name = data_loader_test.dataset.get_img_name(img_id)
            print("Predicting on image: {}".format(img_name))
            boxes = output['boxes'].detach().numpy()
            scores = output['scores'].detach().numpy()

            im_names = np.repeat(img_name, len(boxes), axis=0)
            stacked = np.hstack((im_names.reshape(len(scores), 1), boxes.astype(int), scores.reshape(len(scores), 1)))

            # File to write predictions to
            np.savetxt(f, stacked, fmt='%s', delimiter=',', newline='\n')


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='PyTorch Detection')
    parser.add_argument('--data_path', required=True, help='path to the data to predict on')
    parser.add_argument('--output_file', required=True, help='path where to write the prediction outputs')
    parser.add_argument('--weight_file', required=True, help='path to the weight file')
    parser.add_argument('--device', default='cuda', help='device to use. Either cpu or cuda')
    model = parser.add_mutually_exclusive_group(required=True)
    model.add_argument('--frcnn', action='store_true', help='use a Faster-RCNN model')
    model.add_argument('--mrcnn', action='store_true', help='use a Mask-RCNN model')

    args = parser.parse_args()
    main(args)

【问题讨论】:

    标签: python tensorflow machine-learning pytorch image-segmentation


    【解决方案1】:

    保存的检查点中没有'model' 参数。如果你看train_rcnn.py:106

    torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
    

    您会看到它们只保存模型参数。应该是这样的:

    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "lr_scheduler": lr_scheduler.state_dict()
    }, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
    

    所以在加载后你会得到一个带有'model' 的字典,以及他们似乎想要保留的其他参数。

    这似乎是他们代码中的一个错误。

    【讨论】:

    • 嗨!我使用您建议的更改训练了网络,但最终出现此错误“加载模型创建数据加载器 Traceback(最近一次调用最后):文件“predict_rcnn.py”,第 122 行,在 main(args) 文件中“predict_rcnn.py”,第 92 行,在 main f = open(args.output_file, 'a') PermissionError: [Errno 13] Permission denied:'D:\\My_folder\\samples\\apples\\predictions'"
    • 忽略以上。我通过创建一个 txt 文件来修复它
    • 能够训练和预测!!谢谢!!
    猜你喜欢
    • 2020-04-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-07-03
    • 2021-04-05
    • 2020-05-24
    • 1970-01-01
    相关资源
    最近更新 更多