你可以使用advanced-indexing -
a[np.arange(a.shape[0])[:,None], sortidxs]
示例运行 -
In [144]: a = np.random.randint(0,9,(2,3,4))
In [145]: a
Out[145]:
array([[[1, 1, 5, 5],
[1, 1, 7, 5],
[6, 1, 2, 8]],
[[7, 2, 5, 4],
[3, 7, 3, 7],
[8, 4, 4, 6]]])
In [146]: sortidxs = np.argsort(np.linalg.norm(a, axis=-1))
In [147]: np.array([a[_][sortidxs[_]] for _ in range(a.shape[0])])
Out[147]:
array([[[1, 1, 5, 5],
[1, 1, 7, 5],
[6, 1, 2, 8]],
[[7, 2, 5, 4],
[3, 7, 3, 7],
[8, 4, 4, 6]]])
In [149]: a[np.arange(a.shape[0])[:,None], sortidxs]
Out[149]:
array([[[1, 1, 5, 5],
[1, 1, 7, 5],
[6, 1, 2, 8]],
[[7, 2, 5, 4],
[3, 7, 3, 7],
[8, 4, 4, 6]]])
进一步提升性能
我们可以优化计算 sortidxs 和 np.einsum -
sortidxs = np.einsum('ijk,ijk->ij',a,a).argsort()
让我们计时并验证这个想法 -
In [94]: a = np.random.randint(0,9,(20,30,40))
In [95]: %timeit np.argsort(np.linalg.norm(a, axis=-1))
10000 loops, best of 3: 63.5 µs per loop
In [96]: %timeit np.einsum('ijk,ijk->ij',a,a).argsort()
10000 loops, best of 3: 19.7 µs per loop
In [97]: a = np.random.randint(0,9,(200,300,400))
In [98]: %timeit np.argsort(np.linalg.norm(a, axis=-1))
10 loops, best of 3: 88.6 ms per loop
In [99]: %timeit np.einsum('ijk,ijk->ij',a,a).argsort()
10 loops, best of 3: 22.6 ms per loop
更高维度的数组
对于a 是4D 数组的额外情况,我们需要使用更多数组进行索引。
1] 对于第一个轴:使用np.arange(a.shape[0]),最后有两个新轴。
2] 对于第二个轴:使用np.arange(a.shape[0]),最后一个新轴。
3] 对于第三个轴:使用sortidxs 对其进行索引。
因此,我们会:
m,n,r,s = a.shape
out = a[np.arange(m)[:,None,None],np.arange(n)[:,None], sortidxs]
单例暗淡的数组(长度为 1 的暗淡)
作为一种特殊情况,假设输入数组的第二个轴已经是一个单轴,我们可以简单地使用 0 作为那个轴,从而简化事情,就像这样 -
a[np.arange(m)[:,None,None],0, sortidxs]
示例运行 -
In [58]: a = np.array([[[3, 4],
...: [1, 2]],
...:
...: [[5, 6],
...: [7, 8]]])
...:
...: a = a.reshape((2,1,2,2))
...:
In [59]: sortidxs = np.argsort(np.linalg.norm(a, axis=-1))
In [60]: a[np.arange(a.shape[0])[:,None,None],0, sortidxs]
Out[60]:
array([[[[1, 2],
[3, 4]]],
[[[5, 6],
[7, 8]]]])
为具有(2,3,4) 的通用形状的数组运行另一个示例,以使事情变得非常清楚 -
In [70]: a = np.random.randint(0,9,(2,1,3,4))
In [71]: a
Out[71]:
array([[[[6, 4, 8, 6],
[4, 0, 1, 0],
[5, 3, 2, 5]]],
[[[3, 6, 0, 4],
[6, 2, 5, 2],
[0, 8, 0, 8]]]])
In [72]: sortidxs = np.argsort(np.linalg.norm(a, axis=-1))
In [73]: sortidxs
Out[73]:
array([[[1, 2, 0]],
[[0, 1, 2]]])
In [74]: a[np.arange(a.shape[0])[:,None,None],0, sortidxs]
Out[74]:
array([[[[4, 0, 1, 0],
[5, 3, 2, 5],
[6, 4, 8, 6]]],
[[[3, 6, 0, 4],
[6, 2, 5, 2],
[0, 8, 0, 8]]]])