【问题标题】:How to return boolean array in numba.njit?如何在 numba.njit 中返回布尔数组?
【发布时间】:2020-10-12 09:34:41
【问题描述】:
import numpy as np
from numba import njit, float64
from numba.experimental import jitclass

@njit(fastmath=True)
def compare(values1, values2):
    shape = values1.shape[0]
    res = np.zeros(shape, dtype=bool)
    
    for i in range(shape):
        res[i] = x[i] > y[i]
    
    return res

spce = [("x", float64[:]),
        ("y", float64[:]),
        ("z", float64[:]),]
        
@jitclass(spce)
class Math:
    
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    
    def calculate(self):
        i = compare(self.x, self.y)
        return self.z[i]

如果我这样测试:

x = np.random.rand(10)
y = np.random.rand(10)
compare(x, y)

它会返回:

Traceback (most recent call last):

  File "<ipython-input-25-586dc5d173c7>", line 3, in <module>
    compare(x, y)

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 415, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 358, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)

TypingError: No implementation of function Function(<built-in function zeros>) found for signature:
 
zeros(int64, dtype=Function(<class 'bool'>))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 504.
    With argument(s): '(int64, dtype=Function(<class 'bool'>))':
   No match.

During: resolving callee type: Function(<built-in function zeros>)
During: typing of call at <ipython-input-24-69a4f907fb89> (4)

最后我需要在 jitclass 中使用它:

x = np.random.rand(10)
y = np.random.rand(10)
z = np.random.rand(10)

m = Math(x, y, z)
m.calculate()

实际上输出只是 numpy 中的 z[x>y],但我如何在 njit 和 jitclass 中使用?

我需要他们两个来加速我的其他代码。

如果比较函数可以返回布尔数组,问题应该就解决了。

【问题讨论】:

    标签: python arrays numpy boolean numba


    【解决方案1】:

    为此,您必须使用 Numba 的特殊 bool_ 类型:

    import numpy as np
    from numba.types import bool_, int_, float32
    
    @njit(bool_[:,:](float32[:,:,:],float32[:,:,:],int_))
    def test(im1, im2, j_delta=1):
        diff = ((im1 - im2)**2).sum(axis=2)/3
        mask = np.zeros_like(diff, bool_)  # <--- like so
        for i in range(diff.shape[0]):
            for j in range(diff.shape[1]):
                mask[i,j] = diff[i,j] > 1.0
        return mask
    

    如果将bool_ 替换为bool 甚至np.bool,则会出现编译错误。

    【讨论】:

    • np.bool 不起作用,但 np.bool_ 起作用。
    猜你喜欢
    • 2012-12-29
    • 2023-03-21
    • 2021-01-05
    • 1970-01-01
    • 1970-01-01
    • 2022-01-17
    • 1970-01-01
    • 2013-03-28
    • 1970-01-01
    相关资源
    最近更新 更多