【问题标题】:Error when converting PyTorch model to TorchScript将 PyTorch 模型转换为 TorchScript 时出错
【发布时间】:2019-05-18 02:41:40
【问题描述】:

我正在尝试关注PyTorch guide to load models in C++

以下示例代码有效:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

但是,当尝试其他网络时,例如squeezenet(或alexnet),我的代码失败了:

sq = torchvision.models.squeezenet1_0(pretrained=True)
traced_script_module = torch.jit.trace(sq, example) 

>> traced_script_module = torch.jit.trace(sq, example)                                      
/home/fabio/.local/lib/python3.6/site-packages/torch/jit/__init__.py:642: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function.
 Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 785] (3.1476082801818848 vs. 3.945478677749634) and 999 other locations (100.00%)
  _check_trace([example_inputs], func, executor_options, module, check_tolerance, _force_outplace)

【问题讨论】:

    标签: pytorch torchscript


    【解决方案1】:

    我刚刚发现从torchvision.models 加载的模型默认处于训练模式。 AlexNet 和 SqueezeNet 都有 Dropout 层,如果在训练模式下,推理是不确定的。只需更改为评估模式即可解决问题:

    sq = torchvision.models.squeezenet1_0(pretrained=True)
    sq.eval()
    traced_script_module = torch.jit.trace(sq, example) 
    

    【讨论】:

      猜你喜欢
      • 2022-08-05
      • 2022-09-26
      • 1970-01-01
      • 1970-01-01
      • 2022-10-13
      • 2023-01-31
      • 2018-10-05
      • 1970-01-01
      • 2019-02-01
      相关资源
      最近更新 更多