Pytorch图像去噪实战(二):用UNet解决DnCNN细节丢失问题(结构解析+完整代码+踩坑总结)

Pytorch图像去噪实战(二):用UNet解决DnCNN细节丢失问题(结构解析+完整代码+踩坑总结)


一、问题场景:DnCNN能降噪,但细节被抹掉了

上一篇我们用 DnCNN 完成了图像去噪的入门实践。模型确实能把噪声压下去,但在真实项目里,我很快遇到一个新问题:

图像看起来干净了,但边缘、纹理、细线条也一起被抹掉了。

比如在 OCR、医学影像、老照片修复这类任务中,去噪不是简单地"让图片变平滑",而是要做到:

  • 噪声减少
  • 边缘保留
  • 纹理不丢
  • 结构不变形

我一开始以为是训练轮数不够,后来加大 epoch 后发现:

loss 继续下降,但图像越来越糊。

这说明问题不只是训练不充分,而是模型结构本身表达能力有限。

因此,这一篇我们换一个更适合图像恢复任务的结构:UNet


二、真实问题分析:为什么DnCNN容易丢细节?

DnCNN本质上是一个普通卷积堆叠网络,它的问题主要有三个:

1. 感受野有限

浅层卷积看到的是局部区域,对大面积噪声、压缩块、复杂纹理恢复能力有限。

2. 没有显式多尺度建模

图像去噪并不是只看一个尺度:

  • 小尺度:像素噪声
  • 中尺度:纹理噪声
  • 大尺度:光照不均、压缩块

DnCNN对多尺度信息建模能力较弱。

3. 细节容易在卷积中被平滑

MSE损失本身就倾向于生成平均解,如果网络没有结构保护细节,最后输出就容易发糊。


三、解决方案:使用UNet进行图像去噪

UNet最早用于医学图像分割,但后来被大量用于图像恢复任务,比如:

  • 图像去噪
  • 图像超分辨率
  • 图像修复
  • 扩散模型中的噪声预测网络

UNet的核心优势是:

下采样负责理解全局结构,上采样负责恢复图像细节,跳跃连接负责保留浅层信息。


四、UNet结构理解

UNet可以拆成三部分:

1. Encoder编码器

不断下采样,提取高级语义特征。

2. Bottleneck瓶颈层

在最低分辨率处整合全局信息。

3. Decoder解码器

逐步上采样,恢复原始图像尺寸。

4. Skip Connection跳跃连接

把编码器的浅层细节传给解码器。

这一步非常关键。

如果没有跳跃连接,模型在上采样时只能凭低分辨率特征"猜细节",很容易生成模糊结果。


五、项目结构设计

建议保持下面的项目结构:

复制代码
unet_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── unet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

这个结构适合后续扩展,比如加入 Attention UNet、ResUNet、SwinIR 等模型。


六、UNet模型完整实现

