我很确定最昂贵的东西就是 logs 和那些你可以保存的东西(即调用一次而不是 sig.size - 在这个例子中 = 1000 - 次)。基础身份是log(a) + log(b) = log(a*b)。让我的速度提高了三倍。
顺便说一句。我将产生空数组的random.rand(0, 1, 1000) 更改为random.rand(1000)。另外我认为sig=sqrt(res) 不是一般假设,所以我没有尝试利用它。
我做了一些其他优化,但它们的效果远不及摆脱 logs 的影响。 1) 预计算 log(2*pi) 2) 使用 np.dot 进行最后一项。
更新:
正如@JPdL 指出的那样,这种方法容易受到上溢和下溢的影响。可以使用分块来交换各种性能以防止出现这种情况,请参阅下面的f4. f5. f6。
import numpy as np
np.random.seed(1)
res = np.random.rand(1000)
f1 = lambda r, sig: -0.5*(np.sum(2*np.log(sig) + np.log(2*np.pi) + (r/sig)**2))
def f2(r, sig):
return -0.5*(np.sum(2*np.log(sig) + np.log(2*np.pi) + (r/sig)**2))
# gung ho
def f3(r, sig, precomp=np.log(2*np.pi)):
rs = r/sig
return -np.log(np.prod(sig)) - 0.5*(r.size * precomp + np.dot(rs, rs))
# chunk and hope for the best
def f4(r, sig, chnk=32, precomp=np.log(2*np.pi)):
rs = r/sig
sumlogsig = np.log(np.multiply.reduceat(sig, np.arange(0, len(r), chnk))).sum()
return -sumlogsig - 0.5*(r.size * precomp + np.dot(rs, rs))
# chunk and check for extreme values
def f5(r, sig, chnk=32,
precomp=np.log(2*np.pi), precomp2=np.exp([-8, 8])):
rs = r/sig
bad = np.where((sig<precomp2[0]) | (sig>precomp2[1]))[0]
sumlogsig = np.log(sig[bad]).sum()
keep = sig[bad]
sig[bad] = 1
sumlogsig += np.log(np.multiply.reduceat(sig, np.arange(0, len(r), chnk))).sum()
sig[bad] = keep
return -sumlogsig - 0.5*(r.size * precomp + np.dot(rs, rs))
# chunk and try to be extra clever
def f6(r, sig, chnk=512,
precomp=np.log(2*np.pi), precomp2=np.exp(np.arange(-18, 19))):
binidx = np.digitize(sig, precomp2[1::2])<<1
rs = r/sig
sig = sig*precomp2[36 - binidx]
bad = np.where((binidx==0) | (binidx==36))[0]
sumlogsig = binidx.sum() - 18*r.size + np.log(sig[bad]).sum()
sig[bad] = 1
sumlogsig += np.log(np.multiply.reduceat(sig, np.arange(0, len(r), chnk))).sum()
return -sumlogsig - 0.5*(r.size * precomp + np.dot(rs, rs))
from timeit import timeit
sr = np.sqrt(res)
print(timeit('f1(res,sr)', number=100, globals={'res':res, 'np':np, 'sr':sr, 'f1':f1}))
print(timeit('f2(res,sr)', number=100, globals={'res':res, 'np':np, 'sr':sr, 'f2':f2}))
print(timeit('f3(res,sr)', number=100, globals={'res':res, 'np':np, 'sr':sr, 'f3':f3}))
print(timeit('f4(res,sr)', number=100, globals={'res':res, 'np':np, 'sr':sr, 'f4':f4}))
print(timeit('f5(res,sr)', number=100, globals={'res':res, 'np':np, 'sr':sr, 'f5':f5}))
print(timeit('f6(res,sr)', number=100, globals={'res':res, 'np':np, 'sr':sr, 'f6':f6}))
print(f1(res,np.sqrt(res)))
print(f2(res,np.sqrt(res)))
print(f3(res,np.sqrt(res)))
print(f4(res,np.sqrt(res)))
print(f5(res,np.sqrt(res)))
print(f6(res,np.sqrt(res)))
样本输出:
0.004246247990522534
0.00418912700843066
0.0011273059935774654
0.0013386670034378767
0.0022679700050503016
0.004274581006029621
-662.250886322
-662.250886322
-662.250886322
-662.250886322
-662.250886322
-662.250886322