class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

  

相关文章:

  • 2021-08-29
  • 2021-08-25
  • 2021-12-24
  • 2021-12-10
  • 2022-12-23
  • 2022-12-23
  • 2021-05-19
  • 2021-07-02
猜你喜欢
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2021-10-11
  • 2021-10-27
  • 2022-12-23
  • 2022-12-23
相关资源
相似解决方案