【发布时间】:2019-02-24 10:33:51
【问题描述】:
我正在尝试编写一个快速算法来计算log gamma function。目前我的实现看起来很幼稚,只是迭代了 1000 万次来计算 gamma 函数的日志(我也在使用 numba 来优化代码)。
import numpy as np
from numba import njit
EULER_MAS = 0.577215664901532 # euler mascheroni constant
HARMONC_10MIL = 16.695311365860007 # sum of 1/k from 1 to 10,000,000
@njit(fastmath=True)
def gammaln(z):
"""Compute log of gamma function for some real positive float z"""
out = -EULER_MAS*z - np.log(z) + z*HARMONC_10MIL
n = 10000000 # number of iters
for k in range(1,n+1,4):
# loop unrolling
v1 = np.log(1 + z/k)
v2 = np.log(1 + z/(k+1))
v3 = np.log(1 + z/(k+2))
v4 = np.log(1 + z/(k+3))
out -= v1 + v2 + v3 + v4
return out
我根据scipy.special.gammaln 实现对我的代码进行了计时,而我的代码实际上慢了 100,000 倍。所以我在做一些非常错误或非常幼稚的事情(可能两者兼而有之)。尽管与 scipy 相比,我的答案至少在小数点后 4 位以内是正确的。
我试图阅读实现 scipy 的 gammaln 函数的 _ufunc 代码,但是我不明白 _gammaln 函数所用的 cython 代码。
是否有更快、更优化的方法可以计算对数伽玛函数?我如何理解 scipy 的实现,以便将其与我的结合起来?
【问题讨论】:
-
z的示例输入是什么?我不知道公式,但这并不意味着人们不能尝试对其进行矢量化——不过,我们需要知道如何调用函数进行测试。 -
另外,如果我们谈论的是比 Scipy 慢 100,000 倍,请确保使用示例输入运行它不会花费我们很多时间 :)
-
.@roganjosh 在我的机器上运行带有
1参数的函数大约需要 50 毫秒,所以我想这会是安全的 -
@user8408080 oki doki。你知道输入应该是一个int还是一个数组?
-
据我所知,它可以是任何复数(参见here)。但只有一个数字
标签: python performance math scipy gamma-function