【发布时间】:2022-02-24 18:00:47
【问题描述】:
我正在编写基于 [Pix2Pix tensorflow 教程][教程] 的代码,并且我正在尝试遵循 Wasserstein GAN (WGAN) 要求:(a) 权重裁剪,(b) 鉴别器的线性激活,(c) Wasserstein 损失,以及 (d) 对每个生成器步骤多次训练鉴别器。
我有一个自定义训练循环,使用两个渐变磁带(例如教程中的)。训练步骤的代码如下所示:
@tf.function
def train_step(input_image, target, step):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_total_loss,
generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.trainable_variables))
我的问题:我如何调整代码以多次训练鉴别器为我训练生成器的每一个?
【问题讨论】:
标签: python tensorflow generative-adversarial-network gradienttape