循环并非总是不好的(尤其是当您需要循环时)。此外,没有任何工具或算法可以比 O(n) 更快。所以让我们做一个好的循环。
生成器函数
def cumsum_breach(x, target):
total = 0
for i, y in enumerate(x):
total += y
if total >= target:
yield i
total = 0
list(cumsum_breach(x, 10))
[4, 9]
使用 Numba 及时编译
Numba 是需要安装的第三方库。
Numba 对于支持哪些功能可能会很挑剔。但这行得通。
此外,正如 Divakar 所指出的,Numba 在数组方面表现更好
from numba import njit
@njit
def cumsum_breach_numba(x, target):
total = 0
result = []
for i, y in enumerate(x):
total += y
if total >= target:
result.append(i)
total = 0
return result
cumsum_breach_numba(x, 10)
测试两者
因为我喜欢它¯\_(ツ)_/¯
设置
np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()
准确度
i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))
assert i0 == i1
时间
%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))
582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Numba 大约快 100 倍。
为了更真实的苹果对苹果测试,我将列表转换为 Numpy 数组
%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))
43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
这让他们差不多。