【发布时间】:2021-06-12 10:29:54
【问题描述】:
假设我有以下数据:
import torch
torch.manual_seed(42)
logits = torch.randn(2, 5, 3)
idx = torch.randint(0, 3, (2, 5))
我想做的是:
[[logits[i,j,idx[i][j]] for j in range(len(idx[i]))] for i in range(len(idx))]
但是,这显然是低效的。
我最接近的就是这样做,但这又看起来很丑:
new_idx = torch.stack([idx]*logits.shape[-1], dim=-1)
logits.gather(dim=-1, index=new_idx)
在上述情况下,所需的输出被复制 3 次。
就实际用例而言,我正在考虑一种语言模型,其中 logits 的形状为 (batch_size, sequence_len, vocabulary),而索引只是 (batch_size, sequence_len)。抱歉,如果之前有人问过这个问题,但我找不到任何东西。
【问题讨论】:
标签: pytorch