论文标题:
Thomas Bachlechner, Bodhisattwa Prasad Majumder, Huanru Henry Mao, Garrison W. Cottrell, Julian McAuley
论文作者:
Thomas Bachlechner, Bodhisattwa Prasad Majumder, Huanru Henry Mao, Garrison W. Cottrell, Julian McAuley
论文链接:
https://arxiv.org/abs/2003.04887
代码链接:
https://github.com/majumderb/rezero
在NLP领域,如Transformer的深度模型由于梯度消失/爆炸难以训练,往往需要花费很长的时间才能收敛。
为此,本文提出ReZero:在残差连接前增加一个权重,使模型能够更好接受到梯度信号,加快收敛速度。
这种方法能在上百层的Transformer上收敛,并在常见深层模型上大大缩短训练时间,同时取得相近的结果。
深度模型的训练问题
一般来说,模型越深效果越好。但是同时,模型越深也更难训练——即无法收敛到训练集上一个好的值。
普遍认为,这是深度模型的梯度消失/梯度爆炸现象导致的:梯度的指数级爆炸使得训练极其不稳定,指数级消失使得训练非常缓慢。
拿Transformer来说,模型越深,底层收到的梯度就越低,如下图所示(来自Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention):
比如看图中红色曲线,在高层的时候,梯度的范数尚且较大,但当来到底层的时候,梯度的范数就非常小,特别是在Decoder端,几乎趋于零。这说明了深层模型仍然有比较严重的梯度消失问题。
那么,这个问题的解决方案有什么呢?在计算机视觉上,Normalization和残差连接有很大帮助,但在NLP领域,它们的作用就相当有限,因为Transformer本身就内置了残差连接和Normalization技术。
为了能够更好地解决NLP模型上的训练问题,本文提出了一种非常简单的方法:ReZero——对每个残差连接加上一个可学习的系数——使得模型能在训练初期更加稳定,进而促进整个训练过程。
本方法在12层的Transformer上取得约56%的训练加速,在32层前馈网络上取得15倍的加速,在ResNet54上实现了32%的加速。
更重要的是,该方法能够直接在128层的Transformer上取得收敛,同时还能丢掉Transformer中的Normalization和训练时的warm-up环节。
在阅读完本文后,读者可以思考如下问题:
ReZero和Highway的区别;
和Transformer中warmup训练的联系。
ReZero:加权的残差连接
首先来简要回顾一下残差连接。对一个层的模型,第
层的输入(也即第
层的输出)可以表示为:
而ReZero做的很简单,它只是在残差连接前加了一个可学习的参数:
并且所有的都初始为0。我们通过一个简单的实验说明并分析该方法的有效性。
设现在有个模型有层,每一层只有一个神经元,而且所有层都共享这一个神经元
,那么,模型的输出就可以表示为:
显然,当时就是残差网络,当初始化
时就是ReZero。此时,输入对输出的Jacobian矩阵就是
。
此时,如果初始化且
,那么就趋近
,这就会导致输入很小的扰动造成输出的巨大变化,从而梯度非常不稳定。
而当时,这种不稳定就会大大缓解。下面我们用反向传播加以解释。设学习率为
,损失函数为
,那么参数
的更新就是:
在这个式子里,我们主要看学习率,参数
和
。当
,上式右边第二项(忽略损失项和输入)成为
,这个时候为了让训练更加稳定,学习率就需要正比于
。
但是对于过大的和接近-1的参数
,单纯依靠学习率的手动调整就难以解决这个问题了。
当初始化的时候,上述问题就可以得到解决。因为初始化
,从而第一轮的梯度更新没有更新参数
,但是
是可以更新的:
这样,在下一轮梯度更新的时候,由于此时不为0,则模型参数
就可以更新。把当前的
带入到上面更新
的式子中,我们有:
比较两式,其最大的不同在于括号内的损失函数。换言之,如果损失函数是合理的,那么当前的更新就不会导致过分的梯度波动(
),每次梯度下降使得模型参数既可以更新,又不会不稳定。
下图是模型拟合的对数等高线示意图,左图是使用二次损失的损失图,右图是对应的梯度范数。红线是初始化
和不同的
初始化在训练中的变化轨迹。我们注意到以下事实:
当以
初始化时,不同的
初始化最终都能达到很好的损失(红色)和梯度范数(白色);
观察右图,当以
,
初始化时,梯度非常小,意味着梯度消失,而只有在
这一条线上梯度范数比较好,然而此时模型根本无法收敛(左图);
这说明,模型既要保证平稳训练(梯度不能太大或太小),也要保证最终收敛(取得比较小的损失)。
实验与分析
训练全连接模型
首先我们在简单的全连接模型上实验。该全连接模型**函数为ReLU,层数为32,数据集为CIFAR-10,比较的模型有:
FC:单纯的全连接
FC+Res:全连接和残差结合
FC+Norm:全连接和normalization
FC+Res+Pre-Norm: 全连接和残差和pre-normalization
FC+Res+Post-Norm:全连接和残差和post-normalization
FC+ReZero:
注意到,ReZero没有使用Normalization。下图是各个模型的收敛情况:
可以看到,ReZero比其他三个模型能够快7~15倍收敛,而且,当我们把全连接层加到10K层,ReZero依然可以收敛,而其他模型却无法收敛。
训练Transformer
我们的重点是训练Transformer,一个被普遍认为非常难以训练的模型。我们把Transformer的结构修正为:
同样,我们丢掉了Transformer本来的Normalization,而且也不需要warm-up。下图呈现了Transformer输出-输入的Jacobian矩阵中的奇异值在不同层的模型下的分布。
先看左图,当层数较少(4)时,保持在
附近(图中取了对数),这就表明输入的变化既不会造成输出的太大变化,也能使得模型得以训练。
而当层数增加,尤其是增加到64层时,该值非常小,这说明无论输入怎么改变,模型的输出都八九不离十,模型难以训练。
再看右图是使用ReZero对64层的Transformer训练的情况。在一开始的时候,由于被初始化为0,模型参数无法更新,值
都是1,而当模型开始训练后,
也能保持以1为中心的分布,使得模型能够训练下去。
下面来具体看ReZero的速度提升和最终的效果。下表是12层的Transformer使用各方法在enwiki8数据集上的收敛情况。可以看到,ReZero只需要8800次迭代,而原始Transformer需要13690次迭代,速度方面加快了56%。
下表是各方法在enwiki8测试集上的结果。128层的Transformer能够在ReZero下收敛,并且取得了和Char-level的Transformer非常相似的结果,而参数量只有后者的一半。
最后我们来可视化64层Transformer在训练过程中,各层的系数参数的变化情况,如下图所示。
由于我们初始化,故在训练的开始阶段所有层的系数都比较小,这可以使得各层训练稳定。
当模型训练到一定阶段之后,输入的扰动不会造成输出的剧烈改变,系数开始增大,这是模型加速训练、训练的阶段。
在此之后,系数再一次减小,使得模型训练进入第三个阶段,直到最终收敛。
训练ResNet
最后,我们也将ReZero应用到了ResNet56上。在CIFAR-10上的结果显示,使用ReZero可以降低基线模型的Error rate,同时也加快了约30%的收敛速度。
小结
本文提出了一个非常简单但是非常有效的残差网络变体——ReZero。和残差网络不同,它是在残差连接前增加一个可学习的系数,并初始化为0,使得模型在训练初期更加稳定。
实验表明,运用这种方法可以直接在非常深的Transformer上收敛,而且丢掉了原有的Normalization和warmup。如果它能经得住大量实验的检验,那么ReZero非常有希望成为今后Transformer的标配。
思考讨论
-
ReZero和Highway的区别;Highway的公式是
,其中
表示各部分前馈的信息量。
在开始的时候,
初始化为负值,这时候,
就趋近于0,也就是上式趋近于
,这实际上是和ReZero的想法是一致的——在训练开始的时候保持稳定有助于深层模型的收敛。
不同的是,ReZero更加简单粗暴——直接使用一个可学习的系数而不是所谓的门控,降低了模型学习的难度,简化了复杂度,从而更有利于模型控制训练过程。
-
和Transformer中warmup训练的联系。在我们的回答神经网络中 warmup 策略为什么有效;有什么理论解释么?中,我们说:“刚开始模型对数据的“分布”理解为零,在第一轮训练的时候,每个数据点对模型来说都是新的,模型会很快地进行数据分布修正,如果这时候学习率就很大,极有可能导致开始的时候就对该数据“过拟合”,后面要通过多轮训练才能拉回来,浪费时间。
当训练了一段时间(比如两轮、三轮)后,模型已经对每个数据点看过几遍了,或者说对当前的batch而言有了一些正确的先验,较大的学习率就不那么容易会使模型学偏,所以可以适当调大学习率。
这个过程就可以看做是warmup。那么为什么之后还要decay呢?当模型训到一定阶段后(比如十个epoch),模型的分布就已经比较固定了,或者说能学到的新东西就比较少了。
如果还沿用较大的学习率,就会破坏这种稳定性,用我们通常的话说,就是已经接近loss的local optimal了,为了靠近这个point,我们就要慢慢来。”
从本文最后的可视化图中可见,我们的解释是完全符合这种直觉的。所以说,ReZero和warmup的作用都是为了训练的稳定,只是一个选择了自动学习,一个选择手动调整。
????
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。