官方pytorch代码实现
StarGAN-v2是一个可以实现图像到图像转换的GAN网络,前身是StarGAN.
相较于前作,功能上的一个不同是在同一个域的情况下添加了“样式”的转换,而之前的StarGAN做了域的转换,而样式没有变化,见下图(A)
图中D, E, F为StarGAN-v2在StarGAN baseline之上逐步添加样式转换的效果。可以看出,StarGAN-v2的多样性远大于前作StarGAN。
概念解释
在论文的介绍中,给出了域(domain)和风格(style)的定义。
- 域:一个图像集合,其中的图片可以被分类为同一种具有视觉区分度的类型。
- 风格:每个图像具有的独特外观。
网络模型
StarGAN-v2网络由四个模块组成:Generator(生成器),Mapping network(映射网络),Style encoder(样式编码器),Discriminator(判别器)
Generator(生成器)
生成器将输入图像x转换为输出图像:G(x; s)
其中s为样式编码,由映射网络F生成或者由样式编码器E生成。
Mapping network(映射网络)
根据隐编码z和域y,映射网络生成央视编码s = Fy(z),Fy代表映射网络对应域y的输出。映射网络F由多层感知机MLP和多个输出分支(与域的数量相同)组成。
Style encoder(样式编码器)
根据图像x和其对应的域y,样式编码器E提取出样式编码s = Ey(x),Ey代表样式编码器对应域y的输出。
Discriminator(判别器)
StarGAN-v2的判别器使用了多任务判别器,由多个输出分支构成。每个分支Dy学习一个二分类问题:判断图像x是域y的真实图像,还是由生成器生成的图像G(x; s)
损失函数
损失函数由四个子函数组成:对抗损失,样式重建损失,样式多样性损失,源保留损失(循环一致性损失)
对抗损失
在训练过程中,随机采样隐编码 和域
,然后生成目标样式编码
。生成器G使用图像x和样式编码
通过上面的对抗损失来学习如何生成图像
。
Dy表示D的输出对应了域y。
样式重建损失
为了促使生成器G在生成图片时利用样式编码
,论文使用了上面的样式重建损失。这和之前一些论文的损失函数类似,他们使用了多个encoder来学习图像到隐编码的映射。与他们不同的是,该论文只使用了一个encoder来学习不同域不同风格的输出。
样式多样性损失
为了进一步使生成器G生成多样化的图像,论文显式的增加了一个正则项损失,如上式。
源保留损失(循环一致性损失)
为了保证生成的图像能够合理地保留域无关的特征(比如姿势),论文使用了循环一致性损失
最终损失函数
最终的损失函数就是如下公式
网络结构
生成器
对于AFHQ数据集,生成器使用了4个下采样块,4个中间块和4个上采样块。上采样块和下采样块使用了IN(Instance Normalization)和AdaIN(Adaptive Instance Normalization). 样式编码注入到所有的AdaIN层,通过仿射变换提供了伸缩和位移向量。
映射网络
映射网络由K个输出分支的MLP构成,K是域的个数。前面的4个全连接层在所有的域之间共享,后面的4个全连接层是每个域独有不共享
样式编码器
样式编码器由K个输出分支的CNN构成,K为域的个数。前面6个残差块在所有的域间共享,后面接1个全连接层为每个域独有不共享。
判别器
判别器和样式编码器的网络结构相同
模型效果
根据样例图片生成图像
根据隐编码生成图像
左侧为CelebA-HQ数据集上的效果,右侧为AFHQ数据集上的效果
评估指标
Frechét inception distance (FID)
FID可以衡量两个图片的差异性。论文中使用了在ImageNet上预训练的Inception-V3网络的最后两层平均池化层的特征向量来计算。
Learned perceptual image patch similarity (LPIPS)
LPIPS可以衡量生成图片的多样性:通过计算在ImageNet预训练的AlexNet提取的特征的L1距离。