正向传播

正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。如图所示。假设我们只输入一个xRd\boldsymbol{x} \in \mathbb{R}^{d}的样本,且先不考虑偏差项,这里的d=4
神经网络中的正向和反向传播问题
那么中间变量
z=W(1)x \boldsymbol{z}=\boldsymbol{W}^{(1)} x
其中W(1)Rh×d\boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d},是隐层的权重参数。再将中间变量zRh\boldsymbol{z} \in \mathbb{R}^{h}输入按元素运算**函数后,得到向量长度为hh的隐层变量
h=ϕ(z) \boldsymbol{h}=\phi(z)
隐层变量h\boldsymbol{h}也是一个中间变量。输出层的参数假设只有W(2)Rq×h\boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h},可以得到向量长度为qq的输出层变量
o=W(2)h \boldsymbol{o}=\boldsymbol{W}^{(2)} \boldsymbol{h}
假设损失函数为\ell,且样本标签是yy,可以计算单个数据样本的损失项为
L=(o,y) L=\ell(\boldsymbol{o}, y)
再损失项上添加正则化项,使用L2L_{2}范数正则化,引入超参数λ\lambda
s=λ2(W(1)F2+W(2)F2) s=\frac{\lambda}{2}\left(\left\|\boldsymbol{W}^{(1)}\right\|_{F}^{2}+\left\|\boldsymbol{W}^{(2)}\right\|_{F}^{2}\right)
最终,模型在给定的数据样本上,带正则化的损失定义为
J=L+s J=L+s
JJ称为数据样本的目标函数。

正向传播计算框图

神经网络中的正向和反向传播问题
从图中可以很清晰的看出整个传播的流向。从下面这个流向,我们也可以大体知道如何到最后的JJ

xW(1)zϕhW(2)OL+SJ \boldsymbol{x} \stackrel{\boldsymbol{W}^{(1)}}{\longrightarrow} \boldsymbol{z} \stackrel{\phi}{\longrightarrow} \boldsymbol{h} \stackrel{W^{(2)}}{\longrightarrow} \boldsymbol{O} \stackrel{\ell}{\longrightarrow} L \stackrel{+S}{\longrightarrow} J

反向传播

反向传播是指计算神经网络参数梯度的方法,总体而言,反向传播依据的是微积分中的链式法则对输入或输出X,Y,ZX, Y, Z为任意形状张量(这里为了推广,统一称张量)的函数Y=f(X)Z=g(Y)Y=f(X) 和 Z=g(Y),通过链式法则,可以得到如下:
ZX=prod(ZY,YX)\frac{\partial Z}{\partial X}=\operatorname{prod}\left(\frac{\partial Z}{\partial Y}, \frac{\partial Y}{\partial X}\right)
其中prod运算符根据两个输入的形状,在必要的操作(如转置和互换输入位置)后对两个输入做乘法(为了满足最后的结果,进行通配)。

神经网络中的正向和反向传播问题
接下来,将讲解如何通过反向传播计算:

JW(1)JzJhJW(2)JOJL&JS \frac{\partial J}{\partial \boldsymbol{W}^{(1)}} \leftarrow \frac{\partial J}{\partial \boldsymbol{z}} \leftarrow \frac{\partial J}{\partial \boldsymbol{h}} \leftarrow \frac{\partial J}{\partial \boldsymbol{W}^{(2)}} \leftarrow \frac{\partial J}{\partial \boldsymbol{O}} \leftarrow \frac{\partial J}{\partial L} \& \frac{\partial J}{\partial S}

首先计算的是

JL=1,Js=1 \frac{\partial J}{\partial L}=1, \quad \frac{\partial J}{\partial s}=1

然后,依据链式法则计算目标函数有关输出层变量的梯度JoRq\frac{\partial J}{\partial \boldsymbol{o}}\in \mathbb{R}^{q}

Jo=prod(JL,Lo)=Lo \frac{\partial J}{\partial \boldsymbol{o}}=\operatorname{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right)=\frac{\partial L}{\partial \boldsymbol{o}}

接下来计算的是JW(2)\frac{\partial J}{\partial \boldsymbol{W}^{(2)}},但是在计算之前,为了后续方便,我们先计算一下正则项有关两个参数的梯度:

