文章目录
本文记录一下三种常用的loss function:Contrastive Loss,Triplet Loss,Focal Loss。其中前面两个可以认为是ranking loss类型,Focal Loss是针对正负样本极其不均衡情况下的一种cross entropy loss的升级版。本文主要参考资料是[2][3][5]。
我们平时ML任务的时候,用的最多的是cross entropy loss或者MSE loss。需要有一个明确的目标,比如一个具体的数值或者是一个具体的分类类别。但是ranking loss实际上是一种metric learning,他们学习的相对距离,相关关系,而对具体数值不是很关心。ranking loss 有非常多的叫法,但是他们的公式实际上非常一致的。大概有两类,一类是输入pair 对,另外一种是输入三元组结构。
1. Contrastive Loss (对比loss)
在孪生神经网络(siamese network)中,其采用的损失函数是contrastive loss,这种损失函数可以有效的处理孪生神经网络中的paired data的关系。contrastive loss的表达式如下:
其中,代表两个样本特征的欧氏距离,y为两个样本是否匹配的标签,y=1代表两个样本相似或者匹配,y=0则代表不匹配,margin为设定的阈值。这种损失函数最初来源于Yann LeCun的Dimensionality Reduction by Learning an Invariant Mapping,主要是用在降维中,即本来相似的样本,在经过降维(特征提取)后,在特征空间中,两个样本仍旧相似;而原本不相似的样本,在经过降维后,在特征空间中,两个样本仍旧不相似。
观察上述的contrastive loss的表达式可以发现,这种损失函数可以很好的表达成对样本的匹配程度,也能够很好用于训练提取特征的模型。当y=1(即样本相似)时,损失函数只剩下,即原本相似的样本,如果在特征空间的欧式距离较大,则说明当前的模型不好,因此加大损失。而当y=0时(即样本不相似)时,损失函数为,即当样本不相似时,其特征空间的欧式距离反而小的话,损失值会变大,这也正好符号我们的要求。其中margin是一个超参,相当于是给loss定了一个上届(margin平方),如果d大于等于margin,那么说明已经优化的很好了,loss=0了。
这张图表示的就是损失函数值与样本特征的欧式距离之间的关系,其中红色虚线表示的是相似样本的损失值,蓝色实线表示的不相似样本的损失值。
2. Triplet Loss(三元loss)
Triplet loss最初是在 FaceNet: A Unified Embedding for Face Recognition and Clustering 论文中提出的,可以学到较好的人脸的embedding。
Softmax是确定的分类,需要有真实的标注label。而有的时候我们不一定知道label,但是知道正样本对和负样本对——比如两张照片是同一个人,或者不是同一个人。
输入是一个三元组 <a, p, n>
- a: anchor,表示一个基准样本
- p: positive, 与 a 是同一类别的样本,比如就是同一个人的照片
- n: negative, 与 a 是不同类别的样本,比如就是不同人的照片
Triplet Loss形式:
其中表示距离函数,一般指在Embedding下的欧式距离计算。很显然,Triplet-Loss是希望让a和p的距离尽可能小,而a和n的距离尽可能大,但是具体而言和的数值是多少,并没有规定,只要考察他们之间的相对距离。网上有一张图片说明了几种相对关系。
如果我们给定了一个a和p,以及参数,那么我们就可以考察negative点的位置,会出现三种case(如果是Easy negative的三元组我们叫做Easy triplet,其他类似):
- Easy negatives(绿色区域): 即 ,这种情况不需要优化(无法优化,Loss为0),天然a, p的距离很近, a, n的距离远。
- hard negatives(红色区域):,也就是说negative点反而比较近,说明距离估计的不准,这个时候loss比较大。
- Semi-hard negatives(橙色区域):, 即a, n的距离靠的很近,但是因为我们有一个margin,使得loss依然是正的。这种情况下,其实是说在这个三元组里面比较p和n离a的距离差不多,比较容易混淆。
FaceNet论文中是随机选取semi-hard triplets进行训练的, (也可以选择 hard triplets或者两者一起进行训练)
在线训练时产生样本:
虽然可以离线把triplet数据都产生(配对)好,但实际使用采用此方法,即在线对一个Batch去产生。产生时又分为两种策略Batch All和Batch Hard (是在一篇行人重识别的论文中提到的[7],假设一个batch中有张图片, 其中个身份的人,每个身份的人张图片(比如)。
-
Batch All:计算batch_size中所有valid的hard triplet 和 semi-hard triplet(valid是指a,p,n三个都不能相同,需要是不同图片), 然后取平均得到Loss。理论上最多可以产生 个 triplets:PK个 anchor,K-1 个 positive,PK-K 个 negative。但是因为很多是easy triplets的情况,所以平均会导致Loss很小,easy triplets对我们是不需要的。所以是对所有valid的hard triplet和semi-hard triplet对求平均。
-
Batch Hard:对于每一个anchor,选择距离最大的和距离最大的,所以只有个三元组triplets来求loss。
3. Focal Loss[5]
Focal Loss的提出是用来解决一阶段目标检测算法面对的极端不平衡前景和背景目标(框)数量,作者表示可能有1:1000。原论文主要是关心处理简单的二分类问题,即前景背景检测,但是loss本身也非常容易扩展到多分类问题。
基本的交叉熵loss(针对一个样本,其中表示模型的输出概率,且只要针对ground truth那个类别的输出概率):
α-balanced focal lossloss:
其中,我们发现在原来的cross entropy loss基础上,加上了一个调节因子,,这个调节因子的作用是:
- When an example is misclassified and pt
is small, the modulating factor is near 1 and the loss is unaffected. As , the factor goes to 0 and the loss for well-classified examples is down-weighted - The focusing parameter smoothly adjusts the rate at which easy examples are downweighted.
实验中发现,一般是最有效的。这个调节因子可以让容易分的样本的loss降低重要性。比如,当,以及,那么这个样本就是比较容易分对,它的loss较Cross Entropy Loss就降低了100倍;而,loss就小了1000倍;而对的难分样本,loss只是降低了4倍,相当于重要性变大了。另外,还多了一个weighting factor ,作者表示:In practice may be set by inverse class frequency or treated as a hyperparameter to set by cross validation。实际在实验中,作者是看成一个超参的,需要调节,比如用0.25, 0.5, 0.75都有试过。如果二分类,就设 for class 1 and for class −1。
3.1 引申讨论:其他形式的Focal Loss
作者提出,实际上并不只是上面的Focal Loss可以有效,应该存在一批定义方法,他们的效果是类似的。在[5]附录中作者又给出了一种Focal Loss*的定义方法:
我们定义:
其中,表示样本分对or分错。给出了一种新的loss形式:: Focal Loss*
下图是本文中几种loss曲线,以及他们的梯度:
具体的梯度公式推导是:
最后作者给出的结果是,几种Focal Loss的作用差不了太多,但是比原始的Cross Entropy显著要好一些。作者认为和Focal Loss有相同作用的其他loss应该是等效效果的。
参考资料
[1] Triplet Loss, Ranking Loss, Margin Loss
[2] Contrastive Loss (对比损失)
[3] Triplet-Loss原理及其实现、应用
[4] Loss Rank Mining: A General Hard Example Mining Method for Real-time Detectors
[5] Focal Loss for Dense Object Detection
[6] FaceNet: A Unified Embedding for Face Recognition and Clustering
[7] In Defense of the Triplet Loss for Person Re-Identification
[8] 文献阅读 - Dimensionality Reduction by Learning an Invariant Mapping
[9] 目标检测focal loss 和 loss rank mining笔记