【发布时间】:2015-05-12 07:31:51
【问题描述】:
在过去的一周里,我一直在询问有关此堆栈的相关问题,以尝试找出我不了解的有关在 Python 中将 @jit 装饰器与 Numba 一起使用的问题。但是,我碰壁了,所以我会写整个问题。
当前的问题是计算大量 段对之间的最小距离。分段由它们的 3D 起点和终点表示。在数学上,每个段都被参数化为 [AB] = A + (B-A)*s,其中 s 在 [0,1] 中,A 和 B 是段的起点和终点。对于两个这样的段,可以计算最小距离,并给出公式here。
我已经在另一个thread 上暴露了这个问题,并且给出的答案是通过向量化问题来替换我的代码的双循环,但这会导致大量段的内存问题。因此,我决定坚持使用循环,并改用 numba 的 jit。
由于解决问题需要很多点积,而numpy的点积是not supported by numba,所以我从实现自己的3D点积开始。
import numpy as np
from numba import jit, autojit, double, float64, float32, void, int32
def my_dot(a,b):
res = a[0]*b[0]+a[1]*b[1]+a[2]*b[2]
return res
dot_jit = jit(double(double[:], double[:]))(my_dot) #I know, it's not of much use here.
计算 N 段中所有对的最小距离的函数将 Nx6 数组(6 个坐标)作为输入
def compute_stuff(array_to_compute):
N = len(array_to_compute)
con_mat = np.zeros((N,N))
for i in range(N):
for j in range(i+1,N):
p0 = array_to_compute[i,0:3]
p1 = array_to_compute[i,3:6]
q0 = array_to_compute[j,0:3]
q1 = array_to_compute[j,3:6]
s = ( dot_jit((p1-p0),(q1-q0))*dot_jit((q1-q0),(p0-q0)) - dot_jit((q1-q0),(q1-q0))*dot_jit((p1-p0),(p0-q0)))/( dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(q1-q0)) - dot_jit((p1-p0),(q1-q0))**2 )
t = ( dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(p0-q0)) -dot_jit((p1-p0),(q1-q0))*dot_jit((p1-p0),(p0-q0)))/( dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(q1-q0)) - dot_jit((p1-p0),(q1-q0))**2 )
con_mat[i,j] = np.sum( (p0+(p1-p0)*s-(q0+(q1-q0)*t))**2 )
return con_mat
fast_compute_stuff = jit(double[:,:](double[:,:]))(compute_stuff)
因此,compute_stuff(arg) 将 2D np.array (double[:,:]) 作为参数,执行一堆 numba 支持的 (?) 操作,并返回另一个 2D np.array (double[:, :])。然而,
v = np.random.random( (100,6) )
%timeit compute_stuff(v)
%timeit fast_compute_stuff(v)
每个循环我得到 134 和 123 毫秒。你能解释一下为什么我不能加快我的功能吗?任何反馈将不胜感激。
【问题讨论】:
-
使用 numba 的 JIT 编译器非常不太可能击败
np.dot。np.dot只是一个瘦包装器,它调用 BLAS*gemm/*gemv函数,这些函数经过大量优化并且通常是多线程的。您最好的选择可能是确保 numpy 与您可以获得的最快的多线程 BLAS 库链接(可能是英特尔的 MKL 或 OpenBLAS)。 -
问题不在于 np.dot,问题是如果 jit 编译器遇到 np.dot 调用,它无法推断其返回类型,然后不会加速我的整个函数(顺便说一句,对于 3d 矢量标量产品,我编码的 dot_jit 比 np.dot 快)
-
您是否对原始代码进行了线路分析?我怀疑你大部分时间都在
np.dot中度过,所以不应该期望从嵌套for循环的开销中通过JIT 来获得太多性能优势。 -
使用 cProfile,我看到对于 1000 个片段,我在这些深度操作(np.dot、np.sum 等)中花费了 13 秒中的大约 1 秒的累积时间
-
好的,再次查看您的代码,我意识到那是因为您点的向量只有 2 长!你能在你的问题中发布线路分析时间吗?