【发布时间】:2021-09-06 02:33:24
【问题描述】:
接口torch.max将返回值和索引,我如何使用索引从另一个张量中获取相应的元素? 例如:
a = torch.rand(2,3,4)
b = torch.rand(2,3,4)
# indices shape is [2, 4]
indices = torch.max(a, 1)[1]
# how to get elements by indices ?
b_max = ????
【问题讨论】:
接口torch.max将返回值和索引,我如何使用索引从另一个张量中获取相应的元素? 例如:
a = torch.rand(2,3,4)
b = torch.rand(2,3,4)
# indices shape is [2, 4]
indices = torch.max(a, 1)[1]
# how to get elements by indices ?
b_max = ????
【问题讨论】:
keepdim=True 在调用 torch.max() 和 torch.take_along_dim() 时应该可以解决问题。
>>> import torch
>>> a=torch.rand(2,3,4)
>>> b=torch.rand(2,3,4)
>>> indices=torch.max(a,1,keepdim=True)[1]
>>> b_max = torch.take_along_dim(b,indices,dim=1)
二维示例:
>>> a=torch.rand(2,3)
>>> a
tensor([[0.0163, 0.0711, 0.5564],
[0.4507, 0.8675, 0.5974]])
>>> b=torch.rand(2,3)
>>> b
tensor([[0.7542, 0.1793, 0.5399],
[0.2292, 0.5329, 0.2084]])
>>> indices=torch.max(a,1,keepdim=True)[1]
>>> torch.take_along_dim(b,indices,dim=1)
tensor([[0.5399],
[0.5329]])
【讨论】: