【问题标题】:Get each sequence's last item from packed sequence从打包序列中获取每个序列的最后一项
【发布时间】:2019-08-19 07:31:34
【问题描述】:

我正在尝试通过 GRU 放置一个打包和填充的序列,并检索每个序列的最后一项的输出。当然,我指的不是-1 项目,而是实际最后一个未填充的项目。我们事先知道序列的长度,因此应该很容易为每个序列提取length-1 项。

我尝试了以下

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Data
input = torch.Tensor([[[0., 0., 0.],
                       [1., 0., 1.],
                       [1., 1., 0.],
                       [1., 0., 1.],
                       [1., 0., 1.],
                       [1., 1., 0.]],

                      [[1., 1., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[0., 0., 0.],
                       [1., 0., 0.],
                       [1., 1., 1.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[1., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]]])

lengths = [6, 4, 3, 1]
p = pack_padded_sequence(input, lengths, batch_first=True)

# Forward
gru = torch.nn.GRU(3, 12, batch_first=True)
packed_output, gru_h = gru(p)

# Unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])

last_seq_items = torch.index_select(output, 1, last_seq_idxs) 

print(last_seq_items.size())
# torch.Size([4, 4, 12])

但形状不是我所期望的。我原本期望得到4x12,即last item of each individual sequence x hidden。`

我可以遍历整个事情,并构建一个包含我需要的项目的新张量,但我希望有一种利用一些智能数学的内置方法。我担心手动循环和构建会导致性能很差。

【问题讨论】:

    标签: indexing deep-learning pytorch tensor zero-padding


    【解决方案1】:

    除了最后两个操作last_seq_idxslast_seq_items 你可以只做last_seq_items=output[torch.arange(4), input_sizes-1]

    我不认为index_select 做的事情是正确的。它将在您传递的索引处选择整个批次,因此您的输出大小为 [4,4,12]。

    【讨论】:

    • 谢谢。这确实看起来比我的方法更直接。
    【解决方案2】:

    Umang Gupta 回答的更详细的替代方案:

    # ...
    output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
    # One per sequence, with its last actual node extracted, and unsqueezed
    last_seq = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
    # Merge them together all sequences together to get batch
    last_seq = torch.cat(last_seq, dim=0)
    

    【讨论】:

    • 我觉得这样会慢一些,不是吗?
    • @UmangGupta 哦,是的,我绝对认为你的只是一个切片,而我的方法需要迭代和连接。我出于说明目的发布了我的。但你的才是应该使用的。
    猜你喜欢
    • 2021-11-05
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-08-24
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多