Momentum Contrast for Unsupervised Visual Representation Learning
期待代码开源
Summary
- 提出了Momentum Contrast(MoCo)的无监督视觉表示学习方法;
- 把对比学习比作查字典过程,把字典当作队列,引入动量更新;
- 对end-to-end、Memory Bank、MoCo三种对比学习方式进行了比较;
- 在ImageNet和Instagram 数据集上进行大规模的训练和对比实验,并将训练后的特征迁移至下游任务进行实验。
Problem Statement
- 无监督表示学习在NLP领域中取得巨大的成功如GPT、BERT,但有监督预训练仍是计算机视觉主流方法。
Methods
1.Contrastive Learning as Dictionary Look-up
Contrastive learning 即对比学习,可以将其视为训练一个编码器进行字典查询的任务。
假设有一个编码的查询q以及一组编码的样本(字典中的键值),假设q和一个字典里单独的键值匹配,Contrastive loss 即对比损失,它的取值低表示q和键值(positive key)相似而和字典中其它所有键值(negative keys)不相似。
定义一种对比损失函数InfoNCE,形式为:
用来控制concentration level of distribution。
对比损失作为无监督的目标函数用来训练编码器网络来表示查询(queries)和键值(keys),
查询表示为,为一个编码器网络,为查询样本,同样有,输入的具体形式由特定的任务决定。
2.Momentum Contrast
从上述分析来看,对比学习是一种在高维连续输入(如图片)中建立离散字典的方法,字典是动态的,键值是随机采样得到的,并且key encoder在训练中进行更新。假设好的特征可以通过包含大量negative样本的字典中学习而来,并且key encoder能够在更新中尽可能保持一致,基于这种思想作者提出了MoCo算法。
Dictionary as a queue. 方法的核心是将词典保持为数据样本队列。这样可以重新利用当前mini-batch中已编码的键值。同时队列能够将字典大小和mini-batch大小进行解耦,字典大小可以远远大于mini-batch的大小,可被当作超参数。由于mini-batch遵循先进先出的准则,字典总是表示一个所有数据的子集。
Momentum update. 使用队列可以扩充字典的大小,但是对键值编码器key encoder进行反向传播变得更难(梯度会在队列中的所有数据进行传播)。而简单地将query encoder 直接复制给key encoder ,这样快速地改变key encoder会破坏键值表示的一致性。于是作者提出动量更新方法:
只有通过反向传播更新,的变换更加平滑,这样一来,尽管队列中的键值被不同的编码器进行编码,但是这些编码器的差别很小,在实验中,大的动量(例如0.999)往往效果好于小的动量(例如0.9),意味着缓慢变化的key encoder是利用好队列的关键所在。
Relations to previous mechanisms.
以上三种方法的不同之处在于对键值的保持方式以及键值编码器的更新方法的不同。
a方法,字典大小和mini-batch大小相同,受限于GPU显存,对大的mini-batch进行优化也是挑战,有些pretexts进行了一些调整,能够使用更大的字典,但是这样不方便进行迁移使用。
b方法,Memory Bank包含数据集所有数据的特征表示,从Memory Bank中采样数据不需要进行反向传播,所以能支持比较大的字典,然而一个样本的特征表示只在它出现时才在Memory Bank更新,因此具有更少的一致性,而且它的更新只是进行特征表示的更新,不涉及encoder。
3.Pretext Task
将一对查询query和以及键值key组成样本对,如果它们出自同一图像,那么是正样本对,否则为负样本对。查询和键值分别编码自和。在随机数据增强下从同一图像中任意提取两个"view"构建正样本对,负样本取自队列。
Technical details. 使用ResNet作为编码器,最后一层输出为128D向量,即查询query和键值key的表示。
Shuffling BN. 在实验中发现Batch Norm会阻止模型学到良好的特征表示。模型似乎会欺骗pretext task并容易找到低损失的解决方案。可能是因为由BN导致的intra-batch communication among samples泄露了信息。
作者通过Shuffling BN来解决该问题。在训练时使用多个GPU,在每个GPU上分别进行BN(常规操作),对于键值编码器,在当前mini-batch中打乱样本的顺序,再把它们送到GPU上分别进行BN,然后再恢复样本的顺序;对于查询编码器,不改变样本的顺序。这能够保证用于计算查询和其正键值的批统计值出自两个不同的子集。
Experiment
ImageNet-1M (IN-1M)、Instagram-1B (IG-1B)10亿图片数据集。
训练:使用ResNet-50,SGD优化器,weight_decay=0.0001,momentum=0.9
对于(1N-1M):mini-batch size=256,8GPUs,初始学习率为0.03,训练200epochs,在120~160epoch时将学习率乘以0.1,花费72小时;
对于(1G-1B):mini-batch size=1024,64GPUs,初始学习率为0.12,指数衰减(每62.5k iterations乘以0.9)训练1.25M iterations,花费6天。
上图为end-to-end、memory bank、MoCo三种对比损失方法在ImageNet线性分类评价下的对比结果。
上图为ImageNet上,MoCo和其它方法在线性分类评价下的对比结果。
上图为PASCAL VOC trainval07+12上进行微调的目标检测结果。
上图为end-to-end、memory bank、MoCo三种对比损失方法在PASCAL VOC目标检测的结果。
上图为MoCo与之前方法在PASCAL VOC trainval2007上微调的目标检测结果对比。
上图为在COCO上微调的目标检测和实例分割结果。
上图为MoCo和ImageNet上有监督预训练并微调的任务的对比。
相关工作
-
Pretext tasks. The term “pretext” implies that the task being solved is not of genuine interest, but is solved only for the true purpose of learning a good data representation.
术语“pretext”表示要解决的任务不是真正意义上的,而是仅出于学习良好数据表示的真正目的。
-
loss functions. 常被用来研究pretext task的独立性,包括但不限于
-
Contrastive losses [1] measure the similarities of sample pairs in a representation space
-
Adversarial losses [2] measure the difference between probability distributions,GAN和NCE有着相关联系
-
部分论文引用情况
-
Dimensionality reduction by learning an invariant mapping. In CVPR, 2006.[1] (Contrastive Learning)
-
Generative adversarial nets. In NIPS, 2014.[2]
-
Unsupervised feature learning via non-parametric instance discrimination. In CVPR, 2018. [3]
…
Notes
-
本文受[3]的启发挺多的。
-
[3]的作者在Improving Generalization via Scalable Neighborhood Component Analysis这篇文章里提到了动量更新Memory Bank的方法。
-
作者在文章实验部分"Shuffling BN"中写道"batch norm prevents the model from learning good representations"。