In [274]: W = np.random.rand(10000, 10000)
...:
...: U = np.random.rand(10000)
...: V = np.zeros(10000)
In [275]: timeit U@W
125 ms ± 263 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [276]: timeit V@W
153 ms ± 18.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
现在考虑V 的 100 个元素非零 (1s) 的情况。稀疏实现可以是:
In [277]: Vdata=np.ones(100); Vind=np.arange(0,10000,100)
In [278]: Vind.shape
Out[278]: (100,)
In [279]: timeit Vdata@W[Vind,:]
4.99 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
这时候我有点惊讶,以为W的索引可以抵消乘法次数。
让我们更改V来验证结果:
In [280]: V[Vind]=1
In [281]: np.allclose(V@W, Vdata@W[Vind,:])
如果我必须先找到非零元素怎么办:
In [282]: np.allclose(np.where(V),Vind)
Out[282]: True
In [283]: timeit idx=np.where(V); V[idx]@W[idx,:]
5.07 ms ± 77.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
W 的大小,尤其是第 2 维可能是这种加速的一个重要因素。在这些大小下,内存管理对速度的影响与原始乘法一样大。
===
在这种情况下,sparse 的表现比我预期的要好(其他测试表明我需要 1% 左右的稀疏度才能获得时间优势):
In [294]: from scipy import sparse
In [295]: Vc=sparse.csr_matrix(V)
In [296]: Vc.dot(W)
Out[296]:
array([[46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
45.54413903, 48.28613399]])
In [297]: V.dot(W)
Out[297]:
array([46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
45.54413903, 48.28613399])
In [298]: np.allclose(Vc.dot(W),V@W)
Out[298]: True
In [299]: timeit Vc.dot(W)
1.48 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
即使是稀疏创建:
In [300]: timeit Vm=sparse.csr_matrix(V); Vm.dot(W)
2.01 ms ± 7.89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)