由于最近要求用numpy数组实现MLP算法, 对鸢尾花进行分类, 并且使用随机梯度下降法进行网络权值更新, 因此将其推导过程记下备忘
下图即为三层网络反向传播的简易推导
注: 一般都只在隐藏层加偏置, 因此输出层偏置忽略就行
1. 正向传播
正向传播的完整公式如下, 其中输入为 x i x_{i} xi , 本文即为鸢尾花的四个特征输入
隐藏层的网络权值为 w j i ( 1 ) w_{j i}^{(1)} wji(1) , **函数为 tanh ( x ) \tanh (x) tanh(x)
输出层的网络权值为
w
k
j
(
2
)
w_{k j}^{(2)}
wkj(2) , 输出label采用One-Hot编码, 最终的输出为
y
k
y_{k}
yk
a
j
=
∑
i
=
0
d
w
j
i
(
1
)
x
i
\begin{aligned} a_{j} =\sum_{i=0}^{d} w_{j i}^{(1)} x_{i} \end{aligned}
aj=i=0∑dwji(1)xi
z j = tanh ( a j ) \begin{aligned} z_{j} =\tanh \left(a_{j}\right) \end{aligned} \\ zj=tanh(aj)
y k = ∑ j = 0 K w k j ( 2 ) z j \begin{aligned} y_{k} &=\sum_{j=0}^{K} w_{k j}^{(2)} z_{j} \end{aligned} yk=j=0∑Kwkj(2)zj
2. 损失函数计算
对于输出的(3,1)的
l
a
b
e
l
:
y
k
label:y_{k}
label:yk , 采用均方差误差计算损失,
E
n
=
1
2
∑
k
=
1
L
(
y
k
−
t
k
)
2
E_{n} = \frac{1}{2} \sum_{k=1}^{L}\left(y_{k}-t_{k}\right)^{2}
En=21k=1∑L(yk−tk)2
3. 反向传播
3.1 输出层网络权值 w k j w_{k j} wkj的梯度求解
损失函数对输出层网络权值
w
k
j
w_{k j}
wkj的梯度为
∂
E
n
∂
w
k
j
(
2
)
\frac{\partial E_{n}}{\partial w_{k j}^{(2)} }
∂wkj(2)∂En ,
∂
E
n
∂
w
k
j
(
2
)
=
∂
E
n
∂
y
k
∂
y
k
∂
w
k
j
(
2
)
\begin{aligned} \frac{\partial E_{n}}{\partial w_{k j}^{(2)}}= \frac{\partial E_{n}}{\partial y_{k}} \frac{\partial y_{k}}{\partial w_{k j}^{(2)}} \end{aligned}
∂wkj(2)∂En=∂yk∂En∂wkj(2)∂yk
∂ E n ∂ w k j ( 2 ) = ( y k − t k ) ⊗ z j \frac{\partial E_{n}}{\partial w_{k j}^{(2)} } = (y_{k}-t_{k}) \otimes z_{j} ∂wkj(2)∂En=(yk−tk)⊗zj
3.2 隐藏层网络权值 w j i w_{j i} wji的梯度求解
损失函数对隐藏层网络权值
w
j
i
w_{j i}
wji的梯度为
∂
E
n
∂
w
j
i
(
1
)
\frac{\partial E_{n}}{\partial w_{j i}^{(1)} }
∂wji(1)∂En ,
∂
E
n
∂
w
j
i
(
1
)
=
∂
E
n
∂
z
j
∂
z
j
∂
a
j
∂
a
j
∂
w
j
i
(
1
)
\begin{aligned} \frac{\partial E_{n}}{\partial w_{j i}^{(1)} }= \frac{\partial E_{n}}{\partial z_{j}} \frac{\partial z_{j}}{\partial a_{j}} \frac{\partial a_{j}}{\partial w_{j i}^{(1)}} \end{aligned}
∂wji(1)∂En=∂zj∂En∂aj∂zj∂wji(1)∂aj
∂ E n ∂ w j i ( 1 ) = ( 1 − z j 2 ) ∗ ∑ k w k j ( y k − t k ) ⊗ x i \begin{aligned} \frac{\partial E_{n}}{\partial w_{j i}^{(1)}}= \left(1-z_{j}^{2}\right)* \sum_{k} w_{k j} (y_{k}-t_{k}) \otimes x_{i} \end{aligned} ∂wji(1)∂En=(1−zj2)∗k∑wkj(yk−tk)⊗xi
4. 网络权值更新公式
对于网络权值
w
k
j
w_{k j}
wkj的参数更新公式为
w
k
j
=
w
k
j
−
η
∂
E
n
∂
w
k
j
(
2
)
\begin{aligned} w_{k j} =w_{kj}-\eta \frac{\partial E_{n}}{\partial w_{k j}^{(2)}} \end{aligned}
wkj=wkj−η∂wkj(2)∂En
对于网络权值
w
j
i
w_{j i}
wji的参数更新公式为
w
j
i
=
w
j
i
−
η
∂
E
n
∂
w
j
i
(
1
)
\begin{aligned} w_{j i}=w_{j i }-\eta \frac{\partial E_{n}}{\partial w_{j i}^{(1)}} \end{aligned}
wji=wji−η∂wji(1)∂En
5. 交叉验证
交叉验证可以用于模型评估, 本次实验要求将数据集分为5个子集, 进行交叉验证, 来验证模型的优劣
使用时, 只需依次分别从5个子集中选取一个为测试集, 其他的四个子集为训练集, 进行网络的训练及评估
保证每一个子集都曾作为测试集, 其他为训练集, 该方法可以避免数据集划分不合理导致的网络过拟合, 即只在训练集表现很好, 测试集很差, 而对网络模型优劣的错误判断
本文仅供参考, 若有错误, 欢迎指出, 共同学习
本文首发于公众号: 钰哥lab
欢迎关注