1.目的

仅使用一个模型来执行多个域的图像到图像的转换

2.贡献

  • 提出了一种全新的生成对抗网络StarGAN,该网络只使用一个生成器和一个鉴别器来学习多个域之间的映射,并从各个域的图像中有效地进行训练;
  • 演示了如何使用掩模向量方法(mask vector method)成功学习多个数据集之间的多域图像转换,并使得StarGAN控制所有可用的域标签;
  • 使用StarGAN进行面部属性转换和面部表情合成任务,并对结果进行了定性和定量分析,结果显示其优于基准线模型。

3.关键点

为了保证生成器G能够有效在多个域之间转换,目标域的标签随机给定。

4.网络结构

StarGAN快速阅读

网络的结构仿照Cycle-GAN的设置,使用两层步长为2的卷积层进行下采样(降维),6个残差块连接,然后使用两层步长为2的卷积层进行上采样。生成器使用了实例归一化,但是判别器没有用正则化。判别器网络文中使用的是patch-GAN。
文中在每一层都使用了实例归一化,除了最后的输出层
分类器的**函数使用了leakyrelu,负的一侧的斜率为0.01.

5.loss设置

对抗loss

生成器的目标是最小化对抗loss,判别器的目标是最大化对抗loss

StarGAN快速阅读

域分类loss

生成器和判别器的目标都是最小化域分类loss,域分类loss有两个,real image的域分类loss和fake image的域分类loss,前者是为了训练判别器,后者是为了训练生成器。对于一个给定的输入图片x(属于C1域)和域c,生成器的目标是输出一张图片y,恰好属于c域。

real image domain classification loss(训练D)

StarGAN快速阅读

fake image domain classification loss(训练G)

StarGAN快速阅读

重建loss

通过最小化域分类loss和对抗loss,生成器能够生成符合目标域的真实图片,但源域和目标域的图片内容可能不一致,因此引入了重建loss的概念。StarGAN快速阅读

意思就是利用已经生成的目标域图片与源域的域标签结合生成源域图片,然后计算此时生成的源域的图片和输入时的源域图片之间的L1loss,G的目标是最小化L1loss。

总的loss

此时LD与LG均是最小化。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V4WpKugL-1585618541984)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328113750889.png)]

补充

为了提高训练的效率和训练的稳定性以生成更高质量的image,文中将对抗loss换成了WGAN中的对抗loss

StarGAN快速阅读

原因

在(近似)最优判别器下,最小化生成器的loss等价于最小化StarGAN快速阅读StarGAN快速阅读之间的JS散度,而由于StarGAN快速阅读StarGAN快速阅读几乎不可能有不可忽略的重叠,所以无论它们相距多远JS散度都是常数StarGAN快速阅读,最终导致生成器的梯度(近似)为0,梯度消失。

WGAN的知识介绍:

改进后的GAN相比原始GAN的算法实现流程却只改了四点

  • 判别器最后一层去掉sigmoid
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

6.不同数据集的域标签该如何表示

问题1:数据集1的图片标签有年龄,性别,头发颜色等信息,但却缺乏表情信息;数据集2的图片有年龄,性别,表情等信息,但是却缺乏头发颜色的信息。

解决办法:

StarGAN快速阅读

引入了mask vector ,假设有n个数据集,每个数据集标签或类别并集的数量为T,则建立一个T*n的向量。当使用数据集1的时候,c1的长度为T,使用后0、1表示数据集的类别或标签的信息,剩余的n-1个列向量全部置为0.

7.训练star-Gan时候的输入数据形式

生成器的输入包含两个部分,一部分是输入图像imgs,大小为(batch_size, n_channel, cols, rows);一部分是目标领域的标签domain,大小为(batch_size, n_dim)。为了将这两部拼接,需要通过repeat操作来对domain进行扩展,将其扩展为(batch_size, n_dim, cols, rows),因此,生成器输入的大小为(batch_size, n_channel + n_dim, cols, rows),生成器的输出为(batch_size, n_channel, cols, rows)。判别器的输入为图像imgs,大小为(batch_size, n_channel, cols, rows),判别器的输出分为两部分,一部分是图像的真假判断,大小为(batch_size, 1, s1, s2),另一部分为图像的类别划分,大小为(batch_size, n_dim)。

相关文章: