Pytorch图像去噪实战(九):SwinIR图像去噪实战,用Transformer解决CNN纹理恢复不足问题
一、问题场景:CNN模型去噪稳定,但复杂纹理恢复不够自然
前面我们已经实现了 DnCNN、UNet、ResUNet、Attention UNet、FFDNet、CBDNet 以及自监督去噪方法。
这些方法大多基于 CNN。
CNN在图像去噪中非常稳定,但它有一个天然问题:
卷积更擅长局部建模,对复杂纹理和长距离依赖的表达能力有限。
比如:
- 头发纹理
- 布料纹理
- 建筑线条
- 草地、树叶
- 医学图像细微结构
普通CNN模型容易出现两种结果:
- 去噪强了,纹理被抹平
- 保纹理,噪声又残留
为了解决这个问题,我们引入 Transformer 思路。
这一篇我们实现一个简化版 SwinIR 风格的图像去噪模型。
二、为什么Transformer适合图像恢复?
CNN的卷积核通常是局部的,比如 3x3。
虽然堆叠多层可以扩大感受野,但对长距离关系的建模仍然不够直接。
Transformer的优势是:
可以建模更大范围内的像素关系。
Swin Transformer进一步通过窗口注意力降低计算量,让它更适合图像任务。
SwinIR就是把 Swin Transformer 用到图像恢复任务中的代表模型。
三、工程化理解SwinIR
完整SwinIR实现比较复杂,包括:
- Patch Embedding
- Window Attention
- Shifted Window
- Residual Swin Transformer Block
- Reconstruction Head
为了适合实战入门,我们先实现一个简化版思想:
text
Conv浅层特征提取
-> Window Attention建模局部上下文
-> 残差连接
-> Conv重建图像
这样既能理解核心思想,也方便后续扩展。
四、工程目录结构
swinir_denoise/
├── data/
│ ├── train/
│ └── val/
├── models/
│ └── mini_swinir.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py
五、Window Attention模块实现
这里实现一个简化版窗口注意力,不做 shifted window,先帮助理解核心流程。
models/mini_swinir.py
python
import torch
import torch.nn as nn
class WindowAttention(nn.Module):
def __init__(self, dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, self.num_heads, c // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
out = out.transpose(1, 2).reshape(b, n, c)
out = self.proj(out)
return out
六、Transformer Block实现
python
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads=4, mlp_ratio=2):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Linear(dim * mlp_ratio, dim)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
七、Mini SwinIR完整模型
python
import torch
import torch.nn as nn
class WindowAttention(nn.Module):
def __init__(self, dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(b, n, 3, self.num_heads, c // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
out = out.transpose(1, 2).reshape(b, n, c)
return self.proj(out)
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads=4):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class MiniSwinIRDenoise(nn.Module):
def __init__(self, in_channels=1, dim=64, depth=4):
super().__init__()
self.head = nn.Conv2d(in_channels, dim, 3, padding=1)
self.blocks = nn.ModuleList([
TransformerBlock(dim, num_heads=4)
for _ in range(depth)
])
self.tail = nn.Conv2d(dim, in_channels, 3, padding=1)
def forward(self, x):
residual = x
feat = self.head(x)
b, c, h, w = feat.shape
tokens = feat.flatten(2).transpose(1, 2)
for block in self.blocks:
tokens = block(tokens)
feat = tokens.transpose(1, 2).reshape(b, c, h, w)
noise = self.tail(feat)
return residual - noise
八、数据集代码
python
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class DenoiseDataset(Dataset):
def __init__(self, root_dir, patch_size=128):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.patch_size = patch_size
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("L")
w, h = img.size
if w >= self.patch_size and h >= self.patch_size:
x = random.randint(0, w - self.patch_size)
y = random.randint(0, h - self.patch_size)
img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
else:
img = img.resize((self.patch_size, self.patch_size))
clean = self.to_tensor(img)
sigma = random.choice([15, 25, 50])
noise = torch.randn_like(clean) * sigma / 255.0
noisy = torch.clamp(clean + noise, 0.0, 1.0)
return noisy, clean
九、训练代码
python
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.mini_swinir import MiniSwinIRDenoise
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = DenoiseDataset("data/train", patch_size=128)
loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
model = MiniSwinIRDenoise().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = torch.nn.L1Loss()
for epoch in range(1, 81):
model.train()
total_loss = 0
for noisy, clean in loader:
noisy = noisy.to(device)
clean = clean.to(device)
pred = model(noisy)
loss = criterion(pred, clean)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"mini_swinir_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
十、为什么Transformer训练更吃显存?
Transformer要计算 token 之间的注意力关系。
如果输入是 128x128,token数量就是:
text
128 * 128 = 16384
注意力计算复杂度接近:
text
N * N
所以显存会很高。
完整SwinIR使用 window attention 来降低计算量。
本文的简化版本适合入门理解,不建议直接用超大图训练。
十一、工程优化建议
如果显存不够,可以:
1. 减小patch size
python
patch_size = 64
2. 减小dim
python
dim = 32
3. 减少depth
python
depth = 2
4. 使用混合精度训练
python
torch.cuda.amp.autocast()
十二、踩坑记录
坑1:输入图太大,显存直接爆
Transformer不要一开始就用 256 或 512 图。
建议:
text
64 -> 128 -> 256
逐步测试。
坑2:训练初期loss震荡
Transformer对学习率更敏感。
建议:
python
lr = 2e-4
如果不稳,降到:
python
lr = 1e-4
坑3:小数据集容易过拟合
Transformer参数表达能力强,小数据容易记忆训练集。
解决:
- 数据增强
- 权重衰减
- 随机噪声强度
- 提前停止
十三、效果验证
SwinIR类模型的优势主要体现在复杂纹理区域:
| 模型 | 平坦区域 | 复杂纹理 | 显存 |
|---|---|---|---|
| UNet | 好 | 一般 | 中 |
| ResUNet | 好 | 较好 | 中 |
| MiniSwinIR | 好 | 更自然 | 高 |
如果只是普通噪声,UNet已经够用。
如果图像纹理复杂,Transformer结构更有优势。
十四、适合收藏总结
Transformer去噪完整流程
- Conv提取浅层特征
- 展平成token
- Transformer建模上下文
- 恢复为图像特征
- 预测噪声残差
- noisy - noise得到clean
避坑清单
- 不要直接上大图
- Transformer更吃显存
- 学习率要谨慎
- 小数据容易过拟合
- 建议从Mini版本开始
十五、优化方向
可以继续升级:
- 实现真正Window Partition
- 加Shifted Window
- 使用Residual Swin Block
- 加Patch Embedding
- 直接复现完整SwinIR
结尾总结
SwinIR类方法真正解决的是 CNN 的表达上限问题。
CNN强在稳定,Transformer强在建模复杂关系。
在图像去噪任务中,如果你追求更自然的纹理恢复,Transformer是必须学习的一条路线。
下一篇预告
Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer处理高分辨率图像