class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

来源 :TGS Salt Identification Challenge

相关文章:

  • 2022-12-23
  • 2021-09-04
  • 2021-07-14
  • 2021-12-14
  • 2022-12-23
  • 2021-04-12
  • 2021-05-12
猜你喜欢
  • 2022-12-23
  • 2021-12-23
  • 2022-12-23
  • 2021-04-27
  • 2021-07-26
  • 2021-11-30
相关资源
相似解决方案