【发布时间】:2021-05-10 20:06:49
【问题描述】:
我正在运行一些模拟,我使用 numba 编译我的 python 代码以加快模拟速度。我编写了一个函数来覆盖其中一个输入数组,因此我想传入该数组的副本。但是,这会使代码慢得多,而且比复制所需的时间要慢得多。
以下是计时结果:
> population_ = population.copy()
> %timeit _ = run_simulation(population_, Tmax, dt, Nskip = Nskip)
64.6 ms ± 215 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
> %timeit _ = run_simulation(population.copy(), Tmax, dt, Nskip = Nskip)
87.4 ms ± 778 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
> %timeit _ = population.copy()
442 ns ± 10.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
因此,直接使用.copy() 的结果作为参数调用run_simulation 会慢大约23 毫秒,尽管制作副本只需要大约0.0004 毫秒。我不明白为什么会这样。
对于背景,这里是完整的代码:
import numpy as np
from numba import jit, int32, int64, float64
@jit('int32[:,:,:](int32[:,:,:], float64)', nopython=True)
def one_step(population, dt):
# Hard-coding model parameters here
beta = 0.55
tau = 10
# This probabilty doesn't depend on the other states
pIR = 1 - np.exp(-dt/tau)
# Double for loop over towns and towns
for i in range(population.shape[0]):
I = np.sum(population[i,1,:])
N = np.sum(population[i,:,:])
# Transition probability from susceptible to infected
pSI = 1 - np.exp(-dt*beta*I/N)
for j in range(population.shape[1]):
# Unpack variables for convenience
S, I, R = population[i,j,:]
S2I = np.random.binomial(S, pSI)
I2R = np.random.binomial(I, pIR)
# Calculate new values
S = S - S2I
I = I + S2I - I2R
R = R + I2R
population[i,j,:] = (S, I, R)
return population
@jit('int32[:,:,:](int32[:,:,:], float64, float64, int64)', nopython=True)
def run_simulation(population, Tmax, dt, Nskip = 10):
Nt = int(Tmax/dt)
history = np.zeros((population.shape[0], 3, int((Tmax/dt)/Nskip) + 1), dtype = np.int32)
history[:,:,0] = np.sum(population, axis = 1)
t = 0
for i in range(1, Nt+1):
population = one_step(population, dt)
t += dt
if i % Nskip == 0:
history[:,:,int(i/Nskip)] = np.sum(population, axis = 1)
return history
# Initial state
population = np.random.randint(low = 0, high = 1000, size = (10,10,3), dtype = np.int32)
# Run simulation for 100 days
Tmax = 100
dt = 0.01
# Only store once per day
Nskip = int(1/dt)
# Call one timestep to compile numba-decorated functions
# prior to measuring timing
_ = run_simulation(population, 1.0, 1.0, Nskip = 1)
# Run timing
population_ = population.copy()
%timeit _ = run_simulation(population_, Tmax, dt, Nskip = Nskip)
# Run timing
%timeit _ = run_simulation(population.copy(), Tmax, dt, Nskip = Nskip)
# Run timing
%timeit _ = population.copy()
【问题讨论】:
-
输入数据有多大?
population的尺寸是多少? -
上面列出的时序是使用 10 x 10 x 3 的大小完成的,如列出的代码示例中所示。