【问题标题】:Find all indexes of a numpy array closest to a value查找最接近某个值的 numpy 数组的所有索引
【发布时间】:2018-07-31 17:49:51
【问题描述】:

在一个 numpy 数组中,需要所有最接近给定常量的值的索引。 背景是数字信号处理。该数组包含一个滤波器的幅度函数(np.abs(np.fft.rfft(h))),并且在幅度为例如的地方搜索某些频率(=索引)。 0.5 或在另一种情况下为 0。 大多数时候,所讨论的值并不完全包含在序列中。关闭值的索引应该在这个里面找到。

到目前为止,我想出了以下方法,在该方法中,我查看了序列和常数之间差异的符号变化。但是,这仅适用于在相关点单调递增或递减的序列。有时它也会关闭 1。

def findvalue(seq, value):
    diffseq = seq - value
    signseq = np.sign(diffseq)
    signseq[signseq == 0] = 1
    return np.where(np.diff(signseq))[0]

我想知道是否有更好的解决方案。仅适用于一维实浮点数组,对我的计算效率要求不高。

作为一个数字示例,以下代码应返回 [8, 41]。为了简单起见,我在这里用半波替换了滤波器幅度响应。

f=np.sin(np.linspace(0, np.pi))
findvalue(f, 0.5)

我发现的类似问题如下,但它们只返回第一个或第二个索引:
Find the second closest index to value
Find nearest value in numpy array

【问题讨论】:

    标签: python numpy search


    【解决方案1】:

    这可能远不是最好的方法(我还在学习 numpy),但我希望它可以帮助你找到一个。

    min_distance = np.abs(your_array - your_constant).min()
    # These two tuples contain number closest to your constant from each side.
    np.where(bar == val - min_distance)  # Closest, < your_constant
    np.where(bar == val + min_distance)  # Closest, > your_constant
    

    【讨论】:

      【解决方案2】:

      以下函数将返回一个小数索引,显示大约何时超过该值:

      def FindValueIndex(seq, val):
          r = np.where(np.diff(np.sign(seq - val)) != 0)
          idx = r + (val - seq[r]) / (seq[r + np.ones_like(r)] - seq[r])
          idx = np.append(idx, np.where(seq == val))
          idx = np.sort(idx)
          return idx
      

      逻辑:查找 seq - val 的符号在哪里发生变化。在转换和插值的下方和上方取值一个索引。添加到该索引,其中值实际上等于该值。

      如果你想要一个整数索引,只需使用 np.round。您还可以选择 np.floor 或 np.ceil 将索引四舍五入到您的偏好。

      def FindValueIndex(seq, val):
          r = np.where(np.diff(np.sign(seq - val)) != 0)
          idx = r + (val - seq[r]) / (seq[r + np.ones_like(r)] - seq[r])
          idx = np.append(idx, np.where(seq == val))
          idx = np.sort(idx)
          return np.round(idx)
      

      【讨论】:

      • 谢谢,我也考虑过插值以改进结果。一个问题:你为什么写r + np.ones_like(r)而不是简单的r + 1
      • 那是因为 np.where 返回一个元组。或者我可能已经完成r = np.where(np.diff(np.sign(seq - val)) != 0)[0] 然后r + 1 会工作。
      • 谢谢,根据您的输入,我最终采用了以下方法。我将 sort 更改为 unique 以删除重复项并使舍入可选。 def argvalue(seq, val, intidx=True): r = np.where(np.diff(np.sign(seq - val)) != 0) idx = r + (val - seq[r]) / (seq[r + np.ones_like(r)] - seq[r]) idx = np.append(idx, np.where(seq == val)) if intidx: idx = np.round(idx).astype(int) idx = np.unique(idx) return idx
      【解决方案3】:
      def findvalue(seq, value):
          diffseq = seq - value
          signseq = np.sign(diffseq)
          zero_crossings = signseq[0:-2] != signseq[1:-1]
          indices = np.where(zero_crossings)[0]
          for i, v in enumerate(indices):
              if abs(seq[v + 1] - value) < abs(seq[v] - value):
                  indices[i] = v + 1
          return indices
      

      更多解释

      def print_vec(v):
          for i, f in enumerate(v):
              print("[{}]{:.2f} ".format(i,f), end='')
          print('')
      
      def findvalue_loud(seq, value):
          diffseq = seq - value
          signseq = np.sign(diffseq)
          print_vec(signseq)
          zero_crossings = signseq[0:-2] != signseq[1:-1]
          print(zero_crossings)
      
          indices = np.where(zero_crossings)[0]
          # indices contains the index in the original vector
          # just before the seq crosses the value [8 40]
          # this may be good enough for you
          print(indices)
      
          for i, v in enumerate(indices):
              if abs(seq[v + 1] - value) < abs(seq[v] - value):
                  indices[i] = v + 1
          # now indices contains the closest [8 41]
          print(indices)
          return indices
      

      【讨论】:

        【解决方案4】:

        我认为您在这里有两个选择。一种是对形状进行一些假设,并寻找seq 和您的val 之间差异的零交叉点(就像@ColonelFazackerleytheir answer 中所做的那样)。另一种是说明您希望将值考虑到足够接近的相对容差。

        在后一种情况下,您可以使用numpy.isclose:

        import numpy as np
        
        def findvalue(seq, val, rtol=0.05):    # value that works for your example
            return np.where(np.isclose(seq, val, rtol=rtol))[0]
        

        例子:

        x = np.sin(np.linspace(0, np.pi))
        print(findvalue(x, 0.5))
        # array([ 8, 41])
        

        这有一个缺点,它依赖于rtol 的值。将其设置得太大(此示例为0.1),您会在靠近交叉口的位置获得多个值,将其设置得太低,您将得不到任何值。

        【讨论】:

          猜你喜欢
          • 2018-01-03
          • 2011-08-29
          • 2016-07-19
          • 2013-03-01
          • 1970-01-01
          • 2016-03-17
          • 1970-01-01
          • 2018-06-05
          相关资源
          最近更新 更多