【问题标题】:Keras GAN (generator) not training well despite accurate discriminator尽管判别器准确,Keras GAN(生成器)训练不佳
【发布时间】:2018-07-23 16:23:06
【问题描述】:

我已经尝试整理好几天了,在论坛等上找到了许多建议,现在欢迎任何关于错误的建议!

我正在尝试进行我的第一次 GAN 训练 - 一个简单的前馈深度网络 - 与使用 MNIST 数据集非常相似,但频谱功率窗口源自 VCTK-Corpus(大小(1, 513))。

您可以从下面的 Tensorboard 图表中看到,网络似乎正在交互,并且正在进行某种训练: Tensorboard graph overviewTensorboard graph zoom.

但是,结果很差而且很吵:generated and validation comparison

生成器采用平均值为 0 且标准偏差为 0.5 的正常噪声(通常为 30 到 100 个向量)。

def gan_generator(x_shape, frame_size):
    g_input = Input(shape=x_shape)
    H = BatchNormalization()(g_input)
    H = Dense(128)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(128)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    out = Dense(frame_size[1], activation='linear')(H)

    generator = Model(g_input, out)
    generator.summary()
    return generator

判别器确定生成帧的 one-hot 分类: (不确定此处的批量标准化 - 我已经阅读过,如果您将真实和生成的批次混合在一起,则不应使用它。但是,生成器使用它比不使用它时产生更令人信服的结果 - 尽管损失更高。 )

def gan_discriminator(input_shape):
    d_input = Input(shape=input_shape)
    H = Dropout(0.1)(d_input)
    H = Dense(256)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(128)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(100)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(100)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Reshape((1, -1))(H)
    d_V = Dense(2, activation='softmax')(H)

    discriminator = Model(d_input,d_V)
    discriminator.summary()
    return discriminator

GAN 很简单:

def init_gan(generator, discriminator):
    x = Input(shape=generator.inputs[0].shape[1:])

    #Generator makes a prediction
    pred = generator(x)

    #Discriminator attempts to categorise prediction
    y = discriminator(pred)

    GAN = Model(x, y)
    return GAN

一些训练变量:

  • GAN(生成器):Adam,lr=1e-4,categorical_crossentropy
  • 鉴别器:Adam,lr=1e-3,categorical_crossentropy
  • 批量:大约 8000 个样本
  • 小批量(权重更新周期):32

训练循环:

#Pre-training Discriminator Network
#Load new batch of real frames
frames = load_data(data_dir)
frames_label = np.zeros((frames.shape[0], 1, 2))
frames_label[:, :, 0] = 1 #mark as real frames

#Generate Frames from noise vector
X_noise = noisegen((frames.shape[0], 1, n_noise))
generated_frames = generator.predict(X_noise)
generated_label = np.zeros((generated_frames.shape[0], 1, 2))
generated_label[:, :, 1] = 1 #mark as false frames

#Prep Data - concat real and false data
dis_batch_x = np.concatenate((frames, generated_frames), axis=0)
dis_batch_y = np.concatenate((frames_label, generated_label), axis=0)

#Make discriminator trainable and train for 8 epochs
make_trainable(discriminator, True)
discriminator.compile(optimizer=dis_optimizer, loss=dis_loss)
fit_model(discriminator, dis_batch_x, dis_batch_y, 8)

#Training Loop
for d in range(data_sets):
    print "Starting New Dataset: {0}/{1}".format(d+1, data_sets)

    """ Fit Discriminator """
    #Load new batch of real frames
    frames = load_data(data_dir)
    frames_label = np.zeros((frames.shape[0], 1, 2))
    frames_label[:, :, 0] = 1 #mark as real frames

    #Generate Frames from noise vector
    X_noise = noisegen((frames.shape[0], 1, n_noise))
    generated_frames = generator.predict(X_noise)
    generated_label = np.zeros((generated_frames.shape[0], 1, 2))
    generated_label[:, :, 1] = 1 #mark as false frames

    #Prep Data - concat real and false data
    dis_batch_x = np.concatenate((frames, generated_frames), axis=0)
    dis_batch_y = np.concatenate((frames_label, generated_label), axis=0)

    #Make discriminator trainable & fit
    make_trainable(discriminator, True)
    discriminator.compile(optimizer=dis_optimizer, loss=dis_loss)
    fit_model(discriminator, dis_batch_x, dis_batch_y)


    """ Fit Generator """
    #Prep Data
    X_noise = noisegen((frames.shape[0], 1, n_noise))
    generated_label = np.zeros((generated_frames.shape[0], 1, 2))
    generated_label[:, :, 1] = 1 #mark as false frames

    make_trainable(discriminator, False)
    GAN.layers[2].trainable = False #done twice just to be sure
    GAN.compile(optimizer=GAN_optimizer, loss=GAN_loss) 
    fit_model(GAN, X_noise, generated_label)

最后是一点系统信息:

  • OSX 10.12
  • Tensorflow 1.5.0 (GPU)
  • Keras 2.1.3
  • Python 2.7

提前非常感谢!

【问题讨论】:

  • 我假设您正在使用此脚本 github.com/osh/KerasGAN/blob/master/mnist_gan.py 并认为问题在于您如何启动 GAN init_gan 本身。如果您查看来自line 102-115 的原始脚本,GAN 是通过将生成器和判别器堆叠在一起而制成的,其中判别器的权重被冻结。
  • 感谢您的评论 - 您完全正确,我一直在使用它(以及其他)作为我的 GAN 的指南。实际上,我确实将鉴别器冻结在脚本的底部。
  • 请问在使用此脚本进行训练时,您的内存是否已满或出现类似错误?
  • 不,我没有——我调用一个函数来为每个 Epoch 从磁盘加载一个新的随机批次。您可以在此处查看我如何使用 Multiprocessing 加快速度:stackoverflow.com/questions/48778842/python-multiprocessing-loop/

标签: tensorflow keras


【解决方案1】:

实际的解决方案是我没有在生成器训练中交换我的 True/False 类(建议 https://github.com/soumith/ganhacks),我认为这有效地使其梯度上升。

对此进行澄清会很高兴。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2019-05-10
    • 1970-01-01
    • 2018-12-06
    • 2020-09-25
    • 2020-08-11
    • 1970-01-01
    • 2020-03-06
    • 2020-09-22
    相关资源
    最近更新 更多