【问题标题】:How to get top-k elements of each row in a 2D tensor?如何获取二维张量中每一行的前 k 个元素?
【发布时间】:2020-03-10 08:48:39
【问题描述】:

如何以优雅的方式获取二维张量中每一行的前 k 个元素,而不是像下面那样使用 for 循环?

import torch

elements = torch.rand(5,10)
topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....
index_list = [] # record the topk index in elements

for i in range(5):
    index_list.append(elements[i].topk(topk_list[i]))

【问题讨论】:

    标签: python pytorch tensor


    【解决方案1】:

    某事是否优雅总是有争议的。在for循环中使用固定范围肯定可以改进,你至少可以使用range(len(topk_list)),这样代码就可以被不同的topk列表重用。

    您可以通过以下方式进一步改进:

    for i, n in enumerate(topk_list): 
        index_list.append(elements[i].topk(n))
    

    甚至:

    index_list = [ elements[i].topk(n) for i, n in enumerate(topk_list) ]
    

    但这只是语法糖。

    【讨论】:

      【解决方案2】:

      如果您的 k 变化不大,并且您想对代码进行矢量化,您可以首先获取每行的最大顶部 k,然后收集所需的结果。

      # Code from OP
      import torch
      
      elements = torch.rand(5,10)
      topk_list = [2,3,1,2,0] # means top2 for 1st row, top3 for 2nd row, top1 for 3rd row,....
      index_list = [] # record the topk index in elements
      
      for i in range(5):
          index_list.append(elements[i].topk(topk_list[i]))
      
      # Print the result
      print(index_list)
      
      # Get topk for max_k
      max_k = max(topk_list)
      topk_vals, topk_inds = elements.topk(max_k, dim=-1)
      
      # Select desired topk using mask
      mask = torch.arange(max_k)[None, :] < torch.tensor(topk_list)[:, None]
      vals, inds = topk_vals[mask], topk_inds[mask]
      rows, _ = mask.nonzero().T
      print("-" * 10)
      print("rows", rows)
      print("inds", inds)
      print("vals", vals)
      
      # Or split
      vals_per_row = vals.split(topk_list)
      inds_per_row = inds.split(topk_list)
      print("-" * 10)
      print("vals_per_row", vals_per_row)
      print("inds_per_row", inds_per_row)
      
      # Or zip (for loop but should be cheap)
      index_list = zip(vals_per_row, inds_per_row)
      print("-" * 10)
      print("zipped results", list(index_list))
      

      这给出了以下输出:

      [torch.return_types.topk(
      values=tensor([0.8148, 0.7443]),
      indices=tensor([8, 4])), torch.return_types.topk(
      values=tensor([0.7529, 0.7352, 0.6354]),
      indices=tensor([8, 1, 9])), torch.return_types.topk(
      values=tensor([0.8792]),
      indices=tensor([7])), torch.return_types.topk(
      values=tensor([0.9626, 0.8728]),
      indices=tensor([6, 2])), torch.return_types.topk(
      values=tensor([]),
      indices=tensor([], dtype=torch.int64))]
      ----------
      rows tensor([0, 0, 1, 1, 1, 2, 3, 3])
      inds tensor([8, 4, 8, 1, 9, 7, 6, 2])
      vals tensor([0.8148, 0.7443, 0.7529, 0.7352, 0.6354, 0.8792, 0.9626, 0.8728])
      ----------
      vals_per_row (tensor([0.8148, 0.7443]), tensor([0.7529, 0.7352, 0.6354]), tensor([0.8792]), tensor([0.9626, 0.8728]), tensor([]))
      inds_per_row (tensor([8, 4]), tensor([8, 1, 9]), tensor([7]), tensor([6, 2]), tensor([], dtype=torch.int64))
      ----------
      zipped results [(tensor([0.8148, 0.7443]), tensor([8, 4])), (tensor([0.7529, 0.7352, 0.6354]), tensor([8, 1, 9])), (tensor([0.8792]), tensor([7])), (tensor([0.9626, 0.8728]), tensor([6, 2])), (tensor([]), tensor([], dtype=torch.int64))]
      

      【讨论】:

      • 这就是我想要的,通过改变空间在 GPU 上实现友好的并行效率。谢谢。
      • 如果这是你想要的,你能接受答案吗? :)
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2015-07-15
      • 2016-04-17
      • 2019-09-28
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多