【问题标题】:Error Making prediction with python onnxruntime使用 python onnxruntime 进行预测时出错
【发布时间】:2020-03-21 15:46:17
【问题描述】:

我使用sklearn 库创建了一个非常基本的决策树。这棵树基于 4 个特征进行训练:

feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT

标签/目标特征是一个布尔值(0 或 1)。

我将树转换为ONNX 格式,现在我想使用onnxruntime python 库进行预测。我在互联网上找到了执行此操作的示例代码。问题是我不完全理解这段代码、函数和参数的所有部分到底发生了什么。这导致我得到一个错误。我确实搜索了一些文档,但我找不到这个。

在下面的代码中,我将树模型转换为 ONNX 格式。这是成功的,但部分代码我不明白。在initial_type 变量中,根据我之前提到的 4 个特征列和标签/目标特征,我必须在此处输入什么?现在我输入了FloatTensorType([None, 4],因为我有4 个特征列,而None 是什么我不知道。

##Convert to ONNX format

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
    f.write(onx.SerializeToString())

在下面的代码中,我想使用 onnxruntime 库进行预测,但出现此错误:

RuntimeError: Either type_proto was null or it was not of sequence type

这是因为我看不懂下面的最后一行代码。我输入了这个{input_name: [4, 8, 77.8, 143.45],因为这是特征列的四个值。我在这里做错了什么?

sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]

【问题讨论】:

    标签: python scikit-learn sklearn-pandas onnx onnxruntime


    【解决方案1】:

    你试过{input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)}吗? onnxruntime 需要 numpy 数组作为输入。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-07-24
      • 2014-01-03
      • 2019-10-29
      • 1970-01-01
      • 1970-01-01
      • 2013-05-31
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多