【问题标题】:Input Data Format for RNNRNN 的输入数据格式
【发布时间】:2017-09-30 16:48:06
【问题描述】:

我很困惑如何准确地将数据序列编码为 LSTM RNN 的输入。

在普通 DNN 中,每个标签都有一个输入。 RNN 中的“输入”是什么?为了训练与标签关联的顺序事件,它是否必须是一组(或序列)数据?

我很困惑如何编码顺序信息,因为似乎应该有多个与给定标签关联的输入。

【问题讨论】:

    标签: machine-learning keras lstm recurrent-neural-network


    【解决方案1】:

    似乎应该有多个与给定标签关联的输入

    是的,你是对的。实际上,您的输入需要是 3D 矩阵。例如,如果您有 n 个序列,每个序列的长度为 m,并且您的每个序列数据都有 d 个特征,那么您的 RNN 的输入必须是维度 (n,m,d)。

    【讨论】:

      【解决方案2】:

      例如,如果您有一个时间序列(X1,..,Xt),并且您想训练一个预测器在 +1 的范围内进行预测并使用长度为 3 的序列,您的输入和输出将是:

      [[X1,X2,X3]]    [X4]
      [[X2,X3,X4]]    [X5]
      ...
      [[Xt-3,Xt-2,Xt-1]] [Xt]
      

      所以,有 t-3 个序列,每个序列的长度为 3,并且有 1 个特征。维度应该是(t-3,3,1)。

      【讨论】:

        【解决方案3】:

        让我们用代码写一个例子。

        假设我们有一些句子,其中句子中的每个单词都被编码为一个向量(可能来自 word2vec 的向量)。

        假设我们要将每个句子分类为两个类 (0, 1) 中的一个。我们可以像这样构建一个简单的分类器:

        import numpy as np
        from keras.models import Sequential
        from keras.layers import LSTM, Dense
        
        # each example (of which we have a 100) is a sequence of 10 words and
        # each words is encoded as 16 element vectors
        
        X = np.random.rand(100, 10, 16) 
        y = np.random.choice(1, 100)
        
        model = Sequential()
        model.add(LSTM(128, input_shape=(10, 16)))
        model.add(Dense(1, activation='sigmoid'))
        model.compile(loss='binary_crossentropy', optimizer='sgd')
        
        # fit model
        model.fit(X, y, epochs=3, batch=16)
        

        【讨论】:

        • model.add(LSTM(128, input_shape=(10, 16))
        【解决方案4】:

        Graves 的书(第 19 页)用明确的维度扩充了答案:

        考虑一个长度为 T 的输入序列 x 呈现给具有 I 个输入单元、H 个隐藏单元和 K 个输出单元的 RNN。令 $x_i^t$ 为输入 i 在时间 t 的值。
        在每个单词 t=1,...,T 的输入句子中,单词 t 的长度为 I 嵌入向量。标量级别 ( $x_i^t$ ) 的输入是该向量的第 i 个分量。

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 1970-01-01
          • 2017-11-26
          • 2018-03-06
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2022-12-10
          相关资源
          最近更新 更多