2021-03-04
作者:董鑫
链接:https://www.zhihu.com/question/66200879/answer/870023448
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
自动求导 (Automatic Differentiation, AD) 的课程 (CS207),正好来回答一下。 其实不只是 TensorFlow,Pytorch 这些为深度学习设计的库用到 AD,很多物理,化学等基础科学计算软件也在大量的使用 AD。而且,其实TensorFlow、Pytorch 也并非只能用于deep learning,本质上他们是一种
Tensor computation built on a tape-based autograd system --引自Pytorch
自动求导分成两种模式,一种是 Forward Mode,另外一种是 Reverse Mode。一般的机器学习库用的后一种,原因后面说。
Forward Mode
基于的就是就基本的 链式法则 chain rule,
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduJTdEKysrJTVDbmFibGFfJTdCeCU3RGgrJTNEKyU1Q3N1bV8lN0JpJTNEMSU3RCU1RSU3Qm4lN0QlN0IlNUNmcmFjJTdCJTVDcGFydGlhbCtoJTdEJTdCJTVDcGFydGlhbCt5XyU3QmklN0QlN0QlNUNuYWJsYSt5XyU3QmklN0QlNUNsZWZ0JTI4eCU1Q3JpZ2h0JTI5JTdELislNUNlbmQlN0JhbGlnbiU3RCs%3D)
这个 Forward Mode 就是用 chain rule,像剥洋葱一样一层一层算出来
以
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rZiU1Q2xlZnQlMjh4JTVDcmlnaHQlMjkrJTNEK3grLSslNUNleHAlNUNsZWZ0JTI4LTIlNUNzaW4lNUUlN0IyJTdEJTVDbGVmdCUyODR4JTVDcmlnaHQlMjklNUNyaWdodCUyOS4rKw%3D%3D)
为例。 我们可以把他的计算图画出来。

假如我要 计算
,可以根据上面的图得到一个表格

