【发布时间】:2016-05-08 04:18:19
【问题描述】:
做类似的事情
import numpy as np
a = np.random.rand(10**4, 10**4)
b = np.dot(a, a)
使用多核,运行良好。
不过,a 中的元素是 64 位浮点数(或 32 位平台中的 32 位?),我想乘以 8 位整数数组。不过,请尝试以下方法:
a = np.random.randint(2, size=(n, n)).astype(np.int8)
导致点积不使用多个内核,因此在我的 PC 上运行速度慢了约 1000 倍。
array: np.random.randint(2, size=shape).astype(dtype)
dtype shape %time (average)
float32 (2000, 2000) 62.5 ms
float32 (3000, 3000) 219 ms
float32 (4000, 4000) 328 ms
float32 (10000, 10000) 4.09 s
int8 (2000, 2000) 13 seconds
int8 (3000, 3000) 3min 26s
int8 (4000, 4000) 12min 20s
int8 (10000, 10000) It didn't finish in 6 hours
float16 (2000, 2000) 2min 25s
float16 (3000, 3000) Not tested
float16 (4000, 4000) Not tested
float16 (10000, 10000) Not tested
我知道 NumPy 使用 BLAS,它不支持整数,但如果我使用 SciPy BLAS 包装器,即。
import scipy.linalg.blas as blas
a = np.random.randint(2, size=(n, n)).astype(np.int8)
b = blas.sgemm(alpha=1.0, a=a, b=a)
计算是多线程的。现在,blas.sgemm 的运行时间与 float32 的 np.dot 完全相同,但对于非浮点数,它将所有内容转换为 float32 并输出浮点数,这是 np.dot 不做的。 (此外,b 现在处于F_CONTIGUOUS 顺序,这是一个较小的问题)。
所以,如果我想进行整数矩阵乘法,我必须执行以下操作之一:
- 使用 NumPy 令人痛苦的缓慢
np.dot,很高兴我能保留 8 位整数。 - 使用 SciPy 的
sgemm并使用 4 倍内存。 - 使用 Numpy 的
np.float16并且只使用 2 倍内存,但需要注意的是,np.dot在 float16 数组上比在 float32 数组上慢得多,比 int8 更慢。 - 为多线程整数矩阵乘法找到一个优化的库(实际上,Mathematica 可以做到这一点,但我更喜欢 Python 解决方案),理想情况下支持 1 位数组,虽然 8 位数组也很好......(我实际上的目标是在有限域 Z/2Z 上进行矩阵乘法,并且我知道我可以使用 Sage 来做到这一点,这很 Pythonic,但是,再次,有什么严格意义上的 Python 吗?)
我可以遵循选项 4 吗?有这样的图书馆吗?
免责声明:我实际上是在运行 NumPy + MKL,但我在 vanilly NumPy 上尝试了类似的测试,结果类似。
【问题讨论】:
-
作为选项 4 的可能答案,bitbucket.org/malb/m4ri 看起来很有趣。 “M4RI 是一个在 F2 上具有密集矩阵的快速算术库。”我想这就是 Sage 已经在使用的东西,但我看不出有什么理由不能直接从 Python 中使用它,并使用合适的 Cython 包装器。 (事实上,您可能已经在 Sage 源代码中找到了这样的包装器。)
-
还没有人提到
numpy.einsum,但这可能是一个不错的选择 5。 -
请注意,如果要避免整数溢出,则需要将结果转换为更大的值。如果每个元素是 0 或 1,则需要一个整数格式,该格式可以保存至少
n的值,以保证不会溢出。对于您的示例,n=10000, (u)int16 应该就足够了。你的真实矩阵是稀疏的吗?如果是这样,您最好使用scipy.sparse.csr_matrix。 -
您能否为您要解决的整体问题提供更多背景信息?将大整数矩阵相乘是一件相当不寻常的事情。更多地了解这些矩阵的属性将特别有用。这些值总是 0 还是 1?如果它们可以更大,那么您很可能会发现自己最终受到可以使用 uint64 表示的最大整数的限制。矩阵是如何生成的?它们是否有任何特殊结构(例如对称、块、带等)?
标签: python multithreading numpy matrix-multiplication blas