【发布时间】:2016-12-24 16:06:01
【问题描述】:
我第一次使用 cython 来获得一些功能的速度。该函数采用方阵A 浮点数并输出单个浮点数。它正在计算的函数是permanent of a matrix
当 A 为 30 x 30 时,我的代码目前在我的 PC 上大约需要 60 秒。
在下面的代码中,我已经从 wiki 页面实现了永久的 Balasubramanian-Bax/Franklin-Glynn 公式。我称矩阵为 M。
代码的一个复杂部分是数组 f,它用于保存数组 d 中要翻转的下一个位置的索引。数组 d 包含 +-1 的值。在循环中对 f 和 j 的操作只是快速更新格雷码的一种聪明方法。
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE_int = np.int
ctypedef np.int_t DTYPE_int_t
DTYPE_float = np.float64
ctypedef np.float64_t DTYPE_float_t
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def permfunc(np.ndarray [DTYPE_float_t, ndim =2, mode='c'] M):
cdef int n = M.shape[0]
cdef np.ndarray[DTYPE_float_t, ndim =1, mode='c' ] d = np.ones(n, dtype=DTYPE_float)
cdef int j = 0
cdef int s = 1
cdef np.ndarray [DTYPE_int_t, ndim =1, mode='c'] f = np.arange(n, dtype=DTYPE_int)
cdef np.ndarray [DTYPE_float_t, ndim =1, mode='c'] v = M.sum(axis=0)
cdef DTYPE_float_t p = 1
cdef int i
cdef DTYPE_float_t prod
for i in range(n):
p *= v[i]
while (j < n-1):
for i in range(n):
v[i] -= 2*d[j]*M[j, i]
d[j] = -d[j]
s = -s
prod = 1
for i in range(n):
prod *= v[i]
p += s*prod
f[0] = 0
f[j] = f[j+1]
f[j+1] = j+1
j = f[0]
return p/2**(n-1)
我已经使用了我在 cython 教程中找到的所有简单优化。有些方面我不得不承认我并不完全理解。例如,如果我创建数组 d ints,因为值永远只有 +-1,代码运行速度会慢 10%,所以我将其保留为 float64s。
我还能做些什么来加快代码速度?
这是 cython -a 的结果。如您所见,循环中的所有内容都被编译为 C,因此基本优化已经奏效。
这是 numpy 中的相同函数,它比我当前的 cython 版本慢 100 倍以上。
def npperm(M):
n = M.shape[0]
d = np.ones(n)
j = 0
s = 1
f = np.arange(n)
v = M.sum(axis=0)
p = np.prod(v)
while (j < n-1):
v -= 2*d[j]*M[j]
d[j] = -d[j]
s = -s
prod = np.prod(v)
p += s*prod
f[0] = 0
f[j] = f[j+1]
f[j+1] = j+1
j = f[0]
return p/2**(n-1)
时间已更新
这是我的 cython 版本、numpy 版本和 romeric 对 cython 代码的改进的时间安排(使用 ipython)。我已经为可重复性设置了种子。
from scipy.stats import ortho_group
import pyximport; pyximport.install()
import permlib # This loads in the functions from permlib.pyx
import numpy as np; np.random.seed(7)
M = ortho_group.rvs(23) #Creates a random orthogonal matrix
%timeit permlib.npperm(M) # The numpy version
1 loop, best of 3: 44.5 s per loop
%timeit permlib.permfunc(M) # The cython version
1 loop, best of 3: 273 ms per loop
%timeit permlib.permfunc_modified(M) #romeric's improvement
10 loops, best of 3: 198 ms per loop
M = ortho_group.rvs(28)
%timeit permlib.permfunc(M) # The cython version run on a 28x28 matrix
1 loop, best of 3: 15.8 s per loop
%timeit permlib.permfunc_modified(M) # romeric's improvement run on a 28x28 matrix
1 loop, best of 3: 12.4 s per loop
cython 代码可以加速吗?
我使用的是 gcc,CPU 是 AMD FX 8350。
【问题讨论】:
-
是的:你可以在 Code Review 上问这个问题。
-
@RadLexus 谢谢。但是,似乎 cython 问题在那里很少见。已经有 30 个了!
-
@eleanora:正是出于这种原因,这个数字一直很低。
-
@Rad Lexus,除非你准备在那里回答这个话题,否则不要推荐 Code Review。没有足够熟练的numpy和cython程序员在该板附近提供良好和及时的答案。我目前是那里最活跃的numpy编码员之一,我的声誉只有 1000。 -
我已经添加了
numpy标签,并去掉了optimization标签。这是关于 SO 的常见 numpy/cython 问题 - 如何从 cython 端口获得最佳速度改进。