Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题

Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题


一、问题场景:Transformer效果好,但高分辨率图片跑不动

上一篇我们实现了一个简化版 SwinIR,用 Transformer 思路提升复杂纹理恢复能力。

但很快就会遇到一个真实工程问题:

Transformer去噪效果不错,但图像稍微大一点显存就爆。

比如输入从 128x128 提升到 256x256,显存占用会明显上升。

如果处理真实业务图片,比如 1024x1024,普通全局注意力几乎不可用。

这就是图像恢复任务里非常关键的问题:

如何在高分辨率图像上使用Transformer?

Restormer就是为这类图像恢复任务设计的代表模型之一。


二、Restormer解决什么问题?

Restormer的核心目标是:

在保持Transformer建模能力的同时,降低高分辨率图像恢复的计算压力。

它的关键思想包括:

  • 使用卷积保留局部结构
  • 使用通道维度注意力降低复杂度
  • 使用门控前馈网络增强表达
  • 适合去噪、去雨、去模糊等图像恢复任务

这一篇我们实现一个简化版 Restormer Block,用于图像去噪实战。


三、为什么普通Self-Attention不适合高分辨率图像?

普通 Self-Attention 的复杂度和 token 数量平方相关。

如果图像大小是 H x W,token数是:

text 复制代码
N = H * W

注意力复杂度接近:

text 复制代码

当输入为 512x512 时,N 非常大,计算基本不可接受。

Restormer采用更适合图像恢复的设计,避免直接做巨大空间注意力。


四、工程目录结构

复制代码
restormer_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── mini_restormer.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、核心模块一:LayerNorm2d

Transformer通常用 LayerNorm,但图像特征是 BCHW 格式。

这里实现一个适合图像的 LayerNorm2d。

python 复制代码
import torch
import torch.nn as nn


class LayerNorm2d(nn.Module):
    def __init__(self, channels, eps=1e-6):
        super().__init__()

        self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True, unbiased=False)

        x = (x - mean) / torch.sqrt(var + self.eps)

        weight = self.weight.view(1, -1, 1, 1)
        bias = self.bias.view(1, -1, 1, 1)

        return x * weight + bias

六、核心模块二:通道注意力

这里实现一个简化版通道注意力,用来让模型判断哪些特征通道更重要。

python 复制代码
class ChannelAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.net = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        weight = self.pool(x)
        weight = self.net(weight)
        return x * weight

七、核心模块三:门控前馈网络

Restormer里很重要的一个思想是 Gated Feed Forward。

简单理解:

不是所有特征都直接通过,而是通过门控机制筛选。

python 复制代码
class GatedFeedForward(nn.Module):
    def __init__(self, channels):
        super().__init__()

        hidden = channels * 2

        self.project_in = nn.Conv2d(channels, hidden * 2, 1)
        self.depthwise = nn.Conv2d(hidden * 2, hidden * 2, 3, padding=1, groups=hidden * 2)
        self.project_out = nn.Conv2d(hidden, channels, 1)

    def forward(self, x):
        x = self.project_in(x)
        x = self.depthwise(x)

        x1, x2 = x.chunk(2, dim=1)
        x = torch.nn.functional.gelu(x1) * x2

        return self.project_out(x)

八、Mini Restormer Block

把 LayerNorm、Attention、GatedFFN 组合起来。

python 复制代码
class RestormerBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.norm1 = LayerNorm2d(channels)
        self.attn = ChannelAttention(channels)

        self.norm2 = LayerNorm2d(channels)
        self.ffn = GatedFeedForward(channels)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

九、完整Mini Restormer模型

models/mini_restormer.py

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm2d(nn.Module):
    def __init__(self, channels, eps=1e-6):
        super().__init__()

        self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True, unbiased=False)

        x = (x - mean) / torch.sqrt(var + self.eps)

        return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)


class ChannelAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.net = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        weight = self.net(self.pool(x))
        return x * weight


class GatedFeedForward(nn.Module):
    def __init__(self, channels):
        super().__init__()

        hidden = channels * 2

        self.project_in = nn.Conv2d(channels, hidden * 2, 1)
        self.depthwise = nn.Conv2d(hidden * 2, hidden * 2, 3, padding=1, groups=hidden * 2)
        self.project_out = nn.Conv2d(hidden, channels, 1)

    def forward(self, x):
        x = self.project_in(x)
        x = self.depthwise(x)

        x1, x2 = x.chunk(2, dim=1)
        x = F.gelu(x1) * x2

        return self.project_out(x)


class RestormerBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()

        self.norm1 = LayerNorm2d(channels)
        self.attn = ChannelAttention(channels)

        self.norm2 = LayerNorm2d(channels)
        self.ffn = GatedFeedForward(channels)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class MiniRestormerDenoise(nn.Module):
    def __init__(self, in_channels=1, channels=64, num_blocks=6):
        super().__init__()

        self.head = nn.Conv2d(in_channels, channels, 3, padding=1)

        self.body = nn.Sequential(*[
            RestormerBlock(channels)
            for _ in range(num_blocks)
        ])

        self.tail = nn.Conv2d(channels, in_channels, 3, padding=1)

    def forward(self, x):
        residual = x

        feat = self.head(x)
        feat = self.body(feat)

        noise = self.tail(feat)

        return residual - noise

