【问题标题】:Calculating similarity between Tfidf matrix and predicted vector causes memory overflow计算 Tfidf 矩阵和预测向量之间的相似度导致内存溢出
【发布时间】:2017-09-26 20:25:36
【问题描述】:

我已经使用以下代码在大约 20,000,000 个文档上生成了一个 tf-idf 模型,效果很好。问题是当我尝试使用 linear_kernel 计算相似度分数时,内存使用量会爆炸:

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel

train_file = "docs.txt"
train_docs = DocReader(train_file) #DocReader is a generator for individual documents

vectorizer = TfidfVectorizer(stop_words='english',max_df=0.2,min_df=5)
X = vectorizer.fit_transform(train_docs)

#predicting a new vector, this works well when I check the predictions
indoc = "This is an example of a new doc to be predicted"
invec = vectorizer.transform([indoc])

#This is where the memory blows up
similarities = linear_kernel(invec, X).flatten()

似乎这不应该占用太多内存,将 1-row-CSR 与 20mil-row-CSR 进行比较应该输出 1x20mil ndarray。

Justy 仅供参考:X 是内存中约 12 GB 的 CSR 矩阵(我的计算机只有 16 个)。我曾尝试研究 gensim 来替换它,但我找不到一个很好的例子。

对我缺少什么有什么想法吗?

【问题讨论】:

    标签: python scikit-learn gensim tf-idf csr


    【解决方案1】:

    您可以批量处理。这是一个基于您的代码 sn-p 的示例,但将数据集替换为 sklearn 中的内容。对于这个较小的数据集,我也以原始方式计算它以表明结果是等效的。您或许可以使用更大的批量大小。

    import numpy as np
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import linear_kernel
    from sklearn.datasets import fetch_20newsgroups
    
    train_docs = fetch_20newsgroups(subset='train')
    
    vectorizer = TfidfVectorizer(stop_words='english', max_df=0.2,min_df=5)
    X = vectorizer.fit_transform(train_docs.data)
    
    #predicting a new vector, this works well when I check the predictions
    indoc = "This is an example of a new doc to be predicted"
    invec = vectorizer.transform([indoc])
    
    #This is where the memory blows up
    batchsize = 1024
    similarities = []
    for i in range(0, X.shape[0], batchsize):
        similarities.extend(linear_kernel(invec, X[i:min(i+batchsize, X.shape[0])]).flatten())
    similarities = np.array(similarities)
    similarities_orig = linear_kernel(invec, X)
    print((similarities == similarities_orig).all())
    

    输出:

    True
    

    【讨论】:

    • 谢谢布拉德,这对我的目的非常有效!仍然不确定我为什么会出现内存溢出,可能与我对稀疏矩阵乘法的理解不佳有关:)
    猜你喜欢
    • 2020-06-11
    • 1970-01-01
    • 2018-04-29
    • 2023-03-12
    • 2016-09-29
    • 2015-07-21
    • 2016-02-15
    • 1970-01-01
    • 2023-04-05
    相关资源
    最近更新 更多