【问题标题】:Membership checking in Numpy ndarrayNumpy ndarray 中的成员资格检查
【发布时间】:2016-08-26 10:55:21
【问题描述】:

我编写了一个脚本来评估arr 的某些条目是否在check_elements 中。我的方法比较单个条目,而是比较arr 内的整个向量。因此,脚本会检查 [8, 3][4, 5]、... 是否在 check_elements 中。

这是一个例子:

import numpy as np

# arr.shape -> (2, 3, 2)
arr = np.array([[[8,  3],
                 [4,  5],
                 [6,  2]],

                [[9,  0],
                 [1, 10],
                 [7, 11]]])

# check_elements.shape -> (3, 2)
# generally: (n, 2)
check_elements = np.array([[4, 5], [9, 0], [7, 11]])

# rslt.shape -> (2, 3)
rslt = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.bool)

for i, j in np.ndindex((arr.shape[0], arr.shape[1])):
    if arr[i, j] in check_elements:   # <-- condition is checked against
                                      #     the whole last dimension
        rslt[i, j] = True
    else:
        rslt[i, j] = False

现在:

print(rslt)

...将打印:

[[False  True False]
 [ True False  True]]

获取 I 使用的索引:

print(np.transpose(np.nonzero(rslt)))

...打印以下内容:

[[0 1]    # arr[0, 1] -> [4, 5] -> is in check_elements
 [1 0]    # arr[1, 0] -> [9, 0] -> is in check_elements
 [1 2]]   # arr[1, 2] -> [7, 11] -> is in check_elements

如果我要检查单个值的条件,例如 arr &gt; 3np.where(...),那么这项任务会很简单且高效,但我对单个值感兴趣。我想检查整个最后一个维度(或它的切片)的条件。

我的问题是:有没有更快的方法来达到同样的效果?我对矢量化尝试和np.where 之类的东西不能用于我的问题是否正确,因为它们总是对单个值而不是整个维度或该维度的切片进行操作?

【问题讨论】:

    标签: python numpy indexing


    【解决方案1】:

    这是使用broadcasting 的 Numpythonic 方法:

    >>> (check_elements == arr[:,:,None]).reshape(2, 3, 6).any(axis=2)
    array([[False,  True, False],
           [ True, False,  True]], dtype=bool)
    

    【讨论】:

      【解决方案2】:

      numpy_indexed 包(免责声明:我是它的作者)包含执行此类查询的功能;具体来说,nd(子)数组的包含关系:

      import numpy_indexed as npi
      flatidx = npi.indices(arr.reshape(-1, 2), check_elements)
      idx = np.unravel_index(flatidx, arr.shape[:-1])
      

      请注意,实现是完全矢量化的。

      另外,请注意,使用这种方法,idx 中的索引顺序与 check_elements 的顺序匹配; idx 中的第一项是 check_elements 中第一项的行和列。当使用您上面发布的方法时,或者使用其他建议的答案之一时,此信息会丢失,这将为您提供按其在 arr 中的出现顺序排序的 idx,这通常是不可取的。

      【讨论】:

        【解决方案3】:

        您可以使用np.in1d,即使它适用于一维数组,只要给它一个数组的一维视图,每个最后一个轴包含一个元素:

        arr_view = arr.view((np.void, arr.dtype.itemsize*arr.shape[-1])).ravel()
        check_view = check_elements.view((np.void,
                check_elements.dtype.itemsize*check_elements.shape[-1])).ravel()
        

        将为您提供两个一维数组,其中包含沿最后一个轴的 2 个元素数组的 void 类型版本。现在您可以通过以下方式检查arr 中的哪些元素也在check_view 中:

        flatResult = np.in1d(arr_view, check_view)
        

        这将给出一个扁平数组,然后您可以将其重塑为 arr 的形状,删除最后一个轴:

        print(flatResult.reshape(arr.shape[:-1]))
        

        这会给你想要的结果:

        array([[False,  True, False],
               [ True, False,  True]], dtype=bool)
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2011-10-19
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2016-09-19
          • 1970-01-01
          • 1970-01-01
          • 2015-08-24
          相关资源
          最近更新 更多