MultiPoseNet: Fast Multi-Person Pose Estimation using Pose Residual Network.
论文思路大致解读
论文提出的网络结构大概分成三部分:
- 首先第一部分是Backbone网络,用于提取图片在多尺度下的特征;
- 第二部分包括两个分开、独立的网络,其中一个用来检测图片中所有的人体关键点(keypoint_subnet),另外一个用来图片中的行人检测(person_detect subnet)
- 第三部分即文章的核心部分,提出的残差网络(PRN,Pose Residual Network),概括来说就是一个聚类算法,将第二部分检测的到的所有关键点依据行人检测结果进行聚类,得到每个人的人体关键点聚集。
论文的网络结构如下图所示:
论文具体部分详解
Backbone网络: Backbone网络为后面的关键点检测和行人检测网络提供图像特征,在论文中使用了ResNet网络结构,并加入了两个FPN(Feature Pyramid Networks)网络结构,一个用于后面的关键点检测,一个用于后面的行人检测。论文作者在ResNet网络的最后一个residual block提取特征并计算相应的FPN特征,具体的ResNet网络文中采用了ResNet-50和ResNet-101两种网络,ResNet-50更快,ResNet-101会慢一些但相比ResNet-50在COCO数据集上提高了大约1.6mAP检测结果。
Keypoint Subnet: 关键点检测网络,输入为前面FPN网络的输出特征,输出为关键点热图和分割结果热图。关键点检测网络如下图所示:关键点检测网络还是使用了FPN用于前面步骤的特征点检测,Keypoint Subnet网络第三部分(即d=128那里)将前面传来的特征进行上采样,使得上面三个feature map大小和最下面的一致(从上到下依次为D2、D3、D4、D5),最后将D2-D5进行concatenated得到维度为512的feature map,接着通过一个33的卷积和ReLU进行smoothing,最后通过一个11的卷积得到(K+1)层的热度图输出,K是标注的人体关键点类别数量,+1是person segmentation mask。
Person Detection Subnet:文中使用了 FPN+Focal Loss 模型用来行人检测,其实就是完全套用了RetinaNet结构,输出是N*5,N是图片内行人数量,5是bounding box的四个坐标加对应的confidence。具体网络结构可以参考原文:RetinaNet
PRN: 首先是固定尺寸的输入,所以需要先将关键点检测的输出和行人检测的输出裁剪到一个固定大小的值,然后再对其进行关键点到行人的映射。PRN使用叫“residual correction”的方法来对关键点进行映射(应该就是聚类的意思),使得同一个人的关键点映射到同一类。PRN的具体结构应该是一个多层感知机(MLP),文中给出最终采用的结构是:包含1024神经元的多层感知机,0.5概率的dropout,以及输入和输出中间一个residual connection。PRN采用的函数计算公式没有详细介绍,只给出了一个公式:,具体代码可以看文章开头PRN网络链接。
结果
文中代码运行环境是在GTX1080Ti上,Backbone应该是采用了ResNet-50,才达到了在COCO数据集上平均23FPS的效果。ResNet-101准确度会比ResNet-50高些,但速度会慢些。准确度来说应该是目前所有Bottom-Top方法里最高的,只有两个Top-Down的方法准确度比它高,但在多人姿态估计上,速度比它慢多了。
论文最后有比较是关键点检测结果对最终结果影响大还是行人检测对最终结果影响大,发现关键点检测结果影响非常大。在提供bounding box真值的情况下,网络检测出来的关键点 + GT(bounding box),其最终结果相比 GT(keypoints)+ GT(bounding box),AP值相差将近24.(两个GT的AP值为89.4,而 网络检测出的关键点+ GT box仅有65.1)。所以论文最后也提出了一个建议,就是使用更强的特征提取网络 (例如ResNext)来提高最终结果。(有可能会增加计算复杂度,降低运行速度)
论文里所提到的一些网络结构参考源:
我自己也对这篇论文做了一个简单的复现,代码很杂,但结构比较清晰。目前成功复现了最后的PRN网络,根据官方提供的测试方法,也达到了它说的那个结果。但keypoint_subnet 和 person_detect subnet这两个网络,由于刚开始接触深度学习这方面,所以进展比较缓慢。
我的github地址。
注意::
我在复现这篇paper的时候,发现官方给的PRN网络代码有些问题。具体来讲就是上文的PRN网络代码里,src文件夹下的eval.py文件,在对PRN网络的输出结果做点的预测的时候,官方代码里却给出了用gt-points的值来给bbox_keypoints赋值。具体代码见eval.py第200行和205行。其实我认为代码里第209~220行才是正确的预测点的方法。另外issue里也有人说‘ This repo is a scam’,目前而言我也不确定这篇论文究竟有没有作假,但就其官方提供的关于PRN网络的测试脚本来看,其结果是很有问题的。