【问题标题】:In a GAN with custom training loop, how can I train the discriminator more times than the generator (such as in WGAN) in tensorflow在具有自定义训练循环的 GAN 中,如何在 tensorflow 中训练判别器的次数比生成器(例如 WGAN 中)的次数多
【发布时间】:2022-02-24 18:00:47
【问题描述】:

我正在编写基于 [Pix2Pix tensorflow 教程][教程] 的代码,并且我正在尝试遵循 Wasserstein GAN (WGAN) 要求:(a) 权重裁剪,(b) 鉴别器的线性激活,(c) Wasserstein 损失,以及 (d) 对每个生成器步骤多次训练鉴别器

我有一个自定义训练循环,使用两个渐变磁带(例如教程中的)。训练步骤的代码如下所示:

@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

我的问题:我如何调整代码多次训练鉴别器为我训练生成器的每一个?

【问题讨论】:

    标签: python tensorflow generative-adversarial-network gradienttape


    【解决方案1】:

    您可以对生成器和鉴别器使用单独的梯度带进行训练,并在鉴别器上循环多次。

    @tf.function
    def train_step(input_image, target, step):
      with tf.GradientTape() as gen_tape:
        gen_output = generator(input_image, training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
      generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
      generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    
      disc_train_iterations = 5
      for i in range(disc_train_iterations):
        with tf.GradientTape() as disc_tape:
          gen_output = generator(input_image, training=True)
          disc_real_output = discriminator([input_image, target], training=True)
          disc_generated_output = discriminator([input_image, gen_output], training=True)
          disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    

    【讨论】:

    • 我会试试的,Sascha,然后告诉你。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-05-10
    • 2020-08-11
    • 2018-07-23
    • 2020-02-22
    • 1970-01-01
    • 2019-01-28
    • 1970-01-01
    相关资源
    最近更新 更多