【图像去噪】论文复现:代替ReLU!Pytorch实现即插即用激活函数模块xUnit,并插入到DnCNN中实现xDnCNN!

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中)

本文亮点:

  • 实现三种xUnit模块:xUnit(论文默认)、xUnitS(轻量化)、xUnitD(密集型),随取随用
  • xUnit模块可以加入到任意去噪模型 中,替代ReLU激活函数
  • 测试结果与论文中所述观点基本一致

文章目录


前言

论文题目:xUnit: Learning a Spatial Activation Function for Efficient Image Restoration ------ xUnit:学习空间激活函数进行高效图像恢复

论文地址:xUnit: Learning a Spatial Activation Function for Efficient Image Restoration

论文源码:https://github.com/kligvasser/xUnit

对应的论文精读:【图像去噪】论文精读:xUnit: Learning a Spatial Activation Function for Efficient Image Restoration

只需要源码中的xUnit结构实现,并不需要其他的。本文将xUnit模块插入到DnCNN中实现xDnCNN。

一、xUnit结构实现

xUnit结构回顾:

  • xUnit默认结构:BN+RL+CD+BN+GS
  • xUnit轻量结构:CD+BN+GS
  • xUnit密集结构:CD+BN+RL+CD+BN+GS

代码实现如下,命名为activations.py

python 复制代码
import torch.nn as nn

class xUnit(nn.Module):
    def __init__(self, num_features=64, kernel_size=9, batch_norm=False):
        super(xUnit, self).__init__()
        # xUnit
        self.features = nn.Sequential(
            nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity,
            nn.ReLU(),
            nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
            nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity,
            nn.Sigmoid()
        )

    def forward(self, x):
        a = self.features(x)
        r = x * a
        return r

class xUnitS(nn.Module):
    def __init__(self, num_features=64, kernel_size=9, batch_norm=False):
        super(xUnitS, self).__init__()
        # slim xUnit
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
            nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
            nn.Sigmoid()
        )

    def forward(self, x):
        a = self.features(x)
        r = x * a
        return r

class xUnitD(nn.Module):
    def __init__(self, num_features=64, kernel_size=9, batch_norm=False):
        super(xUnitD, self).__init__()
        # dense xUnit
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=1, padding=0),
            nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_features, out_channels=num_features, kernel_size=kernel_size, padding=(kernel_size // 2), groups=num_features),
            nn.BatchNorm2d(num_features=num_features) if batch_norm else Identity(),
            nn.Sigmoid()
        )

    def forward(self, x):
        a = self.features(x)
        r = x * a
        return r

class Identity(nn.Module):
    def __init__(self,):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

二、xDnCNN结构实现

先回顾以下DnCNN的网络结构:

代码实现如下:

python 复制代码
class DnCNN(nn.Module):
    def __init__(self, num_layers=17, num_features=64):
        super(DnCNN, self).__init__()
        layers = [nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1),
                                nn.ReLU(inplace=True))]
        for i in range(num_layers - 2):
            layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
                                        nn.BatchNorm2d(num_features),
                                        nn.ReLU(inplace=True)))
        layers.append(nn.Conv2d(num_features, 3, kernel_size=3, padding=1))
        self.layers = nn.Sequential(*layers)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, inputs):
        y = inputs
        residual = self.layers(y)
        return y - residual

上面是一个17层的DnCNN实现,使用xUnit代替DnCNN中的ReLU,同时减少卷积层数为9层,称作xDnCNN。

实现如下:

python 复制代码
from torch import nn
from activations import xUnit, xUnitD, xUnitS

class xDnCNN(nn.Module):
    def __init__(self, num_layers=9, num_features=64):
        super(xDnCNN, self).__init__()
        layers = [nn.Sequential(nn.Conv2d(3, num_features, kernel_size=3, stride=1, padding=1),
                                xUnit(num_features, batch_norm=True))]
        for i in range(num_layers - 2):
            layers.append(nn.Sequential(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
                                        nn.BatchNorm2d(num_features),
                                        xUnit(num_features, batch_norm=True)))
        layers.append(nn.Conv2d(num_features, 3, kernel_size=3, padding=1))
        self.layers = nn.Sequential(*layers)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, inputs):
        y = inputs
        residual = self.layers(y)
        return y - residual

至此,xUnit模块嵌入了模型中,除了DnCNN,其他有ReLU激活的模型都可以以此方法替代。

DnCNN于xDnCNN结构对比:

  • DnCNN:[Conv+ReLU]+15个[Conv+BN+ReLU]+Conv,共17个卷积层,16个ReLU,15个BN
  • xDnCNN:[Conv+xUnit(BN+ReLU+Conv+BN+Sigmoid)]+7个[Conv+BN+xUnit(BN+ReLU+Conv+BN+Sigmoid)] + Conv,共17个卷积层,8个ReLU,23个BN

区别为:xDnCNN少了8个ReLU,多了8个BN,并且xUnit中的Conv卷积核为9×9,而其他均为3×3。

虽然模型性能区别不能简单地以模块数量多少而论,但也能从中发现一些端倪。

  • 卷积层个数相同,卷积核越大,参数不一定越多。也受整体层数影响,卷积核更大,整体层数更少,虽然在单层的参数量更多,但总体的参数量更少。
  • 本质上,是增大感受野以增强特征提取能力(论文图5、图9)。只是套了这么一个xUnit的壳,实际上就是改变结构,只不过把这一堆统称为xUnit。
  • 给我们的启示:尝试把一堆组件绑在一起作为一个整体,调整其中的某个参数(e.g.卷积核,就可以减少整体层数了),看看能不能有所提升。

xUnit的作用:减少模型参数、性能几乎不变、纹理细节提升!

三、结果展示

性能对比:

|------------|--------------|--------|
| Methods | DnCNN-S | xDnCNN |
| parameters | 559363(559K) | 29.08 |
| σ=50 | 306947(307K) | 29.03 |

视觉展示(论文图7):


至此本文结束。

如果本文对你有所帮助,请点赞收藏,创作不易,感谢您的支持!

相关推荐
m0_743106461 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106461 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
数据小爬虫@2 小时前
深入解析:使用 Python 爬虫获取苏宁商品详情
开发语言·爬虫·python
健胃消食片片片片2 小时前
Python爬虫技术:高效数据收集与深度挖掘
开发语言·爬虫·python
井底哇哇4 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证4 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩4 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控5 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
ℳ₯㎕ddzོꦿ࿐5 小时前
解决Python 在 Flask 开发模式下定时任务启动两次的问题
开发语言·python·flask
CodeClimb5 小时前
【华为OD-E卷 - 第k个排列 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od