【问题标题】:How does roc_curve() function calculates FPR, TPR values behind the scene. In my case, I got (53,) from (400,) dimensional input dataroc_curve() 函数如何在后台计算 FPR、TPR 值。就我而言,我从 (400,) 维输入数据中得到 (53,)
【发布时间】:2020-12-23 08:37:28
【问题描述】:

我正在绘制 roc 曲线。我有一个带有一个基于隐藏层的分类器的神经网络。所以我的输出是我称之为A2的最后一层激活函数的结果;这将是roc_curve() 中的概率输入。我的 A2 和预测具有以下形状和数据:

print(A2.ravel().shape)
print(predictions.ravel().shape)
print(A2, predictions)

输出:

(400,)
(400,)
[[3.22246780e-04 7.64373268e-01 7.64385217e-01 7.64372464e-01
  1.63920340e-01 7.64372463e-01 2.75254103e-04 7.65185909e-01
  2.06186064e-01 2.12094433e-01 2.75251983e-04 7.64372463e-01
  2.11985152e-01 2.10202927e-01 2.75252955e-04 9.44088883e-02
  2.02522498e-01 2.07370306e-01 2.50282683e-03 2.75260253e-04
  2.11928461e-01 2.75251291e-04 2.75251291e-04 2.75251498e-04
  2.75251306e-04 1.35809613e-01 2.75464969e-04 1.74181943e-01
  2.75435676e-04 2.75251294e-04 2.96236579e-04 2.75268578e-04
  2.76053487e-04 2.78105904e-04 2.75293008e-04 2.75251307e-04
  2.87538148e-04 2.75270689e-04 2.39320951e-06 4.45134656e-02
  2.75251367e-04 2.75251506e-04 2.75251303e-04 2.31132556e-06
  3.69449012e-04 2.75251293e-04 5.59346558e-02 2.31132310e-06
  1.82980485e-01 6.20515482e-06 2.32293394e-02 1.58108674e-03
  2.75252597e-04 1.19360888e-02 2.27051743e-01 2.31161383e-06
  2.31132421e-06 2.31132310e-06 2.31573234e-06 2.31132310e-06
  5.15530179e-01 2.31132310e-06 2.31132311e-06 2.46803695e-06
  2.31132310e-06 2.31141693e-06 2.31132314e-06 2.31181353e-06
  1.08428788e-03 3.91750347e-01 2.15413251e-01 2.31136922e-06
  2.31132310e-06 2.31135038e-06 2.31132310e-06 4.18257225e-02
  2.31132310e-06 2.31692274e-06 2.31132315e-06 2.34152146e-06
  2.31132310e-06 2.31132310e-06 2.31134156e-06 2.32276423e-06
  2.31184444e-06 2.31189807e-06 2.31132310e-06 3.03902587e-06
  2.33123340e-06 6.74029292e-03 1.37374673e-04 7.11777353e-06
  2.31332212e-06 2.31134309e-06 2.85446765e-01 8.45686446e-04
  2.95393201e-06 6.30729453e-02 2.35681287e-06 1.67406531e-05
  1.39482094e-04 1.47208937e-05 2.64716376e-05 1.48764918e-05
  2.37288319e-06 1.76484186e-05 1.47209077e-05 4.24952409e-05
  2.47222738e-04 1.53198138e-05 5.10281474e-06 1.47209298e-05
  1.47208667e-05 2.64277585e-01 1.47208667e-05 1.55307243e-01
  1.47208865e-05 2.91081049e-03 1.47208667e-05 1.47208667e-05
  1.47208667e-05 1.47903704e-05 1.47238820e-05 3.11567098e-02
  4.14289114e-01 1.50836911e-05 2.78303520e-02 1.47208667e-05
  1.47251817e-05 1.47947695e-05 1.47208667e-05 1.47208667e-05
  1.47208940e-05 1.48783712e-05 2.05607558e-04 1.47208667e-05
  4.83812804e-05 1.47208667e-05 2.09377734e-01 1.49642652e-05
  1.47221481e-05 1.47568362e-05 2.77831915e-01 4.82959556e-01
  4.50969045e-01 3.82364226e-02 4.11377002e-02 2.16308926e-01
  8.88141165e-02 2.12679453e-01 2.24050631e-02 1.47208667e-05
  2.12677744e-01 2.12677744e-01 2.12677760e-01 2.33568941e-01
  2.28926909e-01 2.13773365e-01 2.12678951e-01 1.35565877e-03
  2.47656669e-01 1.08727082e-01 2.12677744e-01 2.12678014e-01
  2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01
  2.11844159e-01 1.51525672e-03 2.12677744e-01 2.12677744e-01
  7.65697761e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01
  2.12668331e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01
  2.12677744e-01 2.12677744e-01 2.12677744e-01 6.01467058e-02
  2.12677744e-01 2.12677744e-01 2.12677495e-01 2.12677744e-01
  2.12677743e-01 2.12677744e-01 7.72857608e-01 2.09249431e-01
  7.86146268e-01 7.64683696e-01 8.39288704e-01 2.12677744e-01
  8.05987357e-01 7.73524718e-01 7.64722596e-01 7.64646794e-01
  2.12677744e-01 8.54868081e-01 7.66923142e-01 8.54244158e-01
  2.11261708e-01 7.66992993e-01 2.12677744e-01 2.12598362e-01
  7.66165847e-01 9.99643109e-01 7.65268010e-01 9.99685903e-01
  9.99685903e-01 7.65043689e-01 2.12677744e-01 2.12677744e-01
  7.64840536e-01 9.99685901e-01 9.99332786e-01 2.12677743e-01
  7.79121852e-01 9.99685785e-01 7.79074180e-01 7.65194741e-01
  8.98667738e-01 9.99684795e-01 9.58419683e-01 9.99685902e-01
  9.99685882e-01 9.99639779e-01 9.99639274e-01 9.99677983e-01
  9.99685736e-01 9.99685902e-01 9.92940564e-01 9.99685903e-01
  9.99685839e-01 8.30995491e-01 9.90611316e-01 9.99997341e-01
  9.99670704e-01 9.23825584e-01 9.99685666e-01 9.99996824e-01
  9.99685902e-01 9.40290068e-01 9.99685903e-01 9.99996965e-01
  9.99685364e-01 9.99997362e-01 9.99685801e-01 9.99997362e-01
  9.99996900e-01 9.99685513e-01 9.99997362e-01 9.99684995e-01
  9.99676405e-01 9.99997362e-01 6.89410113e-01 5.28997119e-01
  9.93019339e-01 6.62017810e-01 9.99997362e-01 9.99997362e-01
  9.99997358e-01 9.99997362e-01 9.99997362e-01 9.99997346e-01
  9.99997362e-01 9.99997352e-01 9.99997362e-01 9.99997362e-01
  9.99997362e-01 9.99071790e-01 9.99997362e-01 9.99997362e-01
  3.46195433e-01 9.99995537e-01 9.99997362e-01 9.99997362e-01
  9.99997362e-01 9.99997362e-01 9.99997362e-01 9.99996894e-01
  7.67197871e-01 9.99997179e-01 1.65047845e-01 9.99978488e-01
  2.93981729e-01 9.99997362e-01 9.99997361e-01 9.29067186e-01
  9.99997362e-01 9.48399940e-01 9.99997362e-01 9.99997362e-01
  6.78299886e-01 9.99997362e-01 9.99997362e-01 9.63677152e-01
  9.99997362e-01 3.67733752e-01 9.99997222e-01 7.74993071e-01
  6.37972260e-01 9.99943783e-01 9.77268446e-01 9.99976242e-01
  7.00255679e-01 9.99983200e-01 9.99983201e-01 9.99983138e-01
  9.99983197e-01 9.86360906e-01 9.99983201e-01 9.99389801e-01
  9.98380059e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01
  9.99983199e-01 9.99983199e-01 9.99983201e-01 9.99391768e-01
  9.99983201e-01 9.99983201e-01 9.99981131e-01 9.99983201e-01
  9.99983201e-01 9.76520592e-01 8.44076103e-01 9.99983201e-01
  9.99983201e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01
  9.99899640e-01 9.99983201e-01 9.99983193e-01 9.99964112e-01
  9.99983201e-01 9.99983201e-01 9.99983201e-01 7.58322592e-01
  9.99983201e-01 9.99983201e-01 9.99981971e-01 7.64372463e-01
  7.64372463e-01 9.99983201e-01 9.06823611e-01 9.99983201e-01
  7.64372463e-01 9.99983201e-01 2.01516877e-01 7.64372463e-01
  3.98768426e-01 7.64372463e-01 9.81611504e-01 7.64372463e-01
  7.64372463e-01 7.64370725e-01 7.64372463e-01 7.64372463e-01
  9.99979567e-01 1.90105310e-01 7.64372463e-01 7.64372463e-01
  4.09226724e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01
  7.64372463e-01 7.64387743e-01 7.64372463e-01 7.64372463e-01
  7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01
  7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01
  7.76876797e-01 7.64372463e-01 2.07693046e-01 7.64372463e-01
  7.64372463e-01 7.59770748e-01 7.64372463e-01 7.64372463e-01
  7.66343703e-01 2.05588421e-01 7.64372828e-01 2.06636497e-01
  1.97645490e-01 2.09816835e-01 7.64372464e-01 1.77842165e-01]] [[0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 1 0 1 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 1 1 1 1 1 1 0 0
  1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 0 1
  0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 1 1 1
  1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 0
  0 0 1 0]]

