【发布时间】:2021-07-01 15:23:02
【问题描述】:
我试图找到一个解释,为什么我与 Numba 的矩阵乘法比使用 NumPy 的点函数慢得多。尽管我使用最基本的代码来编写 Numba 矩阵乘法函数,但我不认为性能明显变慢是由于算法。为简单起见,我考虑两个 k x k 方阵,A 和 B。我的代码如下:
1 @njit('f8[:,:](f8[:,:], f8[:,:])')
2 def numba_dot(A, B):
3
4 k=A.shape[1]
5 C = np.zeros((k, k))
6
7 for i in range(k):
8 for j in range(k):
9
10 tmp = 0.
11 for l in range(k):
12 tmp += A[i, l] * B[l, j]
13
14 C[i, j] = tmp
15
16 return C
使用两个随机矩阵 1000 x 1000 矩阵重复运行此代码,通常至少需要大约 1.5 秒才能完成。 另一方面,如果我不更新矩阵 C,即如果我删除第 14 行,或者为了测试而将其替换为例如以下行:
14 C[i, j] = i * j
代码在大约 1-5 毫秒内完成。相比之下,NumPy 的 dot 函数需要 10 ms 左右的矩阵乘法。
上面的矩阵乘法代码和这个小变化之间的运行时间差异背后的原因是什么?有没有办法在不显着降低代码性能的情况下将变量 tmp 的值存储在 C[i, j] 中?
【问题讨论】:
-
你的算法绝对没有优化。关于如何实现矩阵乘法的真实示例看起来像gist.github.com/nadavrot/5b35d44e8ba3dd718e595e40184d03f0 Numpy 在这种情况下调用 BLAS 函数 dgemm。如果输入是连续的,Numba 也会这样做。例如。
@njit('f8[:,::1](f8[:,::1], f8[:,::1])') -
感谢您的回复。由于某种原因,连续输入我也得到了相似的运行时间。
-
只需在 Numba 中调用 np.dot(使用连续数组)。在这两种情况下,numpy 和 numba 都会做同样的事情(调用外部 BLAS 库)。该链接只是为了展示现实世界的矩阵乘法是多么复杂。这是一个很好的学习,例如,但如果你只是不想计算点积,这就是这样做的方法。您也可以在 C 中尝试它。(如果不对算法进行一些改进,它仍然会慢 100 倍以上)。还要考虑编译器试图优化掉无用的部分。如果您编写 C[i, j] = i * j,整个内部循环将被检测为无用。