转载请注明出处:
https://www.cnblogs.com/darkknightzh/p/11486185.html
论文:
https://arxiv.org/abs/1603.06937
官方torch代码(没具体看):
https://github.com/princeton-vl/pose-hg-demo
第三方pytorch代码(位于models/StackedHourGlass.py):
https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation
该论文利用多尺度特征来识别姿态,如下图所示,每个子网络称为hourglass Network,是一个沙漏型的结构,多个这种结构堆叠起来,称作stacked hourglass。堆叠的方式,方便每个模块在整个图像上重新估计姿态和特征。如下图所示,输入图像通过全卷积网络fcn后,得到特征,而后通过多个堆叠的hourglass,得到最终的热图。
Hourglass如下图所示。其中每个方块均为下下图的残差模块。
Hourglass采用了中间监督(Intermediate Supervision)。每个hourglass均会有热图(蓝色)。训练阶段,将这些热图和真实热图计算损失MSE,并求和,得到损失;推断阶段,使用的是最后一个hourglass的热图。
2. stacked hourglass
堆叠hourglass结构如下图所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):
代码如下:
1 class StackedHourGlass(nn.Module): 2 """docstring for StackedHourGlass""" 3 def __init__(self, nChannels, nStack, nModules, numReductions, nJoints): 4 super(StackedHourGlass, self).__init__() 5 self.nChannels = nChannels 6 self.nStack = nStack 7 self.nModules = nModules 8 self.numReductions = numReductions 9 self.nJoints = nJoints 10 11 self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3) # BN+ReLU+conv 12 13 self.res1 = M.Residual(64, 128) # 输入和输出不等,输入通过1*1conv结果和3*(BN+ReLU+conv)求和 14 self.mp = nn.MaxPool2d(2, 2) 15 self.res2 = M.Residual(128, 128) # 输入和输出相等,为x+3*(BN+ReLU+conv) 16 self.res3 = M.Residual(128, self.nChannels) # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和。 17 18 _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[] 19 20 for _ in range(self.nStack): # 堆叠个数 21 _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules)) 22 _ResidualModules = [] 23 for _ in range(self.nModules): 24 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv) 25 _ResidualModules = nn.Sequential(*_ResidualModules) 26 _Residual.append(_ResidualModules) # self.nModules 个 3*(BN+ReLU+conv) 27 _lin1.append(M.BnReluConv(self.nChannels, self.nChannels)) # BN+ReLU+conv 28 _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1)) # 1*1 conv,维度变换 29 _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1)) # 1*1 conv,维度不变 30 _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1)) # 1*1 conv,维度变换 31 32 self.hourglass = nn.ModuleList(_hourglass) 33 self.Residual = nn.ModuleList(_Residual) 34 self.lin1 = nn.ModuleList(_lin1) 35 self.chantojoints = nn.ModuleList(_chantojoints) 36 self.lin2 = nn.ModuleList(_lin2) 37 self.jointstochan = nn.ModuleList(_jointstochan) 38 39 def forward(self, x): 40 x = self.start(x) 41 x = self.res1(x) 42 x = self.mp(x) 43 x = self.res2(x) 44 x = self.res3(x) 45 out = [] 46 47 for i in range(self.nStack): 48 x1 = self.hourglass[i](x) 49 x1 = self.Residual[i](x1) 50 x1 = self.lin1[i](x1) 51 out.append(self.chantojoints[i](x1)) 52 x1 = self.lin2[i](x1) 53 x = x + x1 + self.jointstochan[i](out[i]) # 特征求和 54 55 return (out)