【发布时间】:2021-10-25 13:32:35
【问题描述】:
我目前正在尝试添加一项功能,以中断和恢复从以下示例代码创建的 GAN 上的训练:https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/
我设法让它工作,我将整个复合 GAN 的权重保存在 summarise_performance 函数中,该函数每 10 个 epoch 触发一次,如下所示:
# save all weights
filename3 = 'weights_%08d.h5' % (step+1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))
它加载在我添加到程序开头的一个名为 load_model 的函数中,该函数采用正常构建的 gan 架构,但将其权重更新为最新值,如下所示:
#load model from file and return startBatch number
def load_model(gan_model):
start_batch = 0
files = glob.glob("./weights_0*.h5")
if(len(files) > 0 ):
most_recent_file = files[len(files)-1]
gan_model.load_weights(most_recent_file)
#TODO: breaks if using more than 8 digits for batches
startBatch = int(most_recent_file[10:18])
if (start_batch != 0):
print("> found existing weights; starting at batch %d" % start_batch)
return start_batch
将 start_batch 传递给 train 函数以跳过已完成的 epoch。
虽然这种减轻重量的方法确实“有效”,但我仍然认为我的方法是错误的,因为我发现重量数据显然不包括 GAN 的优化器状态,因此训练不会继续如果它没有被打断的话。
我发现保存进度同时保存优化器状态的方法显然是通过保存整个模型而不是权重来完成的
在这里我遇到了一个问题,因为在 GAN 中我不仅有一个模型可以训练,而且我有 3 个模型:
- 生成器模型 g_model
- 判别器模型 d_model
- 和复合 GAN 模型 gan_model
它们都相互连接并相互依赖。如果我采用幼稚的方法并分别保存和恢复这些部分模型,我最终会得到 3 个独立的脱节模型而不是 GAN
有没有办法保存和恢复整个 GAN,让我可以像没有中断一样继续训练?
【问题讨论】:
标签: python tensorflow keras deep-learning generative-adversarial-network