【问题标题】:huggingface transformers: truncation strategy in encode_plus拥抱脸转换器:encode_plus 中的截断策略
【发布时间】:2020-11-26 12:43:46
【问题描述】:

huggingface 的转换器库中的encode_plus 允许截断输入序列。有两个参数是相关的:truncationmax_length。我将成对的输入序列传递给encode_plus,并且需要以“截断”方式简单地截断输入序列,即,如果由两个输入texttext_pair 组成的整个序列比@987654328 长@ 它应该从右边相应地被截断。

似乎这两种截断策略都不允许这样做,而是longest_first 从最长的序列中删除标记(可以是 text 或 text_pair,但不仅仅是从序列的右侧或末尾,例如,如果 text 比 text_pair 长,这似乎会首先从文本中删除标记),only_firstonly_second 仅从第一个或第二个删除标记(因此,也不仅仅是从末尾),do_not_truncate 不会完全截断。还是我误解了这一点,实际上longest_first 可能就是我要找的东西?

【问题讨论】:

    标签: pytorch huggingface-transformers


    【解决方案1】:

    不,longest_firstcut from the right 不同。当您将截断策略设置为longest_first 时,标记器将在每次需要删除标记时比较texttext_pair 的长度,并从最长的标记中删除一个标记。例如,这可能意味着它将首先从 text_pair 中删除 3 个令牌,然后从 texttext_pair 中交替删除其余的令牌。一个例子:

    from transformers import BertTokenizerFast
    
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    
    seq1 = 'This is a long uninteresting text'
    seq2 = 'What could be a second sequence to the uninteresting text'
    
    print(len(tokenizer.tokenize(seq1)))
    print(len(tokenizer.tokenize(seq2)))
    
    print(tokenizer(seq1, seq2))
    
    print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
    print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))
    

    输出:

    9
    13
    {'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
    {'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
    [CLS] this is a long unint [SEP] what could be a second sequence [SEP]
    

    据我所知,您实际上是在寻找only_second,因为它从右侧切开(即text_pair):

    print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))
    

    输出:

    {'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
    

    当您尝试text 输入比指定的 max_length 更长时,它会引发异常。我认为这是正确的,因为在这种情况下,它不再是序列对输入。

    万一only_second 不符合您的要求,您可以简单地创建自己的截断策略。举个例子only_second手工:

    
    tok_seq1 = tokenizer.tokenize(seq1)
    tok_seq2 = tokenizer.tokenize(seq2)
    
    maxLengthSeq2 =  myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
    if len(tok_seq2) >  maxLengthSeq2:
        tok_seq2 = tok_seq2[:maxLengthSeq2]
    
    input_ids = [tokenizer.cls_token_id] 
    input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
    input_ids += [tokenizer.sep_token_id]
    
    token_type_ids = [0]*len(input_ids)
    
    input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
    input_ids += [tokenizer.sep_token_id]
    token_type_ids += [1]*(len(tok_seq2)+1) 
    
    
    attention_mask = [1]*len(input_ids)
    print(input_ids)
    print(token_type_ids)
    print(attention_mask)
    

    输出:

    [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    

    【讨论】:

      猜你喜欢
      • 2021-03-30
      • 2021-04-01
      • 2020-11-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-08-10
      • 2022-06-28
      • 2021-05-07
      相关资源
      最近更新 更多