3.deeplabv3+的深层网络结构的实现

在第一篇文章中我们提到"在encoder部分,主要包括了backbone(DCNN)、ASPP两大部分",在这里的backbone就是mobilenetv2网络结构和xception网络结构,而ASPP结构就是深层网络结构,其网络结构如下:

ASPP网络结构的原理其实很简单,可以看博文1.deeplabv3+网络结构及原理-CSDN博客,该博文有介绍。以上网络结构里的rate表示空洞卷积核的大小,显然,该网络结构总共5层卷积处理,之后再将不同的层用concat堆叠,最后再用1x1的卷积核整合特征,转换为图片中绿色的层。

下面深层网络结构的代码如下:

复制代码
#-----------------------------------------#
#   ASPP特征提取模块
#   利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
            nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        [b, c, row, col] = x.size()
        # -----------------------------------------#
        #   一共五个分支
        # -----------------------------------------#
        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)
        # -----------------------------------------#
        #   第五个分支,全局平均池化+卷积
        # -----------------------------------------#
        global_feature = torch.mean(x, 2, True)
        global_feature = torch.mean(global_feature, 3, True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)

        # -----------------------------------------#
        #   将五个分支的内容堆叠起来
        #   然后1x1卷积整合特征。
        # -----------------------------------------#
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result
相关推荐
G_H_S_3_5 分钟前
【网络运维】Linux:正则表达式
linux·运维·网络·正则表达式
腾科张老师2 小时前
OSPF 典型组网
网络·智能路由器
2301_801673017 小时前
8.19笔记
网络·安全
三坛海会大神55511 小时前
计算机网络参考模型与子网划分
网络·计算机网络
云卓SKYDROID11 小时前
无人机激光测距技术应用与挑战
网络·无人机·吊舱·高科技·云卓科技
iナナ17 小时前
传输层协议——UDP和TCP
网络·网络协议·tcp/ip·udp
华强笔记18 小时前
Linux内存管理系统性总结
linux·运维·网络
iY_n20 小时前
Linux网络基础
linux·网络·arm开发
EggrollOrz20 小时前
网络编程day3
网络
想睡hhh21 小时前
网络基础——Socket编程预备
网络