【发布时间】:2021-06-14 00:56:53
【问题描述】:
我试图了解pos_weight 参数是如何在BCEWithLogitsLoss 中使用的,以便能够正确定义pos_weight 张量。文档只提到:“正例的权重。必须是长度等于类数的向量。”。由于我无法通过查看代码获得足够的理解(实际代码隐藏在多个函数加载器后面,我什至没有达到使用pos_weight 的地步),我有几个关于pos_weight 论据:
- 负样本的权重是否始终为 1?
- 如果负样本的权重始终为 1,并且假设我希望每个样本对损失的贡献相同,我会执行以下操作。设
l = [100, 10, 5, 15]其中l[0]是负样本的数量,l[1:]是每个标签的正样本数量。在伪代码中是这样的:
l = [100, 10, 5, 15]
lcm = LCM(l) # 300
weights = lcm / l # weights = [3, 30, 60, 20]
weights = weights / l[0] # weights = [1, 10, 20, 6.6667]
positive_weights = weights[1:] # [10, 20, 6.66667]
criterion = nn.BCEWithLogitsLoss(pos_weight=positive_weights)
有人可以确认我对如何使用pos_weight 的理解是否正确?
【问题讨论】:
标签: pytorch