【问题标题】:Keras Conditional GAN does not train wellKeras Conditional GAN 训练不好
【发布时间】:2017-05-17 16:42:31
【问题描述】:

我正在尝试基于 keras-dcgan (https://github.com/jacobgil/keras-dcgan) 上的 jacob 代码构建条件 GAN 模型。

我假设的模型架构如下图:

原纸: http://cs231n.stanford.edu/reports/2015/pdfs/jgauthie_final_report.pdf

对于生成器,我插入条件(在这种情况下,条件是一堆单热向量),首先将其与噪声连接,然后通过生成器提供连接。

对于鉴别器,我通过与模型中间的扁平层连接来插入条件。

我的代码运行,但它生成一些随机图而不是特定数字。哪一步错了?我没有适当地插入条件吗?

运行大约 5500 次迭代后的结果:

代码:

import warnings
warnings.filterwarnings('ignore')

from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Input, merge
from keras.layers import Reshape, concatenate
from keras.layers.core import Activation
from keras.models import Model
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
import tensorflow as tf
from PIL import Image
import argparse
import math
K.set_image_dim_ordering('th')

# based on the labels below, we create a flattened array with 10 one-hot-vectors, and call it y_prime
labels = np.array([0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4,5,5,5,5,5,5,5,5,5,5,6,6,6,6,6,6,6,6,6,6,7,7,7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8,8,8,9,9,9,9,9,9,9,9,9,9])

def dense_to_one_hot(labels_dense, num_classes=10):
    """Convert class labels from scalars to one-hot vectors."""
    num_labels = labels_dense.shape[0]
    index_offset = np.arange(num_labels) * num_classes
    labels_one_hot = np.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot

# y_dim is the number of labels in one hot vector form, hence its 10
# y_prime is a 100*10 matrix, and len(y_p) = 100. Note that len(y_prime) must equate to batch_size for the matrices to be properly concatenated
# Also y_dim=10, which is the size of any one-hot vector
y_p = dense_to_one_hot(labels)
y_size = len(y_p)
y_dim = len(y_p[0])


#g_inputs is the input for generator
#auxiliary_input is the condition
#d_inputs is the input for discriminator
g_inputs = (Input(shape=(100,), dtype='float32'))
auxiliary_input = (Input(shape=(y_dim,), dtype='float32'))
d_inputs = (Input(shape=(1,28,28), dtype='float32'))

def generator_model():
    T = concatenate([g_inputs,auxiliary_input])
    T = (Dense(1024))(T)
    T = (Dense(128*7*7))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (Reshape((128, 7, 7), input_shape=(128*7*7,)))(T)
    T = (UpSampling2D(size=(2, 2)))(T)
    T = (Convolution2D(64, 5, 5, border_mode='same'))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (UpSampling2D(size=(2, 2)))(T)
    T = (Convolution2D(1, 5, 5, border_mode='same'))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    model = Model(input=[g_inputs,auxiliary_input], output=T)
    return model

def discriminator_model():
    T = (Convolution2D(filters= 64, kernel_size= (5,5), padding='same'))(d_inputs)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (MaxPooling2D(pool_size=(2, 2)))(T)
    T = (Convolution2D(128, 5, 5))(T)
    T = (BatchNormalization())(T)
    T = (Activation('tanh'))(T)
    T = (MaxPooling2D(pool_size=(2, 2)))(T)
    T = (Flatten())(T)
    T = concatenate([T, auxiliary_input])
    T = (Dense(1024))(T)
    T = (Activation('tanh'))(T)
    T = (Dense(1))(T)
    T = (Activation('sigmoid'))(T)
    model = Model(input=[d_inputs,auxiliary_input], output=T)
    return model

def generator_containing_discriminator(generator, discriminator):
    T1 = generator([g_inputs, auxiliary_input])
    discriminator.trainable = False
    T2 = discriminator([T1,auxiliary_input])
    model = Model(input=[g_inputs, auxiliary_input], output=T2)
    return model

def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[2:]
    image = np.zeros((height*shape[0], width*shape[1]), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[0, :, :]
    return image


def train(BATCH_SIZE,y_prime):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    discriminator = discriminator_model()
    generator = generator_model()
    discriminator_on_generator = generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
    noise = np.zeros((BATCH_SIZE, 100))
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            y_batch = dense_to_one_hot(y_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE])
            y_batch = np.concatenate((y_batch , y_prime))
            generated_images = generator.predict([noise,y_prime], verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch([X,y_batch], y)
            print("batch %d d_loss : %f" % (index, d_loss))
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch([noise,y_prime], [1] * BATCH_SIZE)
            discriminator.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)



train(100,y_p)

【问题讨论】:

  • 您是否尝试过检查您的图层是否冻结得很好?这往往会引起很多错误。

标签: python machine-learning keras


【解决方案1】:

这是我使用 Keras 构建条件 GAN (CGAN) 的代码:https://github.com/hklchung/GAN-GenerativeAdversarialNetwork/tree/master/CGAN

在 MNIST 上运行 5 个 epoch 后,我得到了这个: MNIST CGAN output

在 CelebsA 数据集上 50 个 epoch 后: CelebA CGAN output

我的经验是,如果您在 20 个 epoch 后没有看到任何好的结果,则说明您的模型有问题,再训练它不会提高您的图像质量。

【讨论】:

    猜你喜欢
    • 2018-12-06
    • 2020-03-06
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-06-24
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多