【问题标题】:How to load a Pytorch model when the parameters are saved as numpy arrays?参数保存为numpy数组时如何加载Pytorch模型?
【发布时间】:2021-03-19 00:54:05
【问题描述】:

this GitHub repo,我已经下载了预训练模型senet50_ft

我是这样加载的:

import pickle
f = open('pretrained_models/senet50_ft_weight.pkl', 'rb')
state_dict = pickle.load(f, encoding='latin1')
f.close()

状态加载完毕,Github repos也提供了SENet模型类here

所以我设法实例化了该模型:

model = senet.senet50()

然后我尝试加载状态,但出现错误:

model.load_state_dict(state_dict)

Traceback (most recent call last):
  File "...\module.py", line 982, in _load_from_state_dict
    param.copy_(input_param)
TypeError: copy_(): argument 'other' (position 1) must be Tensor, not numpy.ndarray

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "...\module.py", line 1037, in load_state_dict
    load(self)
  File "...\module.py", line 1035, in load
    load(child, prefix + name + '.')
  File "...\module.py", line 1032, in load
    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
  File "...\module.py", line 988, in _load_from_state_dict
    .format(key, param.size(), input_param.size(), ex.args))
TypeError: 'int' object is not callable

我尝试通过执行以下操作将ndarray 转换为Tensor

for key in state_dict.keys():
    state_dict[key] = torch.from_numpy(state_dict[key])

但我又遇到了一个错误,我想我哪儿也不去。

我是 PyTorch 的新手,但我怀疑这个模型是用旧版本的 PyTorch 序列化的。你知道是否有解决方案吗?

【问题讨论】:

    标签: python pytorch torch


    【解决方案1】:

    他们有一个 load_state_dict 函数,可以满足您的需求。

    【讨论】:

    • 谢谢,它有效,我还必须用正确数量的类实例化模型,例如 demo.py 中的 8631
    猜你喜欢
    • 2021-12-15
    • 2019-09-26
    • 1970-01-01
    • 1970-01-01
    • 2017-12-25
    • 2012-09-24
    • 2021-06-23
    • 2017-11-09
    相关资源
    最近更新 更多