【问题标题】:How to get weight matrix of one layer at every epoch in LSTM model based on Keras?如何在基于 Keras 的 LSTM 模型的每个时期获得一层的权重矩阵?
【发布时间】:2017-09-28 16:13:40
【问题描述】:

我有一个基于 Keras 的简单 LSTM 模型。

X_train, X_test, Y_train, Y_test = train_test_split(input, labels, test_size=0.2, random_state=i*10)

X_train = X_train.reshape(80,112,12)
X_test = X_test.reshape(20,112,12)

y_train = np.zeros((80,112),dtype='int')
y_test = np.zeros((20,112),dtype='int')

y_train = np.repeat(Y_train,112, axis=1)
y_test = np.repeat(Y_test,112, axis=1)
np.random.seed(1)

# create the model
model = Sequential()
batch_size = 20

model.add(BatchNormalization(input_shape=(112,12), mode = 0, axis = 2))#4
model.add(LSTM(100, return_sequences=False, input_shape=(112,12))) #7 

model.add(Dense(112, activation='hard_sigmoid'))#9
model.compile(loss='binary_crossentropy', optimizer='RMSprop', metrics=['binary_accuracy'])#9

model.fit(X_train, y_train, nb_epoch=30)#9

# Final evaluation of the model
scores = model.evaluate(X_test, y_test, batch_size = batch_size, verbose=0)

我知道如何通过model.get_weights() 获取权重列表,但这是模型完全训练后的值。我想在每个时期都获得权重矩阵(例如,我的 LSTM 中的最后一层),而不仅仅是它的最终值。换句话说,我有 30 个 epoch,我需要得到 30 个权重矩阵值。

真的谢谢,在keras的wiki上没找到解决办法。

【问题讨论】:

    标签: tensorflow deep-learning keras lstm


    【解决方案1】:

    您可以为其编写自定义回调:

    from keras.callbacks import Callback
    
    class CollectWeightCallback(Callback):
        def __init__(self, layer_index):
            super(CollectWeightCallback, self).__init__()
            self.layer_index = layer_index
            self.weights = []
    
        def on_epoch_end(self, epoch, logs=None):
            layer = self.model.layers[self.layer_index]
            self.weights.append(layer.get_weights())
    

    回调的属性self.model 是对正在训练的模型的引用。训练开始时通过Callback.set_model() 设置。

    要获得每个时期最后一层的权重,请使用它:

    cbk = CollectWeightCallback(layer_index=-1)
    model.fit(X_train, y_train, nb_epoch=30, callbacks=[cbk])
    

    然后将权重矩阵收集到cbk.weights

    【讨论】:

      猜你喜欢
      • 2017-12-01
      • 2020-08-09
      • 1970-01-01
      • 2019-03-18
      • 2018-08-06
      • 1970-01-01
      • 2019-02-27
      • 2017-10-07
      • 2017-06-02
      相关资源
      最近更新 更多