【问题标题】:Select elements of numpy array via mask and preserving original dimensions通过掩码选择numpy数组的元素并保留原始尺寸
【发布时间】:2021-01-18 00:15:25
【问题描述】:

您好,我有以下数据

ids = np.concatenate([1.0 * np.ones(shape=(4, 9,)), 
                      2.0 * np.ones(shape=(4, 3,))], axis=1)

logits = np.random.normal(size=(4, 9 + 3, 256))

现在我只想获取具有 1.0 的 id 的 numpy 数组,并且我想获取大小为 (4,9, 256) 的数组

我尝试了logits[ids == 1.0, :],但我得到了(36, 256) 如何在不连接前两个维度的情况下进行切片?

当前尺寸只是示例尺寸,我正在寻找通用解决方案。

【问题讨论】:

  • 您尝试做的事情是不可能的。但是您可以在条件不满意的地方填写值,这样它就不会影响您的进一步处理。参考stackoverflow.com/questions/29046162/…
  • 在您的示例中,您仅屏蔽了第二维。因此,您可以将掩码减少为一维数组:logits[:,(ids==1.0).all(axis=0),:]。由您来提供有关True 分布的附加信息,无论这意味着reshape 之后,还是事先修改掩码。

标签: python numpy


【解决方案1】:

您的问题似乎假设每一行都有相同数量的非零条目;在这种情况下,您通常可以像这样解决您的问题:

mask = (ids == 1)
num_per_row = mask.sum(1)

# same number of entries per row is required
assert np.all(num_per_row == num_per_row[0])  

result = logits[mask].reshape(logits.shape[0], num_per_row[0], logits.shape[2])

print(result.shape)
# (4, 9, 256)

【讨论】:

    猜你喜欢
    • 2013-11-27
    • 1970-01-01
    • 2018-10-30
    • 1970-01-01
    • 2016-03-28
    • 2011-02-08
    • 2016-07-31
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多