【问题标题】:Speed up solving a triangular linear system with numpy?用 numpy 加速求解三角线性系统?
【发布时间】:2013-03-18 04:44:47
【问题描述】:

我有一个方阵 S (160 x 160) 和一个巨大的矩阵 X (160 x 250000)。两者都是密集的numpy数组。

我的目标:找到 Q 使得 Q = inv(chol(S)) * X,其中 chol(S) 是 S 的下 Cholesky 分解。

当然,一个简单的解决方案是

cholS = scipy.linalg.cholesky( S, lower=True)
scipy.linalg.solve( cholS, X )

我的问题:这个解决方案在 python 中明显比我在 Matlab 中尝试时慢 (>2x)。以下是一些计时实验:

timeit np.linalg.solve( cholS, X)
1 loops, best of 3: 1.63 s per loop

timeit scipy.linalg.solve_triangular( cholS, X, lower=True)
1 loops, best of 3: 2.19 s per loop

timeit scipy.linalg.solve( cholS, X)
1 loops, best of 3: 2.81 s per loop

[matlab]
cholS \ X
0.675 s

[matlab using only one thread via -singleCompThread]
cholS \ X
1.26 s

基本上,我想知道:(1)我可以在 python 中达到类似 Matlab 的速度吗? (2) 为什么 scipy 版本这么慢?

求解器应该能够利用 chol(S) 是三角形的事实。然而,使用 numpy.linalg.solve() 比 scipy.linalg.solve_triangular() 更快,即使 numpy 调用根本不使用三角形结构。是什么赋予了?当我的矩阵是三角形时,matlab 求解器似乎会自动检测,但 python 不能。

我很乐意使用对 BLAS/LAPACK 例程的自定义调用来求解三角线性系统,但我真的不想自己编写该代码。

作为参考,我使用的是 scipy 版本 11.0 和 Enthought python 发行版(它使用英特尔的 MKL 库进行矢量化),所以我认为我应该能够达到类似于 Matlab 的速度。

