【问题标题】:Checking for and indexing non-unique/duplicate values in a numpy array检查和索引 numpy 数组中的非唯一/重复值
【发布时间】:2014-10-05 13:14:38
【问题描述】:

我有一个包含对象 ID 的数组 traced_descIDs,我想确定哪些项目在该数组中不是唯一的。然后,对于每个唯一的重复(小心)ID,我需要确定traced_descIDs 的哪些索引与其相关联。

例如,如果我们在这里取 traced_descIDs,我希望发生以下过​​程:

traced_descIDs = [1, 345, 23, 345, 90, 1]
dupIds = [1, 345]
dupInds = [[0,5],[1,3]]

我目前正在通过以下方式找出哪些对象有超过 1 个条目:

mentions = np.array([len(np.argwhere( traced_descIDs == i)) for i in traced_descIDs])
dupMask = (mentions > 1)

但是,这需要很长时间,因为 len( traced_descIDs ) 大约是 150,000。有没有更快的方法来达到同样的效果?

非常感谢任何帮助。干杯。

【问题讨论】:

    标签: python arrays numpy unique


    【解决方案1】:

    虽然字典是 O(n),但 Python 对象的开销有时使使用 numpy 的函数更方便,它使用排序并且是 O(n*log n)。在您的情况下,起点是:

    a = [1, 345, 23, 345, 90, 1]
    unq, unq_idx, unq_cnt = np.unique(a, return_inverse=True, return_counts=True)
    

    如果您使用的 numpy 版本早于 1.9,那么最后一行必须是:

    unq, unq_idx = np.unique(a, return_inverse=True)
    unq_cnt = np.bincount(unq_idx)
    

    我们创建的三个数组的内容是:

    >>> unq
    array([  1,  23,  90, 345])
    >>> unq_idx
    array([0, 3, 1, 3, 2, 0])
    >>> unq_cnt
    array([2, 1, 1, 2])
    

    获取重复项:

    cnt_mask = unq_cnt > 1
    dup_ids = unq[cnt_mask]
    
    >>> dup_ids
    array([  1, 345])
    

    获取索引有点复杂,但非常简单:

    cnt_idx, = np.nonzero(cnt_mask)
    idx_mask = np.in1d(unq_idx, cnt_idx)
    idx_idx, = np.nonzero(idx_mask)
    srt_idx = np.argsort(unq_idx[idx_mask])
    dup_idx = np.split(idx_idx[srt_idx], np.cumsum(unq_cnt[cnt_mask])[:-1])
    
    >>> dup_idx
    [array([0, 5]), array([1, 3])]
    

    【讨论】:

    • 我对这个答案更满意,而且它似乎不会比上面的字典答案花太多时间。感谢您的宝贵时间。
    【解决方案2】:

    scipy.stats.itemfreq 会给出每个项目的频率:

    >>> xs = np.array([1, 345, 23, 345, 90, 1])
    >>> ifreq = sp.stats.itemfreq(xs)
    >>> ifreq
    array([[  1,   2],
           [ 23,   1],
           [ 90,   1],
           [345,   2]])
    >>> [(xs == w).nonzero()[0] for w in ifreq[ifreq[:,1] > 1, 0]]
    [array([0, 5]), array([1, 3])]
    

    【讨论】:

    • 我不知道这个功能。感谢您提请我注意。
    【解决方案3】:

    你目前的做法是O(N**2),用字典在O(N)time:

    >>> from collections import defaultdict
    >>> traced_descIDs = [1, 345, 23, 345, 90, 1]
    >>> d = defaultdict(list)
    >>> for i, x in enumerate(traced_descIDs):
    ...     d[x].append(i)
    ...     
    >>> for k, v in d.items():
    ...     if len(v) == 1:
    ...         del d[k]
    ...         
    >>> d
    defaultdict(<type 'list'>, {1: [0, 5], 345: [1, 3]})
    

    并获取项目和索引:

    >>> from itertools import izip
    >>> dupIds, dupInds = izip(*d.iteritems())
    >>> dupIds, dupInds
    ((1, 345), ([0, 5], [1, 3]))
    

    请注意,如果您想保留 dupIds 中项目的顺序,请使用 collections.OrderedDictdict.setdefault() 方法。

    【讨论】:

    • 我个人更喜欢 numpy 解决方案,但如果你想这样做,标准库已经涵盖了:from collections import Counter
    • 您能否详细说明不使用 OrderedDict 未保留的内容?
    • 注意这个解决方案会创建很多python对象,因此内存使用会爆炸;这就是为什么如果您正在处理大型数据集,留在 numpy 中可能更可取。
    • @CarlM 这里的输出可能是[345, 1],因为字典没有顺序。 OrderedDict 将确保输出为 [1, 345]。`
    【解决方案4】:
    td = np.array(traced_descIDs)
    si = np.argsort(td)
    td[si][np.append(False, np.diff(td[si]) == 0)]
    

    这给了你:

    array([  1, 345])
    

    我还没有完全弄清楚第二部分,但也许这对你来说已经足够灵感了,或者我会回到它。 :)

    【讨论】:

      【解决方案5】:

      numpy_indexed 包中嵌入了与 Jaime 提出的相同矢量化效率的解决方案(免责声明:我是其作者):

      import numpy_indexed as npi
      print(npi.group_by(traced_descIDs, np.arange(len(traced_descIDs))))
      

      这让我们大部分时间到达那里;但是,如果我们还想过滤掉单例组,同时避免任何 python 循环并保持完全向量化,我们可以降低一点,然后这样做:

      g = npi.group_by(traced_descIDs)
      unique = g.unique
      idx = g.split_array_as_list(np.arange(len(traced_descIDs)))
      duplicates = unique[g.count>1]
      idx_duplicates = np.asarray(idx)[g.count>1]
      print(duplicates, idx_duplicates)
      

      【讨论】:

        【解决方案6】:

        np.unqiue 用于 Ndims

        我在 ndArray 中遇到了类似的问题,我想在其中查找重复的行。

        x = np.arange(60).reshape(5,4,3)
        x[1] = x[0]
        

        0 和 1 应该在轴 0 中重复。我使用了 np.unique 并返回了所有选项。然后使用Jaime的方法定位重复项。

        _,i,_,c = np.unique(x,1,1,1,axis=0)
        x_dup = x[i[1<c]]
        

        为了清楚起见,我不必要地使用了return_inverse。结果如下:

        >>> print(x_dupilates)
        [[[ 0  1  2]
          [ 3  4  5]
          [ 6  7  8]
          [ 9 10 11]]]
        

        【讨论】:

          猜你喜欢
          • 2019-05-11
          • 1970-01-01
          • 1970-01-01
          • 2015-06-16
          • 2018-10-25
          • 2021-12-07
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多