【发布时间】:2021-01-22 05:59:25
【问题描述】:
我有一个 2D 张量,我想获得前 k 个值的索引。我知道pytorch's topk 功能。 pytorch 的 topk 函数的问题是,它计算某个维度上的 topk 值。我想获得两个维度的 topk 值。
例如对于以下张量
a = torch.tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
pytorch 的 topk 函数会给我以下。
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
# [0, 2, 1],
# [0, 1, 4],
# [1, 4, 3],
# [1, 0, 4]])
但我想得到以下内容
tensor([[0, 1],
[2, 0],
[3, 1]])
这是 9 在二维张量中的索引。
有什么方法可以使用 pytorch 实现这一点吗?
【问题讨论】:
-
您只是想要示例中的前 1 个,还是对 topk 也感兴趣?
标签: python pytorch tensor matrix-indexing