是的,您可以在 TorchScript 模型中定义多个入口点,方法是使用 @torch.jit.export 装饰器指定应将哪些方法导出为入口点。
例如,给定一个 PyTorch 模型定义如下:
class MyModel(nn.Module):
def update(self):
# Update some params.
def predict(self, X):
# Predict with some input tensor.
您可以使用 @torch.jit.export 装饰器指定更新和预测方法应作为结果 TorchScript 模块中的入口点导出,如下所示:
class MyModel(nn.Module):
@torch.jit.export
def update(self):
# Update some params.
@torch.jit.export
def predict(self, X):
# Predict with some input tensor.
然后,您可以使用以下代码将 MyModel 类导出到 TorchScript:
model = MyModel()
traced_model = torch.jit.script(model)
生成的 TorchScript 模块将有两个入口点,更新和预测,您可以使用它们来调用模型的相应方法。
traced_model.update()
traced_model.predict(X)
或者,您也可以在类级别使用 torch.jit.export 装饰器来指定类中的所有方法都应导出为生成的 TorchScript 模块中的入口点。例如:
@torch.jit.export
class MyModel(nn.Module):
def update(self):
# Update some params.
def predict(self, X):
# Predict with some input tensor.
在此代码中,@torch.jit.export 装饰器应用于 MyModel 类本身,它告诉 torch.jit.script 函数将 MyModel 类中的所有方法导出为生成的 TorchScript 模块中的入口点。
然后,您可以使用以下代码将 MyModel 类导出到 TorchScript:
model = MyModel()
traced_model = torch.jit.script(model)
生成的 TorchScript 模块将有两个入口点,更新和预测,您可以使用它们来调用模型的相应方法。
traced_model.update()
traced_model.predict(X)