【问题标题】:Searching in numpy arrays在 numpy 数组中搜索
【发布时间】:2020-09-11 09:45:38
【问题描述】:

我想从现有数据中构建一个已知坐标的数组,有什么方法可以避免所有这些循环以使其更快?尝试使用 np.where,接近但没有到达那里。 谢谢!

import numpy as np
import matplotlib.pyplot as plt

x = np.array([[0, 1, 3, 6], [0, 1, 2, 4], [0, 1, 2, 3]])
y = np.array([[0, 2, 0, 1], [1, 2, 1, 1], [2, 3, 2, 4]])
z = np.array([[100, 100, 100, 100], [100, 150, 100, 100], [100, 100, 100, 200]])

xx = np.arange(np.min(x), np.max(x) + 1, 1)
yy = np.arange(np.min(y), np.max(y) + 1, 1)
grid_x, grid_y = np.meshgrid(xx, yy)

L1, C1 = x.shape
L2, C2 = grid_x.shape
v = np.zeros((L2, C2))
n = np.zeros((L2, C2))

for l2 in range(L2):
    for c2 in range(C2):
        for l1 in range(L1):
            for c1 in range(C1):
                if grid_x[l2, c2] == x[l1, c1] and grid_y[l2, c2] == y[l1, c1]:
                    v[l2, c2] = v[l2, c2] + z[l1, c1]
                    n[l2, c2] = n[l2, c2] + 1

v = v/n

plt.imshow(v)
plt.show()

【问题讨论】:

  • 添加实际输出和预期或期望输出
  • 这将返回所需的输出,处理数千行和列太慢了。
  • 你能详细说明“接近但没有到达那里。”

标签: python arrays numpy search


【解决方案1】:

有两种方法可以加快速度:

  1. 您可以创建又大又丑的 4 维蒙版并将它们应用到您的网格中(使用 2 个 np 数组的 np.newaxis==
  2. 您可以使用 Numba (http://numba.pydata.org/numba-doc/0.17.0/user/jit.html),而且您根本不需要重写代码。基本上它使用 jit 编译,让你的代码几乎和 numpy 代码一样快。您所需要的只是用@jit 解码器包装您的函数。请注意,并非所有函数都可以通过这种方式编译。可能(不确定)您需要将嵌套循环移动到另一个函数中并用 @jit 包装它。

【讨论】:

    【解决方案2】:

    部分改进,目前关注n,匹配数。

    import numpy as np
    #import matplotlib.pyplot as plt
    
    x = np.array([[0, 1, 3, 6], [0, 1, 2, 4], [0, 1, 2, 3]])
    y = np.array([[0, 2, 0, 1], [1, 2, 1, 1], [2, 3, 2, 4]])
    z = np.array([[100, 100, 100, 100], [100, 150, 100, 100], [100, 100, 100, 200]])
    
    xx = np.arange(np.min(x), np.max(x) + 1, 1)
    yy = np.arange(np.min(y), np.max(y) + 1, 1)
    grid_x, grid_y = np.meshgrid(xx, yy)
    
    L1, C1 = x.shape
    L2, C2 = grid_x.shape
    v = np.zeros((L2, C2),int)
    n = np.zeros((L2, C2),int)
    
    for l2 in range(L2):
        for c2 in range(C2):
            for l1 in range(L1):
                for c1 in range(C1):
                    if grid_x[l2, c2] == x[l1, c1] and grid_y[l2, c2] == y[l1, c1]:
                        #v[l2, c2] += z[l1, c1]
                        n[l2, c2] += 1
    
    # with broadcasting create 4d arrays, matching x and y with their grids
    X = grid_x[:,:,None,None]==x[None,None,:,:]
    Y = grid_y[:,:,None,None]==y[None,None,:,:]
    XY = X & Y
    XY = XY.sum(axis=(2,3))
    
    #v = v/n
    print(x)
    print(y)
    print(n)
    
    print(XY)
    

    XY 匹配 n

    1032:~/mypy$ python3 stack63844601.py 
    [[0 1 3 6]
     [0 1 2 4]
     [0 1 2 3]]
    [[0 2 0 1]
     [1 2 1 1]
     [2 3 2 4]]
    [[1 0 0 1 0 0 0]             # n
     [1 0 1 0 1 0 1]
     [1 2 1 0 0 0 0]
     [0 1 0 0 0 0 0]
     [0 0 0 1 0 0 0]]
    [[1 0 0 1 0 0 0]              # XY
     [1 0 1 0 1 0 1]
     [1 2 1 0 0 0 0]
     [0 1 0 0 0 0 0]
     [0 0 0 1 0 0 0]]
    

    这应该比迭代更快,尽管在较大的情况下可能会遇到内存问题。

    grid_xgrid_y 可能是 sparse 版本,尽管这不会节省中间 4d 数组中的内存。

    我还没有完全理解这个搜索应该做什么。目前,这是一种蛮力方法,使用对 4 级嵌套循环的相当肤浅的理解。

    【讨论】:

    • 这是一个网格问题,但没有插值。有 3 个具有 x、y 和 z 的数组(它们可以展平,我们只有 2 个 for 循环)。 x 和 y 值是网格节点中的确切值,我们可以轻松地从值到索引。网格是根据最大和最小 x 和 y 值构建的。然后就是把 z 值放在正确的位置,当有几个在同一个位置时取平均值。
    猜你喜欢
    • 2012-10-22
    • 2021-06-03
    • 2016-07-31
    • 2016-05-24
    • 1970-01-01
    • 1970-01-01
    • 2018-11-19
    • 1970-01-01
    • 2021-02-28
    相关资源
    最近更新 更多