【发布时间】:2021-05-26 22:47:32
【问题描述】:
我想为多标签分类创建自定义损失函数。这个想法是对正面和负面标签进行不同的权衡。为此,我正在使用这个自定义代码实现。
class WeightedBCEWithLogitLoss(nn.Module):
def __init__(self, pos_weight, neg_weight):
super(WeightedBCEWithLogitLoss, self).__init__()
self.register_buffer('neg_weight', neg_weight)
self.register_buffer('pos_weight', pos_weight)
def forward(self, input, target):
assert input.shape == target.shape, "The loss function received invalid input shapes"
y_hat = torch.sigmoid(input + 1e-8)
loss = -1.0 * (self.pos_weight * target * torch.log(y_hat + 1e-6) + self.neg_weight * (1 - target) * torch.log(1 - y_hat + 1e-6))
# Account for 0 times inf which leads to nan
loss[torch.isnan(loss)] = 0
# We average across each of the extra attribute dimensions to generalize it
loss = loss.mean(dim=1)
# We use mean reduction for our task
return loss.mean()
我开始得到 nan 值,我意识到这是由于 0 次 inf 乘法而发生的。我处理它如图所示。接下来,我再次看到将inf 作为错误值并通过在日志中添加 1e-6 来纠正它(我尝试使用 1e-8 但仍然给了我 inf 错误值)。
如果有人可以查看并建议进一步改进并纠正此处可见的任何更多错误,那就太好了。
【问题讨论】:
标签: machine-learning deep-learning pytorch