import tensorflow as tf
SZ=6
encoder_input = tf.keras.layers.Input(shape=(SZ,))
x = tf.keras.layers.Dense(SZ)(encoder_input)
x = tf.keras.layers.Dense(1)(x)
encoder_model = tf.keras.Model(inputs=encoder_input, outputs=x, name='encoder')
decoder_input = tf.keras.layers.Input(shape=(1,))
x2 = tf.keras.layers.Dense(SZ)(decoder_input)
decoder_model = tf.keras.Model(inputs=decoder_input, outputs=x2, name='decoder')
encoder_output = encoder_model(encoder_input)
decoder_output = decoder_model(encoder_output)
encoder_decoder_model = tf.keras.Model(inputs=encoder_input , outputs=decoder_output, name='encoder-decoder')
encoder_decoder_model.summary()
总结如下:
Model: "encoder-decoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_8 (InputLayer) [(None, 6)] 0
_________________________________________________________________
encoder (Model) (None, 1) 49
_________________________________________________________________
decoder (Model) (None, 6) 12
=================================================================
Total params: 61
Trainable params: 61
Non-trainable params: 0
您可以训练编码器-解码器模型,并且您将 encoder_model 和 decoder_model 分开将被自动训练。您还可以从您的 encoder_decoder 模型中检索它们,如下所示:
retrieved_encoder = encoder_decoder_model.get_layer('encoder')
retrieved_encoder.summary()
打印出来:
Model: "encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_8 (InputLayer) [(None, 6)] 0
_________________________________________________________________
dense_11 (Dense) (None, 6) 42
_________________________________________________________________
dense_12 (Dense) (None, 1) 7
=================================================================
Total params: 49
Trainable params: 49
Non-trainable params: 0
和解码器:
retrieved_decoder = encoder_decoder_model.get_layer('decoder')
retrieved_decoder.summary()
哪个打印:
Model: "decoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_9 (InputLayer) [(None, 1)] 0
_________________________________________________________________
dense_13 (Dense) (None, 6) 12
=================================================================
Total params: 12
Trainable params: 12
Non-trainable params: 0