摘要:
本文主要讲述的RNN(循环神经网络)在训练过程中BP算法的推导。
在阅读本文之前希望读者看过我的另一篇文章BP算法心得体会。因为大部分的思路沿用的就是这篇文章的思路。
参考文章:
数学推导-1
数学推导-2
更新-2018-01-23:
之前写完这篇文章之后,回头看了一遍文章,发现在整个推导的过程都无视了时间维度的存在,所以后来查阅了相关的资料,发现目前网上有一部分RNN的推导过程和本文是一样的,比如上面给到的2篇参考文章,思路和本文是一致的。但是也存在另外一些版本的推导,其过程和本文的截然不同。
所以后来在参考了大神的代码后,重新思考了rnn的训练算法,因此重新写一个篇rnn和bptt供大家参考。
正文
RNN的一般原理介绍这里就不再重复了,本文关注的是RNN是如何利用BP算法来进行训练的。
推导
在推导BP算法之前,我们先做一些变量上的规定,这一步非常关键。
本文使用的RNN是只含一个隐藏层(多个隐藏层其实也是一样的道理)。其结构如下图所示:

(大家看到这个网络结构可能有些困惑,比如说,RNN是由多个网络组成的吗?这里值得注意的是,RNN就只由一个网络组成,图上有多个网络是在不同时刻的输入下的网络情况)
现在,作如下的一些规定:
vim是输入层第m个输入与隐藏层中第i个神经元所连接的权重。
uin是隐层自循环的权重(具体表现为上面结构图中那些紫色、绿色的线)
wkm是隐藏层中第m个神经元与输出层第k个神经元连接的权重。
网络中共有N(i)个输入单元,N(h)个隐藏层,N(o)个输出单元
netthi表示隐藏层第i个神经元在t时刻**前的输入。
具体为:netthi=∑N(i)m=1(vimxtm)+∑N(h)s=1(uisht−1s)
经过**后的输出为:hti=f(netthi)
nettyk表示输出层第k个神经元在t时刻**前的输入。
具体为:nettyk=∑N(h)m=1(wkmhtm)
经过**后的输出为:otk=f(nettyk)
这里同样地,为了方便推导,假设损失函数Et=0.5∗∑N(o)k=1(otk−ttk)2(本文也会说明使用其他损失函数的情况)
E=∑stept=1Et
首先我们需要解决的问题就是求出:
∂E∂uin,∂E∂wkm,∂E∂vim。
1.先来求最简单的∂E∂wkm:
和之前讲解BP的文章套路一样,我们可以对∂E∂wkm使用链式法则,具体如下:
∂E∂wkm=∂E∂nettyk∗∂nettyk∂wkm
对于等式右边第二项很好计算,∂nettyk∂wkm=htm
和之前一样,我们定义等式右边第一项为误差信号δtyk=∂E∂nettyk。
δtyk=∂E∂nettyk=∂E∂otk∗∂otk∂nettyk。(这一步的思路就是找到和nettyk有直接相关的变量)
故:δtyk=∂E∂nettyk=∂E∂otk∗∂otk∂nettyk=(otk−ttk)∗f′(nettyk)
所以,∂E∂wkm=δtyk∗htm=(otk−ttk)∗f′(nettyk)∗htm。
下面,我们推导∂E∂vim。
∂E∂vim=∂E∂netthi∗∂netthi∂vim
对于∂netthi∂vim=xtm
定义误差信号δthi=∂E∂netthi
δthi=∂E∂netthi=∂E∂hti∗∂hti∂netthi=∂E∂hti∗f′(netthi)。
整个RNN如果说最麻烦的推导,可能就是对于∂E∂hti的推导。
按照以前的思路我们容易想到:
∂E∂hti=∑N(o)k=1(∂E∂nettyk∗∂nettyk∂hti)=∑N(o)k=1(∂E∂nettyk∗wki)。
上面的推导对吗?如果推导到这里感觉没问题的话不妨思考一个问题,如果公式是这样,这里哪里体现了RNN具有“记忆”的功能?公式体现的只与当前时刻t有关。
我们注意到,和hti有直接函数关系的除了nettyk以外,其实还有一条等式,而恰恰是这条等式把每个时刻之间的关系串了起来。
就是:netthi=∑N(i)m=1(vimxtm)+∑N(h)s=1(uisht−1s)。
我把上式中的t->t+1,也就是往后推一个时刻,我们有:
nett+1hi=∑N(i)m=1(vimxt+1m)+∑N(h)s=1(uishts)。
也就是说,hti还和nett+1hi相关。所以上式应该改写成:
∂E∂hti=∑N(o)k=1(∂E∂nettyk∗∂nettyk∂hti)+∑N(h)s=1(∂E∂nett+1hs∗∂nett+1hs∂hti)=∑N(o)k=1(∂E∂nettyk∗wki)+∑N(h)s=1(∂E∂nett+1hs∗usi)
其实就是多了一项。这个是大家需要注意的!
对于∂E∂nettyk其为输出层的误差信号,上面已经求过了,即δtyk。
而∂E∂nett+1hs其实就是δt+1hs。这个就是t+1时刻的隐藏层的一个误差信号。而t时刻隐藏层的误差信号与t+1时刻隐藏层的误差信号有关,或者换句话说法,t时刻的隐藏层的误差信号积累了t+1时刻的误差,看到这里,其实我们就可以认识到一个问题,RNN确实具有一定的记忆能力。
Ok,把上式整理一下可以得到:
∂E∂hti=∑N(o)k=1(δtyk∗wki)+∑N(h)s=1(δt+1hs∗usi)
由于:δthi=∂E∂netthi=∂E∂hti∗∂hti∂netthi=∂E∂hti∗f′(netthi),替换掉∂E∂hti得到:
δthi=(∑N(o)k=1(δtyk∗wki)+∑N(h)s=1(δt+1hs∗usi))∗f′(netthi)
∂E∂vim=δthi∗xtm
最后一个是∂E∂uin。其实其和∂E∂vim是一样的,(因为位于同一层)具体可以参考下面:
∂E∂uin=∂E∂netthi∗∂netthi∂uin,可以看到等式右边第一项就是上面推导过的隐藏层误差信号δthi,而第二项就是ht−1n。
所以:∂E∂uin=∂E∂netthi∗∂netthi∂uin=δthi∗ht−1n。
小结
至此,RNN的bp算法算是推导完毕,我们如果看回整个推导过程,其实和前面文章介绍的BP没什么区别,最大的区别在于RNN具有时序性,所以在隐藏层的误差信号处理时需要格外的注意,下面,我们可以从结构图来看待这一个问题,这种角度也可以加深我们对所谓“反向传播”有多一个深刻的理解。

