【发布时间】:2021-09-23 10:36:50
【问题描述】:
我正在尝试用 Numba 编写一个 ufunc。我阅读了this 并合并到我的代码中。所以,我运行的基本代码是
import numpy as np
arr = np.arange(15).reshape((3,5))
def myadd(x, y):
return x+y
myadd = np.frompyfunc(myadd, 2, 1)
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))
现在如果我使用 Numba 向量化
import numpy as np
import numba as nb
arr = np.arange(15).reshape((3,5))
@nb.vectorize
def myadd(x, y):
return x+y
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))
我得到了错误
Traceback (most recent call last):
File "<ipython-input-11-ea9f981e42b2>", line 7, in <module>
print(myadd.accumulate(arr,dtype=object,axis=1).astype(int))
ValueError: could not find a matching type for myadd.accumulate, requested type has type code 'O'
解决方法是什么?我在 Anaconda 的 Spyder 下使用 Numba 0.54.0、Numpy 1.20.3 和 Windows 10 上的 Python 3.8.10。
【问题讨论】:
标签: python-3.x numpy vectorization numba numpy-ufunc