【问题标题】:np.where checking also for subelements in multidimensional arraysnp.where 还检查多维数组中的子元素
【发布时间】:2020-11-18 09:35:29
【问题描述】:

我有两个具有相同第二维的多维数组。我想确保第一个数组的任何元素(即没有行)也是第二个数组的一行。

为此,我使用numpy.where,但它的行为也在检查同一位置的子元素。例如考虑以下代码:

x = np.array([[0,1,2,3], [4,0,6,9]])
z= np.array([[0,1,2,3], [5, 11, 6,98]])
for el in x:
    print(np.where(z==el))

打印出来:

(array([0, 0, 0, 0]), array([0, 1, 2, 3]))
(array([1]), array([2]))

第一个结果是由于第一个数组相等,第二个结果是因为z[1]x[1] 都将6 作为第三个元素。有没有办法告诉np.where 只返回严格相等元素的索引,即上面示例中的0

【问题讨论】:

    标签: python arrays python-3.x numpy


    【解决方案1】:
    [i for i, e in enumerate(x) if (e == z).all(1).any()]
    

    测试用例:

    x = np.array([[0,1,2,3], [4,0,6,9], [4,0,6,19]])
    z= np.array([[4,0,6,9], [0,1,2,3]])
    
    [i for i, e in enumerate(x) if (e == z).all(1).any()]
    

    输出:

    [0, 1]
    

    【讨论】:

    • 使用矢量化方法将大大加快速度 - 请参阅我的答案
    • 实际上,在我的测试中,您的答案仅比这快 10%。我的 void view 方法比这两种方法都快 60 倍。
    • 这是一个很好的答案,我喜欢它的简单性,同时获得第一个向量的索引会更好。
    【解决方案2】:

    简单地返回你的条件的索引 - 这里是元素明智的平等

    回答

    您可以使用矢量化操作找到重复项:

    duplicates = (x[:, None] == z).all(-1).any(-1)
    

    获取值

    要获取重复值,请使用掩码

    x[duplicates]
    

    在这个例子中:

    duplicates = [True False]
    
    x[duplicates] = [[0, 1, 2, 3]]
    

    逻辑

    1. 扩展数组[:, None]
    2. 仅查找整行匹配 all(-1)
    3. 返回至少有一个匹配的行any(-1)

    【讨论】:

      【解决方案3】:

      伙计,自从np.unique 添加了axis 参数以来,我还没有机会链接到this 答案。感谢@Jaime

      vview = lambda a: np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))
      

      基本上,这会将矩阵的“行”转换为行的原始数据流上的一维视图数组。这使您可以像比较单个值一样比较行。

      那么就很简单了:

      print(np.where(vview(x) == vview(z).T))
      (array([0], dtype=int64), array([0], dtype=int64))
      

      表示x的第一行匹配z的第一行

      如果您只想知道x 的行是否在z 的行中:

      print(np.where(np.isin(vview(x), vview(z)).squeeze()))
      (array([0], dtype=int64),)
      

      与@mujjiga 在大数组上的检查时间相比:

      x = np.random.randint(10, size = (1000, 4))
      
      z = np.random.randint(10, size = (1000, 4))
      
      %timeit np.where(np.isin(vview(x), vview(z)).squeeze())
      365 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
      
      %timeit [i for i, e in enumerate(x) if (e == z).all(1).any()]  # @mujjiga
      21.3 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
      
      %timeit np.where((x[:, None] == z).all(-1).any(-1))  # @orgoro
      20 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
      

      因此,循环切片的速度大约提高了 60 倍,这可能是由于快速短路并且只比较了 1/4 的值

      【讨论】:

      • 那不是日常代码:我喜欢它。让我回顾一下,看看我是否做对了。 ascontigous() 在内存中返回一个连续的数组表示,view 方法基本上告诉如何重新格式化没有指定类型的数组 (np.void),我假设 a.dtype.itemsize * a.shape[1] 是一行的字节大小。我不明白为什么在比较时转置。
      • 在某个时候view 开始保持尺寸,所以(n, m) 输入的sn-p 的输出现在是(n, 1) 而不是(n,)。所以,这意味着我可以用广播通常需要的[:, None][None, :] 进行简单的转置。
      【解决方案4】:

      嗯,对于 2D 数组,类似下面的内容可能有用。我认为您必须小心在浮点运算中检查ee==0

      import numpy as np
      aa = np.arange(16).reshape(4,4)
      # we are trying to find the row in aa which is equal to bb
      bb = np.asarray([0,1,2,3])
      cc = bb[None,:]
      dd = aa - cc
      ee = np.linalg.norm(dd,axis=1)
      idx = np.where(ee==0)
      

      【讨论】:

      • 这并不可靠。如果a = [1,2,3]b=[1,1,4](a-b).sum()==0 为真,但即使条目不匹配
      • 感谢您指出这一点。让我看看我是否可以修改它。编辑:用 np.linalg.norm 替换 np.sum。
      • 我的意思是,它可以工作,但即使与循环相比,这也是大量的计算开销。平方根很慢
      猜你喜欢
      • 2018-06-01
      • 1970-01-01
      • 2019-06-04
      • 2017-12-27
      • 1970-01-01
      • 2022-12-06
      • 1970-01-01
      • 2012-07-04
      • 1970-01-01
      相关资源
      最近更新 更多