【问题标题】:Can this cython code be optimized?这个cython代码可以优化吗?
【发布时间】:2016-12-24 16:06:01
【问题描述】:

我第一次使用 cython 来获得一些功能的速度。该函数采用方阵A 浮点数并输出单个浮点数。它正在计算的函数是permanent of a matrix

当 A 为 30 x 30 时,我的代码目前在我的 PC 上大约需要 60 秒。

在下面的代码中,我已经从 wiki 页面实现了永久的 Balasubramanian-Bax/Franklin-Glynn 公式。我称矩阵为 M。

代码的一个复杂部分是数组 f,它用于保存数组 d 中要翻转的下一个位置的索引。数组 d 包含 +-1 的值。在循环中对 f 和 j 的操作只是快速更新格雷码的一种聪明方法。

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython


DTYPE_int = np.int
ctypedef np.int_t DTYPE_int_t
DTYPE_float = np.float64
ctypedef np.float64_t DTYPE_float_t

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
def permfunc(np.ndarray [DTYPE_float_t, ndim =2, mode='c'] M):
    cdef int n = M.shape[0]
    cdef np.ndarray[DTYPE_float_t, ndim =1, mode='c' ] d = np.ones(n, dtype=DTYPE_float)
    cdef int j =  0
    cdef int s = 1
    cdef np.ndarray [DTYPE_int_t, ndim =1, mode='c'] f = np.arange(n, dtype=DTYPE_int)
    cdef np.ndarray [DTYPE_float_t, ndim =1, mode='c'] v = M.sum(axis=0)
    cdef DTYPE_float_t p = 1
    cdef int i
    cdef DTYPE_float_t prod
    for i in range(n):
        p *= v[i]
    while (j < n-1):
        for i in range(n):
            v[i] -= 2*d[j]*M[j, i]
        d[j] = -d[j]
        s = -s
        prod = 1
        for i in range(n):
            prod *= v[i]
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]
    return p/2**(n-1)   

我已经使用了我在 cython 教程中找到的所有简单优化。有些方面我不得不承认我并不完全理解。例如,如果我创建数组 d ints,因为值永远只有 +-1,代码运行速度会慢 10%,所以我将其保留为 float64s。

我还能做些什么来加快代码速度?


这是 cython -a 的结果。如您所见,循环中的所有内容都被编译为 C,因此基本优化已经奏效。

这是 numpy 中的相同函数,它比我当前的 cython 版本慢 100 倍以上。

def npperm(M):
    n = M.shape[0]
    d = np.ones(n)
    j =  0
    s = 1
    f = np.arange(n)
    v = M.sum(axis=0)
    p = np.prod(v)
    while (j < n-1):
        v -= 2*d[j]*M[j]
        d[j] = -d[j]
        s = -s
        prod = np.prod(v)
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]
    return p/2**(n-1)  

时间已更新

这是我的 cython 版本、numpy 版本和 romeric 对 cython 代码的改进的时间安排(使用 ipython)。我已经为可重复性设置了种子。

from scipy.stats import ortho_group
import pyximport; pyximport.install()
import permlib # This loads in the functions from permlib.pyx
import numpy as np; np.random.seed(7)
M = ortho_group.rvs(23) #Creates a random orthogonal matrix 
%timeit permlib.npperm(M) # The numpy version
1 loop, best of 3: 44.5 s per loop
%timeit permlib.permfunc(M) # The cython version
1 loop, best of 3: 273 ms per loop
%timeit permlib.permfunc_modified(M) #romeric's improvement
10 loops, best of 3: 198 ms per loop
M = ortho_group.rvs(28)
%timeit permlib.permfunc(M) # The cython version run on a 28x28 matrix
1 loop, best of 3: 15.8 s per loop
%timeit permlib.permfunc_modified(M) # romeric's improvement run on a 28x28 matrix
1 loop, best of 3: 12.4 s per loop

cython 代码可以加速吗?

我使用的是 gcc,CPU 是 AMD FX 8350。

【问题讨论】:

  • 是的:你可以在 Code Review 上问这个问题。
  • @RadLexus 谢谢。但是,似乎 cython 问题在那里很少见。已经有 30 个了!
  • @eleanora:正是出于这种原因,这个数字一直很低。
  • @Rad Lexus,除非你准备在那里回答这个话题,否则不要推荐 Code Review。没有足够熟练的numpycython 程序员在该板附近提供良好和及时的答案。我目前是那里最活跃的numpy 编码员之一,我的声誉只有 1000。
  • 我已经添加了numpy标签,并去掉了optimization标签。这是关于 SO 的常见 numpy/cython 问题 - 如何从 cython 端口获得最佳速度改进。

