【问题标题】:PyTorch tensor advanced indexingPyTorch 张量高级索引
【发布时间】:2020-07-20 14:51:10
【问题描述】:

假设我有一个矩阵和一个向量如下:

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

有没有办法对它进行切片x[y] 所以结果是:

res = [1, 6, 8]

所以基本上我取y 的第一个元素并取x 中与第一行和元素列相对应的元素。

干杯

【问题讨论】:

  • x 的定义中添加了括号,不确定是这里的错字还是您的代码中的错字。
  • 打错字了,谢谢

标签: python numpy pytorch


【解决方案1】:

可以指定对应的行索引为:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])

【讨论】:

  • 感谢您的快速回复,我收到以下错误:TypeError:只有单个元素的整数张量可以转换为索引
  • 我能够毫无问题地重现这个答案。请修改您的问题@yarin
  • 对不起,我不小心打错了括号,现在它可以工作了
【解决方案2】:

pytorch 中的高级索引与NumPy's 一样工作,即索引数组在轴上一起广播。所以你可以像 FBruzzesi 的回答那样做。

虽然类似于np.take_along_axis,但在pytorch 中你也有torch.gather,用于沿特定轴取值:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])

【讨论】:

    猜你喜欢
    • 2018-06-08
    • 2020-09-26
    • 2020-03-28
    • 2019-11-26
    • 2021-11-04
    • 2019-09-15
    • 1970-01-01
    • 2021-06-27
    • 2020-12-06
    相关资源
    最近更新 更多