Pytorch图像去噪实战(三):ResUNet图像去噪模型实战,解决UNet深层训练不稳定问题
一、问题场景:UNet效果提升了,但训练深一点就不稳定
上一篇我们用 UNet 替代 DnCNN,解决了很多细节丢失问题。
但当我想继续提升模型能力时,很快遇到了新的坑:
UNet层数加深后,loss下降变慢,训练后期震荡,甚至出现输出发灰、边缘断裂的问题。
一开始我以为是学习率问题,于是尝试:
- lr 从 1e-4 降到 1e-5
- batch size 从 8 改成 4
- epoch 从 50 加到 100
结果发现问题只是缓解,并没有根本解决。
后面定位发现:
当UNet变深后,普通卷积块训练难度增加,梯度传递不够稳定。
这时就需要引入一个非常经典的结构:Residual Block 残差块。
二、真实问题分析:为什么普通UNet加深后容易不稳定?
UNet本身已经有 skip connection,但它的 skip 是 encoder 到 decoder 的跨层连接。
而每个卷积块内部,仍然是普通卷积堆叠:
python
Conv -> BN -> ReLU -> Conv -> BN -> ReLU
如果模型变深,这种结构会出现几个问题:
1. 梯度传递路径变长
深层网络训练时,梯度需要穿过很多层,容易衰减。
2. 低层信息容易被覆盖
图像去噪不是分类任务,浅层纹理信息非常重要。
普通卷积块可能会过度变换特征。
3. 输出容易偏平滑
当网络表达能力增强但约束不足时,很容易走向"平滑解"。
三、解决方案:把UNet中的卷积块替换为残差块
残差结构的核心思想:
不直接学习完整映射,而是学习输入和输出之间的残差。
公式可以理解为:
output = F(x) + x
这样做的好处是:
- 梯度更容易回传
- 特征不容易被破坏
- 深层网络更容易训练
- 对图像恢复任务非常友好
四、ResUNet整体结构
ResUNet = UNet主体结构 + Residual Block
整体仍然是:
Encoder -> Bottleneck -> Decoder
区别是:
普通 DoubleConv 替换为 ResBlock。
五、工程目录结构
resunet_denoise/
├── data/
│ ├── train/
│ └── val/
├── models/
│ └── resunet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py
六、残差块实现
models/resunet.py
python
import torch
import torch.nn as nn
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels)
)
self.shortcut = nn.Identity()
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv(x)
shortcut = self.shortcut(x)
return self.relu(out + shortcut)
七、完整ResUNet模型代码
python
import torch
import torch.nn as nn
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels)
)
self.shortcut = nn.Identity()
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.main(x) + self.shortcut(x))
class ResUNetDenoise(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.pool = nn.MaxPool2d(2)
self.enc1 = ResBlock(in_channels, 64)
self.enc2 = ResBlock(64, 128)
self.enc3 = ResBlock(128, 256)
self.bottleneck = ResBlock(256, 512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = ResBlock(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = ResBlock(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = ResBlock(128, 64)
self.out = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
b = self.bottleneck(self.pool(e3))
d3 = self.up3(b)
d3 = torch.cat([d3, e3], dim=1)
d3 = self.dec3(d3)
d2 = self.up2(d3)
d2 = torch.cat([d2, e2], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
d1 = torch.cat([d1, e1], dim=1)
d1 = self.dec1(d1)
return self.out(d1)
八、数据集代码
python
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class DenoiseDataset(Dataset):
def __init__(self, root_dir, patch_size=128):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.patch_size = patch_size
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("L")
w, h = img.size
if w >= self.patch_size and h >= self.patch_size:
left = random.randint(0, w - self.patch_size)
top = random.randint(0, h - self.patch_size)
img = img.crop((left, top, left + self.patch_size, top + self.patch_size))
else:
img = img.resize((self.patch_size, self.patch_size))
clean = self.to_tensor(img)
sigma = random.choice([15, 25, 35, 50])
noise = torch.randn_like(clean) * sigma / 255.0
noisy = torch.clamp(clean + noise, 0.0, 1.0)
return noisy, clean
九、训练代码
python
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.resunet import ResUNetDenoise
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = DenoiseDataset("data/train")
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)
model = ResUNetDenoise().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = torch.nn.L1Loss()
for epoch in range(1, 61):
model.train()
total_loss = 0
for noisy, clean in loader:
noisy = noisy.to(device)
clean = clean.to(device)
pred = model(noisy)
loss = criterion(pred, clean)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"resunet_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
十、为什么这里使用AdamW?
普通 Adam 能训练,但我在实验中发现:
- Adam 前期下降快
- 后期容易震荡
- 输出有时出现轻微伪影
AdamW 加入权重衰减后更稳:
python
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
对于图像恢复任务,AdamW通常比Adam更适合做中后期稳定训练。
十一、加入梯度裁剪,避免训练突然爆炸
当网络变深后,偶尔会出现 loss 突然变大的情况。
可以加入梯度裁剪:
python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
放在:
python
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
这一步在工程里非常实用。
十二、效果验证
实际对比:
| 模型 | 收敛稳定性 | 细节恢复 | 训练难度 |
|---|---|---|---|
| UNet | 中等 | 较好 | 低 |
| ResUNet | 更稳定 | 更好 | 中等 |
ResUNet最大的优势不是"肉眼效果瞬间暴涨",而是:
当模型变深时,它依然更容易训练。
十三、踩坑记录
坑1:残差连接通道不一致
如果 in_channels != out_channels,不能直接相加。
必须用 1x1 卷积对齐:
python
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
坑2:BatchNorm在小batch下不稳定
如果 batch_size 很小,比如 1 或 2,BatchNorm可能表现不好。
可以替换为:
python
nn.GroupNorm(8, channels)
这是后续优化方向。
坑3:残差块不是越多越好
如果数据量很小,模型太深反而过拟合。
建议:
- 小数据:3层encoder足够
- 中等数据:4层encoder
- 大数据:再考虑更深结构
十四、适合收藏总结
ResUNet适合什么场景?
- 普通UNet训练不稳
- 图像细节恢复要求高
- 噪声类型比较复杂
- 模型需要加深
避坑清单
- 残差相加前通道必须一致
- 小batch慎用BatchNorm
- AdamW比Adam更稳
- 深层模型建议加梯度裁剪
- 数据少时不要盲目加深
十五、优化建议
ResUNet还能继续升级:
- GroupNorm替代BatchNorm
- 加SE注意力模块
- 加多尺度监督
- 用感知损失增强纹理
- 改成Residual Attention UNet
结尾总结
UNet解决了多尺度问题,ResUNet进一步解决了深层训练稳定性问题。
真正做工程时,不要盲目追最新模型。
很多时候,一个训练稳定、结构清晰、可维护的 ResUNet,比复杂但不可控的大模型更适合落地。
下一篇预告
Pytorch图像去噪实战(四):Attention UNet图像去噪,让模型学会关注边缘和纹理区域