标签: python c numpy cython


【解决方案1】:

您的cython 函数无能为力,因为它已经得到了很好的优化。但是,您仍然可以通过完全避免调用numpy 来获得适度的加速。

import numpy as np
cimport numpy as np
cimport cython
from libc.stdlib cimport malloc, free
from libc.math cimport pow

cdef inline double sum_axis(double *v, double *M, int n):
    cdef:
        int i, j
    for i in range(n):
        for j in range(n):
            v[i] += M[j*n+i]


@cython.boundscheck(False) 
@cython.wraparound(False)
def permfunc_modified(np.ndarray [double, ndim =2, mode='c'] M):
    cdef:
        int n = M.shape[0], j=0, s=1, i
        int *f = <int*>malloc(n*sizeof(int))
        double *d = <double*>malloc(n*sizeof(double))
        double *v = <double*>malloc(n*sizeof(double))
        double p = 1, prod

    sum_axis(v,&M[0,0],n)

    for i in range(n):
        p *= v[i]
        f[i] = i
        d[i] = 1

    while (j < n-1):
        for i in range(n):
            v[i] -= 2.*d[j]*M[j, i]
        d[j] = -d[j]
        s = -s
        prod = 1
        for i in range(n):
            prod *= v[i]
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]

    free(d)
    free(f)
    free(v)
    return p/pow(2.,(n-1)) 

以下是基本检查和时间安排:

In [1]: n = 12
In [2]: M = np.random.rand(n,n)
In [3]: np.allclose(permfunc_modified(M),permfunc(M))
True
In [4]: n = 28
In [5]: M = np.random.rand(n,n)
In [6]: %timeit permfunc(M) # your version
1 loop, best of 3: 28.9 s per loop
In [7]: %timeit permfunc_modified(M) # modified version posted above
1 loop, best of 3: 21.4 s per loop

编辑 让我们通过展开内部prod 循环来执行一些基本的SSE 向量化,即将上述代码中的循环更改为以下内容

# define t1, t2 and t3 earlier as doubles
t1,t2,t3=1.,1.,1.
for i in range(0,n-1,2):
    t1 *= v[i]
    t2 *= v[i+1]
# define k earlier as int
for k in range(i+2,n):
    t3 *= v[k]
p += s*(t1*t2*t3) 

现在是时候了

In [8]: %timeit permfunc_modified_vec(M) # vectorised
1 loop, best of 3: 14.0 s per loop

因此,与原始优化的 cython 代码相比,速度几乎提高了 2 倍,还不错。

【讨论】:

  • 这很棒。谢谢你。两个问题。 a) 如果我使用单精度浮点数而不是双精度浮点数会有帮助吗? b) 有什么方法可以帮助编译器矢量化代码吗?
  • 将 double 更改为 single 可能有帮助,也可能没有帮助,具体取决于架构和编译器。关于矢量化,请查看我的更新答案
  • @eleanora 哦。展开时,您必须确保 for 循环大小是步长的倍数(在本例中为 2)。我现在已经更新了答案。此外,您的代码对变量p 有溢出问题,因为它很快就会爆炸。因此,SSE 和标量代码会导致较大的M 产生不同的结果。
【解决方案2】:

免责声明:我是下述工具的核心开发者。

作为 Cython 的替代品,您可以试试 Pythran。 原始 NumPy 代码的单个注释:

#pythran export npperm(float[:, :])
import numpy as np
def npperm(M):
    n = M.shape[0]
    d = np.ones(n)
    j =  0
    s = 1
    f = np.arange(n)
    v = M.sum(axis=0)
    p = np.prod(v)
    while j < n-1:
        v -= 2*d[j]*M[j]
        d[j] = -d[j]
        s = -s
        prod = np.prod(v)
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]
    return p/2**(n-1)

编译:

> pythran perm.py

产生类似于 Cython 的加速:

> # numpy version
> python -mtimeit -r3 -n1 -s 'from scipy.stats import ortho_group; from perm import npperm; import numpy as np; np.random.seed(7); M = ortho_group.rvs(23)' 'npperm(M)'
1 loops, best of 3: 21.7 sec per loop
> # pythran version
> pythran perm.py
> python -mtimeit -r3 -n1 -s 'from scipy.stats import ortho_group; from perm import npperm; import numpy as np; np.random.seed(7); M = ortho_group.rvs(23)' 'npperm(M)' 
1 loops, best of 3: 171 msec per loop

