3.deeplabv3+的深层网络结构的实现

在第一篇文章中我们提到"在encoder部分,主要包括了backbone(DCNN)、ASPP两大部分",在这里的backbone就是mobilenetv2网络结构和xception网络结构,而ASPP结构就是深层网络结构,其网络结构如下:

ASPP网络结构的原理其实很简单,可以看博文1.deeplabv3+网络结构及原理-CSDN博客,该博文有介绍。以上网络结构里的rate表示空洞卷积核的大小,显然,该网络结构总共5层卷积处理,之后再将不同的层用concat堆叠,最后再用1x1的卷积核整合特征,转换为图片中绿色的层。

下面深层网络结构的代码如下:

复制代码
#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
            nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        [b, c, row, col] = x.size()
        # -----------------------------------------#
        #   一共五个分支
        # -----------------------------------------#
        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)
        # -----------------------------------------#
        #   第五个分支,全局平均池化+卷积
        # -----------------------------------------#
        global_feature = torch.mean(x, 2, True)
        global_feature = torch.mean(global_feature, 3, True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)

        # -----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        # -----------------------------------------#
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result
相关推荐
静听夜半雨1 小时前
CANoe入门——3、新建LIN工程及LIN DataBase(LDF文件)的创建
网络·数据库·c++·编辑器
Jackilina_Stone1 小时前
【网工第6版】第5章 网络互联⑧
网络·软考·网工·第5章 网络互联
电鱼智能的电小鱼2 小时前
基于 EFISH-SBC-RK3588 的无人机通信云端数据处理模块方案‌
linux·网络·人工智能·嵌入式硬件·无人机·边缘计算
夜空晚星灿烂2 小时前
http通信之axios vs fecth该如何选择?
网络·网络协议·http
爱的叹息2 小时前
【前端】基于 Promise 的 HTTP 客户端工具Axios 详解
前端·网络·网络协议·http
christine-rr2 小时前
【25软考网工】第三章(4)生成树协议、广播风暴和MAC地址表震荡
网络·网络工程师·软考·考试
迷路的小绅士2 小时前
网络安全概述:定义、重要性与发展历程
网络·安全·web安全
昊昊昊昊昊明3 小时前
10天学会嵌入式技术之51单片机-day-7
linux·运维·网络
达斯维达的大眼睛3 小时前
如何在Linux用libevent写一个聊天服务器
linux·运维·服务器·网络
Zhuai-行淮3 小时前
施磊老师基于muduo网络库的集群聊天服务器(七)
服务器·网络·php