【发布时间】:2017-05-23 19:26:32
【问题描述】:
我想我误解了 numpy 中的索引。
我有一个形状为 (dim_x, dim_y, dim_z) 的 3D-numpy 数组,我想找到沿第三个轴 (dim_z) 的最大值,并将其值设置为 1 并将所有其他值设置为零。
问题是我最终在同一行中有几个 1,即使值不同。
代码如下:
>>> test = np.random.rand(2,3,2)
>>> test
array([[[ 0.13110146, 0.07138861],
[ 0.84444158, 0.35296986],
[ 0.97414498, 0.63728852]],
[[ 0.61301975, 0.02313646],
[ 0.14251848, 0.91090492],
[ 0.14217992, 0.41549218]]])
>>> result = np.zeros_like(test)
>>> result[:test.shape[0], np.arange(test.shape[1]), np.argmax(test, axis=2)]=1
>>> result
array([[[ 1., 0.],
[ 1., 1.],
[ 1., 1.]],
[[ 1., 0.],
[ 1., 1.],
[ 1., 1.]]])
我希望以 :
结尾array([[[ 1., 0.],
[ 1., 0.],
[ 1., 0.]],
[[ 1., 0.],
[ 0., 1.],
[ 0., 1.]]])
我可能在这里遗漏了一些东西。据我了解,0:dim_x, np.arange(dim_y) 返回dim_y 元组中的dim_x,np.argmax(test, axis=dim_z) 的形状为(dim_x, dim_y),因此如果索引的形式为[x, y, z],则不应该出现一对[x, y]两次。
谁能解释我哪里错了?提前致谢。
【问题讨论】:
标签: python numpy multidimensional-array