记录一下对RNN,LSTM,GRU基本原理(正向过程以及简单的反向过程)的个人理解


RNN

Recurrent Neural Networks,循环神经网络
(注意区别于recursive neural network,递归神经网络)

为了解决DNN存在着无法对时间序列上的变化进行建模的问题(如自然语言处理、语音识别、手写体识别),出现的另一种神经网络结构——循环神经网络RNN。

RNN结构

RNN,LSTM,GRU基本原理的个人理解

  • t层神经元的输入,除了其自身的输入xt,还包括上一层神经元的隐含层输出st1
  • 每一层的参数U,W,V都是共享的
    RNN,LSTM,GRU基本原理的个人理解
    RNN,LSTM,GRU基本原理的个人理解
  • 每一层并不一定都得有输入和输出,如对句子进行情感分析是多到一,文本翻译多到多,图片描述一到多

数学描述

(以下开始符号统一)
回忆一下单隐含层的前馈神经网络
输入为XRn×x(n个维度为x的向量)
隐含层输出为

H=ϕ(XWxh+bh)

输出层输入HRn×h
输出为
Y^=softmax(HWhy+by)

现在对XHY都加上时序下标
同时引入一个新权重WhhRh×h
得到RNN表达式
Ht=ϕ(XtWxh+Ht1Whh+bh)
Y^t=softmax(HtWhy+by)

H0通常置零

深层RNN和双向RNN

RNN,LSTM,GRU基本原理的个人理解
RNN,LSTM,GRU基本原理的个人理解

通过时间反向传播和随之带来的问题

输入为xtRx
不考虑偏置
隐含层变量为

ht=ϕ(Whxxt+Whhht1)

输出层变量为
ot=Wyhht

则损失函数为
L=1Tt=1T(ot,yt)

以一个三层为例
RNN,LSTM,GRU基本原理的个人理解
三个参数更新公式为

Whx=WhxηLWhx

Whh=WhhηLWhh

Wyh=WyhηLWyh

明显的
Lot=(ot,yt)Tot

根据链式法则
LWyh=t=1Tprod(Lot,otWyh)=t=1TLotht

先计算目标函数有关最终时刻隐含层变量的梯度
LhT=prod(LoT,oThT)=WyhLoT

假设ϕ(x)=x(RNN中用**函数relu还是tanh众说纷纭,有点玄学)

Lht=prod(Lht+1,ht+1ht)+prod(Lot,otht)=WhhLht+1+WyhLot

通项为
Lht=i=tT(Whh)TiWyhLoT+ti

注意上式,当每个时序训练数据样本的时序长度T较大或者时刻t较小,目标函数有关隐含层变量梯度较容易出现衰减和爆炸

LWhx=t=1Tprod(Lht,htWhx)=t=1TLhtxt

LWhh=t=1Tprod(Lht,htWhh)=t=1TLhtht1

梯度裁剪

为了应对梯度爆炸,一个常用的做法是如果梯度特别大,那么就投影到一个比较小的尺度上。θ为设定的裁剪“阈值”,为标量,若梯度的范数大于此阈值,将梯度缩小,若梯度的范数小于此阈值,梯度不变

g=min(θg,1)g


LSTM

RNN的隐含层变量梯度可能会出现衰减或爆炸。虽然梯度裁剪可以应对梯度爆炸,但无法解决梯度衰减。因此,给定一个时间序列,例如文本序列,循环神经网络在实际中其实较难捕捉两个时刻距离较大的文本元素(字或词)之间的依赖关系。
LSTM(long short-term memory)由Hochreiter和Schmidhuber在1997年被提出。

LSTM结构

这里两张图先不用细看,先着重记住公式后再回来看

RNN,LSTM,GRU基本原理的个人理解
RNN,LSTM,GRU基本原理的个人理解

数学描述

(同上,符号统一)
设隐含状态长度h,t时刻输入XtRn×xx维)及t1时刻隐含状态Ht1Rn×h,
输入门,遗忘门,输出门,候选细胞如下

It=σ(XtWxi+Ht1Whi+bi)

Ft=σ(XtWxf+Ht1Whf+bf)

Ot=σ(XtWxo+Ht1Who+bo)

C~t=tanh(XtWxc+Ht1Whc+bc)

(思考侯选细胞**函数的不同)
记忆细胞

Ct=FtCt1+ItC~t

想象,如果遗忘门一直近似1且输入门一直近似0,过去的细胞将一直通过时间保存并传递至当前时刻
隐含状态
Ht=Ottanh(Ct)

输出同RNN
Y^=softmax(HWhy+by)


GRU

由Cho、van Merrienboer、 Bahdanau和Bengio在2014年提出,比LSTM少一个门控,实验结果却相当

GRU结构

RNN,LSTM,GRU基本原理的个人理解

数学描述

设隐含状态长度h,t时刻输入XtRn×xx维)及t1时刻隐含状态Ht1Rn×h,
重置门,更新门如下

Rt=σ(XtWxr+Ht1Whr+br)

Zt=σ(XtWxz+Ht1Whz+bz)

候选隐含状态
H~t=tanh(XtWxh+RtHt1Whh+bh)

隐含状态
Ht=ZtHt1+(1Zt)H~t

输出
Y^=softmax(HWhy+by)

(无力吐槽csdn了,预览和实际用的不一套渲染,公式丑死)

相关文章: