【发布时间】:2021-01-18 00:15:25
【问题描述】:
您好,我有以下数据
ids = np.concatenate([1.0 * np.ones(shape=(4, 9,)),
2.0 * np.ones(shape=(4, 3,))], axis=1)
logits = np.random.normal(size=(4, 9 + 3, 256))
现在我只想获取具有 1.0 的 id 的 numpy 数组,并且我想获取大小为 (4,9, 256) 的数组
我尝试了logits[ids == 1.0, :],但我得到了(36, 256)
如何在不连接前两个维度的情况下进行切片?
当前尺寸只是示例尺寸,我正在寻找通用解决方案。
【问题讨论】:
-
您尝试做的事情是不可能的。但是您可以在条件不满意的地方填写值,这样它就不会影响您的进一步处理。参考stackoverflow.com/questions/29046162/…
-
在您的示例中,您仅屏蔽了第二维。因此,您可以将掩码减少为一维数组:
logits[:,(ids==1.0).all(axis=0),:]。由您来提供有关True分布的附加信息,无论这意味着reshape之后,还是事先修改掩码。