【问题标题】:Get indices of matches from one array in another从另一个数组中获取匹配的索引
【发布时间】:2020-06-30 12:53:03
【问题描述】:

给定两个 np.array;

a = np.array([1, 6, 5, 3, 8, 345, 34, 6, 2, 867])
b = np.array([867, 8, 34, 75])

我想获得一个与 b 具有相同维度的 np.array,其中每个值是 b 中的值出现在 a 中的索引,如果 b 中的值不存在于 a 中,则为 np.nan。

结果应该是;

[9, 4, 6, nan]

a 和 b 将始终具有相同数量的维度,但维度的大小可能不同。

类似的东西;

(伪代码)

c = np.where(b in a)

但它适用于数组(“in”不适用)

我更喜欢“单线”或至少是完全在数组级别的解决方案,并且不需要循环。

谢谢!

【问题讨论】:

    标签: python numpy


    【解决方案1】:

    方法#1

    这是一个np.searchsorted -

    def find_indices(a,b,invalid_specifier=-1):
        # Search for matching indices for each b in sorted version of a. 
        # We use sorter arg to account for the case when a might not be sorted 
        # using argsort on a
        sidx = a.argsort()
        idx = np.searchsorted(a,b,sorter=sidx)
    
        # Remove out of bounds indices as they wont be matches
        idx[idx==len(a)] = 0
    
        # Get traced back indices corresponding to original version of a
        idx0 = sidx[idx]
        
        # Mask out invalid ones with invalid_specifier and return
        return np.where(a[idx0]==b, idx0, invalid_specifier)
    

    给定样本的输出 -

    In [41]: find_indices(a, b, invalid_specifier=np.nan)
    Out[41]: array([ 9.,  4.,  6., nan])
    

    方法#2

    另一个基于 lookup 的正数 -

    def find_indices_lookup(a,b,invalid_specifier=-1):
        # Setup array where we will assign ranged numbers
        N = max(a.max(), b.max())+1
        lookup = np.full(N, invalid_specifier)
        
        # We index into lookup with b to trace back the positions. Non matching ones
        # would have invalid_specifier values as wount had been indexed by ranged ones
        lookup[a] = np.arange(len(a))
        indices  = lookup[b]
        return indices
    

    基准测试

    问题中没有提到效率作为要求,但可能会出现无循环要求。使用尝试重复给定示例设置的设置进行测试,但将其放大1000x

    In [98]: a = np.random.permutation(np.unique(np.random.randint(0,20000,10000)))
    
    In [99]: b = np.random.permutation(np.unique(np.random.randint(0,20000,4000)))
    
    # Solutions from this post
    In [100]: %timeit find_indices(a,b,invalid_specifier=np.nan)
         ...: %timeit find_indices_lookup(a,b,invalid_specifier=np.nan)
    1.35 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    220 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    # @Quang Hoang-soln2
    In [101]: %%timeit
         ...: commons, idx_a, idx_b = np.intersect1d(a,b, return_indices=True)
         ...: orders = np.argsort(idx_b)
         ...: output = np.full(b.shape, np.nan)
         ...: output[orders] = idx_a[orders]
    1.63 ms ± 59.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    # @Quang Hoang-soln1
    In [102]: %%timeit
         ...: s = b == a[:,None]
         ...: np.where(s.any(0), np.argmax(s,0), np.nan)
    137 ms ± 9.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    【讨论】:

    • 不错的答案谢谢!但你能详细说明一下它是如何工作的吗?
    【解决方案2】:

    你可以做一个广播:

    s = b == a[:,None]
    np.where(s.any(0), np.argmax(s,0), np.nan)
    

    输出:

    array([ 9.,  4.,  6., nan])
    

    更新intersect1d的另一个解决方案:

    commons, idx_a, idx_b = np.intersect1d(a,b, return_indices=True)
    
    orders = np.argsort(idx_b)
    
    output = np.full(b.shape, np.nan)
    output[orders] = idx_a[orders]
    

    【讨论】:

    • 太棒了!但是 argmax 到底是做什么的呢?你能详细说明一下解决方案吗?
    • argmax 检查沿给定轴的最大值的索引(本例为0)。
    猜你喜欢
    • 2016-10-22
    • 1970-01-01
    • 1970-01-01
    • 2021-06-24
    • 1970-01-01
    • 2016-02-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多