那么上面这个表里,每一步我们既要算 forward 的值
,也要算 backward 的值
。
有没有办法同时把这两个值算出来呢?
首先引入一个新的概念,二元数。二元数其实跟复数差不多,也是一种实数的推广。我们回忆一下,一个复数可以写成这样的形式:
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rJTVDYmVnaW4lN0JhbGlnbiU3RCt6KyUzRCthKyUyQitpYislNUMraSU1RTIlM0QtMSslNUNlbmQlN0JhbGlnbiU3RCs%3D)
对于复数的理解,一个比较直观的例子就是。本来实数都是在一个实数轴(x轴)的。复部
相当于多了一个 y 轴出来。
那么二元数是这个亚子,
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduJTdEK3orJTNEK2ErJTJCKyU1Q2Vwc2lsb24rYislNUMrJTVDZXBzaWxvbiU1RTIlM0QwKyU1Q2VuZCU3QmFsaWduJTdEKw%3D%3D)
这个二元数很神奇的一个性质是,你带着他做运算,得出来的二元部
前面的系数,就是导数。举个栗子, 我们要求
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD15JTNEJTVDc2luJTI4eCUyOQ%3D%3D)
我们可以把
,所以
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rJTVDYmVnaW4lN0JhbGlnbiU3RCsrKyt5KyUyNislM0QrJTVDc2luJTVDbGVmdCUyOGErJTJCKyU1Q2Vwc2lsb24rYiU1Q3JpZ2h0JTI5KyslNUMlNUMrKysrKyUyNislM0QrJTVDc2luJTVDbGVmdCUyOGElNUNyaWdodCUyOSU1Q2NvcyU1Q2xlZnQlMjglNUNlcHNpbG9uK2IlNUNyaWdodCUyOSslMkIrJTVDY29zJTVDbGVmdCUyOGElNUNyaWdodCUyOSU1Q3NpbiU1Q2xlZnQlMjglNUNlcHNpbG9uK2IlNUNyaWdodCUyOS4rKyU1Q2VuZCU3QmFsaWduJTdEKw%3D%3D)
我们把上面的三角函数展开,
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rJTVDYmVnaW4lN0JhbGlnbiU3RCsrKyU1Q3NpbiU1Q2xlZnQlMjglNUNlcHNpbG9uK2IlNUNyaWdodCUyOSslMjYlM0QrJTVDc3VtXyU3Qm4lM0QwJTdEJTVFJTdCJTVDaW5mdHklN0QlN0IlNUNsZWZ0JTI4LTElNUNyaWdodCUyOSU1RSU3Qm4lN0QlNUNkZnJhYyU3QiU1Q2xlZnQlMjglNUNlcHNpbG9uK2IlNUNyaWdodCUyOSU1RSU3QjJuJTJCMSU3RCU3RCU3QiU1Q2xlZnQlMjgybiUyQjElNUNyaWdodCUyOSUyMSU3RCU3RCslM0QrJTVDZXBzaWxvbitiKyUyQislNUNkZnJhYyU3QiU1Q2xlZnQlMjglNUNlcHNpbG9uK2IlNUNyaWdodCUyOSU1RSU3QjMlN0QlN0QlN0IzJTIxJTdEKyUyQislNUNjZG90cyslM0QrJTVDZXBzaWxvbitiKyU1QysrKyU1QyU1QyslNUNjb3MlNUNsZWZ0JTI4JTVDZXBzaWxvbitiJTVDcmlnaHQlMjkrJTI2JTNEKyU1Q3N1bV8lN0JuJTNEMCU3RCU1RSU3QiU1Q2luZnR5JTdEJTdCJTVDbGVmdCUyOC0xJTVDcmlnaHQlMjklNUUlN0JuJTdEJTVDZGZyYWMlN0IlNUNsZWZ0JTI4JTVDZXBzaWxvbitiJTVDcmlnaHQlMjklNUUlN0IybiU3RCU3RCU3QiU1Q2xlZnQlMjgybiU1Q3JpZ2h0JTI5JTIxJTdEJTdEKyUzRCsxKyUyQislNUNkZnJhYyU3QiU1Q2xlZnQlMjglNUNlcHNpbG9uK2IlNUNyaWdodCUyOSU1RSU3QjIlN0QlN0QlN0IyJTdEKyUyQislNUNjZG90cyslM0QrMS4rJTVDZW5kJTdCYWxpZ24lN0Qr)
得到
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNiZWdpbiU3QmFsaWduJTdEKysreSslMjYrJTNEKyU1Q3NpbiU1Q2xlZnQlMjhhJTVDcmlnaHQlMjkrJTJCKyU1Q2NvcyU1Q2xlZnQlMjhhJTVDcmlnaHQlMjkrYislNUNlcHNpbG9uLislNUNlbmQlN0JhbGlnbiU3RCs%3D)
可以看到,二元部
恰好就是原函数
的导数。
Reverse Mode
这个模式就比较简单和直接了。就是说,上面那个表里面,我每次只计算每个“小运算”的梯度(也是是那个图里面的每个节点),最后我再根据 chain rule 把“小运算”们的梯度串起来。其实 forward mode 和 reverse mode 并没有本质的区别,只是说,reverse mode在计算梯度先不考虑 chain rule,最后再用 chain rule 把梯度组起来。而前者则是直接就应用 chain rule 来算梯度。
下面总结一下 reverse mode 的流程:
- 创建计算图
- 计算前向传播的值及每个操作的梯度
- 这里没有
chain rule 的事
- 比如这个操作是乘法 $x_3 = x_1*x_2$,那么我们只需要把
算出来就好了
- 反向计算梯度从最后一个节点(操作)开始:
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0lNUNvdmVybGluZSU3QnglN0RfJTdCTiU3RCslM0QrJTVDZGZyYWMlN0IlNUNwYXJ0aWFsK2YlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCTiU3RCU3RCslM0QrMSslNUMrJTVDKyU1QyslMkNmJTNEeF9O)
- 根据
chain rule 逐层推进
- 假如有多条求导路径,我们要把他们加起来,例如
举个栗子,我们要计算函数
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rZiU1Q2xlZnQlMjh4JTJDeSU1Q3JpZ2h0JTI5KyUzRCt4eSslMkIrJTVDZXhwJTVDbGVmdCUyOHh5JTVDcmlnaHQlMjkr)
在点
的导数
首先还是先把计算图画出来

