【问题标题】:Numpy first occurrence of value greater than existing valueNumpy第一次出现的值大于现有值
【发布时间】:2013-04-21 01:42:58
【问题描述】:

我在 numpy 中有一个一维数组,我想找到索引的位置,其中某个值超过了 numpy 数组中的值。

例如

aa = range(-10,10)

aa 中查找位置,其中5 的值被超出。

【问题讨论】:

  • 正如 ambrus 所评论的那样,应该清楚是否可能没有解决方案(因为例如 argmax 答案在这种情况下将不起作用(max of (0,0,0,0) = 0))跨度>
  • 我同意这一点,并在下面提供了一个答案(即使有一个我认为仍然模棱两可的公认答案)。我认为代码的正确性比性能更重要。

标签: python numpy


【解决方案1】:

这有点快(而且看起来更好)

np.argmax(aa>5)

由于argmax 将在第一个True 处停止(“如果多次出现最大值,则返回与第一次出现对应的索引。”)并且不会保存另一个列表。

In [2]: N = 10000

In [3]: aa = np.arange(-N,N)

In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop

In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop

In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop

【讨论】:

  • 请注意:如果其输入数组中没有 True 值,np.argmax 将愉快地返回 0(在这种情况下这不是您想要的)。
  • 结果是正确的,但我觉得解释有点可疑。 argmax 似乎并没有止步于第一个 True。 (这可以通过在不同位置使用单个True 创建布尔数组来测试。)速度可能是因为argmax 不需要创建输出列表。
  • 我认为你是对的,@DrV。我的解释是关于为什么它给出了正确的结果,尽管最初的意图实际上并没有寻求最大值,而不是为什么它更快,因为我不能声称理解 argmax 的内部细节。
  • @DrV,我刚刚使用 NumPy 1.11.2 在具有单个 True 的 1000 万个元素的布尔数组上运行 argmaxTrue 的位置很重要。所以 1.11.2 的 argmax 似乎在布尔数组上“短路”了。
  • 我用 2^30 个元素的数组重复了@UlrichStern 的实验(先用 1 填充每个元素,然后用 0 填充,然后添加单个真值以消除空白页欺骗、页面错误噪音等)。当唯一真正的元素位于数组的开头而不是结尾时,np.argmax 的速度提高了 1e5 倍。这是 numpy 1.16.5。
【解决方案2】:

鉴于数组的排序内容,有一个更快的方法:searchsorted

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]

# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop

【讨论】:

  • 这确实是假设数组已排序的最佳答案(实际上并没有在问题中指定)。你可以用np.searchsorted(..., side='right')避免尴尬的+1
  • 我认为side 参数只有在排序数组中有重复值时才会产生影响。它不会改变返回索引的含义,它始终是您可以插入查询值的索引,将以下所有条目向右移动,并维护一个排序数组。
  • @Gus, side 当相同的值在 both 排序和插入数组中时有效,无论其中任何一个重复值如何。排序数组中的重复值只是夸大了效果(两边之间的差异是被插入的值出现在排序数组中的次数)。 side 确实改变了返回索引的含义,尽管它不会改变结果数组,将值插入到这些索引处的排序数组中。一个微妙但重要的区别;事实上,如果N/2 不在aa 中,这个答案给出了错误的索引。
  • 如上述评论中所暗示的,如果N/2 不在aa 中,则此答案会减一。正确的形式是np.searchsorted(aa, N/2, side='right')(没有+1)。否则,两种形式都给出相同的索引。考虑N 的测试用例是奇数(如果使用 python 2,N/2.0 强制浮动)。
【解决方案3】:

我对此也很感兴趣,我已将所有建议的答案与perfplot 进行了比较。 (免责声明:我是 perfplot 的作者。)

如果您知道您正在查看的数组已经排序,那么

numpy.searchsorted(a, alpha)

适合你。这是 O(log(n)) 操作,即速度几乎不取决于数组的大小。没有比这更快的了。

如果您对阵列一无所知,那么您不会出错

numpy.argmax(a > alpha)

已排序:

未分类:

重现情节的代码:

import numpy
import perfplot


alpha = 0.5
numpy.random.seed(0)


def argmax(data):
    return numpy.argmax(data > alpha)


def where(data):
    return numpy.where(data > alpha)[0][0]


def nonzero(data):
    return numpy.nonzero(data > alpha)[0][0]


def searchsorted(data):
    return numpy.searchsorted(data, alpha)


