【问题标题】:visualize predict_proba for multiclass classification可视化 predict_proba 进行多类分类
【发布时间】:2020-09-06 16:16:54
【问题描述】:

有了model.predict_proba(X),我得到了一个包含大量数字的大数组。

我正在寻找一种方法来可视化所有类的分类概率(在我的例子中是 13)。我使用RandomForestClassifier

有什么推荐吗?

【问题讨论】:

  • 你输入空间的维度是多少?
  • 一共13个班,从几百到几千不等。
  • 恕我直言,像@Venkatachalam 这样的方法不会提供有用的可视化,因为您需要对输出模式有所了解——我的第一个想法是根据输入来做,比如我的回答如下。在您的情况下,这需要首先从 d=2000 或其他任何东西到 d=2 进行降维,但这对于高维度数据来说并不难或不寻常。

标签: scikit-learn classification data-visualization


【解决方案1】:

热图是可视化二维矩阵的好方法。当然,如果您的 X 中的记录数量很大,则很难一次将所有内容可视化。否则您可能必须对记录进行抽样。这里我展示了前 10 条记录的视觉效果,如果预测概率大于0.1,则标记预测类别。

看看这个例子:

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np


X, y = make_classification(n_samples=10000,n_features=40,
                           n_informative=30, n_classes=13,
                           n_redundant=0, n_clusters_per_class=1,
                           random_state=42)


X_train, X_test, y_train, y_test = train_test_split(X,y, random_state=42)

forest = RandomForestClassifier(n_estimators=10, random_state=42).fit(X_train, y_train)

pred = forest.predict_proba(X_test)[:10]
fig, ax = plt.subplots(figsize= (20,8))
im = ax.imshow(pred, cmap='Blues')

ax.grid(axis='y')
ax.set_xticklabels([])

ax.set_yticks(np.arange(pred.shape[0]))

plt.ylabel('Records', fontsize='xx-large')
plt.xlabel('Classes', fontsize='xx-large')
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) 

for i in range(pred.shape[0]):
    for j in range(13):
        if pred[i, j] >.1:
             ax.text(j, i, j,
                     ha="center", va="center", color="w", fontsize=30)

【讨论】:

    【解决方案2】:

    如果您的输入空间是二维的,或者如果您使用某种降​​维技术将其嵌入到二维中,您可以绘制多类决策面:

    # generate toy data
    X, y = sklearn.datasets.make_blobs(n_samples=1000, centers=13)
    
    # fit classifier
    clf = sklearn.ensemble.RandomForestClassifier().fit(X, y)
    
    # create decision surface
    xx, yy = np.meshgrid(np.linspace(-13, 12, 100),
                         np.linspace(-13, 12, 100))
    Z = clf.predict(np.array([xx.ravel(), yy.ravel()]).T)
    Z = Z.reshape(xx.shape)
    
    fig, ax = plt.subplots(1,1, figsize=(8,8))
    ax.scatter(X[:,0], X[:,1], c=y, cmap='Paired')
    ax.contourf(xx, yy, Z, cmap='Paired', alpha=0.5)
    

    请注意,这只是每个标签的阴影(predict 不是 predict_proba),但您可以根据概率将其扩展为不同的阴影。

    【讨论】:

      猜你喜欢
      • 2020-06-14
      • 1970-01-01
      • 2018-08-14
      • 2017-08-29
      • 2012-02-28
      • 1970-01-01
      • 2018-10-23
      • 1970-01-01
      • 2020-01-19
      相关资源
      最近更新 更多