Pytorch实现轻量去雾网络

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

基于特征增强的双重注意力去雾网络

介绍:

创新点

网络结构

Ghost模块

RFB结构

双重注意力

残差组模块和SOS增强模块

密集融合模块

结果展示

数据集与实验设置


本文所有资源均可在该地址处获取。

基于特征增强的双重注意力去雾网络

介绍:

本文复现了一个轻量级的图像去雾网络,不需要租赁服务器,使用自己的电脑就可以完成模型训练及测试。该网络基于编解码器建立,在编解码部分使用特征增强以改善整体的恢复效果。

创新点

1、该网络基于编解码结构建立,在网络高维部分结合通道和空间注意力建立双重注意力特征增强模块。
2、该网络为了实现模型轻量化,引入Ghost模块替代非线性卷积;设计RFB(Receptive Field Block)模块以扩大感受野,实现对不同尺度特征的充分融合
3、实验表明,该网络对均匀雾上的复原效果较好

网络结构

网络由编码器、双重注意力特征增强模块和解码器三个部分组成。

如图所示,编码器由步长为1的浅层卷积和四个不同尺度大小的特征提取模块组成。

解码器由四个对应尺度大小的特征复原模块和一个步长为1的尾卷积组成。

网络高维部分为双重注意力特征增强模块,该模块由三个模块串联而成,分别为Ghost模块、RFB模块和双重注意力模块。

之后将详细介绍每个模块的结构。

Ghost模块

这个模块是GhostNet: More Features from Cheap Operations这篇论文提出的,被这篇论文引用用来实现网络的轻量级。Ghost的想法很简单,因为存在特征图相似度很高的情况,所以Ghost的作者认为这种相似的特征图是可以利用其中一张进行简单的线性运算来得到的,这样就可以用更少的参数获得更多的特征图。

Ghost的结构图看起来复杂,但是整体结构很好理解,如图所示,整体利用普通的逐点卷积(1×1)压缩通道数,之后再利用逐层卷积(3×3)来获得相似的特征图,最后将所有的结果堆叠得到输出特征图。

代码实现如下,需要注意通道数在过程中的变化,设置好合适的卷积核大小。

