3. Star Generative Adversarial Networks

3.1. MultiDomain ImagetoImage Translation

学习目标是训练一个能够在multiple domains之间相互生成的生成器GG

定义xx为输入图像,yy为生成图像,cc为target domain label,于是有G(x,c)yG(x, c)\rightarrow y

判别器DD包含两部分,一部分是常规的判别真假的判别器DsrcD_{src},另一部分是auxiliary classifier DclsD_{cls}

Figure3展示了StarGAN的训练过程
StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation(CVPR18)
Adversarial Loss
Ladv=Ex[logDsrc(x)]+Ex,c[log(1Dsrc(G(x,c)))](1) \begin{aligned} \mathcal{L}_{adv}=&\mathbb{E}_x\left [ \log D_{src}(x) \right ] +\\ &\mathbb{E}_{x,c}\left [ \log\left ( 1-D_{src}\left ( G(x,c) \right ) \right ) \right ] \qquad(1) \end{aligned}

Domain Classification Loss

对于判别器DD,需要正确地将real image xx预测为所对应的domain cc'
Lclsr=Ex,c[logDcls(cx)](2) \mathcal{L}_{cls}^r=\mathbb{E}_{x,c'}\left [ -\log D_{cls}\left ( c'\mid x \right ) \right ] \qquad(2)

对于生成器GG,需要最小化fake image G(x,c)G(x,c)被预测为domain
Lclsf=Ex,c[logDcls(cG(x,c))](3) \mathcal{L}_{cls}^f=\mathbb{E}_{x,c}\left [ -\log D_{cls}\left ( c\mid G(x,c) \right ) \right ] \qquad(3)

Reconstruction Loss

对于生成器GG,只考虑公式(1)和(3)无法保证GG只修改图像中与target domain有关的部分,修改与target domain无关的部分,因此引入文献[8, 32]中提出的cycle consistency loss

Lrec=Ex,c,cxG(G(x,c),c)1(4) \mathcal{L}_{rec}=\mathbb{E}_{x,c,c'}\left \| x-G\left ( G(x,c), c' \right ) \right \|_1 \qquad(4)

Full Objective
LD=Ladv+λclsLclsr(5) \mathcal{L}_D=-\mathcal{L}_{adv}+\lambda_{cls}\mathcal{L}_{cls}^r \qquad(5)
LG=Ladv+λclsLclsf+λrecLrec(6) \mathcal{L}_G=\mathcal{L}_{adv}+\lambda_{cls}\mathcal{L}_{cls}^f+\lambda_{rec}\mathcal{L}_{rec} \qquad(6)
注:DD需要最大化Ladv\mathcal{L}_{adv},所以加上了一个负号

实验中设置λcls=1\lambda_{cls}=1λrec=10\lambda_{rec}=10

3.2. Training with Multiple Datasets

如果涉及多个数据集,每个数据集的attribute是不一样的

Mask Vector

引入mask vector mm用于指示label中哪些分量是已知的

假设使用nn个数据集,则mask vector mm是一个nn维的one-hot向量,并且将domain label扩展为
c~=[c1,,cn,m](7) \tilde{c}=\left [ c_1,\cdots,c_n,m \right ] \qquad(7)
其中cic_i表示第ii个数据集的attribute的0-1向量

假设当前图像属于第kk个数据集,那么ckc_k为表示attribute的0-1向量,其它ci(ik)c_i(i\neq k)为全0向量

(个人认为这个mask vector的设计一般般)

4. Implementation

Improved GAN Training

为了使GAN的训练过程更加稳定,同时生成高质量的图像,将公式(1)替换为WGAN-gp的版本
Ladv=Ex[Dsrc(x)]Ex.c[Dsrc(G(x,c))]λgpEx^(x^Dsrc(x^)21)2(8) \begin{aligned} \mathcal{L}_{adv}=&\mathbb{E}_x\left [ D_{src}(x) \right ]-\mathbb{E}_{x.c}\left [ D_{src}(G(x,c)) \right ] \\ &- \lambda_{gp}\mathbb{E}_{\hat{x}}\left ( \left \| \nabla_{\hat{x}}D_{src}\left ( \hat{x} \right ) \right \|_2-1 \right )^2 \qquad(8) \end{aligned}
其中x^\hat{x}是一组真实图像和假图像的线性组合,实验中设置λgp=10\lambda_{gp}=10

Network Architecture

只对生成器GG使用instance normalization,判别器DD的结构为PatchGAN

相关文章: