有同样的问题,
为了解决这个问题,我修改了 ModelCheckpoint(Callback) 类。
我在 on_epoch_begin 回调函数中为 epoch 添加并保存了一个专用的 tensorflow 检查点。
The network doesn't store its training progress with respect to training data - this is not part of its state, because at any point you could decide to change what data set to feed it.
class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
def __init__(self,filepath, monitor='val_loss', verbose=1,
save_best_only=True, save_weights_only=True,
mode='auto', ):
super(EpochModelCheckpoint, self).__init__(filepath=filepath,monitor=monitor,
verbose=verbose,save_best_only=save_best_only,
save_weights_only=save_weights_only, mode=mode)
self.ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
ckpt_dir = f'{os.path.dirname(filepath)}/tf_ckpts'
self.manager = tf.train.CheckpointManager(self.ckpt, ckpt_dir, max_to_keep=3)
def on_epoch_begin(self,epoch,logs=None):
self.ckpt.completed_epochs.assign(epoch)
self.manager.save()
print( f"Epoch checkpoint {self.ckpt.completed_epochs.numpy()} saved to: {self.manager.latest_checkpoint}" )
print(logs)
def callbacks(checkpoint_dir, model_name):
best_model = os.path.join(checkpoint_dir, '{}_best.hdf5'.format(model_name))
save_best = EpochModelCheckpoint( best_model )
return [ save_best ]
def train():
...
model = get_compiled_model()
checkpoint_dir = "./checkpoint_dir"
model_name = "my_model"
# Init checkpoint value
ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
manager = tf.train.CheckpointManager(ckpt, f'{checkpoint_dir}/tf_ckpts', max_to_keep=3)
best_weights = os.path.join(checkpoint_dir, f'{model_name}_best.hdf5')
if os.path.exists(best_weights):
print(f'Loading model {best_weights}')
model.load_weights(best_weights)
# Restore last Epoch
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print(f"Restored epoch ckpt from {manager.latest_checkpoint}, value is ",ckpt.completed_epochs.numpy())
else:
print("Initializing from scratch.")
...
# Set initial_epoch in the model fit to last seen Epoch
completed_epochs=ckpt.completed_epochs.numpy()
history = model.fit(
x=train_ds,
epochs=cfg.epochs,
steps_per_epoch=cfg.steps,
callbacks=callbacks( checkpoint_dir, model_name ),
validation_data=val_ds,
validation_steps=cfg.val_steps,
initial_epoch=completed_epochs )