【问题标题】:What is a fast(er) way to get the center points of objects represented in a 2D numpy array?获取 2D numpy 数组中表示的对象的中心点的快速(更好)方法是什么?
【发布时间】:2021-11-17 23:40:14
【问题描述】:

我有一个存储为 2D numpy 数组的图像掩码,其中的值指示图像中是否存在已分割的对象(0 = 无对象,1..n = 对象 1 到 n)。我想为每个表示对象中心的对象获取一个坐标。它不必是完全准确的质心或重心。我只是取数组中包含每个对象的所有单元格的 x 和 y 索引的平均值。我想知道是否有比我目前的方法更快的方法:

for obj in np.unique(mask):
    if obj == 0:
        continue
    x, y = np.mean(np.where(mask == obj), axis=1)

这是一个可重现的例子:

import numpy as np
mask = np.array([
    [0,0,0,0,0,2,0,0,0,0],
    [0,1,1,0,2,2,2,0,0,0],
    [0,0,1,0,2,2,2,0,0,0],
    [0,0,0,0,0,0,0,0,0,0],
    [0,3,3,3,0,0,4,0,0,0],
    [0,0,0,0,0,4,4,4,0,0],
    [0,0,0,0,0,0,4,0,0,0],
])

points = []
for obj in np.unique(mask):
    if obj == 0:
        continue
    points.append(np.mean(np.where(mask == obj), axis=1))
print(points)

这个输出:

[array([1.33333333, 1.66666667]),
 array([1.28571429, 5.        ]),
 array([4., 2.]),
 array([5., 6.])]

【问题讨论】:

  • 定义“慢”的含义。
  • 这与是否有更快的方法有什么关系?
  • 你能做一些可重现的例子吗?
  • 请注意,xy 似乎倒置了。但我认为关于这个国家的惯例会发生变化。这里y是这里最后一个维度相关的值(一般是最连续的那个)。
  • 我添加了一个例子。是的,x 和 y 是倒置的。

标签: python numpy


【解决方案1】:

我想出了另一种方法,它的速度似乎快了大约 3 倍:

import numpy as np
mask = np.array([
    [0,0,0,0,0,2,0,0,0,0],
    [0,1,1,0,2,2,2,0,0,0],
    [0,0,1,0,2,2,2,0,0,0],
    [0,0,0,0,0,0,0,0,0,0],
    [0,3,3,3,0,0,4,0,0,0],
    [0,0,0,0,0,4,4,4,0,0],
    [0,0,0,0,0,0,4,0,0,0],
])

flat = mask.flatten()
split = np.unique(np.sort(flat), return_index=True)[1]
points = []
for inds in np.split(flat.argsort(), split)[2:]:
    points.append(np.array(np.unravel_index(inds, mask.shape)).mean(axis=1))
print(points)

我想知道是否可以将 for 循环替换为可能更快的 numpy 操作。

【讨论】:

    【解决方案2】:

    你可以copy this answer(如果这个答案对你有用,也给他们一个赞成票)并使用稀疏矩阵而不是 np 数组。但是,这仅证明对于大型阵列更快,随着阵列越大,速度越快:

    import numpy as np, time
    from scipy.sparse import csr_matrix
    
    def compute_M(data):
        cols = np.arange(data.size)
        return csr_matrix((cols, (np.ravel(data), cols)),
                          shape=(data.max() + 1, data.size))
    
    def get_indices_sparse(data,M):
        #M = compute_M(data)
        return [np.mean(np.unravel_index(row.data, data.shape),1) for R,row in enumerate(M) if R>0]
    
    def gen_random_mask(C, n, m):
        mask = np.zeros([n,m],int)
        for i in range(C):
            x = np.random.randint(n)
            y = np.random.randint(m)
            mask[x:x+np.random.randint(n-x),y:y+np.random.randint(m-y)] = i
        return mask
    
    N = 100
    C = 4
    for S in [10,100,1000,10000]:
        mask = gen_random_mask(C, S, S)
        print('Time for size {:d}x{:d}:'.format(S,S))
        s = time.time()
        for _ in range(N):
            points = []
            for obj in np.unique(mask):
                if obj == 0:
                    continue
                points.append(np.mean(np.where(mask == obj), axis=1))
        points_np = np.array(points)
        print('NP: {:f}'.format((time.time() - s)/N))
        mask_s = compute_M(mask)
        s = time.time()
        for _ in range(100):
            points = get_indices_sparse(mask,mask_s)
        print('Sparse: {:f}'.format((time.time() - s)/N))
        np.testing.assert_equal(points,points_np)
    

    这会导致:

    Time for size 10x10:
    NP: 0.000066
    Sparse: 0.000226
    Time for size 100x100:
    NP: 0.000207
    Sparse: 0.000253
    Time for size 1000x1000:
    NP: 0.018662
    Sparse: 0.004472
    Time for size 10000x10000:
    NP: 2.545973
    Sparse: 0.501061
    

    【讨论】:

      【解决方案3】:

      问题可能来自np.where(mask == obj),它一遍又一遍地迭代整个mask 数组。当有很多对象时,这是一个问题。您可以使用 group-by 策略有效地解决此问题。但是,Numpy 还没有提供这样的操作。您可以使用 sort 后跟 split 来实现它。但是排序通常效率不高。另一种方法是让 Numpy 在 unique 调用中返回索引,以便您可以累积有关对象的值(如 reduce-by-key其中归约运算符是加法,键是对象整数)。最后通过简单的除法即可得到均值。

      objects, inverts, counts = np.unique(mask, return_counts=True, return_inverse=True)
      
      # Reduction by object
      x = np.full(len(objects), 0.0)
      y = np.full(len(objects), 0.0)
      xPos = np.repeat(np.arange(mask.shape[0]), mask.shape[1])
      yPos = np.tile(np.arange(mask.shape[1]), reps=mask.shape[0])
      np.add.at(x, inverts, xPos)
      np.add.at(y, inverts, yPos)
      
      # Compute the final mean from the sum
      x /= counts
      y /= counts
      
      # Discard the first item (when obj == 0)
      x = x[1:]
      y = y[1:]
      

      如果您需要更快的速度,您可以使用 Numba 并手动执行缩减(可能在 parallel 中)。

      编辑:如果你真的需要输出列表,你可以使用points = list(np.stack([x, y]).T),但是使用列表而不是 Numpy 数组会很慢(而且内存效率也不高)。

      【讨论】:

        猜你喜欢
        • 2016-09-29
        • 1970-01-01
        • 1970-01-01
        • 2017-11-29
        • 1970-01-01
        • 2021-01-28
        • 1970-01-01
        • 2011-03-30
        • 1970-01-01
        相关资源
        最近更新 更多