【问题标题】:How to mask a 3D tensor with 2D mask and keep the dimensions of original vector?如何用 2D 蒙版屏蔽 3D 张量并保持原始向量的尺寸?
【发布时间】:2020-05-22 14:08:56
【问题描述】:

假设,我有一个 3D 张量 A

A = torch.arange(24).view(4, 3, 2)
print(A)

并要求使用 2D 张量对其进行屏蔽

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

使用 PyTorch 中的 masked_select 功能会导致以下错误。

torch.masked_select(X, (mask == 1))


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
     12 
     13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
     15 #Y = X * mask_
     16 print(Y)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2

如何用 2D 蒙版屏蔽 3D 张量并保持原始向量的尺寸?任何提示将不胜感激。

【问题讨论】:

    标签: pytorch tensor


    【解决方案1】:

    本质上,我们需要将张量掩码的维度与被掩码的张量相匹配。

    有两种方法可以做到。

    方法 1:不保留原始张量维度。

    X = torch.arange(24).view(4, 3, 2)
    print(X)
    
    mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
    mask[0, 0] = 1
    mask[1, 1] = 1
    mask[3, 0] = 1
    print('Mask: ', mask)
    
    # Add a dimension to the mask tensor and expand it to the size of original tensor
    mask_ = mask.unsqueeze(-1).expand(X.size())
    print(mask_)
    
    # Select based on the new expanded mask
    Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
    print(Y)
    

    方法 1 的输出:

    tensor([ 0,  1,  8,  9, 18, 19])
    

    方法 2:保留原始张量尺寸(通过填充)。

    X = torch.arange(24).view(4, 3, 2)
    print(X)
    
    mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
    mask[0, 0] = 1
    mask[1, 1] = 1
    mask[3, 0] = 1
    print('Mask: ', mask)
    
    # Add a dimension to the mask tensor and expand it to the size of original tensor
    mask_ = mask.unsqueeze(-1).expand(X.size())
    print(mask_)
    
    # Select based on the new expanded mask
    Y = X * mask_
    print(Y)
    

    方法 2 的输出:

    tensor([[[ 0,  1],
             [ 2,  3],
             [ 4,  5]],
    
            [[ 6,  7],
             [ 8,  9],
             [10, 11]],
    
            [[12, 13],
             [14, 15],
             [16, 17]],
    
            [[18, 19],
             [20, 21],
             [22, 23]]])
    Mask:  tensor([[1, 0, 0],
            [0, 1, 0],
            [0, 0, 0],
            [1, 0, 0]])
    tensor([[[1, 1],
             [0, 0],
             [0, 0]],
    
            [[0, 0],
             [1, 1],
             [0, 0]],
    
            [[0, 0],
             [0, 0],
             [0, 0]],
    
            [[1, 1],
             [0, 0],
             [0, 0]]])
    tensor([[[ 0,  1],
             [ 0,  0],
             [ 0,  0]],
    
            [[ 0,  0],
             [ 8,  9],
             [ 0,  0]],
    
            [[ 0,  0],
             [ 0,  0],
             [ 0,  0]],
    
            [[18, 19],
             [ 0,  0],
             [ 0,  0]]]
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-10-30
      • 1970-01-01
      • 1970-01-01
      • 2015-01-05
      • 2018-10-31
      • 2019-08-28
      • 1970-01-01
      • 2016-07-02
      相关资源
      最近更新 更多