【问题标题】:Top K indices of a multi-dimensional tensor多维张量的前 K 个索引
【发布时间】:2021-01-22 05:59:25
【问题描述】:

我有一个 2D 张量,我想获得前 k 个值的索引。我知道pytorch's topk 功能。 pytorch 的 topk 函数的问题是,它计算某个维度上的 topk 值。我想获得两个维度的 topk 值。

例如对于以下张量

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

pytorch 的 topk 函数会给我以下。

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

但我想得到以下内容

tensor([[0, 1],
        [2, 0],
        [3, 1]])

这是 9 在二维张量中的索引。

有什么方法可以使用 pytorch 实现这一点吗?

【问题讨论】:

  • 您只是想要示例中的前 1 个,还是对 topk 也感兴趣?

标签: python pytorch tensor matrix-indexing


【解决方案1】:
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)

输出:

[[3 1]
 [2 0]
 [0 1]]
  1. 展平并找到top k
  2. 使用unravel_index将一维索引转换为二维

【讨论】:

    【解决方案2】:

    您可以根据需要进行一些向量操作进行过滤。在这种情况下不使用 topk。

    print(a)
    tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])
    
    values, indices = torch.max(a,1)   # get max values, indices
    temp= torch.zeros_like(values)     # temporary
    temp[values==9]=1                  # fill temp where values are 9 (wished value)
    seq=torch.arange(values.shape[0])  # create a helper sequence
    new_seq=seq[temp>0]                # filter sequence where values are 9
    new_temp=indices[new_seq]          # filter indices with sequence where values are 9
    final = torch.stack([new_seq, new_temp], dim=1)  # stack both to get result
    
    print(final)
    tensor([[0, 1],
            [2, 0],
            [3, 1]])
    

    【讨论】:

      【解决方案3】:

      您可以flatten 原始张量,应用topk,然后将结果标量索引转换回多维索引,如下所示:

      def descalarization(idx, shape):
          res = []
          N = np.prod(shape)
          for n in shape:
              N //= n
              res.append(idx // N)
              idx %= N
          return tuple(res)
      

      例子:

      torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
      # Returns 
      # tensor([[3, 1],
      #         [2, 0],
      #         [0, 1],
      #         [3, 4],
      #         [2, 4]])
      

      【讨论】:

        猜你喜欢
        • 2016-04-17
        • 2020-09-23
        • 2020-10-26
        • 2018-03-03
        • 2019-09-15
        • 2021-11-25
        • 2019-02-05
        • 2020-09-26
        • 1970-01-01
        相关资源
        最近更新 更多