以一个简单例子来说明各个 Loss 函数的使用

label_numpy = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 1]], dtype=np.float) # 模拟 标签
out_numpy = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float) # 模拟 预测
num_classes = 2

$l_n = -w_n[y_n * log{x_n} + (1 - y_n) * log(1 - x_n)]$

label = torch.from_numpy(label_numpy).unsqueeze(0) # N x C
output = torch.from_numpy(out_numpy).unsqueeze(0)  # N x C
# ======================================================= #
criterion = nn.BCELoss()
loss = criterion(F.sigmoid(output), label) # 0.6219
# ======================================================= #

nn.BCEWithLogitsLoss

label = torch.from_numpy(label_numpy).unsqueeze(0)
output = torch.from_numpy(out_numpy).unsqueeze(0)
# ======================================================= #
criterion = nn.BCEWithLogitsLoss()
loss = criterion(output, label) # 0.6219
# ======================================================= #

这个损失将Sigmoid层和BCELoss合并在一个类中,且数值稳定性更好

具体计算过程如下

class BCEWithLogitsLoss(nn.Module):
    """
    这个版本在数值上比使用一个简单的Sigmoid和一个BCELoss as更稳定,通过将操作合并到一个层中,我们利用log-sum-exp技巧来实现数值稳定性。
    """
    def __init__(self):
        super(BCEWithLogitsLoss, self).__init__()

    def forward(self, input, target, weight=None, size_average=None,
                reduce=None, reduction='mean', pos_weight=None):
        if size_average is not None or reduce is not None:
            reduction = _Reduction.legacy_get_string(size_average, reduce)
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
        max_val = (-input).clamp(min=0)
        if pos_weight is None:
            loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
        else:
            log_weight = 1 + (pos_weight - 1) * target
            loss = input - input * target + log_weight * (max_val + ((-max_val).exp() + (-input - max_val).exp()).log())

        if weight is not None:
            loss = loss * weight
        if reduction == 'none':
            return loss
        elif reduction == 'mean':
            return loss.mean()
        else:
            return loss.sum()
View Code

相关文章:

  • 2021-08-09
  • 2021-11-19
  • 2022-01-03
  • 2021-06-17
  • 2021-10-11
  • 2021-07-09
  • 2022-12-23
猜你喜欢
  • 2022-12-23
  • 2021-10-28
  • 2021-11-27
  • 2022-02-16
  • 2021-10-08
  • 2021-10-20
  • 2021-11-18
相关资源
相似解决方案