转载请注明出处:

https://www.cnblogs.com/darkknightzh/p/11486185.html

论文:

https://arxiv.org/abs/1603.06937

官方torch代码(没具体看):

https://github.com/princeton-vl/pose-hg-demo

第三方pytorch代码(位于models/StackedHourGlass.py):

https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation

该论文利用多尺度特征来识别姿态,如下图所示,每个子网络称为hourglass Network,是一个沙漏型的结构,多个这种结构堆叠起来,称作stacked hourglass。堆叠的方式,方便每个模块在整个图像上重新估计姿态和特征。如下图所示,输入图像通过全卷积网络fcn后,得到特征,而后通过多个堆叠的hourglass,得到最终的热图。

(原)堆叠hourglass网络

Hourglass如下图所示。其中每个方块均为下下图的残差模块。

(原)堆叠hourglass网络

(原)堆叠hourglass网络

Hourglass采用了中间监督(Intermediate Supervision)。每个hourglass均会有热图(蓝色)。训练阶段,将这些热图和真实热图计算损失MSE,并求和,得到损失;推断阶段,使用的是最后一个hourglass的热图。

(原)堆叠hourglass网络

2. stacked hourglass

堆叠hourglass结构如下图所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):

(原)堆叠hourglass网络

代码如下:

 1 class StackedHourGlass(nn.Module):
 2     """docstring for StackedHourGlass"""
 3     def __init__(self, nChannels, nStack, nModules, numReductions, nJoints):
 4         super(StackedHourGlass, self).__init__()
 5         self.nChannels = nChannels
 6         self.nStack = nStack
 7         self.nModules = nModules
 8         self.numReductions = numReductions
 9         self.nJoints = nJoints
10 
11         self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3)  # BN+ReLU+conv
12 
13         self.res1 = M.Residual(64, 128) # 输入和输出不等,输入通过1*1conv结果和3*(BN+ReLU+conv)求和
14         self.mp = nn.MaxPool2d(2, 2)
15         self.res2 = M.Residual(128, 128) # 输入和输出相等,为x+3*(BN+ReLU+conv)
16         self.res3 = M.Residual(128, self.nChannels) # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和。
17 
18         _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[]
19 
20         for _ in range(self.nStack):  # 堆叠个数
21             _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules))
22             _ResidualModules = []
23             for _ in range(self.nModules):
24                 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels))   # 输入和输出相等,为x+3*(BN+ReLU+conv)
25             _ResidualModules = nn.Sequential(*_ResidualModules)
26             _Residual.append(_ResidualModules)   # self.nModules 个 3*(BN+ReLU+conv)
27             _lin1.append(M.BnReluConv(self.nChannels, self.nChannels))       # BN+ReLU+conv
28             _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1))  # 1*1 conv,维度变换
29             _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1))        # 1*1 conv,维度不变
30             _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1))   # 1*1 conv,维度变换
31 
32         self.hourglass = nn.ModuleList(_hourglass)
33         self.Residual = nn.ModuleList(_Residual)
34         self.lin1 = nn.ModuleList(_lin1)
35         self.chantojoints = nn.ModuleList(_chantojoints)
36         self.lin2 = nn.ModuleList(_lin2)
37         self.jointstochan = nn.ModuleList(_jointstochan)
38 
39     def forward(self, x):
40         x = self.start(x)
41         x = self.res1(x)
42         x = self.mp(x)
43         x = self.res2(x)
44         x = self.res3(x)
45         out = []
46 
47         for i in range(self.nStack):
48             x1 = self.hourglass[i](x)
49             x1 = self.Residual[i](x1)
50             x1 = self.lin1[i](x1)
51             out.append(self.chantojoints[i](x1))
52             x1 = self.lin2[i](x1)
53             x = x + x1 + self.jointstochan[i](out[i])   # 特征求和
54 
55         return (out)
View Code

相关文章:

  • 2021-09-13
  • 2022-12-23
  • 2021-11-08
  • 2021-08-17
  • 2021-05-28
  • 2022-12-23
  • 2021-07-02
猜你喜欢
  • 2021-06-24
  • 2021-09-03
  • 2021-11-02
  • 2021-10-13
  • 2022-01-17
  • 2021-10-03
  • 2021-11-26
相关资源
相似解决方案