【发布时间】:2020-05-19 07:31:34
【问题描述】:
我正在使用 Numba 非 Python 模式和一些 NumPy 函数。
@njit
def invert(W, copy=True):
'''
Inverts elementwise the weights in an input connection matrix.
In other words, change the from the matrix of internode strengths to the
matrix of internode distances.
If copy is not set, this function will *modify W in place.*
Parameters
----------
W : np.ndarray
weighted connectivity matrix
copy : bool
Returns
-------
W : np.ndarray
inverted connectivity matrix
'''
if copy:
W = W.copy()
E = np.where(W)
W[E] = 1. / W[E]
return W
在这个函数中,W 是一个矩阵。但我收到以下错误。它可能与W[E] = 1. / W[E] 行有关。
File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, A), tuple(array(int64, 1d, C) x 2))
那么使用 NumPy 和 Numba 的正确方法是什么?我知道 NumPy 在矩阵计算方面做得很好。在这种情况下,NumPy 是否足够快以至于 Numba 不再提供加速?
【问题讨论】:
-
Numba 不支持“花式”索引,请查看 here。您应该沿着两个数组维度循环