【问题标题】:Treating a SubGraph of a Neural Network as a Model In TensorFlow/Keras将神经网络的子图视为 TensorFlow/Keras 中的模型
【发布时间】:2020-07-10 04:16:32
【问题描述】:

我正在尝试使用 Keras Layer API 在 tensorflow 中训练自动编码器。这个 API 非常好并且易​​于使用来设置深度学习层。

快速回顾一下,自动编码器(在我看来)是一个函数 $f(x) = z$ 及其伪逆 \hat{x} = f^{-1}(z) 使得 f(f ^{-1}(x)) \约x。在神经网络模型中,您将设置一个带有瓶颈层的神经网络,它试图使用 f^{-1}(f(x)) 来预测自身 x。当训练误差最小化时,您有两个分量,z = f(x) 是直到并包括瓶颈层的预测。 f^{-1}(z) 是最后的瓶颈层。

所以我设置了编码器:

SZ = 6
model = tf.keras.Sequential()

model.add(layers.InputLayer(SZ))
model.add(layers.Dense(SZ))
model.add(layers.Dense(1))
model.add(layers.Dense(SZ))
model.summary()

model.compile('sgd','mse',metrics = ['accuracy'])
history= model.fit(returns.values,returns.values,epochs=100)

我在这里的困难是权重和分量(f 是输入+dense(SZ)+dense(1),f^{-1} 是dense(1)+dense(SZ))是经过训练的,但我没有知道如何解开它们。有什么方法可以打破神经网络中的两层,并将它们视为自己独立的模型吗?

【问题讨论】:

    标签: python tensorflow keras neural-network


    【解决方案1】:
    import tensorflow as tf
    SZ=6
    encoder_input = tf.keras.layers.Input(shape=(SZ,))
    x = tf.keras.layers.Dense(SZ)(encoder_input)
    x = tf.keras.layers.Dense(1)(x)
    encoder_model = tf.keras.Model(inputs=encoder_input, outputs=x, name='encoder')
    
    decoder_input = tf.keras.layers.Input(shape=(1,))
    x2 = tf.keras.layers.Dense(SZ)(decoder_input)
    decoder_model = tf.keras.Model(inputs=decoder_input, outputs=x2, name='decoder')
    
    encoder_output = encoder_model(encoder_input)
    decoder_output = decoder_model(encoder_output)
    
    encoder_decoder_model = tf.keras.Model(inputs=encoder_input , outputs=decoder_output, name='encoder-decoder')
    encoder_decoder_model.summary()
    

    总结如下:

    Model: "encoder-decoder"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_8 (InputLayer)         [(None, 6)]               0         
    _________________________________________________________________
    encoder (Model)              (None, 1)                 49        
    _________________________________________________________________
    decoder (Model)              (None, 6)                 12        
    =================================================================
    Total params: 61
    Trainable params: 61
    Non-trainable params: 0
    

    您可以训练编码器-解码器模型,并且您将 encoder_modeldecoder_model 分开将被自动训练。您还可以从您的 encoder_decoder 模型中检索它们,如下所示:

    retrieved_encoder = encoder_decoder_model.get_layer('encoder')
    retrieved_encoder.summary()
    

    打印出来:

    Model: "encoder"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_8 (InputLayer)         [(None, 6)]               0         
    _________________________________________________________________
    dense_11 (Dense)             (None, 6)                 42        
    _________________________________________________________________
    dense_12 (Dense)             (None, 1)                 7         
    =================================================================
    Total params: 49
    Trainable params: 49
    Non-trainable params: 0
    

    和解码器:

    retrieved_decoder = encoder_decoder_model.get_layer('decoder')
    retrieved_decoder.summary()
    

    哪个打印:

    Model: "decoder"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_9 (InputLayer)         [(None, 1)]               0         
    _________________________________________________________________
    dense_13 (Dense)             (None, 6)                 12        
    =================================================================
    Total params: 12
    Trainable params: 12
    Non-trainable params: 0
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-08-18
      • 1970-01-01
      相关资源
      最近更新 更多