perfplot.save(
    "out.png",
    # setup=numpy.random.rand,
    setup=lambda n: numpy.sort(numpy.random.rand(n)),
    kernels=[argmax, where, nonzero, searchsorted],
    n_range=[2 ** k for k in range(2, 23)],
    xlabel="len(array)",
)

【讨论】:

  • np.searchsorted 不是恒定时间。实际上是O(log(n))。但是您的测试用例实际上是对searchsorted(即O(1))的最佳用例进行基准测试。
  • @MSeifert 需要什么样的输入数组/alpha 才能看到 O(log(n))?
  • 在索引 sqrt(length) 处获取项目确实会导致非常糟糕的性能。我还在这里写了一个answer,包括那个基准。
  • 我怀疑searchsorted(或任何算法)可以击败O(log(n)) 的二进制搜索排序均匀分布的数据。编辑:searchsorted 二分搜索。
  • 如果你知道均匀分布,你可以用 O(1) 击败二进制搜索。如果我有 0 - 1000 之间的单调数,并且您想找到值 748,则可以转到位置 784。这是一组有序的均匀分布的数据,并且是一种可以击败它的算法。
【解决方案4】:
In [34]: a=np.arange(-10,10)

In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)

In [37]: np.where(a>5)[0][0]
Out[37]: 16

