【问题标题】:How to efficiently compare each pair of rows in a 2D matrix?如何有效地比较二维矩阵中的每一对行?
【发布时间】:2019-05-01 19:25:04
【问题描述】:

我正在处理一个子程序,我需要处理矩阵的每一行并找出当前行中包含的其他行。为了说明一行包含另一行时,请考虑如下 3x3 矩阵:

[[1, 0, 1], 

 [0, 1, 0], 

 [1, 0, 0]]

这里第 1 行包含第 3 行,因为第 1 行中的每个元素都大于或等于第 3 行,但第 1 行不包含第 2 行。

我想出了以下解决方案,但由于 for 循环(矩阵大小约为 6000x6000),它非常慢。

for i in range(no_of_rows):
    # Here Adj is the 2D matrix 
    contains = np.argwhere(np.all(Adj[i] >= Adj, axis = 1))

能否让我知道是否可以更有效地做到这一点?

【问题讨论】:

  • 尝试以np.triu_indices 开头并检查此stackoverflow.com/questions/52690963/…
  • 您想识别每个元素都包含自己吗?此外,您已将其标记为 broadcasting。你可以试试(a >= a[:, None]).all(-1),但这会很快用大数组炸毁你的内存
  • @kvitaliy 谢谢!我已经看过那个帖子了。在我的例子中,triu_indices 会产生非常大(25715206 大小)的 R 和 C 向量,并且需要很长时间。
  • @user3483203 谢谢! “包含”一词意味着将每一行与其他每一行元素进行比较。
  • 是的,所以根据定义,第一行包含第一行。你想要这些吗?

标签: python numpy vectorization numpy-ndarray array-broadcasting


【解决方案1】:

由于矩阵的大小以及问题的要求,我认为迭代是不可避免的。你不能使用广播,因为它会爆炸你的内存,所以你需要对现有的数组逐行操作。但是,您可以使用 numbanjit 来大大加快速度,而不是纯 Python 方法。


import numpy as np
from numba import njit


@njit
def zero_out_contained_rows(a):
    """
    Finds rows where all of the elements are
    equal or smaller than all corresponding
    elements of anothe row, and sets all
    values in the row to zero

    Parameters
    ----------
    a: ndarray
      The array to modify

    Returns
    -------
    The modified array

    Examples
    --------
    >>> zero_out_contained_rows(np.array([[1, 0, 1], [0, 1, 0], [1, 0, 0]]))
    array([[1, 0, 1],
            [0, 1, 0],
            [0, 0, 0]])
    """
    x, y = a.shape

    contained = np.zeros(x, dtype=np.bool_)

    for i in range(x):
        for j in range(x):
            if i != j and not contained[j]:
                equal = True
                for k in range(y):
                    if a[i, k] < a[j, k]:
                        equal = False
                        break
                contained[j] = equal

    a[contained] = 0

    return a

这会记录一行是否在另一行中使用。这可以通过短路来防止许多不必要的比较,然后最终用0 清除其他行中包含的行。


与您最初使用迭代的尝试相比,这提高了速度,并且还可以将正确的行归零。


a = np.random.randint(0, 2, (6000, 6000))

%timeit zero_out_contained_rows(a)
1.19 s ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

一旦您的尝试完成运行(目前大约 10 分钟),我将更新时间。

【讨论】:

    【解决方案2】:

    如果您的矩阵 6000x6000 超出您的需要 (6000*6000 - 6000)/2 = 17997000 次计算。

    除了使用 np.triu_indices,您可以尝试对矩阵的顶部三角形使用生成器 - 它应该会减少内存消耗。试试这个,也许会有所帮助..

    def indxs(lst):
       for i1, el1 in enumerate(lst):
          for el2 in lst[i1:][1:]:
             yield (el1, el2)
    

    【讨论】:

      猜你喜欢
      • 2014-05-14
      • 1970-01-01
      • 2018-05-25
      • 1970-01-01
      • 2021-06-20
      • 2017-06-06
      • 1970-01-01
      • 2018-11-12
      • 2013-11-05
      相关资源
      最近更新 更多