我开发了一个numpy-only 可以工作的版本,但是经过测试,我发现它的性能很差,因为它不能利用short-circuiting。既然这是你要求的,我在下面描述它。但是,使用numba 和稍微修改过的代码版本有一个好多 更好的方法。 (请注意,所有这些都返回a 中第一个匹配项的索引,而不是值本身。我发现这种方法更灵活。)
@numba.jit(nopython=True)
def find_reps_numba(a, max_len):
streak = 1
val = a[0]
for i in range(1, len(a)):
if a[i] == val:
streak += 1
if streak >= max_len:
return i - max_len + 1
else:
streak = 1
val = a[i]
return -1
事实证明,这比纯 Python 版本快约 100 倍。
numpy 版本使用rolling window trick 和argmax trick。但同样,这甚至比纯 Python 版本慢得多,大约 30 倍。
def rolling_window(a, window):
a = numpy.ascontiguousarray(a) # This approach requires a C-ordered array
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.strides + (a.strides[-1],)
return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
def find_reps_numpy(a, max_len):
windows = rolling_window(a, max_len)
return (windows == windows[:, 0:1]).sum(axis=1).argmax()
我针对第一个函数的非 jitted 版本测试了这两个函数。 (我使用 Jupyter 的 %%timeit 功能进行测试。)
a = numpy.random.randint(0, 100, 1000000)
%%timeit
find_reps_numpy(a, 3)
28.6 ms ± 553 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
find_reps_orig(a, 3)
4.04 ms ± 40.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
find_reps_numba(a, 3)
8.29 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
请注意,这些数字可能会有很大差异,具体取决于函数必须搜索的a 的深度。为了更好地估计预期性能,我们可以每次都重新生成一组新的随机数,但是如果不将那一步包含在时序中,就很难做到这一点。因此,为了在这里进行比较,我将生成随机数组所需的时间包括在内而不运行其他任何东西:
a = numpy.random.randint(0, 100, 1000000)
9.91 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
a = numpy.random.randint(0, 100, 1000000)
find_reps_numpy(a, 3)
38.2 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
a = numpy.random.randint(0, 100, 1000000)
find_reps_orig(a, 3)
13.7 ms ± 404 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
a = numpy.random.randint(0, 100, 1000000)
find_reps_numba(a, 3)
9.87 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
如您所见,find_reps_numba 速度如此之快,以至于运行numpy.random.randint(0, 100, 1000000) 所需的时间差异要大得多——因此第一次和最后一次测试之间的加速是虚幻的。
所以这个故事的主要寓意是numpy 解决方案并不总是最好的。有时甚至纯 Python 也更快。在这些情况下,nopython 模式下的numba 可能是迄今为止最好的选择。