【语义分割专栏】2:U-net原理篇(由浅入深)

目录

前言

本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!带大家深入语义分割的领域,将从原理,代码深入讲解,希望大家能从中有所收获,其中很多内容都包含着自己的一些想法以及理解,如果有错误的地方欢迎大家批评指正。

论文名称:《U-Net: Convolutional Networks for Biomedical Image Segmentation》

论文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation

欢迎继续来到语义分割专栏系列的第二篇,本文将继续带大家来学习语义分割领域的经典模型:U-net

背景介绍

在上一篇中我们已经详细的介绍过了FCN的原理以及代码实现,本篇中我们要介绍的是U-net,是遵循着FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。

首先我们来看U-net出现的历史背景,U-net的设计初衷是为了解决医学图像分割中的挑战,但是因为其优秀合理的架构,故其也被广泛应用于各种语义分割任务中,包括卫星图像分析、农业图像处理、自动驾驶、遥感图像等。

首先我们来看当时在医学图像分割面临的哪些困难?

首先就是数据稀缺与数据标注困难 :因为其专业性,所以医学图像标注成本都会比较高,需要专业医生;同时处于伦理道德方面,其可用于训练的图像数据数量远少于自然图像数据(如ImageNet等);其次就是分割精度需求高 :因为在医疗领域,即使像素级的微小误差也可能影响诊断结果;所以我们的模型就需要能够兼顾全局上下文局部细节

那么如果是你,在当时的背景条件下,你希望能够在医学图像分割中取得更好的结果,你会怎么做?首先我们来想,我们的数据训练量不会很多,那么我们的模型肯定不能够过于复杂了,其网络设计一定是参数比较小 的,适合小样本训练的,但是去设计个咋样的架构才是合适的呢?这就是很具有创新性的一步了。其次分割精度需求高,那么我们的模型肯定要能够兼顾上下文的信息和局部的细节信息。其实这个在FCN已经给出了个答案,通过跳跃连接的方式,兼顾上下文信息,但是U-net的跳跃连接方式又有着一些差别

U-net核心剖析

是的,你遇到这些问题你怎么办?希望大家都能够设身处地去思考,我们才能够明白每一项工作的创新意义,同时进行深入思考,很多时候你可能也会有自己的想法,每个创新性的想法都是不经意间的,希望大家都能够有思维的碰撞。好了,话说回来,我们来看看U-net的作者时怎么做的,就是那些问题,看看人家如何进行解决的。

编码解码结构(U形状)

我始终觉得U-net是一个超级符合对称美学的架构,左侧的结构是特征提取部分,右侧的结构就是上采样部分,当然也有人将其称为编码器解码器结构。同时由于其网络的整体结构像一个大写的英文字母U,所以叫做U-net。

在Encoder部分,我们将输入图像经过多个卷积和池化操作,逐步提取其语义特征。每经过一个 2x2 最大池化层(红色箭头,采用的是无填充的方式),所以feature map的尺寸会减半。

在Decoder部分,通过 2x2 反卷积操作(绿色箭头)将特征图尺寸逐步放大两倍。上采样后的特征图与编码器部分对应尺度的特征图进行concat融合(蓝色箭头代表 3x3 卷积操作,步长为 1,有效填充,每次操作后特征图尺寸减少 2)。为了进行拼接,需要对尺寸较大的feature map进行crop操作(灰色箭头),使其与上采样后的特征图尺寸匹配。

最终输出层使用 1x1 卷积层进行分类,输出两层,分别代表前景和背景。输入图像为 572x572,输出图像为 388x388,说明经过网络处理后,输出结果与原图尺寸不完全对应。(这里大家可能会有点疑问,为什么语义分割任务会输入输出的shape不匹配,这个后面会有说明。)

卷积模式

这里我们可以看到一个细节,我们在Encoder部分的时候卷积之后,feature map的shape都会减小,这是因为其采用的卷积模式是valid。卷积一共有三个模式,分别是full mode、same mode、valid mode

  • full mode:从卷积核刚开始与我们的图像进行相交的时候就开始卷积操作。
  • same mode:当卷积核的中心(K)与图像的边角重合时,我们就开始做卷积运算,可见卷积核的运动范围比full模式小了一圈。当然了这里的same还有一个意思,就是当卷积之后输出的feature map尺寸保持不变(相对于输入图片)。
  • valid mode:当卷积核全部在图像里面的时候,进行卷积运算,可见卷积核的移动范围较same更小了。

