【发布时间】:2021-08-17 12:19:25
【问题描述】:
我想做一些类似 argmax 但有多个最高值的事情。我知道如何使用普通的 torch.argmax
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 1.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
>>> torch.argmax(a)
tensor(0)
但现在我需要找到前 N 个值的索引。所以像这样的
>>> a = torch.randn(4, 4)
>>> a
tensor([[ 1.3398, 1.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[ 0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195]])
>>> torch.argmax(a,top_n=2)
tensor([0,1])
我在 pytorch 中没有找到任何能够做到这一点的函数,有人知道吗?
【问题讨论】:
标签: python machine-learning deep-learning pytorch