前言
在采用度量学习处理小样本学习问题中,通过学习查询样本(query)与支持样本(support)之间的特征相似性比较,来确定query的类别。但是这种方法有一个缺点,就是query在与支持集中的样本进行比较时,是一个类一个类的进行比较的,也就是query与支持集中某个类的样本比较完,再与另一个类进行比较,没有考虑整体的关系。这样造成的后果是,不能辨别哪个维度的特征与当前任务是最相关的。
本文对基于度量学习的方法做出了改进,将整个支持集的上下文信息都考虑进来,也即将整个支持集视为一个整体。通过这种改进,可以找到与每个任务最相关的特征维度。通过下图来对这种改进做一个说明:
图(a)是5-way 1-shot的小样本学习任务,其中只有两个简单的特征维度:颜色和形状,最终的目标是确定query属于支持集中的哪个类。可以看到query是一个绿色的圆形,由于支持集中每个类的样本的颜色都是不同的,而有些样本的形状是相同的,所以在该任务中最相关的特征就是颜色,因此query所属的类应该是(iii)。但是如果在计算query的特征相似性时没有考虑整个支持集,而是一个类一个类单独地看,那么query与类(i)(iii)(iv)(v)就会有相同的相似度,准确率就会降低。只有在考虑整个支持集的上下文信息的情况下,我们才发现颜色是最相关的特征。这个例子就启发我们要遍历支持集中的所有类别以找到具有类间唯一性(inter-class uniqueness)的特征。
图(b)是具有multi-shot的小样本学习任务,可以看到在类(ii)中,大多数样本的颜色都是相同的,而它们的形状各有不同,因此在类(ii)中,颜色是最相关的特征,也即具有类内通用性(intra-class commonality)的特征。
从以上两个例子可以看出,利用类间唯一性和类内通用性,可以找出与某任务最相关的特征。在一个类中,通过对特征向量取平均,可以缓解类内样本间的不同,从而得到类内的共享特征,即具有类内通用性的特征。
为了结合类间唯一性和类内通用性,本文提出一个类别遍历模块(category traversal module,CTM),它能够在遍历类间和类内之后,选出最具有相关性的特征。CTM主要由两部分组成:
- concentrator unit,用于提取类内具有通用性的embedding;
- projector unit,通过考虑concentrator unit的输出,得到类间具有唯一性的embedding。
下图描述了CTM如何应用到现有的基于度量学习的小样本学习算法中,它可以被看作是一种即插即用的模块:
CTM的实现
CTM模块将支持集的特征作为输入,然后通过concentrator和projector生成一个mask ,这个mask 被应用于支持集和查询集的降维特征,生成与当前任务相关的维度的改进特征,最终这些改进的特征被送入一个度量学习器(metric learner)中。模型的整体结构如下图:
1. concentrator
concentrator的目标是找到一个类中的所有样本共享的通用特征,也即具有类内通用性的特征。将特征提取器的输出表示为,其中表示通道的数量,表示空间大小,那么concentrator可以被定义为:
其中表示输出的通道数量,表示输出的空间大小。其实输入首先被送入一个CNN模块中进行降维,然后对每个类内的样本取平均,从而获得最终输出。在1-shot情况下就不用进行平均操作了,因为每个类中只有一个样本。
这个CNN模块可以是一个简单的CNN层,也可以是一个ResNet块,降维的目的是移除样本之间的差异,从而能够在类内提取出具有通用性的特征。也就是说,从,到,就是一种降维,也就是适当程度的下采样,在经过降维之后,再对类内样本取平均,以得到最终输出。实验证明这种方法比原型网络中直接对样本取平均的方法要好,而当时,原型网络中的取平均方案也可以看作是concentrator的一种特殊情况。
2. projector
projector的目的是屏蔽掉其它不相关的特征,然后选择一个对于当前任务来说最具有判别性的特征。在给定由concentrator输出的特征的情况下,projector同时考虑支持集中所有类别的concentrator特征,然后得到输出:
其中是的一个reshaped version,,分别是输出的通道数量和空间大小。为了实现跨类遍历的目标,首先将中的类别原型连结到,然后在连结后的特征上应用一个CNN来生成一个大小为的映射,最终在上应用softmax来生成一个mask ,这个用于对与该任务最相关的特征维度进行mask。
3. reshaper
为了让projector的输出能够影响feature embedding ,需要将网络中这些模块之间的shape进行匹配。在每个样本上应用一个reshaper网络:
4. 改进的特征的生成
在concentrator和projector的支持下,经过跨类遍历可以生成一个mask:,然后要将这个与支持集和查询集的feature embedding结合起来,生成改进的特征,才能实现CTM的效果。而这种结合,也就是的生成方式,在支持集和查询集中是不同的。
在查询集中,因为query没有类别标签,因此这种结合只是对embedding和进行简单的元素乘法(element-wise multiplication):;
在支持集中,由于我们知道query的类别标签,因此要么直接将 mask到embedding上,要么保持,然后用对concentrator的输出进行mask:
实验证明option 1的效果更好。