【问题标题】:What is the correct way of encoding a large batch of documents with sentence transformers/pytorch?使用句子转换器/pytorch 对大量文档进行编码的正确方法是什么?
【发布时间】:2021-09-21 00:53:12
【问题描述】:

我在使用sentence_transformers 库对大量文档(超过一百万)进行编码时遇到问题。

给定一个非常相似的corpus 字符串列表。当我这样做时:

from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer('msmarco-distilbert-base-v2')
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=False)
    

几个小时后,进程似乎被卡住了,因为它永远不会完成,并且在检查进程查看器时没有任何运行。

由于我怀疑这是一个 ram 问题(GPU 板没有足够的内存来一次性容纳所有内容),我尝试将语料库分成批次,将它们转换为 NumPy 数组,然后将它们连接到单个矩阵如下:

from itertools import zip_longest
from sentence_transformers import SentenceTransformer, util
import torch
from loguru import logger
import glob
from natsort import natsorted



def grouper(iterable, n, fillvalue=np.nan):
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

embedder = SentenceTransformer('msmarco-distilbert-base-v2')

for j, e in enumerate(list(grouper(corpus, 3))):
    try:
#         print('------------------')
        for i  in filter(lambda v: v==v, e):
            corpus_embeddings=embedder.encode(i, convert_to_tensor=False)
            torch.save(corpus_embeddings, f'/Users/user/Downloads/embeddings_part_{j}.npy')
    except TypeError:
        print(j, e)
        logger.debug("TypeError in batch {batch_num}", batch_num=j)

l = []
for e in natsorted(glob.glob("/Users/user/Downloads/*.npy")):
    l.append(torch.load(e))
    corpus_embeddings = np.vstack(l)
corpus_embeddings

尽管如此,上述过程似乎不起作用。原因是当我尝试使用和不使用批处理方法的语料库的小样本时,我得到的矩阵是不同的,例如:

没有批处理方法:

array([[-0.6828216 , -0.26541945,  0.31026787, ...,  0.19941986,
         0.02366139,  0.4489861 ],
       [-0.45781   , -0.02955275,  1.0897563 , ..., -0.20077021,
        -0.37821707,  0.2248317 ],
       [ 0.8532193 , -0.13642257, -0.8872398 , ..., -0.57482916,
         0.12760726, -0.66986346],
       ...,
       [-0.04036704,  0.06745373, -0.6010259 , ..., -0.08174597,
        -0.18513843, -0.64744204],
       [-0.30782765, -0.04935509, -0.11624689, ...,  0.10423593,
        -0.14073376, -0.09206307],
       [-0.77139395, -0.08119706,  0.43753916, ...,  0.1653319 ,
         0.06861683, -0.16276269]], dtype=float32)

采用批处理方式:

array([[ 0.8532191 , -0.13642241, -0.8872397 , ..., -0.5748289 ,
         0.12760736, -0.6698637 ],
       [ 0.3679317 , -0.21968201,  0.9932826 , ..., -0.86282325,
        -0.04683857,  0.18995859],
       [ 0.23026675,  0.69587034, -0.8116473 , ...,  0.23903558,
         0.413471  , -0.23438476],
       ...,
       [ 0.923319  ,  0.4152724 , -0.3153545 , ..., -0.6863369 ,
         0.01149149, -0.51300013],
       [-0.30782777, -0.04935484, -0.11624689, ...,  0.10423636,
        -0.1407339 , -0.09206269],
       [-0.77139413, -0.08119693,  0.43753892, ...,  0.16533189,
         0.06861652, -0.16276267]], dtype=float32)

执行上述批处理程序的正确方法是什么?

更新

检查上述批处理过程后,我发现当我将上述代码(enumerate(list(grouper(corpus, 1)))) 的批处理大小设置为1 时,无论是否使用批处理,我都能获得相同的矩阵输出。因此,我的问题是,将编码器应用于大量文档的正确方法是什么?

【问题讨论】:

    标签: python numpy machine-learning nlp pytorch


    【解决方案1】:

    这一行here 在进行编码之前按文本长度对输入进行排序。我不知道为什么。

    所以,要么将这些行注释掉,要么将它们复制到您的代码中,例如

    length_sorted_idx = np.argsort([-embedder._text_length(sen) for sen in corpus])
    corpus_sorted = [corpus[idx] for idx in length_sorted_idx]
    

    然后使用corpus_sorted 对输出进行编码并将输出映射回使用length_sorted_idx

    或者只是一个一个地编码,你不需要关心哪个输出来自哪个文本。

    【讨论】:

    • 不知道。我只尝试了链接中的前几个示例,但它在我的机器上运行良好。
    • 一对一编码是否也有同样的问题?
    • 这更奇怪,因为据我所知,随着批量大小的变化输出是Batch Normalizationencode具有self.evalhere的行为
    • 有一个参数叫batch_sizehere也许你可以试试把它改成1或者其他数字看看会发生什么?
    • 不知道。我只是看到代码,但还没有尝试过。已经关闭电脑
    【解决方案2】:

    我认为问题在于转换器应该根据附近的单词进行编码,基本上就是所谓的Attention,但是当你批量制作时,单词会发生变化。例如 - 如果没有批处理,您在语料库中有 100 个单词,而 Attention 用于所有 100 个单词,但是通过批处理,您现在每批中只有 25、25、25、25 个单词,然后 Attention 仅用于它们,因此值从剩下的 75 个单词中不存在。所以基本上不应该有一种方法可以使用批处理来实现它并获得相同的结果,因为在计算 Multi-head Attention 时,所有单词都应该存在。

    这与静态编码不同,其中单词的编码与其相邻的单词和上下文无关。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2017-11-23
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多