【问题标题】:Sort numpy array according to a boolean square matrix根据布尔方阵对 numpy 数组进行排序
【发布时间】:2021-09-21 06:35:55
【问题描述】:

假设一个形状为(n,) 的numpy 数组A 和一个形状为(n,n) 的布尔numpy 矩阵B

如果B[i][j]True,那么A[i] 应该排在A[j] 之前的位置。

如果B[i][j]False,那么A[i] 应该排在A[j] 之后的位置。

这些规则仅在B[i][j] 位于主对角线下方时适用。主对角线上或主对角线上方的元素应被忽略。

话虽如此,根据矩阵BA 进行排序的最有效方法是什么?

我知道有几种简单的方法可以做到这一点,但我必须执行这个操作数千次,所以我正在寻找一种计算效率高的方法来实现这个(可读性不是我主要关心的问题)。

【问题讨论】:

  • 如果您知道一些方法,请用示例展示它们!提出改进建议比从头创建示例和代码更容易。

标签: python algorithm numpy sorting numba


【解决方案1】:

我认为这可行,我已尝试将其设为纯 python3。如果你使用 numpy,它可以做得更简单。 我用https://stackoverflow.com/a/57003713/3895321

from functools import cmp_to_key

B = [[True, False, True],
     [True, True, True],
     [False, False, True], ]

A = [2, 1, 3]

tmp = list(range(len(A)))

def compare(i, j):
    if B[i][j]:
        return -1
    else:
        return 1

tmp = sorted(tmp, key=cmp_to_key(compare))

sorted_A = [A[i] for i in tmp]
print(tmp)
print(sorted_A)

【讨论】:

    【解决方案2】:

    我通过的手写合并排序抛出 numba.njit 几乎是纯 python 方法的 10 倍。请注意,这在技术上是一个 argsort,即您必须将索引传递给它并获取如果应用于数组的索引让它排序。

    @numba.njit
    def ceil_log2(x):
        n = 0
        while 2**n < x:
            n += 1
        return n
                
    @numba.njit
    def merge(arr, relation):
        out = np.zeros(len(arr), dtype='int')
        j = 0
        k = len(arr)//2
        for i in range(len(out)):
            if j == len(arr)//2:
                out[i:] = arr[k:]
                break
            elif k == len(arr):
                out[i:] = arr[j:len(arr)//2]
                break
            elif relation[arr[j], arr[k]]:
                out[i] = arr[j]
                j += 1
            else:
                out[i] = arr[k]
                k += 1
        arr[:] = out[:]
    
    @numba.njit
    def merge_sort(arr, relation):
        for i in range(1, 1+ceil_log2(len(arr))):
            idx = np.arange(len(arr))[::2**i]
            idx = [*idx, len(arr)]
            for i, j in zip(idx[:-1], idx[1:]):
                merge(arr[i:j], relation)
        return arr
    
    def example_relation(arr):
        idx = np.arange(len(arr))
        grid = np.meshgrid(idx, idx)
        relation = arr[grid[1]] <= arr[grid[0]]
        return relation
    
    np.random.seed(0)
    array = np.random.normal(size=2**14)
    relation = example_relation(array)
    
    def compare(i, j):
        if relation[i, j]:
            return -1
        else:
            return 1
    
    %time sorted(np.arange(len(array)), key=cmp_to_key(compare))
    %time merge_sort(np.arange(len(array)), relation)
    

    给我

    61.5 ms ± 785 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    6.98 ms ± 43.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    【讨论】:

    • sorted(np.arange(len(array)), key=cmp_to_key(compare)) 包装在一个函数中并使用@numba.njit 编译它怎么样?
    • @xskxzr 好吧,如果 Tim Peters(python 排序算法的作者)这么简单,他自己可能会这样做。问题是sorted 实际上不是用python 编写的,但可能是c。请记住,它并不比我的版本“差”,因为它需要更加灵活。
    猜你喜欢
    • 1970-01-01
    • 2012-08-17
    • 2021-06-29
    • 1970-01-01
    • 2023-04-05
    • 1970-01-01
    • 2016-04-18
    • 2018-01-30
    • 2021-03-08
    相关资源
    最近更新 更多