DnCNN 介绍及基于Pytorch复现

文章目录

  • [DnCNN 介绍及基于Pytorch复现](#DnCNN 介绍及基于Pytorch复现)
  • [1. DnCNN 论文阅读](#1. DnCNN 论文阅读)
    • [1.1 核心思想:残差学习](#1.1 核心思想:残差学习)
    • [1.2 Batch Normalization (BN)](#1.2 Batch Normalization (BN))
    • [1.3 网络架构](#1.3 网络架构)
  • [2. DnCNN 基于Pytorch复现](#2. DnCNN 基于Pytorch复现)
    • [2.1 项目代码说明:](#2.1 项目代码说明:)
    • [2.2 DnCNN 网络结构和损失函数](#2.2 DnCNN 网络结构和损失函数)
    • [2.3 高版本Pytorch运行过程报错修正](#2.3 高版本Pytorch运行过程报错修正)
      • [2.3.1 compare_psnr, compare_ssim报错:修改方式如下](#2.3.1 compare_psnr, compare_ssim报错:修改方式如下)
      • [2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改](#2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改)

DnCNN 介绍及基于Pytorch复现

1. DnCNN 论文阅读

论文地址Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

GitHub 源码DnCNN GitHub Repository 由论文作者提供的官方实现

《Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising》是2017年由Xie, Zoran, and Wu等人提出的一篇论文,主要探索了使用深度卷积神经网络(CNN)来进行图像去噪的方法。与传统的高斯噪声去噪方法相比,提出的方法不仅仅依赖于高斯噪声模型,而是采用了残差学习来改进去噪性能。

提出了卷积神经网络结合残差学习来进行图像降噪,直接学习图像噪声,可以更好的降噪。

  • 强调了residual learning(残差学习)和batch normalization(批量标准化)在图像复原中相辅相成的作用,可以在较深的网络的条件下,依然能带来快的收敛和好的性能。
  • 文章提出DnCNN,在高斯去噪问题下,用单模型应对不同程度的高斯噪音;甚至可以用单模型应对高斯去噪、超分辨率、JPEG去锁三个领域的问题。

1.1 核心思想:残差学习

残差学习 是本论文的核心创新。与直接学习从含噪图像到干净图像的映射不同,该方法通过学习图像噪声的残差(即噪声部分与干净图像的差值)来进行去噪。这一思想的优势在于,噪声部分通常较为简单,且不需要网络学习整个图像的高维特征。

数学表达

假设输入图像 y = x + n { y = x + n } y=x+n,其中:

  • y { y } y是含噪图像。
  • x { x } x是干净图像。
  • n { n } n是噪声。

传统方法的目标是从 y { y } y恢复出 x { x } x,即:
x ^ = f ( y ) {\hat{x} = f(y)} x^=f(y)

其中 f { f } f是去噪网络。

而残差学习方法则先通过CNN学习噪声部分 n { n } n的残差:
n ^ = g ( y ) {\hat{n} = g(y)} n^=g(y)

然后从含噪图像中减去预测的残差,得到去噪图像:
x ^ = y − n ^ {\hat{x} = y - \hat{n}} x^=y−n^

通过这种方式,网络可以专注于预测噪声,通常比直接恢复干净图像更有效。

1.2 Batch Normalization (BN)

内部协变量移位(internal covariate shift):深层神经网络在做非线性变换前的激活输入值,随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值WU+B是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。

Batch Normalization (BN) 是一种用于加速神经网络训练并改善性能的技术。BN 在每一层的输入上应用标准化操作,确保其均值为零、方差为一。标准化的目的是缓解 internal covariate shift 内部协变量偏移(即每层输入分布随训练过程而变化),从而加速收敛并提高模型稳定性。

具体而言,BN 对输入数据 x {x } x 进行标准化操作,将其转化为一个均值为 0、方差为 1 的数据:
x ^ = x − μ σ {\hat{x} = \frac{x - \mu}{\sigma}} x^=σx−μ

其中:

  • μ {\mu} μ和 σ { \sigma } σ分别是批量数据的均值和标准差。

BN 还引入了可训练的尺度( γ {\gamma } γ)和偏移( β { \beta } β)参数,帮助恢复标准化后数据的原始分布:
y = γ x ^ + β {y = {\gamma} {\hat{x} }+ {\beta}} y=γx^+β

其中, γ {\gamma } γ 和 β { \beta } β 是可学习的参数,允许网络通过训练自动调节。

DnCNN 中的 Batch Normalization

在 DnCNN 中,Batch Normalization 的应用是该网络的一个重要组成部分。DnCNN 通过多个卷积层和 Batch Normalization 层来学习图像中的噪声,并通过残差学习来预测噪声的残差。具体地,Batch Normalization 被用于以下几个方面:

  1. 加速训练过程

由于 BN 能够保持每一层输入的均值和方差稳定,网络的训练过程变得更加平滑。这样,DNN 不需要过多地依赖于较低的学习率,且能够通过较大的学习率快速收敛。对于图像去噪任务,快速的收敛速度和更少的训练时间是非常重要的,尤其是在大规模数据集上。

  1. 缓解梯度消失/爆炸问题

深层神经网络训练中的常见问题是梯度消失或梯度爆炸,特别是在较深的网络中。

Batch Normalization 在每一层的输入上进行标准化,能够有效防止梯度消失或爆炸现象,使得网络能够稳定训练。DNN 通过 BN 获得更为稳定的激活分布,从而加快训练速度。

  1. 增强模型的鲁棒性

通过标准化每一层的输入,BN 改变了网络的训练动态,使得网络不容易受到初始化参数或输入数据分布变化的影响。这意味着即使网络初始参数随机化,或输入数据分布发生变化,模型依然能够稳定地训练。对于去噪任务,图像的不同区域或噪声类型可能会存在较大的差异,BN 能够提高网络对这些变化的鲁棒性。

  1. 减少对权重初始化的依赖

Batch Normalization 使得网络的训练过程不那么依赖于精心设计的权重初始化。由于 BN 能够在每一层进行标准化,它减轻了网络对初始化参数的敏感性。使用 BN 后,可以通过较大的学习率进行训练,从而更好地利用网络的表现。

  1. 改进的去噪效果

在 DnCNN 中,网络的目标是通过学习输入图像和噪声之间的残差来恢复干净图像。使用 Batch Normalization 有助于更好地学习噪声的特征,并提高网络对复杂噪声的去除能力。BN 帮助每一层输入数据保持稳定,从而提高了去噪网络在高噪声环境中的鲁棒性。

1.3 网络架构

在网络架构设计中,作者对VGG网络进行了修改,使其适合于图像去噪,并根据最先进的去噪方法中使用的有效补丁大小来设置网络的深度。在模型学习中,采用残差学习公式,并将其纳入批归一化,以快速训练和提高去噪性能。

假设DnCNN的输入是噪声观测 y = x + n { y = x + n } y=x+n,采用残差学习公式来训练残差映射 R ( y ) = n {R(y)=n} R(y)=n,则 x = y − R ( y ) {x=y-R(y)} x=y−R(y);期望残差图像与噪声输入的估计残差之间的平均均方误差:

对于具有一定噪声水平的高斯去噪,将DnCNN的接受场大小设置为35×35,相应的深度为17。对于其他一般的图像去噪任务,采用一个更大的接受域,并将深度设置为20。

第一部分:Conv(3 * 3 * c * 64)+ReLu (c代表图片通道数)

第二部分:Conv(3 * 3 * 64 * 64)+BN(batch normalization)+ReLu

第三部分:Conv(3 * 3 * 64)

每一层都zero padding,使得每一层的输入、输出尺寸保持一致。以此防止产生人工边界(boundary artifacts)。第二部分每一层在卷积与reLU之间都加了批量标准化(batch normalization、BN)。

2. DnCNN 基于Pytorch复现

Pytorch实现代码DnCNN GitHub Pytorch
Pytorch实现 数据集DnCNN GitHub Pytorch dataset

2.1 项目代码说明:

  • data目录:存放训练集和测试集
  • models目录:存放训练的模型
  • results目录:存放处理后的图像
  • data_generator.py:数据预处理
  • main_test.py: 基于模型测试数据集,输出去噪结果并计算测试集PSNR和SSIM
  • main_train.py: 训练网络

运行方式:

1、放置好训练集和测试集

2、运行main_train.py进行训练(时间较长,请耐性等待)

3、运行main_test.py进行测试

2.2 DnCNN 网络结构和损失函数

第一部分:Conv(3 * 3 * c * 64)+ReLu (c代表图片通道数)

第二部分:Conv(3 * 3 * 64 * 64)+BN(batch normalization)+ReLu

第三部分:Conv(3 * 3 * 64)

python 复制代码
class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

损失函数计算:

假设DnCNN的输入是噪声观测 y = x + n { y = x + n } y=x+n,采用残差学习公式来训练残差映射 R ( y ) = n {R(y)=n} R(y)=n,则 x = y − R ( y ) {x=y-R(y)} x=y−R(y);期望残差图像与噪声输入的估计残差之间的平均均方误差:

python 复制代码
class sum_squared_error(_Loss): 
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='sum'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)

    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)

优化器:

由于Pytorch与Matlab不同,Pytorch版本中作者将SGD改为Adam,学习率调整为1e-3。

2.3 高版本Pytorch运行过程报错修正

2.3.1 compare_psnr, compare_ssim报错:修改方式如下

from skimage.metrics import structural_similarity as compare_ssim

from skimage.metrics import peak_signal_noise_ratio as compare_psnr

2.3.2 源码报错图像报错 OSError: cannot write mode F as PNG修改

相关推荐
人工智能转人机6 分钟前
16day-人工智能-机器学习-特征工程
人工智能·学习·机器学习·特征工程
Python×CATIA工业智造9 分钟前
Pycaita二次开发基础代码解析:参数化模板创建与设计表驱动建模
python·pycharm·pycatia
这张生成的图像能检测吗10 分钟前
(论文速读)探索多模式大型语言模型的视觉缺陷
人工智能·深度学习·算法·计算机视觉·语言模型·自然语言处理
白应穷奇14 分钟前
编写高性能数据处理代码 01
后端·python
小蜜蜂爱编程18 分钟前
opencv 阈值分割函数
人工智能·opencv·计算机视觉
机器之心25 分钟前
闹玩呢!首届大模型对抗赛,DeepSeek、Kimi第一轮被淘汰了
人工智能·openai
新智元30 分钟前
Claude Opus 4.1 代码实测惊人!OpenAI 开源模型却只会写屎山?
人工智能·openai
攻城狮7号32 分钟前
GPT-5的诞生之痛:AI帝国的现实危机
人工智能·深度学习·openai·gpt-5·sam altman
新智元34 分钟前
奥特曼深夜官宣:OpenAI 重回开源!两大推理模型追平 o4-mini,号称世界最强
人工智能·openai
稚肩37 分钟前
最优化中常见的优化理论
人工智能