【问题标题】:How do I compute bootstrapped cross entropy loss in PyTorch?如何在 PyTorch 中计算自举交叉熵损失?
【发布时间】:2020-12-23 09:03:28
【问题描述】:

我读过一些论文,它们使用“自举交叉熵损失”来训练他们的分割网络。我们的想法是只关注最难的 k%(比如 15%)像素,以提高学习性能,尤其是在容易的像素占主导地位时。

目前,我使用的是标准交叉熵:

loss = F.binary_cross_entropy(mask, gt)

如何在 PyTorch 中有效地将其转换为引导版本?

【问题讨论】:

    标签: deep-learning neural-network pytorch loss-function


    【解决方案1】:

    添加到@hkchengrex 的自我回答(用于将来的自我和 API 与 PyTorch 的对等);

    可以像这样首先实现functional 版本(在original torch.nn.functional.cross_entropy 中提供一些额外的参数)(我也更喜欢reductioncallable 而不是预定义的字符串):

    import typing
    
    import torch
    
    
    def bootstrapped_cross_entropy(
        inputs,
        targets,
        iteration,
        p: float,
        warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
        weight=None,
        ignore_index=-100,
        reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
    ):
        if not 0 < p < 1:
            raise ValueError("p should be in [0, 1] range, got: {}".format(p))
    
        if isinstance(warmup, int):
            this_p = 1.0 if iteration < warmup else p
        elif callable(warmup):
            this_p = warmup(p, iteration)
        else:
            raise ValueError(
                "warmup should be int or callable, got {}".format(type(warmup))
            )
    
        # Shortcut
        if this_p == 1.0:
            return torch.nn.functional.cross_entropy(
                inputs, targets, weight, ignore_index=ignore_index, reduction=reduction
            )
    
        raw_loss = torch.nn.functional.cross_entropy(
            inputs, targets, weight=weight, ignore_index=ignore_index, reduction="none"
        ).view(-1)
        num_pixels = raw_loss.numel()
    
        loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
        return reduction(loss)
    

    还可以将warmup 指定为callable(采用p 和当前iteration)或int,这允许灵活或轻松的调度。

    并在每次调用期间自动递增 _WeightedLossiteration 的类(因此只有 inputstargets 必须通过):

    class BoostrappedCrossEntropy(torch.nn.modules.loss._WeightedLoss):
        def __init__(
            self,
            p: float,
            warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
            weight=None,
            ignore_index=-100,
            reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
        ):
            self.p = p
            self.warmup = warmup
            self.ignore_index = ignore_index
            self._current_iteration = -1
    
            super().__init__(weight, size_average=None, reduce=None, reduction=reduction)
    
        def forward(self, inputs, targets):
            self._current_iteration += 1
            return bootstrapped_cross_entropy(
                inputs,
                targets,
                self._current_iteration,
                self.p,
                self.warmup,
                self.weight,
                self.ignore_index,
                self.reduction,
            )
    

    【讨论】:

      【解决方案2】:

      通常我们还会在损失中添加一个“热身”期,以便网络可以学习先适应容易的区域并过渡到较难的区域。

      此实现从 k=100 开始并持续 20000 次迭代,然后线性衰减到 k=15 再进行 50000 次迭代。

      class BootstrappedCE(nn.Module):
          def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
              super().__init__()
      
              self.start_warm = start_warm
              self.end_warm = end_warm
              self.top_p = top_p
      
          def forward(self, input, target, it):
              if it < self.start_warm:
                  return F.cross_entropy(input, target), 1.0
      
              raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
              num_pixels = raw_loss.numel()
      
              if it > self.end_warm:
                  this_p = self.top_p
              else:
                  this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
              loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
              return loss.mean(), this_p
      

      【讨论】:

        猜你喜欢
        • 2022-01-09
        • 2021-08-25
        • 2019-11-02
        • 1970-01-01
        • 2018-04-14
        • 2021-01-21
        • 2020-12-22
        • 2017-12-26
        • 2020-08-13
        相关资源
        最近更新 更多