【问题标题】:How to select indices according to another tensor in pytorch如何根据pytorch中的另一个张量选择索引
【发布时间】:2021-04-25 04:18:35
【问题描述】:

任务看起来很简单,但我不知道该怎么做。

所以我有两个张量:

  • 形状为(2, 5, 2) 的索引张量indices,其中最后一个维度对应于x 和y 维度中的索引
  • 形状为(2, 5, 2, 16, 16) 的“值张量”value,我希望使用 x 和 y 索引选择最后两个维度

更具体地说,索引在 0 到 15 之间,我想得到一个输出:

out = value[:, :, :, x_indices, y_indices]

因此输出的形状应该是(2, 5, 2)。有人可以在这里帮助我吗?非常感谢!

编辑:

我尝试了用聚集的建议,但不幸的是它似乎不起作用(我改变了尺寸,但没关系):

首先我生成一个坐标网格:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(1, 3, 1, 1, 1)

下一步,我将创建一些索引。在这种情况下,我总是取索引 1:

indices = torch.ones([1, 3, 2], dtype=torch.int64)

接下来,我用的是你的方法:

indices = indices.unsqueeze(-1).unsqueeze(-1)
new_coords = torch.gather(grid, -1, indices).squeeze(-1).squeeze(-1)

最后,我手动为 x 和 y 坐标选择索引 1:

new_coords_manual = grid[:, :, :, 1, 1]

这会输出以下新坐标:

new_coords
tensor([[[-1.0000, -0.8667],
         [-1.0000, -0.8667],
         [-1.0000, -0.8667]]])

new_coords_manual
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

如您所见,它仅适用于一维。你知道如何解决这个问题吗?

【问题讨论】:

  • 您能否展示一个最小的示例 indicesvalue 以及所需的输出?
  • 产生new_coords_manual时达到想要的输出

标签: python pytorch indices


【解决方案1】:

您可以做的是将前三个轴展平并应用torch.gather

>>> grid.flatten(start_dim=0, end_dim=2).shape
torch.Size([6, 16, 16])

>>> torch.gather(grid.flatten(0, 2), axis=1, indices)
tensor([[[-0.8667, -0.8667],
         [-0.8667, -0.8667],
         [-0.8667, -0.8667]]])

如文档页面所述,这将执行:

out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1

【讨论】:

  • 感谢您的帮助!这确实适用于批量大小为 1,但它似乎面临批量大小 > 1 的相同问题:/ 我还尝试将问题拆分为 x 和 y 坐标并应用indices_y = indices[:, :, 0].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),然后应用new_y = torch.gather(grid, 3, indices_y).squeeze(-1).squeeze(-1)。在对 x 值执行相同操作并在 dim=4 上收集之后,我连接了张量。但我从你的第一个建议中得到了确切的结果。
【解决方案2】:

我想通了,再次感谢@Ivan 的帮助! :)

问题是,我在最后一个维度上没有挤压,而我应该在中间维度上没有挤压,所以索引在最后:

y_t = torch.linspace(-1., 1., 16, device='cpu').reshape(16, 1).repeat(1, 16).unsqueeze(-1)
x_t = torch.linspace(-1., 1., 16, device='cpu').reshape(1, 16).repeat(16, 1).unsqueeze(-1)
grid = torch.cat((y_t, x_t), dim=-1).permute(2, 0, 1).unsqueeze(0)
grid = grid.unsqueeze(1).repeat(2, 3, 1, 1, 1)

indices = torch.ones([2, 3, 2], dtype=torch.int64).unsqueeze(-2).unsqueeze(-2)
new_coords = torch.gather(grid, 3, indices).squeeze(-2).squeeze(-2)

new_coords_manual = grid[:, :, :, 1, 1]

现在new_coords 等于new_coords_manual

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2021-04-11
    • 1970-01-01
    • 2019-08-18
    • 1970-01-01
    • 1970-01-01
    • 2021-11-04
    • 2019-06-13
    • 1970-01-01
    相关资源
    最近更新 更多