Pytorch图像去噪实战(九):SwinIR图像去噪实战,用Transformer解决CNN纹理恢复不足问题

Pytorch图像去噪实战(九):SwinIR图像去噪实战,用Transformer解决CNN纹理恢复不足问题


一、问题场景:CNN模型去噪稳定,但复杂纹理恢复不够自然

前面我们已经实现了 DnCNN、UNet、ResUNet、Attention UNet、FFDNet、CBDNet 以及自监督去噪方法。

这些方法大多基于 CNN。

CNN在图像去噪中非常稳定,但它有一个天然问题:

卷积更擅长局部建模,对复杂纹理和长距离依赖的表达能力有限。

比如:

  • 头发纹理
  • 布料纹理
  • 建筑线条
  • 草地、树叶
  • 医学图像细微结构

普通CNN模型容易出现两种结果:

  • 去噪强了,纹理被抹平
  • 保纹理,噪声又残留

为了解决这个问题,我们引入 Transformer 思路。

这一篇我们实现一个简化版 SwinIR 风格的图像去噪模型。


二、为什么Transformer适合图像恢复?

CNN的卷积核通常是局部的,比如 3x3。

虽然堆叠多层可以扩大感受野,但对长距离关系的建模仍然不够直接。

Transformer的优势是:

可以建模更大范围内的像素关系。

Swin Transformer进一步通过窗口注意力降低计算量,让它更适合图像任务。

SwinIR就是把 Swin Transformer 用到图像恢复任务中的代表模型。


三、工程化理解SwinIR

完整SwinIR实现比较复杂,包括:

  • Patch Embedding
  • Window Attention
  • Shifted Window
  • Residual Swin Transformer Block
  • Reconstruction Head

为了适合实战入门,我们先实现一个简化版思想:

text 复制代码
Conv浅层特征提取
-> Window Attention建模局部上下文
-> 残差连接
-> Conv重建图像

这样既能理解核心思想,也方便后续扩展。


四、工程目录结构

复制代码
swinir_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── mini_swinir.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、Window Attention模块实现

这里实现一个简化版窗口注意力,不做 shifted window,先帮助理解核心流程。

models/mini_swinir.py

python 复制代码
import torch
import torch.nn as nn


class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()

        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        b, n, c = x.shape

        qkv = self.qkv(x)
        qkv = qkv.reshape(b, n, 3, self.num_heads, c // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        out = attn @ v
        out = out.transpose(1, 2).reshape(b, n, c)
        out = self.proj(out)

        return out

六、Transformer Block实现

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, mlp_ratio=2):
        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

七、Mini SwinIR完整模型

python 复制代码
import torch
import torch.nn as nn


class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()

        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        b, n, c = x.shape

        qkv = self.qkv(x)
        qkv = qkv.reshape(b, n, 3, self.num_heads, c // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        out = attn @ v
        out = out.transpose(1, 2).reshape(b, n, c)
        return self.proj(out)


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class MiniSwinIRDenoise(nn.Module):
    def __init__(self, in_channels=1, dim=64, depth=4):
        super().__init__()

        self.head = nn.Conv2d(in_channels, dim, 3, padding=1)

        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads=4)
            for _ in range(depth)
        ])

        self.tail = nn.Conv2d(dim, in_channels, 3, padding=1)

    def forward(self, x):
        residual = x

        feat = self.head(x)
        b, c, h, w = feat.shape

        tokens = feat.flatten(2).transpose(1, 2)

        for block in self.blocks:
            tokens = block(tokens)

        feat = tokens.transpose(1, 2).reshape(b, c, h, w)

        noise = self.tail(feat)

        return residual - noise

八、数据集代码

python 复制代码
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class DenoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        sigma = random.choice([15, 25, 50])
        noise = torch.randn_like(clean) * sigma / 255.0

        noisy = torch.clamp(clean + noise, 0.0, 1.0)

        return noisy, clean

