【问题标题】:Tensorflow Keras cannot properly resume training at initial epoch from checkpoint fileTensorflow Keras 无法从检查点文件在初始时期正确恢复训练
【发布时间】:2019-10-07 09:20:09
【问题描述】:

我正在 tensorflow 中加载 keras 模型以恢复训练。我想从我停止的时期继续训练,以便时期编号是唯一的,并且我可以跟踪时期的数量。该模型是从保存最高精度的回调创建的检查点文件中加载的。当我在 model.fit() 中恢复训练时,我将“初始纪元”设置为 52,并将“纪元”设置为 52+5。但是,它从 1/57 而不是 53/57 开始训练,即使我只想要 5 个 epoch,它也会继续上升到 57。我是否加载错误?训练恢复为“正常”,准确性是我停止的地方,但纪元数不会从我想要的地方继续,而是从 1 开始重新开始。

我尝试在从检查点文件加载时删除检查点回调初始化,但由于未定义“回调列表”,因此会产生名称错误。

model = load_model('my_model.hdf5')
checkpoint = ModelCheckpoint(cp_filepath, monitor='acc', 
verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

bs=32 #batch size
epoch count=52
cur_epochs=5
model.fit(
    training_set,
    steps_per_epoch=len(training_set)//bs,
    inital_epoch=epoch_count,
    epochs=cur_epochs+epoch_count,
    validation_data=test_set,
    validation_steps=len(test_set)//bs,
    callbacks=callbacks_list, 
    shuffle=True,
    verbose=1
    )

从保存的文件恢复时,我希望看到第 53/57 个时期和 5 个训练时期。 我得到了 1/57 和 57 个 epoch 的训练

【问题讨论】:

    标签: python tensorflow keras callback checkpoint


    【解决方案1】:

    有同样的问题, 为了解决这个问题,我修改了 ModelCheckpoint(Callback) 类。

    我在 on_epoch_begin 回调函数中为 epoch 添加并保存了一个专用的 tensorflow 检查点。

    The network doesn't store its training progress with respect to training data - this is not part of its state, because at any point you could decide to change what data set to feed it.

    class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
    
        def __init__(self,filepath, monitor='val_loss', verbose=1, 
                     save_best_only=True, save_weights_only=True, 
                     mode='auto', ):
    
            super(EpochModelCheckpoint, self).__init__(filepath=filepath,monitor=monitor,
                 verbose=verbose,save_best_only=save_best_only,
                 save_weights_only=save_weights_only, mode=mode)
    
            self.ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
            ckpt_dir = f'{os.path.dirname(filepath)}/tf_ckpts'
            self.manager = tf.train.CheckpointManager(self.ckpt, ckpt_dir, max_to_keep=3)
    
        def on_epoch_begin(self,epoch,logs=None):        
            self.ckpt.completed_epochs.assign(epoch)
            self.manager.save()
            print( f"Epoch checkpoint {self.ckpt.completed_epochs.numpy()}  saved to: {self.manager.latest_checkpoint}" ) 
            print(logs)
    
    def callbacks(checkpoint_dir, model_name):
    
        best_model = os.path.join(checkpoint_dir, '{}_best.hdf5'.format(model_name))
        save_best = EpochModelCheckpoint( best_model  )
        return [ save_best ]
    
    def train():
    
        ...
    
        model = get_compiled_model()
        checkpoint_dir = "./checkpoint_dir"
        model_name = "my_model"
        # Init checkpoint value
        ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
        manager = tf.train.CheckpointManager(ckpt, f'{checkpoint_dir}/tf_ckpts', max_to_keep=3)    
    
        best_weights = os.path.join(checkpoint_dir, f'{model_name}_best.hdf5') 
        if os.path.exists(best_weights):
            print(f'Loading model {best_weights}')
            model.load_weights(best_weights)
    
            # Restore last Epoch
            ckpt.restore(manager.latest_checkpoint)
            if manager.latest_checkpoint:
                print(f"Restored epoch ckpt from {manager.latest_checkpoint}, value is ",ckpt.completed_epochs.numpy())
            else:
                print("Initializing from scratch.")
    
         ...
        # Set initial_epoch in the model fit to last seen Epoch
        completed_epochs=ckpt.completed_epochs.numpy()
        history = model.fit(
            x=train_ds,
            epochs=cfg.epochs,
            steps_per_epoch=cfg.steps,
            callbacks=callbacks( checkpoint_dir, model_name ),        
            validation_data=val_ds,
            validation_steps=cfg.val_steps,
            initial_epoch=completed_epochs )
    

    【讨论】:

      猜你喜欢
      • 2016-09-29
      • 2018-08-05
      • 1970-01-01
      • 1970-01-01
      • 2018-02-16
      • 1970-01-01
      • 1970-01-01
      • 2022-10-23
      • 2023-02-22
      相关资源
      最近更新 更多