【问题标题】:PyTorch: Zero all elements of vector except top k?PyTorch:除前k之外的向量的所有元素都归零?
【发布时间】:2021-04-14 22:36:07
【问题描述】:

我正在尝试创建一个新的激活层,我们称之为 topk,它的工作方式如下。它将一个大小为 n 的向量 x 作为输入(将前一层输出乘以权重矩阵并添加偏差的结果)和一个正整数 k,并将输出一个大小为 n 的向量 topk(x),其元素是:

              x_i (if x_i is one of the top k elements of x) 
topk(x)_i = 
              0 (otherwise)

在计算topk(x)的梯度时,x的前k个元素的梯度应该是1,其他的都是0。

我应该如何实现这个?任何帮助将不胜感激。

【问题讨论】:

    标签: python pytorch top-n


    【解决方案1】:

    您可以为此使用torch.topk

    k = 2
    output = torch.randn(5)
    vals, idx = output.topk(k)
    
    topk = torch.zeros_like(output)
    topk[idx] = vals
    
    >>> topk
    tensor([1.0557, 0.0000, 0.0000, 1.4562, 0.0000])
    

    请注意,虽然 topk()'values' 是可微分的,但 'indices' are not (类似于 argmax 是不可微分的函数)。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2013-11-04
      • 1970-01-01
      • 2021-11-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-09-02
      • 1970-01-01
      相关资源
      最近更新 更多