本菜鸡觉得RNN求导公式太复杂了, 所以想了一个办法拆分求导的公式.

那就是用语法树.

原文参见RNN反向求导详解_格物致知-CSDN博客

RNN求导公式详细推导

ot=φ(Vst)=φ(Vϕ(Wst1+Uxt))Lt=loss(ot,yt) \begin{aligned} o_t&=\varphi(Vs_t)=\varphi(V\phi(Ws_{t-1}+Ux_t))\\ L_t&=\text{loss}(o_t,y_t) \end{aligned}
ot=Vsto_t^*=Vs_t, st=Uxt+Wst1s_t^*=Ux_t+Ws_{t-1}

ot=φ(ot)o_t=\varphi(o_t^*), st=ϕ(st)s_t=\phi(s_t^*)

现在把LtL_t画成一棵语法树, 然后开始一步一步求导

RNN求导公式详细推导

*表示元素相乘, 用×\times表示矩阵乘法
Ltot=Ltototot=Ltotφ(ot)(1) \begin{aligned} \cfrac{\partial L_t}{\partial o_t^*}=\cfrac{\partial L_t}{\partial o_t}*\cfrac{\partial o_t}{\partial o_t^*}=\cfrac{\partial L_t}{\partial o_t}*\varphi'(o_t^*)\tag{1} \end{aligned}
式1的结果是一个与oto_t^*的维度一致的向量.
LtVt=Ltot[?]otV(2) \cfrac{\partial L_t}{\partial V_t}=\cfrac{\partial L_t}{\partial o_t^*}[?]\cfrac{\partial o_t^*}{\partial V}\tag{2}
公式2整体上是标量对矩阵求导, 标量对矩阵求导就是标量对矩阵中的每个元素求导; 有一个中间值oto_t^*是向量.

的前半部分在公式1中求过了, 后面是对矩阵×向量的求导

既然是对VV求导那结果的形状必然跟VV一样

还是写个例子算算怎么求导吧

o=V×s=[V11V12V13V14V21V22V23V24V31V32V33V34]×[s1s2s3s4]=[V11s1+V12s2+V13s3+V14s4V21s1+V22s2+V23s3+V24s4V31s1+V32s2+V33s3+V34s4]=[o1o2o3](3) \boldsymbol{o^*}=\boldsymbol{V}\times\boldsymbol{s}= \begin{bmatrix}V_{11}&V_{12}&V_{13}&V_{14}\\V_{21}&V_{22}&V_{23}&V_{24}\\V_{31}&V_{32}&V_{33}&V_{34}\end{bmatrix} \times \begin{bmatrix}s_1\\s_2\\s_3\\s_4\end{bmatrix}= \begin{bmatrix}V_{11}s_1+V_{12}s_2+V_{13}s_3+V_{14}s_4\\V_{21}s_1+V_{22}s_2+V_{23}s_3+V_{24}s_4\\V_{31}s_1+V_{32}s_{2}+V_{33}s_3+V_{34}s_4\end{bmatrix}= \begin{bmatrix}o^*_1\\o^*_2\\o^*_3\end{bmatrix}\tag{3}
LV11=Lo1o1V11=Lo1s1LV12=Lo1o1V12=Lo1s2LV34=Lo3o3V34=Lo3s4 \begin{aligned} \cfrac{\partial L}{\partial V_{11}}=\cfrac{\partial L}{\partial o^*_1}\cfrac{\partial o^*_1}{\partial V_{11}}&=\cfrac{\partial L}{\partial o^*_1}s_1\\ \cfrac{\partial L}{\partial V_{12}}=\cfrac{\partial L}{\partial o^*_1}\cfrac{\partial o^*_1}{\partial V_{12}}&=\cfrac{\partial L}{\partial o^*_1}s_2\\ &\vdots\\ \cfrac{\partial L}{\partial V_{34}}=\cfrac{\partial L}{\partial o^*_3}\cfrac{\partial o^*_3}{\partial V_{34}}&=\cfrac{\partial L}{\partial o^*_3}s_4\\ \end{aligned}
LV=[Lo1Lo2Lo3]×[s1s2s3s4] \begin{aligned} \cfrac{\partial L}{\partial V}=\begin{bmatrix}\cfrac{\partial L}{\partial o^*_1}\\\cfrac{\partial L}{\partial o^*_2}\\\cfrac{\partial L}{\partial o^*_3}\end{bmatrix}\times\begin{bmatrix}s_1&s_2&s_3&s_4\end{bmatrix} \end{aligned}
所以式2应该写成
LtVt=Ltot×otV=Ltot×stT(4) \cfrac{\partial L_t}{\partial V_t}=\cfrac{\partial L_t}{\partial o_t^*}\times\cfrac{\partial o_t^*}{\partial V}=\cfrac{\partial L_t}{\partial o_t^*}\times s_t^T\tag{4}
然后求LtL_tsts_t的导数, 还要参考式3

