【问题标题】:Indexing the max elements in a multidimensional tensor in PyTorch在 PyTorch 中索引多维张量中的最大元素
【发布时间】:2019-06-01 02:21:19
【问题描述】:

我正在尝试索引多维张量中最后一维的最大元素。例如,假设我有一个张量

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

这里 idx 存储最大索引,可能看起来像

>>>> A
tensor([[[ 1.0503,  0.4448,  1.8663],
     [ 0.8627,  0.0685,  1.4241]],

    [[ 1.2924,  0.2456,  0.1764],
     [ 1.3777,  0.9401,  1.4637]],

    [[ 0.5235,  0.4550,  0.2476],
     [ 0.7823,  0.3004,  0.7792]],

    [[ 1.9384,  0.3291,  0.7914],
     [ 0.5211,  0.1320,  0.6330]],

    [[ 0.3292,  0.9086,  0.0078],
     [ 1.3612,  0.0610,  0.4023]]])
>>>> idx
tensor([[ 2,  2],
    [ 0,  2],
    [ 0,  0],
    [ 0,  2],
    [ 1,  0]])

我希望能够访问这些索引并根据它们分配给另一个张量。意思是我希望能够做到

B = torch.new_zeros(A.size())
B[idx] = A[idx]

其中 B 处处为 0,除了 A 在最后一个维度上最大的地方。那就是B应该存储

>>>>B
tensor([[[ 0,  0,  1.8663],
     [ 0,  0,  1.4241]],

    [[ 1.2924,  0,  0],
     [ 0,  0,  1.4637]],

    [[ 0.5235,  0,  0],
     [ 0.7823,  0,  0]],

    [[ 1.9384,  0,  0],
     [ 0,  0,  0.6330]],

    [[ 0,  0.9086,  0],
     [ 1.3612,  0,  0]]])

事实证明这比我预期的要困难得多,因为 idx 没有正确索引数组 A。到目前为止,我一直无法找到使用 idx 来索引 A 的矢量化解决方案。

有没有很好的矢量化方法来做到这一点?

【问题讨论】:

    标签: python multidimensional-array deep-learning pytorch tensor


    【解决方案1】:

    您可以使用torch.meshgrid 创建索引元组:

    >>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
    >>> B = torch.zeros_like(A)
    >>> B[index_tuple] = A[index_tuple]
    

    请注意,您也可以通过(针对 3D 的特定情况)模仿meshgrid

    >>> index_tuple = (
    ...     torch.arange(A.size(0))[:, None],
    ...     torch.arange(A.size(1))[None, :],
    ...     idx
    ... )
    

    更多解释:
    我们将有类似这样的索引:

    In [173]: idx 
    Out[173]: 
    tensor([[2, 1],
            [2, 0],
            [2, 1],
            [2, 2],
            [2, 2]])
    

    由此,我们想要找到三个索引(因为我们的张量是 3D,我们需要三个数字来检索每个元素)。基本上我们要在前两个维度构建一个网格,如下图所示。 (这就是我们使用网格网格的原因)。

    In [174]: A[0, 0, 2], A[0, 1, 1]  
    Out[174]: (tensor(0.6288), tensor(-0.3070))
    
    In [175]: A[1, 0, 2], A[1, 1, 0]  
    Out[175]: (tensor(1.7085), tensor(0.7818))
    
    In [176]: A[2, 0, 2], A[2, 1, 1]  
    Out[176]: (tensor(0.4823), tensor(1.1199))
    
    In [177]: A[3, 0, 2], A[3, 1, 2]    
    Out[177]: (tensor(1.6903), tensor(1.0800))
    
    In [178]: A[4, 0, 2], A[4, 1, 2]          
    Out[178]: (tensor(0.9138), tensor(0.1779))
    

    在上面的 5 行中,索引中的前两个数字基本上是我们使用 meshgrid 构建的网格,第三个数字来自idx

    即前两个数字形成一个网格。

     (0, 0) (0, 1)
     (1, 0) (1, 1)
     (2, 0) (2, 1)
     (3, 0) (3, 1)
     (4, 0) (4, 1)
    

    【讨论】:

    • 感谢您的解决方案!但是如果 idx 是第二维而不是最后一个呢?
    【解决方案2】:

    一个丑陋的解决方法是用idx 创建一个二进制掩码,并用它来索引数组。基本代码如下所示:

    import torch
    torch.manual_seed(0)
    
    A = torch.randn((5, 2, 3))
    _, idx = torch.max(A, dim=2)
    
    mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
    B = torch.zeros_like(A)
    B[mask] = A[mask]
    print(A)
    print(B)
    

    诀窍在于torch.arange(A.size(2)) 枚举了idx 中的可能值,而mask 在它们等于idx 的地方不为零。备注:

    1. 如果真的舍弃了torch.max的第一个输出,可以改用torch.argmax
    2. 我认为这是一些更广泛问题的最小示例,但请注意,您目前正在使用大小为 (1, 1, 3) 的内核重新发明 torch.nn.functional.max_pool3d
    3. 另外,请注意,使用掩码赋值对张量进行就地修改可能会导致 autograd 出现问题,因此您可能需要使用 torch.where,如 here 所示。

    我希望有人提出更简洁的解决方案(避免 mask 数组的中间分配),可能会使用 torch.index_select,但我现在无法让它工作。

    【讨论】:

      【解决方案3】:

      可以使用 torch.scatter here

      >>> import torch
      >>> a = torch.randn(4,2,3)
      >>> a
      tensor([[[ 0.1583,  0.1102, -0.8188],
               [ 0.6328, -1.9169, -0.5596]],
      
              [[ 0.5335,  0.4069,  0.8403],
               [-1.2537,  0.9868, -0.4947]],
      
              [[-1.2830,  0.4386, -0.0107],
               [ 1.3384,  0.5651,  0.2877]],
      
              [[-0.0334, -1.0619, -0.1144],
               [ 0.1954, -0.7371,  1.7001]]])
      >>> ind = torch.max(a,1,keepdims=True)[1]
      >>> ind
      tensor([[[1, 0, 1]],
      
              [[0, 1, 0]],
      
              [[1, 1, 1]],
      
              [[1, 1, 1]]])
      >>> torch.zeros_like(a).scatter(1,ind,a)
      tensor([[[ 0.0000,  0.1102,  0.0000],
               [ 0.1583,  0.0000, -0.8188]],
      
              [[ 0.5335,  0.0000,  0.8403],
               [ 0.0000,  0.4069,  0.0000]],
      
              [[ 0.0000,  0.0000,  0.0000],
               [-1.2830,  0.4386, -0.0107]],
      
              [[ 0.0000,  0.0000,  0.0000],
               [-0.0334, -1.0619, -0.1144]]])
      

      【讨论】:

        猜你喜欢
        • 2019-02-05
        • 2021-09-14
        • 1970-01-01
        • 2018-01-26
        • 2019-09-15
        • 2020-09-26
        • 2020-06-10
        • 2018-06-08
        • 2019-11-06
        相关资源
        最近更新 更多