【发布时间】:2019-09-17 08:54:48
【问题描述】:
我从 TF2.0 文档中提取了一些代码来从自定义数据集生成图像。代码是here
由于文档使用 Keras,我想我可能会将鉴别器网络更改为预训练网络,例如 InceptionV3,并且只训练顶层。我找到了this 代码(在一组新类上微调 InceptionV3)。我似乎无法弄清楚如何用另一个替换一个。我知道我试图用功能 API 替换顺序模式。但我想它们在某种程度上是相互关联的。不过,我不是 Keras 的常客。
我的问题是:如何将顺序模式中的自定义 CNN 替换为功能 API 中预训练的 CNN 以用作鉴别器?
编辑:如果有人有使用 GANEstimator 的示例,我会很高兴,因为我更习惯于 TF。
使用生成器生成随机图像
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 3)
return model
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
当前的鉴别器和辅助器(输出 tf.Tensor([[-0.0003378]], shape=(1, 1), dtype=float32))
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
所需的鉴别器
def make_discriminator_model():
# create the base pre-trained model
model = InceptionV3(weights='imagenet', include_top=False)
# ADD TOP LAYERS
# FREEZE ALL LAYERS EXCEPT TOP LAYERS
return model
# COMPILE
def discriminator_loss(real_output, fake_output):
real_loss = ??? # Real Loss
fake_loss = ??? # Fake loss
total_loss = real_loss + fake_loss
return total_loss
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
所有导入
from __future__ import absolute_import, division, print_function, unicode_literals
try:
# %tensorflow_version only exists in Colab.
%tensorflow_version 2.x
except Exception:
pass
import tensorflow as tf
print('TF version: {}'.format(tf.__version__))
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from PIL import Image
from tensorflow.keras import layers
import time
from IPython import display
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import vgg16
import os.path
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras import backend as K
编辑: 这是我最终得到的鉴别器!感谢@pandrey
def make_discriminator_model():
pre_trained = tf.keras.applications.InceptionV3(
weights='imagenet', include_top=False, input_shape=IMG_SHAPE
)
pre_trained.trainable = False # mark all weights as non-trainable
model = tf.keras.Sequential([pre_trained])
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(1))
return model
【问题讨论】:
标签: python tensorflow keras generative-adversarial-network