models/unet.py

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class UNetDenoise(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        b = self.bottleneck(self.pool(e3))

        d3 = self.up3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.out(d1)

七、数据集构建

dataset.py

python 复制代码
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class DenoiseDataset(Dataset):
    def __init__(self, image_dir, patch_size=128, sigma_list=(15, 25, 50)):
        self.image_paths = [
            os.path.join(image_dir, name)
            for name in os.listdir(image_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.sigma_list = sigma_list
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.image_paths)

    def random_crop(self, img):
        w, h = img.size
        if w < self.patch_size or h < self.patch_size:
            img = img.resize((self.patch_size, self.patch_size))
            return img

        left = random.randint(0, w - self.patch_size)
        top = random.randint(0, h - self.patch_size)

        return img.crop((left, top, left + self.patch_size, top + self.patch_size))

    def __getitem__(self, index):
        img = Image.open(self.image_paths[index]).convert("L")
        img = self.random_crop(img)
        clean = self.to_tensor(img)

        sigma = random.choice(self.sigma_list)
        noise = torch.randn_like(clean) * sigma / 255.0
        noisy = torch.clamp(clean + noise, 0.0, 1.0)

        return noisy, clean

八、训练代码

train.py

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.unet import UNetDenoise


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = DenoiseDataset("data/train")
    train_loader = DataLoader(
        train_dataset,
        batch_size=8,
        shuffle=True,
        num_workers=4
    )

    model = UNetDenoise().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = torch.nn.L1Loss()

    for epoch in range(1, 51):
        model.train()
        total_loss = 0

        for noisy, clean in train_loader:
            noisy = noisy.to(device)
            clean = clean.to(device)

            pred = model(noisy)
            loss = criterion(pred, clean)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch}/50], Loss: {avg_loss:.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"unet_denoise_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

九、为什么这里用L1Loss而不是MSELoss?

这是一个工程里非常重要的细节。

我一开始直接用 MSELoss,结果发现:

  • loss下降很快
  • 图像也很干净
  • 但细节明显发糊

原因是:

MSE对大误差惩罚更重,容易让模型学习平均化结果。

图像恢复任务里,L1Loss通常更稳一些,能更好保留边缘。

实际建议:

python 复制代码
loss = L1Loss

如果想进一步提升效果,可以组合:

python 复制代码
loss = l1_loss + 0.1 * ssim_loss

十、效果评估代码

eval.py

python 复制代码
import math
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.unet import UNetDenoise


def calc_psnr(pred, target):
    mse = torch.mean((pred - target) ** 2)
    if mse.item() == 0:
        return 100
    return 20 * math.log10(1.0 / math.sqrt(mse.item()))


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = UNetDenoise().to(device)
    model.load_state_dict(torch.load("unet_denoise_epoch_50.pth", map_location=device))
    model.eval()

    img = Image.open("test.png").convert("L")
    transform = transforms.ToTensor()

    clean = transform(img).unsqueeze(0).to(device)
    noise = torch.randn_like(clean) * 25 / 255.0
    noisy = torch.clamp(clean + noise, 0.0, 1.0)

    with torch.no_grad():
        pred = model(noisy)
        pred = torch.clamp(pred, 0.0, 1.0)

    print("PSNR:", calc_psnr(pred, clean))

    result = torch.cat([noisy.cpu(), pred.cpu(), clean.cpu()], dim=0)
    vutils.save_image(result, "compare.png", nrow=3)


if __name__ == "__main__":
    main()

十一、踩坑记录

坑1:上采样后尺寸对不上

如果输入图像尺寸不是 2 的倍数,多次下采样后会出现尺寸不一致。

解决方式:

  • 训练时使用固定 patch,比如 128x128
  • 输入尺寸尽量是 16 或 32 的倍数

坑2:显存爆炸

UNet比DnCNN显存占用明显更大。

解决方式:

  • batch_size 从 4 或 8 开始
  • patch_size 不要一开始就用 512
  • 使用混合精度训练

坑3:输出偏灰

原因通常是没有 clamp。

推理时必须加:

python 复制代码
pred = torch.clamp(pred, 0.0, 1.0)

十二、验证结果

在相同噪声强度 sigma=25 下,实际测试结果大致如下:

模型 PSNR提升 视觉效果
DnCNN 中等 边缘略糊
UNet 更高 细节更清晰

UNet在纹理、边缘、文字类图像上的表现明显好于普通卷积堆叠模型。


十三、适合收藏总结

UNet去噪完整流程

  1. 准备干净图像
  2. 随机裁剪patch
  3. 添加多强度噪声
  4. 构建UNet模型
  5. 使用L1Loss训练
  6. 用PSNR和视觉效果共同评估

避坑清单

  • 输入尺寸最好是 2 的倍数
  • 不要盲目使用大图训练
  • L1Loss通常比MSE更适合细节恢复
  • 推理结果一定要 clamp
  • batch_size要根据显存调整

十四、优化建议

UNet已经比DnCNN强很多,但仍然有改进空间:

  • 加残差结构:ResUNet
  • 加注意力机制:Attention UNet
  • 加多尺度监督:Deep Supervision
  • 换Transformer结构:SwinIR / Restormer

结尾总结

UNet真正解决的是 DnCNN 的结构短板:

DnCNN偏局部卷积,UNet具备多尺度恢复能力。

在真实工程里,如果只是做入门实验,DnCNN足够;但如果你要做更稳定的图像去噪系统,UNet才是更值得投入的基础模型。


下一篇预告

Pytorch图像去噪实战(三):ResUNet去噪模型实战,用残差结构解决深层网络训练不稳定问题

相关推荐
RD_daoyi1 小时前
GEO时代:AI 重构下,SEO的本质与破局之路
人工智能·重构
GJGCY1 小时前
金融AI Agent平台技术路线与落地能力对比:7家主流智能体优缺点分析
人工智能·ai·金融·数字化·智能体
直奔標竿1 小时前
Java开发者AI转型第二十二课!Spring AI 个人知识库实战(一)——架构搭建与核心契约落地
java·人工智能·后端·spring·架构
益企联工程项目管理软件1 小时前
2026工程管理软件推荐:7款工具助力工程项目数字化升级!
大数据·人工智能·云原生·项目管理·制造
dFObBIMmai1 小时前
CSS如何检测页面浮动元素位置_使用审查工具与clear
jvm·数据库·python
熊猫钓鱼>_>1 小时前
大型复杂远程AI Agent应用:从架构困局到进化突围
人工智能·ai·架构·开源·大模型·llm·agent
qq_460978401 小时前
实现 Svelte 中基于数组索引的 details 元素单开单关交互
jvm·数据库·python
AI前沿资讯1 小时前
支持视频动作迁移的AI 3D平台有哪些?2026全维度测评
人工智能·3d
AwesomeCPA1 小时前
Claude Code 实战分享(1):从“代码助手“到“AI 协调者“
人工智能