【发布时间】:2021-03-22 00:01:42
【问题描述】:
我正在尝试从 roc_curve() 获取 tpr(true positive rate) 和 fpr(false positive rate),然后是 auc score(),然后可以绘制图表以查看我的模型在多标签(500 个标签) 数据不平衡但出现错误。
我正在计算每个标签预测的概率,以便我可以更改阈值以获得更好的精度、召回率和准确度,并在预测时获得大多数目标标签。
代码:
from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import ClassifierChain
rfc = RandomForestClassifier(n_jobs = -1, random_state =0, class_weight = 'balanced')
clf2 = ClassifierChain(rfc)
clf2.fit(X_train , y_train)
y_pred = clf2.predict_proba(X_test)
y_pred.shape
>> (8125,500)
y_pred[0]
>> array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01, 0. , 0. , 0.01, 0. , 0.01, 0. , 0. , 0. ,
0. , 0.01, 0. , 0. , 0. , 0.01, 0. , 0.01, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01, 0. ,
0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.03, 0. , 0. , 0. , 0. , 0. , 0.01,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.5 , 0.01, 0. , 0. , 0. , 0. , 0.01, 0. ,
0. , 0.05, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01, 0. , 0.02, 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.03, 0.04, 0. ,
0. , 0. , 0.01, 0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0.01, 0. , 0. , 0.02, 0. ,
0. , 0.01, 0. , 0.01, 0. , 0.28, 0. , 0. , 0. , 0. , 0.01,
0. , 0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0.01, 0. , 0. ,
0.01, 0. , 0. , 0. , 0. , 0. , 0.02, 0.07, 0. , 0.01, 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01, 0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.02, 0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0.02, 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.01, 0. , 0. , 0.02, 0.01, 0. , 0. ,
0. , 0. , 0. , 0.01, 0. , 0. , 0.01, 0. , 0. , 0.01, 0. ,
0. , 0. , 0. , 0.03, 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0.15, 0. , 0. , 0.02, 0. ,
0.01, 0. , 0.11, 0. , 0.01, 0. , 0. , 0. , 0. , 0.02, 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.02, 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0.1 , 0.02, 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01, 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.01, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.02,
0. , 0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01, 0. , 0. , 0.01, 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ]])
from sklearn.metrics import roc_auc_score,roc_curve,precision_recall_curve
fpr, tpr, thresholds = roc_curve(y_test,y_pred)
最后一行代码给了我错误。
追溯:
ValueError Traceback (most recent call last)
<ipython-input-72-ea45ece64953> in <module>()
1 from sklearn.metrics import roc_auc_score,roc_curve,precision_recall_curve
----> 2 fpr, tpr, thresholds = roc_curve(y_test,y_pred)
1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
534 if not (y_type == "binary" or
535 (y_type == "multiclass" and pos_label is not None)):
--> 536 raise ValueError("{0} format is not supported".format(y_type))
537
538 check_consistent_length(y_true, y_score, sample_weight)
ValueError: multilabel-indicator format is not supported
【问题讨论】:
标签: python machine-learning scikit-learn multilabel-classification