【问题标题】:Speed up simple distance calculation加快简单距离计算
【发布时间】:2013-09-27 16:19:19
【问题描述】:

我正在实现一个简单的代码,它计算list_A 中的点(x_a, y_a)list_B 中的所有点(x_b, y_b) 之间的距离,并返回找到的最小距离。对list_A 中的所有点重复此操作。

我的代码的MWE

# list_A points defined in array.
list_A = np.array([
    [x_data_a,  # x
     y_data_a]  # y
    ], dtype=float)

# list_B points defined in list.
list_B = [[x_data_b], [y_data_b]]

# Iterate through all data points in list_A
for ind, x_a in enumerate(list_A[0][0]):
    y_a = list_A[0][1][ind]

    # Iterate through all points in list_B.
    dist_min = 1000.
    for ind2, x_b in enumerate(list_B[0]):
        y_b = list_B[1][ind2]
        # Find distance between points.
        dist = (x_a-x_b)**2 + (y_a-y_b)**2
        if dist < dist_min:
            # Update value of min distance.
            dist_min = dist

    print 'Min dist to (', x_a, y_a, '): ', dist_min

数据格式如下:

list_A = [[[1.2 2.3 1.5 2.3 5.8 4.6 9.1] [2.5 1.0 4.6 2.4 7.4 1.1 3.2]]]

list_B = [[1.4, 5.8, 7.9], [6.1, 1.2, 3.7]]

对于大型列表/数组,这可能需要相当长的时间才能完成。可以加快速度吗?

【问题讨论】:

  • 根据您对某些答案的 cmets,我意识到我不了解您的数据格式。你是说x_data_a 本身就是一个点序列吗?您能否提供一个包含文字数值的数据结构的简单示例?
  • 请查看已编辑的问题。我认为使用zip 可能会成功,因为我收到了ValueError: XA and XB must have the same number of columns (i.e. feature dimension.) 错误。
  • 你的例子仍然没有意义。我在那里看不到任何要点,只是个别数字的列表。你不能在你的个人点内有...,因为那样你就不会知道这些点的尺寸,也无法找到它们之间的距离。请提供一个没有... 的小文字示例。
  • 抱歉,... 要缩短列表的位置。我已经更新了 que question,显示了一组真实数据的样子。无论如何,我很确定使用zip(*) 是解决我上面提到的错误的方法。
  • 是的,您的数组格式错误,您可以使用zip(*list_A) 将它们转换为正确的格式。

标签: python performance algorithm numpy distance


【解决方案1】:

运行您的代码,我得到以下信息:

Min dist to ( 1.2 2.5 ):  13.0
Min dist to ( 2.3 1.0 ):  12.29
Min dist to ( 1.5 4.6 ):  2.26
Min dist to ( 2.3 2.4 ):  13.69
Min dist to ( 5.8 7.4 ):  18.1
Min dist to ( 4.6 1.1 ):  1.45
Min dist to ( 9.1 3.2 ):  1.69

将您的数组转换为以下 Nx2 数组:

a
[[ 1.2  2.5]
 [ 2.3  1. ]
 [ 1.5  4.6]
 [ 2.3  2.4]
 [ 5.8  7.4]
 [ 4.6  1.1]
 [ 9.1  3.2]]

b
[[ 1.4  6.1]
 [ 5.8  1.2]
 [ 7.9  3.7]]

现在以下应该可以工作了:

import scipy.spatial.distance as spdist

dist_arr = spdist.cdist(a,b)

print dist_arr**2
[[ 13.    22.85  46.33]
 [ 26.82  12.29  38.65]
 [  2.26  30.05  41.77]
 [ 14.5   13.69  33.05]
 [ 21.05  38.44  18.1 ]
 [ 35.24   1.45  17.65]
 [ 67.7   14.89   1.69]]

ind = np.argmin(dist_arr,axis=1)

print ind
[0 1 0 1 2 1 2]

print dist_arr[np.arange(ind.shape[0]),ind]**2
[ 13.    12.29   2.26  13.69  18.1    1.45   1.69]

如果 ab 是 2X5000 而原始代码约为 135 秒,则需要 ~.3 秒。加速 450 倍。

