Meta-learning with implicit gradients--nips19

论文思想

原始的MAML算法一个很大的挑战是外循环(元更新)需要通过对内循环(梯度自适应)过程进行求导,一般就要求存储和计算高阶导数。这篇论文的核心是利用隐微分方法,求解过程只需要内循环优化的解,而不需要整个内循环优化器的优化过程。
好处:①这样就将元梯度计算(外循环)和内循环优化器的选择解耦,可以任意选择内层优化器;②多步梯度不再有梯度消失或者存储约束
论文笔记--Meta learning with implicit gradient
上图可知,MAML算法需要对内循环优化路径进行求导来计算元梯度,一阶MAML简单的将dϕidθ\frac{d\phi_i}{d\theta}置为II来进行估计;iMAML通过估计local curvature推导出准确的元梯度解析表达式(用内循环的solution而不是对solution的求导来表达元梯度),而不用对整个优化路径进行求导。
这样的好处有:不用存储和求导优化路径,能有效地在内循环中应用多步梯度;整个方法与内优化方法的选择无关,只要能得到内循环优化问题的一个估计解就行。这样可以应用高阶方法甚至不可导的优化方法。

Few-shot case formula

θML:=argminθΘF(θ)outer-lever, where F(θ)=1Mi=1ML(Alg(θ,Ditr)inner-level ,Ditest )\overbrace{\boldsymbol{\theta}_{\mathrm{ML}}^{*}:=\underset{\boldsymbol{\theta} \in \Theta}{\operatorname{argmin}} F(\boldsymbol{\theta})}^{\text{outer-lever}}, \text { where } F(\boldsymbol{\theta})=\frac{1}{M} \sum_{i=1}^{M} \mathcal{L}\left(\overbrace{\mathcal{A} l g\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\mathrm{tr}}\right)}^{\text {inner-level }}, \mathcal{D}_{i}^{\text {test }}\right)

公式中Alg\mathcal{A} l g代表内循环的算法,输出的是自适应任务的优化参数。为了防止过拟合,可以在内循环过程中加入正则项:
Alg(θ,Ditr)=argminϕΦL(ϕ,Ditr)+λ2ϕθ2\mathcal{A} l g^\star\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\mathrm{tr}}\right)=\arg\min_{\phi'\in\Phi}\mathcal{L}(\phi',\mathcal{D}^{tr}_{i})+\frac{\lambda}{2}||\phi'-\theta||^2
这里θ\theta是我们要求的元参数(即模型初始化),内循环过程中看做一个常量,在外循环中梯度更新求解,内循环过程实际变量是自适应参数ϕ\phi'\star表示可准确求解,实际当中使用梯度迭代法只能返回估计的最优值。进一步的双阶段优化问题可改写为:
θML:=argminθΘF(θ), where F(θ)=1Mi=1MLi(Algi(θ)), and Algi(θ):=argminϕΦGi(ϕ,θ), where Gi(ϕ,θ)=L^i(ϕ)+λ2ϕθ2\begin{array}{l}{\boldsymbol{\theta}_{\mathrm{ML}}^{*}:=\underset{\boldsymbol{\theta} \in \Theta}{\operatorname{argmin}} F(\boldsymbol{\theta}), \text { where } F(\boldsymbol{\theta})=\frac{1}{M} \sum_{i=1}^{M} \mathcal{L}_{i}\left(\mathcal{A} l g_{i}^{\star}(\boldsymbol{\theta})\right), \text { and }} \\ {\mathcal{A} l g_{i}^{\star}(\boldsymbol{\theta}):=\underset{\boldsymbol{\phi}^{\prime} \in \Phi}{\operatorname{argmin}} G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right), \text { where } G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}}\end{array}其中
Li(ϕ):=L(ϕ,Ditest ),L^i(ϕ):=L(ϕ,Ditr ),Algi(θ):=Alg(θ,Ditr )\mathcal{L}_{i}(\phi):=\mathcal{L}\left(\phi, \mathcal{D}_{i}^{\text {test }}\right), \quad \hat{\mathcal{L}}_{i}(\phi):=\mathcal{L}\left(\phi, \mathcal{D}_{i}^{\text {tr }}\right), \quad \mathcal{A} l g_{i}(\boldsymbol{\theta}):=\mathcal{A} l g\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\text {tr }}\right)d,d,\nabla分别表示全导数和偏导数,根据链式法则,我们知道元梯度可写为:
dθLi(Algi(θ))=dAlgi(θ)dθϕLi(ϕ)ϕ=Algi(θ)=dAlgi(θ)dθϕLi(Algi(θ))d_{\boldsymbol{\theta}}\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))=\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}\nabla_\phi\mathcal{L}_i(\phi)|_{\phi=\mathcal{A} l g_{i}(\boldsymbol{\theta})}=\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}\nabla_\phi\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))

