【语义分割专栏】3:Segnet原理篇

目录

前言

本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!带大家深入语义分割的领域,将从原理,代码深入讲解,希望大家能从中有所收获,其中很多内容都包含着自己的一些想法以及理解,如果有错误的地方欢迎大家批评指正。

论文名称: SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

论文地址:SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

欢迎来到语义分割系列第三篇哦,本文将继续带大家来学习语义分割领域的经典模型:Segnet。

背景介绍

同样的我们继续来看看Segnet出现的历史背景,其是Segnet是与FCN同时期的研究工作,不过稍晚些于FCN,所以其在论文中很多的实验都是与FCN进行比较分析的。像上节我们讲的针对于医疗图像分割的Unet一样,Segnet也是针对于某些特定的任务的,其主要针对于无人驾驶任务和AR任务的,尤其是无人驾驶任务。

那么对于无人驾驶任务和AR任务,最重要的是什么呢?那肯定是实时性和准确性了,特别对于无人驾驶任务,我们需要在很短的时间内处理,并且需要具有相当不错的准确性。并且其准确性有个非常重要的方面就是其边界的准确性,需要车辆能够精确地识别和分割道路、行人、车辆、交通标志、车道线等元素。

再看当时方法的局限,首先关于手工方法就不说了,难以应对复杂的情况。而当时典型的语义分割模型FCN同样也存在着局限性,FCN在上采样的时候使用的是反卷积,其会丢失非常重要的空间信息,这会导致边界模糊,而这种情况不是我们希望看到的。其次FCN的计算量是比较大的,难以满足实时性的要求。

Segnet核心剖析

是的,你遇到这些问题你怎么办?希望大家都能够设身处地去思考,我们才能够明白每一项工作的创新意义,同时进行深入思考,很多时候你可能也会有自己的想法,每个创新性的想法都是不经意间的,希望大家都能够有思维的碰撞。好了,话说回来,我们来看看Segnet的作者是怎么做的,看看人家如何进行解决的。

池化索引(pooling Indices)

其实我认为Segnet最核心的创新部分就是采用了池化索引(pooling Indices)的方法了。什么意思呢,看下图,我们来详细讲解一下。

Segnet同样的也是使用了编码器-解码器 的结构。Segnet的上采样方式与FCN不同,FCN采用的是反卷积的方式,而Segnet**每次池化操作时,它不仅保留了池化后的特征图,还记录了每个池化区域中最大值的位置(即池化索引)。在解码器部分,它将这些池化索引传递过来,用于指导上采样过程。**通过这样的方式,编码器在低分辨率下检测到的语义信息(通过卷积特征)和原始图像中的精确空间位置信息(通过池化索引)得以结合,显著提高了分割边界的定位精度。

并且非常重要的一点,SegNet通过使用池化索引进行指导的上采样,避免了使用计算量大的反卷积操作。这使得SegNet模型参数量少,计算复杂度低,非常适合需要实时处理的场景,就像无人驾驶或者AR场景。

通过pooling Indices获得的上采样feature map是稀疏的,其实从图中就能看出,所以在解码器部分会有对应的卷积结构,将稀疏的feature map通过卷积变成稠密的feature map。

当然其弊端也会非常明显,在上采样的过程中缺少了学习的过程,所以其在精度上是无法与当时的FCN进行比较的,但是综合精度、实时性、轻便性来说,Segnet则是更好的考量。

其他细节

编码器解码器的对称结构

其实在其原论文中多次强调了编码器解码器的对称结构的重要性,因为像在FCN中,其编码器解码器的结构是及其不对称的,也就是下采样上采样的过程不对称,其编码器参数有134M但解码器参数仅仅只有0.5M。这就会导致个什么情况呢?解码器的部分的参数相当难以训练。其实如果看过FCN论文也会发现,其详细讲述了FCN的训练过程,也说明了其是难以训练的。

所以设计了一个端到端的、编码器网络中每个编码器都被逐步连接到解码器网络中的SegNet。这种想法很简单,也就是保存多个尺度上提取到的特征和全局的上下文信息,为上采样时提供更多的可用信息,从而保留更多高频细节,实现精细的分割。结构图如下所示:

Segnet模型代码

首先是我们的crop函数,为什么需要用到这个,因为在测试的时候,我们不会对图像进行resize操作的,所以其就不一定是32的倍数,在下采样的过程中可能会出现从45->22的情况,但是上采样过程中就会变成22->44,这样就会造成shape的不匹配,所以需要对齐两者的shape大小。

