【问题标题】:Predictions using AUC metrics for multilabel classification使用 AUC 指标进行多标签分类的预测
【发布时间】:2021-07-19 13:21:13
【问题描述】:
我正在使用 AUC 指标进行多标签分类。由于 keras 已经删除了 prediction_classes 来获取预测类,所以我只使用 0.5 的阈值来获取输出类。但是,据我了解,对于不平衡的数据集,AUC 的阈值不应为 0.5。如何获得用于训练模型的阈值?
此外,我知道 AUC 用于二进制分类。我可以将它用于多标签问题吗?如何计算阈值?取平均值与否。
【问题讨论】:
标签:
python
tensorflow
keras
multilabel-classification
【解决方案1】:
多标签问题可以使用AUC,查看this。
import numpy as np
y_true = np.random.randint(0,2,(100,4))
y_pred = np.random.randint(0,2,(100,4))
m = tf.keras.metrics.AUC(multi_label=True, thresholds=[0, 0.5])
m(y_true, y_pred).numpy()
仅供参考,来自 tf 2.5,它现在支持 logit 预测。