Pytorch图像去噪实战(十):Restormer图像去噪实战,用高效Transformer解决高分辨率去噪问题
一、问题场景:Transformer效果好,但高分辨率图片跑不动
上一篇我们实现了一个简化版 SwinIR,用 Transformer 思路提升复杂纹理恢复能力。
但很快就会遇到一个真实工程问题:
Transformer去噪效果不错,但图像稍微大一点显存就爆。
比如输入从 128x128 提升到 256x256,显存占用会明显上升。
如果处理真实业务图片,比如 1024x1024,普通全局注意力几乎不可用。
这就是图像恢复任务里非常关键的问题:
如何在高分辨率图像上使用Transformer?
Restormer就是为这类图像恢复任务设计的代表模型之一。
二、Restormer解决什么问题?
Restormer的核心目标是:
在保持Transformer建模能力的同时,降低高分辨率图像恢复的计算压力。
它的关键思想包括:
- 使用卷积保留局部结构
- 使用通道维度注意力降低复杂度
- 使用门控前馈网络增强表达
- 适合去噪、去雨、去模糊等图像恢复任务
这一篇我们实现一个简化版 Restormer Block,用于图像去噪实战。
三、为什么普通Self-Attention不适合高分辨率图像?
普通 Self-Attention 的复杂度和 token 数量平方相关。
如果图像大小是 H x W,token数是:
text
N = H * W
注意力复杂度接近:
text
N²
当输入为 512x512 时,N 非常大,计算基本不可接受。
Restormer采用更适合图像恢复的设计,避免直接做巨大空间注意力。
四、工程目录结构
restormer_denoise/
├── data/
│ ├── train/
│ └── val/
├── models/
│ └── mini_restormer.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py
五、核心模块一:LayerNorm2d
Transformer通常用 LayerNorm,但图像特征是 BCHW 格式。
这里实现一个适合图像的 LayerNorm2d。
python
import torch
import torch.nn as nn
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(channels))
self.bias = nn.Parameter(torch.zeros(channels))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=1, keepdim=True)
var = x.var(dim=1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
weight = self.weight.view(1, -1, 1, 1)
bias = self.bias.view(1, -1, 1, 1)
return x * weight + bias
六、核心模块二:通道注意力
这里实现一个简化版通道注意力,用来让模型判断哪些特征通道更重要。
python
class ChannelAttention(nn.Module):
def __init__(self, channels):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.net = nn.Sequential(
nn.Conv2d(channels, channels // 4, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // 4, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
weight = self.pool(x)
weight = self.net(weight)
return x * weight
七、核心模块三:门控前馈网络
Restormer里很重要的一个思想是 Gated Feed Forward。
简单理解:
不是所有特征都直接通过,而是通过门控机制筛选。
python
class GatedFeedForward(nn.Module):
def __init__(self, channels):
super().__init__()
hidden = channels * 2
self.project_in = nn.Conv2d(channels, hidden * 2, 1)
self.depthwise = nn.Conv2d(hidden * 2, hidden * 2, 3, padding=1, groups=hidden * 2)
self.project_out = nn.Conv2d(hidden, channels, 1)
def forward(self, x):
x = self.project_in(x)
x = self.depthwise(x)
x1, x2 = x.chunk(2, dim=1)
x = torch.nn.functional.gelu(x1) * x2
return self.project_out(x)
八、Mini Restormer Block
把 LayerNorm、Attention、GatedFFN 组合起来。
python
class RestormerBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.norm1 = LayerNorm2d(channels)
self.attn = ChannelAttention(channels)
self.norm2 = LayerNorm2d(channels)
self.ffn = GatedFeedForward(channels)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
九、完整Mini Restormer模型
models/mini_restormer.py
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(channels))
self.bias = nn.Parameter(torch.zeros(channels))
self.eps = eps
def forward(self, x):
mean = x.mean(dim=1, keepdim=True)
var = x.var(dim=1, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
class ChannelAttention(nn.Module):
def __init__(self, channels):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.net = nn.Sequential(
nn.Conv2d(channels, channels // 4, 1),
nn.ReLU(inplace=True),
nn.Conv2d(channels // 4, channels, 1),
nn.Sigmoid()
)
def forward(self, x):
weight = self.net(self.pool(x))
return x * weight
class GatedFeedForward(nn.Module):
def __init__(self, channels):
super().__init__()
hidden = channels * 2
self.project_in = nn.Conv2d(channels, hidden * 2, 1)
self.depthwise = nn.Conv2d(hidden * 2, hidden * 2, 3, padding=1, groups=hidden * 2)
self.project_out = nn.Conv2d(hidden, channels, 1)
def forward(self, x):
x = self.project_in(x)
x = self.depthwise(x)
x1, x2 = x.chunk(2, dim=1)
x = F.gelu(x1) * x2
return self.project_out(x)
class RestormerBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.norm1 = LayerNorm2d(channels)
self.attn = ChannelAttention(channels)
self.norm2 = LayerNorm2d(channels)
self.ffn = GatedFeedForward(channels)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
class MiniRestormerDenoise(nn.Module):
def __init__(self, in_channels=1, channels=64, num_blocks=6):
super().__init__()
self.head = nn.Conv2d(in_channels, channels, 3, padding=1)
self.body = nn.Sequential(*[
RestormerBlock(channels)
for _ in range(num_blocks)
])
self.tail = nn.Conv2d(channels, in_channels, 3, padding=1)
def forward(self, x):
residual = x
feat = self.head(x)
feat = self.body(feat)
noise = self.tail(feat)
return residual - noise
十、训练代码
python
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.mini_restormer import MiniRestormerDenoise
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = DenoiseDataset("data/train", patch_size=128)
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)
model = MiniRestormerDenoise().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = torch.nn.L1Loss()
for epoch in range(1, 101):
model.train()
total_loss = 0
for noisy, clean in loader:
noisy = noisy.to(device)
clean = clean.to(device)
pred = model(noisy)
loss = criterion(pred, clean)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"mini_restormer_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
十一、数据集代码
python
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class DenoiseDataset(Dataset):
def __init__(self, root_dir, patch_size=128):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.patch_size = patch_size
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("L")
w, h = img.size
if w >= self.patch_size and h >= self.patch_size:
x = random.randint(0, w - self.patch_size)
y = random.randint(0, h - self.patch_size)
img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
else:
img = img.resize((self.patch_size, self.patch_size))
clean = self.to_tensor(img)
sigma = random.choice([15, 25, 35, 50])
noisy = torch.clamp(clean + torch.randn_like(clean) * sigma / 255.0, 0, 1)
return noisy, clean
十二、为什么Restormer比普通Transformer更适合图像恢复?
普通Transformer直接对空间token做全局注意力,代价非常高。
Restormer类结构更重视:
- 局部卷积
- 通道交互
- 门控特征
- 残差学习
这使它更适合高分辨率恢复任务。
本文实现的是简化版,重点是理解结构思想,而不是完全复现论文细节。
十三、踩坑记录
坑1:LayerNorm维度写错
图像是 BCHW,不是 NLP 中的 BNC。
如果直接用 nn.LayerNorm(channels),很容易维度不匹配。
本文用的是自定义 LayerNorm2d。
坑2:GatedFFN通道数对不上
project_in 输出 hidden * 2,后面要 chunk 成两半。
如果通道设置错误,会报维度错误。
坑3:训练初期输出偏暗
原因可能是残差预测不稳定。
解决:
- 降低学习率
- 使用梯度裁剪
- 输出后 clamp
- 使用L1Loss
十四、效果验证
MiniRestormer相比普通UNet,主要优势是:
- 纹理保留更好
- 高噪声下更稳
- 平坦区域更自然
- 不容易出现过度平滑
| 模型 | 高分辨率适应性 | 纹理恢复 | 训练成本 |
|---|---|---|---|
| UNet | 好 | 中等 | 低 |
| MiniSwinIR | 一般 | 好 | 高 |
| MiniRestormer | 较好 | 好 | 中高 |
十五、适合收藏总结
MiniRestormer流程
- Conv提取浅层特征
- RestormerBlock建模
- Channel Attention筛选特征
- GatedFFN增强表达
- 预测噪声残差
- noisy - noise得到结果
避坑清单
- LayerNorm要适配BCHW
- GatedFFN通道要算对
- 学习率不要太大
- 高分辨率建议patch训练
- 推理结果必须clamp
十六、优化建议
可以继续升级:
- 多尺度Encoder-Decoder结构
- 更接近原版MDTA注意力
- 加像素shuffle上采样
- 加真实噪声数据微调
- 支持彩色RGB图像
结尾总结
Restormer代表的是图像恢复模型的一个重要方向:
不再简单套用NLP Transformer,而是针对图像恢复任务重新设计高效结构。
如果你已经掌握了 UNet 和 SwinIR,Restormer是非常值得继续深入的模型。
下一篇预告
Pytorch图像去噪实战(十一):Diffusion扩散模型图像去噪入门,从噪声预测理解生成式去噪