1. 前言

在深度学习中,自监督学习或半监督学习是一个很有意思并且非常吸引人的领域。监督学习的损失函数比较直观和简洁,跟真值有直接的联系。而自监督学习或半监督学习的误差函数的构建比较复杂,需要从原始数据中挖掘出有价值的信息作为损失函数。并且损失函数的设计高度依赖于领域知识。但是这一类学习方法的好处是可以在标注数据较少甚至没有的场合完成网络的训练。对比学习(Contrastive Learning)是近年来在自监督学习中比较火热的领域。相关知识可以参考这篇知乎。这篇博文主要去简单分析一篇对比学习的论文“A Simple Framework for Contrastive Learning of Visual Representations”。这篇文章用对比学习做图像目标分类。在讲解这篇文章的时候,我们以图像目标分类作为讲解例子。

2. 对比学习

对比学习归属于自监督学习,所以对比学习是没有真值标签的。对比学习的示意图如下所示。以图像目标分类为例子讲解下图。假设xx是一张印有小狗的图像。我们要设计一个网络Net()\text {Net}(\cdot),使得它有图像分类的能力,即z=Net(x)z={Net}(x)zz表示小狗所在的那一类。由上图所示,网络的结构是Net()=g()f()\text {Net}(\cdot)=g(\cdot) \circ f(\cdot)。其中h=f(x)h=f(x)表示图像的特征。

小白科研笔记:理解对比学习
图1:对比学习的示意图

上图可见xi=t(x)x_i=t(x)以及xj=t(x)x_j=t'(x)。以及t,tTt,t' \in TTT表示二维图像数据增强的操作集合。这个集合包括,图像旋转,图像对称,图像噪声,图像剪裁等等操作,示意图如下所示。举个例子,比如tt可以是图像旋转,tt'可以是图像噪声。那么数据增强后的图像xix_ixjx_j对应下图中的(f)(h)。当然,原始图像xx对应下图中的(a)

小白科研笔记:理解对比学习
图2:二维图像数据增强的种种操作

把数据增强后的图像xix_ixjx_j放入网络里面,可以输出zi=Net(xi)z_i=\text{Net}(x_i)zj=Net(xj)z_j=\text{Net}(x_j),表示网络从这些图像中预测的结果。因为对比学习没有真值,那应该怎样设计误差函数,指导ziz_izjz_j趋于正确的分类呢?论文作者认为,误差函数应该设计为ziz_izjz_j的差异度。不管ziz_izjz_j预测成什么,我只希望zi=zjz_i=z_j,除此之外别我他求。要么这两张图都预测为猫,要么这两张图都预测为狗。论文作者认为,只有当zi=zjz_i=z_j的时候,网络才能真正地从两幅数据增强的图像中学到它们之间的通用特征。而通用特征则是目标识别的关键。

肯定会有读者质疑,如果仅仅是以zi=zjz_i=z_j作为误差函数标准,万一训练结果是zi=zj=z_i=z_j=猫,这该怎么办呢?别着急。后面会去讲它(在讲后面对比学习的伪代码的时候会做出交代)!

怎样去设计一个误差函数去衡量ziz_izjz_j的差异度呢?在多分类问题里面,ziz_izjz_j都是指one-hot向量。论文中使用余弦距离/向量积来衡量它们之间的差异,定义si,j=ziTzj/(zizj)s_{i,j}=z_i^Tz_j/(\Vert z_i\Vert \Vert z_j \Vert)。如果ziz_izjz_j越发地相近,那么si,j1s_{i,j}\rightarrow 1

接下来,看一下对比学习计算的伪代码:

小白科研笔记:理解对比学习
看到计算si,js_{i,j}这段代码为止,差不多都能理解。l(i,j)l(i,j)log(Softmax())-\log(\text{Softmax}(\cdot))的变体,也能去理解。在训练过程中,需要使L0L \rightarrow 0,这意味着l(2k1,2k)0l(2k-1,2k) \rightarrow 0以及l(2k,2k1)0l(2k,2k-1) \rightarrow 0。结合l(i,j)l(i,j)的定义,l(2k1,2k)0l(2k-1,2k) \rightarrow 0表示s2k1,2k1s_{2k-1,2k} \rightarrow 1并且s2k1,m0(m2k1,2k)s_{2k-1,m}\rightarrow 0 (m \not= 2k-1,2k)。这说明网络会强制要求:同一张图片经过数据增强变形后得到两张增强图片,必须要从这两张图片中挖掘出共同的特征(误差函数告诉我们:非得是这两张同源图片,其他图片都是不行的)。经过这般学习后,f()f(\cdot)会有泛化能力很强的表征能力。对于一张图片,不管这个图片做了怎样的数据增强处理,f()f(\cdot)都会稳定地提取到这张图片最为本质的特征。

伪代码的最后一行也是个骚操作,保留f()f(\cdot),丢掉g()g(\cdot)。丢掉g()g(\cdot)其实是可以理解的。这对应着前面讲的那个问题:如果仅仅是以zi=zjz_i=z_j作为误差函数标准,万一训练结果是zi=zj=z_i=z_j=猫,这该怎么办呢?因为缺乏真值作为监督,g()g(\cdot)会产生一些奇怪的结果,比如把所有狗的照片识别为猫,所有猫的照片识别为猪,所有猪的照片识别为狗。所以丢掉g()g(\cdot)是正确的。保留f()f(\cdot)的原因前面也已经讲了,f()f(\cdot)有泛化能力很强的表征能力。

注意看这篇论文的标题A Simple Framework for Contrastive Learning of Visual Representations。这篇论文主要的目的是学习提取特征(Learning of Visual Representations),所以说它的核心目的即是获取f()f(\cdot)

3. 结束语

当然,对于一个完整的图片目标分类算法,通过对比学习得到f()f(\cdot)后,可以通过监督学习的方法重新训练g()g(\cdot)。因为f()f(\cdot)拥有强大且稳定的表示特征的能力,那么整体网络Net\text{Net}也是泛化能力强的。这大概是这篇文章的意思。

相关文章:

  • 2021-12-30
  • 2021-09-12
  • 2021-11-11
  • 2021-12-24
  • 2021-05-01
  • 2021-07-01
  • 2021-11-07
  • 2021-12-17
猜你喜欢
  • 2021-10-28
  • 2021-08-04
  • 2022-01-10
  • 2021-06-01
  • 2021-11-11
  • 2022-01-18
  • 2021-04-07
相关资源
相似解决方案