【问题标题】:Which output should I use for prediction with LSTM for sequenced data?对于序列数据,我应该使用 LSTM 预测哪个输出?
【发布时间】:2020-09-18 09:16:28
【问题描述】:

我还是机器学习和深度学习的新手。我目前正在尝试在 PyTorch 中使用 LSTM 预测时间序列数据。我遇到的问题是我不明白应该使用哪个输出来进行最终预测。 我的代码如下:

class Model(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, seq_len, dropout):
    super(Model, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.dropout = dropout
    self.seq_len = seq_len
    self.lstm = nn.LSTM(
        input_size = self.input_size,
        hidden_size = self.hidden_size,
        dropout = self.dropout
    )
    self.linear = nn.Linear(self.hidden_size, self.output_size)

  def reset_hidden_state(self):
    self.hidden = (
        torch.zeros(1, self.seq_len, self.hidden_size),
        torch.zeros(1, self.seq_len, self.hidden_size)
    )

  def forward(self, sequences):
    lstm_out, self.hidden = self.lstm(sequences, self.hidden)
    y_pred = self.linear(lstm_out[-1, :, :])
    return y_pred

mymodel = Model(5, 10, 1, 3, 0.0)
inps = torch.randn(10, 3, 5)   #input
#print(inps)
mymodel.reset_hidden_state()
out = mymodel.forward(inps)
print(out.shape)
print(out)

输出:

torch.Size([3, 1])

张量([[-0.0996], [-0.0587], [-0.0421]], grad_fn=)

如您所见,这给了我三个输出,但我的输出大小为 1,因为我试图仅预测 1 个变量。那么,在这种情况下,我应该使用哪个变量来进行 final 预测?或者,是否有可能只预测这样的顺序数据的 1 个值?

注意:我的python版本是3.7.4 我的 PyTorch 版本是 1.4.0

如果我在提问时犯了任何错误,我深表歉意。这是我第一次在这里提问。

【问题讨论】:

    标签: python time-series pytorch lstm


    【解决方案1】:

    您已经在使用 LSTM 的正确输出,这是最后一个隐藏状态。方便的是,这也是 lstm_out 中的最后一个元素,您将其用作 lstm_out[-1, :, :]

    模型inps的输入是多个序列,因为它们的大小是[seq_len, batch_size, num_featuers] = [10, 3, 5]。这意味着您有 3 个独立的序列,每个序列有 10 个时间步,每个时间步有 5 个特征。

    因此,out(大小:[3, 1])包含 3 个序列中每一个的预测。 out[0][0] 是第一个序列的预测,out[1][0] 是第二个序列,out[2][0] 是第三个序列。您还可以使用 out.unsqueeze(1) 摆脱奇异的第二个序列,因此您有一个具有 3 个预测的一维张量。

    如果您只想预测单个序列,您可以使用批量大小为 1,这意味着输入的大小为 [10, 1, 5],然后您会得到一个返回值,即使它在大小为 [1, 1] 的张量中。

    【讨论】:

      猜你喜欢
      • 2020-05-02
      • 2018-06-10
      • 1970-01-01
      • 1970-01-01
      • 2021-11-03
      • 1970-01-01
      • 2021-10-02
      • 2018-02-16
      • 2017-06-16
      相关资源
      最近更新 更多