【问题标题】:Loss function for comparing two vectors for categorization用于比较两个向量进行分类的损失函数
【发布时间】:2021-07-24 00:37:35
【问题描述】:

我正在执行一项 NLP 任务,我在其中分析文档并将其分类为六个类别之一。但是,我在三个不同的时间段执行此操作。所以最终的输出是一个由三个整数组成的数组(稀疏),其中每个整数是 0-5 的类别。所以标签看起来像这样:[1, 4, 5]

我正在使用 BERT,我正在尝试决定我应该附加什么类型的头,以及我应该使用什么类型的损失函数。使用 BERT 的大小为 1024 的输出并通过具有 18 个神经元的 Dense 层运行它,然后重新整形为大小为 (3,6) 的东西是否有意义?

最后,我假设我会使用稀疏分类交叉熵作为我的损失函数?

【问题讨论】:

  • 您为什么以及如何在 3 个不同的时间段内这样做?
  • @GaryOng 我正在使用三个不同的时间段,因为我试图预测 1 个月、6 个月和 1 年的结果。我试图使用同一个 ML 网络来完成这三个任务,而不是为每个时间段构建单独的网络。

标签: python machine-learning nlp bert-language-model


【解决方案1】:

在典型设置中,您获取 BERT 的 CLS 输出(bert-base 的向量长度为​​ 768,bert-large 的向量长度为​​ 1024)并添加一个分类头(它可能是一个简单的密集层有辍学)。在这种情况下,输入是单词标记,分类头的输出是每个类的 logits 向量,通常使用常规的交叉熵损失函数。然后你申请softmax 并获得每个班级的概率分数,或者如果你申请argmax 你将获得获胜班级。因此,结果可能是分类分数向量 [1x6] 或主要类别索引(整数)。

图像取自 d2l.ai

您可以简单地连接 3 个这样的网络(每个时间段)以获得所需的结果。

显然,我只描述了一种可能的解决方案。但由于它通常会提供良好的结果,我建议您在转向更复杂的之前尝试一下。

最后,当输出稀疏时使用稀疏分类交叉熵损失(比如[4]),当输出是单热编码时使用常规分类交叉熵损失(比如[0 0 0 0 1 0])。否则它们是完全一样的。

【讨论】:

  • 由于我试图预测 1 个月、6 个月和 1 年的结果,您的建议是创建三个不同的神经网络?还是一个具有三个输出的 NN?如果是三个输出(例如,[3, 5, 2] 表示类3 和周期1,类5 在周期2 和类2 在周期3),我使用什么损失函数将其与 [3, 5, 1] 的正确结果进行比较?
  • 例如,比较一个时间段的稀疏分类交叉熵很容易,比如[3][4] 的输出,但如果它是一个数组([3, 5, 2][3, 5, 1])我不确定到底用什么作为损失函数
  • 我建议您创建 3 个不同的分类网络,每个分类网络用于特定的时间段,然后连接它们的输出(每个输出一个整数)。在这种情况下,网络将是完全独立的,您可以更轻松地训练它们。在这种情况下,您的第一个网络产生3 作为输出,正确答案也是3,第二个网络也是如此,而第三个预测类2,而实际上它应该是1。在这种情况下,您将不得不使用稀疏分类交叉熵损失。这是最简单的设置。
【解决方案2】:

bert 最终隐藏状态为 (512,1024)。您可以选择第一个代币,即 CLS 代币,也可以采用平均池化。无论哪种方式,您的最终输出都是形状 (1024,) 现在只需像 nn.Linear(1024,6) 一样放置 3 个形状 (1024,6) 的线性层,并将其传递到下面的损失函数中。 (如果你愿意,你可以让它更复杂)

只需将损失相加并向后调用。请记住,您可以在任何标量张量上调用 loss.backward()。(pytorch)

def loss(time1output,time2output,time3output,time1label,time2label,time3label):
    loss1 = nn.CrossEntropyLoss()(time1output,time1label)
    loss2 = nn.CrossEntropyLoss()(time2output,time2label)
    loss3 = nn.CrossEntropyLoss()(time3output,time3label)

    return loss1 + loss2 + loss3

【讨论】:

    猜你喜欢
    • 2013-03-15
    • 2016-06-29
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2010-12-25
    • 1970-01-01
    • 2014-10-21
    • 2017-08-10
    相关资源
    最近更新 更多