【发布时间】:2021-05-24 00:37:21
【问题描述】:
我从 [this repository][1] 获得了一个 pytorch 模型,我必须将其转换为 tflite。 代码如下:
def get_torch_model(model_path):
"""
Loads state-dict into model and creates an instance.
"""
model= torch.load(model_path)
return model
# Conversion
import torch
from torchvision import transforms
import onnx
import cv2
import numpy as np
import onnx
import tensorflow as tf
import torch
from PIL import Image
import torch.onnx
image, tf_lite_image, sample_input = get_sample_input("crop.jpg")
torch_model = get_torch_model("pose_resnet_152_256x256.pth")
ONNX_FILE = "./m_model.onnx"
到这里为止,一切都很顺利。但是当我运行下面的单元格时:
torch.onnx.export(
model=torch_model,
args=sample_input,
f=ONNX_FILE,
verbose=False,
export_params=True,
do_constant_folding=False, # fold constant values for optimization
input_names=['input'],
opset_version=10,
output_names=['output']
)
onnx_model = onnx.load(ONNX_FILE)
onnx.checker.check_model(onnx_model)
完整的错误日志:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-33-15df717ec276> in <module>
8 input_names=['input'],
9 opset_version=10,
---> 10 output_names=['output']
11 )
12
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
274 do_constant_folding, example_outputs,
275 strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276 custom_opsets, enable_onnx_checker, use_external_data_format)
277
278
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
92 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
93 custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94 use_external_data_format=use_external_data_format)
95
96
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
677 _set_opset_version(opset_version)
678 _set_operator_export_type(operator_export_type)
--> 679 with select_model_mode_for_export(model, training):
680 val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
681 operator_export_type,
~\anaconda3\envs\py36\lib\contextlib.py in __enter__(self)
79 def __enter__(self):
80 try:
---> 81 return next(self.gen)
82 except StopIteration:
83 raise RuntimeError("generator didn't yield") from None
~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in select_model_mode_for_export(model, mode)
36 def select_model_mode_for_export(model, mode):
37 if not isinstance(model, torch.jit.ScriptFunction):
---> 38 is_originally_training = model.training
39
40 if mode is None:
AttributeError: 'collections.OrderedDict' object has no attribute 'training'
当我使用 torch.onnx.export() 时出现此错误。
请让我知道这里出了什么问题。 我没有正确加载重量吗?如果没有,那么我如何加载模型?我不知道类或架构细节,那么我该如何使用 model.load_state_dict() ??
[1]: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
【问题讨论】:
-
我猜你加载了错误的对象,这说明你有一个属性错误,这通常发生在你试图访问一个对象没有的属性时。因此,请检查您是否加载了模型或其他对象。还请说出导致此错误的第二个单元格的确切说明。
-
我不确定,但我认为我没有加载任何错误的对象。错误指向 torch.onnx.export() 的
output_names=['output']行,我尝试将其删除,因此它只指向之前的行。 -
打印模型对象并检查它是否打印出类似包含torch.nn模块的元组。它肯定会打印一个 OrderedDict 对象。实际模型应该在 OrderedDict 中。使用索引来访问模型。不要传递 OrderedDict 对象。
-
@Shai 回答正确,需要导入模型的类来加载训练好的权重。