【问题标题】:Finding which rows have all elements as zeros in a matrix with numpy用numpy查找矩阵中哪些行的所有元素都为零
【发布时间】:2014-07-06 17:19:00
【问题描述】:

我有一个大的numpy 矩阵M。矩阵的某些行的所有元素都为零,我需要获取这些行的索引。我正在考虑的天真的方法是遍历矩阵中的每一行,然后检查每个元素。但是我认为使用numpy 可以更好更快地完成此任务。希望能帮到你!

【问题讨论】:

    标签: python numpy matrix


    【解决方案1】:

    这是一种方法。我假设已经使用import numpy as np 导入了 numpy。

    In [20]: a
    Out[20]: 
    array([[0, 1, 0],
           [1, 0, 1],
           [0, 0, 0],
           [1, 1, 0],
           [0, 0, 0]])
    
    In [21]: np.where(~a.any(axis=1))[0]
    Out[21]: array([2, 4])
    

    这个答案略有不同:How to check that a matrix contains a zero column?

    这是怎么回事:

    如果数组中的任何值是“真”,any 方法将返回 True。非零数被认为是真,0 被认为是假。通过使用参数axis=1,该方法应用于每一行。对于示例a,我们有:

    In [32]: a.any(axis=1)
    Out[32]: array([ True,  True, False,  True, False], dtype=bool)
    

    所以每个值表示对应的行是否包含非零值。 ~ 运算符是二进制“非”或补码:

    In [33]: ~a.any(axis=1)
    Out[33]: array([False, False,  True, False,  True], dtype=bool)
    

    (给出相同结果的替代表达式是(a == 0).all(axis=1)。)

    要获取行索引,我们使用where 函数。它返回其参数为 True 的索引:

    In [34]: np.where(~a.any(axis=1))
    Out[34]: (array([2, 4]),)
    

    请注意,where 返回一个包含单个数组的元组。 where 适用于 n 维数组,因此它总是返回一个元组。我们想要那个元组中的单个数组。

    In [35]: np.where(~a.any(axis=1))[0]
    Out[35]: array([2, 4])
    

    【讨论】:

      【解决方案2】:

      如果元素是int(0),则接受的答案有效。如果要查找所有值为 0.0(浮点数)的行,则必须使用 np.isclose()

      print(x)
      # output
      tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 1., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
               1., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 0., 0.],
      ])
      np.where(np.all(np.isclose(labels, 0), axis=1))
      (array([ 0, 3]),)
      

      注意:这也适用于 PyTorch 张量,当您想要找到归零的多热编码向量时,这非常有用。

      【讨论】:

        【解决方案3】:

        使用np.sum的解决方案,
        如果您想使用阈值,则很有用

        a = np.array([[1.0, 1.0, 2.99],
                  [0.0000054, 0.00000078, 0.00000232],
                  [0, 0, 0],
                  [1, 1, 0.0],
                  [0.0, 0.0, 0.0]])
        print(np.where(np.sum(np.abs(a), axis=1)==0)[0])
        >>[2 4]
        print(np.where(np.sum(np.abs(a), axis=1)<0.0001)[0])
        >>[1 2 4]  
        

        使用np.prod 检查行是否包含至少一个零元素

        print(np.where(np.prod(a, axis=1)==0)[0])
        >>[2 3 4]
        

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2015-09-20
          • 1970-01-01
          • 2012-02-27
          • 2023-04-09
          • 1970-01-01
          • 2014-04-30
          • 1970-01-01
          相关资源
          最近更新 更多