![](https://i-blog.csdnimg.cn/direct/8a76d952a07c47ba81eaccbdd70196c1.png)
目录
[一 SCI](#一 SCI)
[1 SCI网络结构](#1 SCI网络结构)
[2 SCI损失函数](#2 SCI损失函数)
[3 实验](#3 实验)
[二 SCI效果](#二 SCI效果)
[1 下载代码](#1 下载代码)
[2 运行](#2 运行)
一 SCI
![](https://i-blog.csdnimg.cn/direct/38226f3dab0d4dbf80c6a6a37bb629ba.png)
💜论文题目 :++++Toward Fast, Flexible, and Robust Low-Light Image Enhancement++++
💚论文地址 :https://arxiv.org/pdf/2204.10137
💙代码地址 :https://github.com/vis-opt-group/SCI
【摘要】 现有的++++低照度图像增强技术++++ 大多不仅难以兼顾视觉质量和计算效率,而且在未知复杂场景下往往失效。在本文中,我们开发了一种新的自校准照明( Self-Calibrated Illumination,SCI )学习框架 ,用于实际低照度场景中快速、灵活和鲁棒的亮化图像 。具体来说,我们建立了一个带有权重共享的级联光照学习过程来处理这个任务。考虑到级联模式的计算负担,我们构造了自校准模块,实现了每个阶段结果之间的收敛,产生了仅使用单个基本块进行推理(但在以前的工作中尚未得到利用)的增益,极大地减少了计算开销。然后,我们定义了++++无监督训练损失++++ ,以提高模型的能力,使其能够适应一般的场景。进一步,我们对挖掘SCI固有属性(现有工作中的欠缺),包括操作不敏感适应性(在不同的设置下获得稳定的性能)和模型无关通用性(可以应用于现有的基于光照的工作中,以提高性能)进行了全面探索。最后,大量的实验和消融研究充分表明了我们在质量和效率上的优越性。在低照度人脸检测和夜间语义分割上的应用充分显示了SCI的潜在实用价值。
见图1。最近最先进的方法和我们的方法的比较。
Kin D [ 34 ]是具有代表性的成对监督方法。
EnGAN [ 11 ]考虑了非成对监督学习。
Zero DCE [ 7 ]和RUAS [ 14 ]引入了无监督学习。
我们的方法(只包含3个大小为3 × 3的卷积)也属于无监督学习。如放大区域显示的那样,这些比较方法出现了不正确的曝光,颜色失真和结构不足以降低视觉质量。相比之下,我们的结果呈现出生动的颜色和清晰的轮廓。进一步,我们报告了( b )中的计算效率( SIZE , FLOPs和TIME)和( c )中的增强( PSNR , SSIM和EME)、检测( mAP )和分割( mIoU )三种任务中5种度量指标的数值得分,可以很容易地观察到我们的方法明显优于其他方法。
![](https://i-blog.csdnimg.cn/direct/a7f38cac291d43c791fc0845d5aa6dd3.png)
更具体地说,论文的主要贡献可以归结为:
① 开发了一个**++++权重共享的光照学习自校准模块++++** ,以保证每个阶段的结果之间的收敛性,提高曝光稳定性 ,并大幅降低计算负担。据我们所知,利用学习过程加速低照度图像增强算法是第一项工作。
② 在自校准模块的作用下,我们定义了无监督的训练损失 来约束每个阶段的输出,从而赋予模型对不同场景的适应能力。属性分析表明,++++SCI具有操作不敏感的自适应性和模型无关的一般性++++,这是现有工作中没有发现的。
③ 进行了大量的实验来说明我们相对于其他先进方法的优越性。在++++暗人脸检测和夜间语义分割++++上的应用进一步展示了我们的实用价值。简而言之,在基于网络的低照度图像增强领域,SCI重新定义了视觉质量、计算效率和下游任务性能的峰值点。
1 SCI网络结构
见图2 。SCI的整个框架。在训练阶段,SCI由光照估计和自校准模块组成 。在原始低照度输入中加入自校准模块映射,作为下一阶段照度估计的输入。注意这两个模块分别是整个训练过程中的共享参数。在测试阶段,只使用了单一的光照估计模块。
![](https://i-blog.csdnimg.cn/direct/590bac30666a4c61bb07f42a53090fb3.png)
❤️核心代码( model.py )
主要由EnhanceNetwork 、 CalibrateNetwork 、 Network 、 Finetunemodel四个类组成。
① EnhanceNetwork: 对输入图像进行增强。
🦋🦋🦋通过多次堆叠卷积块,来学习图像的特征。增强后的图像通过与输入相加并进行截断,以确保像素值在合理范围内。
首先,通过__init__初始化超参数和网络层。
然后,将输入图像通过3*3的卷积层,得到特征 fea,对特征fea多次应用相同的卷积块进行叠加,通过输出卷积层获得最终的特征 fea。
接着 ,将生成的特征与输入图像相加,得到增强后的图像,通过clamp 函数将图像像素值限制在 0.0001 和 1 之间。
最后 ,返回增强后的图像 illu。
② CalibrateNetwork : 定义了一个校准网络 。
🦋🦋🦋在前向传播时,输入经过一系列卷积操作后,对于最终的特征 fea再与原始输入相减 ,得到最终的增益调整结果delta。
③ Network : 组合了上述图像增强网络 (EnhanceNetwork) 和校准网络(CalibrateNetwork),并多次执行这两个操作。
首先,初始化网络结构,并创建EnhanceNetwork、CalibrateNetwork以及loss损失函数的实例。
接着,定义权重初始化的方法。
然后,通过多次迭代,每次迭代中进行下述的步骤:
◆ 将当前输入保存到列表中。
◆ 使用图像增强网络 EnhanceNetwork 处理当前输入,得到增强后的图像。
◆ 计算增强前后的比例,并将比例值限制在 [0, 1] 范围内。
◆ 使用校准网络 CalibrateNetwork 对比例进行校准,得到校准值。
◆ 将原始输入与校准值相加,得到下一阶段的输入。
◆ 将当前阶段的增强图像、比例、输入和校准值的绝对值保存到对应的列表中。
◆ 返回四个列表,分别包含不同阶段的增强图像、比例、输入和校准值。
最后,计算损失。
④ Finetunemodel : 进行模型的微调。
python
import torch
import torch.nn as nn
from loss import LossFunction
class EnhanceNetwork(nn.Module):
def __init__(self, layers, channels):
super(EnhanceNetwork, self).__init__()
kernel_size = 3
dilation = 1
padding = int((kernel_size - 1) / 2) * dilation
self.in_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.blocks = nn.ModuleList()
for i in range(layers):
self.blocks.append(self.conv)
self.out_conv = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, input):
fea = self.in_conv(input)
for conv in self.blocks:
fea = fea + conv(fea)
fea = self.out_conv(fea)
illu = fea + input
illu = torch.clamp(illu, 0.0001, 1)
return illu
class CalibrateNetwork(nn.Module):
def __init__(self, layers, channels):
super(CalibrateNetwork, self).__init__()
kernel_size = 3
dilation = 1
padding = int((kernel_size - 1) / 2) * dilation
self.layers = layers
self.in_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.convs = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
nn.BatchNorm2d(channels),
nn.ReLU()
)
self.blocks = nn.ModuleList()
for i in range(layers):
self.blocks.append(self.convs)
self.out_conv = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
nn.Sigmoid()
)
def forward(self, input):
fea = self.in_conv(input)
for conv in self.blocks:
fea = fea + conv(fea)
fea = self.out_conv(fea)
delta = input - fea
return delta
class Network(nn.Module):
def __init__(self, stage=3):
super(Network, self).__init__()
self.stage = stage
self.enhance = EnhanceNetwork(layers=1, channels=3)
self.calibrate = CalibrateNetwork(layers=3, channels=16)
self._criterion = LossFunction()
def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)
def forward(self, input):
ilist, rlist, inlist, attlist = [], [], [], []
input_op = input
for i in range(self.stage):
inlist.append(input_op)
i = self.enhance(input_op)
r = input / i
r = torch.clamp(r, 0, 1)
att = self.calibrate(r)
input_op = input + att
ilist.append(i)
rlist.append(r)
attlist.append(torch.abs(att))
return ilist, rlist, inlist, attlist
def _loss(self, input):
i_list, en_list, in_list, _ = self(input)
loss = 0
for i in range(self.stage):
loss += self._criterion(in_list[i], i_list[i])
return loss
class Finetunemodel(nn.Module):
def __init__(self, weights):
super(Finetunemodel, self).__init__()
self.enhance = EnhanceNetwork(layers=1, channels=3)
self._criterion = LossFunction()
base_weights = torch.load(weights)
pretrained_dict = base_weights
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)
def forward(self, input):
i = self.enhance(input)
r = input / i
r = torch.clamp(r, 0, 1)
return i, r
def _loss(self, input):
i, r = self(input)
loss = self._criterion(input, i)
return loss
2 SCI损失函数
❗❗❗总损失函数:Ltotal = αLf + βLs
Lf :表示保真度;Ls :表示平滑损失。
++++🌸保真度损失++++ 是为了保证估计的照度与每级输入之间的像素级一致性,表示如公式(4)所示。
![](https://i-blog.csdnimg.cn/direct/2bc732699d9140c0ad2d10e66ad5f416.png)
++++🌸光照的平滑特性++++ 在这个任务[ 7、34]中是一个广泛的共识。这里我们采用一个具有空间变化 l1范数的光滑项[ 4 ],表示如公式(5)所示。
![](https://i-blog.csdnimg.cn/direct/06f34f9245234fa39344ed9f2067a752.png)
❤️核心代码( l oss .py )
SCI使用的是无监督损失训练,由fifidelity loss 和smoothing loss的线性组合构成。
◆ Fidelity Loss
🌸采用均方误差损失函数 nn.MSELoss 计算输入图像 input 与增强后的图像 illu 之间的均方误差。
正则化项基于像素梯度和其指数权重的计算。
◆ Smooth Loss
🌸采用 SmoothLoss 类的实例 self.smooth_loss 计算输入图像 input 与增强后的图像 illu 之间的光滑损失。
通过 YCbCr 色彩空间的梯度计算来衡量图像的光滑性。
python
import torch
import torch.nn as nn
class LossFunction(nn.Module):
def __init__(self):
super(LossFunction, self).__init__()
self.l2_loss = nn.MSELoss()
self.smooth_loss = SmoothLoss()
def forward(self, input, illu):
Fidelity_Loss = self.l2_loss(illu, input)
Smooth_Loss = self.smooth_loss(input, illu)
return 1.5*Fidelity_Loss + Smooth_Loss
class SmoothLoss(nn.Module):
def __init__(self):
super(SmoothLoss, self).__init__()
self.sigma = 10
def rgb2yCbCr(self, input_im):
im_flat = input_im.contiguous().view(-1, 3).float()
mat = torch.Tensor([[0.257, -0.148, 0.439], [0.564, -0.291, -0.368], [0.098, 0.439, -0.071]]).cuda()
bias = torch.Tensor([16.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0]).cuda()
temp = im_flat.mm(mat) + bias
out = temp.view(input_im.shape[0], 3, input_im.shape[2], input_im.shape[3])
return out
# output: output input:input
def forward(self, input, output):
self.output = output
self.input = self.rgb2yCbCr(input)
sigma_color = -1.0 / (2 * self.sigma * self.sigma)
w1 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :] - self.input[:, :, :-1, :], 2), dim=1,
keepdim=True) * sigma_color)
w2 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :] - self.input[:, :, 1:, :], 2), dim=1,
keepdim=True) * sigma_color)
w3 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 1:] - self.input[:, :, :, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w4 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-1] - self.input[:, :, :, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w5 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-1] - self.input[:, :, 1:, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w6 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 1:] - self.input[:, :, :-1, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w7 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-1] - self.input[:, :, :-1, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w8 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 1:] - self.input[:, :, 1:, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w9 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :] - self.input[:, :, :-2, :], 2), dim=1,
keepdim=True) * sigma_color)
w10 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :] - self.input[:, :, 2:, :], 2), dim=1,
keepdim=True) * sigma_color)
w11 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 2:] - self.input[:, :, :, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w12 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-2] - self.input[:, :, :, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w13 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-1] - self.input[:, :, 2:, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w14 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 1:] - self.input[:, :, :-2, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w15 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-1] - self.input[:, :, :-2, 1:], 2), dim=1,
keepdim=True) * sigma_color)
w16 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 1:] - self.input[:, :, 2:, :-1], 2), dim=1,
keepdim=True) * sigma_color)
w17 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-2] - self.input[:, :, 1:, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w18 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 2:] - self.input[:, :, :-1, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w19 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-2] - self.input[:, :, :-1, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w20 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 2:] - self.input[:, :, 1:, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w21 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-2] - self.input[:, :, 2:, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w22 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 2:] - self.input[:, :, :-2, :-2], 2), dim=1,
keepdim=True) * sigma_color)
w23 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-2] - self.input[:, :, :-2, 2:], 2), dim=1,
keepdim=True) * sigma_color)
w24 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 2:] - self.input[:, :, 2:, :-2], 2), dim=1,
keepdim=True) * sigma_color)
p = 1.0
pixel_grad1 = w1 * torch.norm((self.output[:, :, 1:, :] - self.output[:, :, :-1, :]), p, dim=1, keepdim=True)
pixel_grad2 = w2 * torch.norm((self.output[:, :, :-1, :] - self.output[:, :, 1:, :]), p, dim=1, keepdim=True)
pixel_grad3 = w3 * torch.norm((self.output[:, :, :, 1:] - self.output[:, :, :, :-1]), p, dim=1, keepdim=True)
pixel_grad4 = w4 * torch.norm((self.output[:, :, :, :-1] - self.output[:, :, :, 1:]), p, dim=1, keepdim=True)
pixel_grad5 = w5 * torch.norm((self.output[:, :, :-1, :-1] - self.output[:, :, 1:, 1:]), p, dim=1, keepdim=True)
pixel_grad6 = w6 * torch.norm((self.output[:, :, 1:, 1:] - self.output[:, :, :-1, :-1]), p, dim=1, keepdim=True)
pixel_grad7 = w7 * torch.norm((self.output[:, :, 1:, :-1] - self.output[:, :, :-1, 1:]), p, dim=1, keepdim=True)
pixel_grad8 = w8 * torch.norm((self.output[:, :, :-1, 1:] - self.output[:, :, 1:, :-1]), p, dim=1, keepdim=True)
pixel_grad9 = w9 * torch.norm((self.output[:, :, 2:, :] - self.output[:, :, :-2, :]), p, dim=1, keepdim=True)
pixel_grad10 = w10 * torch.norm((self.output[:, :, :-2, :] - self.output[:, :, 2:, :]), p, dim=1, keepdim=True)
pixel_grad11 = w11 * torch.norm((self.output[:, :, :, 2:] - self.output[:, :, :, :-2]), p, dim=1, keepdim=True)
pixel_grad12 = w12 * torch.norm((self.output[:, :, :, :-2] - self.output[:, :, :, 2:]), p, dim=1, keepdim=True)
pixel_grad13 = w13 * torch.norm((self.output[:, :, :-2, :-1] - self.output[:, :, 2:, 1:]), p, dim=1, keepdim=True)
pixel_grad14 = w14 * torch.norm((self.output[:, :, 2:, 1:] - self.output[:, :, :-2, :-1]), p, dim=1, keepdim=True)
pixel_grad15 = w15 * torch.norm((self.output[:, :, 2:, :-1] - self.output[:, :, :-2, 1:]), p, dim=1, keepdim=True)
pixel_grad16 = w16 * torch.norm((self.output[:, :, :-2, 1:] - self.output[:, :, 2:, :-1]), p, dim=1, keepdim=True)
pixel_grad17 = w17 * torch.norm((self.output[:, :, :-1, :-2] - self.output[:, :, 1:, 2:]), p, dim=1, keepdim=True)
pixel_grad18 = w18 * torch.norm((self.output[:, :, 1:, 2:] - self.output[:, :, :-1, :-2]), p, dim=1, keepdim=True)
pixel_grad19 = w19 * torch.norm((self.output[:, :, 1:, :-2] - self.output[:, :, :-1, 2:]), p, dim=1, keepdim=True)
pixel_grad20 = w20 * torch.norm((self.output[:, :, :-1, 2:] - self.output[:, :, 1:, :-2]), p, dim=1, keepdim=True)
pixel_grad21 = w21 * torch.norm((self.output[:, :, :-2, :-2] - self.output[:, :, 2:, 2:]), p, dim=1, keepdim=True)
pixel_grad22 = w22 * torch.norm((self.output[:, :, 2:, 2:] - self.output[:, :, :-2, :-2]), p, dim=1, keepdim=True)
pixel_grad23 = w23 * torch.norm((self.output[:, :, 2:, :-2] - self.output[:, :, :-2, 2:]), p, dim=1, keepdim=True)
pixel_grad24 = w24 * torch.norm((self.output[:, :, :-2, 2:] - self.output[:, :, 2:, :-2]), p, dim=1, keepdim=True)
ReguTerm1 = torch.mean(pixel_grad1) \
+ torch.mean(pixel_grad2) \
+ torch.mean(pixel_grad3) \
+ torch.mean(pixel_grad4) \
+ torch.mean(pixel_grad5) \
+ torch.mean(pixel_grad6) \
+ torch.mean(pixel_grad7) \
+ torch.mean(pixel_grad8) \
+ torch.mean(pixel_grad9) \
+ torch.mean(pixel_grad10) \
+ torch.mean(pixel_grad11) \
+ torch.mean(pixel_grad12) \
+ torch.mean(pixel_grad13) \
+ torch.mean(pixel_grad14) \
+ torch.mean(pixel_grad15) \
+ torch.mean(pixel_grad16) \
+ torch.mean(pixel_grad17) \
+ torch.mean(pixel_grad18) \
+ torch.mean(pixel_grad19) \
+ torch.mean(pixel_grad20) \
+ torch.mean(pixel_grad21) \
+ torch.mean(pixel_grad22) \
+ torch.mean(pixel_grad23) \
+ torch.mean(pixel_grad24)
total_term = ReguTerm1
return total_term
3 实验
见图7。在LSRW数据集上对当前最先进的低照度图像增强方法进行了视觉比较。
![](https://i-blog.csdnimg.cn/direct/914eb17d44984d3bb87f7221858195ac.png)
见图8。在一些具有挑战性的实例上进行视觉比较。更多的结果可以在补充材料中找到。
![](https://i-blog.csdnimg.cn/direct/5f4497af55264cdab83b1b3df908dcb1.png)
二 SCI效果
Requirements:python3.7 pytorch==1.8.0 cuda11.1
1 下载代码
git clone https://github.com/vis-opt-group/SCI.git
2 运行
python3 test.py
原图:
![](https://i-blog.csdnimg.cn/direct/3e02a013048d4803bddbcd145f7f6911.png)
效果图:
![](https://i-blog.csdnimg.cn/direct/add6eff61ee34c03b8cfdc300f3a47e8.png)
至此,本文分享的内容就结束啦💕💕💕💕💕💕。