网络结构
论文中的网络结构图如下,embedding的提取直接使用预训练好的text encoder进行提取(不是本文重点)。提出的StackGAN整个模型包含2个GAN网络,分别用于两个阶段:
Stage1 :embedding+ noise 为输入,利用GAN输出低分辨率的64x64大小的影像;
Stage2 :embedding+ Stage I的低分辨率生成影像 为输入,利用GAN输高分辨率的256x256大小的影像
结合代码,stage I与stage II 的详细结构如下:
注意:其实代码中stage II 鉴别器输出的logit 有两种,分为condition 和uncondition,分别对应着有无引入embedding信息。(图中只显示了condition的logit输出)
每个阶段的GAN训练流程是相同的:
- 生成fake img;
- 训练鉴别器。考虑三种鉴别器输入,(1)real pairs:真实图像与对应的文本embedding,gt为 1;(2)wrong pairs:真实图像与不匹配的文本embedding,gt为 0;(3) fake pairs:生成图像与对应的文本embedding,gt为 0
- 训练生成器。鉴别器输入只考虑fake pairs,,gt为 1
网络训练
官方的pytorch实现有问题,stage I的生成器损失无法收敛。
在尝试1、提高D的lr同时降低G的lr 以及 反过来调整lr;2、提高G的通道数;3、使用改进的GAN损失函数形式 后,均无法正常收敛。