【问题标题】:Hidden state tensors have a different order than the returned tensors隐藏状态张量的顺序与返回的张量不同
【发布时间】:2020-03-06 13:50:29
【问题描述】:


作为 GRU 训练的一部分,我想检索隐藏状态张量。

我定义了一个有两层的 GRU:

self.lstm = nn.GRU(params.vid_embedding_dim, params.hidden_dim , 2)

forward函数定义如下(以下只是部分实现):

    def forward(self, s, order, batch_size, where, anchor_is_phrase = False):
    """
    Forward prop. 
    """
      # s is of shape [128 , 1 , 300] , 128 is batch size
      output, (a,b) = self.lstm(s.cuda())
      output.data.contiguous()

out 的形状是:[128 , 400](128 是每个样本嵌入到 400 维向量中的样本数)。

我知道out 是最后一个隐藏状态的输出,因此我希望它等于b。但是,在我检查了这些值之后,我发现它确实相等,但 b 包含以不同顺序的张量,例如 output[0]b[49]。我在这里错过了什么吗?

谢谢。

【问题讨论】:

    标签: tensorflow deep-learning pytorch lstm recurrent-neural-network


    【解决方案1】:

    我理解你的困惑。看看下面的例子和 cmets:

    # [Batch size, Sequence length, Embedding size]
    inputs = torch.rand(128, 5, 300)
    gru = nn.GRU(input_size=300, hidden_size=400, num_layers=2, batch_first=True)
    
    with torch.no_grad():
        # output is all hidden states, for each element in the batch of the last layer in the RNN
        # a is the last hidden state of the first layer
        # b is the last hidden state of the second (last) layer
        output, (a, b) = gru(inputs)
    

    如果我们打印出形状,它们将证实我们的理解:

    print(output.shape) # torch.Size([128, 5, 400])
    print(a.shape) # torch.Size([128, 400])
    print(b.shape) # torch.Size([128, 400])
    

    此外,我们可以测试从output 获得的最后一层的最后一层的每个元素的最后隐藏状态是否等于b

    np.testing.assert_almost_equal(b.numpy(), output[:,:-1,:].numpy())
    

    最后,我们可以创建一个 3 层的 RNN,并运行相同的测试:

    gru = nn.GRU(input_size=300, hidden_size=400, num_layers=3, batch_first=True)
    with torch.no_grad():
        output, (a, b, c) = gru(inputs)
    
    np.testing.assert_almost_equal(c.numpy(), output[:,-1,:].numpy())
    

    再一次,断言通过了,但前提是我们为c(现在是 RNN 的最后一层)这样做。否则:

    np.testing.assert_almost_equal(b.numpy(), output[:,-1,:].numpy())
    

    引发错误:

    AssertionError: 数组几乎不等于 7 位小数

    我希望这能让你明白。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-12-06
      • 1970-01-01
      • 2019-07-11
      • 1970-01-01
      • 2019-06-12
      • 2018-12-13
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多