九、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.mini_swinir import MiniSwinIRDenoise


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = DenoiseDataset("data/train", patch_size=128)
    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

    model = MiniSwinIRDenoise().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    criterion = torch.nn.L1Loss()

    for epoch in range(1, 81):
        model.train()
        total_loss = 0

        for noisy, clean in loader:
            noisy = noisy.to(device)
            clean = clean.to(device)

            pred = model(noisy)
            loss = criterion(pred, clean)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"mini_swinir_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

十、为什么Transformer训练更吃显存?

Transformer要计算 token 之间的注意力关系。

如果输入是 128x128,token数量就是:

text 复制代码
128 * 128 = 16384

注意力计算复杂度接近:

text 复制代码
N * N

所以显存会很高。

完整SwinIR使用 window attention 来降低计算量。

本文的简化版本适合入门理解,不建议直接用超大图训练。


十一、工程优化建议

如果显存不够,可以:

1. 减小patch size

python 复制代码
patch_size = 64

2. 减小dim

python 复制代码
dim = 32

3. 减少depth

python 复制代码
depth = 2

4. 使用混合精度训练

python 复制代码
torch.cuda.amp.autocast()

十二、踩坑记录

坑1:输入图太大,显存直接爆

Transformer不要一开始就用 256 或 512 图。

建议:

text 复制代码
64 -> 128 -> 256

逐步测试。


坑2:训练初期loss震荡

Transformer对学习率更敏感。

建议:

python 复制代码
lr = 2e-4

如果不稳,降到:

python 复制代码
lr = 1e-4

坑3:小数据集容易过拟合

Transformer参数表达能力强,小数据容易记忆训练集。

解决:

  • 数据增强
  • 权重衰减
  • 随机噪声强度
  • 提前停止

十三、效果验证

SwinIR类模型的优势主要体现在复杂纹理区域:

模型 平坦区域 复杂纹理 显存
UNet 一般
ResUNet 较好
MiniSwinIR 更自然

如果只是普通噪声,UNet已经够用。

如果图像纹理复杂,Transformer结构更有优势。


十四、适合收藏总结

Transformer去噪完整流程

  1. Conv提取浅层特征
  2. 展平成token
  3. Transformer建模上下文
  4. 恢复为图像特征
  5. 预测噪声残差
  6. noisy - noise得到clean

避坑清单

  • 不要直接上大图
  • Transformer更吃显存
  • 学习率要谨慎
  • 小数据容易过拟合
  • 建议从Mini版本开始

十五、优化方向

可以继续升级:

  • 实现真正Window Partition
  • 加Shifted Window
  • 使用Residual Swin Block
  • 加Patch Embedding
  • 直接复现完整SwinIR

结尾总结

SwinIR类方法真正解决的是 CNN 的表达上限问题。

CNN强在稳定,Transformer强在建模复杂关系。

在图像去噪任务中,如果你追求更自然的纹理恢复,Transformer是必须学习的一条路线。


下一篇预告

Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer处理高分辨率图像

相关推荐
西西弗Sisyphus1 小时前
Transformer 架构里关于 Attention 概念的澄清
transformer·attention·注意力机制·注意力·self-attention
毕胜客源码1 小时前
卷积神经网络的手势识别系统(有技术文档)深度学习 图像识别 卷积神经网络 Django python 人工智能
人工智能·python·深度学习·cnn·django
Jmayday2 小时前
Pytorch:CNN进行图象分类案例
人工智能·pytorch·cnn
郝学胜-神的一滴2 小时前
深度学习核心:损失函数完全解析 —— 从原理到 PyTorch 实战
人工智能·pytorch·python·深度学习·机器学习
AI技术增长2 小时前
Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题
pytorch·深度学习·机器学习·cnn·transformer
yongui478342 小时前
基于卷积神经网络(CNN)的盲源分离MATLAB实现
人工智能·matlab·cnn
AI技术增长2 小时前
Pytorch图像去噪实战(七):Noise2Noise自监督图像去噪实战,没有干净图也能训练模型
人工智能·pytorch·python
这张生成的图像能检测吗13 小时前
(论文速读)IMSE-IGA-CNN-Transformer
人工智能·深度学习·cnn·transformer·故障诊断·预测模型·时序模型
大连好光景14 小时前
《从函数到大模型速通》
rnn·cnn·transformer