RNN求导公式详细推导
图片来源: https://wenku.baidu.com/view/0c28ff2249d7c1c708a1284ac850ad02de8007c1.html

Ls1=[Lo1o1s1+Lo2o2s1+Lo3o3s1]Ls4=[Lo1o1s4+Lo2o2s4+Lo3o3s4] \begin{aligned} \cfrac{\partial L}{\partial s_1}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_1} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_1} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_1}\end{bmatrix}\\ &\vdots\\ \cfrac{\partial L}{\partial s_4}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_4} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_4} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\ \end{aligned}
Ls=[Lo1o1s1+Lo2o2s1+Lo3o3s1Lo1o1s2+Lo2o2s2+Lo3o3s2Lo1o1s3+Lo2o2s3+Lo3o3s3Lo1o1s4+Lo2o2s4+Lo3o3s4]=[o1s1o2s1o3s1o1s2o2s2o3s2o1s3o2s3o3s3o1s4o2s4o3s4]×[Lo1Lo2Lo3]=?×Lot(5) \begin{aligned} \cfrac{\partial L}{\partial s}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_1} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_1} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_2} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_2} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_3} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_3} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_4} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_4} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\&= \begin{bmatrix}\cfrac{\partial o_1^*}{\partial s_1}&\cfrac{\partial o_2^*}{\partial s_1}&\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial o_1^*}{\partial s_2}&\cfrac{\partial o_2^*}{\partial s_2}&\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial o_1^*}{\partial s_3}&\cfrac{\partial o_2^*}{\partial s_3}&\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial o_1^*}{\partial s_4}&\cfrac{\partial o_2^*}{\partial s_4}&\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\times\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\\\cfrac{\partial L}{\partial o_2^*}\\\cfrac{\partial L}{\partial o_3^*}\end{bmatrix} \\&=?\times{\partial L \over \partial o_t}\tag{5} \end{aligned}
要解决式5的后一步, 需要先向量求导的问题

参考链接: https://zhuanlan.zhihu.com/p/36448789

文中有一句话:

不过为了方便我们在实践中应用,通常情况下即使yy向量是列向量也按照行向量来进行求导。

根据这句话可以得出, 一般情况下是行向量对列向量求导.

行向量XX对列向量YY求导会形成一个矩阵, 矩阵的宽度是XX的长度, 矩阵的高度是YY的长度

所以式5中的问号矩阵应该是一个行向量oto_t^*对列向量ss求导
Lst=otst×Lot(6) \cfrac{\partial L}{\partial s_t}=\cfrac{\partial o_t^*}{\partial s_t}\times\cfrac{\partial L}{\partial o_t^*}\tag{6}
式6中的otst\cfrac{\partial o_t^*}{\partial s_t}还可以继续求出结果
otst=[o1s1o2s1o3s1o1s2o2s2o3s2o1s3o2s3o3s3o1s4o2s4o3s4]=[V11V21V31V12V22V32V13V23V33V14V24V34]=VT \begin{aligned} \cfrac{\partial o_t^*}{\partial s_t}&=\begin{bmatrix}\cfrac{\partial o_1^*}{\partial s_1}&\cfrac{\partial o_2^*}{\partial s_1}&\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial o_1^*}{\partial s_2}&\cfrac{\partial o_2^*}{\partial s_2}&\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial o_1^*}{\partial s_3}&\cfrac{\partial o_2^*}{\partial s_3}&\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial o_1^*}{\partial s_4}&\cfrac{\partial o_2^*}{\partial s_4}&\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\ &=\begin{bmatrix}V_{11}&V_{21}&V_{31}\\V_{12}&V_{22}&V_{32}\\V_{13}&V_{23}&V_{33}\\V_{14}&V_{24}&V_{34}\end{bmatrix}\\ &=V^T \end{aligned}
上面的结果带入式6中得到
Lst=otst×Lot=VT×Lot(7) \cfrac{\partial L}{\partial s_t}=\cfrac{\partial o_t^*}{\partial s_t}\times\cfrac{\partial L}{\partial o_t^*}=V^T\times\cfrac{\partial L}{\partial o_t^*}\tag{7}
到此为止, 所以涉及到的技术都已经写完了, 把求导结果都填到语法树上后
RNN求导公式详细推导

分析后面发现, 后面的结构都是对前面的规律的简单重复.

所以后面随便填两个吧!
RNN求导公式详细推导

相关文章: