【问题标题】:Adjusting GAN hyperparameters调整 GAN 超参数
【发布时间】:2020-12-07 19:03:20
【问题描述】:

如上两张图所示,在训练一个DCGAN模型时,梯度不稳定,波动很大。由于这个原因,模型无法绘制出完美的图像,甚至无法绘制出被人识别的图像人眼。有没有人能告诉我如何调整诸如辍学率或学习率之类的参数以使模型运行得更好?我将非常感谢你! 这是我之前制作的模型(使用 Keras 构建):

鉴别器:

学习率为 0.0005

辍学率为 0.6

batch_size 为 25

dis=Sequential()

dis.add(Conv2D(depth*1, 5, strides=2, input_shape=(56,56,3),padding='same',kernel_initializer='RandomNormal', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Conv2D(depth*2, 5, strides=2, padding='same',kernel_initializer='RandomNormal', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Conv2D(depth*4, 5, strides=2, padding='same',kernel_initializer='RandomNormal', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Conv2D(depth*8,5,strides=1,padding='same',kernel_initializer='RandomUniform', bias_initializer='zeros'))

dis.add(LeakyReLU(alpha=alp))

dis.add(Dropout(dropout))

dis.add(Flatten())

dis.add(Dense(1))

dis.add(Activation('sigmoid'))

dis.summary()

dis.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=d_lr))

生成器和 GAN 模型:

学习率为 0.0001

动量为 0.9

gen=Sequential()

gen.add(Dense(dim*dim*dep,input_dim=100))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(Reshape((dim,dim,dep)))

gen.add(Dropout(dropout))

gen.add(UpSampling2D())

gen.add(Conv2DTranspose(int(dep/2),5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(UpSampling2D())

gen.add(Conv2DTranspose(int(dep/4),5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(UpSampling2D())

gen.add(Conv2DTranspose(int(dep/8),5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(BatchNormalization(momentum=momentum))

gen.add(Activation('relu'))

gen.add(Conv2DTranspose(3,5,padding='same',kernel_initializer='RandomNormal', bias_initializer='RandomNormal'))

gen.add(Activation('sigmoid'))

gen.summary()


GAN=Sequential()

GAN.add(gen)

GAN.add(dis)

GAN.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=g_lr))

【问题讨论】:

  • 判别器看起来不错,但生成器不是。我认为您可以更长时间地训练生成器并降低学习率。

标签: keras deep-learning generative-adversarial-network


【解决方案1】:

稳定的 GAN 训练是一个开放的研究问题。不过,我可以给你两个提示。如果您坚持使用原始的 GAN 训练例程并且对自己在做什么没有绝对的了解,请使用 DCGAN 架构和他们的论文 (https://arxiv.org/pdf/1511.06434.pdf%C3%AF%C2%BC%E2%80%B0) 中描述的可用超参数。 GAN 训练非常不稳定,使用其他超参数会导致模式崩溃或梯度消失。

使用 GAN 更简单的方法是使用 Wasserstein GAN。这些对于 abritrary 架构来说是相当稳定的。但是,我强烈建议使用他们论文中建议的超参数,因为对我来说,不同超参数的训练也崩溃了。改进的 Wasserstein GAN:[https://arxiv.org/pdf/1704.00028.pdf]

【讨论】:

  • 谢谢。我会尝试 Wasserstein GAN。
猜你喜欢
  • 2021-06-14
  • 2020-07-24
  • 2017-10-26
  • 2018-07-24
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2023-02-12
  • 2021-07-04
相关资源
最近更新 更多