【问题标题】:How to plot precision and recall of multiclass classifier?如何绘制多类分类器的精度和召回率?
【发布时间】:2019-09-29 03:02:47
【问题描述】:

我正在使用 scikit learn,我想绘制准确率和召回率曲线。我使用的分类器是RandomForestClassifier。 scikit learn 文档中的所有资源都使用二进制分类。另外,我可以为多类绘制 ROC 曲线吗?

另外,我只找到了用于多标签的 SVM,它有一个 decision_function,而 RandomForest 没有

【问题讨论】:

标签: python matplotlib scikit-learn roc precision-recall


【解决方案1】:

来自 scikit-learn 文档:

Precision-recall 曲线通常用于二元分类中 研究分类器的输出。为了延长 精度召回曲线和平均精度到多类或 多标签分类,需要对输出进行二值化处理。 每个标签可以绘制一条曲线,但也可以绘制一条曲线 通过考虑标签的每个元素的精确召回曲线 指标矩阵作为二元预测(微平均)。

ROC 曲线通常用于二元分类来研究 分类器的输出。为了将 ROC 曲线和 ROC 面积扩展到 多类或多标签分类,需要二值化 输出。每个标签可以绘制一条 ROC 曲线,但也可以绘制一条 ROC 曲线 通过考虑标签指标的每个元素绘制ROC曲线 矩阵作为二元预测(微平均)。

因此,您应该对输出进行二值化,并考虑每个类的精确召回和 roc 曲线。此外,您将使用predict_proba 来获取类概率。

我把代码分成三部分:

  1. 常规设置、学习和预测
  2. 精确召回曲线
  3. ROC 曲线

1.一般设置、学习和预测

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_openml("mnist_784")
y = mnist.target
y = y.astype(np.uint8)
n_classes = len(set(y))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)

2。精确召回曲线

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()

3. ROC曲线

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

【讨论】:

  • 为什么我使用 OneVsRestClassifier? RandomForest 不是已经支持多类了吗?
  • 我在运行第一部分时遇到了这些错误: UserWarning: Label not 0 is present in all tr​​aining example UserWarning: Label not 1 is present in all tr​​aining example UserWarning: Label not 2 is present in all训练示例
  • 请注意,警告不是错误。考虑到这一行Y = label_binarize(mnist.target, classes=[*range(n_classes)]),您应该在数据集中提供类。在我的示例中,这些类是[0,1,2,...,9]
  • 如何用微平均创建PR曲线或ROC曲线?据我所知,如果你有 3 个类别,你将获得 3 个概率向量,每个类别的概率为 1。然后观察被分配到概率最高的类。也就是说,与阈值无关。但是对于 ROC 和 PR 曲线,您需要一个阈值,那么您将如何进行微平均呢?如何根据特定阈值将观察结果分配给类?
  • 我只是尝试在阈值等于 0 时反向计算精度和召回率,并查看它是否与分类报告()函数给出的匹配,但它返回的结果却出奇地不同。我在这里解决这个问题:stats.stackexchange.com/questions/559203/…
猜你喜欢
  • 2018-01-18
  • 2019-03-31
  • 2023-03-08
  • 2016-06-19
  • 2013-10-01
  • 1970-01-01
  • 2016-01-09
  • 2021-05-05
  • 2015-07-22
相关资源
最近更新 更多