【问题标题】:With the HuggingFace transformer, how can I return multiple samples when generating text?使用 HuggingFace 转换器,如何在生成文本时返回多个样本?
【发布时间】:2020-10-09 20:04:02
【问题描述】:

我要离开https://github.com/cortexlabs/cortex/blob/master/examples/pytorch/text-generator/predictor.py

但如果我通过num_samples=5,我会得到:

    generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 5 and 1 in dimension 0

代码是:

def sample_sequence(
    model,
    length,
    context,
    num_samples=1,
    temperature=1,
    top_k=0,
    top_p=0.9,
    repetition_penalty=1.0,
    device="cpu",
):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    print('context_size', context.shape)
    generated = context
    print('context', context)
    with torch.no_grad():
        for _ in trange(length):
            inputs = {"input_ids": generated}
            outputs = model(
                **inputs
            )  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.0)

            # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for _ in set(generated.view(-1).tolist()):
                next_token_logits[_] /= repetition_penalty

            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0:  # greedy sampling:
                next_token = torch.argmax(filtered_logits).unsqueeze(0)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    return generated

【问题讨论】:

    标签: python pytorch huggingface-transformers


    【解决方案1】:

    据我所知,这段代码没有提供多个样本,但您可以通过一些调整来调整它。

    这一行已经使用多项式但只返回 1:

    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
    

    改成:

    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=num_samples)
    

    现在您还需要更改结果构造。这将 next_token 与句子连接起来。你现在得到了num_samples 的 next_tokens,你需要解压它们:

    generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    

    改成:

    generated = torch.cat((generated, next_token.unsqueeze(1)), dim=1)
    

    整个函数现在应该是这样的:

    def sample_sequence(
        model,
        length,
        context,
        num_samples=1,
        temperature=1,
        top_k=0,
        top_p=0.9,
        repetition_penalty=1.0,
        device="cpu",
    ):
        context = torch.tensor(context, dtype=torch.long, device=device)
        context = context.unsqueeze(0).repeat(num_samples, 1)
        generated = context
        with torch.no_grad():
            for _ in trange(length):
                inputs = {"input_ids": generated}
                outputs = model(
                    **inputs
                )  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
                next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.0)
    
                # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
                for _ in set(generated.view(-1).tolist()):
                    next_token_logits[_] /= repetition_penalty
    
                filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                if temperature == 0:  # greedy sampling:
                    next_token = torch.argmax(filtered_logits).unsqueeze(0)
                else:
                    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=num_samples)
                generated = torch.cat((generated, next_token.unsqueeze(1)), dim=1)
        return generated
    

    最后但同样重要的是,您必须将 tokenizer.decode 调用更改为 tokenizer.batch_decode,因为返回值现​​在包含多个样本:

    tokenizer.batch_decode(output.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True)
    

    您必须自己考虑的事情是当没有有效next_token 时您想要做什么。目前您将收到如下错误消息:

    RuntimeError: 无效的多项分布(replacement=False,没有足够的非负类别进行采样)

    还有一点你需要考虑的是,他们的代码是否正确。在我进行的几次测试中,感觉创建的句子质量随着@数量的增加而下降987654329@(即使用简单的循环多次调用sample_sequence 可能质量会更好?)。我还没有使用过 GPT2,在这里无法为您提供帮助。

    【讨论】:

      猜你喜欢
      • 2020-04-02
      • 1970-01-01
      • 2015-09-06
      • 2019-10-29
      • 2011-11-29
      • 1970-01-01
      • 1970-01-01
      • 2023-04-02
      • 1970-01-01
      相关资源
      最近更新 更多