【问题讨论】:

    标签: python numpy scipy linear-algebra


    【解决方案1】:

    TL;DR:当你有一个三角形系统时,不要使用 numpy 或 scipy 的 solve,只需使用 scipy.linalg.solve_triangular 至少带有 check_finite=False 关键字参数,以获得快速且非破坏性的解决方案。


    在偶然发现numpy.linalg.solvescipy.linalg.solve(以及scipy's lu_solve 等)之间存在一些差异后,我发现了这个线程。我没有 Enthought 的基于 MKL 的 Numpy/Scipy,但我希望我的发现可以在某种程度上对您有所帮助。

    使用 Numpy 和 Scipy 的预构建二进制文件(32 位,在 Windows 7 上运行):

    1. 在求解向量 X 时,我发现 numpy.linalg.solvescipy.linalg.solve 之间存在显着差异(即,X 是 160 x 1)。 Scipy 运行时是 numpy 的 1.23 倍,我认为这是可观的。

    2. 但是,大部分差异似乎是由于 scipy 的 solve 检查无效条目。将check_finite=False 传递给 scipy.linalg.solve 时,scipy 的 solve 运行时是 1.02x numpy。

    3. Scipy 使用破坏性更新的求解,即overwrite_a=True, overwrite_b=True 比 numpy 的求解(非破坏性)稍快。 Numpy 的求解运行时是 1.021x 破坏性 scipy.linalg.solve。只有check_finite=False 的 Scipy 的运行时间是破坏性案例的 1.04 倍。总之,破坏性scipy.linalg.solve 比这两种情况都快得多。

    4. 以上是矢量X。如果我将X 设为一个宽数组,特别是160 x 10000,scipy.linalg.solvecheck_finite=False 基本上与check_finite=False, overwrite_a=True, overwrite_b=True 一样快。 Scipy 的solve(没有任何特殊关键字)运行时是这个“不安全”(check_finite=False)调用的 1.09 倍。 Numpy 的 solve 在这个数组 X 案例中的运行时间是 scipy 最快的 1.03 倍。

    5. scipy.linalg.solve_triangular 在这两种情况下都能显着提高速度,但您必须关闭输入检查,即传入check_finite=False。对于向量和数组X,最快求解的运行时间分别是solve_triangular 的5.68x 和1.76x,check_finite=False

    6. solve_triangular 具有破坏性计算 (overwrite_b=True) 不会让您在 check_finite=False 之上加速(实际上对数组 X 的情况有些伤害)。

    7. 我,无知,以前不知道 solve_triangular 并使用 scipy.linalg.lu_solve 作为三角求解器,即,而不是 solve_triangular(cholS, X)lu_solve((cholS, numpy.arange(160)), X) (两者产生相同的答案)。但是我发现以这种方式使用的lu_solve 对于向量X 的情况有1.07x 不安全solve_triangular 的运行时间,而对于数组X 情况它的运行时间是1.76 倍。我不确定为什么数组Xlu_solve 比向量X 慢得多,但教训是使用solve_triangular(没有无限检查)。

    8. 将数据复制为 Fortran 格式似乎一点也不重要。也不会转换为numpy.matrix

    我不妨将我的非 MKL Python 库与单线程 (maxNumCompThreads=1) Matlab 2013a 进行比较。上面最快的 Python 实现对于向量 X 案例的运行时间延长了 4.5 倍,对于胖矩阵 X 案例的运行时间延长了 6.3 倍。

    但是,这是我用来对这些进行基准测试的 Python 脚本,也许拥有 MKL 加速 Numpy/Scipy 的人可以发布他们的数字。请注意,我只是注释掉n = 10000 行以禁用胖矩阵X 案例并执行n=1 向量案例。 (对不起。)

    import scipy.linalg as sla
    import numpy.linalg as nla
    from numpy.random import RandomState
    from timeit import timeit
    import numpy as np
    
    RNG = RandomState(69)
    
    m=160
    n=1
    #n=10000
    Ac = RNG.randn(m,m)
    if 1:
        Ac = np.triu(Ac)
    
    bc = RNG.randn(m,n)
    Af = Ac.copy("F")
    bf = bc.copy("F")
    
    if 0: # Save to Matlab format
        import scipy.io as io
        io.savemat("b_%d.mat"%(n,), dict(A=Ac, b=bc))
        import sys
        sys.exit(0)
    
    def lapper(fn, source, **kwargs):
        Alocal = source[0].copy()
        blocal = source[1].copy()
        fn(Alocal, blocal,**kwargs)
    
    laps = (1000 if n<=1 else 100)
    def printer(t, s=''):
        print ("%g seconds, %d laps, " % (t/float(laps), laps)) + s
        return t/float(laps)
    
    t=[]
    print "C"
    t.append(printer(timeit(lambda: lapper(sla.solve, (Ac,bc)), number=laps),
                     "scipy.solve"))
    t.append(printer(timeit(lambda: lapper(sla.solve, (Ac,bc), check_finite=False),
                            number=laps), "scipy.solve, infinite-ok"))
    t.append(printer(timeit(lambda: lapper(nla.solve, (Ac,bc)), number=laps),
                     "numpy.solve"))
    
    #print "F" # Doesn't seem to matter
    #printer(timeit(lambda: lapper(sla.solve, (Af,bf)), number=laps))
    #printer(timeit(lambda: lapper(nla.solve, (Af,bf)), number=laps))
    
    print "sla with tweaks"
    t.append(printer(timeit(lambda: lapper(sla.solve, (Ac,bc), overwrite_a=True,
                                  overwrite_b=True,  check_finite=False),
                            number=laps), "scipy.solve destructive"))
    
    print "Tri"
    t.append(printer(timeit(lambda: lapper(sla.solve_triangular, (Ac,bc)),
                            number=laps), "scipy.solve_triangular"))
    t.append(printer(timeit(lambda: lapper(sla.solve_triangular, (Ac,bc),
                                  check_finite=False), number=laps),
                     "scipy.solve_triangular, inf-ok"))
    t.append(printer(timeit(lambda: lapper(sla.solve_triangular, (Ac,bc),
                                           overwrite_b=True, check_finite=False),
                            number=laps), "scipy.solve_triangular destructive"))
    
    print "LU"
    piv = np.arange(m)
    t.append(printer(timeit(lambda: lapper(
        lambda X,b: sla.lu_solve((X, piv),b,check_finite=False),
        (Ac,bc)), number=laps), "LU"))
    
    print "all times:"
    print t
    

    上述向量情况下脚本的输出,n=1:

    C
    0.000739405 seconds, 1000 laps, scipy.solve
    0.000624746 seconds, 1000 laps, scipy.solve, infinite-ok
    0.000590003 seconds, 1000 laps, numpy.solve
    sla with tweaks
    0.000608365 seconds, 1000 laps, scipy.solve destructive
    Tri
    0.000208711 seconds, 1000 laps, scipy.solve_triangular
    9.38371e-05 seconds, 1000 laps, scipy.solve_triangular, inf-ok
    9.37682e-05 seconds, 1000 laps, scipy.solve_triangular destructive
    LU
    0.000100215 seconds, 1000 laps, LU
    all times:
    [0.0007394047886284343, 0.00062474593940593, 0.0005900030818282472, 0.0006083650710913095, 0.00020871054023307778, 9.383710445114923e-05, 9.37682389063692e-05, 0.00010021534750467032]
    

    上述矩阵案例n=10000的脚本输出:

    C
    0.118985 seconds, 100 laps, scipy.solve
    0.113687 seconds, 100 laps, scipy.solve, infinite-ok
    0.115569 seconds, 100 laps, numpy.solve
    sla with tweaks
    0.113122 seconds, 100 laps, scipy.solve destructive
    Tri
    0.0725959 seconds, 100 laps, scipy.solve_triangular
    0.0634396 seconds, 100 laps, scipy.solve_triangular, inf-ok
    0.0638423 seconds, 100 laps, scipy.solve_triangular destructive
    LU
    0.1115 seconds, 100 laps, LU
    all times:
    [0.11898513112988955, 0.11368747217793944, 0.11556863916356903, 0.11312182352918797, 0.07259593807427585, 0.0634396208970783, 0.06384230931663318, 0.11150022257648459]
    

    请注意,上述 Python 脚本可以将其数组保存为 Matlab .MAT 数据文件。这目前被禁用(if 0,抱歉),但如果启用,您可以在完全相同的数据上测试 Matlab 的速度。这是 Matlab 的计时脚本:

    clear
    q = load('b_10000.mat');
    A=q.A;
    b=q.b;
    clear q
    matrix_time = timeit(@() A\b)
    
    q = load('b_1.mat');
    A=q.A;
    b=q.b;
    clear q
    vector_time = timeit(@() A\b)
    

    您需要来自 Mathworks File Exchange 的 timeit 函数:http://www.mathworks.com/matlabcentral/fileexchange/18798-timeit-benchmarking-function。这会产生以下输出:

    matrix_time =
        0.0099989
    vector_time =
       2.2487e-05
    

    这个实证分析的结果是,至少在 Python 中,当你有一个三角系统时,不要使用 numpy 或 scipy 的 solve,只需使用 scipy.linalg.solve_triangular 和至少 check_finite=False 关键字参数来实现快速和无损解决方案。

    【讨论】:

      【解决方案2】:

      为什么不直接使用公式:Q = inv(chol(S)) * X,这是我的测试:

      import scipy.linalg
      import numpy as np
      
      N = 160
      M = 100000
      S = np.random.randn(N, N)
      B = np.random.randn(N, M)
      S = np.dot(S, S.T)
      
      cS = scipy.linalg.cholesky(S, lower=True)
      Y1 = scipy.linalg.solve(cS, B)
      icS = scipy.linalg.inv(cS)
      Y2 = np.dot(icS, B)
      
      np.allclose(Y1, Y2)
      

      输出:

      True
      

      这是时间测试:

      %time scipy.linalg.solve(cholS, B)
      %time np.linalg.solve(cholS, B)
      %time scipy.linalg.solve_triangular(cholS, B, lower=True)
      %time ics=scipy.linalg.inv(cS);np.dot(ics, B)
      

      输出:

      CPU times: user 2.07 s, sys: 0.00 s, total: 2.07 s
      Wall time: 2.08 s
      CPU times: user 1.93 s, sys: 0.00 s, total: 1.93 s
      Wall time: 1.92 s
      CPU times: user 1.12 s, sys: 0.00 s, total: 1.12 s
      Wall time: 1.13 s
      CPU times: user 0.71 s, sys: 0.00 s, total: 0.71 s
      Wall time: 0.72 s
      

      我不知道为什么scipy.linalg.solve_triangular 在你的系统上比numpy.linalg.solve 慢,但是inv 版本是最快的。

      【讨论】:

      • 来自Numerical Recipes: The Art of Scientific Computing, page 41 of the latest edition "如果我们得到矩阵逆矩阵,我们以后不能让它乘以一个新的右手边来得到一个额外的解决方案吗?这确实有效,但它给出了一个答案很容易受到舍入误差的影响,并且不如新向量在第一个实例中包含在右侧向量集合中那么好。”
      • @Jaime 实际上它的accuracy 并不像通常想象的那么糟糕,但这仍然不是解决任何线性系统的好方法。 “一些广泛使用的教科书让读者相信,通过将 b 乘以计算出的逆 inv(A) 来求解线性方程组 Ax = b 是不准确的。[...] 事实上,在对逆如何进行合理假设的情况下计算后,x=inv(A)*b 与最佳后向稳定求解器计算的解一样准确。"
      • 我可以确认使用显式逆确实比调用“求解”至少快 2 倍。由于经久不衰的民间智慧,我什至没有尝试过这个解决方案,即使用显式逆很容易出现不准确性。我会去试试这个,看看是否有明显的准确性问题。谢谢。
      • @jorgeca 当然,“在合理的假设下”结果将是相同的。问题是知道你的系统是否满足那些“合理的假设”。如果没有事先的测试或理论上的证明,我认为沿着这条路节省 1 分并不是一个好主意。计算机时间。
      • @jorgeca 的链接论文非常有趣。它说如果 b 不是“坏”,即如果它不是(几乎)正交,则将 Ax=b 求解为 inv(A)*b 与“更好”的方法(LU 等)一样稳定和准确到“小奇异子空间”(具有小奇异值的左奇异向量的跨度)。即使 b “坏”,它仍然是 accurate(inv(A)*b 与 A\b 一样接近实际解决方案)但不是稳定(A*(inv(A)*b) 将比 A*(A\b) 离 b 更远)。所以,如果你只是想要一个解决方案, inv(A)*b 将是准确的(通常是稳定的)。
      【解决方案3】:

      有几件事可以尝试:

      • X = X.copy('F') # 使用fortran-order数组,避免复制

      • Y = solve_triangular(cholS, X, overwrite_b=True) # 避免再次复制,但删除X的内容

      • Y = solve_triangular(cholS, X, check_finite=False) # Scipy >= 0.12 only --- 但似乎对速度没有太大影响...

      有了这两个,它应该几乎等同于直接调用 MKL 而没有缓冲区副本。

      我无法重现 np.linalg.solvescipy.linalg.solve 具有不同速度的问题 --- 使用我拥有的 BLAS + LAPACK 组合,两者的速度似乎相同。

      【讨论】:

        猜你喜欢
        • 2020-04-03
        • 2012-12-03
        • 2012-09-27
        • 2013-09-07
        • 2017-12-13
        • 1970-01-01
        • 1970-01-01
        • 2018-01-17
        • 2019-01-15
        相关资源
        最近更新 更多