这个是上面的结构图,我关注一下“紫色”的线,紫色线连接的是t时刻隐藏层和t+1时刻的隐藏层。我们从误差传播的角度来看。对于t时刻隐藏层某一个神经元而言,其误差可以分为两部分来源,第一部分就是t时刻本身的(黑色线,连接隐藏层和输出层这些),另外一部分就是t+1时刻时候隐藏层和隐藏层(自循环层)。
而这两部分恰恰对应了上面公式的两个部分。
公式中红色部分就是t时刻的误差,蓝色部分就是来自于t+1时刻。
δthi=(∑N(o)k=1(δtyk∗wki)+∑N(h)s=1(δt+1hs∗usi))∗f′(netthi)
下面再补一副很简陋的图,想表达的意思和上面一样。

关于训练过程的细节-1
这里可能有人会疑问,计算t时刻误差需要用到t+1时刻的误差,这个不是有背常理吗?这里需要注意的,神经网络里面是先前向计算,然后反向传播误差。所以每次训练,先从t=0时刻前向计算至最后一个时刻t。然后从t时刻反向传播误差。所以这里需要保存每一个时刻隐藏层、输出层的输出。
关于训练过程的细节-2
最后一个时刻由于没有下一个时刻传回来的隐藏层误差,所以下式中蓝色一项为0。
δthi=(∑N(o)k=1(δtyk∗wki)+∑N(h)s=1(δt+1hs∗usi))∗f′(netthi)
即:
δthi=∑N(o)k=1(δtyk∗wki)。
关于损失函数
和之前BP算法推导一样,其实损失函数就只有在这一步中产生影响。
δtyk=∂E∂nettyk=∂E∂otk∗∂otk∂nettyk。
完全可以保留∂E∂otk这个记号,并不影响后面的计算。
至此,两篇关于BP算法的文章算是告一段落,希望大家能够从中学习到东西。