Prototype Rectification for Few-shot Learning 文献阅读
本文是我对《Prototype Rectification for Few-shot Learning》一文的理解,难免有不足之处,欢迎大家多多交流,批评指正~
算法框图
算法步骤
- 首先基于支持集(支持集中样本全部被标注)训练一个feature extractor F_θ (x) 和一个基于cosine similarity 的分类器C(∙│W)。
(注:基于cosine similarity 的分类器并不是简单计算cosine,其中仍有参数待训练,此处W应表示feature extractor中的参数?)。
C(F_θ (x)│W)=Softmax(τ∙Cos(F_θ (x),W))
其中,W表示待学习的参数,τ表示一个标量参数;
特征提取器的目标函数是在监督分类任务上最小化对数似然函数:
L(θ,W│D)=E[-logC(F_θ (x)│W)]
其中,D表示支持集。 - 为了避免由于再次微调导致的过拟合情况(为什么需要再次微调?以及怎么微调?),在训练完成的支持集中直接计算基础类原型(N-way, K-shot)。
P_n=1/K ∑_(i=1)^K▒X ̅_(i,n)
X ̅ 表示支持集的归一化特征。
在此基础上,对查询集中的样本进行分类(通过余弦相似度直接划分到距离最近的基础类原型); - 对基础类原型进行更新,使其不断逼近理想的原型。
文章主要贡献
通过少样本获得的原型与理想的原型之间存在误差。本文分为类内偏差与跨类偏差两大类。
- 类内偏差(intra-class bias):对于每一类,类内偏差表示理想的原型与实际求得的原型之间的差距。
B_intra=E_(X’~p_(X’ ) ) [X^’ ]-E_(X~p_X ) [X]
其中,p_(X^’ )表示理想的样本分布,p_X表示小样本分布。
因为获得p_(X^’ )是不可能的,因此本文采用查询集中分类置信度(使用前面训练好的C(F_θ (x)│W))最高的Z个样本扩充支持集S,进一步更新每一类的原型。
S’=S∪Q_pseudoZ
考虑到在使用伪标记的查询集样本增强支持集时,有一些查询集样本是被错误分类的,为了缓解这种影响,我们使用带有权重的样本和作为原型(不同于前面的直接平均)。
P_n’=∑_(i=1)(Z+K)▒〖ω_(i,n) 〖X ̅^’〗(i,n) 〗
此处,X ̅^'表示增强后的支持集的归一化特征,ω(i,n)表示增强后的查询集样本与基础原型的相关性。
ω_(i,n)=(exp(ε∙Cos(〖X ̅’〗_(i,n),P_n)))/(∑_(j=1)(K+Z)▒〖exp(ε∙Cos(〖X ̅^’〗(j,n),P_n))〗)
如上为ω(i,n)的表达式,也就是分类置信度越高的样本具有更大的权重(通过以上方法获得的P_n^'比P_n更接近于理想原型,文中说的)。 - 跨类偏差(cross-bias):跨类偏差表示支持集与查询集之间的距离。(在本文所讨论的情况下,因为支持集与查询集是满足同种分布的,所以它们之间的距离并不属于域自适应的问题。)
B_cross=E_(X_s~p_S ) [X_s ]-E_(X_q~p_Q ) [X_q ]
这里,p_S和p_Q分别表示支持集与查询集的分布。
为了减小跨类偏差,本文在每一个归一化查询集特征后添加一个转移量ξ,定义如下:
ξ=1/|S| ∑_i^(|S|)▒X ̅_(i,s) -1/|Q| ∑_j^(|Q|)▒X ̅_(j,q)
理论推导
为什么使用伪标记样本是合理的?
对于某一类的基础原型表示:
P=(∑_i^T▒X ̅_i^’ )/T
其中,X表示某一类的特征,P表示某一类的原型,T=K+Z。
目标函数为(即最大化与分类精度正相关的理想余弦相似度,不太理解):
max〖E_P [E_X [Cos(P,X)]]〗
本文通过数学上的推导得到了下式:
E_P [E_X [Cos(P,X)]]≥(E[X ̅ ]∙E[P])/√(E[‖P‖2^2 ] )=(∑(i=1)^D▒〖E[(x_i ) ̅ ]^2 〗)/√(1/T ∑_(i=1)^D▒Var[(x_i ) ̅ ] +∑_(i=1)^D▒〖E[(x_i ) ̅ ]^2 〗)
X ̅_i^‘∈S’,并且S’是从X采样的子集。X ̅与P ̅分别表示归一化后的特征与原型。P与X ̅都是 D维特征矢量,即P=[p_1,p_2,…p_D],X ̅=[x ̅_1,x ̅_2,…x ̅_D]。(在本文的方法中,假设每一维向量相互独立)。
通过上式可以看到,E_P [E_X [Cos(P,X)]]与T成正相关,也就是目标函数max〖E_P [E_X [Cos(P,X)]]〗
的下边界与T成正相关。
因此引入伪标记样本的合理性在于它可以提升预期性能的下限(如下为原文中描述)。
为什么增大Z可以等同于提高T?值得思考。
实验结果
(减小类内偏差带来的提升更显著。)
伪样本数目取8时,效果最好;
T-SNE可视化((a)通过类内偏差校正后,获得的原型与理想的原型位置更加接近;(b)通过类间偏差校正后,查询集的分布更接近于支持集原型):
总结
在本文中,作者试图解决利用少样本数据获得的原型表征与理想的原型表征之间的偏差问题,认为应从类内偏差与跨类偏差两部分解决。在解决类内偏差时,作者提出了利用伪标记样本的方法,用查询集样本作为伪标记样本。在解决类间偏差时,作者提出了引入一个偏移量。偏移量是通过支持集与查询集的样本特征矢量均值之差得到。本文的贡献在于给出了偏差的分类及数学描述,且方法简单有效。