【问题标题】:How to add LSTM layer on top of Huggingface BERT model如何在 Huggingface BERT 模型之上添加 LSTM 层
【发布时间】:2021-01-17 16:58:08
【问题描述】:

我正在处理二进制分类任务,并想尝试在 huggingface BERT 模型的最后一个隐藏层之上添加 lstm 层,但是我无法到达最后一个隐藏层。 BERT 可以和 LSTM 结合吗?

tokenizer = BertTokenizer.from_pretrained(model_path)
tain_inputs, train_labels, train_masks = data_prepare_BERT(
    train_file, lab2ind, tokenizer, content_col, label_col, 
    max_seq_length)
validation_inputs, validation_labels, validation_masks = data_prepare_BERT(
    dev_file, lab2ind, tokenizer, content_col, label_col,max_seq_length)

# Load BertForSequenceClassification, the pretrained BERT model with a single linear classification layer on top.
model = BertForSequenceClassification.from_pretrained(
    model_path, num_labels=len(lab2ind))

【问题讨论】:

  • 您需要提供更多代码和详细信息。
  • @Ruli 我刚刚编辑了它!
  • 这里有一个可用的答案:stackoverflow.com/questions/65205582/…
  • @AshwinGeetD'Sa,谢谢!我已经尝试过了,但是我收到了这个错误TypeError: __init__() got an unexpected keyword argument 'batch_first' for nn.Linear()!
  • batch_first 仅适用于 LSTM,不适用于线性。所以,请再次检查您的代码。

标签: nlp pytorch recurrent-neural-network bert-language-model huggingface-transformers


【解决方案1】:

确实可以,但是需要自己实现。 BertForSequenceClassification 类是 BertModel 的包装器。它运行模型,获取与 [CLS] 标记对应的隐藏状态,并在此之上应用分类器。

在您的情况下,您可以将类作为起点,并在 BertModel 和分类器之间添加一个 LSTM 层。 BertModel 在元组中返回隐藏状态和池化状态以进行分类。只取原始类中使用的另一个元组成员。

虽然在技术上是可行的,但与使用 BertForSequenceClassification 相比,我预计会有任何性能提升。 Transformer 层的微调可以学习任何其他 LSTM 层的能力。

【讨论】:

  • 非常感谢您的回复!将 LSTM 添加到 BertForSequenceClassification 类需要额外的计算成本(我使用的是 colab GPU)吗?我的意思是如果我将它添加到原始类中,我需要重新训练整个模型吗?
  • 另外,我想尝试添加 BiLSTM 和 CNN-LSTM。您认为这会提高性能吗?
  • 重新训练整个模型是什么意思? Bert 模型是预训练的,你可以微调也可以不微调。你在 Bert 上添加的任何东西都需要从头开始训练,无论是简单的分类器还是 LSTM。
  • 广告其他架构:您的特定问题可能非常适合 LSTM 或 CNN,并且您会获得明显更好的性能,但与微调 BERT 相比,我不期望太多。跨度>
  • 非常有用!非常感谢你,@Jindřich。我想如果我在 BERT 的最后一个隐藏层之上添加一个额外的层,那么有一种方法可以获取该层的隐藏状态并将它们输入到一个新层,而无需从头开始训练 BERT。
猜你喜欢
  • 2021-01-17
  • 2021-03-20
  • 2021-03-24
  • 1970-01-01
  • 2020-10-28
  • 2021-10-07
  • 2021-02-16
  • 2019-01-14
  • 2017-09-23
相关资源
最近更新 更多