【问题标题】:Subsetting A Pytorch Tensor Using Square-Brackets使用方括号对 Pytorch 张量进行子集化
【发布时间】:2020-02-13 06:44:16
【问题描述】:

我在 PyTorch 中遇到了一行代码,用于将 3D 张量简化为 2D 张量。 3D 张量x 的大小为torch.Size([500, 50, 1]),这行代码:

x = x[lengths - 1, range(len(lengths))]

用于将x 缩减为大小为torch.Size([50, 1]) 的二维张量。 lengths 也是一个形状为 torch.Size([50]) 的张量,包含值。

请谁能解释一下这是如何工作的?谢谢。

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    在被这种行为弄得一头雾水之后,我对此进行了更多挖掘,发现它是consistent behavior with the indexing of multi-dimensional NumPy arrays。使这种反直觉的原因是两个数组必须具有相同的长度这一不太明显的事实,即在这种情况下len(lengths)

    事实上,它的工作原理如下: * lengths 正在确定您访问第一个维度的顺序。即,如果您有一个一维数组a = [0, 1, 2, ...., 500],并使用列表b = [300, 200, 100] 访问它,则结果为a[b] = [301, 201, 101](这也解释了lengths - 1 运算符,它只会导致访问的值与分别在blengths 中使用的索引)。 * range(len(lengths)) 然后*只需选择i-th 行中的i-th 元素。如果你有一个方阵,你可以把它解释为矩阵的对角线。由于您只能访问前两个维度上每个位置的单个元素,因此可以将其存储在单个维度中(从而将您的 3D 张量减少到 2D)。后一个维度只是保持“原样”。

    如果你想玩这个,我强烈建议将 range() 值更改为更长/更短的值,这将导致以下错误:

    IndexError:形状不匹配:无法广播索引数组 连同形状 (x,) (y,)

    其中xy 是您的特定长度值。

    要以长格式编写此访问方法以了解“幕后”发生的情况,还请考虑以下示例:

    import torch
    x = torch.randint(500, 50, 1)
    lengths = torch.tensor([2, 30, 1, 4])  # random examples to explore
    diag = list(range(len(lengths)))  # [0, 1, 2, 3]
    result = []
    for i, row in enumerate(lengths):
        temp_tensor = x[row, :, :]  # temp_tensor.shape = [1, 50, 1]
        temp_tensor = temp_tensor.squeeze(0)[diag[i]]  # temp_tensor.shape = [1, 1]
        result.append(temp.tensor)
    
    # back to pytorch
    result = torch.tensor(result)
    result.shape  # [4, 1]
    

    【讨论】:

      【解决方案2】:

      这里的关键特性是将张量lengths 的值作为x 的索引传递。 这里简化的例子,我交换了容器的尺寸,所以 index dimenson 优先:

      container = torch.arange(0, 50 )
      container = f.reshape((5, 10))
      >>>tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
              [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
              [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
              [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
              [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])
      
      indices = torch.arange( 2, 7, dtype=torch.long )
      >>>tensor([2, 3, 4, 5, 6])
      
      print( container[ range( len(indices) ), indices] )
      >>>tensor([ 2, 13, 24, 35, 46])    
      

      注意:我们从一行中得到一件事(range( len(indices) ) 生成连续的行号),列号由索引[ row_number ]

      【讨论】:

        猜你喜欢
        • 2021-11-02
        • 1970-01-01
        • 2020-04-26
        • 2020-06-23
        • 2021-07-08
        • 1970-01-01
        • 1970-01-01
        • 2021-10-13
        • 2020-02-16
        相关资源
        最近更新 更多