GAN学习笔记 (1):理论基础
下面图片/公式来自李宏毅老师课件:http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS18.html
理论基础
Generator
对于GAN的Generator而言,它是要接受一个任意输入,然后generates一个输出,这个generates出的输出我们希望它符合真实数据集的分布。如果我们希望输出是一张图片(i.e. 224*224 维的vector),为了更直观的表示,我们将这个vector简单地表示为一个点。这个点的值不同,它符合真实图片的分布的概率也就不同,generator的终极目标当然是generates出的“点”符合真实图片的分布。如下所示:
我们要我们generate出的“点”的分布符合真实的点的分布,一个很自然的想法就是构造一个衡量两个分布的函数然后用梯度下降去minimize它,如下图所示:Pdata和PG分别代表真实分布和我们生成的点所属的分布。
尴尬的是,我们generate出来的点的分布和真实的分布我们都不知道,如何计算divergence呢?Goodefellow告诉我们,构造一个网络来计算出这个divergence,即discriminator,discriminator的作用,就是量出这个divergence下面来看discriminator。
Discriminator
我们虽然既不知道PG也不知道Pdata,但我们可以从它们中采样,然后训练一个二分类器,给PG中的样本低分,给Pdata中的样本高分。概率的损失我们常用交叉熵损失函数,所以我们得到我们Discriminator的损时函数:
V ( G , D ) = E ( x ∼ P d a t a ) [ l o g D ( x ) ] + E ( x ∼ P G ) [ l o g ( 1 − D ( x ) ) ] V(G,D)=E_(x∼P_data ) [logD(x)]+E_(x∼P_G ) [log(1-D(x))] V(G,D)=E(x∼Pdata)[logD(x)]+E(x∼PG)[log(1−D(x))]
我们不能直接给0、1标签来训练分类器,因为给标签的做法就不是在找两个分布的差异了,而是单纯的训练一个判断真假的函数,你总能把它训练的很好。我们固定住G,给如果x来自真实分布我们希望D(x)高,如果x由我们生成,我们希望D(x)低,所以我们的目标就是让上述式子值最大,以此来代表两个分布的差异。看一段推导:
所以我们的目标就是maximize P d a t a ( x ) l o g D ( x ) + P G ( x ) l o g ( 1 − D ( x ) ) P_data (x)logD(x)+P_G (x)log(1-D(x)) Pdata(x)logD(x)+PG(x)log(1−D(x)),再看一段推导:
所以使得V最大的D为: ( P d a t a ( x ) ) / ( P d a t a ( x ) + P G ( x ) ) (P_data (x))/(P_data (x)+P_G (x) ) (Pdata(x))/(Pdata(x)+PG(x)),回代回原式得到结果:
所以优化目标的最大值就是真实分布和我们生成”点“的分布的JS散度!我们通过一个网络找出了两个分布的差异,这样就话题又回到了generator的优化身上。
summarize
对于generator而言我们的目标是:
对于上面这个式子,我们可以用下面这幅图来更好地理解:
如上图所示,首先固定住G,让D进行变化,找到V(G,D)的最大值,这个最大值就是js散度,衡量了两个分布的差异,然后再变换G找出目标函数的最小值。上图就是G3中绿色框画出的点。上述步骤,我们程序化地描述如下:
More
还有最后一个问题,如下式所示:
L(G)含有max操作,是一个分段函数,它可导吗?我们的做法是,让取得最大值的函数对变量求导,如下所示:
再来看一下我们的算法步骤:
每一步我们会得到新的G1,D1…但是我们得到G1的时候V(G1,D0)衡量的就不是上一步的js散度了,我们不要把G更新地太猛,要保持更新后的两个function很像。
In Practice
实际操作中我们当然不可能求期望,所以我们这样做:
最后我们看一下总体的算法:
以上就是全部理论内容啦!下一篇博客上代码!