【发布时间】:2021-02-11 21:07:21
【问题描述】:
如何索引具有 n 维的张量 t 和 m index 张量,以便保留 t 的最后一个维度?对于维度 m 之前的所有维度,index 张量的形状等于张量 t。或者换句话说,我想索引张量的中间维度,同时保留所选索引的所有以下维度。
例如,假设我们有两个张量:
t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
与 t:
tensor([[[ 15.2165, -7.9702],
[ 0.6646, 5.2844],
[-22.0657, -5.9876],
[ -9.7319, 11.7384],
[ 4.3985, -6.7058]],
[[-15.6854, -11.9362],
[ 11.3054, 3.3068],
[ -4.7756, -7.4524],
[ 5.0977, -17.3831],
[ 3.9152, -11.5047]],
[[ -5.4265, -22.6456],
[ 1.6639, 10.1483],
[ 13.2129, 3.7850],
[ 3.8543, -4.3496],
[ -8.7577, -12.9722]]])
然后我想要的输出将具有形状 (3, 2, 2) 并且是:
tensor([[[ 0.6646, 5.2844],
[ -9.7319, 11.7384]],
[[-15.6854, -11.9362],
[ 3.9152, -11.5047]],
[[ 3.8543, -4.3496],
[ 13.2129, 3.7850]]])
另一个例子是我有一个形状为(40, 10, 6, 2) 的张量t 和一个形状为(40, 10, 3) 的索引张量。这应该查询张量 t 的维度 3,并且预期的输出形状将是 (40, 10, 3, 2)。
如何在不使用循环的情况下以通用方式实现这一点?
【问题讨论】:
标签: python indexing pytorch tensor