【问题标题】:calculate cosine similarity in Pytorch在 Pytorch 中计算余弦相似度
【发布时间】:2020-12-14 10:14:44
【问题描述】:

我有一个候选文档嵌入张量,即cdd_doc_embeddings,大小为[batch_size, cdd_size, signal_length, embedding_dim],一个历史点击文档嵌入张量,即his_doc_embeddings,大小为[batch_size, his_size, signal_length, embedding_dim]

现在我想计算它们之间的cosine similarity,产生一个大小为[batch_size, cdd_size, his_size, signal_length, signal_length] 的张量fusion_matrix,其中entry [ b,i,j,u,v ] 表示ui 批候选文档中的第 b 个单词和 j 中的第 v 个单词第 b 个批次中的第 em> 个历史点击文档

如何使用 PyTorch 有效地做到这一点?

【问题讨论】:

  • 有个计算余弦相似度的pytorch函数here

标签: deep-learning nlp pytorch


【解决方案1】:

好的,我想通了。

import torch.nn.functional as F
# [bs, cs, 1, sl, ed]
cdd_news_embedding = F.normalize(self.embedding[cdd_news_batch].unsqueeze(dim=2), dim=-1)
# [bs, 1, hs, ed, sl]
his_news_embedding = F.normalize(self.embedding[his_news_batch].unsqueeze(dim=1), dim=-1).transpose(-1,-2)

# transform cosine similarity calculation into normalized matrix production
fusion_matrices = torch.matmul(cdd_news_embedding, his_news_embedding).unsqueeze(dim=-1)

【讨论】:

    猜你喜欢
    • 2015-05-24
    • 1970-01-01
    • 2021-05-19
    • 2011-05-21
    • 1970-01-01
    • 2017-07-07
    • 2018-04-11
    • 2017-02-03
    相关资源
    最近更新 更多