这里我们我用蓝色表示卷积核,橙色表示我们的图像部分,白色部分位填充,相信通过下图能够更加清晰的了解不同的卷积方式了。

跳跃连接

不知道大家还记不记得我们之前在FCN中也讲到了跳跃连接的,我们这里回顾下:

首先我们需要明白一个事情就是,我们的网络在进行特征提取的时候是从低级语义信息不断不断的进行提取到最后的输出的高级语义信息的。网络的低层提取的语义信息更多代表了图像的纹理、边缘等一些显性的信息,网络的高层所提取的一些语义信息更多的就是其数据核心的抽象的语义信息了。那我们最后进行语义分割的特征图的语义信息肯定损失了很多关键的细节、边缘信息了,并且最后还会有上采样的过程,这个现象就会更加加剧。我们想要最后进行语义分割也能够有些这些细节信息怎么办?

这就是跳跃连接了。想要低层语义信息,直接把低层的语义信息加回来不就好了,简单粗暴,但是同样的也非常有效。这就是我们在FCN中跳跃连接的方式,直接将对应位置的信息进行相加,即就是相当于是add操作。

但是在U-net中的跳跃连接方式是concat,从图中也能看出,我们是将之前的低级语义信息与我们在后来提取到的高级语义信息进行通道上的相加了,不是对应位置像素直接相加。那么二者有什么区别呢?

add

我们来看,以下是 keras 中对 add 的实现源码,pytorch的封装更复杂一些,不过原理都是一样的,看这个就行:

python 复制代码
def _merge_function(self, inputs):
    output = inputs[0]
    for i in range(1, len(inputs)):
        output += inputs[i]
    return output

其中 inputs 为待融合的特征图,inputs[0]、inputs[1]......等的通道数一样,且特征图宽与高也一样。

从代码中可以很容易地看出,add 方式有以下特点

  1. 做的是对应通道对应位置的值的相加,通道数不变
  2. 描述图像的特征个数不变,但是每个特征下的信息却增加了。

concat

同样的,我们通过阅读下面代码实例帮助理解 concat 的工作原理:

python 复制代码
import torch

# 创建两个张量
t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
t2 = torch.tensor([[7, 8, 9], [10, 11, 12]])

# 沿第1维拼接
result_1 = torch.cat([t1, t2], dim=1)
print(result_1)
# 输出: tensor([[ 1,  2,  3,  7,  8,  9],
#               [ 4,  5,  6, 10, 11, 12]])

在模型网路当中,数据通常为 4 个维度,即 num×channels×height×width ,因此默认值 1 表示的是 channels 通道进行拼接。如:

python 复制代码
combine = torch.cat([d1, add1, add2, add3, add4], 1)

从代码中可以很容易地看出,concat 方式有以下特点:

  1. 做的是通道的合并,通道数变多了
  2. 描述图像的特征个数变多,但是每个特征下的信息却不变。

所以到这里,我们就能够很清晰的知道add操作和concat操作的不同了。

操作 描述 优点 缺点 补充
add - 相当于加了一种prior - 要求两路输入的对应通道特征图语义类似 - 计算量少 - 特征提取能力差 - 对应通道信息类似时,可融合多通道信息 - 尺度不一致时,小尺度特征可能被淹没
concat - 通过训练学习整合两个特征图通道之间的信息 - 特征提取能力强 - 计算量大(是add的2倍) - 能提取更合适的信息,效果更好

其他细节

overlap-tile策略

因为医学图像是一般都是相当大的,我们在分割的时候就不可能将原图直接输入网络,所以需要用一个滑动窗口把原图扫一遍,使用原图的切片进行训练或测试。可以看图,其中红框标出来的是要分割区域。但是我们在切图时要包含周围区域,overlap另一个重要原因是周围overlap部分可以为分割区域边缘部分提供纹理等信息。