Implicit MAML Algorithm

上式中ϕLi(Algi(θ))=ϕLi(ϕ)ϕ=Algi(θ)\nabla_\phi\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))=\nabla_\phi\mathcal{L}_i(\phi)|_{\phi=\mathcal{A} l g_{i}(\boldsymbol{\theta})}在求解出Algi(θ)\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})(利用梯度下降或其他优化方法)后,很容易计算。而dAlgi(θ)dθ\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}的计算比较复杂,直接利用导数传递涉及到高阶导数,且需要记录整个更新过程。将内循环(自适应)过程的结果ϕi=Algi\phi_i = \mathcal{A} l g^\star_{i}隐式地定义为优化问题的solution。那么可以采用一种不需要考虑优化路径的方法来计算ϕi\phi_i(Lemma 1):
dAlgi(θ)dθ=(I+1λϕ2L^i(ϕi))1\frac{d\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}=\left(\boldsymbol{I}+\frac{1}{\lambda}\nabla^2_\phi\hat{\mathcal{L}}_i(\phi_i)\right)^{-1}
证明:ϕi=Algi\phi_i = \mathcal{A} l g^\star_{i}是函数Gi(ϕ,θ)=L^i(ϕ)+λ2ϕθ2G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}的最小值的时候满足一阶必要条件,即一阶梯度为0:
ϕG(ϕ,θ)ϕ=ϕi=0L^(ϕi)+λ(ϕiθ)=0ϕi=θ1λL^(ϕi)\left.\nabla_{\boldsymbol{\phi}^{\prime}} G\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)\right|_{\boldsymbol{\phi}^{\prime}=\boldsymbol{\phi}_i}=0 \Longrightarrow \nabla \hat{\mathcal{L}}(\boldsymbol{\phi}_i)+\lambda(\boldsymbol{\phi}_i-\boldsymbol{\theta})=0 \Longrightarrow \boldsymbol{\phi}_i=\boldsymbol{\theta}-\frac{1}{\lambda} \nabla \hat{\mathcal{L}}(\boldsymbol{\phi}_i)上式是常见的隐等式,当倒数存在的时候,上式左右两边同时对θ\boldsymbol{\theta}求导有:
dϕidθ=I1λ2L^(ϕi)dϕidθ(I+1λ2L^(ϕi))dϕidθ=I\frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}}=I-\frac{1}{\lambda} \nabla^{2} \hat{\mathcal{L}}(\boldsymbol{\phi}_i) \frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}} \Longrightarrow\left(I+\frac{1}{\lambda} \nabla^{2} \hat{\mathcal{L}}(\boldsymbol{\phi_i})\right) \frac{d \boldsymbol{\phi_i}}{d \boldsymbol{\theta}}=I

Practical Algorithm

上式中dϕidθ=dAlgi(θ)dθ\frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}}=\frac{d\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}在计算中有两个困难,首先Algi\mathcal{A} l g^\star_{i}是精确的解,而通过内循环优化得到的往往只是估计解;除此之外,计算还涉及到求逆和二阶导,这对深度神经网络是很难的。本文采取估计的方法对上式求解进行简化,核心公式为:
gi(I+1λϕ2L^i(ϕi))1ϕLi(ϕi)δ\left\|\boldsymbol{g}_{i}-\left(I+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right)^{-1} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)\right\| \leq \delta^{\prime}式中gi\boldsymbol{g}_i即为对元梯度dθLi(Algi(θ))d_{\boldsymbol{\theta}}\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))的估计,ϕi\boldsymbol{\phi}_i是对最优值Algi\mathcal{A} l g^\star_{i}的估计,利用梯度优化迭代法什么的求解。那么进一步的上述gi\boldsymbol{g}_i的求解可转化成一个二次型优化问题:
minw12w(I+1λϕ2L^i(ϕi))wwϕLi(ϕi)\min _{\boldsymbol{w}} \frac{1}{2}\boldsymbol{w}^{\top}\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right) \boldsymbol{w}-\boldsymbol{w}^{\top} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)这样可以利用共轭梯度法快速求解。过程中只需要计算2Li^(ϕi)v\nabla^2\hat{\mathcal{L}_i}(\boldsymbol{\phi}_i)\boldsymbol{v}v\boldsymbol{v}是共轭梯度)

相关文章: