【问题标题】:how to get max value by the indices return by torch.max?如何通过torch.max返回的索引获取最大值?
【发布时间】:2021-09-06 02:33:24
【问题描述】:

接口torch.max将返回值和索引,我如何使用索引从另一个张量中获取相应的元素? 例如:

a = torch.rand(2,3,4)
b = torch.rand(2,3,4)
# indices shape is [2, 4]
indices = torch.max(a, 1)[1]
# how to get elements by indices ?
b_max = ????

【问题讨论】:

    标签: python pytorch max tensor


    【解决方案1】:

    keepdim=True 在调用 torch.max()torch.take_along_dim() 时应该可以解决问题。

    >>> import torch
    >>> a=torch.rand(2,3,4)
    >>> b=torch.rand(2,3,4)
    >>> indices=torch.max(a,1,keepdim=True)[1]
    >>> b_max = torch.take_along_dim(b,indices,dim=1)
    

    二维示例:

    >>> a=torch.rand(2,3)
    >>> a
    tensor([[0.0163, 0.0711, 0.5564],
            [0.4507, 0.8675, 0.5974]])
    >>> b=torch.rand(2,3)
    >>> b
    tensor([[0.7542, 0.1793, 0.5399],
            [0.2292, 0.5329, 0.2084]])
    >>> indices=torch.max(a,1,keepdim=True)[1]
    >>> torch.take_along_dim(b,indices,dim=1)
    tensor([[0.5399],
            [0.5329]])
    

    【讨论】:

      猜你喜欢
      • 2019-09-11
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-06-25
      • 1970-01-01
      • 2012-07-03
      • 1970-01-01
      相关资源
      最近更新 更多