【论文阅读笔记】SCI算法与代码 | 低照度图像增强 | 2022.4.21

目录

[一 SCI](#一 SCI)

[1 SCI网络结构](#1 SCI网络结构)

核心代码(model.py)

[2 SCI损失函数](#2 SCI损失函数)

核心代码(loss.py)

[3 实验](#3 实验)

[二 SCI效果](#二 SCI效果)

[1 下载代码](#1 下载代码)

[2 运行](#2 运行)



一 SCI

💜论文题目++++Toward Fast, Flexible, and Robust Low-Light Image Enhancement++++

💚论文地址https://arxiv.org/pdf/2204.10137

💙代码地址https://github.com/vis-opt-group/SCI

【摘要】 现有的++++低照度图像增强技术++++ 大多不仅难以兼顾视觉质量和计算效率,而且在未知复杂场景下往往失效。在本文中,我们开发了一种新的自校准照明( Self-Calibrated Illumination,SCI )学习框架用于实际低照度场景中快速、灵活和鲁棒的亮化图像 。具体来说,我们建立了一个带有权重共享的级联光照学习过程来处理这个任务。考虑到级联模式的计算负担,我们构造了自校准模块,实现了每个阶段结果之间的收敛,产生了仅使用单个基本块进行推理(但在以前的工作中尚未得到利用)的增益,极大地减少了计算开销。然后,我们定义了++++无监督训练损失++++ ,以提高模型的能力,使其能够适应一般的场景。进一步,我们对挖掘SCI固有属性(现有工作中的欠缺),包括操作不敏感适应性(在不同的设置下获得稳定的性能)和模型无关通用性(可以应用于现有的基于光照的工作中,以提高性能)进行了全面探索。最后,大量的实验和消融研究充分表明了我们在质量和效率上的优越性。在低照度人脸检测和夜间语义分割上的应用充分显示了SCI的潜在实用价值。

见图1。最近最先进的方法和我们的方法的比较。

Kin D [ 34 ]是具有代表性的成对监督方法。

EnGAN [ 11 ]考虑了非成对监督学习。

Zero DCE [ 7 ]和RUAS [ 14 ]引入了无监督学习。

我们的方法(只包含3个大小为3 × 3的卷积)也属于无监督学习。如放大区域显示的那样,这些比较方法出现了不正确的曝光,颜色失真和结构不足以降低视觉质量。相比之下,我们的结果呈现出生动的颜色和清晰的轮廓。进一步,我们报告了( b )中的计算效率( SIZE , FLOPs和TIME)和( c )中的增强( PSNR , SSIM和EME)、检测( mAP )和分割( mIoU )三种任务中5种度量指标的数值得分,可以很容易地观察到我们的方法明显优于其他方法。

更具体地说,论文的主要贡献可以归结为:

开发了一个**++++权重共享的光照学习自校准模块++++** ,以保证每个阶段的结果之间的收敛性,提高曝光稳定性并大幅降低计算负担。据我们所知,利用学习过程加速低照度图像增强算法是第一项工作。

在自校准模块的作用下,我们定义了无监督的训练损失 来约束每个阶段的输出,从而赋予模型对不同场景的适应能力。属性分析表明,++++SCI具有操作不敏感的自适应性和模型无关的一般性++++,这是现有工作中没有发现的。

进行了大量的实验来说明我们相对于其他先进方法的优越性。在++++暗人脸检测和夜间语义分割++++上的应用进一步展示了我们的实用价值。简而言之,在基于网络的低照度图像增强领域,SCI重新定义了视觉质量、计算效率和下游任务性能的峰值点。

1 SCI网络结构

见图2 。SCI的整个框架。在训练阶段,SCI由光照估计和自校准模块组成 。在原始低照度输入中加入自校准模块映射,作为下一阶段照度估计的输入。注意这两个模块分别是整个训练过程中的共享参数。在测试阶段,只使用了单一的光照估计模块

❤️核心代码( model.py

主要由EnhanceNetwork CalibrateNetwork Network Finetunemodel四个类组成。

EnhanceNetwork: 对输入图像进行增强。

🦋🦋🦋通过多次堆叠卷积块,来学习图像的特征。增强后的图像通过与输入相加并进行截断,以确保像素值在合理范围内。

首先,通过__init__初始化超参数和网络层。

然后,将输入图像通过3*3的卷积层,得到特征 fea,对特征fea多次应用相同的卷积块进行叠加,通过输出卷积层获得最终的特征 fea。

接着 ,将生成的特征与输入图像相加,得到增强后的图像,通过clamp 函数将图像像素值限制在 0.0001 和 1 之间。

最后 ,返回增强后的图像 illu

CalibrateNetwork 定义了一个校准网络

🦋🦋🦋在前向传播时,输入经过一系列卷积操作后,对于最终的特征 fea再与原始输入相减 ,得到最终的增益调整结果delta

Network 组合了上述图像增强网络 (EnhanceNetwork) 和校准网络(CalibrateNetwork),并多次执行这两个操作

首先,初始化网络结构,并创建EnhanceNetwork、CalibrateNetwork以及loss损失函数的实例。

接着,定义权重初始化的方法。

然后,通过多次迭代,每次迭代中进行下述的步骤:

◆ 将当前输入保存到列表中。

◆ 使用图像增强网络 EnhanceNetwork 处理当前输入,得到增强后的图像。

◆ 计算增强前后的比例,并将比例值限制在 [0, 1] 范围内。

◆ 使用校准网络 CalibrateNetwork 对比例进行校准,得到校准值。

◆ 将原始输入与校准值相加,得到下一阶段的输入。

◆ 将当前阶段的增强图像、比例、输入和校准值的绝对值保存到对应的列表中。

◆ 返回四个列表,分别包含不同阶段的增强图像、比例、输入和校准值。

最后,计算损失。

Finetunemodel 进行模型的微调。

python 复制代码
import torch
import torch.nn as nn
from loss import LossFunction



class EnhanceNetwork(nn.Module):
    def __init__(self, layers, channels):
        super(EnhanceNetwork, self).__init__()

        kernel_size = 3
        dilation = 1
        padding = int((kernel_size - 1) / 2) * dilation

        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU()
        )

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

        self.blocks = nn.ModuleList()
        for i in range(layers):
            self.blocks.append(self.conv)

        self.out_conv = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input):
        fea = self.in_conv(input)
        for conv in self.blocks:
            fea = fea + conv(fea)
        fea = self.out_conv(fea)

        illu = fea + input
        illu = torch.clamp(illu, 0.0001, 1)

        return illu


class CalibrateNetwork(nn.Module):
    def __init__(self, layers, channels):
        super(CalibrateNetwork, self).__init__()
        kernel_size = 3
        dilation = 1
        padding = int((kernel_size - 1) / 2) * dilation
        self.layers = layers

        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        self.blocks = nn.ModuleList()
        for i in range(layers):
            self.blocks.append(self.convs)

        self.out_conv = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input):
        fea = self.in_conv(input)
        for conv in self.blocks:
            fea = fea + conv(fea)

        fea = self.out_conv(fea)
        delta = input - fea

        return delta



class Network(nn.Module):

    def __init__(self, stage=3):
        super(Network, self).__init__()
        self.stage = stage
        self.enhance = EnhanceNetwork(layers=1, channels=3)
        self.calibrate = CalibrateNetwork(layers=3, channels=16)
        self._criterion = LossFunction()

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.normal_(1., 0.02)

    def forward(self, input):

        ilist, rlist, inlist, attlist = [], [], [], []
        input_op = input
        for i in range(self.stage):
            inlist.append(input_op)
            i = self.enhance(input_op)
            r = input / i
            r = torch.clamp(r, 0, 1)
            att = self.calibrate(r)
            input_op = input + att
            ilist.append(i)
            rlist.append(r)
            attlist.append(torch.abs(att))

        return ilist, rlist, inlist, attlist

    def _loss(self, input):
        i_list, en_list, in_list, _ = self(input)
        loss = 0
        for i in range(self.stage):
            loss += self._criterion(in_list[i], i_list[i])
        return loss



class Finetunemodel(nn.Module):

    def __init__(self, weights):
        super(Finetunemodel, self).__init__()
        self.enhance = EnhanceNetwork(layers=1, channels=3)
        self._criterion = LossFunction()

        base_weights = torch.load(weights)
        pretrained_dict = base_weights
        model_dict = self.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        self.load_state_dict(model_dict)

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.normal_(1., 0.02)

    def forward(self, input):
        i = self.enhance(input)
        r = input / i
        r = torch.clamp(r, 0, 1)
        return i, r


    def _loss(self, input):
        i, r = self(input)
        loss = self._criterion(input, i)
        return loss

2 SCI损失函数

❗❗❗总损失函数:Ltotal = αLf + βLs

Lf :表示保真度;Ls :表示平滑损失。

++++🌸保真度损失++++ 是为了保证估计的照度与每级输入之间的像素级一致性,表示如公式(4)所示。

++++🌸光照的平滑特性++++ 在这个任务[ 7、34]中是一个广泛的共识。这里我们采用一个具有空间变化 l1范数的光滑项[ 4 ],表示如公式(5)所示。

❤️核心代码( l oss .py

SCI使用的是无监督损失训练,由fifidelity losssmoothing loss的线性组合构成。

◆ Fidelity Loss

🌸采用均方误差损失函数 nn.MSELoss 计算输入图像 input 与增强后的图像 illu 之间的均方误差。

正则化项基于像素梯度和其指数权重的计算。

◆ Smooth Loss

🌸采用 SmoothLoss 类的实例 self.smooth_loss 计算输入图像 input 与增强后的图像 illu 之间的光滑损失。

通过 YCbCr 色彩空间的梯度计算来衡量图像的光滑性。

python 复制代码
import torch
import torch.nn as nn

class LossFunction(nn.Module):
    def __init__(self):
        super(LossFunction, self).__init__()
        self.l2_loss = nn.MSELoss()
        self.smooth_loss = SmoothLoss()

    def forward(self, input, illu):
        Fidelity_Loss = self.l2_loss(illu, input)
        Smooth_Loss = self.smooth_loss(input, illu)
        return 1.5*Fidelity_Loss + Smooth_Loss



class SmoothLoss(nn.Module):
    def __init__(self):
        super(SmoothLoss, self).__init__()
        self.sigma = 10

    def rgb2yCbCr(self, input_im):
        im_flat = input_im.contiguous().view(-1, 3).float()
        mat = torch.Tensor([[0.257, -0.148, 0.439], [0.564, -0.291, -0.368], [0.098, 0.439, -0.071]]).cuda()
        bias = torch.Tensor([16.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0]).cuda()
        temp = im_flat.mm(mat) + bias
        out = temp.view(input_im.shape[0], 3, input_im.shape[2], input_im.shape[3])
        return out

    # output: output      input:input
    def forward(self, input, output):
        self.output = output
        self.input = self.rgb2yCbCr(input)
        sigma_color = -1.0 / (2 * self.sigma * self.sigma)
        w1 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :] - self.input[:, :, :-1, :], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w2 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :] - self.input[:, :, 1:, :], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w3 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 1:] - self.input[:, :, :, :-1], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w4 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-1] - self.input[:, :, :, 1:], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w5 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-1] - self.input[:, :, 1:, 1:], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w6 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 1:] - self.input[:, :, :-1, :-1], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w7 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-1] - self.input[:, :, :-1, 1:], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w8 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 1:] - self.input[:, :, 1:, :-1], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w9 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :] - self.input[:, :, :-2, :], 2), dim=1,
                                 keepdim=True) * sigma_color)
        w10 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :] - self.input[:, :, 2:, :], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w11 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 2:] - self.input[:, :, :, :-2], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w12 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-2] - self.input[:, :, :, 2:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w13 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-1] - self.input[:, :, 2:, 1:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w14 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 1:] - self.input[:, :, :-2, :-1], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w15 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-1] - self.input[:, :, :-2, 1:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w16 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 1:] - self.input[:, :, 2:, :-1], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w17 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-2] - self.input[:, :, 1:, 2:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w18 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 2:] - self.input[:, :, :-1, :-2], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w19 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-2] - self.input[:, :, :-1, 2:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w20 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 2:] - self.input[:, :, 1:, :-2], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w21 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-2] - self.input[:, :, 2:, 2:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w22 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 2:] - self.input[:, :, :-2, :-2], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w23 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-2] - self.input[:, :, :-2, 2:], 2), dim=1,
                                  keepdim=True) * sigma_color)
        w24 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 2:] - self.input[:, :, 2:, :-2], 2), dim=1,
                                  keepdim=True) * sigma_color)
        p = 1.0

        pixel_grad1 = w1 * torch.norm((self.output[:, :, 1:, :] - self.output[:, :, :-1, :]), p, dim=1, keepdim=True)
        pixel_grad2 = w2 * torch.norm((self.output[:, :, :-1, :] - self.output[:, :, 1:, :]), p, dim=1, keepdim=True)
        pixel_grad3 = w3 * torch.norm((self.output[:, :, :, 1:] - self.output[:, :, :, :-1]), p, dim=1, keepdim=True)
        pixel_grad4 = w4 * torch.norm((self.output[:, :, :, :-1] - self.output[:, :, :, 1:]), p, dim=1, keepdim=True)
        pixel_grad5 = w5 * torch.norm((self.output[:, :, :-1, :-1] - self.output[:, :, 1:, 1:]), p, dim=1, keepdim=True)
        pixel_grad6 = w6 * torch.norm((self.output[:, :, 1:, 1:] - self.output[:, :, :-1, :-1]), p, dim=1, keepdim=True)
        pixel_grad7 = w7 * torch.norm((self.output[:, :, 1:, :-1] - self.output[:, :, :-1, 1:]), p, dim=1, keepdim=True)
        pixel_grad8 = w8 * torch.norm((self.output[:, :, :-1, 1:] - self.output[:, :, 1:, :-1]), p, dim=1, keepdim=True)
        pixel_grad9 = w9 * torch.norm((self.output[:, :, 2:, :] - self.output[:, :, :-2, :]), p, dim=1, keepdim=True)
        pixel_grad10 = w10 * torch.norm((self.output[:, :, :-2, :] - self.output[:, :, 2:, :]), p, dim=1, keepdim=True)
        pixel_grad11 = w11 * torch.norm((self.output[:, :, :, 2:] - self.output[:, :, :, :-2]), p, dim=1, keepdim=True)
        pixel_grad12 = w12 * torch.norm((self.output[:, :, :, :-2] - self.output[:, :, :, 2:]), p, dim=1, keepdim=True)
        pixel_grad13 = w13 * torch.norm((self.output[:, :, :-2, :-1] - self.output[:, :, 2:, 1:]), p, dim=1, keepdim=True)
        pixel_grad14 = w14 * torch.norm((self.output[:, :, 2:, 1:] - self.output[:, :, :-2, :-1]), p, dim=1, keepdim=True)
        pixel_grad15 = w15 * torch.norm((self.output[:, :, 2:, :-1] - self.output[:, :, :-2, 1:]), p, dim=1, keepdim=True)
        pixel_grad16 = w16 * torch.norm((self.output[:, :, :-2, 1:] - self.output[:, :, 2:, :-1]), p, dim=1, keepdim=True)
        pixel_grad17 = w17 * torch.norm((self.output[:, :, :-1, :-2] - self.output[:, :, 1:, 2:]), p, dim=1, keepdim=True)
        pixel_grad18 = w18 * torch.norm((self.output[:, :, 1:, 2:] - self.output[:, :, :-1, :-2]), p, dim=1, keepdim=True)
        pixel_grad19 = w19 * torch.norm((self.output[:, :, 1:, :-2] - self.output[:, :, :-1, 2:]), p, dim=1, keepdim=True)
        pixel_grad20 = w20 * torch.norm((self.output[:, :, :-1, 2:] - self.output[:, :, 1:, :-2]), p, dim=1, keepdim=True)
        pixel_grad21 = w21 * torch.norm((self.output[:, :, :-2, :-2] - self.output[:, :, 2:, 2:]), p, dim=1, keepdim=True)
        pixel_grad22 = w22 * torch.norm((self.output[:, :, 2:, 2:] - self.output[:, :, :-2, :-2]), p, dim=1, keepdim=True)
        pixel_grad23 = w23 * torch.norm((self.output[:, :, 2:, :-2] - self.output[:, :, :-2, 2:]), p, dim=1, keepdim=True)
        pixel_grad24 = w24 * torch.norm((self.output[:, :, :-2, 2:] - self.output[:, :, 2:, :-2]), p, dim=1, keepdim=True)

        ReguTerm1 = torch.mean(pixel_grad1) \
                    + torch.mean(pixel_grad2) \
                    + torch.mean(pixel_grad3) \
                    + torch.mean(pixel_grad4) \
                    + torch.mean(pixel_grad5) \
                    + torch.mean(pixel_grad6) \
                    + torch.mean(pixel_grad7) \
                    + torch.mean(pixel_grad8) \
                    + torch.mean(pixel_grad9) \
                    + torch.mean(pixel_grad10) \
                    + torch.mean(pixel_grad11) \
                    + torch.mean(pixel_grad12) \
                    + torch.mean(pixel_grad13) \
                    + torch.mean(pixel_grad14) \
                    + torch.mean(pixel_grad15) \
                    + torch.mean(pixel_grad16) \
                    + torch.mean(pixel_grad17) \
                    + torch.mean(pixel_grad18) \
                    + torch.mean(pixel_grad19) \
                    + torch.mean(pixel_grad20) \
                    + torch.mean(pixel_grad21) \
                    + torch.mean(pixel_grad22) \
                    + torch.mean(pixel_grad23) \
                    + torch.mean(pixel_grad24)
        total_term = ReguTerm1
        return total_term

