【发布时间】:2021-04-23 22:54:45
【问题描述】:
我有一长串存储在 1d 张量中的值,我喜欢从中收集值。给定一个二维张量,其中第二维表示一组索引,我想计算一个新的二维张量,其中第二维是索引给定的所有值。基本上,我有一批索引,我使用这些索引从一维张量中获取一批值。
我是 pytorch 的新手,但我设法使用 pytorch 的收集函数准确计算了这一点
list_values = torch.tensor([1, 2, 3, 4, 5, 6])
list_values = list_values.unsqueeze(0)
list_values = list_values.expand((2, 6))
indices = torch.tensor([[1, 2], [2, 3]])
result = torch.gather(list_values, 1, indices)
这非常有效,并给出了正确的结果。但是,如果我没记错的话,随着“list_value”中元素数量的增加,展开操作在内存方面是相当昂贵的。
有没有更好的解决方案?
【问题讨论】: