【问题标题】:how can I optimize jonswap spectrum in cython如何在 cython 中优化 jonswap 频谱
【发布时间】:2019-04-10 23:17:43
【问题描述】:

我正在尝试提高使用 cython 计算 Jonswap 频谱的性能。但它似乎比原始代码慢得多。我该如何改进?

cython 代码:

from libc.math cimport exp
from libc.stdlib cimport malloc
import numpy as np
cimport numpy as np

DTYPE_float = np.float64
ctypedef np.float64_t DTYPE_float_t

def jonswap(np.ndarray[DTYPE_float_t, ndim=1, mode ='c'] w, double Hs, double Tp, double gamma = 3.7):
    '''
    get Jonswap spectra
    :param w: np.ndarray Angular Frequency
    '''
    cdef:
        int n = w.shape[0]
        double *sigma = <double*>malloc(n * sizeof(double)) 
        double *a = <double*>malloc(n * sizeof(double)) 
        int i 
    cdef double wp
    cdef np.ndarray[DTYPE_float_t, ndim=1, mode='c'] sj = np.ones(n, dtype=DTYPE_float)

    wp = 2 * np.pi / Tp
    for i in range(n):
        sigma[i] = 0.07 if w[i] < wp else 0.09
        a[i] = exp(-0.5 * pow((w[i] - wp) / (sigma[i] * w[i]), 2.0))
        sj[i] = 320 * pow(Hs, 2) * pow(w[i], -5.0) / pow(Tp, 4) * exp(-1950 * pow(w[i], -4) / pow(Tp, 4)) * pow(gamma, a[i])

    return sj

原代码:

def jonswap(w: np.ndarray, Hs: float, Tp: float, gamma: float = 3.7) -> np.ndarray:
    '''
    get Jonswap spectra
    :param w: np.ndarray Angular Frequency
    '''
    omega = w
    wp = 2 * np.pi / Tp
    sigma = np.where(omega < wp, 0.07, 0.09)
    a = np.exp(-0.5 * np.power((omega - wp) / (sigma * omega), 2.0))
    sj = 320 * np.power(Hs, 2) * np.power(omega, -5.0) / np.power(Tp, 4) * \
          np.exp(-1950 * np.power(omega, -4) / np.power(Tp, 4)) * np.power(gamma, a)

    return sj

【问题讨论】:

    标签: python c performance cython


    【解决方案1】:

    您的原始代码都是向量化的 numpy 操作,因此改进空间有限。使用注释标志 (-a) 运行 cython 指出了以下可能的改进。

    • 使用 libc pow 而不是 python 内置
    • 省略边界检查/环绕语义
    • 使用 c 除法关闭除以 0 检查(如果这样做安全的话!)

    新的 cython 版本

    from libc.math cimport exp, pow
    from libc.stdlib cimport malloc
    import numpy as np
    cimport numpy as np
    cimport cython
    
    DTYPE_float = np.float64
    ctypedef np.float64_t DTYPE_float_t
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    @cython.cdivision(True)
    def cy_jonswap(np.ndarray[DTYPE_float_t, ndim=1, mode ='c'] w, double Hs, double Tp, double gamma = 3.7):
        '''
        get Jonswap spectra
        :param w: np.ndarray Angular Frequency
        '''
        cdef:
            int n = w.shape[0]
            double *sigma = <double*>malloc(n * sizeof(double)) 
            double *a = <double*>malloc(n * sizeof(double)) 
            int i 
        cdef double wp
        cdef np.ndarray[DTYPE_float_t, ndim=1, mode='c'] sj = np.ones(n, dtype=DTYPE_float)
    
        wp = 2 * np.pi / Tp
        with nogil:
            for i in range(n):
                sigma[i] = 0.07 if w[i] < wp else 0.09
                a[i] = exp(-0.5 * pow((w[i] - wp) / (sigma[i] * w[i]), 2.0))
                sj[i] = 320 * pow(Hs, 2) * pow(w[i], -5.0) / pow(Tp, 4) * exp(-1950 * pow(w[i], -4) / pow(Tp, 4)) * pow(gamma, a[i])
    
        return sj
    

    时间

    w = np.random.randn(1_000_000)
    
    %timeit cy_jonswap(w, .5, .5)
    289 ms ± 7.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    %timeit jonswap(w, .5, .5)
    411 ms ± 26.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    另外,请注意,在您的 cython 版本中,您正在泄漏 sigmaa 的内存

    【讨论】:

    • 大部分缓慢的 pow 调用是不必要的。例如,Hs*Hs 比 pow(Hs,2) 快得多。 stackoverflow.com/a/53172561/4045774 如果你有 pow(w[i], -5) 而不是 pow(w[i], -5.0) 甚至 1./(w[i]*w[i] *w[i]*w[i]*w[i])
    猜你喜欢
    • 2016-11-02
    • 1970-01-01
    • 2014-05-28
    • 1970-01-01
    • 2023-03-26
    • 2017-11-17
    • 2011-07-16
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多