【问题标题】:How to ignore zeros in matrix multiplications?如何忽略矩阵乘法中的零?
【发布时间】:2019-05-24 15:21:28
【问题描述】:

假设我有一个带有随机数的 10000 x 10000 矩阵 W,以及两个 10000 暗淡向量 U 和 V,U 中有随机数,V 用零填充。 使用 numpy 或 pytorch,计算 U @ W 和 V @ W 需要相同的时间。我的问题是,有没有办法优化矩阵乘法,使其在计算过程中跳过或忽略零,从而更快地计算 V @ W 之类的东西?

import numpy as np
W = np.random.rand(10000, 10000)

U = np.random.rand(10000)
V = np.zeros(10000)

y1 = U @ W
y2 = V @ W
# computing y2 should take less amount of time than y1 since it always returns zero vector.

【问题讨论】:

标签: python numpy math


【解决方案1】:

您可以使用scipy.sparse 类来提高性能,但这完全取决于矩阵。例如,使用V 作为稀疏矩阵获得的性能将非常好。通过将U 转换为稀疏矩阵获得的效果不会很好,或者实际上可能会降低性能(在这种情况下U 实际上是密集的)。

import numpy as np
import scipy.sparse as sps

W = np.random.rand(10000, 10000)
U = np.random.rand(10000)
V = np.zeros(10000)

%timeit U @ W
125 ms ± 1.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit V @ W
128 ms ± 6.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Vsp = sps.csr_matrix(V)
Usp = sps.csr_matrix(U)
Wsp = sps.csr_matrix(W)

%timeit Vsp.dot(Wsp)
1.34 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 
%timeit Vsp @ Wsp
1.39 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit Usp @ Wsp
2.37 s ± 84.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

正如您所见,对 V @ W 使用稀疏方法有很大的改进,但实际上会降低 U @ W 的性能,因为 U 或 W 中的条目都不为零。

【讨论】:

    【解决方案2】:
    In [274]: W = np.random.rand(10000, 10000) 
         ...:  
         ...: U = np.random.rand(10000) 
         ...: V = np.zeros(10000)                                                                            
    In [275]: timeit U@W                                                                                     
    125 ms ± 263 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    In [276]: timeit V@W                                                                                     
    153 ms ± 18.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    现在考虑V 的 100 个元素非零 (1s) 的情况。稀疏实现可以是:

    In [277]: Vdata=np.ones(100); Vind=np.arange(0,10000,100)                                                
    In [278]: Vind.shape                                                                                     
    Out[278]: (100,)
    In [279]: timeit Vdata@W[Vind,:]                                                                         
    4.99 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    这时候我有点惊讶,以为W的索引可以抵消乘法次数。

    让我们更改V来验证结果:

    In [280]: V[Vind]=1                                                                                      
    In [281]: np.allclose(V@W, Vdata@W[Vind,:])  
    

    如果我必须先找到非零元素怎么办:

    In [282]: np.allclose(np.where(V),Vind)                                                                  
    Out[282]: True
    In [283]: timeit idx=np.where(V); V[idx]@W[idx,:]                                                        
    5.07 ms ± 77.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    W 的大小,尤其是第 2 维可能是这种加速的一个重要因素。在这些大小下,内存管理对速度的影响与原始乘法一样大。

    ===

    在这种情况下,sparse 的表现比我预期的要好(其他测试表明我需要 1% 左右的稀疏度才能获得时间优势):

    In [294]: from scipy import sparse                                                                       
    In [295]: Vc=sparse.csr_matrix(V)                                                                        
    In [296]: Vc.dot(W)                                                                                      
    Out[296]: 
    array([[46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
            45.54413903, 48.28613399]])
    In [297]: V.dot(W)                                                                                       
    Out[297]: 
    array([46.01437545, 50.46422246, 44.80337192, ..., 55.57660691,
           45.54413903, 48.28613399])
    In [298]: np.allclose(Vc.dot(W),V@W)                                                                     
    Out[298]: True
    
    In [299]: timeit Vc.dot(W)                                                                               
    1.48 ms ± 84.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    即使是稀疏创建:

    In [300]: timeit Vm=sparse.csr_matrix(V); Vm.dot(W)                                                      
    2.01 ms ± 7.89 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    【讨论】:

      猜你喜欢
      • 2012-12-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-04-05
      • 1970-01-01
      • 2018-04-11
      • 1970-01-01
      相关资源
      最近更新 更多