【问题标题】:How to early-stop autoregressive model with a list of stop words?如何使用停用词列表尽早停止自回归模型?
【发布时间】:2022-04-26 01:29:51
【问题描述】:

我正在使用来自transformers 的 GPT-Neo 模型来生成文本。因为我使用的提示以'{'开头,所以我想在生成配对'}'后停止这句话。 我发现源代码中有一个StoppingCriteria 方法,但没有进一步说明如何使用它。有没有人找到一种方法来尽早停止模型生成?谢谢!

这是我尝试过的:

from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torch_dtype=dtype).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids in self.keywords:
            return True
        return False

stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w) for w in stop_words]
stop_ids.append(tokenizer.eos_token_id)
stop_criteria = KeywordsStoppingCriteria(stop_ids)

model.generate(
    text_inputs='some text:{', 
    StoppingCriteria=stop_criteria
)

【问题讨论】:

  • 您可以发布您当前代码的minimal reproducible example 吗?
  • 如果我有这个问题的示例答案,我不必首先发布这个问题:p。但我会发布我尝试过的内容的 sn-p。

标签: python huggingface-transformers autoregressive-models gpt-2


【解决方案1】:

我已经能够调整您的代码以使其正常工作。此外,请确保您使用的是最新版本的转换器,您可能需要升级。

import torch
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


stop_words = ['}', ' }', '\n']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)


inputs = tokenizer.encode('some text: {', add_special_tokens=False, return_tensors='pt')

output = model.generate(
    inputs,
    do_sample=True,
    stopping_criteria=StoppingCriteriaList([stop_criteria]),

)
print(tokenizer.decode(*output))

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2011-07-26
    • 2011-02-14
    • 1970-01-01
    • 2020-09-16
    • 1970-01-01
    • 2015-05-30
    • 2015-04-26
    • 2019-09-11
    相关资源
    最近更新 更多