【问题标题】:Extract features from 2 auto-encoders and feed them into an MLP从 2 个自动编码器中提取特征并将它们输入 MLP
【发布时间】:2018-11-12 11:46:26
【问题描述】:

我知道从自动编码器中提取的特征可以输入到 mlp 中以用于分类或回归目的。这是我之前做过的事情。
但是如果我有 2 个自动编码器呢?我可以从 2 个自动编码器的瓶颈层中提取特征并将它们输入到基于这些特征执行分类的 mlp 中吗?如果是,那么如何?我不确定如何连接这两个功能集。我尝试使用 numpy.hstack() 给我“不可散列切片”错误,而使用 tf.concat() 给我错误“模型的输入张量必须是 Keras 张量”。两个自动编码器的瓶颈层的维度均为 (None,100)。所以,基本上,如果我将它们水平堆叠,我应该得到一个(无,200)。 mlp 的隐藏层可能包含一些 (num_hidden=100) 个神经元。有人可以帮忙吗?

x1 = autoencoder1.get_layer('encoder2').output
x2 = autoencoder2.get_layer('encoder2').output

#inp = np.hstack((x1, x2))
inp = tf.concat([x1, x2], 1)
x = tf.concat([x1, x2], 1)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
y = Dense(1, activation='sigmoid', name='prediction')(h)
mymlp = Model(inputs=inp, outputs=y)

# Compile model
mymlp.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train model
mymlp.fit(x_train, y_train, epochs=20, batch_size=8)

根据@twolffpiggott 的建议更新:

from keras.layers import Input, Dense, Dropout
from keras import layers
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import numpy as np

x1 = Data1
x2 = Data2
y = Data3

num_neurons1 = x1.shape[1]
num_neurons2 = x2.shape[1]

# Train-test split
x1_train, x1_test, x2_train, x2_test, y_train, y_test = train_test_split(x1, x2, y, test_size=0.2)

# scale data within [0-1] range
scalar = MinMaxScaler()
x1_train = scalar.fit_transform(x1_train)
x1_test = scalar.transform(x1_test)

x2_train = scalar.fit_transform(x2_train)
x2_test = scalar.transform(x2_test)

x_train = np.concatenate([x1_train, x2_train], axis =-1)
x_test = np.concatenate([x1_test, x2_test], axis =-1)

# Auto-encoder1

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons1,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded1 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded1)
decoded = Dense(num_neurons1, activation='sigmoid', name='decoder2')(decoded)

# this model maps an input to its reconstruction
autoencoder1 = Model(inputs=input_data, outputs=decoded)
autoencoder1.compile(optimizer='sgd', loss='mse')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    

# training
autoencoder1.fit(x1_train, x1_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x1_test, x1_test))

# Auto-encoder2

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons2,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded2 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded2)
decoded = Dense(num_neurons2, activation='sigmoid', name='decoder2')(decoded)


# this model maps an input to its reconstruction
autoencoder2 = Model(inputs=input_data, outputs=decoded)
autoencoder2.compile(optimizer='sgd', loss='mse')

# training
autoencoder2.fit(x2_train, x2_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x2_test, x2_test))

# MLP

num_hidden = 100

encoded1.trainable = False
encoded2.trainable = False

encoded1 = autoencoder1(autoencoder1.inputs)
encoded2 = autoencoder2(autoencoder2.inputs)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
x = Dropout(0.2)(concatenated)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

