numba 可以这么慢是很奇怪的。
这并不奇怪。当您在 numba 函数中调用 NumPy 函数时,您将调用这些函数的 numba 版本。这些可以更快、更慢或与 NumPy 版本一样快。您可能很幸运,也可能很不幸(您很不幸!)。但即使在 numba 函数中,您仍然会创建许多临时数组,因为您使用 NumPy 函数(一个临时数组用于点结果,一个用于每个平方和总和,一个用于点加第一个总和),因此您不会利用numba 的可能性。
我是不是用错了?
基本上:是的。
我真的需要加快速度
好的,我试试看。
让我们从沿轴 1 调用展开平方和开始:
import numba as nb
@nb.njit
def sum_squares_2d_array_along_axis1(arr):
res = np.empty(arr.shape[0], dtype=arr.dtype)
for o_idx in range(arr.shape[0]):
sum_ = 0
for i_idx in range(arr.shape[1]):
sum_ += arr[o_idx, i_idx] * arr[o_idx, i_idx]
res[o_idx] = sum_
return res
@nb.njit
def euclidean_distance_square_numba_v1(x1, x2):
return -2 * np.dot(x1, x2.T) + np.expand_dims(sum_squares_2d_array_along_axis1(x1), axis=1) + sum_squares_2d_array_along_axis1(x2)
在我的计算机上,它已经比 NumPy 代码快 2 倍,比原始 Numba 代码快近 10 倍。
从经验上来说,让它比 NumPy 快 2 倍通常是极限(至少在 NumPy 版本不是不必要的复杂或低效的情况下),但是您可以通过展开所有内容来挤出更多内容:
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v2(x1, x2):
f1 = 0.
for i_idx in range(x1.shape[1]):
f1 += x1[0, i_idx] * x1[0, i_idx]
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
val_from_x2 = x2[o_idx, i_idx]
val += (-2) * x1[0, i_idx] * val_from_x2 + val_from_x2 * val_from_x2
val += f1
res[o_idx] = val
return res
但这仅比最新方法提高了约 10-20%。
此时您可能会意识到您可以简化代码(即使它可能不会加快速度):
import numba as nb
@nb.njit
def euclidean_distance_square_numba_v3(x1, x2):
res = np.empty(x2.shape[0], dtype=x2.dtype)
for o_idx in range(x2.shape[0]):
val = 0
for i_idx in range(x2.shape[1]):
tmp = x1[0, i_idx] - x2[o_idx, i_idx]
val += tmp * tmp
res[o_idx] = val
return res
是的,这看起来很简单,而且速度并不慢。
然而,在所有的兴奋中,我忘了提及 明显 解决方案:scipy.spatial.distance.cdist,它有一个 sqeuclidean(平方欧几里得距离)选项:
from scipy.spatial import distance
distance.cdist(x1, x2, metric='sqeuclidean')
它并不比 numba 快,但无需编写自己的函数即可使用...
测试
测试正确性并进行热身:
x1 = np.array([[1.,2,3]])
x2 = np.array([[1.,2,3], [2,3,4], [3,4,5], [4,5,6], [5,6,7]])
res1 = euclidean_distance_square(x1, x2)
res2 = euclidean_distance_square_numba_original(x1, x2)
res3 = euclidean_distance_square_numba_v1(x1, x2)
res4 = euclidean_distance_square_numba_v2(x1, x2)
res5 = euclidean_distance_square_numba_v3(x1, x2)
np.testing.assert_array_equal(res1, res2)
np.testing.assert_array_equal(res1, res3)
np.testing.assert_array_equal(res1[0], res4)
np.testing.assert_array_equal(res1[0], res5)
np.testing.assert_almost_equal(res1, distance.cdist(x1, x2, metric='sqeuclidean'))
时间安排:
x1 = np.random.random((1, 512))
x2 = np.random.random((1000000, 512))
%timeit euclidean_distance_square(x1, x2)
# 2.09 s ± 54.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_original(x1, x2)
# 10.9 s ± 158 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v1(x1, x2)
# 907 ms ± 7.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v2(x1, x2)
# 715 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit euclidean_distance_square_numba_v3(x1, x2)
# 731 ms ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit distance.cdist(x1, x2, metric='sqeuclidean')
# 706 ms ± 4.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
注意:如果您有整数数组,您可能希望将 numba 函数中的硬编码 0.0 更改为 0。