【问题标题】:Using weights in CrossEntropyLoss and BCELoss (PyTorch)在 CrossEntropyLoss 和 BCELoss (PyTorch) 中使用权重
【发布时间】:2021-08-16 04:44:25
【问题描述】:

我正在训练一个 PyTorch 模型来执行二进制分类。我的少数类约占数据的 10%,所以我想使用加权损失函数。 BCELossCrossEntropyLoss 的文档说我可以为每个样本使用 'weight'

但是,当我声明CE_loss = nn.BCELoss()nn.CrossEntropyLoss() 然后执行CE_Loss(output, target, weight=batch_weights),其中outputtargetbatch_weightsTensors 的batch_size,我收到以下错误留言:

forward() got an unexpected keyword argument 'weight'

【问题讨论】:

    标签: pytorch loss-function


    【解决方案1】:

    您是否想对数据集中第 0 类和第 1 类的所有元素应用单独的固定权重?目前尚不清楚您在这里为 batch_weights 传递了什么值。如果是这样,那么这不是 BCELoss 中的权重参数所做的。 weight 参数要求您为数据集中的每个 ELEMENT 传递一个单独的权重,而不是为每个 CLASS 传递一个单独的权重。有几种方法可以解决这个问题。您可以为每个元素构建一个权重表。或者,您可以使用自定义损失函数来满足您的需求:

    def BCELoss_class_weighted(weights):
    
        def loss(input, target):
            input = torch.clamp(input,min=1e-7,max=1-1e-7)
            bce = - weights[1] * target * torch.log(input) - (1 - target) * weights[0] * torch.log(1 - input)
            return torch.mean(bce)
    
      return loss
    

    请注意,添加钳位以避免数值不稳定很重要。

    HTH 杰伦

    【讨论】:

      【解决方案2】:

      实现目标的另一种方法是在初始化损失时使用reduction=none,然后在计算平均值之前将结果张量乘以权重。 例如

      loss = torch.nn.BCELoss(reduction='none')
      model = torch.sigmoid
      
      weights = torch.rand(10,1)
      inputs = torch.rand(10,1)
      targets = torch.rand(10,1)
      
      intermediate_losses = loss(model(inputs), targets)
      final_loss = torch.mean(weights*intermediate_losses)
      

      当然,对于您的场景,您仍然需要计算权重张量。但希望这会有所帮助!

      【讨论】:

        【解决方案3】:

        问题在于您提供了 weight 参数。正如文档here 中提到的,应该在模块实例化期间提供 weights 参数。

        例如,类似,

        from torch import nn
        weights = torch.FloatTensor([2.0, 1.2]) 
        loss = nn.BCELoss(weights=weights)
        

        您可以找到更具体的示例 here 或其他有用的 PT 论坛讨论 here

        【讨论】:

        • BCELoss 的文档说“权重”应该是“手动重新调整权重,赋予每个批次元素的损失”。如果给定,则必须是大小为 nbatch 的张量。如果每批的权重都发生变化怎么办?
        • 从表面上看,我认为这是不可能的。最重要的是,我认为每批动态调整权重可能会对学习产生负面影响,因为损失函数属性每批都在不断变化。
        【解决方案4】:

        你需要像下面这样传递权重:

        CE_loss = CrossEntropyLoss(weight=[…])
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2021-06-06
          • 2020-03-21
          • 2020-01-27
          • 1970-01-01
          • 2021-03-16
          • 2019-09-19
          • 2021-06-10
          • 2020-10-20
          相关资源
          最近更新 更多