【发布时间】:2018-08-05 23:25:16
【问题描述】:
我正在使用 scikit-learn 构建文档分类系统,它运行良好。我正在将模型转换为 Core ML 模型格式。但是模型格式除了输入参数为multiArrayType。我想让它排除字符串或字符串数组,以便我可以轻松地从 IOS 应用程序中预测。我尝试了以下方法:
from sklearn.linear_model import LogisticRegression
logreg = LogisticRegression()
logreg.fit(X_train_dtm, y_train)
#testing a value
docs_new = ['get exclusive prize offer']
docs_pred_class = nb.predict(count_vect.transform(docs_new))
#Exporting to coremodel
import coremltools
coreml_model = coremltools.converters.sklearn.convert(logreg)
#print model
coreml_model
打印 coreml 模型会得到以下输出:
input {
name: "input"
type {
multiArrayType {
shape: 7505
dataType: DOUBLE
}
}
}
output {
name: "classLabel"
type {
int64Type {
}
}
}
output {
name: "classProbability"
type {
dictionaryType {
int64KeyType {
}
}
}
}
predictedFeatureName: "classLabel"
predictedProbabilitiesName: "classProbability"
我检查了 GitHub 库中的Core ML model,我可以看到有不同的输入和输出。
我怎样才能做到这一点,以便我可以从 IOS 应用程序传递一个简单的参数来进行预测。
【问题讨论】:
标签: python scikit-learn text-classification coreml coremltools