【讨论】:

    【解决方案5】:

    元素之间具有恒定步长的数组

    如果是 range 或任何其他线性增加的数组,您可以简单地以编程方式计算索引,根本不需要实际迭代数组:

    def first_index_calculate_range_like(val, arr):
        if len(arr) == 0:
            raise ValueError('no value greater than {}'.format(val))
        elif len(arr) == 1:
            if arr[0] > val:
                return 0
            else:
                raise ValueError('no value greater than {}'.format(val))
    
        first_value = arr[0]
        step = arr[1] - first_value
        # For linearly decreasing arrays or constant arrays we only need to check
        # the first element, because if that does not satisfy the condition
        # no other element will.
        if step <= 0:
            if first_value > val:
                return 0
            else:
                raise ValueError('no value greater than {}'.format(val))
    
        calculated_position = (val - first_value) / step
    
        if calculated_position < 0:
            return 0
        elif calculated_position > len(arr) - 1:
            raise ValueError('no value greater than {}'.format(val))
    
        return int(calculated_position) + 1
    

    也许可以稍微改进一下。我已经确保它适用于一些示例数组和值,但这并不意味着其中不会有错误,尤其是考虑到它使用浮点数...

    >>> import numpy as np
    >>> first_index_calculate_range_like(5, np.arange(-10, 10))
    16
    >>> np.arange(-10, 10)[16]  # double check
    6
    
    >>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
    15
    

    鉴于它可以在没有任何迭代的情况下计算位置,它将是恒定时间 (O(1)),并且可能会击败所有其他提到的方法。但是它需要数组中的一个恒定步长,否则会产生错误的结果。

    使用 numba 的一般解决方案

    更通用的方法是使用 numba 函数:

    @nb.njit
    def first_index_numba(val, arr):
        for idx in range(len(arr)):
            if arr[idx] > val:
                return idx
        return -1
    

    这适用于任何数组,但它必须遍历数组,所以在平均情况下它将是O(n)

    >>> first_index_numba(4.8, np.arange(-10, 10))
    15
    >>> first_index_numba(5, np.arange(-10, 10))
    16
    

    基准测试

    尽管 Nico Schlömer 已经提供了一些基准,但我认为包含我的新解决方案并测试不同的“值”可能会很有用。

    测试设置:

    import numpy as np
    import math
    import numba as nb
    
    def first_index_using_argmax(val, arr):
        return np.argmax(arr > val)
    
    def first_index_using_where(val, arr):
        return np.where(arr > val)[0][0]
    
    def first_index_using_nonzero(val, arr):
        return np.nonzero(arr > val)[0][0]
    
    def first_index_using_searchsorted(val, arr):
        return np.searchsorted(arr, val) + 1
    
    def first_index_using_min(val, arr):
        return np.min(np.where(arr > val))
    
    def first_index_calculate_range_like(val, arr):
        if len(arr) == 0:
            raise ValueError('empty array')
        elif len(arr) == 1:
            if arr[0] > val:
                return 0
            else:
                raise ValueError('no value greater than {}'.format(val))
    
        first_value = arr[0]
        step = arr[1] - first_value
        if step <= 0:
            if first_value > val:
                return 0
            else:
                raise ValueError('no value greater than {}'.format(val))
    
        calculated_position = (val - first_value) / step
    
        if calculated_position < 0:
            return 0
        elif calculated_position > len(arr) - 1:
            raise ValueError('no value greater than {}'.format(val))
    
        return int(calculated_position) + 1
    
    @nb.njit
    def first_index_numba(val, arr):
        for idx in range(len(arr)):
            if arr[idx] > val:
                return idx
        return -1
    
    funcs = [
        first_index_using_argmax, 
        first_index_using_min, 
        first_index_using_nonzero,
        first_index_calculate_range_like, 
        first_index_numba, 
        first_index_using_searchsorted, 
        first_index_using_where
    ]
    
    from simple_benchmark import benchmark, MultiArgument
    

    并且这些图是使用以下方法生成的:

    %matplotlib notebook
    b.plot()
    

    项目在开头

    b = benchmark(
        funcs,
        {2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
        argument_name="array size")
    

    numba 函数的性能最好,其次是 calculate-function 和 searchsorted 函数。其他解决方案的性能要差得多。

    项目在末尾​​h2>
    b = benchmark(
        funcs,
        {2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
        argument_name="array size")
    

    对于小型数组,numba 函数的执行速度惊人地快,但是对于较大的数组,它的计算函数和 searchsorted 函数的表现要好。

    项目位于 sqrt(len)

    b = benchmark(
        funcs,
        {2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
        argument_name="array size")
    

    这更有趣。同样 numba 和 calculate 函数表现出色,但这实际上触发了最坏的 searchsorted 情况,在这种情况下确实不能正常工作。

    没有值满足条件时的函数比较

    另一个有趣的点是,如果没有应该返回其索引的值,这些函数的行为:

    arr = np.ones(100)
    value = 2
    
    for func in funcs:
        print(func.__name__)
        try:
            print('-->', func(value, arr))
        except Exception as e:
            print('-->', e)
    

    有了这个结果:

    first_index_using_argmax
    --> 0
    first_index_using_min
    --> zero-size array to reduction operation minimum which has no identity
    first_index_using_nonzero
    --> index 0 is out of bounds for axis 0 with size 0
    first_index_calculate_range_like
    --> no value greater than 2
    first_index_numba
    --> -1
    first_index_using_searchsorted
    --> 101
    first_index_using_where
    --> index 0 is out of bounds for axis 0 with size 0
    

    Searchsorted、argmax 和 numba 只是返回错误值。但是 searchsortednumba 返回的索引不是数组的有效索引。

    函数whereminnonzerocalculate 抛出异常。但是,只有 calculate 的例外实际上说明了任何有用的信息。

    这意味着实际上必须将这些调用包装在一个适当的包装函数中,该函数捕获异常或无效返回值并进行适当处理,至少在您不确定该值是否可以在数组中的情况下。


    注意:calculate 和searchsorted 选项仅在特殊条件下有效。 “计算”函数需要一个恒定的步骤,而 searchsorted 需要对数组进行排序。因此,这些在适当的情况下可能很有用,但不是针对此问题的通用解决方案。如果您正在处理 sorted Python 列表,您可能需要查看 bisect 模块,而不是使用 Numpys searchsorted。

    【讨论】:

      【解决方案6】:

      我想提议

      np.min(np.append(np.where(aa>5)[0],np.inf))
      

      这将返回满足条件的最小索引,而如果条件从未满足则返回无穷大(并且where 返回一个空数组)。

      【讨论】:

        【解决方案7】:

        我会去

        i = np.min(np.where(V >= x))
        

        其中V 是向量(一维数组),x 是值,i 是结果索引。

        【讨论】:

          【解决方案8】:

          您应该使用np.where 而不是np.argmax。即使没有找到值,后者也会返回位置 0,这不是您期望的索引。

          >>> aa = np.array(range(-10,10))
          >>> print(aa)
          array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
                   3,   4,   5,   6,   7,   8,   9])
          

          如果满足条件,则返回一个索引数组。

          >>> idx = np.where(aa > 5)[0]
          >>> print(idx)
          array([16, 17, 18, 19], dtype=int64)
          

          否则,如果不满足,则返回一个空数组。

          >>> not_found = len(np.where(aa > 20)[0])
          >>> print(not_found)
          array([], dtype=int64)
          

          在这种情况下反对argmax 的要点是:越简单越好,如果解决方案不模棱两可。因此,要检查是否有问题符合条件,只需执行if len(np.where(aa &gt; value_to_search)[0]) &gt; 0

          【讨论】:

            猜你喜欢
            • 1970-01-01
            • 2018-06-30
            • 2018-05-28
            • 2017-08-23
            • 2023-03-12
            • 1970-01-01
            • 1970-01-01
            • 1970-01-01
            • 2019-08-19
            相关资源
            最近更新 更多