Pytorch图像去噪实战(二):用UNet解决DnCNN细节丢失问题(结构解析+完整代码+踩坑总结)
一、问题场景:DnCNN能降噪,但细节被抹掉了
上一篇我们用 DnCNN 完成了图像去噪的入门实践。模型确实能把噪声压下去,但在真实项目里,我很快遇到一个新问题:
图像看起来干净了,但边缘、纹理、细线条也一起被抹掉了。
比如在 OCR、医学影像、老照片修复这类任务中,去噪不是简单地"让图片变平滑",而是要做到:
- 噪声减少
- 边缘保留
- 纹理不丢
- 结构不变形
我一开始以为是训练轮数不够,后来加大 epoch 后发现:
loss 继续下降,但图像越来越糊。
这说明问题不只是训练不充分,而是模型结构本身表达能力有限。
因此,这一篇我们换一个更适合图像恢复任务的结构:UNet。
二、真实问题分析:为什么DnCNN容易丢细节?
DnCNN本质上是一个普通卷积堆叠网络,它的问题主要有三个:
1. 感受野有限
浅层卷积看到的是局部区域,对大面积噪声、压缩块、复杂纹理恢复能力有限。
2. 没有显式多尺度建模
图像去噪并不是只看一个尺度:
- 小尺度:像素噪声
- 中尺度:纹理噪声
- 大尺度:光照不均、压缩块
DnCNN对多尺度信息建模能力较弱。
3. 细节容易在卷积中被平滑
MSE损失本身就倾向于生成平均解,如果网络没有结构保护细节,最后输出就容易发糊。
三、解决方案:使用UNet进行图像去噪
UNet最早用于医学图像分割,但后来被大量用于图像恢复任务,比如:
- 图像去噪
- 图像超分辨率
- 图像修复
- 扩散模型中的噪声预测网络
UNet的核心优势是:
下采样负责理解全局结构,上采样负责恢复图像细节,跳跃连接负责保留浅层信息。
四、UNet结构理解
UNet可以拆成三部分:
1. Encoder编码器
不断下采样,提取高级语义特征。
2. Bottleneck瓶颈层
在最低分辨率处整合全局信息。
3. Decoder解码器
逐步上采样,恢复原始图像尺寸。
4. Skip Connection跳跃连接
把编码器的浅层细节传给解码器。
这一步非常关键。
如果没有跳跃连接,模型在上采样时只能凭低分辨率特征"猜细节",很容易生成模糊结果。
五、项目结构设计
建议保持下面的项目结构:
unet_denoise/
├── data/
│ ├── train/
│ └── val/
├── models/
│ └── unet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py
这个结构适合后续扩展,比如加入 Attention UNet、ResUNet、SwinIR 等模型。
六、UNet模型完整实现
models/unet.py
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class UNetDenoise(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.enc1 = DoubleConv(in_channels, 64)
self.enc2 = DoubleConv(64, 128)
self.enc3 = DoubleConv(128, 256)
self.pool = nn.MaxPool2d(2)
self.bottleneck = DoubleConv(256, 512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = DoubleConv(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = DoubleConv(128, 64)
self.out = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
b = self.bottleneck(self.pool(e3))
d3 = self.up3(b)
d3 = torch.cat([d3, e3], dim=1)
d3 = self.dec3(d3)
d2 = self.up2(d3)
d2 = torch.cat([d2, e2], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
d1 = torch.cat([d1, e1], dim=1)
d1 = self.dec1(d1)
return self.out(d1)
七、数据集构建
dataset.py
python
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class DenoiseDataset(Dataset):
def __init__(self, image_dir, patch_size=128, sigma_list=(15, 25, 50)):
self.image_paths = [
os.path.join(image_dir, name)
for name in os.listdir(image_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.patch_size = patch_size
self.sigma_list = sigma_list
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.image_paths)
def random_crop(self, img):
w, h = img.size
if w < self.patch_size or h < self.patch_size:
img = img.resize((self.patch_size, self.patch_size))
return img
left = random.randint(0, w - self.patch_size)
top = random.randint(0, h - self.patch_size)
return img.crop((left, top, left + self.patch_size, top + self.patch_size))
def __getitem__(self, index):
img = Image.open(self.image_paths[index]).convert("L")
img = self.random_crop(img)
clean = self.to_tensor(img)
sigma = random.choice(self.sigma_list)
noise = torch.randn_like(clean) * sigma / 255.0
noisy = torch.clamp(clean + noise, 0.0, 1.0)
return noisy, clean
八、训练代码
train.py
python
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.unet import UNetDenoise
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = DenoiseDataset("data/train")
train_loader = DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
num_workers=4
)
model = UNetDenoise().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.L1Loss()
for epoch in range(1, 51):
model.train()
total_loss = 0
for noisy, clean in train_loader:
noisy = noisy.to(device)
clean = clean.to(device)
pred = model(noisy)
loss = criterion(pred, clean)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch}/50], Loss: {avg_loss:.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"unet_denoise_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
九、为什么这里用L1Loss而不是MSELoss?
这是一个工程里非常重要的细节。
我一开始直接用 MSELoss,结果发现:
- loss下降很快
- 图像也很干净
- 但细节明显发糊
原因是:
MSE对大误差惩罚更重,容易让模型学习平均化结果。
图像恢复任务里,L1Loss通常更稳一些,能更好保留边缘。
实际建议:
python
loss = L1Loss
如果想进一步提升效果,可以组合:
python
loss = l1_loss + 0.1 * ssim_loss
十、效果评估代码
eval.py
python
import math
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.unet import UNetDenoise
def calc_psnr(pred, target):
mse = torch.mean((pred - target) ** 2)
if mse.item() == 0:
return 100
return 20 * math.log10(1.0 / math.sqrt(mse.item()))
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetDenoise().to(device)
model.load_state_dict(torch.load("unet_denoise_epoch_50.pth", map_location=device))
model.eval()
img = Image.open("test.png").convert("L")
transform = transforms.ToTensor()
clean = transform(img).unsqueeze(0).to(device)
noise = torch.randn_like(clean) * 25 / 255.0
noisy = torch.clamp(clean + noise, 0.0, 1.0)
with torch.no_grad():
pred = model(noisy)
pred = torch.clamp(pred, 0.0, 1.0)
print("PSNR:", calc_psnr(pred, clean))
result = torch.cat([noisy.cpu(), pred.cpu(), clean.cpu()], dim=0)
vutils.save_image(result, "compare.png", nrow=3)
if __name__ == "__main__":
main()
十一、踩坑记录
坑1:上采样后尺寸对不上
如果输入图像尺寸不是 2 的倍数,多次下采样后会出现尺寸不一致。
解决方式:
- 训练时使用固定 patch,比如 128x128
- 输入尺寸尽量是 16 或 32 的倍数
坑2:显存爆炸
UNet比DnCNN显存占用明显更大。
解决方式:
- batch_size 从 4 或 8 开始
- patch_size 不要一开始就用 512
- 使用混合精度训练
坑3:输出偏灰
原因通常是没有 clamp。
推理时必须加:
python
pred = torch.clamp(pred, 0.0, 1.0)
十二、验证结果
在相同噪声强度 sigma=25 下,实际测试结果大致如下:
| 模型 | PSNR提升 | 视觉效果 |
|---|---|---|
| DnCNN | 中等 | 边缘略糊 |
| UNet | 更高 | 细节更清晰 |
UNet在纹理、边缘、文字类图像上的表现明显好于普通卷积堆叠模型。
十三、适合收藏总结
UNet去噪完整流程
- 准备干净图像
- 随机裁剪patch
- 添加多强度噪声
- 构建UNet模型
- 使用L1Loss训练
- 用PSNR和视觉效果共同评估
避坑清单
- 输入尺寸最好是 2 的倍数
- 不要盲目使用大图训练
- L1Loss通常比MSE更适合细节恢复
- 推理结果一定要 clamp
- batch_size要根据显存调整
十四、优化建议
UNet已经比DnCNN强很多,但仍然有改进空间:
- 加残差结构:ResUNet
- 加注意力机制:Attention UNet
- 加多尺度监督:Deep Supervision
- 换Transformer结构:SwinIR / Restormer
结尾总结
UNet真正解决的是 DnCNN 的结构短板:
DnCNN偏局部卷积,UNet具备多尺度恢复能力。
在真实工程里,如果只是做入门实验,DnCNN足够;但如果你要做更稳定的图像去噪系统,UNet才是更值得投入的基础模型。
下一篇预告
Pytorch图像去噪实战(三):ResUNet去噪模型实战,用残差结构解决深层网络训练不稳定问题