DnCNN 介绍及基于Pytorch复现

文章目录

  • [DnCNN 介绍及基于Pytorch复现](#DnCNN 介绍及基于Pytorch复现)
  • [1. DnCNN 论文阅读](#1. DnCNN 论文阅读)
    • [1.1 核心思想:残差学习](#1.1 核心思想:残差学习)
    • [1.2 Batch Normalization (BN)](#1.2 Batch Normalization (BN))
    • [1.3 网络架构](#1.3 网络架构)
  • [2. DnCNN 基于Pytorch复现](#2. DnCNN 基于Pytorch复现)
    • [2.1 项目代码说明:](#2.1 项目代码说明:)
    • [2.2 DnCNN 网络结构和损失函数](#2.2 DnCNN 网络结构和损失函数)
    • [2.3 高版本Pytorch运行过程报错修正](#2.3 高版本Pytorch运行过程报错修正)
      • [2.3.1 compare_psnr, compare_ssim报错:修改方式如下](#2.3.1 compare_psnr, compare_ssim报错:修改方式如下)
      • [2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改](#2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改)

DnCNN 介绍及基于Pytorch复现

1. DnCNN 论文阅读

论文地址Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

GitHub 源码DnCNN GitHub Repository 由论文作者提供的官方实现

《Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising》是2017年由Xie, Zoran, and Wu等人提出的一篇论文,主要探索了使用深度卷积神经网络(CNN)来进行图像去噪的方法。与传统的高斯噪声去噪方法相比,提出的方法不仅仅依赖于高斯噪声模型,而是采用了残差学习来改进去噪性能。

提出了卷积神经网络结合残差学习来进行图像降噪,直接学习图像噪声,可以更好的降噪。

  • 强调了residual learning(残差学习)和batch normalization(批量标准化)在图像复原中相辅相成的作用,可以在较深的网络的条件下,依然能带来快的收敛和好的性能。
  • 文章提出DnCNN,在高斯去噪问题下,用单模型应对不同程度的高斯噪音;甚至可以用单模型应对高斯去噪、超分辨率、JPEG去锁三个领域的问题。

1.1 核心思想:残差学习

残差学习 是本论文的核心创新。与直接学习从含噪图像到干净图像的映射不同,该方法通过学习图像噪声的残差(即噪声部分与干净图像的差值)来进行去噪。这一思想的优势在于,噪声部分通常较为简单,且不需要网络学习整个图像的高维特征。

数学表达

假设输入图像 y = x + n { y = x + n } y=x+n,其中:

  • y { y } y是含噪图像。
  • x { x } x是干净图像。
  • n { n } n是噪声。

传统方法的目标是从 y { y } y恢复出 x { x } x,即:
x ^ = f ( y ) {\hat{x} = f(y)} x^=f(y)

其中 f { f } f是去噪网络。

而残差学习方法则先通过CNN学习噪声部分 n { n } n的残差:
n ^ = g ( y ) {\hat{n} = g(y)} n^=g(y)

然后从含噪图像中减去预测的残差,得到去噪图像:
x ^ = y − n ^ {\hat{x} = y - \hat{n}} x^=y−n^

通过这种方式,网络可以专注于预测噪声,通常比直接恢复干净图像更有效。

1.2 Batch Normalization (BN)

内部协变量移位(internal covariate shift):深层神经网络在做非线性变换前的激活输入值,随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

Batch Normalization (BN) 是一种用于加速神经网络训练并改善性能的技术。BN 在每一层的输入上应用标准化操作,确保其均值为零、方差为一。标准化的目的是缓解 internal covariate shift 内部协变量偏移(即每层输入分布随训练过程而变化),从而加速收敛并提高模型稳定性。

具体而言,BN 对输入数据 x {x } x 进行标准化操作,将其转化为一个均值为 0、方差为 1 的数据:
x ^ = x − μ σ {\hat{x} = \frac{x - \mu}{\sigma}} x^=σx−μ

其中:

  • μ {\mu} μ和 σ { \sigma } σ分别是批量数据的均值和标准差。

BN 还引入了可训练的尺度( γ {\gamma } γ)和偏移( β { \beta } β)参数,帮助恢复标准化后数据的原始分布:
y = γ x ^ + β {y = {\gamma} {\hat{x} }+ {\beta}} y=γx^+β

其中, γ {\gamma } γ 和 β { \beta } β 是可学习的参数,允许网络通过训练自动调节。

DnCNN 中的 Batch Normalization

在 DnCNN 中,Batch Normalization 的应用是该网络的一个重要组成部分。DnCNN 通过多个卷积层和 Batch Normalization 层来学习图像中的噪声,并通过残差学习来预测噪声的残差。具体地,Batch Normalization 被用于以下几个方面:

  1. 加速训练过程

由于 BN 能够保持每一层输入的均值和方差稳定,网络的训练过程变得更加平滑。这样,DNN 不需要过多地依赖于较低的学习率,且能够通过较大的学习率快速收敛。对于图像去噪任务,快速的收敛速度和更少的训练时间是非常重要的,尤其是在大规模数据集上。

  1. 缓解梯度消失/爆炸问题

深层神经网络训练中的常见问题是梯度消失或梯度爆炸,特别是在较深的网络中。

Batch Normalization 在每一层的输入上进行标准化,能够有效防止梯度消失或爆炸现象,使得网络能够稳定训练。DNN 通过 BN 获得更为稳定的激活分布,从而加快训练速度。

  1. 增强模型的鲁棒性

通过标准化每一层的输入,BN 改变了网络的训练动态,使得网络不容易受到初始化参数或输入数据分布变化的影响。这意味着即使网络初始参数随机化,或输入数据分布发生变化,模型依然能够稳定地训练。对于去噪任务,图像的不同区域或噪声类型可能会存在较大的差异,BN 能够提高网络对这些变化的鲁棒性。

  1. 减少对权重初始化的依赖

Batch Normalization 使得网络的训练过程不那么依赖于精心设计的权重初始化。由于 BN 能够在每一层进行标准化,它减轻了网络对初始化参数的敏感性。使用 BN 后,可以通过较大的学习率进行训练,从而更好地利用网络的表现。

  1. 改进的去噪效果

在 DnCNN 中,网络的目标是通过学习输入图像和噪声之间的残差来恢复干净图像。使用 Batch Normalization 有助于更好地学习噪声的特征,并提高网络对复杂噪声的去除能力。BN 帮助每一层输入数据保持稳定,从而提高了去噪网络在高噪声环境中的鲁棒性。

1.3 网络架构

在网络架构设计中,作者对VGG网络进行了修改,使其适合于图像去噪,并根据最先进的去噪方法中使用的有效补丁大小来设置网络的深度。在模型学习中,采用残差学习公式,并将其纳入批归一化,以快速训练和提高去噪性能。

假设DnCNN的输入是噪声观测 y = x + n { y = x + n } y=x+n,采用残差学习公式来训练残差映射 R ( y ) = n {R(y)=n} R(y)=n,则 x = y − R ( y ) {x=y-R(y)} x=y−R(y);期望残差图像与噪声输入的估计残差之间的平均均方误差:

对于具有一定噪声水平的高斯去噪,将DnCNN的接受场大小设置为35×35,相应的深度为17。对于其他一般的图像去噪任务,采用一个更大的接受域,并将深度设置为20。

第一部分:Conv(3 * 3 * c * 64)+ReLu (c代表图片通道数)

第二部分:Conv(3 * 3 * 64 * 64)+BN(batch normalization)+ReLu

第三部分:Conv(3 * 3 * 64)

每一层都zero padding,使得每一层的输入、输出尺寸保持一致。以此防止产生人工边界(boundary artifacts)。第二部分每一层在卷积与reLU之间都加了批量标准化(batch normalization、BN)。

2. DnCNN 基于Pytorch复现

Pytorch实现代码DnCNN GitHub Pytorch
Pytorch实现 数据集DnCNN GitHub Pytorch dataset

2.1 项目代码说明:

  • data目录:存放训练集和测试集
  • models目录:存放训练的模型
  • results目录:存放处理后的图像
  • data_generator.py:数据预处理
  • main_test.py: 基于模型测试数据集,输出去噪结果并计算测试集PSNR和SSIM
  • main_train.py: 训练网络

运行方式:

1、放置好训练集和测试集

2、运行main_train.py进行训练(时间较长,请耐性等待)

3、运行main_test.py进行测试

2.2 DnCNN 网络结构和损失函数

第一部分:Conv(3 * 3 * c * 64)+ReLu (c代表图片通道数)

第二部分:Conv(3 * 3 * 64 * 64)+BN(batch normalization)+ReLu

第三部分:Conv(3 * 3 * 64)

python 复制代码
class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

损失函数计算:

假设DnCNN的输入是噪声观测 y = x + n { y = x + n } y=x+n,采用残差学习公式来训练残差映射 R ( y ) = n {R(y)=n} R(y)=n,则 x = y − R ( y ) {x=y-R(y)} x=y−R(y);期望残差图像与噪声输入的估计残差之间的平均均方误差:

python 复制代码
class sum_squared_error(_Loss): 
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)

优化器:

由于Pytorch与Matlab不同,Pytorch版本中作者将SGD改为Adam,学习率调整为1e-3。

2.3 高版本Pytorch运行过程报错修正

2.3.1 compare_psnr, compare_ssim报错:修改方式如下

from skimage.metrics import structural_similarity as compare_ssim

from skimage.metrics import peak_signal_noise_ratio as compare_psnr

2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改

相关推荐
西西弗Sisyphus18 分钟前
全面掌握Python时间处理
python·time
java1234_小锋2 小时前
一周学会Flask3 Python Web开发-http响应状态码
python·flask·flask3
Elastic 中国社区官方博客3 小时前
Elasticsearch 混合搜索 - Hybrid Search
大数据·人工智能·elasticsearch·搜索引擎·ai·语言模型·全文检索
@心都3 小时前
机器学习数学基础:29.t检验
人工智能·机器学习
9命怪猫3 小时前
DeepSeek底层揭秘——微调
人工智能·深度学习·神经网络·ai·大模型
奔跑吧邓邓子3 小时前
【Python爬虫(12)】正则表达式:Python爬虫的进阶利刃
爬虫·python·正则表达式·进阶·高级
码界筑梦坊4 小时前
基于Flask的京东商品信息可视化分析系统的设计与实现
大数据·python·信息可视化·flask·毕业设计
pianmian14 小时前
python绘图之箱型图
python·信息可视化·数据分析
csbDD4 小时前
2025年网络安全(黑客技术)三个月自学手册
linux·网络·python·安全·web安全
kcarly5 小时前
KTransformers如何通过内核级优化、多GPU并行策略和稀疏注意力等技术显著加速大语言模型的推理速度?
人工智能·语言模型·自然语言处理