十、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.mini_restormer import MiniRestormerDenoise


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = DenoiseDataset("data/train", patch_size=128)
    loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

    model = MiniRestormerDenoise().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    criterion = torch.nn.L1Loss()

    for epoch in range(1, 101):
        model.train()
        total_loss = 0

        for noisy, clean in loader:
            noisy = noisy.to(device)
            clean = clean.to(device)

            pred = model(noisy)
            loss = criterion(pred, clean)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"mini_restormer_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

十一、数据集代码

python 复制代码
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class DenoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        sigma = random.choice([15, 25, 35, 50])
        noisy = torch.clamp(clean + torch.randn_like(clean) * sigma / 255.0, 0, 1)

        return noisy, clean

十二、为什么Restormer比普通Transformer更适合图像恢复?

普通Transformer直接对空间token做全局注意力,代价非常高。

Restormer类结构更重视:

  • 局部卷积
  • 通道交互
  • 门控特征
  • 残差学习

这使它更适合高分辨率恢复任务。

本文实现的是简化版,重点是理解结构思想,而不是完全复现论文细节。


十三、踩坑记录

坑1:LayerNorm维度写错

图像是 BCHW,不是 NLP 中的 BNC。

如果直接用 nn.LayerNorm(channels),很容易维度不匹配。

本文用的是自定义 LayerNorm2d。


坑2:GatedFFN通道数对不上

project_in 输出 hidden * 2,后面要 chunk 成两半。

如果通道设置错误,会报维度错误。


坑3:训练初期输出偏暗

原因可能是残差预测不稳定。

解决:

  • 降低学习率
  • 使用梯度裁剪
  • 输出后 clamp
  • 使用L1Loss

十四、效果验证

MiniRestormer相比普通UNet,主要优势是:

  • 纹理保留更好
  • 高噪声下更稳
  • 平坦区域更自然
  • 不容易出现过度平滑
模型 高分辨率适应性 纹理恢复 训练成本
UNet 中等
MiniSwinIR 一般
MiniRestormer 较好 中高

十五、适合收藏总结

MiniRestormer流程

  1. Conv提取浅层特征
  2. RestormerBlock建模
  3. Channel Attention筛选特征
  4. GatedFFN增强表达
  5. 预测噪声残差
  6. noisy - noise得到结果

避坑清单

  • LayerNorm要适配BCHW
  • GatedFFN通道要算对
  • 学习率不要太大
  • 高分辨率建议patch训练
  • 推理结果必须clamp

十六、优化建议

可以继续升级:

  • 多尺度Encoder-Decoder结构
  • 更接近原版MDTA注意力
  • 加像素shuffle上采样
  • 加真实噪声数据微调
  • 支持彩色RGB图像

结尾总结

Restormer代表的是图像恢复模型的一个重要方向:

不再简单套用NLP Transformer,而是针对图像恢复任务重新设计高效结构。

如果你已经掌握了 UNet 和 SwinIR,Restormer是非常值得继续深入的模型。


下一篇预告

Pytorch图像去噪实战(十一):Diffusion扩散模型图像去噪入门,从噪声预测理解生成式去噪

相关推荐
数据与后端架构提升之路1 小时前
自动驾驶数据闭环中,Video Clip 的多模态特征到底怎么提取?
人工智能·机器学习·自动驾驶
audyxiao0011 小时前
智能交通顶刊TITS论文分享|一种基于文本提示引导的多模态大语言模型的交通流预测框架
人工智能·深度学习·多模态大模型
动物园猫1 小时前
工业粉尘检测数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·分类
yongui478341 小时前
基于卷积神经网络(CNN)的盲源分离MATLAB实现
人工智能·matlab·cnn
TTGGGFF1 小时前
深度学习如何重塑三维重建:从任务定义到工程落地全流程解析
人工智能·深度学习
AI技术增长1 小时前
Pytorch图像去噪实战(七):Noise2Noise自监督图像去噪实战,没有干净图也能训练模型
人工智能·pytorch·python
广州灵眸科技有限公司2 小时前
瑞芯微(EASY EAI)RV1126B AI算法开发流程
人工智能·算法·机器学习
生信碱移10 小时前
PACells:这个方法可以鉴定疾病/预后相关的重要细胞亚群,作者提供的代码流程可以学习起来了,甚至兼容转录组与 ATAC 两种数据类型!
人工智能·学习·算法·机器学习·数据挖掘·数据分析·r语言
jay神11 小时前
VisDrone2019-DET 无人机小目标检测数据集
人工智能·深度学习·yolo·目标检测·计算机视觉·毕业设计·无人机