无需重新实现 sum_axis(Pythran 会处理这个问题)。

更有趣的是,Pythran 能够通过选项标志识别几种可矢量化(在生成 SSE/AVX 内在函数的意义上)模式:

> pythran perm.py -DUSE_BOOST_SIMD -march=native
>  python -mtimeit -r3 -n10 -s 'from scipy.stats import ortho_group; from perm import npperm; import numpy as np; np.random.seed(7); M = ortho_group.rvs(23)' 'npperm(M)' 
10 loops, best of 3: 93.2 msec per loop

相对于 NumPy 版本实现了最终的 x232 加速,与展开的 Cython 版本相当,无需太多手动调整。

【讨论】:

  • 这很有趣,谢谢!我会在几天内正确测试你的代码并报告。
【解决方案3】:

此答案基于之前发布的@romeric 的代码。我更正了代码并简化了它,并添加了cdivisioncompiler 指令。

@cython.boundscheck(False) 
@cython.wraparound(False)
@cython.cdivision(True)
def permfunc_modified_2(np.ndarray [double, ndim =2, mode='c'] M):
    cdef:
        int n = M.shape[0], s=1, i, j
        int *f = <int*>malloc(n*sizeof(int))
        double *d = <double*>malloc(n*sizeof(double))
        double *v = <double*>malloc(n*sizeof(double))
        double p = 1, prod

    for i in range(n):
        v[i] = 0.
        for j in range(n):
            v[i] += M[j,i]
        p *= v[i]
        f[i] = i
        d[i] = 1
    j = 0
    while (j < n-1):
        prod = 1.
        for i in range(n):
            v[i] -= 2.*d[j]*M[j, i]
            prod *= v[i]
        d[j] = -d[j]
        s = -s            
        p += s*prod
        f[0] = 0
        f[j] = f[j+1]
        f[j+1] = j+1
        j = f[0]

    free(d)
    free(f)
    free(v)
    return p/pow(2.,(n-1))

@romeric 的原始代码没有初始化v,所以有时会得到不同的结果。此外,我将while 之前的两个循环和while 内部的两个循环分别组合在一起。

最后是对比

In [1]: from scipy.stats import ortho_group
In [2]: import permlib
In [3]: import numpy as np; np.random.seed(7)
In [4]: M = ortho_group.rvs(5)
In [5]: np.equal(permlib.permfunc(M), permlib.permfunc_modified_2(M))
Out[5]: True
In [6]: %timeit permfunc(M)
10000 loops, best of 3: 20.5 µs per loop
In [7]: %timeit permlib.permfunc_modified_2(M)
1000000 loops, best of 3: 1.21 µs per loop
In [8]: M = ortho_group.rvs(15)
In [9]: np.equal(permlib.permfunc(M), permlib.permfunc_modified_2(M))
Out[9]: True
In [10]: %timeit permlib.permfunc(M)
1000 loops, best of 3: 1.03 ms per loop
In [11]: %timeit permlib.permfunc_modified_2(M)
1000 loops, best of 3: 432 µs per loop
In [12]: M = ortho_group.rvs(28)
In [13]: np.equal(permlib.permfunc(M), permlib.permfunc_modified_2(M))
Out[13]: True
In [14]: %timeit permlib.permfunc(M)
1 loop, best of 3: 14 s per loop
In [15]: %timeit permlib.permfunc_modified_2(M)
1 loop, best of 3: 5.73 s per loop

【讨论】:

  • 太棒了!我无法弄清楚为什么代码有时会给出错误的答案。谢谢你。我认为乘以 2 可以移出循环,对吗?
  • 是的,你可以写d[i] = 2.v[i] -= d[j]*M[j,i],但这不会改变代码的性能
【解决方案4】:

嗯,一个明显的优化是将 d[i] 设置为 -2 和 +2 并避免乘以 2。我怀疑这不会有任何区别,但仍然。

另一个是确保编译结果代码的 C++ 编译器已启用所有优化(尤其是矢量化)。

计算新 v[i]s 的循环可以与Cython's support of OpenMP 并行化。在 30 次迭代时,这也可能不会产生影响。

【讨论】:

  • 感谢您的想法。您对 2 的因素是正确的。cython 代码是 C(不是 C++)。我将尝试使用编译器标志,看看它是否有任何区别。我想我需要按照您的建议并行化主 while 循环,如果我使用不同的方法来计算格雷码,这将是可行的。
猜你喜欢
  • 2012-03-13
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2013-09-11
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多