超分辨率重建 | CVPR 2024 DarkIR:轻量级低光照图像增强与去模糊模型(代码实践)

论文名称:DarkIR: Robust Low-Light Image Restoration

论文原文 (Paper)https://arxiv.org/pdf/2412.13443
官方代码 (Code)https://github.com/cidautai/DarkIR

超分辨率重建 | CVPR 2025 DarkIR:轻量级低光照图像增强与去模糊模型(论文精度)

摘要: 现有的目标检测和图像恢复模型在低光照、夜间或运动模糊场景下往往表现不佳。本文复现了CVPR 2024 最新论文 DarkIR 中的核心组件:EBlock (基于频域的低光增强编码块)DBlock (基于大感受野的去模糊解码块) 。这两个模块设计轻量且高效,非常适合作为即插即用模块集成到网络中,用于提升模型在暗光增强去模糊任务中的鲁棒性。


目录

    • [[超分辨率重建 | CVPR 2025 DarkIR:轻量级低光照图像增强与去模糊模型(论文精度)](https://editor.csdn.net/md/?articleId=154455320)](#超分辨率重建 | CVPR 2025 DarkIR:轻量级低光照图像增强与去模糊模型(论文精度))
    • [一、 论文理论与模块解析](#一、 论文理论与模块解析)
      • [1. 论文背景与痛点](#1. 论文背景与痛点)
      • [2. 核心模块原理](#2. 核心模块原理)
        • [(1) EBlock:基于频域的低光增强 (Encoder Block)](#(1) EBlock:基于频域的低光增强 (Encoder Block))
        • [(2) DBlock:大感受野去模糊 (Decoder Block)](#(2) DBlock:大感受野去模糊 (Decoder Block))
    • [二、 核心代码复现](#二、 核心代码复现)
    • [三、 结果验证与引流](#三、 结果验证与引流)
      • [1. 独立运行测试](#1. 独立运行测试)

一、 论文理论与模块解析

1. 论文背景与痛点

在夜间摄影或安防监控中,图像通常面临"低光照(Low-light)"和"运动模糊(Blur)"的双重挑战。传统的图像增强方法通常将这两个问题分开处理,导致处理效率低且效果不连贯。

  • 痛点: 目前主流的 YOLO 系列或 CNN 网络在处理夜间模糊图像时,特征提取能力会大幅下降,导致漏检或误检。
  • DarkIR 的解决方案: 提出了一种多任务低光照图像恢复网络,通过在频域 处理光照信息,在空域处理模糊信息,实现了端到端的高效恢复。

2. 核心模块原理

本文提取了 DarkIR 中最核心的两个模块,均已封装为即插即用的 nn.Module,可直接替换主干网络(Backbone)或特征融合层(Neck)。

(1) EBlock:基于频域的低光增强 (Encoder Block)
  • 核心原理: 论文指出,图像的"光照/亮度"信息主要集中在频域的**幅度谱(Amplitude)**中。
  • 代码实现: EBlock 内部集成了 FreMLP (Frequency MLP)。它利用 FFT (快速傅里叶变换) 将特征转换到频域,仅对幅度谱进行增强,保持相位谱(结构信息)不变,从而在不破坏物体结构的前提下提亮特征。
  • 适用位置: 适合放置在网络的 浅层 Backbone,用于在特征提取初期改善光照条件。
(2) DBlock:大感受野去模糊 (Decoder Block)
  • 核心原理: 去模糊任务需要较大的感受野来捕捉上下文信息。
  • 代码实现: DBlock 引入了多分支的 扩张卷积 (Dilated Convolution) 结构(代码中的 dilations 参数)。通过并行使用不同扩张率(如 1, 3, 5)的卷积,模块能够模拟"大核卷积"的效果,捕捉不同尺度的模糊特征,同时保持较低的计算量。
  • 适用位置: 适合放置在网络的 深层 BackboneNeck/Head 部分,用于细化特征并去除模糊干扰。

二、 核心代码复现

本部分提供了完整的、经过测试的 PyTorch 代码。代码包含 FreMLP(频域处理)、EBlock(编码块)和 DBlock(解码块)的完整实现,并附带了必要的 LayerNorm 和 Gate 机制。

python 复制代码
"""
DarkIR模型的核心组件测试文件
包含编码块(EBlock)和解码块(DBlock)的定义和测试
"""
import torch
import torch.nn as nn


# LayerNorm2d: 2D层归一化实现
class LayerNormFunction(torch.autograd.Function):
    """自定义2D层归一化的前向和反向传播函数"""

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)
        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None


class LayerNorm2d(nn.Module):
    """2D层归一化模块,对通道维度进行归一化"""

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)


class SimpleGate(nn.Module):
    """简单门控机制:将输入通道分为两部分并逐元素相乘"""
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2


class FreMLP(nn.Module):
    """频域MLP:在频域对幅度谱进行处理,保持相位不变"""
    def __init__(self, nc, expand=2):
        super(FreMLP, self).__init__()
        self.process1 = nn.Sequential(
            nn.Conv2d(nc, expand * nc, 1, 1, 0),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(expand * nc, nc, 1, 1, 0)
        )

    def forward(self, x):
        _, _, H, W = x.shape
        x_freq = torch.fft.rfft2(x, norm='backward')
        mag = torch.abs(x_freq)
        pha = torch.angle(x_freq)
        mag = self.process1(mag)
        real = mag * torch.cos(pha)
        imag = mag * torch.sin(pha)
        x_out = torch.complex(real, imag)
        x_out = torch.fft.irfft2(x_out, s=(H, W), norm='backward')
        return x_out


class Branch(nn.Module):
    """多分支结构中的单个分支:使用深度可分离卷积,支持不同的空洞率"""
    def __init__(self, c, DW_Expand, dilation=1):
        super().__init__()
        self.dw_channel = DW_Expand * c
        self.branch = nn.Sequential(
            nn.Conv2d(in_channels=self.dw_channel, out_channels=self.dw_channel, kernel_size=3, padding=dilation, stride=1, groups=self.dw_channel,
                      bias=True, dilation=dilation)
        )

    def forward(self, input):
        return self.branch(input)


class DBlock(nn.Module):
    """解码块:包含多头分支注意力机制和前馈网络(FFN)"""
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, dilations=[1], extra_depth_wise=False):
        super().__init__()
        self.dw_channel = DW_Expand * c

        self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation=1)
        self.extra_conv = nn.Conv2d(self.dw_channel, self.dw_channel, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity()
        self.branches = nn.ModuleList()
        for dilation in dilations:
            self.branches.append(Branch(self.dw_channel, DW_Expand=1, dilation=dilation))

        assert len(dilations) == len(self.branches)
        self.dw_channel = DW_Expand * c
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True, dilation=1),
        )
        self.sg1 = SimpleGate()
        self.sg2 = SimpleGate()
        self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation=1)
        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp, adapter=None):
        y = inp
        x = self.norm1(inp)
        x = self.extra_conv(self.conv1(x))
        z = 0
        for branch in self.branches:
            z += branch(x)

        z = self.sg1(z)
        x = self.sca(z) * z
        x = self.conv3(x)
        y = inp + self.beta * x
        x = self.conv4(self.norm2(y))
        x = self.sg2(x)
        x = self.conv5(x)
        x = y + x * self.gamma
        return x


class EBlock(nn.Module):
    """编码块:包含多头分支注意力机制和频域处理模块"""
    def __init__(self, c, DW_Expand=2, dilations=[1], extra_depth_wise=False):
        super().__init__()
        self.dw_channel = DW_Expand * c
        self.extra_conv = nn.Conv2d(c, c, kernel_size=3, padding=1, stride=1, groups=c, bias=True, dilation=1) if extra_depth_wise else nn.Identity()
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=self.dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation=1)

        self.branches = nn.ModuleList()
        for dilation in dilations:
            self.branches.append(Branch(c, DW_Expand, dilation=dilation))

        assert len(dilations) == len(self.branches)
        self.dw_channel = DW_Expand * c
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=self.dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True, dilation=1),
        )
        self.sg1 = SimpleGate()
        self.conv3 = nn.Conv2d(in_channels=self.dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True, dilation=1)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)
        self.freq = FreMLP(nc=c, expand=2)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        y = inp
        x = self.norm1(inp)
        x = self.conv1(self.extra_conv(x))
        z = 0
        for branch in self.branches:
            z += branch(x)

        z = self.sg1(z)
        x = self.sca(z) * z
        x = self.conv3(x)
        y = inp + self.beta * x
        x_step2 = self.norm2(y)
        x_freq = self.freq(x_step2)
        x = y * x_freq
        x = y + x * self.gamma
        return x


