Pytorch图像去噪实战(三):ResUNet图像去噪模型实战,解决UNet深层训练不稳定问题

Pytorch图像去噪实战(三):ResUNet图像去噪模型实战,解决UNet深层训练不稳定问题


一、问题场景:UNet效果提升了,但训练深一点就不稳定

上一篇我们用 UNet 替代 DnCNN,解决了很多细节丢失问题。

但当我想继续提升模型能力时,很快遇到了新的坑:

UNet层数加深后,loss下降变慢,训练后期震荡,甚至出现输出发灰、边缘断裂的问题。

一开始我以为是学习率问题,于是尝试:

  • lr 从 1e-4 降到 1e-5
  • batch size 从 8 改成 4
  • epoch 从 50 加到 100

结果发现问题只是缓解,并没有根本解决。

后面定位发现:

当UNet变深后,普通卷积块训练难度增加,梯度传递不够稳定。

这时就需要引入一个非常经典的结构:Residual Block 残差块


二、真实问题分析:为什么普通UNet加深后容易不稳定?

UNet本身已经有 skip connection,但它的 skip 是 encoder 到 decoder 的跨层连接。

而每个卷积块内部,仍然是普通卷积堆叠:

python 复制代码
Conv -> BN -> ReLU -> Conv -> BN -> ReLU

如果模型变深,这种结构会出现几个问题:

1. 梯度传递路径变长

深层网络训练时,梯度需要穿过很多层,容易衰减。

2. 低层信息容易被覆盖

图像去噪不是分类任务,浅层纹理信息非常重要。

普通卷积块可能会过度变换特征。

3. 输出容易偏平滑

当网络表达能力增强但约束不足时,很容易走向"平滑解"。


三、解决方案:把UNet中的卷积块替换为残差块

残差结构的核心思想:

不直接学习完整映射,而是学习输入和输出之间的残差。

公式可以理解为:

output = F(x) + x

这样做的好处是:

  • 梯度更容易回传
  • 特征不容易被破坏
  • 深层网络更容易训练
  • 对图像恢复任务非常友好

四、ResUNet整体结构

ResUNet = UNet主体结构 + Residual Block

整体仍然是:

Encoder -> Bottleneck -> Decoder

区别是:

普通 DoubleConv 替换为 ResBlock。


五、工程目录结构

复制代码
resunet_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── resunet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

六、残差块实现

models/resunet.py

python 复制代码
import torch
import torch.nn as nn


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Identity()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        shortcut = self.shortcut(x)
        return self.relu(out + shortcut)

七、完整ResUNet模型代码

python 复制代码
import torch
import torch.nn as nn


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Identity()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.main(x) + self.shortcut(x))


class ResUNetDenoise(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.pool = nn.MaxPool2d(2)

        self.enc1 = ResBlock(in_channels, 64)
        self.enc2 = ResBlock(64, 128)
        self.enc3 = ResBlock(128, 256)

        self.bottleneck = ResBlock(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = ResBlock(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = ResBlock(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = ResBlock(128, 64)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        b = self.bottleneck(self.pool(e3))

        d3 = self.up3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.out(d1)

八、数据集代码

python 复制代码
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class DenoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            left = random.randint(0, w - self.patch_size)
            top = random.randint(0, h - self.patch_size)
            img = img.crop((left, top, left + self.patch_size, top + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        sigma = random.choice([15, 25, 35, 50])
        noise = torch.randn_like(clean) * sigma / 255.0
        noisy = torch.clamp(clean + noise, 0.0, 1.0)

        return noisy, clean

九、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.resunet import ResUNetDenoise


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = DenoiseDataset("data/train")
    loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

    model = ResUNetDenoise().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    criterion = torch.nn.L1Loss()

    for epoch in range(1, 61):
        model.train()
        total_loss = 0

        for noisy, clean in loader:
            noisy = noisy.to(device)
            clean = clean.to(device)

            pred = model(noisy)
            loss = criterion(pred, clean)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"resunet_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

十、为什么这里使用AdamW?

普通 Adam 能训练,但我在实验中发现:

  • Adam 前期下降快
  • 后期容易震荡
  • 输出有时出现轻微伪影

AdamW 加入权重衰减后更稳:

python 复制代码
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

对于图像恢复任务,AdamW通常比Adam更适合做中后期稳定训练。


十一、加入梯度裁剪,避免训练突然爆炸

当网络变深后,偶尔会出现 loss 突然变大的情况。

可以加入梯度裁剪:

python 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

放在:

python 复制代码
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

这一步在工程里非常实用。


十二、效果验证

实际对比:

模型 收敛稳定性 细节恢复 训练难度
UNet 中等 较好
ResUNet 更稳定 更好 中等

ResUNet最大的优势不是"肉眼效果瞬间暴涨",而是:

当模型变深时,它依然更容易训练。


十三、踩坑记录

坑1:残差连接通道不一致

如果 in_channels != out_channels,不能直接相加。

必须用 1x1 卷积对齐:

python 复制代码
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)

坑2:BatchNorm在小batch下不稳定

如果 batch_size 很小,比如 1 或 2,BatchNorm可能表现不好。

可以替换为:

python 复制代码
nn.GroupNorm(8, channels)

这是后续优化方向。


坑3:残差块不是越多越好

如果数据量很小,模型太深反而过拟合。

建议:

  • 小数据:3层encoder足够
  • 中等数据:4层encoder
  • 大数据:再考虑更深结构

十四、适合收藏总结

ResUNet适合什么场景?

  • 普通UNet训练不稳
  • 图像细节恢复要求高
  • 噪声类型比较复杂
  • 模型需要加深

避坑清单

  • 残差相加前通道必须一致
  • 小batch慎用BatchNorm
  • AdamW比Adam更稳
  • 深层模型建议加梯度裁剪
  • 数据少时不要盲目加深

十五、优化建议

ResUNet还能继续升级:

  • GroupNorm替代BatchNorm
  • 加SE注意力模块
  • 加多尺度监督
  • 用感知损失增强纹理
  • 改成Residual Attention UNet

结尾总结

UNet解决了多尺度问题,ResUNet进一步解决了深层训练稳定性问题。

真正做工程时,不要盲目追最新模型。

很多时候,一个训练稳定、结构清晰、可维护的 ResUNet,比复杂但不可控的大模型更适合落地。


下一篇预告

Pytorch图像去噪实战(四):Attention UNet图像去噪,让模型学会关注边缘和纹理区域

相关推荐
TDengine (老段)1 小时前
工业软件的未来:构建在工业数据底座之上的 AI Agent
大数据·数据库·人工智能·时序数据库·tdengine
aLTttY1 小时前
Spring Boot集成AI大模型实战:从0到1打造智能应用
人工智能·spring boot·后端
FlyIer5561 小时前
2026 个人网站建站软件实测
人工智能
Yuer20251 小时前
Case-X01豆包意图识别能力压力测试
人工智能·edca os
木枷1 小时前
SuffixDecoding: Extreme Speculative Decoding forEmerging AI Applications
人工智能
qq_白羊座1 小时前
提示词工程|大语言模型核心参数设置(含数值范围+适用场景)
人工智能
小苑同学1 小时前
《大模型的结构》
人工智能·自然语言处理
动物园猫1 小时前
高质量人体检测与行人识别数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·分类
Awu12271 小时前
🍎Claude.md 是啥?让你的 AI 助手乖乖听你的话
人工智能·ai编程·claude