【发布时间】:2020-10-12 09:34:41
【问题描述】:
import numpy as np
from numba import njit, float64
from numba.experimental import jitclass
@njit(fastmath=True)
def compare(values1, values2):
shape = values1.shape[0]
res = np.zeros(shape, dtype=bool)
for i in range(shape):
res[i] = x[i] > y[i]
return res
spce = [("x", float64[:]),
("y", float64[:]),
("z", float64[:]),]
@jitclass(spce)
class Math:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
def calculate(self):
i = compare(self.x, self.y)
return self.z[i]
如果我这样测试:
x = np.random.rand(10)
y = np.random.rand(10)
compare(x, y)
它会返回:
Traceback (most recent call last):
File "<ipython-input-25-586dc5d173c7>", line 3, in <module>
compare(x, y)
File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 415, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 358, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\utils.py", line 80, in reraise
raise value.with_traceback(tb)
TypingError: No implementation of function Function(<built-in function zeros>) found for signature:
zeros(int64, dtype=Function(<class 'bool'>))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 504.
With argument(s): '(int64, dtype=Function(<class 'bool'>))':
No match.
During: resolving callee type: Function(<built-in function zeros>)
During: typing of call at <ipython-input-24-69a4f907fb89> (4)
最后我需要在 jitclass 中使用它:
x = np.random.rand(10)
y = np.random.rand(10)
z = np.random.rand(10)
m = Math(x, y, z)
m.calculate()
实际上输出只是 numpy 中的 z[x>y],但我如何在 njit 和 jitclass 中使用?
我需要他们两个来加速我的其他代码。
如果比较函数可以返回布尔数组,问题应该就解决了。
【问题讨论】:
标签: python arrays numpy boolean numba