【问题标题】:Using np.min with list input in a numba function在 numba 函数中使用带有列表输入的 np.min
【发布时间】:2019-07-25 02:57:19
【问题描述】:

这里使用np.min有什么问题?为什么 numba 不喜欢在该函数中使用列表,还有其他方法可以让 np.min 工作吗?

from numba import njit
import numpy as np

@njit
def availarray(length):
    out=np.ones(14)
    if length>0:
        out[0:np.min([int(length),14])]=0
    return out

availarray(3)

该函数在min 下运行良好,但np.min 应该更快...

【问题讨论】:

  • 一般来说,应该避免在 Python 列表上使用 NumPy 函数,而在 NumPy 数组上使用 Python 函数(例如参见 stackoverflow.com/a/49908528/539338)。但这不适用于 numba 函数,因为 numba 不使用 Python 列表,而是使用 Python 列表的同构类型版本。 Numba 还将所有对内置或 NumPy 函数的函数调用替换为它们自己的实现。因此,您不能真正根据非 numba 实现的经验对 numba 函数的性能进行争论 - 确实需要衡量性能。

标签: python list numpy numba


【解决方案1】:

问题是np.min 的numba 版本需要array 作为输入。

from numba import njit
import numpy as np

@njit
def test_numba_version_of_numpy_min(inp):
    return np.min(inp)

>>> test_numba_version_of_numpy_min(np.array([1, 2]))  # works
1

>>> test_numba_version_of_numpy_min([1, 2]) # doesn't work
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function amin at 0x000001B5DBDEE598>) with argument(s) of type(s): (reflected list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.

更好的解决方案是只使用 Python 的 numba 版本min

from numba import njit
import numpy as np

@njit
def availarray(length):
    out = np.ones(14)
    if length > 0:
        out[0:min(length, 14)] = 0
    return out

因为np.minmin 实际上都是这些函数的Numba 版本(至少在njitted 函数中)min 在这种情况下也应该快得多。然而,这不太可能引起注意,因为数组的分配和将一些元素设置为零将是这里的主要运行时贡献者。

请注意,您甚至不需要在此处调用 min - 因为即使使用了更大的停止索引,切片也会隐式停止在数组的末尾:

from numba import njit
import numpy as np

@njit
def availarray(length):
    out = np.ones(14)
    if length > 0:
        out[0:length] = 0
    return out

【讨论】:

    【解决方案2】:

    要使您的代码与 numba 一起使用,您必须将 np.min 应用于 NumPy 数组,这意味着您必须将列表 [int(length),14] 转换为 NumPy 数组,如下所示

    from numba import njit
    import numpy as np
    
    @njit
    def availarray(length):
        out=np.ones(14)
        if length>0:
            out[0:np.min(np.array([int(length),14]))]=0   
        return out
    
    availarray(3)
    # array([0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
    

    【讨论】:

    • 看起来这会比只使用 min(int(length), 14) 在计算上更昂贵?
    猜你喜欢
    • 2021-06-04
    • 1970-01-01
    • 1970-01-01
    • 2021-03-12
    • 2021-09-19
    • 2021-02-11
    • 2017-06-21
    • 1970-01-01
    • 2020-10-27
    相关资源
    最近更新 更多