【问题标题】:How to merge keras sequential models with same input?如何合并具有相同输入的 keras 顺序模型?
【发布时间】:2018-02-06 10:08:32
【问题描述】:

我正在尝试在 keras 中创建我的第一个集成模型。我的数据集中有 3 个输入值和一个输出值。

from keras.optimizers import SGD,Adam
from keras.layers import Dense,Merge
from keras.models import Sequential

model1 = Sequential()
model1.add(Dense(3, input_dim=3, activation='relu'))
model1.add(Dense(2, activation='relu'))
model1.add(Dense(2, activation='tanh'))
model1.compile(loss='mse', optimizer='Adam', metrics=['accuracy'])

model2 = Sequential()
model2.add(Dense(3, input_dim=3, activation='linear'))
model2.add(Dense(4, activation='tanh'))
model2.add(Dense(3, activation='tanh'))
model2.compile(loss='mse', optimizer='SGD', metrics=['accuracy'])

model3 = Sequential()
model3.add(Merge([model1, model2], mode = 'concat'))
model3.add(Dense(1, activation='sigmoid'))
model3.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['accuracy'])

model3.input_shape

集成模型 (model3) 编译时没有任何错误,但在拟合模型时,我必须将相同的输入传递两次 model3.fit([X,X],y)。我认为这是一个不必要的步骤,而不是两次传递输入,我想为我的集成模型提供一个公共输入节点。我该怎么做?

【问题讨论】:

    标签: python keras keras-layer ensemble-learning


    【解决方案1】:

    定义新的输入层并直接使用模型输出(在函数式 api 中工作):

    assert model1.input_shape == model2.input_shape # make sure they got same shape
    
    inp = tf.keras.layers.Input(shape=model1.input_shape[1:])
    model = tf.keras.models.Model(inputs=[inp], outputs=[model1(inp), model2(inp)])
    

    【讨论】:

      【解决方案2】:

      etov 的回答是一个不错的选择。

      但是假设您已经准备好model1model2 而您不想更改它们,您可以像这样创建第三个模型:

      singleInput = Input((3,))
      
      out1 = model1(singleInput)   
      out2 = model2(singleInput)
      #....
      #outN = modelN(singleInput)
      
      out = Concatenate()([out1,out2]) #[out1,out2,...,outN]
      out = Dense(1, activation='sigmoid')(out)
      
      model3 = Model(singleInput,out)
      

      如果您已经准备好所有模型并且不想更改它们,您可以拥有这样的东西(未经测试):

      singleInput = Input((3,))
      output = model3([singleInput,singleInput])
      singleModel = Model(singleInput,output)
      

      【讨论】:

      • 这确实是一个不错的选择,实际上我认为这两种方式几乎是等价的,除了输入层是嵌入。在这种情况下,使用公共输入层与为每个模型使用不同的输入层会有所不同(两者都是有效的 - 正确的选择取决于应用程序)
      • 是的。但生成的模型不是连续的。
      【解决方案3】:

      Keras functional API 似乎更适合您的用例,因为它允许计算图的更大灵活性。例如:

      from keras.layers import concatenate
      from keras.models import Model
      from keras.layers import Input, Merge
      from keras.layers.core import Dense
      from keras.layers.merge import concatenate
      
      # a single input layer
      inputs = Input(shape=(3,))
      
      # model 1
      x1 = Dense(3, activation='relu')(inputs)
      x1 = Dense(2, activation='relu')(x1)
      x1 = Dense(2, activation='tanh')(x1)
      
      # model 2 
      x2 = Dense(3, activation='linear')(inputs)
      x2 = Dense(4, activation='tanh')(x2)
      x2 = Dense(3, activation='tanh')(x2)
      
      # merging models
      x3 = concatenate([x1, x2])
      
      # output layer
      predictions = Dense(1, activation='sigmoid')(x3)
      
      # generate a model from the layers above
      model = Model(inputs=inputs, outputs=predictions)
      model.compile(optimizer='adam',
                    loss='binary_crossentropy',
                    metrics=['accuracy'])
      
      # Always a good idea to verify it looks as you expect it to 
      # model.summary()
      
      data = [[1,2,3], [1,1,3], [7,8,9], [5,8,10]]
      labels = [0,0,1,1]
      
      # The resulting model can be fit with a single input:
      model.fit(data, labels, epochs=50)
      

      注意事项:

      • Keras 版本(前版本和后版本 2)之间的 API 可能略有不同
      • 上面的示例为每个模型指定了不同的优化器和损失函数。但是,由于 fit() 仅被调用一次(在模型 3 上),相同的设置 - 模型 3 的设置 - 将应用于整个模型。为了在训练子模型时有不同的设置,它们必须分别进行 fit() - 查看@Daniel 的评论。

      编辑:基于 cmets 的更新笔记

      【讨论】:

      • 仅当您将 fit 用于该特定模型时,才会考虑模型的编译(优化器和损失)。如果你在model3中使用fit,只有model3的编译才会生效。 --- 根本不需要编译model1model2,除非你要单独训练它们(使用model1.fitmodel2.fit)。权重和预测不需要compile
      猜你喜欢
      • 1970-01-01
      • 2019-08-09
      • 1970-01-01
      • 1970-01-01
      • 2019-11-28
      • 2018-03-05
      • 2019-02-13
      • 2018-02-09
      • 2021-02-24
      相关资源
      最近更新 更多