Python图像处理【22】基于卷积神经网络的图像去雾

基于卷积神经网络的图像去雾

    • [0. 前言](#0. 前言)
    • [1. 渐进特征融合网络](#1. 渐进特征融合网络)
    • [2. 图像去雾](#2. 图像去雾)
      • [2.1 网络构建](#2.1 网络构建)
      • [2.2 模型测试](#2.2 模型测试)
    • 小结
    • 系列链接

0. 前言

单图像去雾 (dehazing) 是一个具有挑战性的图像恢复问题。为了解决这个问题,大多数算法都采用经典的大气散射模型,该模型是一种基于单一散射和均匀大气介质假设的简化物理模型,但现实环境中的雾霾表述更加复杂。

1. 渐进特征融合网络

在本节中,我们将学习如何使用输入自适应端到端深度学习预训练去雾模型,即渐进特征融合网络 (Progressive Feature Fusion Network, PFFNet),并通过使用 Pytorch 来执行模糊图像的去雾操作。渐进特征融合所采用的 U-Net 架构编码器 - 解码器网络,可直接学习从模糊图像到清晰图像的高度非线性转换函数。深度神经网架构如下图所示:

从以上体系结构图可以看出:

  • 编码器由五个卷积层组成,每个卷积层之后都有非线性 ReLU 激活函数;第一层用于从原始模糊图像中相对较大的局部感受野上的提取特征,然后,依次执行四次下采样卷积操作,以获取图像金字塔
  • 特征转换模块由基于残差的模块组成,深层网络可以表示非常复杂的特征,也可以学习到许多不同尺度的特征,但同时,在使用反向传播进行训练时,经常会遇到消失的梯度问题,而残差网络就是为了解决这一问题而被提出的,可以用于训练更深的网络
  • 解码器由四个反卷积层和一个卷积层组成,与编码器相反,解码器的反卷积层顺序堆叠以恢复图像结构细节

2. 图像去雾

2.1 网络构建

(1) 首先下载预训练网络模型,并导入所需的库,模块和函数:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage, Normalize, Resize
#from torchviz import make_dot
import matplotlib.pylab as plt 

(2) 定义与深神经网络中不同层相对应的 ConvLayerUpsampleConvLayer 类,所有网络层都继承自 Pytorchnn.module 类;每个层都需要实现自己的 init() (用于初始化参数/成员变量/层)和 forward() 方法(定义前向传播过程中的计算):

python 复制代码
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class UpsampleConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
      super(UpsampleConvLayer, self).__init__()
      reflection_padding = kernel_size // 2
      self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
      self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

(3) 接下来,我们用两个 ConvLayer 类实例定义类 ResidualBlock,在 ConvLayer 类实例之间使用 PReLU 激活函数,该类同样继承自 nn.module,并定义 forward() 方法用于前向传播:

python 复制代码
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.relu = nn.PReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out) * 0.1
        out = torch.add(out, residual)
        return out 

(4) 定义继承自 nn.conv2d 类的 MeanShift 类,通过将 requires_grad 的参数设置为 False,冻结 MeanShift 层:

python 复制代码
class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, sign):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.bias.data = float(sign) * torch.Tensor(rgb_mean) * rgb_range

        # Freeze the MeanShift layer
        for params in self.parameters():
            params.requires_grad = False

(5) 最后,根据所定义的神经网络层定义深度神经网络类 Net,该类同样需要定义 init() 方法。网络使用了五个 ConvLayer,然后使用四个 UPSampleconvLayer,最后通过 ConvLayer 层后输出,网络使用 LeakyReLU 作为激活函数。

同样,需要定义向前传播方法 forward(),并在每个激活函数后使用双线性上采样:

python 复制代码
class Net(nn.Module):
    def __init__(self, res_blocks=18):
        super(Net, self).__init__()

        rgb_mean = (0.5204, 0.5167, 0.5129)
        self.sub_mean = MeanShift(1., rgb_mean, -1)
        self.add_mean = MeanShift(1., rgb_mean, 1)

        self.conv_input = ConvLayer(3, 16, kernel_size=11, stride=1)
        self.conv2x = ConvLayer(16, 32, kernel_size=3, stride=2)
        self.conv4x = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.conv8x = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.conv16x = ConvLayer(128, 256, kernel_size=3, stride=2)

        self.dehaze = nn.Sequential()
        for i in range(1, res_blocks):
            self.dehaze.add_module('res%d' % i, ResidualBlock(256))

        self.convd16x = UpsampleConvLayer(256, 128, kernel_size=3, stride=2)
        self.convd8x = UpsampleConvLayer(128, 64, kernel_size=3, stride=2)
        self.convd4x = UpsampleConvLayer(64, 32, kernel_size=3, stride=2)
        self.convd2x = UpsampleConvLayer(32, 16, kernel_size=3, stride=2)

        self.conv_output = ConvLayer(16, 3, kernel_size=3, stride=1)
()
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.relu(self.conv_input(x))
        res2x = self.relu(self.conv2x(x))
        res4x = self.relu(self.conv4x(res2x))

        res8x = self.relu(self.conv8x(res4x))
        res16x = self.relu(self.conv16x(res8x))

        res_dehaze = res16x
        res16x = self.dehaze(res16x)
        res16x = torch.add(res_dehaze, res16x)

        res16x = self.relu(self.convd16x(res16x))
        res16x = F.upsample(res16x, res8x.size()[2:], mode='bilinear')
        res8x = torch.add(res16x, res8x)

        res8x = self.relu(self.convd8x(res8x))
        res8x = F.upsample(res8x, res4x.size()[2:], mode='bilinear')
        res4x = torch.add(res8x, res4x)

        res4x = self.relu(self.convd4x(res4x))
        res4x = F.upsample(res4x, res2x.size()[2:], mode='bilinear')
        res2x = torch.add(res4x, res2x)

        res2x = self.relu(self.convd2x(res2x))
        res2x = F.upsample(res2x, x.size()[2:], mode='bilinear')
        x = torch.add(res2x, x)

        x = self.conv_output(x)

        return x

(6) 定义预训练模型参数位置以及模型使用的残差块数量:

python 复制代码
rb = 13
checkpoint = "I-HAZE_O-HAZE.pth"

(7) 实例化 Net() 类并使用 load_state_dict() 方法从检查点加载预训练权重。由于我们不需要训练模型,因此使用测试模式:

python 复制代码
net = Net(rb)
net.load_state_dict(torch.load(checkpoint)['state_dict'])
net.eval()

2.2 模型测试

(1) 接下来,使用 open() 函数读取输入图像:

python 复制代码
im_path = "pic.png"
im = Image.open(im_path)
h, w = im.size
print(h, w)

(2) 使用 torchvision.transforms 模块中的 ToTensor() 将图像转换为张量对象以输入网络,然后使用输入图像在模型上运行正向传递过程计算输出,最后将输出转换为图像:

python 复制代码
imt = ToTensor()(im)
imt = Variable(imt).view(1, -1, w, h)
#im = im.cuda()
with torch.no_grad():
    imt = net(imt)
out = torch.clamp(imt, 0., 1.)
out = out.cpu()
out = out.data[0]
out = ToPILImage()(out)

def plot_image(image, title=None, sz=10):
    plt.imshow(image)
    plt.title(title, size=sz)
    plt.axis('off')
plt.figure(figsize=(20,10))
plt.subplot(121), plot_image(im, 'hazed input')
plt.subplot(122), plot_image(out, 'de-hazed output')
plt.tight_layout()
plt.show() 

小结

图像去雾已成为计算机视觉的重要研究方向,在雾、霾等恶劣天气下拍摄的的图像通常由于大气散射的作用,图像质量严重下降使颜色偏灰白色,对比度降低,物体特征难以辨认,还会影响图像的分析与处理。因此,需要使用图像去雾技术来增强或修复图像,以改善视觉效果并便于图像的后续处理。在本节中,我们学习了一种基于卷积神经网络的图像去雾模型,通过使用训练后的模型可以显著改善图像视觉效果。

系列链接

Python图像处理【1】图像与视频处理基础
Python图像处理【2】探索Python图像处理库
Python图像处理【3】Python图像处理库应用
Python图像处理【4】图像线性变换
Python图像处理【5】图像扭曲/逆扭曲
Python图像处理【6】通过哈希查找重复和类似的图像
Python图像处理【7】采样、卷积与离散傅里叶变换
Python图像处理【8】使用低通滤波器模糊图像
Python图像处理【9】使用高通滤波器执行边缘检测
Python图像处理【10】基于离散余弦变换的图像压缩
Python图像处理【11】利用反卷积执行图像去模糊
Python图像处理【12】基于小波变换执行图像去噪
Python图像处理【13】使用PIL执行图像降噪
Python图像处理【14】基于非线性滤波器的图像去噪
Python图像处理【15】基于非锐化掩码锐化图像
Python图像处理【16】OpenCV直方图均衡化
Python图像处理【17】指纹增强和细节提取
Python图像处理【18】边缘检测详解
Python图像处理【19】基于霍夫变换的目标检测
Python图像处理【20】图像金字塔
Python图像处理【21】基于卷积神经网络增强微光图像

相关推荐
小赖同学啊19 分钟前
物联网数据安全区块链服务
开发语言·python·区块链
码荼40 分钟前
学习开发之hashmap
java·python·学习·哈希算法·个人开发·小白学开发·不花钱不花时间crud
小陈phd2 小时前
李宏毅机器学习笔记——梯度下降法
人工智能·python·机器学习
kk爱闹2 小时前
【挑战14天学完python和pytorch】- day01
android·pytorch·python
Blossom.1182 小时前
机器学习在智能建筑中的应用:能源管理与环境优化
人工智能·python·深度学习·神经网络·机器学习·机器人·sklearn
亚力山大抵2 小时前
实验六-使用PyMySQL数据存储的Flask登录系统-实验七-集成Flask-SocketIO的实时通信系统
后端·python·flask
showyoui2 小时前
Python 闭包(Closure)实战总结
开发语言·python
amazinging3 小时前
北京-4年功能测试2年空窗-报培训班学测开-第四十一天
python·学习·appium
amazinging3 小时前
北京-4年功能测试2年空窗-报培训班学测开-第三十九天
python·学习·appium
m0_723140233 小时前
Python训练营-Day42
python