正如other answer 所说,在这种情况下不是因为len 函数,而是因为对numba 函数的调用实际上比对普通Python 函数的调用慢。
jit-ted 函数有何不同?
要理解为什么调用 numba jitted 函数更慢,我们必须明白 numba jited 函数不再是一个函数。这是一个调度程序对象:
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
print(nb_len) # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)
这个CPUDispatcher 实例代表(可能)多个基于修饰函数生成的编译函数。
这意味着当您调用CPUDispatcher 实例时,有多个步骤:
- 获取参数的类型。
- 如果没有适合这些类型参数的编译函数,则使用参数类型编译修饰函数。
- 有时:将参数转换为相应的 numba 类型。
- 调用编译后的函数。
与未装饰的函数相比,所有这些步骤都会增加开销。特别是如果没有合适的编译函数并且调度程序需要编译函数 - 或者 - 输入类型需要转换(仅适用于 Python 类型,如:列表、集合、字典)调用 CPUDispatcher 会慢很多 - 这些类型在编写 numba 0.46 时已被弃用,部分原因是,请参阅"2.11.2. Deprecation of reflection for List and Set types"。
你的情况
在您的情况下,由于编译,第一次调用 jited 函数会明显变慢。
任何后续调用只会稍微慢一些,因为 numba 必须获取参数类型,检查是否已经存在编译函数,然后调用该编译函数。有趣的是,额外的时间取决于参数的数量和该函数已编译的“重载”数量。通常这个额外的时间是微不足道的,因为该函数的作用远不止调用len。
编译时间
尽管函数非常简单,但第一次调用的编译需要大量时间:
import numpy as np
import numba as nb
def first_call(seq):
@nb.njit
def nb_len(seq):
return len(seq)
return nb_len(seq)
@nb.njit
def _nb_len(seq):
return len(seq)
def subsequent_calls(seq):
return _nb_len(seq)
t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))
%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
转化时间
另外,如果 numba 需要转换参数,它会慢很多。这只发生在 numba 无法直接处理的 Python 类型上,例如列表:
import numpy as np
import numba as nb
@nb.njit
def nb_len(seq):
return len(seq)
arr = np.random.rand(10_000)
lst = arr.tolist()
nb_len(arr)
nb_len(lst)
%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
总结
- 与普通 Python 函数相比,Numba 函数有一些额外的开销。因此,请确保你做了 numba 擅长优化的“足够”的事情,否则一个普通的 Python 函数会更快、更灵活并且更容易调试。
- numba 函数中的函数调用确实不同于 numba 函数之外的函数调用。所以
nb_len 中的len() 和py_len 中的len() 可以有完全不同的运行时间。但是在这种情况下,运行时间几乎相同。但通常最好能意识到这一点。
- 根据参数类型,numba 函数可能(在幕后)非常慢,尤其是在将 Python 类型作为参数或返回类型处理时!