【问题标题】:Pytorch cosine similarity NxN elementsPytorch 余弦相似度 NxN 个元素
【发布时间】:2021-07-15 20:28:45
【问题描述】:

我有 128 个嵌入向量

image.shape = torch.Size([128, 512])
text.shape = torch.Size([128, 512])

我想计算包含所有元素之间余弦相似度的张量(即:

cosine.shape = torch.Size([128, 128])

其中第一行是第一张图片与所有文本的余弦相似度(128)等

目前我只是这样做,但结果是一个仅包含 N 个余弦相似度的一维数组。

cosine_similarity = torch.nn.CosineSimilarity()

cosine = cosine_similarity(image, text)

我该怎么做?我尝试转置文本,但没有成功

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    这应该可以解决问题:

    repeated_image = image.repeat_interleave(128, dim=0)
    repeated_text = text.repeat(128, 1)
    
    distances = torch.nn.functional.cosine_similarity(repeated_image, repeated_text, dim=1)
    distances = distances.reshape(128, 128)
    

    说明: 为了更快地进行计算,最好复制数据,然后使用 GPU 并行执行这些步骤,而不是在所有可能的对上进行 for 循环。前两行以适当的方式复制数据:

    repeat_interleave 将沿 dim=0 复制张量(这是一个 512 维图像)的每个元素,这将创建一个 (128*128, 512) 张量,其前 128 个元素属于第一个图像,并且以此类推。

    repeat 将整个张量沿第 0 维重复 128 次,再次创建一个 (128*128, 512) 的张量,其中第一个元素属于第一个文本,第二个元素属于第二个文本,依此类推(第 129 个元素属于第一个文本。)

    然后我们使用cosine_similarity 函数计算所有成对距离,然后重新整形以将它们变成适当的形状。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2020-08-12
      • 2011-01-01
      • 2017-12-12
      • 2013-05-24
      • 1970-01-01
      • 2014-02-25
      相关资源
      最近更新 更多