【发布时间】:2020-06-08 05:47:38
【问题描述】:
我正在处理嵌套 for 循环的代码。 a_list 和 b_list 是元组列表,其中每个元组由两个张量 [(tens1, tens2), ...] 组成。我正在尝试计算a_list 中的每个tens1 与b_list 中的每个tens1 的相似性。下面是我的代码。嵌套循环似乎是一个瓶颈。有没有更好的方法(pythonic)我可以重写循环?
a2b= defaultdict(dict)
b2a= defaultdict(dict)
ab_sim = []
for a, vec_a in a_list:
for b, vec_b in b_list:
# Ignore combination if the first element in both a and b are same
if a[0] == b[0]:
continue
# Calculate cosine similarity of combination
sim = self.calculate_similarity(vec_a, vec_b )
a2b[a][b] = sim
b2a[b][a] = sim
ab_sim.append(sim)
calculate_similarity 只是一种计算余弦相似度的方法。 a_list 和 b_list 可以是任意大小。我有 b2a 和 a2b 因为我需要它们来进行其他计算。
【问题讨论】:
-
如果你真的需要所有余弦相似度的完整 n×m 矩阵,我认为没有办法计算所有这些。如果你可以编辑你的问题来解释为什么你需要所有这些余弦相似性,也许我们可以避免XY problem。
-
你意识到你实际上是在比较第二个张量吗?第一个存储到
a,第二个存储到vec_a,b和vec_b相同。你正在比较vec_a和vec_b。如果您希望a和b成为索引,则需要使用enumerate。 -
作为一个小的优化,你不一定需要
a2b和b2a,如果你用排序的元组键将它们存储在一个dict中。 -
@Adirio 我们不知道
a_list/b_list的形状,它们可能已经是索引了。 -
所以我试图计算两种语言之间句子表示的余弦相似度。我有
a_list和b_list是一个元组列表,其中每个元组都包含(句子,embedding_representation)。
标签: python python-3.x oop for-loop pytorch