我们逐层的抽丝剥茧,
![[公式]](/default/index/img?u=aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD0rJTVDYmVnaW4lN0JhbGlnbiU3RCslNUNvdmVybGluZSU3QnglN0QlN0I1JTdEKyUyNiUzRCslNUNkZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreCU3QjUlN0QlN0QrJTNEKzElNUMlNUMrKyU1Q292ZXJsaW5lJTdCeCU3RCU3QjQlN0QrJTI2JTNEKyU1Q2RmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4JTdCNSU3RCU3RCU1Q2RmcmFjJTdCJTVDcGFydGlhbCt4XyU3QjUlN0QlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCNCU3RCU3RCslM0QrMSslNUNjZG90KzErJTNEKzElNUMlNUMrKyU1Q292ZXJsaW5lJTdCeCU3RCU3QjMlN0QrJTI2JTNEKyU1Q2RmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4JTdCNCU3RCU3RCU1Q2RmcmFjJTdCJTVDcGFydGlhbCt4XyU3QjQlN0QlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCMyU3RCU3RCslMkIrJTVDZGZyYWMlN0IlNUNwYXJ0aWFsK2YlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCNSU3RCU3RCU1Q2RmcmFjJTdCJTVDcGFydGlhbCt4XyU3QjUlN0QlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCMyU3RCU3RCUzRCsxKyU1Q2Nkb3QrZSU1RSU3QjIlN0QrJTJCKzElNUNjZG90KzErJTNEKzErJTJCK2UlNUUlN0IyJTdEJTVDJTVDKyU1Q292ZXJsaW5lJTdCeCU3RCU3QjIlN0QrJTI2JTNEKyU1Q2RmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4JTdCMyU3RCU3RCU1Q2RmcmFjJTdCJTVDcGFydGlhbCt4XyU3QjMlN0QlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCMiU3RCU3RCslM0QrJTVDbGVmdCUyODErJTJCK2UlNUUlN0IyJTdEJTVDcmlnaHQlMjl4XyU3QjElN0QrJTNEKysxKyUyQitlJTVFJTdCMiU3RCslM0QrJTVDZGZyYWMlN0IlNUNwYXJ0aWFsK2YlN0QlN0IlNUNwYXJ0aWFsK3klN0QrJTVDJTVDKyU1Q292ZXJsaW5lJTdCeCU3RCU3QjElN0QrJTI2JTNEKyU1Q2RmcmFjJTdCJTVDcGFydGlhbCtmJTdEJTdCJTVDcGFydGlhbCt4JTdCMyU3RCU3RCU1Q2RmcmFjJTdCJTVDcGFydGlhbCt4XyU3QjMlN0QlN0QlN0IlNUNwYXJ0aWFsK3hfJTdCMSU3RCU3RCslM0QrKyU1Q2xlZnQlMjgxKyUyQitlJTVFJTdCMiU3RCU1Q3JpZ2h0JTI5eF8lN0IyJTdEKyUzRCsrMislMkIrMmUlNUUlN0IyJTdEKyUzRCslNUNkZnJhYyU3QiU1Q3BhcnRpYWwrZiU3RCU3QiU1Q3BhcnRpYWwreCU3RCsrJTVDZW5kJTdCYWxpZ24lN0Qr)
总结
- 可以很清楚的看到,在训练人工神经网络时常用的
backpropagation 也是属于 reverse mode 的。
- 假如我们要计算的梯度的函数是
- 如果 n 是相对比较大的话,用
forward 比较省计算
- 如果 m 是相对比较大的话,用
reverse 比较省计算