【问题标题】:How to save and resume training a GAN with multiple model parts with Tensorflow 2/ Keras如何使用 Tensorflow 2/Keras 保存和恢复训练具有多个模型部分的 GAN
【发布时间】:2021-10-25 13:32:35
【问题描述】:

我目前正在尝试添加一项功能,以中断和恢复从以下示例代码创建的 GAN 上的训练:https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/

我设法让它工作,我将整个复合 GAN 的权重保存在 summarise_performance 函数中,该函数每 10 个 epoch 触发一次,如下所示:

# save all weights
filename3 = 'weights_%08d.h5' % (step+1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))

它加载在我添加到程序开头的一个名为 load_model 的函数中,该函数采用正常构建的 gan 架构,但将其权重更新为最新值,如下所示:

#load model from file and return startBatch number
def load_model(gan_model):
   start_batch = 0
   files = glob.glob("./weights_0*.h5")
   if(len(files) > 0 ):
       most_recent_file = files[len(files)-1]
       gan_model.load_weights(most_recent_file)
       #TODO: breaks if using more than 8 digits for batches
       startBatch = int(most_recent_file[10:18])
       if (start_batch != 0):
           print("> found existing weights; starting at batch %d" % start_batch)
   return start_batch

将 start_batch 传递给 train 函数以跳过已完成的 epoch。

虽然这种减轻重量的方法确实“有效”,但我仍然认为我的方法是错误的,因为我发现重量数据显然不包括 GAN 的优化器状态,​​因此训练不会继续如果它没有被打断的话。

我发现保存进度同时保存优化器状态的方法显然是通过保存整个模型而不是权重来完成的

在这里我遇到了一个问题,因为在 GAN 中我不仅有一个模型可以训练,而且我有 3 个模型:

  • 生成器模型 g_model
  • 判别器模型 d_model
  • 和复合 GAN 模型 gan_model

它们都相互连接并相互依赖。如果我采用幼稚的方法并分别保存和恢复这些部分模型,我最终会得到 3 个独立的脱节模型而不是 GAN

有没有办法保存和恢复整个 GAN,让我可以像没有中断一样继续训练?

【问题讨论】:

    标签: python tensorflow keras deep-learning generative-adversarial-network


    【解决方案1】:

    如果您想恢复整个 GAN,可以考虑使用tf.train.Checkpoint

    ### In your training loop
    
    checkpoint_dir = '/checkpoints'
    checkpoint = tf.train.Checkpoint(gan_optimizer=gan_optimizer,
                                discriminator_optimizer=discriminator_optimizer,
                                      generator=generator,
                                      discriminator=discriminator
                                      gan_model = gan_model)
      
    ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint)  
        print ('Latest checkpoint restored!!')
    
    ....
    ....
    
    
    if (epoch + 1) % 40 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))
    
    ### After x number of epochs, just save your generator model for inference.
    
    generator.save('your_model.h5')
    

    您也可以考虑完全摆脱复合模型。 Here 就是我的意思的一个例子。

    【讨论】:

    • 感谢您的回答,检查点似乎真的是要走的路。我只是好奇他们是否需要像您的示例中那样使用覆盖的 tensorflow train_step() ?如果是这样,那么我不能使用它们,因为我的基本实现使用 model.train_on_batch()
    • 好问题。检查点不需要它,但使用自定义训练循环与 model.train_on_batch() 一样有效,它会为您提供所需的所有灵活性。以tutorial 为例。
    • 感谢您的快速回答。我肯定会尝试添加检查点。另一个简单的问题:我应该如何处理保持批号一致?使用我刚刚从文件名中的批次/纪元迭代编码的权重,但检查点似乎不可能。
    • 您能解释一下批号一致性的确切含义吗?如果您想将纪元编号用作检查点文件的一部分,则只需执行ckpt_manager.save(checkpoint_number=epoch)。查看docs了解更多信息。
    • 基本上我想记录我所处的批次/时期/迭代,并在循环中的那个点而不是 0 处重新开始训练,以保持任何快照或工件的编号一致。我设法让它在我的训练循环中与ckpt_manager.save(checkpoint_number=i+1)start_batch = int(ckpt_manager.latest_checkpoint[19:]) 一起工作,非常感谢,我对这个恢复功能的实现现在似乎可以工作了。我会将您的答案标记为已接受
    猜你喜欢
    • 2018-02-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-04-25
    • 2019-08-27
    • 1970-01-01
    • 2016-02-18
    相关资源
    最近更新 更多