【问题标题】:max_length doesn't fix the question-answering modelmax_length 不修复问答模型
【发布时间】:2021-03-29 21:13:48
【问题描述】:

我的问题: 给定一个大 (>512b) .txt 文件,如何让我的“问答”模型运行?

上下文: 我正在使用来自谷歌的词嵌入模型 BERT 创建一个问答模型。当我导入包含几个句子的 .txt 文件时,模型运行良好,但是当 .txt 文件超过 512b 单词作为模型学习上下文的限制时,模型不会回答我的问题。

我尝试解决问题: 我在编码部分设置了一个 max_length,但这似乎并没有解决问题(我的尝试代码如下)。

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

max_seq_length = 512


tokenizer = AutoTokenizer.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("henryk/bert-base-multilingual-cased-finetuned-dutch-squad2")

f = open("test.txt", "r")

text = str(f.read())

questions = [
    "Wat is de hoofdstad van Nederland?",
    "Van welk automerk is een Cayenne?",
    "In welk jaar is pindakaas geproduceerd?",
]

for question in questions:
    inputs = tokenizer.encode_plus(question, 
                                   text, 
                                   add_special_tokens=True, 
                                   max_length=max_seq_length,
                                   truncation=True,
                                   return_tensors="pt")
    input_ids = inputs["input_ids"].tolist()[0]

    text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    answer_start_scores, answer_end_scores = model(**inputs, return_dict=False)

    answer_start = torch.argmax(
        answer_start_scores
    )  # Get the most likely beginning of answer with the argmax of the score
    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score

    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

    print(f"Question: {question}")
    print(f"Answer: {answer}\n")

代码结果:

> Question: Wat is de hoofdstad van Nederland?
> Answer: [CLS]
>
> Question: Van welk automerk is een Cayenne?
> Answer: [CLS]
>
> Question: In welk jaar is pindakaas geproduceerd?
> Answer: [CLS]

正如我们所看到的,该模型仅返回在标记器编码部分发生的 [CLS]-token。

编辑:我想出了解决这个问题的方法,就是迭代 .txt 文件,这样模型就可以通过迭代找到答案。

【问题讨论】:

标签: python machine-learning text bert-language-model maxlength


【解决方案1】:

编辑:我想出解决这个问题的方法是遍历 .txt 文件,这样模型就可以通过迭代找到答案。模型使用 [CLS] 回答的原因是因为它无法在 512b 上下文中找到答案,它必须更深入地查看上下文。

通过创建这样的循环:

with open("sample.txt", "r") as a_file:
  for line in a_file:
    text = line.strip()
    print(text)

可以将迭代后的文本应用到encode_plus中。

【讨论】:

    猜你喜欢
    • 2018-08-23
    • 1970-01-01
    • 2021-04-23
    • 1970-01-01
    • 2011-06-26
    • 1970-01-01
    • 2015-07-25
    • 1970-01-01
    • 2014-12-15
    相关资源
    最近更新 更多