【问题标题】:Filling torch tensor with zeros after certain index在某个索引后用零填充火炬张量
【发布时间】:2019-12-24 04:02:06
【问题描述】:

给定一个 3d 张量,说: batch x sentence length x embedding dim

a = torch.rand((10, 1000, 96)) 

以及每个句子的实际长度数组(或张量)

lengths =  torch .randint(1000,(10,))

outputs tensor([ 370., 502., 652., 859., 545., 964., 566., 576.,1000., 803.])

如何根据张量'lengths'沿维度1(句子长度)的某个索引后用零填充张量'a'?

我想要那样的东西:

a[ : , lengths : , : ]  = 0

一种方法(如果批量足够大,速度会很慢):

for i_batch in range(10):
    a[ i_batch  , lengths[i_batch ] : , : ]  = 0

【问题讨论】:

    标签: python nlp pytorch


    【解决方案1】:

    您可以使用二进制掩码来做到这一点。
    使用lengths 作为mask 的列索引,我们指示每个序列的结束位置(请注意,我们使maska.size(1) 长,以允许全长序列)。
    使用cumsum(),我们将seq len 之后mask 中的所有条目设置为1。

    mask = torch.zeros(a.shape[0], a.shape[1] + 1, dtype=a.dtype, device=a.device)
    mask[(torch.arange(a.shape[0]), lengths)] = 1
    mask = mask.cumsum(dim=1)[:, :-1]  # remove the superfluous column
    a = a * (1. - mask[..., None])     # use mask to zero after each column
    

    对于a.shape = (10, 5, 96)lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3]
    在每一行将 1 分配给各自的 lengthsmask 看起来像:

    mask = 
    tensor([[0., 1., 0., 0., 0., 0.],
            [0., 0., 1., 0., 0., 0.],
            [0., 1., 0., 0., 0., 0.],
            [0., 1., 0., 0., 0., 0.],
            [0., 0., 0., 1., 0., 0.],
            [1., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 1., 0.],
            [0., 0., 0., 0., 1., 0.],
            [0., 1., 0., 0., 0., 0.],
            [0., 0., 0., 1., 0., 0.]])
    

    cumsum 之后得到

    mask = 
    tensor([[0., 1., 1., 1., 1.],
            [0., 0., 1., 1., 1.],
            [0., 1., 1., 1., 1.],
            [0., 1., 1., 1., 1.],
            [0., 0., 0., 1., 1.],
            [1., 1., 1., 1., 1.],
            [0., 0., 0., 0., 1.],
            [0., 0., 0., 0., 1.],
            [0., 1., 1., 1., 1.],
            [0., 0., 0., 1., 1.]])
    

    请注意,有效序列条目所在的位置恰好为零,而超出序列长度的位置为零。使用1 - mask 可以满足您的需求。

    享受;)

    【讨论】:

      猜你喜欢
      • 2021-07-08
      • 1970-01-01
      • 2020-08-02
      • 2020-06-11
      • 2022-07-20
      • 2016-07-20
      • 2020-10-26
      • 2016-06-27
      相关资源
      最近更新 更多