第一步:PSD介绍
以往的研究主要集中在具有合成模糊图像的训练模型上,当模型用于真实世界的模糊图像时,会导致性能下降。
为了解决上述问题,提高去雾的泛化性能,作者提出了一种Principled Synthetic-to-real Dehazing (PSD)框架。
本文提出的PSD适用于将现有的去雾模型推广到实际领域,包括两个阶段:有监督的预训练 和无监督的微调。
预训练阶段,作者将选定的去雾模型主干修改为一个基于物理模型的网络,并用合成数据训练该网络。利用设计良好的主干,我们可以得到一个预先训练的模型,在合成域上具有良好的去雾性能。
微调阶段 ,作者利用真实的模糊图像以无监督的方式训练模型。
本文的贡献:
- 作者将真实世界的去雾任务重新定义为一个合成到真实的泛化框架:首先一个在合成配对数据上预先训练的去雾模型主干,真实的模糊图像随后将被利用以一种无监督的方式微调模型。PSD易于使用,可以以大多数深度去雾模型为骨干。
- 由于没有清晰的真实图像作为监督,作者利用几个流行的、有充分根据的物理先验来指导微调。作者将它们合并成一个预先的损失committee,作为具体任务的代理指导,这一部分也是PSD的核心。
- 性能达到SOTA
第二步:PSD网络结构
首先对两个框架大的方向做一个整体概述:
Pre-training
首先采用目前性能最好的框架之一作为网络的主干
然后我们将主干修改为一个基于物理的网络,根据一个单一的雾图同时生成干净的图像 J,传输图 t 和大气光 A,为了共同优化这三个分量,作者加入了一个重建损失,它引导网络输出服从物理散射模型。
在这个阶段,只使用标记的合成数据进行训练,最终得到一个在合成域上预训练的模型。
Fine-tuning
作者利用未标记的真实数据 将预训练模型从合成域推广到真实域。受去雾强物理背景的启发,作者认为一个高质量的无雾图像应该遵循一些特定的统计规则 ,这些规则可以从图像先验中推导出来。此外,单一先验提供的物理知识并不总是可靠的,所以作者的目标是找到多个先验的组合,希望它们能够相互补充。
基于上述,作者设计了一个先验损失committee来作为任务特定的代理指导,用于训练未标记的真实数据。
此外,作者应用了一种learning without forgetting (LwF)的方法,该方法通过将原始任务的训练数据(即合成的模糊图像)通过网络运转到同真实的模糊数据一起,从而强行使得模型记忆合成领域的知识。
第三步:模型代码展示
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class BlockUNet1(nn.Module):
def __init__(self, in_channels, out_channels, upsample=False, relu=False, drop=False, bn=True):
super(BlockUNet1, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.dropout = nn.Dropout2d(0.5)
self.batch = nn.InstanceNorm2d(out_channels)
self.upsample = upsample
self.relu = relu
self.drop = drop
self.bn = bn
def forward(self, x):
if self.relu == True:
y = F.relu(x)
elif self.relu == False:
y = F.leaky_relu(x, 0.2)
if self.upsample == True:
y = self.deconv(y)
if self.bn == True:
y = self.batch(y)
if self.drop == True:
y = self.dropout(y)
elif self.upsample == False:
y = self.conv(y)
if self.bn == True:
y = self.batch(y)
if self.drop == True:
y = self.dropout(y)
return y
class G2(nn.Module):
def __init__(self, in_channels, out_channels):
super(G2, self).__init__()
self.conv = nn.Conv2d(in_channels, 8, 4, 2, 1, bias=False)
self.layer1 = BlockUNet1(8, 16)
self.layer2 = BlockUNet1(16, 32)
self.layer3 = BlockUNet1(32, 64)
self.layer4 = BlockUNet1(64, 64)
self.layer5 = BlockUNet1(64, 64)
self.layer6 = BlockUNet1(64, 64)
self.layer7 = BlockUNet1(64, 64)
self.dlayer7 = BlockUNet1(64, 64, True, True, True, False)
self.dlayer6 = BlockUNet1(128, 64, True, True, True)
self.dlayer5 = BlockUNet1(128, 64, True, True, True)
self.dlayer4 = BlockUNet1(128, 64, True, True)
self.dlayer3 = BlockUNet1(128, 32, True, True)
self.dlayer2 = BlockUNet1(64, 16, True, True)
self.dlayer1 = BlockUNet1(32, 8, True, True)
self.relu = nn.ReLU()
self.dconv = nn.ConvTranspose2d(16, out_channels, 4, 2, 1, bias=False)
self.lrelu = nn.LeakyReLU(0.2)
def forward(self, x):
y1 = self.conv(x)
y2 = self.layer1(y1)
y3 = self.layer2(y2)
y4 = self.layer3(y3)
y5 = self.layer4(y4)
y6 = self.layer5(y5)
y7 = self.layer6(y6)
y8 = self.layer7(y7)
dy8 = self.dlayer7(y8)
concat7 = torch.cat([dy8, y7], 1)
dy7 = self.dlayer6(concat7)
concat6 = torch.cat([dy7, y6], 1)
dy6 = self.dlayer5(concat6)
concat5 = torch.cat([dy6, y5], 1)
dy5 = self.dlayer4(concat5)
concat4 = torch.cat([dy5, y4], 1)
dy4 = self.dlayer3(concat4)
concat3 = torch.cat([dy4, y3], 1)
dy3 = self.dlayer2(concat3)
concat2 = torch.cat([dy3, y2], 1)
dy2 = self.dlayer1(concat2)
concat1 = torch.cat([dy2, y1], 1)
out = self.relu(concat1)
out = self.dconv(out)
out = self.lrelu(out)
return F.avg_pool2d(out, (out.shape[2], out.shape[3]))
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)
class PALayer(nn.Module):
def __init__(self, channel):
super(PALayer, self).__init__()
self.pa = nn.Sequential(
nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.pa(x)
return x * y
class CALayer(nn.Module):
def __init__(self, channel):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.ca = nn.Sequential(
nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.ca(y)
return x * y
class Block(nn.Module):
def __init__(self, conv, dim, kernel_size,):
super(Block, self).__init__()
self.conv1=conv(dim, dim, kernel_size, bias=True)
self.act1=nn.ReLU(inplace=True)
self.conv2=conv(dim,dim,kernel_size,bias=True)
self.calayer=CALayer(dim)
self.palayer=PALayer(dim)
def forward(self, x):
res=self.act1(self.conv1(x))
res=res+x
res=self.conv2(res)
res=self.calayer(res)
res=self.palayer(res)
res += x
return res
class Group(nn.Module):
def __init__(self, conv, dim, kernel_size, blocks):
super(Group, self).__init__()
modules = [ Block(conv, dim, kernel_size) for _ in range(blocks)]
modules.append(conv(dim, dim, kernel_size))
self.gp = nn.Sequential(*modules)
def forward(self, x):
res = self.gp(x)
res += x
return res
class FFANet(nn.Module):
def __init__(self,gps,blocks,conv=default_conv):
super(FFANet, self).__init__()
self.gps=gps
self.dim=64
kernel_size=3
pre_process = [conv(3, self.dim, kernel_size)]
assert self.gps==3
self.g1= Group(conv, self.dim, kernel_size,blocks=blocks)
self.g2= Group(conv, self.dim, kernel_size,blocks=blocks)
self.g3= Group(conv, self.dim, kernel_size,blocks=blocks)
self.ca=nn.Sequential(*[
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(self.dim*self.gps,self.dim//16,1,padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(self.dim//16, self.dim*self.gps, 1, padding=0, bias=True),
nn.Sigmoid()
])
self.palayer=PALayer(self.dim)
self.conv_J_1 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
self.conv_J_2 = nn.Conv2d(64, 3, 3, 1, 1, bias=False)
self.conv_T_1 = nn.Conv2d(64, 16, 3, 1, 1, bias=False)
self.conv_T_2 = nn.Conv2d(16, 1, 3, 1, 1, bias=False)
post_precess = [
conv(self.dim, self.dim, kernel_size),
conv(self.dim, 3, kernel_size)]
self.pre = nn.Sequential(*pre_process)
self.post = nn.Sequential(*post_precess)
self.ANet = G2(3, 3)
def forward(self, x1, x2=0, Val=False):
x = self.pre(x1)
res1=self.g1(x)
res2=self.g2(res1)
res3=self.g3(res2)
w=self.ca(torch.cat([res1,res2,res3],dim=1))
w=w.view(-1,self.gps,self.dim)[:,:,:,None,None]
out=w[:,0,::]*res1+w[:,1,::]*res2+w[:,2,::]*res3
out=self.palayer(out)
out_J = self.conv_J_1(out)
out_J = self.conv_J_2(out_J)
out_J = out_J + x1
out_T = self.conv_T_1(out)
out_T = self.conv_T_2(out_T)
if Val == False:
out_A = self.ANet(x1)
else:
out_A = self.ANet(x2)
out_I = out_T * out_J + (1 - out_T) * out_A
#x=self.post(out)
return out, out_J, out_T, out_A, out_I
if __name__ == "__main__":
net=FFA(gps=3,blocks=19)
print(net)
第四步:运行
第五步:整个工程的内容
代码的下载路径 (新窗口打开链接): 基于深度学习神经网络的AI图像PSD去雾系统源码
有问题可以私信或者留言,有问必答