文章目录
- [DnCNN 介绍及基于Pytorch复现](#DnCNN 介绍及基于Pytorch复现)
- [1. DnCNN 论文阅读](#1. DnCNN 论文阅读)
-
- [1.1 核心思想:残差学习](#1.1 核心思想:残差学习)
- [1.2 Batch Normalization (BN)](#1.2 Batch Normalization (BN))
- [1.3 网络架构](#1.3 网络架构)
- [2. DnCNN 基于Pytorch复现](#2. DnCNN 基于Pytorch复现)
-
- [2.1 项目代码说明:](#2.1 项目代码说明:)
- [2.2 DnCNN 网络结构和损失函数](#2.2 DnCNN 网络结构和损失函数)
- [2.3 高版本Pytorch运行过程报错修正](#2.3 高版本Pytorch运行过程报错修正)
-
- [2.3.1 compare_psnr, compare_ssim报错:修改方式如下](#2.3.1 compare_psnr, compare_ssim报错:修改方式如下)
- [2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改](#2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改)
DnCNN 介绍及基于Pytorch复现
1. DnCNN 论文阅读
论文地址 :Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising
GitHub 源码 :DnCNN GitHub Repository 由论文作者提供的官方实现
《Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising》是2017年由Xie, Zoran, and Wu等人提出的一篇论文,主要探索了使用深度卷积神经网络(CNN)来进行图像去噪的方法。与传统的高斯噪声去噪方法相比,提出的方法不仅仅依赖于高斯噪声模型,而是采用了残差学习来改进去噪性能。
提出了卷积神经网络结合残差学习来进行图像降噪,直接学习图像噪声,可以更好的降噪。
- 强调了residual learning(残差学习)和batch normalization(批量标准化)在图像复原中相辅相成的作用,可以在较深的网络的条件下,依然能带来快的收敛和好的性能。
- 文章提出DnCNN,在高斯去噪问题下,用单模型应对不同程度的高斯噪音;甚至可以用单模型应对高斯去噪、超分辨率、JPEG去锁三个领域的问题。
1.1 核心思想:残差学习
残差学习 是本论文的核心创新。与直接学习从含噪图像到干净图像的映射不同,该方法通过学习图像噪声的残差(即噪声部分与干净图像的差值)来进行去噪。这一思想的优势在于,噪声部分通常较为简单,且不需要网络学习整个图像的高维特征。
数学表达 :
假设输入图像 y = x + n { y = x + n } y=x+n,其中:
- y { y } y是含噪图像。
- x { x } x是干净图像。
- n { n } n是噪声。
传统方法的目标是从 y { y } y恢复出 x { x } x,即:
x ^ = f ( y ) {\hat{x} = f(y)} x^=f(y)
其中 f { f } f是去噪网络。
而残差学习方法则先通过CNN学习噪声部分 n { n } n的残差:
n ^ = g ( y ) {\hat{n} = g(y)} n^=g(y)
然后从含噪图像中减去预测的残差,得到去噪图像:
x ^ = y − n ^ {\hat{x} = y - \hat{n}} x^=y−n^
通过这种方式,网络可以专注于预测噪声,通常比直接恢复干净图像更有效。
1.2 Batch Normalization (BN)
内部协变量移位(internal covariate shift):深层神经网络在做非线性变换前的激活输入值,随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。
Batch Normalization (BN) 是一种用于加速神经网络训练并改善性能的技术。BN 在每一层的输入上应用标准化操作,确保其均值为零、方差为一。标准化的目的是缓解 internal covariate shift 内部协变量偏移(即每层输入分布随训练过程而变化),从而加速收敛并提高模型稳定性。
具体而言,BN 对输入数据 x {x } x 进行标准化操作,将其转化为一个均值为 0、方差为 1 的数据:
x ^ = x − μ σ {\hat{x} = \frac{x - \mu}{\sigma}} x^=σx−μ
其中:
- μ {\mu} μ和 σ { \sigma } σ分别是批量数据的均值和标准差。
BN 还引入了可训练的尺度( γ {\gamma } γ)和偏移( β { \beta } β)参数,帮助恢复标准化后数据的原始分布:
y = γ x ^ + β {y = {\gamma} {\hat{x} }+ {\beta}} y=γx^+β
其中, γ {\gamma } γ 和 β { \beta } β 是可学习的参数,允许网络通过训练自动调节。
DnCNN 中的 Batch Normalization
在 DnCNN 中,Batch Normalization 的应用是该网络的一个重要组成部分。DnCNN 通过多个卷积层和 Batch Normalization 层来学习图像中的噪声,并通过残差学习来预测噪声的残差。具体地,Batch Normalization 被用于以下几个方面:
- 加速训练过程
由于 BN 能够保持每一层输入的均值和方差稳定,网络的训练过程变得更加平滑。这样,DNN 不需要过多地依赖于较低的学习率,且能够通过较大的学习率快速收敛。对于图像去噪任务,快速的收敛速度和更少的训练时间是非常重要的,尤其是在大规模数据集上。
- 缓解梯度消失/爆炸问题
深层神经网络训练中的常见问题是梯度消失或梯度爆炸,特别是在较深的网络中。
Batch Normalization 在每一层的输入上进行标准化,能够有效防止梯度消失或爆炸现象,使得网络能够稳定训练。DNN 通过 BN 获得更为稳定的激活分布,从而加快训练速度。
- 增强模型的鲁棒性
通过标准化每一层的输入,BN 改变了网络的训练动态,使得网络不容易受到初始化参数或输入数据分布变化的影响。这意味着即使网络初始参数随机化,或输入数据分布发生变化,模型依然能够稳定地训练。对于去噪任务,图像的不同区域或噪声类型可能会存在较大的差异,BN 能够提高网络对这些变化的鲁棒性。
- 减少对权重初始化的依赖
Batch Normalization 使得网络的训练过程不那么依赖于精心设计的权重初始化。由于 BN 能够在每一层进行标准化,它减轻了网络对初始化参数的敏感性。使用 BN 后,可以通过较大的学习率进行训练,从而更好地利用网络的表现。
- 改进的去噪效果
在 DnCNN 中,网络的目标是通过学习输入图像和噪声之间的残差来恢复干净图像。使用 Batch Normalization 有助于更好地学习噪声的特征,并提高网络对复杂噪声的去除能力。BN 帮助每一层输入数据保持稳定,从而提高了去噪网络在高噪声环境中的鲁棒性。
1.3 网络架构
在网络架构设计中,作者对VGG网络进行了修改,使其适合于图像去噪,并根据最先进的去噪方法中使用的有效补丁大小来设置网络的深度。在模型学习中,采用残差学习公式,并将其纳入批归一化,以快速训练和提高去噪性能。
假设DnCNN的输入是噪声观测 y = x + n { y = x + n } y=x+n,采用残差学习公式来训练残差映射 R ( y ) = n {R(y)=n} R(y)=n,则 x = y − R ( y ) {x=y-R(y)} x=y−R(y);期望残差图像与噪声输入的估计残差之间的平均均方误差:

对于具有一定噪声水平的高斯去噪,将DnCNN的接受场大小设置为35×35,相应的深度为17。对于其他一般的图像去噪任务,采用一个更大的接受域,并将深度设置为20。
第一部分:Conv(3 * 3 * c * 64)+ReLu (c代表图片通道数)
第二部分:Conv(3 * 3 * 64 * 64)+BN(batch normalization)+ReLu
第三部分:Conv(3 * 3 * 64)
每一层都zero padding,使得每一层的输入、输出尺寸保持一致。以此防止产生人工边界(boundary artifacts)。第二部分每一层在卷积与reLU之间都加了批量标准化(batch normalization、BN)。
2. DnCNN 基于Pytorch复现
Pytorch实现代码 :DnCNN GitHub Pytorch
Pytorch实现 数据集 :DnCNN GitHub Pytorch dataset
2.1 项目代码说明:

- data目录:存放训练集和测试集
- models目录:存放训练的模型
- results目录:存放处理后的图像
- data_generator.py:数据预处理
- main_test.py: 基于模型测试数据集,输出去噪结果并计算测试集PSNR和SSIM
- main_train.py: 训练网络
运行方式:
1、放置好训练集和测试集
2、运行main_train.py进行训练(时间较长,请耐性等待)
3、运行main_test.py进行测试
2.2 DnCNN 网络结构和损失函数
第一部分:Conv(3 * 3 * c * 64)+ReLu (c代表图片通道数)
第二部分:Conv(3 * 3 * 64 * 64)+BN(batch normalization)+ReLu
第三部分:Conv(3 * 3 * 64)
python
class DnCNN(nn.Module):
def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def forward(self, x):
y = x
out = self.dncnn(x)
return y-out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
损失函数计算:
假设DnCNN的输入是噪声观测 y = x + n { y = x + n } y=x+n,采用残差学习公式来训练残差映射 R ( y ) = n {R(y)=n} R(y)=n,则 x = y − R ( y ) {x=y-R(y)} x=y−R(y);期望残差图像与噪声输入的估计残差之间的平均均方误差:

python
class sum_squared_error(_Loss):
"""
Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
The backward is defined as: input-target
"""
def __init__(self, size_average=None, reduce=None, reduction='sum'):
super(sum_squared_error, self).__init__(size_average, reduce, reduction)
def forward(self, input, target):
# return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)
优化器:
由于Pytorch与Matlab不同,Pytorch版本中作者将SGD改为Adam,学习率调整为1e-3。
2.3 高版本Pytorch运行过程报错修正
2.3.1 compare_psnr, compare_ssim报错:修改方式如下
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改
