【发布时间】:2020-06-26 11:23:27
【问题描述】:
据我了解,常规 GAN 与 WGAN 之间的区别在于,我们在每个 epoch 中使用更多示例来训练鉴别器/批评者。如果在常规 GAN 中,我们在每个 epoch 中都有一个批次用于两个模块,那么在 WGAN 中,我们将有 5 个批次(或更多)用于鉴别器,一个用于生成器。
所以基本上我们有另一个判别器的内部循环:
real_images_labels = np.ones((BATCH_SIZE, 1))
fake_images_labels = -real_images_labels
for epoch in range(epochs):
for batch in range(NUM_BACHES):
for critic_iter in range(n_critic):
random_batches_idx = np.random.randint(0, NUM_BACHES) # Choose random batch from dataset
imgs_data=dataset_list[random_batches_idx]
c_loss_real = critic.train_on_batch(imgs_data, real_images_labels) # update the weights after 1 batch
noise = tf.random.normal([imgs_data.shape[0], noise_dim]) # Generate noise data
generated_images = generator(noise, training=True)
c_loss_fake = critic.train_on_batch(generated_images, fake_images_labels) # update the weights after 1 batch
imgs_data=dataset_list[batch]
noise = tf.random.normal([imgs_data.shape[0], noise_dim]) # Generate noise data
gen_loss_batch = gen_loss_batch + gan.train_on_batch(noise,real_images_labels)
训练花费了我很多时间,每个 epoch 大约 3m。我不得不减少训练时间的想法是为每个批次向前运行 n_critic 次,我可以增加鉴别器的 batch_size 并使用更大的 batch_size 向前运行一次。
我正在寻求反馈:这听起来合理吗?
(我没有粘贴我的整个代码,它只是其中的一部分)。
【问题讨论】:
标签: tensorflow keras generative-adversarial-network