1. 概述

导读:这篇文章是在检测模型上使用知识蒸馏,从而实现在减小检测模型尺寸与推理时间的同时,尽可能提升小模型的检测性能。这篇文章是基于Tiny-YOLO的检测模型,但是在知识蒸馏的部分做了较多的工作,归纳为:(1)objectness scaled distillation:按照目标是否为检测模型的置信度给蒸馏的网络添加权重参数,相当于是objectness-ware;(2)FM_NMS(feature map NMS):在给定的窗格内保留最好的检测框,从而提升teacher传递信息的质量,抑制无关检测结果。使用文章的方法其在Tiny_YOLO的baseline 54.2 mAP提升到59.4 mAP,并且在速度上快了20%。

在下图中将文章的方法与其它的一些检测算法进行比较:
《Object detection at 200 Frames Per Second》论文笔记

2. 方法设计

2.1 蒸馏网络的结构

文章所用到的蒸馏网络的结构如下图所示:
《Object detection at 200 Frames Per Second》论文笔记
注意到,文章并没有将teacher与student对应特征层的feature直接迁移,而是使用FM-NMS进行信息的提取与增强之后再进行迁移,这样可以避免有效信息被覆盖与student网络的过拟合

2.2 FM-NMS

作者在这篇文章中提到,目标检测是与极大值抑制息息相关的,而这一步是独立于端到端训练的网络的,在经过极大值抑制之前的特征图上检测网络在该区域的响应是很稠密的(相邻位置处得到的疑似检测结果很多)。若是直接将NMS之前的特征用于监督student model将会导致过拟合与性能下降。对此文章提出了特征图的非极大值抑制,从而抑制这些稠密的响应结果。对此文章提出了FM-NMS的操作,其流程见下图所示:
《Object detection at 200 Frames Per Second》论文笔记
在文章作者在teacher的feature map上使用333*3方格进行滑动,选择其中目标置信度最高的值对应的检测信息作为改点处的输出,而其它的点信息就会被抑制。从而得到一个有效信息高度浓缩的特征图。

2.3 Objectness Scaled Loss

这部分包含了检测网络本身的损失函数也包含了对应的蒸馏网络的损失。对于原始的YOLO损失函数其定义为三个部分损失的相加(目标置信度、分类、检测框):
YYOLO=fobj(oigt,oi^)+fcl(pigt,pi^)+fbb(bigt,bi^)Y_{YOLO}=f_{obj}(o_i^{gt},\hat{o_i})+f_{cl}(p_i^{gt},\hat{p_i})+f_{bb}(b_i^{gt},\hat{b_i})
文章在进行蒸馏的时候考虑到了蒸馏目标的置信度问题,从而有效抑制了对于背景区域的蒸馏,从而提升检测网络的性能,这里首先对于置信度损失,其定义为:
fobjComb(oigt,oi^,oiT)=fobj(oigt,oi^)+λDfobj(oiT,oi^)f_{obj}^{Comb}(o_i^{gt},\hat{o_i},o_i^T)=f_{obj}(o_i^{gt},\hat{o_i})+\lambda_D\cdot f_{obj}(o_i^T,\hat{o_i})
对于分类的损失:
fclComb(pigt,pi^,piT,oiT^)=fcl(pigt,pi^)+oiT^λDfcl(piT,pi^)f_{cl}^{Comb}(p_i^{gt},\hat{p_i},p_i^T,\hat{o_i^T})=f_{cl}(p_i^{gt},\hat{p_i})+\hat{o_i^T}\cdot \lambda_D\cdot f_{cl}(p_i^T,\hat{p_i})
对于检测框边界的损失:
fbbComb(bigt,bi^,biT,oiT^)=fbb(bigt,bi^)+oiT^λDfbb(biT,bi^)f_{bb}^{Comb}(b_i^{gt},\hat{b_i},b_i^T,\hat{o_i^T})=f_{bb}(b_i^{gt},\hat{b_i})+\hat{o_i^T}\cdot \lambda_D\cdot f_{bb}(b_i^T,\hat{b_i})
对应的总的损失函数定义为:
Lfinal=fobjComb(oigt,oi^,oiT)+fclComb(pigt,pi^,piT,oiT^)+fbbComb(bigt,bi^,biT,oiT^)L_{final}=f_{obj}^{Comb}(o_i^{gt},\hat{o_i},o_i^T)+f_{cl}^{Comb}(p_i^{gt},\hat{p_i},p_i^T,\hat{o_i^T})+f_{bb}^{Comb}(b_i^{gt},\hat{b_i},b_i^T,\hat{o_i^T})

3. 实验结果

文中提到的两点改进对于检测性能的影响:
《Object detection at 200 Frames Per Second》论文笔记

相关文章: