【问题标题】:What is the impact of `pos_weight` argument in `BCEWithLogitsLoss`?`BCEWithLogitsLoss` 中的 `pos_weight` 参数有什么影响?
【发布时间】:2022-02-09 23:21:39
【问题描述】:

根据nn.BCEWithLogitsLosspytorch docpos_weight 是一个可选参数a,它采用正例的权重。我不完全理解该页面中的声明“pos_weight > 1 增加召回率和 pos_weight

【问题讨论】:

    标签: machine-learning pytorch


    【解决方案1】:

    带有 logits 损失的二元交叉熵(nn.BCEWithLogitsLoss,相当于 F.binary_cross_entropy_with_logits)是一个 sigmoid 层(nn.Sigmoid),后面是二元交叉熵损失(nn.BCELoss)。一般情况假设您处于多标签分类任务,即单个输入可以用多个类标记。一个常见的子情况是只有一个类:二元分类任务。如果您将q 定义为预测类别的张量,而p 定义为与每个类别的真实概率相对应的基本事实[0,1]

    二元交叉熵的显式公式为:

    z = torch.sigmoid(q)
    loss = -(w_p*p*torch.log(z) + (1-p)*torch.log(1-z))
    

    引入w_p,与每个类的真实标签相关的权重。阅读this post,了解有关BCELoss 使用的加权方案的更多详细信息。

    对于给定的类:

    precision =  TP / (TP + FP)
    recall = TP / (TP + FN)
    

    然后如果w_p > 1,它会增加正分类的权重(分类为真)。这往往会增加误报(FP),从而降低精度。类似地,如果w_p < 1,我们正在减少真实类的权重,这意味着它会倾向于增加假阴性(FN),从而降低召回率

    【讨论】:

      猜你喜欢
      • 2021-06-14
      • 2011-09-01
      • 2020-05-09
      • 2012-07-19
      • 1970-01-01
      • 1970-01-01
      • 2012-05-08
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多