Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题

Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题


一、问题场景:合成噪声训练很好,真实图片效果却很差

前面我们训练的模型大多基于一个假设:

噪声是高斯噪声。

也就是训练时这样造数据:

python 复制代码
noise = torch.randn_like(clean) * sigma / 255.0

这在论文实验里很常见,也方便复现。

但在真实工程里,我遇到一个非常现实的问题:

模型在合成噪声测试集上PSNR很高,但处理真实手机照片、截图、扫描件时效果明显变差。

真实噪声往往不是简单高斯噪声,它可能包含:

  • 传感器噪声
  • JPEG压缩噪声
  • 低光噪声
  • 颜色偏移
  • 局部噪声不均匀
  • 锐化产生的伪影

因此,只靠合成高斯噪声训练出来的模型,很容易出现泛化不足。

这一篇我们参考 CBDNet 的思路,做一个更接近真实噪声场景的去噪模型。


二、CBDNet解决什么问题?

CBDNet的核心思想可以概括为:

先估计噪声,再根据噪声分布进行去噪。

它不是假设整张图噪声强度一样,而是认为不同区域噪声可能不同。

这非常符合真实情况。

比如一张夜景照片:

  • 暗部噪声很重
  • 亮部噪声较轻
  • 边缘区域可能有压缩伪影

普通模型无法区分这些区域,而 CBDNet 会先预测一张 noise map。


三、整体架构设计

我们实现一个简化版 CBDNet,分成两个子网络:

1. Noise Estimation Network

输入 noisy image,输出 noise map。

2. Denoising Network

输入 noisy image + noise map,输出 clean image。

整体流程:

text 复制代码
noisy -> noise estimation -> noise map
noisy + noise map -> denoising network -> clean

四、工程目录结构

复制代码
cbdnet_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── cbdnet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、数据构建:模拟更真实的噪声

真实噪声比高斯噪声复杂。

这里我们用一个工程上常见的简化方式:

  • 随机高斯噪声
  • 随机JPEG压缩
  • 随机噪声强度
  • 局部噪声变化

dataset.py

python 复制代码
import os
import random
import io
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class RealisticNoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def jpeg_compress(self, img):
        quality = random.randint(30, 95)
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=quality)
        buffer.seek(0)
        return Image.open(buffer).convert("L")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        if random.random() < 0.5:
            img = self.jpeg_compress(img)

        base = self.to_tensor(img)

        sigma = random.uniform(5, 50) / 255.0
        noise = torch.randn_like(base) * sigma

        noisy = torch.clamp(base + noise, 0.0, 1.0)

        noise_map = torch.ones_like(clean) * sigma

        return noisy, noise_map, clean

六、CBDNet模型实现

models/cbdnet.py

python 复制代码
import torch
import torch.nn as nn


class NoiseEstimationNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 1, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)


class DenoisingNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(2, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 1, 3, padding=1)
        )

    def forward(self, noisy, noise_map):
        x = torch.cat([noisy, noise_map], dim=1)
        residual = self.net(x)
        return noisy - residual


class CBDNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.noise_estimator = NoiseEstimationNet()
        self.denoiser = DenoisingNet()

    def forward(self, noisy):
        noise_map = self.noise_estimator(noisy)
        clean = self.denoiser(noisy, noise_map)
        return clean, noise_map

七、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import RealisticNoiseDataset
from models.cbdnet import CBDNet


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = RealisticNoiseDataset("data/train")
    loader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=4)

    model = CBDNet().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    image_loss = torch.nn.L1Loss()
    noise_loss = torch.nn.L1Loss()

    for epoch in range(1, 61):
        model.train()
        total_loss = 0

        for noisy, gt_noise_map, clean in loader:
            noisy = noisy.to(device)
            gt_noise_map = gt_noise_map.to(device)
            clean = clean.to(device)

            pred_clean, pred_noise_map = model(noisy)

            loss_img = image_loss(pred_clean, clean)
            loss_noise = noise_loss(pred_noise_map, gt_noise_map)

            loss = loss_img + 0.2 * loss_noise

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"cbdnet_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

八、为什么要监督noise map?

很多人实现类似结构时,只训练最终输出,不监督 noise map。

这样会导致一个问题:

noise estimator 学不到明确含义,只变成一个中间黑盒特征。

我们这里加入:

