我。沿最后一个轴(行)的 Ndim 数组掩码
对于沿行屏蔽的 n-dim 数组,我们可以这样做 -
def mask_from_start_indices(a, mask_indices):
r = np.arange(a.shape[-1])
return mask_indices[...,None]<=r
示例运行 -
In [177]: np.random.seed(0)
...: a = np.random.randint(10, size=(2, 2, 5))
...: mask_indices = np.argmax(a, axis=-1)
In [178]: a
Out[178]:
array([[[5, 0, 3, 3, 7],
[9, 3, 5, 2, 4]],
[[7, 6, 8, 8, 1],
[6, 7, 7, 8, 1]]])
In [179]: mask_indices
Out[179]:
array([[4, 0],
[2, 3]])
In [180]: mask_from_start_indices(a, mask_indices)
Out[180]:
array([[[False, False, False, False, True],
[ True, True, True, True, True]],
[[False, False, True, True, True],
[False, False, False, True, True]]])
二。沿通用轴的 Ndim 数组掩码
对于沿通用轴屏蔽的 n-dim 数组,它将是 -
def mask_from_start_indices_genericaxis(a, mask_indices, axis):
r = np.arange(a.shape[axis]).reshape((-1,)+(1,)*(a.ndim-axis-1))
mask_indices_nd = mask_indices.reshape(np.insert(mask_indices.shape,axis,1))
return mask_indices_nd<=r
示例运行 -
数据数组设置:
In [288]: np.random.seed(0)
...: a = np.random.randint(10, size=(2, 3, 5))
In [289]: a
Out[289]:
array([[[5, 0, 3, 3, 7],
[9, 3, 5, 2, 4],
[7, 6, 8, 8, 1]],
[[6, 7, 7, 8, 1],
[5, 9, 8, 9, 4],
[3, 0, 3, 5, 0]]])
沿axis=1 设置和屏蔽索引-
In [290]: mask_indices = np.argmax(a, axis=1)
In [291]: mask_indices
Out[291]:
array([[1, 2, 2, 2, 0],
[0, 1, 1, 1, 1]])
In [292]: mask_from_start_indices_genericaxis(a, mask_indices, axis=1)
Out[292]:
array([[[False, False, False, False, True],
[ True, False, False, False, True],
[ True, True, True, True, True]],
[[ True, False, False, False, False],
[ True, True, True, True, True],
[ True, True, True, True, True]]])
沿axis=2 设置和屏蔽索引-
In [293]: mask_indices = np.argmax(a, axis=2)
In [294]: mask_indices
Out[294]:
array([[4, 0, 2],
[3, 1, 3]])
In [295]: mask_from_start_indices_genericaxis(a, mask_indices, axis=2)
Out[295]:
array([[[False, False, False, False, True],
[ True, True, True, True, True],
[False, False, True, True, True]],
[[False, False, False, True, True],
[False, True, True, True, True],
[False, False, False, True, True]]])
其他场景
A.扩展到给定的结束/停止索引以进行屏蔽
当我们获得结束/停止索引以进行屏蔽时,要扩展解决方案,即我们正在寻找矢量化 mask[r, :m] = True,我们只需将发布的解决方案中的最后一个比较步骤编辑为以下 -
return mask_indices_nd>r
B.输出整数数组
在某些情况下,我们可能希望获得一个 int 数组。在那些上,简单地查看输出。因此,如果out 是已发布解决方案的输出,那么我们可以简单地为int8 和uint8 dtype 输出分别执行out.view('i1') 或out.view('u1')。
对于其他数据类型,我们需要使用.astype() 进行数据类型转换。
C.用于停止索引的包含索引的屏蔽
对于包含索引的掩码,即索引将包含在停止索引的情况下,我们需要简单地在比较中包含相等性。因此,最后一步是 -
return mask_indices_nd>=r
D.对于起始索引的索引排他屏蔽
这是一种情况,当给定开始索引并且这些索引不被屏蔽,但仅从下一个元素开始直到结束时才被屏蔽。因此,类似于上一节中列出的推理,对于这种情况,我们将最后一步修改为 -
return mask_indices_nd<r