【问题标题】:Pytorch: How to index a tensor?Pytorch:如何索引张量?
【发布时间】:2021-11-04 18:31:33
【问题描述】:

我是 PyTorch 的新手,我仍在思考如何形成正确的 gather 语句。我有一个大小为(1,200,61,1632) 的 4D 输入张量,其中1632 是时间维度。我想用张量idx 对其进行索引,它的大小为(4,1632),其中idx 的每一行都是我想从input 张量中提取的值。所以idx 的行看起来像:

[0,20,30,0]
[0,150,9,1]
[0,180,100,2]
...

这样输出的大小为1632。换句话说,我想这样做:

output = []
for i in range(1632):
  output.append(input[idx[0,i], idx[1,i], idx[2,i], idx[3,i]])

这是否是 torch.gather 的合适用例?查看收集的文档,它说输入张量和索引张量必须具有相同的形状。

【问题讨论】:

    标签: python indexing pytorch tensor


    【解决方案1】:

    既然 PyTorch doesn't offerravel_multi_index 的实现,那么丑陋的做法就是这样:

    output = input[idx[0, :], idx[1, :], idx[2, :], idx[3, :]]
    

    在 NumPy 中,您可以这样做:

    output = np.take(input, np.ravel_multi_index(idx, input.shape))
    

    【讨论】:

    • @Ambrose 没问题 :) 玩得开心
    猜你喜欢
    • 2020-03-28
    • 2019-11-26
    • 2020-07-20
    • 2019-09-16
    • 2019-01-13
    • 2019-08-26
    • 2019-09-15
    • 1970-01-01
    • 2020-09-26
    相关资源
    最近更新 更多