【问题标题】:How can I use Cython well to solve a differential equation faster?如何更好地使用 Cython 更快地求解微分方程?
【发布时间】:2017-03-16 15:18:06
【问题描述】:

我想缩短 Scipy 的 odeint 求解微分的时间 方程。

为了练习,我使用Python in scientific computations 中的示例作为模板。因为 odeint 将函数 f 作为参数,所以我将此函数编写为静态类型的 Cython 版本并希望 odeint 的运行时间会显着减少。

函数f 包含在名为ode.pyx 的文件中,如下所示:

import numpy as np
cimport numpy as np
from libc.math cimport sin, cos

def f(y, t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
  return derivs

def fCMath(y, double t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + sin(theta) + d*cos(Omega*t)
  return derivs

然后我创建一个文件setup.py 来编译函数:

from distutils.core import setup
from Cython.Build import cythonize

setup(ext_modules=cythonize('ode.pyx'))

求解微分方程的脚本(也包含 Python f) 的版本被称为 solveODE.py,看起来像:

import ode
import numpy as np
from scipy.integrate import odeint
import time

def f(y, t, params):
    theta, omega = y
    Q, d, Omega = params
    derivs = [omega,
             -omega/Q + np.sin(theta) + d*np.cos(Omega*t)]
    return derivs

params = np.array([2.0, 1.5, 0.65])
y0 = np.array([0.0, 0.0])
t = np.arange(0., 200., 0.05)

start_time = time.time()
odeint(f, y0, t, args=(params,))
print("The Python Code took: %.6s seconds" % (time.time() - start_time))

start_time = time.time()
odeint(ode.f, y0, t, args=(params,))
print("The Cython Code took: %.6s seconds ---" % (time.time() - start_time))

start_time = time.time()
odeint(ode.fCMath, y0, t, args=(params,))
print("The Cython Code incorpoarting two of DavidW_s suggestions took: %.6s seconds ---" % (time.time() - start_time))

然后我运行:

python setup.py build_ext --inplace
python solveODE.py 

在终端中。

python版本的时间约为0.055秒, 而 Cython 版本大约需要 0.04 秒。

是否有人建议我改进我解决问题的尝试 微分方程,最好不要用 Cython 修改 odeint 例程本身?

编辑

我在ode.pyxsolveODE.py 两个文件中加入了DavidW 的建议,运行这些建议的代码只用了大约0.015 秒。

【问题讨论】:

  • 您应该将其发布到 codereview
  • 我可能会尝试numba 而不是cython,但任何差异都可能很小。大部分计算时间可能是在odeint 调用您的函数时发生的上下文切换。老实说,您可能会看到编写自己的数值积分函数(再次使用 cython 或 numba)以避免上下文切换的最佳收益
  • @Farhan.K 我有点同意,但经验表明人们在这里得到的“加速 Cython”问题的答案比 codereview 更好,所以我不确定这是否是个好建议
  • @fabian 我没有通读源代码本身,但你的函数fode.f 是python 对象,每次调用至少需要一次上下文切换(4000 次调用 0-200以 0.05 为步长)否则 odeint 将无法使用任何旧的自定义用户功能。我已经用 numba 获得了 4 倍的加速,但我正在努力获得更多...
  • @Farhan.K 不要仅仅因为他们想让代码更快而建议 CodeReview。注意标签在相应板上的受欢迎程度。如果您想改进 C++ 或 Java 代码,CR 非常棒,但在处理像 Cython 这样的专用包时,CR 就差了很多。

标签: python scipy cython differential-equations scientific-computing


【解决方案1】:

最简单的更改(可能会让您受益匪浅)是使用 C 数学库 sincos 对单个数字而不是数字进行运算。对numpy 的调用以及确定它不是数组所花费的时间相当昂贵。

from libc.math cimport sin, cos

    # later
    -omega/Q + sin(theta) + d*cos(Omega*t)

我很想为输入 d 分配一个类型(在不更改界面的情况下,其他输入都不容易输入):

def f(y, double t, params):

我想我也会像您在 Python 版本中那样返回一个列表。我认为使用 C 数组不会有什么好处。

【讨论】:

  • 感谢您的建议!事实上,通过使用 C 数学库,代码相对于我的版本提高了大约 40%,并且总速度大约是 Python 代码的两倍。按照您的建议键入 t 可进一步将代码改进几个百分点。
【解决方案2】:

tldr;使用 numba.jit 实现 3 倍加速...

我对 cython 没有太多经验,但我的机器似乎为您的严格 python 版本获得了相似的计算时间,所以我们应该能够大致比较苹果和苹果。我使用numba 来编译函数f(我稍微重新编写了它以使其更好地与编译器配合使用)。

def f(y, t, params):
    return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])

numba_f = numba.jit(f)

numba_f 代替你的ode.f 给我这个输出...

The Python Code took: 0.0468 seconds
The Numba Code took: 0.0155 seconds

然后我想知道我是否可以复制 odeint 并使用 numba 编译以进一步加快速度...(我不能)

这是我的龙格-库塔数值微分方程积分器:

#function f is provided inline (not as an arg)
def runge_kutta(y0, steps, dt, args=()): #improvement on euler's method. *note: time steps given in number of steps and dt
    Y = np.empty([steps,y0.shape[0]])
    Y[0] = y0
    t = 0
    n = 0
    for n in range(steps-1):
        #calculate coeficients
        k1 = f(Y[n], t, args) #(euler's method coeficient) beginning of interval
        k2 = f(Y[n] + (dt * k1 / 2), t + (dt/2), args) #interval midpoint A
        k3 = f(Y[n] + (dt * k2 / 2), t + (dt/2), args) #interval midpoint B
        k4 = f(Y[n] + dt * k3, t + dt, args) #interval end point

        Y[n + 1] = Y[n] + (dt/6) * (k1 + 2*k2 + 2*k3 + k4) #calculate Y(n+1)
        t += dt #calculate t(n+1)
    return Y

