【问题标题】:Discriminator Loss Not Changing in Generative Adversarial Network生成对抗网络中的鉴别器损失没有改变
【发布时间】:2021-02-19 10:45:42
【问题描述】:

我正在尝试使用 pix2pix GAN 生成器和 Unet 作为鉴别器来训练 GAN。但经过一些时期后,我的鉴别器损失停止变化并停留在 5.546 附近。 GAN 训练是好兆头还是坏兆头。

这是我的损失计算:

def discLoss(rValid, rLabel, fValid, fLabel):
    # validity loss
    bce =     tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
    # classifier loss
    scce =     tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    # Loss for real
    real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
    # Loss for fake
    fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
    # Total discriminator loss
    d_loss = (real_dloss + fake_dloss)# / 2
    return d_loss

def generator_loss(disc_generated_output, gen_output, target):
  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  LAMBDA = 100
  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

这是我的火车步骤:

def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    fImg = generator([img1, label], training=True)
    rValid, rLabel = discriminator(img2, training=True)
    fValid, fLabel = discriminator(fImg, training=True)

    disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
    gen_loss = generator_loss(fValid, fImg, img2)
    # genLoss(label, rValid, rLabel, fValid, fLabel)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()

【问题讨论】:

  • 请直接复制代码,不要链接到图片。这是你的全部代码吗?

标签: python keras tensorflow2.0 loss-function generative-adversarial-network


【解决方案1】:

这个损失太高了。您需要注意 G 和 D 都以均匀的速度学习。访问此问题和相关链接:How to balance the generator and the discriminator performances in a GAN?

【讨论】:

    猜你喜欢
    • 2017-07-30
    • 2018-07-26
    • 2017-11-27
    • 2017-11-08
    • 2019-12-22
    • 2020-09-25
    • 2020-05-19
    • 1970-01-01
    • 2017-12-08
    相关资源
    最近更新 更多