sW(1)=λW(1),sW(2)=λW(2) \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}=\lambda \boldsymbol{W}^{(1)}, \quad \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}=\lambda \boldsymbol{W}^{(2)}

在计算JW(2)Rq×h\frac{\partial J}{\partial \boldsymbol{W}^{(2)}}\in \mathbb{R}^{q \times h}得到如下:

JW(2)=prod(Jo,oW(2))+prod(Js,sW(2))=Joh+λW(2) \frac{\partial J}{\partial \boldsymbol{W}^{(2)}}= \operatorname{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right)+\operatorname{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right)= \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^{\top}+\lambda \boldsymbol{W}^{(2)}

这里可能有人会觉得奇怪为什么是Jo\frac{\partial J}{\partial \boldsymbol{o}}而不是JL\frac{\partial J}{\partial L},可以对其进行证明一下(后续同理):

Jo,oW(2)=JL,LW(2) \frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}=\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{W}^{(2)}}

JL,LW(2)=LW(2)=Lo,oW(2) \frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{W}^{(2)}} = \frac{\partial L}{\partial \boldsymbol{W}^{(2)}}=\frac{\partial L}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}

由之前的式子Jo=prod(JL,Lo)=Lo\frac{\partial J}{\partial \boldsymbol{o}}=\operatorname{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right)=\frac{\partial L}{\partial \boldsymbol{o}}可以得到:

Lo,oW(2)=Jo,oW(2) \frac{\partial L}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}=\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}

沿着输出层继续反向传播,现在求的是JhRh\frac{\partial J}{\partial \boldsymbol{h}}\in \mathbb{R}^{h},可以的得到:

Jh=prod(Jo,oh)=W(2)Jo. \frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}}.

由于**函数ϕ\phi是按元素运算的,中间变量z\boldsymbol{z}的梯度JzRh\frac{\partial J}{\partial \boldsymbol{z}}\in \mathbb{R}^h的计算需要使用按元素乘法符号\odot:

Jz=prod(Jh,hz)=Jhϕ(z). \frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right).

最后,可以得到最开始的一层的模型参数的梯度J/W(1)Rh×d\partial J/\partial \boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d}。依据链式法则,可以得到:

JW(1)=prod(Jz,zW(1))+prod(Js,sW(1))=Jzx+λW(1). \frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)}.

训练深度模型

在训练深度学习模型时候,正向传播和反向传播之间是相互依赖的。

一方面,正向传播的计算可能被依赖于模型参数的当前值,而这些模型参数是在反向传播的梯度计算后,通过算法迭代优化的。如,计算正则化项s=(λ/2)(W(1)F2+W(2)F2)s = (\lambda/2) \left(\|\boldsymbol{W}^{(1)}\|_F^2 + \|\boldsymbol{W}^{(2)}\|_F^2\right)依赖模型参数W(1)\boldsymbol{W}^{(1)}W(2)\boldsymbol{W}^{(2)}的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。

另一方面,反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值,是通过正向传播计算得到的,如,参数梯度J/W(2)=(J/o)h+λW(2)\partial J/\partial \boldsymbol{W}^{(2)} = (\partial J / \partial \boldsymbol{o}) \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)}的计算需要依赖隐层变量的当前值h\boldsymbol{h}。这个当前值是通过从输入层到输出层的正向传播计算并得到的。

因此,在模型参数初始化完成后,我们交替地进行正向传播和反向传播,并根据反向传播计算的梯度迭代模型参数。既然我们在反向传播中使用了正向传播中计算得到的中间变量来避免重复计算,那么这个复用也导致正向传播结束后不能立即释放中间变量内存。这也是训练要比预测占用更多内存的一个重要原因。另外需要指出的是,这些中间变量的个数大体上与网络层数线性相关,每个变量的大小跟批量大小和输入个数也是线性相关的,它们是导致较深的神经网络使用较大批量训练时更容易超内存的主要原因。

相关文章:

  • 2021-03-31
  • 2022-01-04
  • 2022-01-02
  • 2021-10-02
  • 2021-11-25
  • 2022-12-23
  • 2021-10-04
猜你喜欢
  • 2022-01-12
  • 2021-12-19
  • 2021-09-19
  • 2021-09-30
  • 2021-08-06
  • 2021-12-19
  • 2021-09-03
相关资源
相似解决方案