【问题标题】:Why is matrix multiplication with Numba slow?为什么用 Numba 进行矩阵乘法很慢?
【发布时间】:2021-07-01 15:23:02
【问题描述】:

我试图找到一个解释,为什么我与 Numba 的矩阵乘法比使用 NumPy 的点函数慢得多。尽管我使用最基本的代码来编写 Numba 矩阵乘法函数,但我不认为性能明显变慢是由于算法。为简单起见,我考虑两个 k x k 方阵,A 和 B。我的代码如下:

1     @njit('f8[:,:](f8[:,:], f8[:,:])')
2     def numba_dot(A, B):
3
4         k=A.shape[1]
5         C = np.zeros((k, k))
6
7         for i in range(k):
8             for j in range(k):
9
10                 tmp = 0.
11                for l in range(k):
12                    tmp += A[i, l] * B[l, j]
13     
14                C[i, j] = tmp
15
16         return C

使用两个随机矩阵 1000 x 1000 矩阵重复运行此代码,通常至少需要大约 1.5 秒才能完成。 另一方面,如果我不更新矩阵 C,即如果我删除第 14 行,或者为了测试而将其替换为例如以下行:

14                C[i, j] = i * j

代码在大约 1-5 毫秒内完成。相比之下,NumPy 的 dot 函数需要 10 ms 左右的矩阵乘法。

上面的矩阵乘法代码和这个小变化之间的运行时间差异背后的原因是什么?有没有办法在不显着降低代码性能的情况下将变量 tmp 的值存储在 C[i, j] 中?

【问题讨论】:

  • 你的算法绝对没有优化。关于如何实现矩阵乘法的真实示例看起来像gist.github.com/nadavrot/5b35d44e8ba3dd718e595e40184d03f0 Numpy 在这种情况下调用 BLAS 函数 dgemm。如果输入是连续的,Numba 也会这样做。例如。 @njit('f8[:,::1](f8[:,::1], f8[:,::1])')
  • 感谢您的回复。由于某种原因,连续输入我也得到了相似的运行时间。
  • 只需在 Numba 中调用 np.dot(使用连续数组)。在这两种情况下,numpy 和 numba 都会做同样的事情(调用外部 BLAS 库)。该链接只是为了展示现实世界的矩阵乘法是多么复杂。这是一个很好的学习,例如,但如果你只是不想计算点积,这就是这样做的方法。您也可以在 C 中尝试它。(如果不对算法进行一些改进,它仍然会慢 100 倍以上)。还要考虑编译器试图优化掉无用的部分。如果您编写 C[i, j] = i * j,整个内部循环将被检测为无用。

标签: python numpy numba


【解决方案1】:

本机 NumPy 实现适用于矢量化操作。如果您的 CPU 支持这些,则处理速度会快得多。当前的微处理器具有片上矩阵乘法,用于对数据传输和向量运算进行流水线化处理。

您的实现执行 k^3 次循环迭代;十亿的任何事情都需要一些不平凡的时间。 您的代码指定您希望单独执行每个单元一个单元的操作,十亿个不同的操作,而不是并行和流水线完成的大约 5k 个操作。

【讨论】:

  • 感谢您的回答。我认为我的示例表明,不仅仅是必须执行的操作数量,还有操作的类型。当按照描述修改代码并使用 Numba 编译代码时,可以在类似于 NumPy 的点函数的时间内执行三个循环。
猜你喜欢
  • 2012-11-13
  • 2012-07-14
  • 2017-07-25
  • 2023-03-12
  • 2011-03-14
  • 2012-06-22
  • 2016-11-04
  • 1970-01-01
  • 2019-03-16
相关资源
最近更新 更多