【问题标题】:How to apply condition to indices of numpy array efficiently?如何有效地将条件应用于 numpy 数组的索引?
【发布时间】:2020-07-27 09:06:29
【问题描述】:

我有一个 2D NumPy 数组,我想为数组设置值,前提是它的索引满足特定条件。

我可以使用for 循环来做到这一点:

import numpy as np

new_a = np.ones((5,10), dtype=np.float32)
for i in range(new_a.shape[0]):
    for j in range(new_a.shape[1]):
        if (np.nan_to_num(i/np.nan_to_num(j))) >= new_a.shape[0]/new_a.shape[1]:               #(This is the condition, which I may change a little as needed)
            new_a[i, j] = 0
            
print(new_a)

''' Output:                                                                                    (This gives a upper triangular matrix)
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
'''

但问题是我想为大约(10000, 20000) 的大型二维数组执行此操作。所以for 循环会非常慢。 如何使用 NumPy 或任何其他方法有效地(在更短的时间内)做到这一点 图书馆?

注意:我不希望创建对角矩阵的解决方案(因为我想将代码应用于许多不同的条件)。我正在寻找“有效地将条件应用于 numpy 数组的索引”的解决方案(一种比使用 for 循环更快的方法)。

【问题讨论】:

  • 不,它只适用于矩形阵列。另外,我想根据不同的问题改变条件。如果您告诉有关“如何有效地将条件应用于 numpy 数组的索引?”的方法将会很有帮助?

标签: python arrays numpy matrix


【解决方案1】:

你可以像这样得到数组索引:

import numpy as np

new_a = np.ones((5,10), dtype=np.float32)
indices = np.indices(new_a.shape)
y_indices = indices[0]
x_indices = indices[1]

要获得特定比较成立的索引,您可以:

locations = np.nan_to_num(indices[0] / indices[1]) >= new_a.shape[0] / new_a.shape[1]

要应用它,只需:

new_a[locations] = 0
print(new_a)

返回

[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]

【讨论】:

    【解决方案2】:

    您可以通过使用np.meshgrid 对操作进行矢量化来提高速度(支付额外的内存消耗):

    xv, yv = np.meshgrid(np.arange(new_a.shape[1]), np.arange(new_a.shape[0]))
    idx = np.nan_to_num(yv/xv) >= new_a.shape[0]/new_a.shape[1]
    new_a[idx] = 0
    print(new_a)
    

    打印

    >>> new_a
    array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
           [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
           [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)
    

    【讨论】:

    • 这会产生错误:布尔索引与维度 0 上的索引数组不匹配;维度为 1000,但对应的布尔维度为 1。在 new_a[idx]=0 行上
    • 当我尝试:new_a = np.ones((1000,2000), dtype=np.float32)
    • 是的,我在删除硬编码值时不小心删除了 np.arange()。现在已经修复了
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-01-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-09-23
    相关资源
    最近更新 更多