【发布时间】:2021-11-04 18:31:33
【问题描述】:
我是 PyTorch 的新手,我仍在思考如何形成正确的 gather 语句。我有一个大小为(1,200,61,1632) 的 4D 输入张量,其中1632 是时间维度。我想用张量idx 对其进行索引,它的大小为(4,1632),其中idx 的每一行都是我想从input 张量中提取的值。所以idx 的行看起来像:
[0,20,30,0]
[0,150,9,1]
[0,180,100,2]
...
这样输出的大小为1632。换句话说,我想这样做:
output = []
for i in range(1632):
output.append(input[idx[0,i], idx[1,i], idx[2,i], idx[3,i]])
这是否是 torch.gather 的合适用例?查看收集的文档,它说输入张量和索引张量必须具有相同的形状。
【问题讨论】:
标签: python indexing pytorch tensor