【问题标题】:Using Softmax Activation function after calculating loss from BCEWithLogitLoss (Binary Cross Entropy + Sigmoid activation)从 BCEWithLogitLoss 计算损失后使用 Softmax 激活函数(Binary Cross Entropy + Sigmoid 激活)
【发布时间】:2020-09-14 14:47:36
【问题描述】:

我正在学习使用 PyTorch 的二元分类教程,在这里,网络的最后一层是 torch.Linear(),只有一个神经元。 (有意义)这将给我们一个神经元。如pred=network(input_batch)

之后损失函数的选择是loss_fn=BCEWithLogitsLoss()(这比先使用softmax然后计算损失在数值上稳定),它将Softmax函数应用于最后一层的输出,给我们一个概率。所以在那之后,它会计算二进制交叉熵来最小化损失。

loss=loss_fn(pred,true)

我担心的是,毕竟作者使用了torch.round(torch.sigmoid(pred))

为什么会这样?我的意思是我知道它会得到[0,1] 范围内的预测概率,然后以默认阈值 0.5 对值进行四舍五入。

在网络中的最后一层之后使用sigmoid 一次不是更好,而不是在两个不同的地方使用softmax 和sigmoid,因为它是一个二进制分类?

这样不是更好吗

out = self.linear(batch_tensor)
return self.sigmoid(out)

然后计算BCE损失并使用argmax()检查准确性??

我只是好奇这会是一个有效的策略吗?

【问题讨论】:

  • 您正在学习的教程是什么?为什么不简单地包含一个链接?
  • 不,BCEWithLogitsLoss 不应用 softmax。它将 sigmoid 应用于 logits,然后是二进制交叉熵。 pytorch.org/docs/stable/nn.html#bcewithlogitsloss
  • 这是来自 PluralSight 的付费视频。抱歉,我在二元类和多类之间感到困惑

标签: deep-learning neural-network pytorch recurrent-neural-network


【解决方案1】:

您似乎将二元分类视为具有两个类的多类分类,但在使用二元交叉熵方法时这并不完全正确。在查看任何实现细节之前,让我们先澄清二进制分类的目标。

从技术上讲,有两个类,0 和 1,但不要将它们视为两个独立的类,您可以将它们视为彼此对立的。例如,您想对 StackOverflow 答案是否有用进行分类。这两个类将是“有帮助”“没有帮助”。自然地,你会简单地问“答案有帮助吗?”,消极的一面被忽略了,如果不是这样,你可以推断它是“没有帮助” 。 (记住,这是一个二元情况,没有中间立场)。

因此,您的模型只需要预测单个类,但为了避免与实际的两个类混淆,可以表示为:模型预测正例发生的概率。在上一个示例的上下文中:StackOverflow 的答案有帮助的概率是多少?

Sigmoid 为您提供 [0, 1] 范围内的值,即概率。现在,您需要通过定义阈值来确定模型何时足够自信以使其为正。为了平衡,阈值为 0.5,因此只要概率大于 0.5 即为正(1 类:"有帮助"),否则为负(0 类:"不有帮助的”),这是通过四舍五入实现的(即torch.round(torch.sigmoid(pred)))。

之后损失函数的选择是loss_fn=BCEWithLogitsLoss()(比先使用softmax然后计算损失在数值上稳定),它将Softmax函数应用于最后一层的输出,给我们一个概率。

在网络的最后一层之后使用一次 sigmoid 不是更好,而不是在 2 个不同的地方使用 softmax 和 sigmoid,因为它是二元分类吗??

BCEWithLogitsLoss 应用 Sigmoid 而不是 Softmax,根本不涉及 Softmax。来自nn.BCEWithLogitsLoss documentation

这种损失将 Sigmoid 层和 BCELoss 组合在一个类中。这个版本比使用简单的 Sigmoid 后跟 BCELoss 在数值上更稳定,因为通过将操作组合到一层,我们利用了 log-sum-exp 技巧为了数值稳定性。

通过不在模型中应用 Sigmoid,您可以获得数值更稳定的二元交叉熵版本,但这意味着如果您想在训练之外进行实际预测,则必须手动应用 Sigmoid。

[...] 并使用argmax() 来检查准确性??

再次,您正在考虑多类场景。您只有一个输出类,即输出的大小为 [batch_size, 1]。拿 argmax 来说,总是给你 0,因为这是唯一可用的类。

【讨论】:

  • 哦!!!知道了。谢谢你。当时我很困惑,因为我没有清楚地思考。
猜你喜欢
  • 2021-03-23
  • 1970-01-01
  • 2023-04-03
  • 2019-10-31
  • 2018-06-27
  • 2018-04-15
  • 2020-12-05
  • 2018-10-24
  • 1970-01-01
相关资源
最近更新 更多