【问题标题】:How does python numpy.where() work?python numpy.where() 是如何工作的?
【发布时间】:2011-08-04 06:55:39
【问题描述】:

我正在玩 numpy 并挖掘文档,我发现了一些魔法。即我说的是numpy.where()

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

他们如何在内部实现您能够将x > 5 之类的内容传递给方法?我想这与__gt__ 有关,但我正在寻找详细的解释。

【问题讨论】:

    标签: python numpy magic-methods


    【解决方案1】:

    他们如何在内部实现您能够将 x > 5 之类的内容传递给方法?

    简短的回答是他们没有。

    对 numpy 数组的任何类型的逻辑操作都会返回一个布尔数组。 (即__gt____lt__ 等都返回给定条件为真的布尔数组)。

    例如

    x = np.arange(9).reshape(3,3)
    print x > 5
    

    产量:

    array([[False, False, False],
           [False, False, False],
           [ True,  True,  True]], dtype=bool)
    

    如果x 是一个numpy 数组,这与if x > 5: 之类的东西会引发ValueError 的原因相同。它是一组真/假值,而不是单个值。

    此外,numpy 数组可以由布尔数组索引。例如。在这种情况下,x[x>5] 产生 [6 7 8]

    老实说,您实际上很少需要 numpy.where,但它只返回布尔数组为 True 的索引。通常你可以用简单的布尔索引来做你需要的事情。

    【讨论】:

    • 只是指出 numpy.where 确实有 2 种“操作模式”,第一个返回 indices,其中 condition is True 以及是否存在可选参数 xy (与condition 相同的形状,或可广播到这种形状!),当condition is True 时,它将返回来自x 的值,否则来自y。因此,这使where 更加通用,并使其可以更频繁地使用。谢谢
    • 在某些情况下,使用[]__getitem__ 语法超过numpy.wherenumpy.take 也会产生开销。由于__getitem__ 还必须支持切片,因此存在一些开销。在使用 Python Pandas 数据结构和对非常大的列进行逻辑索引时,我看到了明显的速度差异。在这些情况下,如果您不需要切片,那么 takewhere 实际上会更好。
    【解决方案2】:

    旧答案 这有点令人困惑。它为您提供了您的陈述正确的位置(所有这些位置)。

    所以:

    >>> a = np.arange(100)
    >>> np.where(a > 30)
    (array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
           48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
           65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
           82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
           99]),)
    >>> np.where(a == 90)
    (array([90]),)
    
    a = a*40
    >>> np.where(a > 1000)
    (array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
           43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
           60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
           77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
           94, 95, 96, 97, 98, 99]),)
    >>> a[25]
    1000
    >>> a[26]
    1040
    

    我将它用作 list.index() 的替代方法,但它也有许多其他用途。我从未将它与二维数组一起使用过。

    http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

    新答案 这个人似乎在问一些更根本的问题。

    问题是您如何实现某些功能,让函数(例如 where)知道请求的内容。

    首先请注意,调用任何比较运算符都会做一件有趣的事情。

    a > 1000
    array([False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`
    

    这是通过重载“__gt__”方法来完成的。例如:

    >>> class demo(object):
        def __gt__(self, item):
            print item
    
    
    >>> a = demo()
    >>> a > 4
    4
    

    如您所见,“a > 4”是有效代码。

    您可以在此处获取所有重载函数的完整列表和文档:http://docs.python.org/reference/datamodel.html

    令人难以置信的是,这样做是多么简单。 python中的所有操作都是以这种方式完成的。说 a > b 等价于 a.gt(b)!

    【讨论】:

    • 这种比较运算符重载似乎不适用于更复杂的逻辑表达式 - 例如我不能做 np.where(a > 30 and a < 50)np.where(30 < a < 50) 因为它最终试图评估逻辑 AND两个布尔数组,这是毫无意义的。有没有办法用np.where写这样的条件?
    • @meowsqueak np.where((a > 30) & (a < 50))
    • 为什么 np.where() 在您的示例中返回一个列表?
    【解决方案3】:

    np.where 返回一个长度等于调用它的 numpy ndarray 的维度的元组(换句话说,ndim),元组的每个项目都是初始中所有这些值的索引的 numpy ndarray条件为真的ndarray。 (请不要将尺寸与形状混淆)

    例如:

    x=np.arange(9).reshape(3,3)
    print(x)
    array([[0, 1, 2],
          [3, 4, 5],
          [6, 7, 8]])
    y = np.where(x>4)
    print(y)
    array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))
    


    y 是长度为 2 的元组,因为 x.ndim 是 2。元组中的第一项包含所有大于 4 的元素的行号,第二项包含所有大于 4 的项的列号。如您所见,[1, 2,2,2] 对应行号 5,6,7,8,[2,0,1,2] 对应列号 5,6,7,8 请注意,ndarray 是沿第一维(按行)遍历的。

    同样,

    x=np.arange(27).reshape(3,3,3)
    np.where(x>4)
    


    将返回一个长度为 3 的元组,因为 x 有 3 个维度。

    但是等等,np.where 还有更多内容!

    当两个额外的参数被添加到np.where;它将对上述元组获得的所有成对行列组合进行替换操作。

    x=np.arange(9).reshape(3,3)
    y = np.where(x>4, 1, 0)
    print(y)
    array([[0, 0, 0],
       [0, 0, 1],
       [1, 1, 1]])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2011-12-09
      • 2011-04-30
      • 2013-12-28
      • 2011-05-16
      • 2012-04-01
      • 2023-03-15
      相关资源
      最近更新 更多