现在当我在roc_curve() 中输入这些值时,我得到 fpr、tpr、以下形状和大小的阈值:

fpr, tpr, threshold = roc_curve(Y.ravel(), A2.ravel())
print(fpr.shape, tpr.shape, threshold.shape)
print(fpr, tpr, threshold)

输出:

(53,) (53,) (53,)
[0.    0.    0.    0.    0.    0.    0.    0.    0.005 0.005 0.015 0.015
 0.025 0.025 0.03  0.03  0.035 0.035 0.05  0.05  0.06  0.06  0.065 0.065
 0.075 0.075 0.095 0.095 0.1   0.1   0.19  0.27  0.285 0.285 0.3   0.3
 0.32  0.32  0.325 0.325 0.335 0.335 0.34  0.34  0.345 0.345 0.35  0.35
 0.355 0.355 0.36  0.36  1.   ] [0.    0.015 0.045 0.19  0.23  0.235 0.245 0.615 0.615 0.62  0.62  0.64
 0.64  0.665 0.665 0.675 0.675 0.685 0.685 0.69  0.69  0.7   0.7   0.715
 0.825 0.895 0.895 0.905 0.905 0.92  0.92  0.935 0.935 0.945 0.945 0.95
 0.95  0.955 0.955 0.96  0.96  0.965 0.965 0.97  0.97  0.975 0.975 0.99
 0.99  0.995 0.995 1.    1.   ] [1.99999736e+00 9.99997362e-01 9.99997362e-01 9.99995537e-01
 9.99983201e-01 9.99983201e-01 9.99983201e-01 8.44076103e-01
 8.39288704e-01 8.30995491e-01 7.86146268e-01 7.74993071e-01
 7.72857608e-01 7.66165847e-01 7.65697761e-01 7.65194741e-01
 7.65185909e-01 7.64840536e-01 7.64646794e-01 7.64387743e-01
 7.64373268e-01 7.64372464e-01 7.64372464e-01 7.64372463e-01
 7.64372463e-01 5.28997119e-01 4.14289114e-01 3.98768426e-01
 3.91750347e-01 2.93981729e-01 2.12677744e-01 2.12677744e-01
 2.12677744e-01 2.12677743e-01 2.12668331e-01 2.12598362e-01
 2.11844159e-01 2.11261708e-01 2.10202927e-01 2.09816835e-01
 2.09249431e-01 2.07693046e-01 2.07370306e-01 2.06636497e-01
 2.06186064e-01 2.05588421e-01 2.02522498e-01 1.90105310e-01
 1.82980485e-01 1.77842165e-01 1.74181943e-01 1.65047845e-01
 2.31132310e-06]

