【问题标题】:How can I determine neighborhood from pairwise distance matrix efficiently?如何有效地从成对距离矩阵中确定邻域?
【发布时间】:2020-05-16 19:08:10
【问题描述】:

我在 A 组的 M 个点和 B 组的 N 个点之间有一个 M * N 成对距离矩阵。

我想为 A 组中的每个点获取 B 组中的相邻点列表。

有没有使用 pytorch 解决这个问题的有效代码?而不是多个“for”循环。

谢谢

【问题讨论】:

  • 请附上minimal reproducible example 来澄清您的问题,最好是一个小例子。此外,请分享您将用于暴力破解上述解决方案的任何实现(无论是伪代码还是实际代码)。

标签: pytorch


【解决方案1】:

你可以使用sort:

import torch

# fake pairwise distance matrix, M=3, N=4
x = torch.rand((3,4))
print(x)
# tensor([[0.7667, 0.6847, 0.3779, 0.3007],
#         [0.9881, 0.9909, 0.3180, 0.5389],
#         [0.6341, 0.8095, 0.4214, 0.7216]])

closest = torch.sort(x, dim=-1)  # default is -1, but I prefer to be clear

# let's say you want the k=2 closest points
k=2
closest_k_values = closest[0][:, :k]
closest_k_indices = closest[1][:, :k]

print(closest_k_values)
# tensor([[0.3007, 0.3779],
#         [0.3180, 0.5389],
#         [0.4214, 0.6341]])

print(closest_k_indices)
# tensor([[3, 2],
#         [2, 3],
#         [2, 0]])

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-09-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-12-08
    相关资源
    最近更新 更多