我通过的手写合并排序抛出 numba.njit 几乎是纯 python 方法的 10 倍。请注意,这在技术上是一个 argsort,即您必须将索引传递给它并获取如果应用于数组的索引让它排序。
@numba.njit
def ceil_log2(x):
n = 0
while 2**n < x:
n += 1
return n
@numba.njit
def merge(arr, relation):
out = np.zeros(len(arr), dtype='int')
j = 0
k = len(arr)//2
for i in range(len(out)):
if j == len(arr)//2:
out[i:] = arr[k:]
break
elif k == len(arr):
out[i:] = arr[j:len(arr)//2]
break
elif relation[arr[j], arr[k]]:
out[i] = arr[j]
j += 1
else:
out[i] = arr[k]
k += 1
arr[:] = out[:]
@numba.njit
def merge_sort(arr, relation):
for i in range(1, 1+ceil_log2(len(arr))):
idx = np.arange(len(arr))[::2**i]
idx = [*idx, len(arr)]
for i, j in zip(idx[:-1], idx[1:]):
merge(arr[i:j], relation)
return arr
def example_relation(arr):
idx = np.arange(len(arr))
grid = np.meshgrid(idx, idx)
relation = arr[grid[1]] <= arr[grid[0]]
return relation
np.random.seed(0)
array = np.random.normal(size=2**14)
relation = example_relation(array)
def compare(i, j):
if relation[i, j]:
return -1
else:
return 1
%time sorted(np.arange(len(array)), key=cmp_to_key(compare))
%time merge_sort(np.arange(len(array)), relation)
给我
61.5 ms ± 785 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
6.98 ms ± 43.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)