【问题标题】:Can't call numba functions from inside njit'ed functions无法从 njit 函数内部调用 numba 函数
【发布时间】:2020-06-21 19:31:15
【问题描述】:

numba 一起工作时,我偶然发现了非常意想不到的行为。我创建了一个nb.njit 函数,在其中我试图创建int8 numpy 数组的nb.typed.List,所以我尝试创建一个对应的numba 类型。

nb.int8[:]  # type of the list elements

所以,我通过 lsttype 关键字将此类型设置为 nb.typed.List

l = nb.typed.List(lsttype=nb.int8[:])  # list of int8 numpy ndarrays

我得到的是:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(class(int8), slice<a:b>)
 
There are 16 candidate implementations:
   - Of which 16 did not match due to:
   Overload in function 'getitem': File: <built-in>: Line <N/A>.
     With argument(s): '(class(int8), slice<a:b>)':
    No match.

我想这意味着numba 正在尝试对nb.int8 类型对象进行切片,就好像它不理解符号一样。

所以,我尝试了另一种方式,创建一个 np.int8 类型的空数组,并使用 nb.typeof 函数。

nb.typeof(np.array([], dtype=np.int8))

它返回了:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'typeof' of type Module(<module 'numba' from '/Users/.../venv37/lib/python3.7/site-packages/numba/__init__.py'>)

这个我不明白! numba怎么看不到自己

最小的例子很简单:

import numba as nb
import numpy as np

@nb.njit
def v():
    print(nb.typeof(np.array([], dtype=np.int8)))

v()

所以我尝试使用相同的功能,但没有@nb.njit

并打印出来!

array(int8, 1d, C)

另外,我尝试在函数中导入 numba,因为它看不到模块,但它产生了:

numba.core.errors.UnsupportedError: Failed in nopython mode pipeline (step: analyzing                    bytecode)
Use of unsupported opcode (IMPORT_NAME) found

我也尝试重新安装和更新numbanumpy

这是什么诡计?

【问题讨论】:

    标签: python numpy types numba


    【解决方案1】:

    我发现了一个讨厌的解决方法。

    @nb.njit
    def f(types=(nb.int8[::1], nb.float64[::1])):
        a = nb.typed.List.empty_list(types[0])
        b = nb.typed.List.empty_list(types[1])
        # and so on...
    

    [::1] 表示 C 类型的一维 numpy.ndarray

    【讨论】:

    • 但问题仍然悬而未决。有没有更好的办法?
    • 我也很好奇,因为我刚刚遇到了这个问题。我也发布了一个链接here
    【解决方案2】:

    让 numba 自己判断类型怎么样?

    @nb.njit
    def v():
        l=nb.typed.List()
        l.append(np.array([1,2,3], dtype=np.int8))
        return l
    v()
    

    为我返回ListType[array(int8, 1d, C)]([[1 2 3]])

    【讨论】:

    • 在函数中,我想用for的方法,numba无法推断类型。而且,添加一个元素然后删除它听起来非常愚蠢。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-03-05
    • 2019-09-24
    • 1970-01-01
    • 1970-01-01
    • 2017-08-18
    • 1970-01-01
    相关资源
    最近更新 更多