【发布时间】:2022-02-09 23:21:39
【问题描述】:
根据nn.BCEWithLogitsLoss 的pytorch doc,pos_weight 是一个可选参数a,它采用正例的权重。我不完全理解该页面中的声明“pos_weight > 1 增加召回率和 pos_weight
【问题讨论】:
根据nn.BCEWithLogitsLoss 的pytorch doc,pos_weight 是一个可选参数a,它采用正例的权重。我不完全理解该页面中的声明“pos_weight > 1 增加召回率和 pos_weight
【问题讨论】:
带有 logits 损失的二元交叉熵(nn.BCEWithLogitsLoss,相当于 F.binary_cross_entropy_with_logits)是一个 sigmoid 层(nn.Sigmoid),后面是二元交叉熵损失(nn.BCELoss)。一般情况假设您处于多标签分类任务,即单个输入可以用多个类标记。一个常见的子情况是只有一个类:二元分类任务。如果您将q 定义为预测类别的张量,而p 定义为与每个类别的真实概率相对应的基本事实[0,1]。
二元交叉熵的显式公式为:
z = torch.sigmoid(q)
loss = -(w_p*p*torch.log(z) + (1-p)*torch.log(1-z))
引入w_p,与每个类的真实标签相关的权重。阅读this post,了解有关BCELoss 使用的加权方案的更多详细信息。
对于给定的类:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
然后如果w_p > 1,它会增加正分类的权重(分类为真)。这往往会增加误报(FP),从而降低精度。类似地,如果w_p < 1,我们正在减少真实类的权重,这意味着它会倾向于增加假阴性(FN),从而降低召回率。
【讨论】: