ys99

1.文章原文地址

SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

2.文章摘要

语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。

3.网络结构

4.Pytorch实现

  1 import torch.nn as nn
  2 import torch
  3 
  4 class conv2DBatchNormRelu(nn.Module):
  5     def __init__(self,in_channels,out_channels,kernel_size,stride,padding,bias=True,dilation=1,is_batchnorm=True):
  6         super(conv2DBatchNormRelu,self).__init__()
  7         if is_batchnorm:
  8             self.cbr_unit=nn.Sequential(
  9                 nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
 10                           bias=bias,dilation=dilation),
 11                 nn.BatchNorm2d(out_channels),
 12                 nn.ReLU(inplace=True),
 13             )
 14         else:
 15             self.cbr_unit=nn.Sequential(
 16                 nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
 17                           bias=bias, dilation=dilation),
 18                 nn.ReLU(inplace=True)
 19             )
 20 
 21     def forward(self,inputs):
 22         outputs=self.cbr_unit(inputs)
 23         return outputs
 24 
 25 class segnetDown2(nn.Module):
 26     def __init__(self,in_channels,out_channels):
 27         super(segnetDown2,self).__init__()
 28         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 29         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 30         self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)
 31 
 32     def forward(self,inputs):
 33         outputs=self.conv1(inputs)
 34         outputs=self.conv2(outputs)
 35         unpooled_shape=outputs.size()
 36         outputs,indices=self.maxpool_with_argmax(outputs)
 37         return outputs,indices,unpooled_shape
 38 
 39 class segnetDown3(nn.Module):
 40     def __init__(self,in_channels,out_channels):
 41         super(segnetDown3,self).__init__()
 42         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 43         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 44         self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 45         self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)
 46 
 47     def forward(self,inputs):
 48         outputs=self.conv1(inputs)
 49         outputs=self.conv2(outputs)
 50         outputs=self.conv3(outputs)
 51         unpooled_shape=outputs.size()
 52         outputs,indices=self.maxpool_with_argmax(outputs)
 53         return outputs,indices,unpooled_shape
 54 
 55 
 56 class segnetUp2(nn.Module):
 57     def __init__(self,in_channels,out_channels):
 58         super(segnetUp2,self).__init__()
 59         self.unpool=nn.MaxUnpool2d(2,2)
 60         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 61         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 62 
 63     def forward(self,inputs,indices,output_shape):
 64         outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
 65         outputs=self.conv1(outputs)
 66         outputs=self.conv2(outputs)
 67         return outputs
 68 
 69 class segnetUp3(nn.Module):
 70     def __init__(self,in_channels,out_channels):
 71         super(segnetUp3,self).__init__()
 72         self.unpool=nn.MaxUnpool2d(2,2)
 73         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 74         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 75         self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 76 
 77     def forward(self,inputs,indices,output_shape):
 78         outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
 79         outputs=self.conv1(outputs)
 80         outputs=self.conv2(outputs)
 81         outputs=self.conv3(outputs)
 82         return outputs
 83 
 84 class segnet(nn.Module):
 85     def __init__(self,in_channels=3,num_classes=21):
 86         super(segnet,self).__init__()
 87         self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
 88         self.down2=segnetDown2(64,128)
 89         self.down3=segnetDown3(128,256)
 90         self.down4=segnetDown3(256,512)
 91         self.down5=segnetDown3(512,512)
 92 
 93         self.up5=segnetUp3(512,512)
 94         self.up4=segnetUp3(512,256)
 95         self.up3=segnetUp3(256,128)
 96         self.up2=segnetUp2(128,64)
 97         self.up1=segnetUp2(64,64)
 98         self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1)
 99 
100     def forward(self,inputs):
101         down1,indices_1,unpool_shape1=self.down1(inputs)
102         down2,indices_2,unpool_shape2=self.down2(down1)
103         down3,indices_3,unpool_shape3=self.down3(down2)
104         down4,indices_4,unpool_shape4=self.down4(down3)
105         down5,indices_5,unpool_shape5=self.down5(down4)
106 
107         up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
108         up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
109         up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
110         up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
111         up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
112         outputs=self.finconv(up1)
113 
114         return outputs
115 
116 if __name__=="__main__":
117     inputs=torch.ones(1,3,224,224)
118     model=segnet()
119     print(model(inputs).size())
120     print(model)

参考

https://github.com/meetshah1995/pytorch-semseg

相关文章: