元素之间具有恒定步长的数组
如果是 range 或任何其他线性增加的数组,您可以简单地以编程方式计算索引,根本不需要实际迭代数组:
def first_index_calculate_range_like(val, arr):
if len(arr) == 0:
raise ValueError('no value greater than {}'.format(val))
elif len(arr) == 1:
if arr[0] > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
first_value = arr[0]
step = arr[1] - first_value
# For linearly decreasing arrays or constant arrays we only need to check
# the first element, because if that does not satisfy the condition
# no other element will.
if step <= 0:
if first_value > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
calculated_position = (val - first_value) / step
if calculated_position < 0:
return 0
elif calculated_position > len(arr) - 1:
raise ValueError('no value greater than {}'.format(val))
return int(calculated_position) + 1
也许可以稍微改进一下。我已经确保它适用于一些示例数组和值,但这并不意味着其中不会有错误,尤其是考虑到它使用浮点数...
>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16] # double check
6
>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15
鉴于它可以在没有任何迭代的情况下计算位置,它将是恒定时间 (O(1)),并且可能会击败所有其他提到的方法。但是它需要数组中的一个恒定步长,否则会产生错误的结果。
使用 numba 的一般解决方案
更通用的方法是使用 numba 函数:
@nb.njit
def first_index_numba(val, arr):
for idx in range(len(arr)):
if arr[idx] > val:
return idx
return -1
这适用于任何数组,但它必须遍历数组,所以在平均情况下它将是O(n):
>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16
基准测试
尽管 Nico Schlömer 已经提供了一些基准,但我认为包含我的新解决方案并测试不同的“值”可能会很有用。
测试设置:
import numpy as np
import math
import numba as nb
def first_index_using_argmax(val, arr):
return np.argmax(arr > val)
def first_index_using_where(val, arr):
return np.where(arr > val)[0][0]
def first_index_using_nonzero(val, arr):
return np.nonzero(arr > val)[0][0]
def first_index_using_searchsorted(val, arr):
return np.searchsorted(arr, val) + 1
def first_index_using_min(val, arr):
return np.min(np.where(arr > val))
def first_index_calculate_range_like(val, arr):
if len(arr) == 0:
raise ValueError('empty array')
elif len(arr) == 1:
if arr[0] > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
first_value = arr[0]
step = arr[1] - first_value
if step <= 0:
if first_value > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))
calculated_position = (val - first_value) / step
if calculated_position < 0:
return 0
elif calculated_position > len(arr) - 1:
raise ValueError('no value greater than {}'.format(val))
return int(calculated_position) + 1
@nb.njit
def first_index_numba(val, arr):
for idx in range(len(arr)):
if arr[idx] > val:
return idx
return -1
funcs = [
first_index_using_argmax,
first_index_using_min,
first_index_using_nonzero,
first_index_calculate_range_like,
first_index_numba,
first_index_using_searchsorted,
first_index_using_where
]
from simple_benchmark import benchmark, MultiArgument
并且这些图是使用以下方法生成的:
%matplotlib notebook
b.plot()
项目在开头
b = benchmark(
funcs,
{2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")
numba 函数的性能最好,其次是 calculate-function 和 searchsorted 函数。其他解决方案的性能要差得多。
项目在末尾h2>
b = benchmark(
funcs,
{2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")
对于小型数组,numba 函数的执行速度惊人地快,但是对于较大的数组,它的计算函数和 searchsorted 函数的表现要好。
项目位于 sqrt(len)
b = benchmark(
funcs,
{2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")
这更有趣。同样 numba 和 calculate 函数表现出色,但这实际上触发了最坏的 searchsorted 情况,在这种情况下确实不能正常工作。
没有值满足条件时的函数比较
另一个有趣的点是,如果没有应该返回其索引的值,这些函数的行为:
arr = np.ones(100)
value = 2
for func in funcs:
print(func.__name__)
try:
print('-->', func(value, arr))
except Exception as e:
print('-->', e)
有了这个结果:
first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0
Searchsorted、argmax 和 numba 只是返回错误值。但是 searchsorted 和 numba 返回的索引不是数组的有效索引。
函数where、min、nonzero 和calculate 抛出异常。但是,只有 calculate 的例外实际上说明了任何有用的信息。
这意味着实际上必须将这些调用包装在一个适当的包装函数中,该函数捕获异常或无效返回值并进行适当处理,至少在您不确定该值是否可以在数组中的情况下。
注意:calculate 和searchsorted 选项仅在特殊条件下有效。 “计算”函数需要一个恒定的步骤,而 searchsorted 需要对数组进行排序。因此,这些在适当的情况下可能很有用,但不是针对此问题的通用解决方案。如果您正在处理 sorted Python 列表,您可能需要查看 bisect 模块,而不是使用 Numpys searchsorted。