3 实验

见图7。在LSRW数据集上对当前最先进的低照度图像增强方法进行了视觉比较。

见图8。在一些具有挑战性的实例上进行视觉比较。更多的结果可以在补充材料中找到。

二 SCI效果

Requirements:python3.7 pytorch==1.8.0 cuda11.1

1 下载代码

git clone https://github.com/vis-opt-group/SCI.git

2 运行

python3 test.py

原图:

效果图:

至此,本文分享的内容就结束啦💕💕💕💕💕💕。

相关推荐
JeffreyGu.5 分钟前
软考中级软件设计师如何两个月通过
笔记
老黄浅谈质量6 分钟前
PyCharm结合DeepSeek-R1
ide·python·ai·pycharm·deepseek
qwq_ovo_pwp7 分钟前
题解 洛谷 Luogu P1828 [USACO3.2] 香甜的黄油 Sweet Butter 最短路 堆优化Dijkstra Floyd C++
数据结构·c++·算法·图论·最短路
深度安全实验室13 分钟前
AI-学习路线图-PyTorch-我是土堆
人工智能
孤寂大仙v14 分钟前
蓝耘智算平台部署deepseek-助力深度学习
人工智能·深度学习
CodeClimb19 分钟前
【华为OD-E卷 - 120 分割数组的最大差值 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
阿里云云原生23 分钟前
从云原生到 AI 原生,谈谈我经历的网关发展历程和趋势
人工智能·云原生
硕风和炜26 分钟前
【LeetCode: 378. 有序矩阵中第 K 小的元素 + 二分】
java·算法·leetcode·面试·矩阵·二分
追逐梦想永不停27 分钟前
公司配置内网穿透方法笔记(二):FTP内网穿透方法
笔记
三月七(爱看动漫的程序员)28 分钟前
基础链的使用
网络·数据库·人工智能·语言模型·自然语言处理·prompt·智能路由器