【发布时间】:2021-11-23 12:09:04
【问题描述】:
我有这两个代码执行相同但针对不同的数据结构
res = np.array([np.array([2.0, 4.0, 6.0]), np.array([8.0, 10.0, 12.0])], dtype=np.int)
%timeit np.sum(res, axis=1)
4.08 µs ± 728 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
list_obj_array = np.ndarray((2,), dtype=np.object)
list_obj_array[0] = [2.0, 4.0, 6.0]
list_obj_array[1] = [8.0, 10.0, 12.0]
v_func = np.vectorize(np.sum, otypes=[np.int])
%timeit v_func(list_obj_array)
20.6 µs ± 486 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
第二个慢了 5 倍,有没有更好的优化方法?
@nb.jit()
def nb_np_sum(arry_list):
return [np.sum(row) for row in arry_list]
%timeit nb_np_sum(list_obj_array)
30.8 µs ± 5.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
@nb.jit()
def nb_sum(arry_list):
return [sum(row) for row in arry_list]
%timeit nb_sum(list_obj_array)
13.6 µs ± 669 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
迄今为止最好的(感谢@hpaulj)
%timeit [sum(l) for l in list_obj_array]
850 ns ± 115 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
@nb.njit()
def nb_sum(arry_list):
return [sum(row) for row in arry_list]
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'>
File "<ipython-input-54-3bb48c5273bb>", line 3:
def nb_sum(arry_list):
return [sum(row) for row in arry_list]
对于更长的数组
list_obj_array = np.ndarray((n,), dtype=np.object)
for i in range(n):
list_obj_array[i] = list(range(7))
vectorized 版本更接近最佳选择(列表理解)
%timeit [sum(l) for l in list_obj_array]
23.4 µs ± 4.19 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit v_func(list_obj_array)
29.6 µs ± 4.91 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
numba 还是比较慢
%timeit nb_sum(list_obj_array)
74.4 µs ± 6.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
【问题讨论】:
-
Numpy 为小型阵列引入了相当大的延迟。如果您在这么小的数组上进行大量计算,那么 Numpy 可能不是正确的工具。但是,请注意,如果您批量执行 Numba 可用于加速此类操作(由于解释器速度慢,可能比使用纯 Python 函数/数据类型快得多)。
-
@JérômeRichard, 10xs 4 解释清楚我确实尝试过 numba 看到上面的结果并没有更快
-
你为什么还要使用带有 python 列表的
numpy数组?只需使用常规列表 -
@juanpa.arrivillaga ,上面只是一个玩具代码,我的系统完全使用 numpy ,但在某些情况下它应该处理列表
-
Numba 并不快,因为它没有成功编译函数(请参阅
Failed in nopython mode pipeline),主要是因为该函数适用于纯 Python 对象。这样的对象天生就很慢。由于 Numba 无法编译代码,因此基准函数只是一个纯 Python 函数。如果你想获得高性能,你需要处理静态类型的原生对象,例如带有原生类型项目的 Numpy 数组(或者可能具有原生类型项目的列表,尽管它们在 Numba 中速度不是那么快)。此外,请注意第一个数组项的类型是np.float64,而不是np.int。
标签: python numpy vectorization