【问题标题】:How to load a torchvision model from disk?如何从磁盘加载torchvision 模型?
【发布时间】:2021-11-30 01:45:06
【问题描述】:

我正在尝试按照top answer here 的解决方案从 .pth 文件加载对象检测模型。

os.environ['TORCH_HOME'] = '../input/torchvision-fasterrcnn-resnet-50/' #setting the environment variable
model = detection.fasterrcnn_resnet50_fpn(pretrained=False).to(DEVICE)

我收到以下错误

NotADirectoryError: [Errno 20] Not a directory: '../input/torchvision-fasterrcnn-resnet-50/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth/hub'

谷歌没有透露该错误的答案,我不知道它的含义,除了显而易见的(那个文件夹“hub”不见了)。

我必须解压或创建文件夹吗? 我已尝试加载权重,但收到相同的错误消息。

这就是我加载模型的方式

model = detection.fasterrcnn_resnet50_fpn(pretrained=True)
checkpoint = torch.load('../input/torchvision-fasterrcnn-resnet-50/model.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

感谢您的帮助!

完整的错误跟踪:

gaierror: [Errno -3] Temporary failure in name resolution

During handling of the above exception, another exception occurred:

URLError                                  Traceback (most recent call last)
/tmp/ipykernel_42/1218627017.py in <module>
      1 # to load
----> 2 model = detection.fasterrcnn_resnet50_fpn(pretrained=True)
      3 checkpoint = torch.load('../input/torchvision-fasterrcnn-resnet-50/model.pth.tar')
      4 model.load_state_dict(checkpoint['state_dict'])

/opt/conda/lib/python3.7/site-packages/torchvision/models/detection/faster_rcnn.py in fasterrcnn_resnet50_fpn(pretrained, progress, num_classes, pretrained_backbone, trainable_backbone_layers, **kwargs)
    360     if pretrained:
    361         state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
--> 362                                               progress=progress)
    363         model.load_state_dict(state_dict)
    364     return model

/opt/conda/lib/python3.7/site-packages/torch/hub.py in load_state_dict_from_url(url, model_dir, map_location, progress, check_hash, file_name)
    553             r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
    554             hash_prefix = r.group(1) if r else None
--> 555         download_url_to_file(url, cached_file, hash_prefix, progress=progress)
    556 
    557     if _is_legacy_zip_format(cached_file):

/opt/conda/lib/python3.7/site-packages/torch/hub.py in download_url_to_file(url, dst, hash_prefix, progress)
    423     # certificates in older Python
    424     req = Request(url, headers={"User-Agent": "torch.hub"})
--> 425     u = urlopen(req)
    426     meta = u.info()
    427     if hasattr(meta, 'getheaders'):

/opt/conda/lib/python3.7/urllib/request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
    220     else:
    221         opener = _opener
--> 222     return opener.open(url, data, timeout)
    223 
    224 def install_opener(opener):

/opt/conda/lib/python3.7/urllib/request.py in open(self, fullurl, data, timeout)
    523             req = meth(req)
    524 
--> 525         response = self._open(req, data)
    526 
    527         # post-process response

/opt/conda/lib/python3.7/urllib/request.py in _open(self, req, data)
    541         protocol = req.type
    542         result = self._call_chain(self.handle_open, protocol, protocol +
--> 543                                   '_open', req)
    544         if result:
    545             return result

/opt/conda/lib/python3.7/urllib/request.py in _call_chain(self, chain, kind, meth_name, *args)
    501         for handler in handlers:
    502             func = getattr(handler, meth_name)
--> 503             result = func(*args)
    504             if result is not None:
    505                 return result

/opt/conda/lib/python3.7/urllib/request.py in https_open(self, req)
   1391         def https_open(self, req):
   1392             return self.do_open(http.client.HTTPSConnection, req,
-> 1393                 context=self._context, check_hostname=self._check_hostname)
   1394 
   1395         https_request = AbstractHTTPHandler.do_request_

/opt/conda/lib/python3.7/urllib/request.py in do_open(self, http_class, req, **http_conn_args)
   1350                           encode_chunked=req.has_header('Transfer-encoding'))
   1351             except OSError as err: # timeout error
-> 1352                 raise URLError(err)
   1353             r = h.getresponse()
   1354         except:

URLError: <urlopen error [Errno -3] Temporary failure in name resolution>

【问题讨论】:

  • 如果您没有加载任何预训练模型,为什么还要更新 TORCH_HOME
  • @Ivan 我正在加载一个预训练模型,我将 .path 文件保存到了 torch_home 中定义的目录
  • 你能显示你正在加载模型的行吗?
  • @Ivan 嗨,我添加了一行,根据我认为只需要的评论,但我在根据加载教程使用 torch.save 保存后也尝试了上述(现已编辑)和储蓄。但是我已经在第一行得到了同样的错误
  • 好的,能否提供完整的错误回溯? *“但是我已经在第一行得到了同样的错误”*那是因为你设置了pretrained=True,在TORCH_HOME下找不到/hub子目录。

标签: pytorch torch torchvision


【解决方案1】:

如果您正在加载预训练的网络,则不需要从 torchvision pretrained 加载模型(如 pretrained by torchvision 在 ImageNet 上使用 pretrained=True) .你有两个选择:

  1. 设置pretrained=False 并使用以下方法加载您的权重:

    checkpoint = torch.load('../input/torchvision-fasterrcnn-resnet-50/model.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    
  2. 或者,如果您决定更改 TORCH_HOME(这并不理想),您需要保持 Torchvision 的相同目录结构:

    inputs/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth 
    

    实际上,您不会为了加载一个模型而更改 TORCH_HOME

【讨论】:

  • 当我由于 pretrained = False 我仍然收到相同的错误:我在更新的跟踪中发布了该错误。感谢您的耐心等待
  • 您在没有互联网的情况下尝试过您的解决方案吗?
  • 如果您已经在本地保存了体重,则无需访问互联网。
  • 谢谢 ivan,但也许我的速度很慢,我见过类似的其他问题。您需要定义模型类和结构吗?或者你将如何在没有互联网的情况下实例化模型?
  • 我无法在没有互联网的情况下运行它:model = detection.fasterrcnn_resnet50_fpn(pretrained=False) checkpoint = torch.load('../input/torchvision-fasterrcnn-resnet-50/model.pth.tar') model.load_state_dict(checkpoint['state_dict'])
【解决方案2】:

我找到了深入github的解决方案,这个问题有点隐藏。

检测.() 除了 pretrained 之外还有一个默认参数,称为 pretrained_backbone,默认设置为 true,如果为 True,则将模型设置为从 url 的字典路径下载。

这将起作用:

detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone = False, num_classes = 91).

然后像往常一样加载模型。 num_classes 是预期的,在文档中它是默认值 = 91,但在 github 中我将其视为无,这就是为什么我在此处添加它以确保安全。

【讨论】:

    猜你喜欢
    • 2019-02-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-01-08
    • 2020-10-15
    • 2017-07-28
    • 2010-11-02
    • 2018-10-23
    相关资源
    最近更新 更多