【问题标题】:Wrong LSTM time series predicted for input size different from trained input size为不同于训练输入大小的输入大小预测的错误 LSTM 时间序列
【发布时间】:2020-01-05 16:11:16
【问题描述】:

我正在与Bach chorales dataset 合作。每个合唱的长度约为 100-500 个时间步长,每个时间步长包含 4 个整数(例如:[74, 70, 65, 58]),其中每个整数对应于钢琴上的音符索引。

我正在尝试训练一个可以预测下一个时间步长的模型(4 注释),给定合唱团的一系列时间步长。

问题是什么:我得到了与模型训练的相同大小的输入的正确输出,但对于不同大小的输入,我得到了错误的输出。

到目前为止我做了什么:我使用了 Keras 的 TimeseriesGenerator 来生成输入序列和相应的输出:

generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
print(generator[0])

输出:

(array([[[74, 70, 65, 58],
        [74, 70, 65, 58],
        [74, 70, 65, 58]]]), array([[75, 70, 58, 55]]))

然后我训练了一个 LSTM 模型。我在 input_shape 中使用了None 来允许可变大小的输入。

n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')

# fit model
model.fit_generator(generator, epochs=500, validation_data=validation_generator)

我预测大小为 3 的输入似乎有效(因为它接受了长度为 3 的输入的训练):

# demonstrate prediction
x_input = dataX[5:8]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[8])
[[[75 70 58 55]
  [75 70 60 55]
  [75 70 60 55]]]
[[76.25768  68.525444 59.745518 53.799873]]
expected:  [77 69 62 50]

现在我尝试预测不同大小的输入,比如长度为 5,但这是行不通的。 测试样本的输出:

# demonstrate prediction
x_input = dataX[1:6]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[6])
[[[74 70 65 58]
  [74 70 65 58]
  [74 70 65 58]
  [75 70 58 55]
  [75 70 58 55]]]
[[227.16667 217.89767 213.62988 148.44817]]
expected:  [75 70 60 55]

预测完全错误,似乎是在做一些总结。任何关于为什么会发生这种情况以及如何解决它的输入/帮助将不胜感激。

