本文主要是对tensorflow中lstm模型中的c,h进行解析。rnn_cell_impl.py

1.关于RNN模型

在rnn_cell_impl.py的tensorflow源码中,关于RNN部分实现的类主要是BasicRNNCell,
首先在build函数中,定义了两个变量_kernel和_bias。
关于RNN相关模型-tensorflow源码理解
其中_num_untis表示RNN cell 的untis数目。
所以在call函数中,hidden_state的更新如下所示:
关于RNN相关模型-tensorflow源码理解

从上面中可以看出,RNN首先将input与上一个state连接,然后与在build函数中定义的_kernal变量点乘,最后加上偏置项。

2. 关于LSTM模型

主要看BasicLSTMCell这个类,在build函数中,定义了两个参数_kernel与_bias
关于RNN相关模型-tensorflow源码理解
与RNN不同,参数_kernal与_bias的列都是_num_units的四倍,主要是因为后面要分成四个部分,分别为i,j,f,o。
因此,在call函数中,
关于RNN相关模型-tensorflow源码理解
在call函数中,i,j,f,o可以分别表示为:
关于RNN相关模型-tensorflow源码理解

所以,在上面的图中,最上面的横线表示C,最小面的横线表示h。

3. 关于GRU模型

在GRU模型中的build函数中,可以看到定义了四个参数:
关于RNN相关模型-tensorflow源码理解
因此,在call函数中,

关于RNN相关模型-tensorflow源码理解
从下面的图中可以看出,zt为u,r表示rt,
关于RNN相关模型-tensorflow源码理解

从tensorflow的源码来看,上面的公式中ht的求解有问题,所以参考维基百科,得到下面的公式:
关于RNN相关模型-tensorflow源码理解

相关文章:

  • 2021-07-11
  • 2021-06-02
  • 2022-12-23
  • 2022-03-09
  • 2022-12-23
  • 2021-09-12
  • 2021-06-07
猜你喜欢
  • 2021-05-08
  • 2021-12-16
  • 2022-01-01
  • 2021-08-11
  • 2021-12-09
  • 2021-10-08
  • 2021-11-13
相关资源
相似解决方案