【问题标题】:Select elements from a numpy array based on values in another array that is not an index array根据另一个不是索引数组的数组中的值从 numpy 数组中选择元素
【发布时间】:2013-02-08 01:36:28
【问题描述】:

假设我有以下两个数组:

a = array([(1, 'L', 74.423088306605), (5, 'H', 128.05441039929008),
       (2, 'L', 68.0581377353869), (0, 'H', 88.15726964130869), 
       (4, 'L', 97.4501582588212), (3, 'H', 92.98550136344437),
       (7, 'L', 87.75945631669309), (6, 'L', 90.43196739694255),
       (8, 'H', 111.13662092749307), (15, 'H', 91.44444608631304),
       (10, 'L', 85.43615908319185), (11, 'L', 78.11685661303494),
       (13, 'H', 108.2841293816308), (17, 'L', 74.43917911042259),
       (14, 'H', 64.41057325770373), (9, 'L', 27.407214746467943),
       (16, 'H', 81.50506434964355), (12, 'H', 97.79700070323196),
       (19, 'L', 51.139258140713025), (18, 'H', 118.34835768605957)], 
      dtype=[('id', '<i4'), ('name', 'S1'), ('value', '<f8')])

b = array([ 0,  3,  5,  8, 12, 13, 14, 15, 16, 18], dtype=int32)

我想从a 中选择idb 中给出的元素。也就是说,b 不是索引数组。它包含观察的ids。我怎样才能在 numpy 中做到这一点?

感谢您的帮助。

【问题讨论】:

    标签: python numpy


    【解决方案1】:

    你应该得到你想要的

    indeces = [i for i,id in enumerate(a['id']) if id in b]
    suba = a[indeces]
    print(suba)
    >>>array([(5, 'H', 128.05441039929008), (0, 'H', 88.15726964130869),
       (3, 'H', 92.98550136344437), (8, 'H', 111.13662092749307),
       (15, 'H', 91.44444608631304), (13, 'H', 108.2841293816308),
       (14, 'H', 64.41057325770373), (16, 'H', 81.50506434964355),
       (12, 'H', 97.79700070323196), (18, 'H', 118.34835768605957)], 
      dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])
    

    【讨论】:

    • 谢谢!这似乎很好。如果我在某个时候没有看到更好的答案,我会接受这个。
    【解决方案2】:

    对于您的示例数组,以下方法比 Francesco 的方法快几倍:

    In [7]: a[np.argmax(a['id'][None, :] == b[:, None], axis=1)]
    Out[7]: 
    array([(0, 'H', 88.15726964130869), (3, 'H', 92.98550136344437),
           (5, 'H', 128.05441039929008), (8, 'H', 111.13662092749307),
           (12, 'H', 97.79700070323196), (13, 'H', 108.2841293816308),
           (14, 'H', 64.41057325770373), (15, 'H', 91.44444608631304),
           (16, 'H', 81.50506434964355), (18, 'H', 118.34835768605957)], 
          dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])
    
    In [8]: %timeit a[np.argmax(a['id'][None, :] == b[:, None], axis=1)]
    100000 loops, best of 3: 11.6 us per loop
    
    In [9]: %timeit indices = [i for i,id in enumerate(a['id']) if id in b]; a[indices]
    10000 loops, best of 3: 66.9 us per loop
    

    要了解它的工作原理,请看一下:

    In [10]: a['id'][None, :] == b[:, None]
    Out[10]: 
    array([[False, False, False,  True, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False,
            False, False],
        ... # several rows removed 
        [False, False, False, False, False, False, False, False, False,
            False, False, False, False, False, False, False, False, False,
            False,  True]], dtype=bool)
    

    它是一个数组,行数与b 中的元素一样多,列数与a 中的元素一样多。 np.argmax然后找到每行第一个True的位置,即b对应元素在a['id']中第一次出现的索引。

    如上所示,对于小型数组,这在性能方面优于 python。但是如果ab 太大,那么bools 的中间数组的大小会削弱性能。此外,np.argmax 必须搜索整行,它永远不会提前跳出循环,如果a 太长,这不是一件好事。我在对this question 的回答中使用了类似的方法进行了一些计时,对于中等大小的数组来说,这仍然是可行的方法。

    Francesco 的方法绝对不那么老套,更容易理解,我必须承认,对于样本大小的数组,性能差异无关紧要。但它不会让你感觉像this...

    【讨论】:

    • 哇,这太神奇了,虽然我不能说我理解 [None,:] 背后的逻辑。只是好奇:你对你的方法的扩展有任何想法吗?天真地我会说我的缩放与 a 和 b 的大小大致呈线性关系(如果 if id in b 是懒惰的,缩放会更好)
    • @FrancescoMontesano 这正是问题所在,我认为这是 O(n**2),我会说你的更好,尽管根据 this 你可能需要将 b 转换为set 是真的。所以最终你的方法将是最快的,但是对于非常大范围的较小尺寸,python 的缓慢性或 numpy/C 的速度都无关紧要。
    • @FrancescoMontesano [None, :] 等价于.reshape(1, -1),它将一维数组转换为列向量。因此,当它比较列向量和行向量时,会将它们广播成完整的矩形。
    • 明白了!我(认为我)不知道这个广播。谢谢解释
    【解决方案3】:
    sorted = numpy.sort(a)
    sorted[b]
     array([(0, 'H', 88.15726964130869), (3, 'H', 92.98550136344437),
       (5, 'H', 128.05441039929008), (8, 'H', 111.13662092749307),
       (12, 'H', 97.79700070323196), (13, 'H', 108.2841293816308),
       (14, 'H', 64.41057325770373), (15, 'H', 91.44444608631304),
       (16, 'H', 81.50506434964355), (18, 'H', 118.34835768605957)], 
      dtype=[('id', '<i4'), ('name', '|S1'), ('value', '<f8')])
    

    只要数组中的 id 数与行数一样多。

    【讨论】:

    • 我不想依赖排序。我希望它对订购具有鲁棒性。
    • 我相信第一行应该是 sorted=numpy.argsort(a) 在这种情况下它对排序是健壮的。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-11-13
    • 2012-07-19
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多