python 复制代码
loss_noise = L1(pred_noise_map, gt_noise_map)

目的不是让 noise map 完全精确,而是给它一个稳定训练方向。


九、推理代码

python 复制代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.cbdnet import CBDNet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CBDNet().to(device)
model.load_state_dict(torch.load("cbdnet_epoch_60.pth", map_location=device))
model.eval()

img = Image.open("real_noisy.png").convert("L")
to_tensor = transforms.ToTensor()

noisy = to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    pred, noise_map = model(noisy)
    pred = torch.clamp(pred, 0.0, 1.0)

vutils.save_image(pred.cpu(), "cbdnet_result.png")
vutils.save_image(noise_map.cpu(), "estimated_noise_map.png")

十、真实噪声任务中的重要经验

1. 不要只训练高斯噪声

高斯噪声只是最干净的实验设定,不能代表真实图片。

2. 压缩噪声必须加入

现实中的图片大部分都经历过压缩。

3. 不要过度追PSNR

真实噪声下,没有干净GT时,PSNR不一定可用。

肉眼效果、业务指标更重要。

比如 OCR 场景要看识别率,而不是只看图像指标。


十一、踩坑记录

坑1:noise map全变成常数

原因:

  • noise loss权重太小
  • 数据噪声变化不足

解决:

python 复制代码
loss = loss_img + 0.2 * loss_noise

坑2:真实图片去噪后发糊

原因:

  • 训练噪声太单一
  • L1Loss仍然偏平滑

解决:

  • 加JPEG噪声
  • 加随机噪声强度
  • 加边缘损失

坑3:噪声估计图没有意义

noise map不是一定要和真实噪声完全一致,它的价值在于给 denoiser 提供区域性噪声提示。

不要把它当成最终产品,而是中间引导。


十二、效果验证

在合成噪声测试中,CBDNet不一定显著超过UNet。

但在真实噪声场景中,它通常更稳。

场景 UNet CBDNet
高斯噪声 表现好 表现好
JPEG压缩 一般 更稳
低光噪声 容易残留噪点 更自然
真实截图 有伪影 更干净

十三、适合收藏总结

CBDNet完整流程

  1. 输入真实带噪图
  2. 先预测 noise map
  3. 拼接 noisy 和 noise map
  4. 再进行去噪
  5. 输出 clean image

避坑清单

  • 不要只用高斯噪声
  • 加入JPEG压缩增强
  • noise map需要辅助监督
  • 真实场景不要迷信PSNR
  • 业务指标更重要

十四、优化建议

可以继续改进:

  • Noise Estimator改成UNet结构
  • Denoiser改成ResUNet
  • 加感知损失
  • 加OCR识别损失
  • 使用真实噪声数据集微调

结尾总结

CBDNet真正解决的是一个非常工程化的问题:

模型在实验数据上很好,但真实图片不好用。

它的关键不是某个复杂模块,而是建模方式变了:

先估计噪声,再根据噪声去恢复图像。

这是图像去噪从"实验室模型"走向"真实工程"的重要一步。


下一篇预告

Pytorch图像去噪实战(七):Noise2Noise自监督去噪实战,没有干净图也能训练模型

相关推荐
博.闻广见6 小时前
AI_概率统计-2.常见分布
人工智能·机器学习
小糖学代码6 小时前
LLM系列:2.pytorch入门:8.神经网络的损失函数(criterion)
人工智能·深度学习·神经网络
Jmayday6 小时前
Pytorch:RNN理论基础
pytorch·rnn·深度学习
谭欣辰8 小时前
C++快速幂完整实战讲解
算法·决策树·机器学习
AI周红伟8 小时前
周红伟:GPT-Image-2深度解析:从技术原理到实战教程,为什么它能让整个AI圈炸锅?
人工智能·gpt·深度学习·机器学习·语言模型·openclaw
*Lisen8 小时前
从零手写 FlashAttention(PyTorch实现 + 原理推导)
人工智能·pytorch·python
端平入洛9 小时前
梯度是什么:PyTorch 自动求导详解
人工智能·深度学习
Uopiasd1234oo9 小时前
上下文引导模块改进YOLOv26局部与全局特征融合能力双重提升
深度学习·yolo·机器学习
哥布林学者9 小时前
深度学习进阶(十四)ConvNeXt
机器学习·ai