【讨论】:

  • 请参阅我在 BrenBarn 的回答中提出的关于输入列表尺寸的问题。另外,你为什么选择这种特殊的元素配置?我的设置是每个父列表(A 和 B)中的两个子列表,包含 x 和 y 值,并且 x,y 对的总数在 A 和 B 中不一定相同。
  • @Gabriel:正如我在对我的回答的评论中解释的那样,他的示例已经显示了在输入列表中使用不同长度的情况。
  • @Gabriel 我已经使用cdist 复制了您的结果,与原始代码相比,速度提高了约 400 倍。
  • @Ophion 我得到了类似的结果,虽然没有你发现的那么快(我发现了约 40 倍的加速,但我将答案与更多代码混合在一起)感谢所有答案的家伙!我将此标记为已接受,因为它比 BrenBarn 的答案更详细,即使它们都基于cdist。干杯。
【解决方案2】:

使用scipy.spatial.distance.cdist,您根本不需要编写自己的距离计算代码。

编辑:您需要转置您的数据。它应该是这样的格式:

list_A = [
 [1, 2],
 [3, 4],
 [4, 5]
]

list_B = [
 [8, 9],
 [10, 11],
 [11, 12],
 [13, 14]
]

目前您拥有的是一个 X 坐标列表和一个单独的 Y 坐标列表。您需要重新定向这些,以便您拥有一个 XY 对列表。如果您的数据是普通列表,您可以使用list_A = zip(*list_A) 转置它们;如果它们是 numpy 数组,您可以使用 list_A = list_A.T 转置它们。

【讨论】:

  • 不会使用这个要求x_data_ax_data_b 具有相同的长度(y 值相同)?因为这不是我可以对我的数据施加的限制。
  • @Gabriel:不,如果我没听错的话。如果您有一个 M 点列表和另一个 N 点列表,则可以使用 cdist 查找从 M 中的每个点到 N 中的每个点的所有距离。这两个列表不必具有相等的长度。 (您要找到之间距离的点必须具有相同数量的组件 --- 即相同的尺寸 --- 但如果您想找到所有距离,无论如何都需要。)
  • @Gabriel:我现在看到了你的格式。您需要对其进行转置,以便获得 XY 对列表,而不是单独的 X 和 Y 坐标列表。请参阅我编辑的答案。
【解决方案3】:

如果你想避免使用 scipy 来获取 scipy.spatial.dist

import numpy as np

a = np.random.rand(2,1000) 
b = np.random.rand(2,1001)

min_dist = np.sqrt(np.min([np.min(np.sum((b - a[:,i,None])**2, axis=0)) 
                           for i in range(a.shape[1])]))

如果您正在寻找 a 中每个点的最小距离,则将最后一行替换为

min_dists = np.sqrt([np.min(np.sum((b - a[:,i,None])**2, axis=0)) 
                           for i in range(a.shape[1])])

【讨论】:

  • 为了获得时间,您可以避免使用np.sqrt 并使用np.argmin 而不是np.min 然后指向值的索引。接下来你只需要返回 value[index]。 (获得 np.sqrt 时间)
  • @Katsu 他想要找到最小距离,所以我必须在某个时候执行 sqrt 并且只调用一次。也许我不理解你。
  • 它只在浮动而不是列表上调用是的,对不起。您可以使用 xrange 代替 range,这里最好使用迭代器。
  • 或者你可以简单地这样做np.sqrt(np.min(np.sum((a[:,None,:]-b[:,:,None])**2,axis=0))),这里不需要循环,但这只会给出一个奇异值。我相信他正在寻找多个。
  • @Ophion 我想到了(我喜欢 python 广播!),但担心内存需求。对于我的示例, a[:,None,:]-b[:,:,None] 创建一个 2 x 1001 x 1000 数组。
猜你喜欢
  • 2017-08-10
  • 2020-05-17
  • 2019-08-01
  • 1970-01-01
  • 2013-04-23
  • 1970-01-01
  • 1970-01-01
  • 2016-05-19
  • 1970-01-01
相关资源
最近更新 更多