使用 Numba 的解决方案
在某些情况下,如果支持您使用的所有功能,这很容易做到。在您的代码中,win = signal.detrend(win, type = 'linear') 是您必须在 Numba 中实现的部分,因为不支持此功能。
在 Numba 中实施去趋势
如果您查看 detrend 的 source-code,并提取您的问题的相关部分,它可能看起来像这样:
@nb.njit()
def detrend(w):
Npts=w.shape[0]
A=np.empty((Npts,2),dtype=w.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
coef, resids, rank, s = np.linalg.lstsq(A, w.T)
out=w.T- np.dot(A, coef)
return out.T
我还为np.max(np.isnan(win)) == 1 实现了一个更快的解决方案
@nb.njit()
def isnan(win):
for i in range(win.shape[0]):
for j in range(win.shape[1]):
if np.isnan(win[i,j]):
return True
return False
主要功能
因为我这里用的是Numba,所以并行化很简单,只是在外层循环和
import numpy as np
import numba as nb
@nb.njit(parallel=True)
def RMSH_det_nb(DEM, w):
[nrows, ncols] = np.shape(DEM)
#create an empty array to store result
rms = DEM*np.nan
for i in nb.prange(w+1,nrows-w):
for j in range(w+1,ncols-w):
win = DEM[i-w:i+w-1,j-w:j+w-1]
if isnan(win):
rms[i,j] = np.nan
else:
win = detrend(win)
z = win.flatten()
nz = z.size
rootms = np.sqrt(1 / (nz - 1) * np.sum((z-np.mean(z))**2))
rms[i,j] = rootms
return rms
时间安排(小例子)
w = 10
DEM=np.random.rand(100, 100).astype(np.float32)
res1=RMSH_det(DEM, w)
res2=RMSH_det_nb(DEM, w)
print(np.allclose(res1,res2,equal_nan=True))
#True
%timeit res1=RMSH_det(DEM, w)
#1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res2=RMSH_det_nb(DEM, w) #approx. 55 times faster
#29 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
大型阵列的时间
w = 10
DEM=np.random.rand(1355, 1165).astype(np.float32)
%timeit res2=RMSH_det_nb(DEM, w)
#6.63 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[编辑] 使用正规方程的实现
Overdetermined system
此方法数值精度较低。虽然这个解决方案要快得多。
@nb.njit()
def isnan(win):
for i in range(win.shape[0]):
for j in range(win.shape[1]):
if win[i,j]==np.nan:
return True
return False
@nb.njit()
def detrend(w):
Npts=w.shape[0]
A=np.empty((Npts,2),dtype=w.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
coef, resids, rank, s = np.linalg.lstsq(A, w.T)
out=w.T- np.dot(A, coef)
return out.T
@nb.njit()
def detrend_2(w,T1,A):
T2=np.dot(A.T,w.T)
coef=np.linalg.solve(T1,T2)
out=w.T- np.dot(A, coef)
return out.T
@nb.njit(parallel=True)
def RMSH_det_nb_normal_eq(DEM,w):
[nrows, ncols] = np.shape(DEM)
#create an empty array to store result
rms = DEM*np.nan
Npts=w*2-1
A=np.empty((Npts,2),dtype=DEM.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
T1=np.dot(A.T,A)
nz = Npts**2
for i in nb.prange(w+1,nrows-w):
for j in range(w+1,ncols-w):
win = DEM[i-w:i+w-1,j-w:j+w-1]
if isnan(win):
rms[i,j] = np.nan
else:
win = detrend_2(win,T1,A)
rootms = np.sqrt(1 / (nz - 1) * np.sum((win-np.mean(win))**2))
rms[i,j] = rootms
return rms
时间
w = 10
DEM=np.random.rand(100, 100).astype(np.float32)
res1=RMSH_det(DEM, w)
res2=RMSH_det_nb(DEM, w)
print(np.allclose(res1,res2,equal_nan=True))
#True
%timeit res1=RMSH_det(DEM, w)
#1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res2=RMSH_det_nb_normal_eq(DEM,w)
#7.97 ms ± 89.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
使用正规方程的优化解
重用临时数组以避免昂贵的内存分配,并使用矩阵乘法的自定义实现。这仅推荐用于非常小的矩阵,在大多数其他情况下,np.dot (sgeemm) 会快很多。
@nb.njit()
def matmult_2(A,B,out):
for j in range(B.shape[1]):
acc1=nb.float32(0)
acc2=nb.float32(0)
for k in range(B.shape[0]):
acc1+=A[0,k]*B[k,j]
acc2+=A[1,k]*B[k,j]
out[0,j]=acc1
out[1,j]=acc2
return out
@nb.njit(fastmath=True)
def matmult_mod(A,B,w,out):
for j in range(B.shape[1]):
for i in range(A.shape[0]):
acc=nb.float32(0)
acc+=A[i,0]*B[0,j]+A[i,1]*B[1,j]
out[j,i]=acc-w[j,i]
return out
@nb.njit()
def detrend_2_opt(w,T1,A,Tempvar_1,Tempvar_2):
T2=matmult_2(A.T,w.T,Tempvar_1)
coef=np.linalg.solve(T1,T2)
return matmult_mod(A, coef,w,Tempvar_2)
@nb.njit(parallel=True)
def RMSH_det_nb_normal_eq_opt(DEM,w):
[nrows, ncols] = np.shape(DEM)
#create an empty array to store result
rms = DEM*np.nan
Npts=w*2-1
A=np.empty((Npts,2),dtype=DEM.dtype)
for i in range(Npts):
A[i,0]=1.*(i+1) / Npts
A[i,1]=1.
T1=np.dot(A.T,A)
nz = Npts**2
for i in nb.prange(w+1,nrows-w):
Tempvar_1=np.empty((2,Npts),dtype=DEM.dtype)
Tempvar_2=np.empty((Npts,Npts),dtype=DEM.dtype)
for j in range(w+1,ncols-w):
win = DEM[i-w:i+w-1,j-w:j+w-1]
if isnan(win):
rms[i,j] = np.nan
else:
win = detrend_2_opt(win,T1,A,Tempvar_1,Tempvar_2)
rootms = np.sqrt(1 / (nz - 1) * np.sum((win-np.mean(win))**2))
rms[i,j] = rootms
return rms
时间
w = 10
DEM=np.random.rand(100, 100).astype(np.float32)
res1=RMSH_det(DEM, w)
res2=RMSH_det_nb_normal_eq_opt(DEM, w)
print(np.allclose(res1,res2,equal_nan=True))
#True
%timeit res1=RMSH_det(DEM, w)
#1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit res2=RMSH_det_nb_normal_eq_opt(DEM,w)
#4.66 ms ± 87.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
isnan 的时间
这个函数是一个完全不同的实现。如果 NaN 位于数组的开头,它会快得多,但无论如何,即使没有,也会有一些加速。我用小数组(大约窗口大小)和@user3666197 建议的大大小对它进行了基准测试。
case_1=np.full((20,20),np.nan)
case_2=np.full((20,20),0.)
case_2[10,10]=np.nan
case_3=np.full((20,20),0.)
case_4 = np.full( ( int( 1E4 ), int( 1E4 ) ),np.nan)
case_5 = np.ones( ( int( 1E4 ), int( 1E4 ) ) )
%timeit np.any(np.isnan(case_1))
%timeit np.any(np.isnan(case_2))
%timeit np.any(np.isnan(case_3))
%timeit np.any(np.isnan(case_4))
%timeit np.any(np.isnan(case_5))
#2.75 µs ± 73.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#2.75 µs ± 46.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#2.76 µs ± 32.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#81.3 ms ± 2.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#86.7 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit isnan(case_1)
%timeit isnan(case_2)
%timeit isnan(case_3)
%timeit isnan(case_4)
%timeit isnan(case_5)
#244 ns ± 5.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#357 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#475 ns ± 9.28 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#235 ns ± 0.933 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
#58.8 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)