但是这样的策略会带来一个问题,图像边界的图像块没有周围像素,卷积会使图像边缘处的信息丢失。因此其对周围像素采用了镜像扩充。下图中红框部分为原始图片,其周围扩充的像素点均由原图沿白线对称得到。这样,边界图像块也能得到准确的预测。

另一个问题是,这样的操作会带来图像重叠问题,即第一块图像周围的部分会和第二块图像重叠。所以还记得我们之前讲解网络结构的时候吗?其输入图像为 572x572,但是最终的输出图像为 388x388,我认为就是通过这样的方式和我们在concat时候的crop操作来让模型只关注图像的黄色区域内的部分。

弹性形变

为了解决任务中数据缺乏的问题,我们常常都是会采用一些数据增强的方法来扩充数据集。常见的增强方式包括对图像进行旋转、平移等仿射变换,或进行镜像处理。在此基础上,U-net 论文中使用一种更适合医学图像的数据增强方式------弹性变换。该方法最初在 MNIST 手写数字识别任务中使用,发现通过对原图进行弹性变形可以显著提升模型识别准确率。因为U-Net 处理的图像数据来自细胞组织,而细胞边界本身就具有自然的、不规则形变特性,因此使用弹性变换可以模拟真实情况下的结构畸变,从而提升模型的泛化能力。

弹性变换的基本原理是:为图像的每个像素坐标引入一个在 (−1,1)(-1, 1) 区间内的随机扰动,这些扰动通过高斯滤波平滑后,再乘以一个缩放系数来控制最终的形变幅度。最终,原图中位置 (x,y)(x, y) 的像素被映射到新的位置 (x+δ_x,y+δ_y)(x + \\delta_x, y + \\delta_y),新图像的像素值通过插值从原图获得,即新位置的值来自原图对应位置的值。

图示中展示了在相同扰动强度下,不同高斯标准差带来的形变效果。结果表明,第二幅图的形变效果在真实感和增强效果之间达到了较好的平衡。

这个时候我们在回头看,最初的核心两个问题:数据稀缺与数据标注困难分割精度需求高

通过设计了轻巧的U型网络,采用了大量数据增强的方式,使得其能够更好的适应小样本的任务。通过多尺度融合 + 跳跃连接,提升了对小物体和边界的感知能力;并且跳跃连接还能够避免深层网络中"语义信息丰富但空间信息丢失"的问题,从而能够保证分割精度。

U-net模型代码

这里同样的 我自己也尝试去复现了U-net模型代码,当然细节上跟原论文中的U-net不是完全一样,原来的U-net模型是适用于医学图像分割任务,所以其有部分设计也是为了医学图像分割设计的,我这里复现的U-net代码更适合普遍的语义分割任务,其输入输出的shape大小是相同的。

首先是我将所有的上采样下采样中的卷积部分集成到了一起,看模型结构能够看出,每个部分都是两次卷积,所以代码如下,就在设置不同stage的时候设置好输入输出通道即可。

python 复制代码
class Down_Up_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Down_Up_Conv, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv_block(x)

然后这是跳跃连接的代码,同时我们采取了crop操作。我们通过获取两个feature map的长宽,然后再对齐之后进行再通道维上的拼接,代码如下,还是比较好理解的。