python 复制代码
def crop(upsampled, bypass):

    h1, w1 = upsampled.shape[2], upsampled.shape[3]
    h2, w2 = bypass.shape[2], bypass.shape[3]

    # 计算差值
    deltah = h2 - h1
    deltaw = w2 - w1

    # 计算填充的起始和结束位置
    # 对于高度
    pad_top = deltah // 2
    pad_bottom = deltah - pad_top
    # 对于宽度
    pad_left = deltaw // 2
    pad_right = deltaw - pad_left

    # 对 upsampled 进行中心填充
    upsampled_padded = F.pad(upsampled, (pad_left, pad_right, pad_top, pad_bottom), "constant", 0)

    return upsampled_padded

然后就是我们的Segnet模型代码了。其实还是非常好理解的,其编码器的结构就是VGG的结构,只不过其在maxpooling的时候需要保存索引,然后就是解码器的结构,其实就是对编码器做个对称就行了。写好模型参数之后,非常重要的,记得要进行参数的初始化哈,这样能够利于之后的训练过程。

python 复制代码
class SegNet(nn.Module):
    def __init__(self,num_classes=12):
        super(SegNet, self).__init__()
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.encoder4 = nn.Sequential(
            nn.Conv2d(256,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        self.encoder5 = nn.Sequential(
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )

        self.decoder1 = nn.Sequential(
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        self.decoder2 = nn.Sequential(
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.decoder3 = nn.Sequential(
            nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.decoder4 = nn.Sequential(
            nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.decoder5 = nn.Sequential(
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64,num_classes,kernel_size=1),
        )

        self.max_pool = nn.MaxPool2d(2,2,return_indices=True)
        self.max_uppool = nn.MaxUnpool2d(2,2)

        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x1 = self.encoder1(x)
        x,pool_indices1 = self.max_pool(x1)
        x2 = self.encoder2(x)
        x,pool_indices2 = self.max_pool(x2)
        x3 = self.encoder3(x)
        x,pool_indices3 = self.max_pool(x3)
        x4 = self.encoder4(x)
        x,pool_indices4 = self.max_pool(x4)
        x5 = self.encoder5(x)
        x,pool_indices5 = self.max_pool(x5)

        x = self.max_uppool(x,pool_indices5)
        x = crop(x, x5)
        x = self.decoder1(x)
        x = self.max_uppool(x,pool_indices4)
        x = crop(x, x4)
        x = self.decoder2(x)
        x = self.max_uppool(x,pool_indices3)
        x = crop(x, x3)
        x = self.decoder3(x)
        x = self.max_uppool(x,pool_indices2)
        x = crop(x, x2)
        x = self.decoder4(x)
        x = self.max_uppool(x,pool_indices1)
        x = crop(x, x1)
        x = self.decoder5(x)

        return x

结语

希望上列所述内容对你有所帮助,如果有错误的地方欢迎大家批评指正!

并且如果可以的话希望大家能够三连鼓励一下,谢谢大家!

如果你觉得讲的还不错想转载,可以直接转载,不过麻烦指出本文来源出处即可,谢谢!

参考资料

本文参考了下列的文章内容,集百家之长汇聚于此,同时包含自己的思考想法

SegNet图像分割网络直观详解 - 知乎

相关推荐
蓝婷儿1 小时前
Python 机器学习核心入门与实战进阶 Day 1 - 分类 vs 回归
python·机器学习·分类
Devil枫2 小时前
Kotlin扩展函数与属性
开发语言·python·kotlin
程序员阿超的博客3 小时前
Python 数据分析与机器学习入门 (八):用 Scikit-Learn 跑通第一个机器学习模型
python·机器学习·数据分析·scikit-learn·入门教程·python教程
xingshanchang4 小时前
PyTorch 不支持旧GPU的异常状态与解决方案:CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH
人工智能·pytorch·python
reddingtons5 小时前
Adobe Firefly AI驱动设计:实用技巧与创新思维路径
大数据·人工智能·adobe·illustrator·photoshop·premiere·indesign
CertiK5 小时前
IBW 2025: CertiK首席商务官出席,探讨AI与Web3融合带来的安全挑战
人工智能·安全·web3
Deepoch6 小时前
Deepoc 大模型在无人机行业应用效果的方法
人工智能·科技·ai·语言模型·无人机
Deepoch6 小时前
Deepoc 大模型:无人机行业的智能变革引擎
人工智能·科技·算法·ai·动态规划·无人机
kngines6 小时前
【字节跳动】数据挖掘面试题0003:有一个文件,每一行是一个数字,如何用 MapReduce 进行排序和求每个用户每个页面停留时间
人工智能·数据挖掘·mapreduce·面试题
Binary_ey6 小时前
AR衍射光波导设计遇瓶颈,OAS 光学软件来破局
人工智能·软件需求·光学软件·光波导