# Compile model
myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit(x_train, y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict(x_test)

给我一​​个错误:unhashable type: 'list' from the line: myMLP = 模型(输入=[autoencoder1.inputs,autoencoder2.inputs],输出=y)

【问题讨论】:

    标签: python neural-network keras autoencoder


    【解决方案1】:

    问题是您将 numpy 数组与 keras 张量混合在一起。这不能走。

    有两种方法。

    • 从每个自动编码器预测 numpy 数组,连接数组,将它们发送到第三个模型
    • 连接所有模型,可能使自动编码器无法训练,适合每个自动编码器的一个输入。

    就个人而言,我会选择第一个。 (假设自动编码器已经过训练并且不需要更改)。

    第一种方法

    numpyOutputFromAuto1 = autoencoder1.predict(numpyInputs1)    
    numpyOutputFromAuto2 = autoencoder2.predict(numpyInputs2)
    
    inputDataForThird = np.concatenate([numpyOutputFromAuto1,numpyOutputFromAuto2],axis=-1)
    
    inputTensorForMlp = Input(inputsForThird.shape[1:])
    h = Dense(num_hidden, activation='relu', name='hidden')(inputTensorForMlp)
    y = Dense(1, activation='sigmoid', name='prediction')(h)
    
    mymlp = Model(inputs=inputTensorForMlp, outputs=y)
    
    ....
    mymlp.fit(inputDataForThird ,someY)
    

    第二种方法

    这有点复杂,起初我认为没有太多理由这样做。 (但当然也有可能是个不错的选择)

    现在我们完全忘记了 numpy 并使用 keras 张量。

    自行创建 mlp(如果您稍后在没有自动编码器的情况下使用它,那就太好了):

    inputTensorForMlp = Input(input_shape_compatible_with_concatenated_encoder_outputs)
    x = Dropout(0.2)(inputTensorForMlp)
    h = Dense(num_hidden, activation='relu', name='hidden')(x)
    h = Dropout(0.5)(h)
    y = Dense(1, activation='sigmoid', name='prediction')(h)
    myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)
    

    我们可能想要自动编码器的瓶颈特性,对吧?如果您碰巧使用以下方法正确创建了自动编码器:编码器模型,解码器模型,加入两者,那么仅使用编码器模型会更容易。其他:

    encodedOutput1 = autoencoder1.layers[bottleneckLayer].outputs #or encoder1.outputs
    encodedOutput2 = autoencoder1.layers[bottleneckLayer].outputs #or encoder2.outputs
    

    创建连接模型。连接必须使用 keras 层(我们正在使用 keras 张量):

    concatenated = Concatenate()([encodedOutput1,encodedOutput2])
    output = myMLP(concatenated)
    
    joinedModel = Model([autoencoder1.input,autoencoder2.input],output)
    

    【讨论】:

    • 我猜,你的意思是 mymlp = Model(inputs=inputTensorForMlp, outputs=y) ?
    • 非常感谢@Daniel。它工作得很好。但是,我仍在尝试使第二种方法起作用。
    • 抱歉,需要澄清一下。对于第一种方法,您提到 - numpyOutputFromAuto1 = autoencoder1.predict(numpyInputs1); numpyOutputFromAuto2 = autoencoder2.predict(numpyInputs2) 那么 mymlp 如何减少功能? [阅读:分号=换行]
    • 我认为,使用 input_data = Input(shape=(num_neurons1,)); encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data); encoded1 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded); encoder = Model(inputs=input_data, outputs=encoded1); encoder.predict(numpyInputs) 之类的东西会减少功能。请让我知道我哪里出错了。但是在测试模型时我会使用什么作为 x_test - mymlp.predict(???,y_test) [换行的分号]
    • 您完美地描述了第一种方法。但是如果你打算训练编码器,你需要把它和解码器关联起来,使用y = x
    【解决方案2】:

    我也会采用 Daniel 的第一种方法(为了简单和高效),但如果您对第二种方法感兴趣;例如,如果您对端到端运行网络感兴趣,您可以这样处理:

    # make autoencoders not trainable
    autoencoder1.trainable = False
    autoencoder2.trainable = False
    
    encoded1 = autoencoder1(kerasInputs1)
    encoded2 = autoencoder2(kerasInputs2)
    
    concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
    h = Dense(num_hidden, activation='relu', name='hidden')(concatenated)
    y = Dense(1, activation='sigmoid', name='prediction')(h)
    
    myMLP = Model([input_data1, input_data2], y)
    
    myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    # Training
    myMLP.fit([x1_train, x2_train], y_train, epochs=200, batch_size=8)
    
    # Testing
    myMLP.predict([x1_test, x2_test])
    

    关键修改

    1. 两个自动编码器的权重应端到端冻结(否则随机初始化的 MLP 的早期梯度更新可能会导致大部分学习的损失)。
    2. 自动编码器输入层应分配给每个自动编码器的单独变量input_data1input_data2(而不是同时分配给input_data)。即使autoencoder1.inputs 返回一个 tf 张量,这也是unhashable type: list 异常的来源,替换为[input_data1, input_data2] 即可解决问题。
    3. 在为端到端模型拟合 MLP 时,输入应该是 x1_trainx2_train 的列表,而不是串联的输入。预测时也一样。

    【讨论】:

    • 感谢@twolffpiggott 的建议。你的意思是我应该冻结两个自动编码器中最终编码器层的权重吗?我已经更新了我的问题以分享代码 sn-p。仍然出现错误 - 不可散列的类型:'list'。
    • 你提到的-encoded1 = autoencoder1(kerasInputs1) encoded2 = autoencoder2(kerasInputs2)。我可以使用encoded1 = autoencoder1.get_layer('encoder2').output 和encoded2 = autoencoder1.get_layer('encoder2').output 而不是你提到的吗?会有什么不同?
    • 这种方法的用处在于构建一个端到端模型,您可以将原始输入提供给您的自动编码器。这意味着将两个完整的自动编码器模型与密集层组合在一起。如果您愿意,这还允许您智能地重新训练关于 MLP 目标的自动编码器权重。您提到的另一种方法与 Daniel 的第一种方法类似。
    • 感谢您的澄清。
    猜你喜欢
    • 1970-01-01
    • 2021-05-22
    • 1970-01-01
    • 1970-01-01
    • 2019-09-19
    • 1970-01-01
    • 1970-01-01
    • 2020-01-11
    • 2022-01-15
    相关资源
    最近更新 更多