【问题标题】:Numba @jit fails to optimise simple functionNumba @jit 无法优化简单功能
【发布时间】:2015-11-18 21:22:13
【问题描述】:

我有一个非常简单的函数,它使用 Numpy 数组和 for 循环,但是添加 Numba @jit 装饰器绝对不会加快速度:

# @jit(float64[:](int32,float64,float64,float64,int32))
@jit
def Ising_model_1D(N=200,J=1,T=1e-2,H=0,n_iter=1e6):
    beta = 1/T
    s = randn(N,1) > 10  
    s[N-1] = s[0]
    mag = zeros((n_iter,1))
    aux_idx =  randint(low=0,high=N,size=(n_iter,1))

    for i1 in arange(n_iter):
        rnd_idx = aux_idx[i1]
        s_1 = s[rnd_idx]*2 - 1
        s_2 = s[(rnd_idx+1)%(N)]*2 - 1
        s_3 = s[(rnd_idx-1)%(N)]*2 - 1
        delta_E = 2.0*J*(s_2+s_3)*s_1 + 2.0*H*s_1

        if(delta_E < 0):
            s[rnd_idx] = np.logical_not(s[rnd_idx]) 
        elif(np.exp(-1*beta*delta_E) >= rand()):
            s[rnd_idx] = np.logical_not(s[rnd_idx])
        s[N-1] = s[0]
        mag[i1] = (s*2-1).sum()*1.0/N 
    return mag

另一方面,MATLAB 的运行时间不到 0.5 秒! 为什么 Numba 缺少一些如此基本的东西?

【问题讨论】:

  • 您正在循环体中的标量值上调用 NumPy 函数。这些函数旨在有效地处理大型数组,而不是单个值。 numba 无法优化这些函数调用。简而言之,您需要对代码进行矢量化,而不是 JIT 编译。
  • @ajcr 我认为其中一些实际上可以,例如rand()ndarray.sum()(至少在最新版本的numba 中可以)。
  • @jme:啊,谢谢,我不知道是这样的。我曾认为反复调用np.logical_not(和其他编译函数)会减慢循环速度。我应该更深入地研究 numba 文档。

标签: python numpy numba


【解决方案1】:

这是对您的代码的修改,在我的机器上运行大约需要 0.4 秒:

def ising_model_1d(N=200,J=1,T=1e-2,H=0,n_iter=1e6):
    n_iter = int(n_iter)
    beta = 1/T
    s = randn(N) > 10
    s[N-1] = s[0]

    mag = zeros(n_iter)
    aux_idx =  randint(low=0,high=N,size=n_iter)

    pre_rand = rand(n_iter)

    _ising_jitted(n_iter, aux_idx, s, J, N, H, beta, pre_rand, mag)

    return mag


@jit(nopython=True)
def _ising_jitted(n_iter, aux_idx, s, J, N, H, beta, pre_rand, mag):
    for i1 in range(n_iter):
        rnd_idx = aux_idx[i1]
        s_1 = s[rnd_idx*2] - 1
        s_2 = s[(rnd_idx+1)%(N)]*2 - 1
        s_3 = s[(rnd_idx-1)%(N)]*2 - 1
        delta_E = 2.0*J*(s_2+s_3)*s_1 + 2.0*H*s_1
        t = rand()
        if delta_E < 0:
            s[rnd_idx] = not s[rnd_idx]
        elif np.exp(-1*beta*delta_E) >= pre_rand[i1]:
            s[rnd_idx] = not s[rnd_idx]

        s[N-1] = s[0]
        mag[i1] = (s*2-1).sum()*1.0/N

请确保结果符合预期!我改变了你的很多东西,不能保证计算是正确的!

使用numba 需要一点小心。 Python 函数以及大多数numpy 函数无法由编译器优化。我觉得有帮助的一件事是使用nopython 选项到@jit。这意味着只要你给它一些它不能真正优化的代码,编译器就会抱怨。然后,您可以查看错误消息并找到可能会降低您的代码速度的行。

我发现,诀窍是在 Python 中编写一个“网关”函数,该函数使用numpy 及其矢量化函数来完成尽可能多的工作。它应该创建你需要存储结果的空数组。它应该打包你在计算过程中需要的所有数据。然后它应该将所有这些传递到一个大而长的参数列表中的 jitter 函数中。

举个例子:注意我是如何在 jitted 代码中处理随机数生成的。在您的原始代码中,您调用了rand()

elif(np.exp(-1*beta*delta_E) >= rand()):

但是rand() 不能被numba 优化(至少在旧版本的numba 中是这样。在较新的版本中它可以,只要调用rand 时不带参数)。观察结果是,每次n_iter 迭代都需要一个随机数。所以我们只需在包装函数中使用numpy 创建一个随机数组,然后将这个随机数组提供给jitted 函数。获得一个随机数就像索引这个数组一样简单。

最后,关于可以被最新版本的编译器优化的numpy 函数列表,请参阅here。在我修改您的代码时,我积极地删除了对numpy 函数的调用,以便代码可以在更多版本的numba 上运行。

【讨论】:

  • 太棒了!感谢您的详细回复。 Numba 文档在细节上有点稀疏。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2015-05-08
  • 2019-01-13
  • 1970-01-01
  • 1970-01-01
  • 2018-01-18
  • 2015-12-25
相关资源
最近更新 更多