【问题标题】:How to plot ROC Curve for multiclass data and measure MAUC from confusion matrix如何绘制多类数据的 ROC 曲线并从混淆矩阵测量 MAUC
【发布时间】:2020-04-03 11:04:44
【问题描述】:

我在具有 3 个类的数据集上使用了反向传播:L、B、R。在制作神经网络后,我还制作了一个混淆矩阵。

实际类数组:

sample_test = array([0, 1, 0, 2, 0, 2, 1, 1, 0, 1, 1, 1], dtype=int64)

预测的类数组:

yp = array([0, 1, 0, 2, 0, 2, 0, 1, 0, 1, 1, 1], dtype=int64)

混淆矩阵代码:

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels

class_names = ['B','R','L']

def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    classes = [0, 1, 2]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plot_confusion_matrix(sample_test, yp, classes=class_names, 
                      title='Confusion matrix, without normalization')

# Plot normalized confusion matrix
plot_confusion_matrix(sample_test, yp, classes=class_names , normalize=True,
                      title='Normalized confusion matrix')

plt.show()

输出:

现在我想为此绘制 ROC 曲线并计算 MAUC。我看到了documentation,但无法正确理解该怎么做。

如果有人可以通过提供一些建议来帮助我,我将非常感激。提前致谢。

【问题讨论】:

    标签: python-3.x machine-learning roc confusion-matrix


    【解决方案1】:

    ROC 是按类计算的 - 将每个类视为“正”类,将其他类视为“负”类。注意 - 首先你必须使用 predict_proba() - 来获得每个类别的预测概率。像这样的:

    import seaborn as sns
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import train_test_split
    from sklearn import preprocessing
    from sklearn.metrics import roc_auc_score
    
    iris = sns.load_dataset('iris')
    X = iris.drop('species',axis=1)
    y = iris['species']
    X_train, X_test, y_train, y_test = train_test_split(X,y)
    
    le = preprocessing.LabelEncoder()
    le.fit(y_train)
    le.transform(y_train)
    
    model = DecisionTreeClassifier(max_depth=1)
    model.fit(X_train,le.transform(y_train))
    
    predictions =pd.DataFrame(model.predict_proba(X_test),columns=list(le.inverse_transform(model.classes_)))
    
    print(roc_auc_score((y_test == 'versicolor').astype(float), predictions['versicolor']))
    

    【讨论】:

    • 此值为 1 类。我对吗?那么如何获得其他类的 auc 值以及如何将它们结合起来计算 MAUC?你能给我提供任何文件吗?感谢您的帮助。
    • 看看pdfs.semanticscholar.org/780d/…(等式3),您可以使用那里的公式来组合成对的AUC
    猜你喜欢
    • 2019-12-26
    • 2017-08-29
    • 2018-12-26
    • 2018-09-23
    • 2016-01-15
    • 2016-11-28
    • 2022-06-15
    • 1970-01-01
    相关资源
    最近更新 更多