【问题讨论】:

    标签: python keras time-series lstm recurrent-neural-network


    【解决方案1】:

    我可以为您提供模型无法学习的三个可能原因。

    最后一个密集层

    model.add(Dense(n_features))

    这可能是您模型中的主要罪魁祸首(但我建议全部解决)。分类模型的最后一层需要是softmax 层。所以只需将其更改为

    model.add(Dense(n_features, activation='softmax`))

    损失函数

    通常crossentropymse 更适合分类问题。所以试试吧,

    model.compile(optimizer='adam', loss='categorical_crossentropy')

    LSTM 中的激活

    LSTM 使用tanh 作为激活。除非您有充分的理由将其更改为 relu,否则不要这样做,因为当激活函数更改为正常的前馈层 woud 时,LSTM 不会输出相同的行为。

    【讨论】:

    • 谢谢!合并所有 3 个建议后,该模型似乎没有收敛。想知道我是否需要标准化/规范化输入。标准化后,预测比以前好,但仍然不是预期的。
    • @kumar,你有多少个时间步长的数据点?
    【解决方案2】:

    我建议 x_input 的长度保持 3 会更好 以下是我的测试代码:

    import sys
    from keras.models import Sequential
    from keras.layers import Dense,Activation,LSTM
    from keras.preprocessing.sequence import TimeseriesGenerator
    import numpy as np
    import logger
    logger.logger_initialize('LOGGER.log')
    
    
    def bc_pitches():
        a = open('chorales.lisp', 'r')
    
        #parse the input as vectors and store vectors
    
        def obtainNum(elemSt):
            a = elemSt.split(" ")
            return int(a[1])
    
        bookOfLists = []
    
        for i in range(210):
            counter = 0
            gun = a.readline()
            if (len(gun) <= 1): #for /n accommodation
                continue
            else:
                while (gun[counter:(counter+2)] != "(("):
                    counter += 1
                tribo = gun[(counter+2):(len(gun)-4)]
                stringArr = tribo.split("))((") #separates each vector into an element
                lister = [x.split(") (") for x in stringArr]
                #lister = map(lambda x : x.split(") ("), stringArr) #each vector becomes
                #a list of component elements so lister is a list of lists
                lister2 = [[obtainNum(each) for each in x] for x in lister]
                #lister2 = map(lambda x : map(obtainNum, x), lister)
                bookOfLists.append(lister2)
        pitches=np.zeros([100,500],dtype=np.int32)
        for i in range(len(bookOfLists)):
            for j in range(len(bookOfLists[i])):
                for t in range(bookOfLists[i][j][0],bookOfLists[i][j][0]+bookOfLists[i][j][2]):
                    try:
                        pitches[i][t]=bookOfLists[i][j][1]
                    except:
                        print(i,j,t)
                        sys.exit()
        return pitches
    
    pitches=bc_pitches()
    dataX=dataY=(pitches[:4,:].T)[:150]
    generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
    for i in range(len(generator)):
        logger.info(i,generator[i])
    
    validation_dataX=validation_dataY=(pitches[:4,:].T)[150:]
    validation_generator = TimeseriesGenerator(validation_dataX, validation_dataY, length=3, batch_size=1)
    
    
    n_features = 4
    model = Sequential()
    model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
    model.add(LSTM(128 , activation = 'relu'))
    model.add(Dense(n_features))
    model.compile(optimizer='adam', loss='mse')
    
    # fit model
    model.fit_generator(generator, epochs=50, validation_data=validation_generator)
    
    
    # demonstrate prediction
    x_input = (pitches[:4,:].T)[155:158]
    x_input = x_input.reshape((1, len(x_input), 4))
    logger.info(x_input)
    yhat = model.predict(x_input, verbose=0)
    logger.info(yhat)
    logger.info('expected: ', (pitches[:4,:].T)[158])
    
    
    # demonstrate prediction
    x_input = (pitches[:4,:].T)[151:156]
    x_input = x_input.reshape((1, len(x_input), 4))
    logger.info(x_input)
    yhat = model.predict(x_input, verbose=0)
    logger.info(yhat)
    logger.info('expected: ', (pitches[:4,:].T)[156])
    
    for i in range(10):
        yhat = model.predict(validation_generator[i][0], verbose=0)
        logger.info(i,yhat)
        logger.info('expected: ', validation_generator[i][1])
    

    结果:

    ...
        100 (array([[[72, 73, 69, 73],
                [72, 73, 69, 73],
                [72, 73, 69, 73]]]),
         array([[72, 73, 69, 73]])) 
        101 (array([[[72, 73, 69, 73],
                [72, 73, 69, 73],
                [72, 73, 69, 73]]]),
         array([[74, 71, 71, 71]])) 
        102 (array([[[72, 73, 69, 73],
                [72, 73, 69, 73],
                [74, 71, 71, 71]]]),
         array([[74, 71, 71, 71]])) 
        103 (array([[[72, 73, 69, 73],
                [74, 71, 71, 71],
                [74, 71, 71, 71]]]),
         array([[74, 71, 71, 71]])) 
        104 (array([[[74, 71, 71, 71],
                [74, 71, 71, 71],
                [74, 71, 71, 71]]]),
         array([[74, 71, 71, 71]])) 
        105 (array([[[74, 71, 71, 71],
                [74, 71, 71, 71],
                [74, 71, 71, 71]]]),
         array([[74, 73, 67, 71]])) 
        106 (array([[[74, 71, 71, 71],
                [74, 71, 71, 71],
                [74, 73, 67, 71]]]),
         array([[74, 73, 67, 71]])) 
        107 (array([[[74, 71, 71, 71],
                [74, 73, 67, 71],
                [74, 73, 67, 71]]]),
         array([[74, 73, 67, 71]])) 
        108 (array([[[74, 73, 67, 71],
                [74, 73, 67, 71],
                [74, 73, 67, 71]]]),
         array([[74, 73, 67, 71]])) 
        109 (array([[[74, 73, 67, 71],
                [74, 73, 67, 71],
                [74, 73, 67, 71]]]),
         array([[74, 74, 69, 76]])) 
        110 (array([[[74, 73, 67, 71],
                [74, 73, 67, 71],
                [74, 74, 69, 76]]]),
         array([[74, 74, 69, 76]])) 
        111 (array([[[74, 73, 67, 71],
                [74, 74, 69, 76],
                [74, 74, 69, 76]]]),
         array([[72, 74, 71, 76]])) 
        112 (array([[[74, 74, 69, 76],
                [74, 74, 69, 76],
                [72, 74, 71, 76]]]),
         array([[72, 74, 71, 76]])) 
        113 (array([[[74, 74, 69, 76],
                [72, 74, 71, 76],
                [72, 74, 71, 76]]]),
         array([[71, 73, 72, 71]])) 
        114 (array([[[72, 74, 71, 76],
                [72, 74, 71, 76],
                [71, 73, 72, 71]]]),
         array([[71, 73, 72, 71]])) 
        115 (array([[[72, 74, 71, 76],
                [71, 73, 72, 71],
                [71, 73, 72, 71]]]),
         array([[71, 73, 72, 71]])) 
        116 (array([[[71, 73, 72, 71],
                [71, 73, 72, 71],
                [71, 73, 72, 71]]]),
         array([[71, 73, 72, 71]])) 
        117 (array([[[71, 73, 72, 71],
                [71, 73, 72, 71],
                [71, 73, 72, 71]]]),
         array([[69, 71, 71, 73]])) 
        118 (array([[[71, 73, 72, 71],
                [71, 73, 72, 71],
                [69, 71, 71, 73]]]),
         array([[69, 71, 71, 73]])) 
        119 (array([[[71, 73, 72, 71],
                [69, 71, 71, 73],
                [69, 71, 71, 73]]]),
         array([[69, 71, 71, 73]]))
        120 (array([[[69, 71, 71, 73],
                [69, 71, 71, 73],
                [69, 71, 71, 73]]]),
         array([[69, 71, 71, 73]]))
        121 (array([[[69, 71, 71, 73],
                [69, 71, 71, 73],
                [69, 71, 71, 73]]]),
         array([[69, 70, 72, 68]]))
        122 (array([[[69, 71, 71, 73],
                [69, 71, 71, 73],
                [69, 70, 72, 68]]]),
         array([[69, 70, 72, 68]]))
        123 (array([[[69, 71, 71, 73],
                [69, 70, 72, 68],
                [69, 70, 72, 68]]]),
         array([[69, 70, 71, 69]]))
        124 (array([[[69, 70, 72, 68],
                [69, 70, 72, 68],
                [69, 70, 71, 69]]]),
         array([[69, 70, 71, 69]]))
        125 (array([[[69, 70, 72, 68],
                [69, 70, 71, 69],
                [69, 70, 71, 69]]]),
         array([[67, 71, 69, 71]]))
        126 (array([[[69, 70, 71, 69],
                [69, 70, 71, 69],
                [67, 71, 69, 71]]]),
         array([[67, 71, 69, 71]]))
        127 (array([[[69, 70, 71, 69],
                [67, 71, 69, 71],
                [67, 71, 69, 71]]]),
         array([[67, 71, 69, 71]]))
        128 (array([[[67, 71, 69, 71],
                [67, 71, 69, 71],
                [67, 71, 69, 71]]]),
         array([[67, 71, 69, 71]]))
        129 (array([[[67, 71, 69, 71],
                [67, 71, 69, 71],
                [67, 71, 69, 71]]]),
         array([[71, 71, 68, 69]]))
        130 (array([[[67, 71, 69, 71],
                [67, 71, 69, 71],
                [71, 71, 68, 69]]]),
         array([[71, 71, 68, 69]]))
        131 (array([[[67, 71, 69, 71],
                [71, 71, 68, 69],
                [71, 71, 68, 69]]]),
         array([[71, 71, 68, 69]]))
        132 (array([[[71, 71, 68, 69],
                [71, 71, 68, 69],
                [71, 71, 68, 69]]]),
         array([[71, 71, 68, 69]]))
        133 (array([[[71, 71, 68, 69],
                [71, 71, 68, 69],
                [71, 71, 68, 69]]]),
         array([[71, 71, 69, 68]]))
        134 (array([[[71, 71, 68, 69],
                [71, 71, 68, 69],
                [71, 71, 69, 68]]]),
         array([[71, 71, 69, 68]]))
        135 (array([[[71, 71, 68, 69],
                [71, 71, 69, 68],
                [71, 71, 69, 68]]]),
         array([[71, 71, 69, 68]]))
        136 (array([[[71, 71, 69, 68],
                [71, 71, 69, 68],
                [71, 71, 69, 68]]]),
         array([[71, 71, 69, 68]]))
        137 (array([[[71, 71, 69, 68],
                [71, 71, 69, 68],
                [71, 71, 69, 68]]]),
         array([[72, 64, 69, 68]]))
        138 (array([[[71, 71, 69, 68],
                [71, 71, 69, 68],
                [72, 64, 69, 68]]]),
         array([[72, 64, 69, 68]]))
        139 (array([[[71, 71, 69, 68],
                [72, 64, 69, 68],
                [72, 64, 69, 68]]]),
         array([[72, 64, 69, 68]]))
        140 (array([[[72, 64, 69, 68],
                [72, 64, 69, 68],
                [72, 64, 69, 68]]]),
         array([[72, 64, 69, 68]]))
        141 (array([[[72, 64, 69, 68],
                [72, 64, 69, 68],
                [72, 64, 69, 68]]]),
         array([[74, 69, 76, 66]]))
        142 (array([[[72, 64, 69, 68],
                [72, 64, 69, 68],
                [74, 69, 76, 66]]]),
         array([[74, 69, 76, 66]]))
        143 (array([[[72, 64, 69, 68],
                [74, 69, 76, 66],
                [74, 69, 76, 66]]]),
         array([[74, 69, 76, 66]]))
        144 (array([[[74, 69, 76, 66],
                [74, 69, 76, 66],
                [74, 69, 76, 66]]]),
         array([[74, 69, 76, 66]]))
        145 (array([[[74, 69, 76, 66],
                [74, 69, 76, 66],
                [74, 69, 76, 66]]]),
         array([[74, 71, 72, 69]]))
        146 (array([[[74, 69, 76, 66],
                [74, 69, 76, 66],
                [74, 71, 72, 69]]]),
         array([[74, 71, 72, 69]]))
        Epoch 1/50
        147/147 [==============================] - 2s 16ms/step - loss: 514.8802 - val_l
        oss: 0.0082
        Epoch 2/50
        147/147 [==============================] - 2s 11ms/step - loss: 51.5768 - val_lo
        ss: 0.0249
        Epoch 3/50
        147/147 [==============================] - 2s 11ms/step - loss: 71.6900 - val_lo
        ss: 0.0464
        Epoch 4/50
        147/147 [==============================] - 2s 10ms/step - loss: 47.4575 - val_lo
        ss: 0.1303
        Epoch 5/50
        147/147 [==============================] - 2s 10ms/step - loss: 52.6841 - val_lo
        ss: 0.5772
        Epoch 6/50
        147/147 [==============================] - 2s 11ms/step - loss: 47.3059 - val_lo
        ss: 5.2535
        Epoch 7/50
        147/147 [==============================] - 2s 11ms/step - loss: 43.6491 - val_lo
        ss: 41.2008
        Epoch 8/50
        147/147 [==============================] - 2s 11ms/step - loss: 37.8593 - val_lo
        ss: 28.5831
        Epoch 9/50
        147/147 [==============================] - 2s 11ms/step - loss: 40.8553 - val_lo
        ss: 41.5958
        Epoch 10/50
        147/147 [==============================] - 2s 11ms/step - loss: 34.5995 - val_lo
        ss: 57.3419
        Epoch 11/50
        147/147 [==============================] - 2s 11ms/step - loss: 34.2054 - val_lo
        ss: 38.9516
        Epoch 12/50
        147/147 [==============================] - 2s 11ms/step - loss: 36.9247 - val_lo
        ss: 38.1881
        Epoch 13/50
        147/147 [==============================] - 2s 10ms/step - loss: 34.5922 - val_lo
        ss: 49.7601
        Epoch 14/50
        147/147 [==============================] - 2s 11ms/step - loss: 38.1668 - val_lo
        ss: 46.0043
        Epoch 15/50
        147/147 [==============================] - 2s 10ms/step - loss: 35.4724 - val_lo
        ss: 39.1485
        Epoch 16/50
        147/147 [==============================] - 2s 11ms/step - loss: 35.7787 - val_lo
        ss: 38.2263
        Epoch 17/50
        147/147 [==============================] - 2s 11ms/step - loss: 32.5241 - val_lo
        ss: 38.0783
        Epoch 18/50
        147/147 [==============================] - 2s 11ms/step - loss: 35.1693 - val_lo
        ss: 35.3403
        Epoch 19/50
        147/147 [==============================] - 2s 11ms/step - loss: 34.5822 - val_lo
        ss: 28.0546
        Epoch 20/50
        147/147 [==============================] - 2s 11ms/step - loss: 32.7388 - val_lo
        ss: 37.5600
        Epoch 21/50
        147/147 [==============================] - 2s 11ms/step - loss: 36.7384 - val_lo
        ss: 19.3809
        Epoch 22/50
        147/147 [==============================] - 2s 11ms/step - loss: 34.0202 - val_lo
        ss: 38.0124
        Epoch 23/50
        147/147 [==============================] - 2s 11ms/step - loss: 31.7241 - val_lo
        ss: 36.0455
        Epoch 24/50
        147/147 [==============================] - 2s 10ms/step - loss: 33.6021 - val_lo
        ss: 19.4785
        Epoch 25/50
        147/147 [==============================] - 2s 11ms/step - loss: 29.5922 - val_lo
        ss: 37.5662
        Epoch 26/50
        147/147 [==============================] - 2s 10ms/step - loss: 31.7600 - val_lo
        ss: 25.8877
        Epoch 27/50
        147/147 [==============================] - 2s 11ms/step - loss: 31.0494 - val_lo
        ss: 25.5513
        Epoch 28/50
        147/147 [==============================] - 2s 11ms/step - loss: 32.7150 - val_lo
        ss: 22.6177
        Epoch 29/50
        147/147 [==============================] - 2s 11ms/step - loss: 30.3998 - val_lo
        ss: 26.8450
        Epoch 30/50
        147/147 [==============================] - 2s 10ms/step - loss: 30.3076 - val_lo
        ss: 42.8708
        Epoch 31/50
        147/147 [==============================] - 2s 11ms/step - loss: 30.6752 - val_lo
        ss: 32.9248
        Epoch 32/50
        147/147 [==============================] - 2s 10ms/step - loss: 29.2235 - val_lo
        ss: 33.0209
        Epoch 33/50
        147/147 [==============================] - 2s 11ms/step - loss: 30.7826 - val_lo
        ss: 21.4303
        Epoch 34/50
        147/147 [==============================] - 2s 11ms/step - loss: 31.5795 - val_lo
        ss: 28.7224
        Epoch 35/50
        147/147 [==============================] - 2s 11ms/step - loss: 29.2187 - val_lo
        ss: 19.5436
        Epoch 36/50
        147/147 [==============================] - 2s 10ms/step - loss: 28.8158 - val_lo
        ss: 23.3435
        Epoch 37/50
        147/147 [==============================] - 2s 10ms/step - loss: 27.8942 - val_lo
        ss: 29.7689
        Epoch 38/50
        147/147 [==============================] - 2s 11ms/step - loss: 31.8379 - val_lo
        ss: 19.7113
        Epoch 39/50
        147/147 [==============================] - 2s 11ms/step - loss: 29.4185 - val_lo
        ss: 30.7159
        Epoch 40/50
        147/147 [==============================] - 2s 11ms/step - loss: 29.2826 - val_lo
        ss: 22.0266
        Epoch 41/50
        147/147 [==============================] - 2s 11ms/step - loss: 29.3911 - val_lo
        ss: 22.6929
        Epoch 42/50
        147/147 [==============================] - 2s 10ms/step - loss: 28.0742 - val_lo
        ss: 16.1369
        Epoch 43/50
        147/147 [==============================] - 2s 11ms/step - loss: 27.4483 - val_lo
        ss: 19.0667
        Epoch 44/50
        147/147 [==============================] - 2s 11ms/step - loss: 27.6157 - val_lo
        ss: 15.3852
        Epoch 45/50
        147/147 [==============================] - 2s 11ms/step - loss: 27.9996 - val_lo
        ss: 21.4107
        Epoch 46/50
        147/147 [==============================] - 2s 11ms/step - loss: 28.4632 - val_lo
        ss: 17.0626
        Epoch 47/50
        147/147 [==============================] - 2s 11ms/step - loss: 29.0796 - val_lo
        ss: 21.7797
        Epoch 48/50
        147/147 [==============================] - 2s 10ms/step - loss: 28.2646 - val_lo
        ss: 21.8080
        Epoch 49/50
        147/147 [==============================] - 2s 11ms/step - loss: 28.7243 - val_lo
        ss: 18.9899
        Epoch 50/50
        147/147 [==============================] - 2s 11ms/step - loss: 28.2579 - val_lo
        ss: 28.6534
        [[[72 73 74 68]
          [71 74 76 66]
          [71 74 76 66]]]
        [[72.415985 69.27797  71.99651  69.86983 ]]
        expected:  [71 74 76 66]
        [[[74 71 72 69]
          [72 73 74 68]
          [72 73 74 68]
          [72 73 74 68]
          [72 73 74 68]]]
        [[153.16042 179.3388  158.57655 169.93341]]
        expected:  [71 74 76 66]
        0 [[73.17023 69.77195 71.62949 71.44139]]
        expected:  [[72 73 74 68]]
        1 [[72.80142  69.71678  71.557175 71.15702 ]]
        expected:  [[72 73 74 68]]
        2 [[72.39997  69.51012  71.5443   70.574905]]
        expected:  [[72 73 74 68]]
        3 [[72.39997  69.51012  71.5443   70.574905]]
        expected:  [[71 74 76 66]]
        4 [[72.51985  69.45031  71.813896 70.3402  ]]
        expected:  [[71 74 76 66]]
        5 [[72.415985 69.27797  71.99651  69.86983 ]]
        expected:  [[71 74 76 66]]
        6 [[72.11394  68.977165 72.128334 69.17176 ]]
        expected:  [[71 74 76 66]]
        7 [[72.11394  68.977165 72.128334 69.17176 ]]
        expected:  [[71 76 74 61]]
        8 [[72.221664 69.22221  71.957596 68.933846]]
        expected:  [[71 76 74 61]]
        9 [[72.15421  69.480225 71.38563  68.43072 ]]
        expected:  [[71 76 74 61]]
    
        (Keras) D:\programs_data\Keras>
    

    【讨论】:

    • 感谢您的回答。看起来你的代码有类似的问题,输入:[[[74 71 72 69] [72 73 74 68] [72 73 74 68] [72 73 74 68] [72 73 74 68]]],输出是[[153.16042 179.3388 158.57655 169.93341]]
    猜你喜欢
    • 1970-01-01
    • 2017-07-27
    • 1970-01-01
    • 2019-03-08
    • 2018-07-25
    • 2016-11-15
    • 2017-10-25
    • 1970-01-01
    • 2019-07-27
    相关资源
    最近更新 更多