Tutorial on GoogleNet based image classification
2018-06-26 15:50:29
本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。
1. GoogleNet 的结构:
1 class Inception(nn.Module): 2 def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes): 3 super(Inception, self).__init__() 4 # 1x1 conv branch 5 self.b1 = nn.Sequential( 6 nn.Conv2d(in_planes, kernel_1_x, kernel_size=1), 7 nn.BatchNorm2d(kernel_1_x), 8 nn.ReLU(True), 9 ) 10 11 # 1x1 conv -> 3x3 conv branch 12 self.b2 = nn.Sequential( 13 nn.Conv2d(in_planes, kernel_3_in, kernel_size=1), 14 nn.BatchNorm2d(kernel_3_in), 15 nn.ReLU(True), 16 nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1), 17 nn.BatchNorm2d(kernel_3_x), 18 nn.ReLU(True), 19 ) 20 21 # 1x1 conv -> 5x5 conv branch 22 self.b3 = nn.Sequential( 23 nn.Conv2d(in_planes, kernel_5_in, kernel_size=1), 24 nn.BatchNorm2d(kernel_5_in), 25 nn.ReLU(True), 26 nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1), 27 nn.BatchNorm2d(kernel_5_x), 28 nn.ReLU(True), 29 nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1), 30 nn.BatchNorm2d(kernel_5_x), 31 nn.ReLU(True), 32 ) 33 34 # 3x3 pool -> 1x1 conv branch 35 self.b4 = nn.Sequential( 36 nn.MaxPool2d(3, stride=1, padding=1), 37 nn.Conv2d(in_planes, pool_planes, kernel_size=1), 38 nn.BatchNorm2d(pool_planes), 39 nn.ReLU(True), 40 ) 41 42 def forward(self, x): 43 y1 = self.b1(x) 44 y2 = self.b2(x) 45 y3 = self.b3(x) 46 y4 = self.b4(x) 47 return torch.cat([y1,y2,y3,y4], 1)