【发布时间】:2020-03-08 07:46:54
【问题描述】:
我希望有效地计算应在数组中插入元素以保持顺序的索引,但包括一个小数部分,表示数组中两个最近点之间的“距离”。
应该可以使用索引和分数取回原始值。在实践中,以及性能很重要的原因,我需要对大量数据点执行此操作。
为了证明我的意思,我通过np.searchsorted 和一些if 语句提出了一些工作逻辑,但无法使用NumPy 对逻辑进行矢量化。我也很高兴看到一个有效的解决方案,它利用 numba 并具有与 NumPy 相当或更好的性能。甚至是我不知道的 NumPy、Scipy 等现成的解决方案。
我还在下面包含了一些基准测试代码。
import numpy as np
np.random.seed(0)
datapoint = np.random.random() * np.random.choice([1, -1]) * 500 # -274.4067
line = np.linspace(-500, 500, 101) # [-500, -490, ... , 0, ..., 490, 500] - an ordered array, may not be linspace
def get_position(line, point):
position = np.searchsorted(line, point, side='right')
size = line.shape[0]
if position == 0:
main = 0
fraction = 0
elif position == size:
main = size-1
fraction = 0
else:
main = position - 1
fraction = (point - line[position-1]) / (line[position] - line[position-1])
return main, fraction
idx, frac = get_position(line, datapoint) # (22, 0.55932480363376269)
print(line[idx] + frac * (line[idx + 1] - line[idx])) # -274.4067; test to see if you get back original value
def run_multiple(line, data):
out = np.empty((data.shape[0], 3))
for i in range(data.shape[0]):
idx, frac = get_position(line, data[i])
out[i, 0] = data[i]
out[i, 1] = idx
out[i, 2] = frac
return out
基准测试
# Python 3.6.0, NumPy 1.11.3, Numba 0.30.1
# Note: Numba 0.30.1 does not support "side" argument of np.searchsorted; not able to upgrade
n = 10**5 # Actual n will be larger
res = run_multiple(line, np.random.random(n) * np.random.choice([1, -1], n) * 500) # 901 ms per loop
# array([[ -4.22132874e+02, 7.00000000e+00, 7.86712571e-01],
# [ -4.28972809e+02, 7.00000000e+00, 1.02719119e-01],
# [ 4.23625869e+02, 9.20000000e+01, 3.62586939e-01],
# ...,
# [ -1.88627877e+02, 3.10000000e+01, 1.37212282e-01],
# [ 4.98162640e+01, 5.40000000e+01, 9.81626397e-01],
# [ 1.35777097e+02, 6.30000000e+01, 5.77709684e-01]])
【问题讨论】:
-
另外,现在 numba does support
right参数searchsorted。您还在使用旧版本吗? -
@user3483203,不幸的是,现在,是的,我在基准测试部分添加了代码注释,应该更多地突出显示它。
-
为什么不使用新 Numba 版本的实现?它在第 3347 行 github.com/numba/numba/blob/master/numba/targets/arraymath.py
标签: python arrays algorithm performance numpy