【问题标题】:How to get Top N predictions using sklearn's SGDClassifier如何使用 sklearn 的 SGDClassifier 获得 Top N 预测
【发布时间】:2019-03-12 22:06:04
【问题描述】:

我尝试使用 scikit 的 SGDClassifier 设置一个简单的文本分类任务,并尝试获取前 N 个预测,包括它们的概率。作为样本训练数据,我有三个类

  • 苹果
  • 柠檬
  • 橙子

每个类一个文档:

  • 在苹果中:“苹果和柠檬”
  • 在柠檬中:“柠檬和橙子”
  • 在橙子中:“橙子和苹果”

我现在想预测三个测试文档“apple”、“lemon”和“orange”,并希望获得每个文档的 Top-2-Predictions,包括它们的概率。到目前为止,我的代码如下所示:

from sklearn.linear_model import SGDClassifier
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
import numpy as np

train = load_files('data/test/')

text_clf_svm = Pipeline([('vect', CountVectorizer()), ('tfidf', TfidfTransformer()),
                     ('clf-svm', SGDClassifier(loss='modified_huber', penalty='l2',alpha=1e-3, n_iter=5, random_state=42))])
text_clf_svm = text_clf_svm.fit(train.data, train.target)

docs=['apple', 'orange', 'lemon']
predicted = text_clf_svm.predict(docs)
#Perform a Top 1 prediction
for doc, category in zip(docs, predicted):
    print('%r => %s' % (doc, train.target_names[category]))

# Perform a top 2 prediction
print(np.argsort(text_clf_svm.predict_proba(docs), axis=1)[-2:])

我的输出如下:

'apple' => apples
'orange' => lemons
'lemon' => lemons
[[1 2 0]
[0 1 2]]

我现在难以解释数据。我真正想摆脱的是:

'apple' => apples (0.54...), lemons (0.43...)
'orange' => apples (0.48...), oranges (0.43...)
'lemon' => lemons (0.48...), oranges (0.43...)

有人能告诉我怎么做吗?提前感谢您的帮助!

【问题讨论】:

    标签: python scikit-learn multilabel-classification


    【解决方案1】:

    您正在使用argsort,argsort所做的就是它为您提供了排序阵列的索引,所以您应该做的是如下:

    preds = text_clf_svm.predict_proba(docs)
    preds_idx = np.argsort(preds, axis=1)[-2:]
    
    for i,d in enumerate(docs):
        print d,"=>"
        for p in preds_idx[i]:
            print(text_clf_svm.classes_[p],"(",preds[i][p],")")
    

    只是重新格式化打印到您的风格,您将拥有所需的内容:)

    【讨论】:

    • 亲爱的Imtinan,很酷,对我来说完全努力。非常感谢!
    • 没有问题,乐于帮助:)只接受答案所以将来任何人都可以参考这个 span>
    【解决方案2】:

    @Imtinan 答案的快速补充,因为该答案将您的标签排序为第二高,然后是第一高概率(升序)。如果您希望它按降序排列,只需修改:

    preds_idx = np.argsort(-preds, axis = 1)[ :2]

    【讨论】:

      猜你喜欢
      • 2015-12-04
      • 2019-07-18
      • 2019-03-08
      • 2020-05-18
      • 1970-01-01
      • 1970-01-01
      • 2020-06-01
      • 2013-06-08
      • 1970-01-01
      相关资源
      最近更新 更多