python 复制代码
def crop_and_concat(upsampled, bypass):
    """
    将两个 feature map 在 H 和 W 上对齐后拼接(dim=1)
    - upsampled: 解码器上采样后的特征图 (N, C1, H1, W1)
    - bypass: 编码器传来的特征图 (N, C2, H2, W2)
    """
    h1, w1 = upsampled.shape[2], upsampled.shape[3]
    h2, w2 = bypass.shape[2], bypass.shape[3]

    # 计算差值
    delta_h = h2 - h1
    delta_w = w2 - w1

    # 对 encoder 输出进行中心裁剪
    bypass_cropped = bypass[:, :,
                     delta_h // 2: delta_h // 2 + h1,
                     delta_w // 2: delta_w // 2 + w1]

    # 拼接通道维
    return torch.cat([upsampled, bypass_cropped], dim=1)

然后就是搭建我们的U-net模型了,这还是比较容易的,将encoder部分的五个阶段的下采样卷积定义好,注意通道数的变换,然后就是Decoder的上采样的过程,我们使用的是转置卷积,上采样后还有卷积过程,所以我们按照U-net的模型图搭建即可。注意,我这里是把maxpooling给摘出来了的,每个下采样卷积之后都会有一个maxpooling层,这个可别忘了,在forward里面有体现。定义好模型参数之后就是模型参数的初始化了,这个步骤可千万不能忘。

python 复制代码
class UNet(nn.Module):
    def __init__(self, num_classes=2):
        super(UNet, self).__init__()
        self.stage_down1=Down_Up_Conv(3, 64)
        self.stage_down2=Down_Up_Conv(64, 128)
        self.stage_down3=Down_Up_Conv(128, 256)
        self.stage_down4=Down_Up_Conv(256, 512)
        self.stage_down5=Down_Up_Conv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2,padding=1)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2,padding=1)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2,padding=1)

        self.stage_up4=Down_Up_Conv(1024, 512)
        self.stage_up3=Down_Up_Conv(512, 256)
        self.stage_up2=Down_Up_Conv(256, 128)
        self.stage_up1=Down_Up_Conv(128, 64)
        self.stage_out=Down_Up_Conv(64, num_classes)
        self.maxpool = nn.MaxPool2d(kernel_size=2)

        self.initialize_weights()

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        stage1 = self.stage_down1(x)
        x = self.maxpool(stage1)
        stage2 = self.stage_down2(x)
        x = self.maxpool(stage2)
        stage3 = self.stage_down3(x)
        x = self.maxpool(stage3)
        stage4 = self.stage_down4(x)
        x = self.maxpool(stage4)
        stage5 = self.stage_down5(x)

        x = self.up4(stage5)

        x = self.stage_up4(crop_and_concat(x, stage4))
        x = self.up3(x)
        x = self.stage_up3(crop_and_concat(x, stage3))
        x = self.up2(x)
        x = self.stage_up2(crop_and_concat(x, stage2))
        x = self.up1(x)
        x = self.stage_up1(crop_and_concat(x, stage1))
        out = self.stage_out(x)
        return out

结语

希望上列所述内容对你有所帮助,如果有错误的地方欢迎大家批评指正!

并且如果可以的话希望大家能够三连鼓励一下,谢谢大家!

如果你觉得讲的还不错想转载,可以直接转载,不过麻烦指出本文来源出处即可,谢谢!

参考资料

本文参考了下列的文章内容,集百家之长汇聚于此,同时包含自己的思考想法

UNET详解和UNET++介绍(零基础)-CSDN博客

图像分割必备知识点 | Unet详解 理论+ 代码 - 知乎

深度学习系列-UNet网络 - 知乎

相关推荐
yzx9910133 小时前
Python开发系统项目
人工智能·python·深度学习·django
高效匠人3 小时前
人工智能-Chain of Thought Prompting(思维链提示,简称CoT)
人工智能
要努力啊啊啊4 小时前
GaLore:基于梯度低秩投影的大语言模型高效训练方法详解一
论文阅读·人工智能·语言模型·自然语言处理
先做个垃圾出来………4 小时前
《机器学习系统设计》
人工智能·机器学习
my_q5 小时前
机器学习与深度学习08-随机森林02
深度学习·随机森林·机器学习
s153355 小时前
6.RV1126-OPENCV 形态学基础膨胀及腐蚀
人工智能·opencv·计算机视觉
jndingxin5 小时前
OpenCV CUDA模块特征检测------角点检测的接口createMinEigenValCorner()
人工智能·opencv·计算机视觉
Tianyanxiao5 小时前
宇树科技更名“股份有限公司”深度解析:机器人企业IPO前奏与资本化路径
人工智能
道可云5 小时前
道可云人工智能每日资讯|北京农业人工智能与机器人研究院揭牌
人工智能·机器人·ar·deepseek
不爱吃山楂罐头5 小时前
第三十三天打卡复习
python·深度学习