简单的循环函数通常是编译后最快的,尽管这可能会被重新构造以获得更好的速度。我应该注意,这给出了与odeint 不同的答案,在大约 2000 步后偏差多达 0.001,并且在 3000 步之后完全不同。对于函数的 numba 版本,我只是将 f 替换为 @987654333 @,并添加了以@numba.jit 作为装饰器的编译。在这种情况下,正如预期的那样,纯 python 版本非常慢,但 numba 版本并不比带有 odeint 的 numba 快(再次,ymmv)。

using custom integrator
The Python Code took: 0.2340 seconds
The Numba Code took: 0.0156 seconds

这是一个提前编译的例子。我在这台计算机上没有编译所需的工具链,也没有管理员来安装它,所以这给了我一个错误,我没有所需的编译器,但它应该可以正常工作。

import numpy as np
from numba.pycc import CC

cc = CC('diffeq')

@cc.export('func', 'f8[:](f8[:], f8, f8[:])')
def func(y, t, params):
    return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])

cc.compile()

【讨论】:

  • 非常感谢您的详细解答!我从中学到了很多。不幸的是,当我使用@jit 声明时,我的代码被严重拖慢了。事实上,代码大约花了 2 秒。我的 python 版本附带 anaconda 4.3.1 并提供 numbda 0.30.1。对于如此缓慢的结果,您有什么想法吗?
  • @fabian 初始编译需要时间,但随后的每次运行都应该很快。有人已经在 cmets 主线程中提到了这一点。这就像在运行时而不是事先进行 cython 编译。 numba 确实有预编译支持,如果您查看他们的文档,但我从未使用过它。
  • @fabian 或使用cache=Truejit (numba.pydata.org/numba-doc/0.30.1/reference/…) 可能
  • @fabian 我添加了一个编译示例。我实际上在这台计算机上没有编译器,所以我现在无法测试它。 (缺少 vcvarsall.bat 错误)
【解决方案3】:

如果其他人使用其他模块回答这个问题,我不妨插一句:

我是JiTCODE 的作者,它接受用 SymPy 符号编写的 ODE,然后将此 ODE 转换为 Python 模块的 C 代码,编译此 C 代码,加载结果并将其用作 @987654322 的派生@。您翻译为 JiTCODE 的示例如下所示:

from jitcode import jitcode, provide_basic_symbols
import numpy as np
from sympy import sin, cos
import time

Q = 2.0
d = 1.5
Ω = 0.65

t, y = provide_basic_symbols()

f = [
    y(1),
    -y(1)/Q + sin(y(0)) + d*cos(Ω*t)
    ]

initial_state = np.array([0.0,0.0])

ODE = jitcode(f)
ODE.set_integrator("lsoda")
ODE.set_initial_value(initial_state,0.0)

start_time = time.time()
data = np.vstack(ODE.integrate(T) for T in np.arange(0.05, 200., 0.05))
end_time = time.time()
print("JiTCODE took: %.6s seconds" % (end_time - start_time))

这需要 0.11 秒,与基于 odeint 的解决方案相比,这速度非常慢,但这不是由于实际集成,而是由于处理结果的方式:虽然 odeint 直接在内部高效地创建了一个数组,这是通过 Python here 完成的。根据您所做的工作,这可能是一个关键的缺点,但这很快就会与更粗略的采样或更大的微分方程无关。

所以,让我们删除数据集合,只看一下集成,将最后几行替换为以下内容:

ODE = jitcode(f)
ODE.set_integrator("lsoda", max_step=0.05, nsteps=1e10)
ODE.set_initial_value(initial_state,0.0)

start_time = time.time()
ODE.integrate(200.0)
end_time = time.time()
print("JiTCODE took: %.6s seconds" % (end_time - start_time))

请注意,我设置 max_step=0.05 以强制积分器至少执行与您的示例中一样多的步骤,并确保唯一的区别是积分结果不会存储到某个数组中。这在 0.010 秒内运行。

【讨论】:

    【解决方案4】:

    NumbaLSODA 耗时 0.00088 秒(比 Cython 快 17 倍)。

    from NumbaLSODA import lsoda_sig, lsoda
    import numba as nb
    import numpy as np
    import time
    
    @nb.cfunc(lsoda_sig)
    def f(t, y_, dy, p_):
        p = nb.carray(p_, (3,))
        y = nb.carray(y_, (2,))
        theta, omega = y
        Q, d, Omega = p
        dy[0] = omega
        dy[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
    
    funcptr = f.address # address to ODE function
    y0 = np.array([0.0, 0.0])
    data = np.array([2.0, 1.5, 0.65])
    t = np.arange(0., 200., 0.05)
    
    start_time = time.time()
    usol, success = lsoda(funcptr, y0, t, data = data)
    print("NumbaLSODA took: %.8s seconds ---" % (time.time() - start_time))
    

    结果

    NumbaLSODA took: 0.000880 seconds ---
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-05-01
      • 1970-01-01
      • 2019-01-19
      • 1970-01-01
      • 2016-01-13
      • 2020-11-01
      相关资源
      最近更新 更多