【问题标题】:Extract fixed number of elements per row in numpy array在numpy数组中每行提取固定数量的元素
【发布时间】:2021-05-09 01:51:19
【问题描述】:

假设我有一个数组a和一个布尔数组b,我想从a的每一行的有效元素中提取固定数量的元素。有效元素是b 指示的元素。

这是一个例子:

a = np.arange(24).reshape(4,6)
b = np.array([[0,0,1,1,0,0],[0,1,0,1,0,1],[0,1,1,1,1,0],[0,0,0,0,1,1]]).astype(bool)
x = []
for i in range(a.shape[0]):
    c = a[i,b[i]]
    d = np.random.choice(c, 2)
    x.append(d)

这里我使用了一个 for 循环,如果这些数组又大又高维,它会很慢。有没有更有效的方法来做到这一点?谢谢。

【问题讨论】:

    标签: arrays numpy numpy-ndarray array-broadcasting numpy-ufunc


    【解决方案1】:
    1. 生成形状为 a 的随机均匀 [0, 1] 矩阵。
    2. 将此矩阵乘以掩码 b 以将无效元素设置为零。
    3. 从每行中选择k 的最大索引(仅从该行中的有效元素模拟无偏随机k-sample)。
    4. (可选)使用这些索引来获取元素。
    a = np.arange(24).reshape(4,6)
    b = np.array([[0,0,1,1,0,0],[0,1,0,1,0,1],[0,1,1,1,1,0],[0,0,0,0,1,1]])
    k = 2
    
    r = np.random.uniform(size=a.shape)
    indices = np.argpartition(-r * b, k)[:,:k]
    

    从索引中获取元素:

    >>> indices
    array([[3, 2],
           [5, 1],
           [3, 2],
           [4, 5]])
    >>> a[np.arange(a.shape[0])[:,None], indices]
    array([[ 3,  2],
           [11,  7],
           [15, 14],
           [22, 23]])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-08-26
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多