【问题标题】:Torch select index for mult dimensional vector多维向量的 Torch 选择索引
【发布时间】:2021-06-12 10:29:54
【问题描述】:

假设我有以下数据:

import torch
torch.manual_seed(42)
logits = torch.randn(2, 5, 3)
idx = torch.randint(0, 3, (2, 5))

我想做的是:

[[logits[i,j,idx[i][j]] for j in range(len(idx[i]))] for i in range(len(idx))]

但是,这显然是低效的。

我最接近的就是这样做,但这又看起来很丑:

new_idx = torch.stack([idx]*logits.shape[-1], dim=-1)
logits.gather(dim=-1, index=new_idx)

在上述情况下,所需的输出被复制 3 次。

就实际用例而言,我正在考虑一种语言模型,其中 logits 的形状为 (batch_size, sequence_len, vocabulary),而索引只是 (batch_size, sequence_len)。抱歉,如果之前有人问过这个问题,但我找不到任何东西。

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    same with this answer torch.gather(logits, -1, idx.unsqueeze(-1))

    >>> [[logits[i,j,idx[i][j]] for j in range(len(idx[i]))] for i in range(len(idx))]
    [[tensor(0.9007), tensor(0.6784), tensor(-0.0431), tensor(-1.4036), tensor(-0.7279)], [tensor(-0.2168), tensor(1.7174), tensor(-0.4245), tensor(0.9956), tensor(-1.2742)]]
    
    >>> torch.gather(logits, -1, idx.unsqueeze(-1))
    tensor([[[ 0.9007],
             [ 0.6784],
             [-0.0431],
             [-1.4036],
             [-0.7279]],
    
            [[-0.2168],
             [ 1.7174],
             [-0.4245],
             [ 0.9956],
             [-1.2742]]])
    

    【讨论】:

      猜你喜欢
      • 2021-06-15
      • 1970-01-01
      • 2016-05-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2012-07-20
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多