【问题标题】:numpy sort error in complie with numba jit decorator使用 numba jit 装饰器编译时出现 numpy 排序错误
【发布时间】:2021-11-17 07:29:38
【问题描述】:

我正在尝试实现 numba.jit 函数来调用 numpy.sort 函数对 numpy 数组进行排序,但它失败为“检测到从 nopython 编译路径回退到对象模式编译路径”。我的代码如下:

gg = numpy.array ([[1,0,2],[1,2,1]],dtype = np.dtype((int,int)))

@nb.jit(nb.void(numba.int32[:,:]))
def kk (gg):
    np.sort(gg)

我也尝试过 njit 模式,但也出现以下错误:

"Failed in nopython mode pipeline (step: nopython frontend)
   [1m[1m[1mNo implementation of function Function(<intrinsic stub>) found for signature:

    >>> stub(array(int32, 2d, A))

   There are 2 candidate implementations:
   [1m  - Of which 2 did not match due to:
     Intrinsic of function 'stub': File: numba\core\overload_glue.py: Line 35.
       With argument(s): '(array(int32, 2d, A))':"

我检查了 numba 文档,因为它显示支持 numpy.sort 函数。我的代码有问题吗?还是排序功能只能在对象模式下工作?

【问题讨论】:

    标签: python jit numba


    【解决方案1】:

    Numba 不支持对二维数组进行排序。为了解决这个问题,您可以遍历感兴趣的维度并对每一行或每一列进行排序。 不过,这会比直接使用np.sort 慢。

    import numba as nb
    import numpy as np
    
    @nb.njit(nb.int32[:, :](nb.int32[:, :]))
    def sort_by_second_axis(arr):
        # Make a copy so we do not modify original.
        arr = arr.copy()
        for i in range(arr.shape[0]):
            arr[i].sort()
        return arr
    

    这是一个使用示例:

    prng = np.random.RandomState(42)
    x = (prng.uniform(size=16) * 10).astype("int32").reshape(4, 4)
    np.array_equal(np.sort(x), sort_by_second_axis(x))
    

    如果您使用@nb.jit(nb.void(nb.int32[:]))(即,将其应用于一维数组),警告会消失。 Numba 似乎不支持 np.sort 在具有 nopython 模式的非平面数组上。这就是为什么它必须退回到对象模式。

    import numba as nb
    import numpy as np
    
    @nb.jit(nb.void(nb.int32[:]))
    def sortme(arr):
        np.sort(arr)
    

    我还会质疑在这种情况下您是否需要 numba。 np.sort 在 C 中实现并且已经编译。它非常快,而且根据我的测试,numba 有点慢。

    import numba as nb
    import numpy as np
    
    @nb.njit(nb.int32[:](nb.int32[:]))
    def sort_numba(arr):
        return np.sort(arr)
    
    prng = np.random.RandomState(seed=42)
    x = (prng.rand(100_000) * 1_000).astype("int32")
    assert np.array_equal(np.sort(x), sort_numba(x))
    
    %timeit np.sort(x)
    # 3.73 ms ± 45.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    %timeit sort_numba(x)
    # 3.98 ms ± 37 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    【讨论】:

    • 嗨,我有一个多维数组。不是 1D .. 我正在努力提高对数据进行排序的速度,这就是为什么我想看看 numba 是否能完成这项工作。
    • @user2625363 - 请查看我的编辑
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-09-03
    • 1970-01-01
    • 2018-05-26
    • 2019-08-07
    • 2018-09-24
    • 1970-01-01
    相关资源
    最近更新 更多