【问题标题】:Cython optimizationCython 优化
【发布时间】:2016-11-02 21:40:03
【问题描述】:

我正在用 Python 编写一个相当大的模拟,并希望从 Cython 获得一些额外的性能。但是,对于下面的代码,我似乎并没有得到那么多,即使它包含一个相当大的循环。大约 10 万次迭代。

我是否犯了一些初学者的错误,或者这个循环大小只是太小而产生很大的影响? (在我的测试中,Cython 代码只快了大约 2 倍)。

import numpy as np;
cimport numpy as np;
import math

ctypedef np.complex64_t cpl_t
cpl = np.complex64

def example(double a, np.ndarray[cpl_t,ndim=2] A):

    cdef int N = 100

    cdef np.ndarray[cpl_t,ndim=3] B = np.zeros((3,N,N),dtype = cpl)

    cdef Py_ssize_t n, m;
    for n in range(N):
        for m in range(N):

            if np.sqrt(A[0,n]) > 1:
                B[0,n,m] = A[0,n] + 1j * A[0,m]

    return B;

【问题讨论】:

  • 您正在循环内进行np.sqrt 调用。这看起来像一个性能问题。无论如何,为什么它在循环中? a 永不改变。为什么不在循环前做if a <= 1: return B
  • @GWW:这对我来说就像第一行。
  • @user2357112 谢谢,这实际上是我忽略的东西,我可以将其移出。实际上,在 cython 中我应该避免诸如 np.sqrt() 或 np.exp() 之类的数学运算吗?
  • 这些回调到 Python 中,所以是的,如果你想在 GIL 之外运行(例如,用于多线程),你可能想要避免它们。
  • 您通常最好在 Cython 中使用 c 标准库数学运算。

标签: python numpy optimization cython


【解决方案1】:

您应该使用编译器指令。我用 Python 编写了你的​​函数

import numpy as np

def example_python(a, A):
    N = 100
    B = np.zeros((3,N,N),dtype = np.complex)
    aux = np.sqrt(A[0])
    for n in range(N):
        if aux[n] > 1:
            for m in range(N):
                B[0,n,m] = A[0,n] + 1j * A[0,m]
return B

在 Cython 中(您可以了解编译器指令here

import cython
import numpy as np
cimport numpy as np

ctypedef np.complex64_t cpl_t
cpl = np.complex64

@cython.boundscheck(False) # compiler directive
@cython.wraparound(False) # compiler directive
def example_cython(double a, np.ndarray[cpl_t,ndim=2] A):

    cdef int N = 100
    cdef np.ndarray[cpl_t,ndim=3] B = np.zeros((3,N,N),dtype = cpl)
    cdef np.ndarray[float, ndim=1] aux
    cdef Py_ssize_t n, m
    aux = np.sqrt(A[0,:]).real
    for n in range(N):
        if aux[n] > 1.:
            for m in range(N):
                B[0,n,m] = A[0,n] + 1j * A[0,m]
    return B

我比较两个函数

c = np.array(np.random.rand(100,100)+1.5+1j*np.random.rand(100,100), dtype=np.complex64)

%timeit example_python(100, c)
10 loops, best of 3: 61.8 ms per loop

%timeit example_cython(100, c)
10000 loops, best of 3: 134 µs per loop

在这种情况下,Cython 比 Python 快约 450 倍。

【讨论】:

  • 非常感谢,这正是我想要的!我还不知道编译器指令。
猜你喜欢
  • 1970-01-01
  • 2023-03-26
  • 2011-07-16
  • 1970-01-01
  • 1970-01-01
  • 2018-05-20
  • 2014-02-18
  • 2014-12-14
  • 1970-01-01
相关资源
最近更新 更多