目录
- 前言
- 背景介绍
- Segnet核心剖析
- [池化索引(pooling Indices)](#池化索引(pooling Indices))
- 其他细节
- Segnet模型代码
- 结语
- 参考资料
前言
本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!带大家深入语义分割的领域,将从原理,代码深入讲解,希望大家能从中有所收获,其中很多内容都包含着自己的一些想法以及理解,如果有错误的地方欢迎大家批评指正。
论文名称: SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation
论文地址:SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation
欢迎来到语义分割系列第三篇哦,本文将继续带大家来学习语义分割领域的经典模型:Segnet。

背景介绍
同样的我们继续来看看Segnet出现的历史背景,其是Segnet是与FCN同时期的研究工作,不过稍晚些于FCN,所以其在论文中很多的实验都是与FCN进行比较分析的。像上节我们讲的针对于医疗图像分割的Unet一样,Segnet也是针对于某些特定的任务的,其主要针对于无人驾驶任务和AR任务的,尤其是无人驾驶任务。
那么对于无人驾驶任务和AR任务,最重要的是什么呢?那肯定是实时性和准确性了,特别对于无人驾驶任务,我们需要在很短的时间内处理,并且需要具有相当不错的准确性。并且其准确性有个非常重要的方面就是其边界的准确性,需要车辆能够精确地识别和分割道路、行人、车辆、交通标志、车道线等元素。
再看当时方法的局限,首先关于手工方法就不说了,难以应对复杂的情况。而当时典型的语义分割模型FCN同样也存在着局限性,FCN在上采样的时候使用的是反卷积,其会丢失非常重要的空间信息,这会导致边界模糊,而这种情况不是我们希望看到的。其次FCN的计算量是比较大的,难以满足实时性的要求。
Segnet核心剖析
是的,你遇到这些问题你怎么办?希望大家都能够设身处地去思考,我们才能够明白每一项工作的创新意义,同时进行深入思考,很多时候你可能也会有自己的想法,每个创新性的想法都是不经意间的,希望大家都能够有思维的碰撞。好了,话说回来,我们来看看Segnet的作者是怎么做的,看看人家如何进行解决的。
池化索引(pooling Indices)
其实我认为Segnet最核心的创新部分就是采用了池化索引(pooling Indices)的方法了。什么意思呢,看下图,我们来详细讲解一下。

Segnet同样的也是使用了编码器-解码器 的结构。Segnet的上采样方式与FCN不同,FCN采用的是反卷积的方式,而Segnet**每次池化操作时,它不仅保留了池化后的特征图,还记录了每个池化区域中最大值的位置(即池化索引)。在解码器部分,它将这些池化索引传递过来,用于指导上采样过程。**通过这样的方式,编码器在低分辨率下检测到的语义信息(通过卷积特征)和原始图像中的精确空间位置信息(通过池化索引)得以结合,显著提高了分割边界的定位精度。
并且非常重要的一点,SegNet通过使用池化索引进行指导的上采样,避免了使用计算量大的反卷积操作。这使得SegNet模型参数量少,计算复杂度低,非常适合需要实时处理的场景,就像无人驾驶或者AR场景。
通过pooling Indices获得的上采样feature map是稀疏的,其实从图中就能看出,所以在解码器部分会有对应的卷积结构,将稀疏的feature map通过卷积变成稠密的feature map。
当然其弊端也会非常明显,在上采样的过程中缺少了学习的过程,所以其在精度上是无法与当时的FCN进行比较的,但是综合精度、实时性、轻便性来说,Segnet则是更好的考量。
其他细节
编码器解码器的对称结构
其实在其原论文中多次强调了编码器解码器的对称结构的重要性,因为像在FCN中,其编码器解码器的结构是及其不对称的,也就是下采样上采样的过程不对称,其编码器参数有134M但解码器参数仅仅只有0.5M。这就会导致个什么情况呢?解码器的部分的参数相当难以训练。其实如果看过FCN论文也会发现,其详细讲述了FCN的训练过程,也说明了其是难以训练的。
所以设计了一个端到端的、编码器网络中每个编码器都被逐步连接到解码器网络中的SegNet。这种想法很简单,也就是保存多个尺度上提取到的特征和全局的上下文信息,为上采样时提供更多的可用信息,从而保留更多高频细节,实现精细的分割。结构图如下所示:

Segnet模型代码
首先是我们的crop函数,为什么需要用到这个,因为在测试的时候,我们不会对图像进行resize操作的,所以其就不一定是32的倍数,在下采样的过程中可能会出现从45->22的情况,但是上采样过程中就会变成22->44,这样就会造成shape的不匹配,所以需要对齐两者的shape大小。
python
def crop(upsampled, bypass):
h1, w1 = upsampled.shape[2], upsampled.shape[3]
h2, w2 = bypass.shape[2], bypass.shape[3]
# 计算差值
deltah = h2 - h1
deltaw = w2 - w1
# 计算填充的起始和结束位置
# 对于高度
pad_top = deltah // 2
pad_bottom = deltah - pad_top
# 对于宽度
pad_left = deltaw // 2
pad_right = deltaw - pad_left
# 对 upsampled 进行中心填充
upsampled_padded = F.pad(upsampled, (pad_left, pad_right, pad_top, pad_bottom), "constant", 0)
return upsampled_padded
然后就是我们的Segnet模型代码了。其实还是非常好理解的,其编码器的结构就是VGG的结构,只不过其在maxpooling的时候需要保存索引,然后就是解码器的结构,其实就是对编码器做个对称就行了。写好模型参数之后,非常重要的,记得要进行参数的初始化哈,这样能够利于之后的训练过程。
python
class SegNet(nn.Module):
def __init__(self,num_classes=12):
super(SegNet, self).__init__()
self.encoder1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.encoder2 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
self.encoder3 = nn.Sequential(
nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
)
self.encoder4 = nn.Sequential(
nn.Conv2d(256,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
)
self.encoder5 = nn.Sequential(
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
)
self.decoder1 = nn.Sequential(
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
)
self.decoder2 = nn.Sequential(
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
)
self.decoder3 = nn.Sequential(
nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
self.decoder4 = nn.Sequential(
nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.decoder5 = nn.Sequential(
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,num_classes,kernel_size=1),
)
self.max_pool = nn.MaxPool2d(2,2,return_indices=True)
self.max_uppool = nn.MaxUnpool2d(2,2)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.encoder1(x)
x,pool_indices1 = self.max_pool(x1)
x2 = self.encoder2(x)
x,pool_indices2 = self.max_pool(x2)
x3 = self.encoder3(x)
x,pool_indices3 = self.max_pool(x3)
x4 = self.encoder4(x)
x,pool_indices4 = self.max_pool(x4)
x5 = self.encoder5(x)
x,pool_indices5 = self.max_pool(x5)
x = self.max_uppool(x,pool_indices5)
x = crop(x, x5)
x = self.decoder1(x)
x = self.max_uppool(x,pool_indices4)
x = crop(x, x4)
x = self.decoder2(x)
x = self.max_uppool(x,pool_indices3)
x = crop(x, x3)
x = self.decoder3(x)
x = self.max_uppool(x,pool_indices2)
x = crop(x, x2)
x = self.decoder4(x)
x = self.max_uppool(x,pool_indices1)
x = crop(x, x1)
x = self.decoder5(x)
return x
结语
希望上列所述内容对你有所帮助,如果有错误的地方欢迎大家批评指正!
并且如果可以的话希望大家能够三连鼓励一下,谢谢大家!
如果你觉得讲的还不错想转载,可以直接转载,不过麻烦指出本文来源出处即可,谢谢!
参考资料
本文参考了下列的文章内容,集百家之长汇聚于此,同时包含自己的思考想法