因此我的 roc 曲线如下所示:

plt.figure()
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic(ROC Curve)')

输出:

为什么我得到(53,) FPR、TPR、阈值的形状?我的情况只是简单的两类分类。感谢您的帮助。

【问题讨论】:

    标签: python numpy machine-learning scikit-learn roc


    【解决方案1】:

    阈值的数量计算如下:

    • 第 1 步:保留唯一的分数值,加 1。

    Source:

    # y_score typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = np.where(np.diff(y_score))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
    
    • 第 2 步(如果 if drop_intermediate and len(fps) > 2):降低与介于两者之间并与其他点共线的点对应的阈值。

    Source:

    # Attempt to drop thresholds corresponding to points in between and
    # collinear with other points. These are always suboptimal and do not
    # appear on a plotted ROC curve (and thus do not affect the AUC).
    # Here np.diff(_, 2) is used as a "second derivative" to tell if there
    # is a corner at the point. Both fps and tps must be tested to handle
    # thresholds with multiple data points (which are combined in
    # _binary_clf_curve). This keeps all cases where the point should be kept,
    # but does not drop more complicated cases like fps = [1, 3, 7],
    # tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
    if drop_intermediate and len(fps) > 2:
        optimal_idxs = np.where(np.r_[True,
                                      np.logical_or(np.diff(fps, 2),
                                                    np.diff(tps, 2)),
                                      True])[0]
        fps = fps[optimal_idxs]
        tps = tps[optimal_idxs]
        thresholds = thresholds[optimal_idxs]
    

    然后为每个阈值计算 FPR 和 TPR。

    【讨论】:

      猜你喜欢
      • 2015-08-28
      • 2020-12-28
      • 2013-09-01
      • 2021-05-22
      • 1970-01-01
      • 2020-12-22
      • 2020-05-16
      • 2019-08-22
      • 1970-01-01
      相关资源
      最近更新 更多