【问题标题】:How to solve the problem that Trainable-False does not work in Keras?如何解决 Trainable-False 在 Keras 中不起作用的问题?
【发布时间】:2019-09-02 23:08:50
【问题描述】:

我现在面临“trainable=false”的故障。

当我开发的代码具有这样的结构时,

该模型有两个细分模型(FC模型,CN模型),它们以串行方式连接。

只训练FC模型后,我想冻结FC并训练FC+CN,整个模型。

但是用 trainable 冻结不起作用,并且出现了一些奇怪的东西。

不冻结时:

model.FCnetwork.trainable = True
model.FCnetwork.summary()
Total params: 2,584,576
Trainable params: 2,578,432
Non-trainable params: 6,144

当冻结时:

model.FCnetwork.trainable = False
model.FCnetwork.summary()
Total params: 5,163,008
Trainable params: 2,578,432
Non-trainable params: 2,584,576

总参数增加。当然,冻结不起作用。

这是我设计的课程

class MYMAP():
    def __init__(self):
        # Input shape


        optimizer = optimizers.Adam()

        self.CNnetwork= self.Convolutional_network()
        self.CNnetwork.compile()




        self.FCnetwork = self.Fullyconnected_network()
        self.FCnetwork.compile(loss='mse',
            optimizer=optimizer)

        z = Input(shape=(input_size,))
        img = self.FCnetwork(z)

        valid = self.CNnetwork(img)

        self.combined = Model(z, valid)

        optimizer_DG = optimizers.Adam()
        self.combined.compile(loss='mse', optimizer=optimizer_DG)

    def Fullyconnected_network(self):

        noise = Input(shape=(input_size,))
        img = model(noise)

        return Model(noise, img)




    def Convolutional_network(self):

        img = Input(shape=(image_size_vectored,))
        validity = model(img)

        return Model(img, validity)

想出解决方法对我来说有点困难。

非常感谢。

【问题讨论】:

    标签: tensorflow keras deep-learning


    【解决方案1】:

    警告说得很清楚

    你是否设置了model.trainable而不调用model.compile

    正确的示例代码:

    class MYMAP():
        def __init__(self):        
            self.optimizer = optimizers.Adam()
            self.FCnetwork = self.Fullyconnected_network()
    
            self.FCnetwork.compile(loss='mse',
                optimizer=self.optimizer)
    
            z = Input(shape=(32,))
            img = self.FCnetwork(z)
    
    
        def Fullyconnected_network(self):            
            noise = Input(shape=(32,))        
            img = Dense(8)(noise)
            return Model(noise, img)
    
    model = MYMAP()
    model.FCnetwork.trainable = True
    model.FCnetwork.compile(loss='mse', optimizer=optimizers.Adam())
    model.FCnetwork.summary()
    model.FCnetwork.trainable = False
    model.FCnetwork.compile(loss='mse', optimizer=optimizers.Adam())
    model.FCnetwork.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_39 (InputLayer)        (None, 32)                0         
    _________________________________________________________________
    dense_15 (Dense)             (None, 8)                 264       
    =================================================================
    Total params: 264
    Trainable params: 264
    Non-trainable params: 0
    _________________________________________________________________
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_39 (InputLayer)        (None, 32)                0         
    _________________________________________________________________
    dense_15 (Dense)             (None, 8)                 264       
    =================================================================
    Total params: 264
    Trainable params: 0
    

    因此,请确保在更改模型的可训练参数后运行 model.compile。

    【讨论】:

      猜你喜欢
      • 2019-09-21
      • 1970-01-01
      • 1970-01-01
      • 2014-01-28
      • 1970-01-01
      • 2013-07-31
      • 2011-04-30
      • 2019-10-20
      • 2021-11-08
      相关资源
      最近更新 更多