如何在 Numba Vectorize 签名中指定元组?
在numba.vectorize 函数中不能使用元组。这是因为vectorize 对这些类型的数组 的代码进行了矢量化处理。
因此,使用float, float, tuple 签名会创建一个函数,该函数需要两个包含浮点数的数组和一个包含元组的数组。问题是包含元组的数组没有 dtype - 如果您使用结构化数组而不是包含元组的数组,它可能会起作用,但我没有尝试过。
如何在 Numba jit 签名中指定元组?
在 numba 签名中指定UniTuple 的正确方法是使用numba.types.containers.UniTuple。在你的情况下:
nb.types.containers.UniTuple(nb.types.float64, 9)
所以正确的签名应该是这样的:
import numba as nb
@nb.njit(
nb.types.float64(
nb.types.float64,
nb.types.float64,
nb.types.containers.UniTuple(nb.types.float64, 9)))
def func(f1, f2, ftuple):
# ...
return f1
我经常避免显式键入我的 numba 函数 - 但当我这样做时,我发现使用 numba.typeof 非常有用,例如:
>>> nb.typeof((1.0, ) * 9)
tuple(float64 x 9)
>>> type(nb.typeof((1.0, ) * 9))
numba.types.containers.UniTuple
>>> help(type(nb.typeof((1.0, ) * 9))) # I shortened the result:
Help on class UniTuple in module numba.types.containers:
class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, numba.types.abstract.Sequence)
| UniTuple(*args, **kwargs)
|
| Type class for homogeneous tuples.
|
| Methods defined here:
|
| __init__(self, dtype, count)
| Initialize self. See help(type(self)) for accurate signature.
所以信息就在那里:它是numba.types.containes.UniTuple,你用两个参数实例化它,dtype(这里是float64)和数字(这里是9)。
如果您只想对浮点数组进行矢量化
如果你不想为元组参数向量化函数,你可以简单地在另一个函数中创建向量化函数并在那里调用它:
import numba as nb
import numpy as np
def func(E, L, fparams):
@nb.vectorize(['float64(float64, float64)'])
def fn_vec(e, l):
return e + l + fparams[1] # just to illustrate that the tuple is available
return fn_vec(E, L)
这使得元组在vectorized 函数中可用。但是,它必须创建内部函数并在每次调用外部函数时对其进行编译,因此这实际上可能会更慢。我也不确定这是否适用于target="cuda",您可能需要自己测试。