# 测试代码:验证EBlock和DBlock的输入输出形状
# 输入输出格式: B C H W (Batch, Channels, Height, Width)
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    x = torch.randn(1, 32, 64, 64).to(device)

    # 创建编码块和解码块实例
    eblock = EBlock(32, DW_Expand=2, dilations=[1, 3, 5], extra_depth_wise=True).to(device)
    dblock = DBlock(32, DW_Expand=2, FFN_Expand=2, dilations=[1, 3], extra_depth_wise=True).to(device)

    y_e = eblock(x)
    y_d = dblock(x)

    # print("\n", eblock)
    print("EBlock input:", x.shape)
    print("EBlock output:", y_e.shape)
    # print("\n", dblock)
    print("DBlock input:", x.shape)
    print("DBlock output:", y_d.shape)
    print("\n")

三、 结果验证与引流

1. 独立运行测试

将上述代码保存并运行。可以看到模块成功处理了 (1, 32, 64, 64) 的输入,且输出尺寸保持不变,证明了其即插即用的特性。


声明: 本专栏提供的均为 独立、可运行 的 Python 模块代码,旨在帮助大家快速复现论文、优化模型。

相关推荐
老吴学AI18 小时前
系列报告十:(Menlo)《2025: The State of Generative AI in the Enterprise》
人工智能·vibe coding
喜欢吃豆18 小时前
深度解析:FFmpeg 远程流式解复用原理与工程实践
人工智能·架构·ffmpeg·大模型·音视频·多模态
ChaITSimpleLove18 小时前
AI时代编程范式:“游击战”与“阵地战”的灵活应用
人工智能·ai编程范式·战略思维·战术思维·灵活策略·游击战与阵地战
hacker70718 小时前
精进Excel图表:AI赋能,成为Excel图表高手
人工智能·信息可视化·excel
OpenBayes18 小时前
HY-MT1.5-1.8B 支持多语言神经机器翻译;Med-Banana-50K 提供医学影像编辑基准数据
人工智能·深度学习·自然语言处理·数据集·机器翻译·图像生成
综合热讯18 小时前
脑机接口赋能 认知障碍诊疗迈入精准时代
人工智能·机器学习·数据挖掘
victory043118 小时前
pytorch 矩阵乘法和实际存储形状的差异
人工智能·pytorch·矩阵
之歆18 小时前
Spring AI入门到实战到原理源码-多模型协作智能客服系统
java·人工智能·spring
盛世宏博北京18 小时前
《可复制推广:智慧档案馆 “十防” 安全防护体系建设指南》
网络·人工智能·web安全·智慧档案