class GhostModule(nn.Module):
    def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
        super(GhostModule, self).__init__()
        self.oup = oup
        init_channels = math.ceil(oup / ratio)  # ratio = oup / intrinsic
        new_channels = init_channels * (ratio - 1)

        self.primary_conv = nn.Sequential(
            nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False),
            nn.BatchNorm2d(init_channels),
            nn.ReLU(inplace=True) if relu else nn.Sequential(),
        )

        self.cheap_operation = nn.Sequential(
            nn.Conv2d(init_channels, new_channels, dw_size, 1, padding=dw_size // 2, groups=init_channels, bias=False),
            # groups 分组卷积
            nn.BatchNorm2d(new_channels),
            nn.ReLU(inplace=True) if relu else nn.Sequential(),
        )

    def forward(self, x):
        x1 = self.primary_conv(x)
        x2 = self.cheap_operation(x1)
        out = torch.cat([x1, x2], dim=1)
        return out[:, :self.oup, :, :]

RFB结构

本文引入该结构用于扩大网络的感受野,该结构有三个并行分支,结构如图所示:

从结构图中很清晰的可以看出,三个分支的卷积尺度不一样,进而三个分支的作用也不一样。第一个分支叠加一个1×1卷积和两个3×3卷积,并使用空洞率为5的3×3卷积(空洞率使用参数dilation=5来实现)以扩大采样范围,获取整体的图像特征。第二个分支堆叠一个1×1卷积和一个3×3卷积,并使用空洞率为3的3×3卷积来进一步扩大采样范围,以提取常规大小的图像特征。第三个分支仅叠加一个1×1卷积核一个空洞率为1的3×3卷积来获得图像的细节特征。最后,将三个分支的结果拼接以获得图像的多尺度特征,进而保证特征的完整性。

实现代码如下:

class RFB(nn.Module):
    def __init__(self, channel):
        super(RFB, self).__init__()
        self.conv1 = nn.Conv2d(channel, channel, 1, 1)
        self.conv13 = nn.Conv2d(channel*3, channel, 1, 1)
        self.conv3 = nn.Conv2d(channel, channel, 3, 1)
        self.convp5 = nn.Conv2d(channel, channel, 3, 1, padding=5, dilation=5)  # 膨胀卷积
        self.convp3 = nn.Conv2d(channel, channel, 3, 1, padding=3, dilation=5)
        self.convp1 = nn.Conv2d(channel, channel, 3, 1, padding=1, dilation=1)

    def forward(self, x):
        feature1 = x
        feature2 = x
        feature3 = x

        feature1 = self.conv1(feature1)
        feature1 = self.conv3(feature1)
        feature1 = self.conv3(feature1)
        feature1 = self.convp5(feature1)

        feature2 = self.conv1(feature2)
        feature2 = self.conv3(feature2)
        feature2 = self.convp3(feature2)

        feature3 = self.conv1(feature3)
        feature3 = self.convp1(feature3)

        feature1 = F.interpolate(feature1, size=(16, 16), mode='bilinear', align_corners=True)
        feature2 = F.interpolate(feature2, size=(16, 16), mode='bilinear', align_corners=True)
        # print(feature1.size())
        # print(feature2.size())
        # print(feature3.size())
        feature = torch.cat((feature1, feature2, feature3), 1)
        # print(feature.size())
        feature = self.conv13(feature)

        x = self.conv1(x)
        x = x + feature
        return x

双重注意力

地址类似,本文受CBAM启发建立了双重注意力模块,结构如图所示:

可以很清晰的看到,结构大体与CBAM相似,该模块将通道注意力和空间注意力串联,来分别实现通道和空间维度的注意力权值,实现代码也跟CBAM类似,具体如下:

class Dual_Attention(nn.Module):
    def __init__(self, channel, r=16):
        super(Dual_Attention, self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.net = nn.Sequential(
            nn.Conv2d(channel, channel // r, 1),
            nn.ReLU(),
            nn.Conv2d(channel // r, channel, 1)
        )
        self.Sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(2, 1, 7, padding=7 // 2, bias=False)

    def forward(self, x):
        x1 = self.maxpool(x)
        x2 = self.avgpool(x)

        x1 = self.net(x1)
        x2 = self.net(x2)

        x3 = x1 + x2
        x3 = self.Sigmoid(x3)
        x3 = x3 * x

        avg_out = torch.mean(x3, dim=1, keepdim=True)
        max_out, _ = torch.max(x3, dim=1, keepdim=True)  # 就是返回两个值,下划线表示被忽略的值
        out = torch.cat((avg_out, max_out), 1)

        out = self.conv(out)
        out = self.Sigmoid(out)
        out = out*x3

        return out

残差组模块和SOS增强模块

残差组模块和SOS增强模块的结构也很轻量,编码器使用残差组模块,解码器使用SOS增强模块,结构如图所示。

残差组模块使用三个bottleneck结构组成,以防止梯度爆炸。SOS增强模块把之前的特征作为输入,以防止细节的丢失,并逐步细化特征。残差组的代码如下所示。

class Residual(nn.Module):
    def __init__(self, inchannels, outchannels):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(inchannels, outchannels, 1)
        self.conv2 = nn.Conv2d(inchannels, inchannels, 1, padding=1)
        self.conv3 = nn.Conv2d(outchannels, inchannels, 3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # print(x.size())
        x1 = self.conv1(x)
        x1 = self.relu(x1)
        # print(x1.size())
        x1 = self.conv3(x1)
        x1 = self.relu(x1)
        # print(x1.size())
        x1 = self.conv2(x1)
        # print(x1.size())

        x2 = x1 + x
        x2 = self.relu(x2)
        # print(x2.size())

        x3 = self.conv1(x2)
        x3 = self.relu(x3)
        x3 = self.conv3(x3)
        x3 = self.relu(x3)
        x3 = self.conv2(x3)

        x4 = x2 + x3
        x4 = self.relu(x4)

        x5 = self.conv1(x4)
        x5 = self.relu(x5)
        x5 = self.conv3(x5)
        x5 = self.relu(x5)
        x5 = self.conv2(x5)

        out = x + x4 + x5

        return out

密集融合模块

编解码网络会在图像传输过程中丢失细节特征,然而本文在编码器和解码器部分使用密集融合模块来弥补非相邻层之间缺少连接的问题,进而改善细节丢失问题。这个模块很有意思,接下来将用图的方式详细讲。官方给出的图是下面这样的:

比较难懂,但是结合整体的结构图来理解,其实就是对每一个不同尺度的特征图进行处理。具体来说,假设输入图像的通道数为3,那么经过一次浅层卷积之后通道数变为16,此时输入到第一层编码器模块中,得到通道数为32大小的特征图输入到第一个密集融合模块中,该模块的结构如图:

进一步输入到第二层编码器模块中,得到通道数为64大小的特征图输入到第二个密集融合模块中,该模块的结构如图:

进一步输入到第三层编码器模块中,得到通道数大小为128的特征图,输入到第三个密集融合模块中,该模块的结构如图:

之后,输入到最后一层解码器模块中,得到通道数大小为256的特征图,输入到第四个密集融合模块中,该模块的结构如图:

相同,与编码器的结构类似,解码器中的密集融合模块也如此,只有维度大小是从256到32。通道数为256的特征图输入到解码器中,第一层解码器将其变为128大小的,输入到解码器中第一个密集融合模块中,该模块的结构如图所示:

之后,经过第二层解码器变为通道数为64的特征图,输入到解码器中第二个密集融合模块中,该模块的结构如图所示:

接下来,将该模块的输出输入到第三层解码器中,得到通道数为32的特征图,将其输入到解码器中的第三个密集融合模块中,该模块的结构如图所示:

最后,将该模块的输出输入到最后一层解码器中,得到通道数为16的特征图,将其输入到解码器中的最后一个密集融合模块中,该模块的结构如图所示:

结果展示

数据集与实验设置

基于Pytorch框架复现,使用显存大小为2GB的GeForce MX250显卡在O-Haze数据集上跑的模型,结果如下,第一张为原始模糊图像,第二张为该方法恢复的图像,第三章为原始清晰图像。可以看到,模型效果仍有待提升。

有雾图像 复原后 无雾图像

​​

相关推荐
Srlua7 分钟前
辅助任务改进社交帖子多模态分类
人工智能·python
兔子的洋葱圈8 分钟前
Python的3D可视化库【vedo】2-5 (plotter模块) 坐标转换、场景导出、添加控件
python·3d·数据可视化
drebander16 分钟前
基于 Python 将 PDF 转 Markdown 并拆解为 JSON,支持自定义标题处理
python·pdf·json
L_cl26 分钟前
【NLP 15、深度学习处理文本】
人工智能·深度学习
2401_871151071 小时前
十二月第14讲:使用Python实现两组数据纵向排序
开发语言·python·算法
知新_ROL1 小时前
通过解调使用正则化相位跟踪技术进行相位解包裹
人工智能·算法·机器学习
一位小说男主1 小时前
可解释性方法:从理论到实践的深度剖析(续上文)
人工智能·深度学习·机器学习
Cachel wood1 小时前
Vue.js前端框架教程5:Vue数据拷贝和数组函数
linux·前端·vue.js·python·阿里云·前端框架·云计算
martian6651 小时前
深入详解神经网络基础知识——理解前馈神经网络( FNN)、卷积神经网络(CNN)和循环神经网络(RNN)等概念及应用
人工智能·深度学习·神经网络
程序猿人大林1 小时前
C# opencvsharp 流程化-脚本化-(2)ROI
人工智能·计算机视觉·c#