【问题标题】:Fast algorithm for log gamma function对数伽玛函数的快速算法
【发布时间】:2019-02-24 10:33:51
【问题描述】:

我正在尝试编写一个快速算法来计算log gamma function。目前我的实现看起来很幼稚,只是迭代了 1000 万次来计算 gamma 函数的日志(我也在使用 numba 来优化代码)。

import numpy as np
from numba import njit
EULER_MAS = 0.577215664901532 # euler mascheroni constant
HARMONC_10MIL = 16.695311365860007 # sum of 1/k from 1 to 10,000,000

@njit(fastmath=True)
def gammaln(z):
"""Compute log of gamma function for some real positive float z"""
    out = -EULER_MAS*z - np.log(z) + z*HARMONC_10MIL
    n = 10000000 # number of iters
    for k in range(1,n+1,4):
        # loop unrolling
        v1 = np.log(1 + z/k)
        v2 = np.log(1 + z/(k+1))
        v3 = np.log(1 + z/(k+2))
        v4 = np.log(1 + z/(k+3))
        out -= v1 + v2 + v3 + v4

    return out

我根据scipy.special.gammaln 实现对我的代码进行了计时,而我的代码实际上慢了 100,000 倍。所以我在做一些非常错误或非常幼稚的事情(可能两者兼而有之)。尽管与 scipy 相比,我的答案至少在小数点后 4 位以内是正确的。

我试图阅读实现 scipy 的 gammaln 函数的 _ufunc 代码,但是我不明白 _gammaln 函数所用的 cython 代码。

是否有更快、更优化的方法可以计算对数伽玛函数?我如何理解 scipy 的实现,以便将其与我的结合起来?

【问题讨论】:

  • z 的示例输入是什么?我不知道公式,但这并不意味着人们不能尝试对其进行矢量化——不过,我们需要知道如何调用函数进行测试。
  • 另外,如果我们谈论的是比 Scipy 慢 100,000 倍,请确保使用示例输入运行它不会花费我们很多时间 :)
  • .@roganjosh 在我的机器上运行带有 1 参数的函数大约需要 50 毫秒,所以我想这会是安全的
  • @user8408080 oki doki。你知道输入应该是一个int还是一个数组?
  • 据我所知,它可以是任何复数(参见here)。但只有一个数字

标签: python performance math scipy gamma-function


【解决方案1】:

您的函数的运行时间将随着迭代次数线性扩展(直至一些恒定的开销)。所以减少迭代次数是加速算法的关键。虽然预先计算HARMONIC_10MIL 是一个聪明的想法,但当您截断系列时,它实际上会导致更差的准确性;只计算系列中的一部分结果可以提供更高的准确度。

下面的代码是上面发布的代码的修改版本(尽管使用cython 而不是numba)。

from libc.math cimport log, log1p
cimport cython
cdef:
    float EULER_MAS = 0.577215664901532 # euler mascheroni constant

@cython.cdivision(True)
def gammaln(float z, int n=1000):
    """Compute log of gamma function for some real positive float z"""
    cdef:
        float out = -EULER_MAS*z - log(z)
        int k
        float t
    for k in range(1, n):
        t = z / k
        out += t - log1p(t)

    return out

如下图所示,即使经过 100 次近似,也能得到一个接近的近似值。

在 100 次迭代时,它的运行时间与 scipy.special.gammaln 处于同一数量级:

%timeit special.gammaln(5)
# 932 ns ± 19 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit gammaln(5, 100)
# 1.25 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

剩下的问题当然是要使用多少次迭代。函数log1p(t) 可以扩展为小t 的泰勒级数(这与大k 的极限有关)。特别是,

log1p(t) = t - t ** 2 / 2 + ...

这样,对于大的k,总和的参数变为

t - log1p(t) = t ** 2 / 2 + ...

因此,在t 中,和的参数在二阶之前为零,如果t 足够小,则可以忽略不计。也就是说,迭代次数至少应该和z一样大,最好至少大一个数量级。

但是,如果可能的话,我会坚持使用 scipy 的经过充分测试的实现。

【讨论】:

  • 很好的答案,在您的示例中它似乎确实运行得很快!一个愚蠢的问题,我如何获得 libc.math 库?我已经 pip 安装了 Cython,但似乎找不到 libc.math 库。
  • libc.math 我认为应该默认包含在内。但是,我经常犯错误写 import 而不是 cimport 用于 cython 包含。这可能是问题吗?
  • 它似乎不适用于上面代码中importcimport 的任何排列...我可以运行import cython 但不能运行from libc.math cimport ...(这会导致语法错误)或from libc.math import ...(这会导致 ModuleNotFoundError)
  • 我的错误@Till Hoffmann,我在没有正确设置的情况下在 jupyter notebook 中运行它。得到它的工作。非常感谢!
【解决方案2】:

通过尝试 numba 的并行模式并主要使用矢量化函数,我设法将性能提高了大约 3 倍(遗憾的是,numba 无法理解 numpy.substract.reduce

from functools import reduce
import numpy as np
from numba import njit

@njit(fastmath=True, parallel=True)
def gammaln_vec(z):
    out = -EULER_MAS*z - np.log(z) + z*HARMONC_10MIL
    n = 10000000

    v = np.log(1 + z/np.arange(1, n+1))

    return out-reduce(lambda x1, x2: x1-x2, v, 0)

次:

#Your function:
%timeit gammaln(1.5)
48.6 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

#My function:
%timeit gammaln_vec(1.5)
15 ms ± 340 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

#scpiy's function
%timeit gammaln_sp(1.5)
1.07 µs ± 18.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

因此,使用 scipy 的功能会更好。如果没有 C 代码,我不知道如何进一步分解它

【讨论】:

    【解决方案3】:

    关于您之前的问题,我想一个将 scipy.special 函数包装到 Numba 的示例也很有用。

    示例

    只要只涉及简单的数据类型(int、double、double*、...),包装 Cython cdef 函数就非常容易和可移植。有关如何调用 scipy.special 函数 have a look at this 的文档。您实际需要包装函数的函数名称在scipy.special.cython_special.__pyx_capi__ 中。可以使用不同数据类型调用的函数名称被破坏,但确定正确的名称非常容易(只需查看数据类型)

    #slightly modified version of https://github.com/numba/numba/issues/3086
    from numba.extending import get_cython_function_address
    from numba import vectorize, njit
    import ctypes
    import numpy as np
    
    _PTR = ctypes.POINTER
    _dble = ctypes.c_double
    _ptr_dble = _PTR(_dble)
    
    addr = get_cython_function_address("scipy.special.cython_special", "gammaln")
    functype = ctypes.CFUNCTYPE(_dble, _dble)
    gammaln_float64 = functype(addr)
    
    @njit
    def numba_gammaln(x):
      return gammaln_float64(x)
    

    在 Numba 中的使用

    #Numba example with loops
    import numba as nb
    import numpy as np
    @nb.njit()
    def Test_func(A):
      out=np.empty(A.shape[0])
      for i in range(A.shape[0]):
        out[i]=numba_gammaln(A[i])
      return out
    

    时间安排

    data=np.random.rand(1_000_000)
    Test_func(A): 39.1ms
    gammaln(A):   39.1ms
    

    当然,您可以轻松地并行化此函数,并在 scipy 中超越单线程 gammaln 实现,并且您可以在任何 Numba 编译函数中有效地调用此函数。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-03-18
      • 1970-01-01
      • 1970-01-01
      • 2019-01-18
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多