您可以使用二进制掩码来做到这一点。
使用lengths 作为mask 的列索引,我们指示每个序列的结束位置(请注意,我们使mask 比a.size(1) 长,以允许全长序列)。
使用cumsum(),我们将seq len 之后mask 中的所有条目设置为1。
mask = torch.zeros(a.shape[0], a.shape[1] + 1, dtype=a.dtype, device=a.device)
mask[(torch.arange(a.shape[0]), lengths)] = 1
mask = mask.cumsum(dim=1)[:, :-1] # remove the superfluous column
a = a * (1. - mask[..., None]) # use mask to zero after each column
对于a.shape = (10, 5, 96) 和lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3]。
在每一行将 1 分配给各自的 lengths,mask 看起来像:
mask =
tensor([[0., 1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]])
cumsum 之后得到
mask =
tensor([[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 1., 1.]])
请注意,有效序列条目所在的位置恰好为零,而超出序列长度的位置为零。使用1 - mask 可以满足您的需求。
享受;)