【发布时间】:2021-04-03 16:15:51
【问题描述】:
我有一个旧的 TF1.1x 检查点,包括一个 LSTM 层,并且我还有一个早期运行的层激活,用于旧网络的每一层。我正在尝试使用 Python 在 TF2.2 和 Keras 中重新创建这个网络。 旧网络中使用的层是 'tf.contrib.rnn.LSTMBlockFusedCell'。
我将检查点的 LSTM 内核权重拆分为相应的“内核”和“Recurrent_kernel”,并将它们分别加载到 TF2.2 中的 LSTM 层(以及“偏差”)。
但是,当我使用旧激活运行 model.prediction 时,与旧模型激活相比,我从新 LSTM 层得到完全不同的输出。
我只加载了上面的,即:Kernel、Recurrent_Kernel 和Bias weights。该层没有其他参数。
希望已经提取了下面代码sn-p中的要点:
# Create minimalistic Model, and Build it
#
modelC = keras.Sequential()
modelC.add( keras.layers.Reshape([-1,2048], name='l4_lstm' ))
modelC.add( keras.layers.LSTM( units=2048 ) )
modelC.build(input_shape = (batch_size, 2048))
# Load Weights from Checkpoint Dictionary 'ckptdict',
#
weights_ds = []
weights_ds.append(ckptdict['lstm_fused_cell/kernel'][:2048] ) # "W"
weights_ds.append(ckptdict['lstm_fused_cell/kernel'][2048:] ) # "U"
weights_ds.append(ckptdict['lstm_fused_cell/bias']) # "b"
modelC.set_weights(weights_ds)
# Run the minimal model on Activations from last layer before LSTM
# (data corresponding to the Checkpointed TF1.1x model)
#
l3pred = modelC.predict( l3 )
# At this point, l3pred is wildly different from the TF1.1x version,
#
对于网络的其他层,导入权重的类似方法可以正常工作(== 与旧激活的结果相同),这些层都是“密集”的,但 LSTM 层让我望而却步。
谁能指出解释如何正确导入和运行 LSTM 层的描述?非常感谢!
(2019 年 7 月出现了类似的问题,但我还没有看到答案。)
【问题讨论】:
标签: python-3.x tensorflow keras lstm