Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题

Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题


一、问题场景:合成噪声训练很好,真实图片效果却很差

前面我们训练的模型大多基于一个假设:

噪声是高斯噪声。

也就是训练时这样造数据:

python 复制代码
noise = torch.randn_like(clean) * sigma / 255.0

这在论文实验里很常见,也方便复现。

但在真实工程里,我遇到一个非常现实的问题:

模型在合成噪声测试集上PSNR很高,但处理真实手机照片、截图、扫描件时效果明显变差。

真实噪声往往不是简单高斯噪声,它可能包含:

  • 传感器噪声
  • JPEG压缩噪声
  • 低光噪声
  • 颜色偏移
  • 局部噪声不均匀
  • 锐化产生的伪影

因此,只靠合成高斯噪声训练出来的模型,很容易出现泛化不足。

这一篇我们参考 CBDNet 的思路,做一个更接近真实噪声场景的去噪模型。


二、CBDNet解决什么问题?

CBDNet的核心思想可以概括为:

先估计噪声,再根据噪声分布进行去噪。

它不是假设整张图噪声强度一样,而是认为不同区域噪声可能不同。

这非常符合真实情况。

比如一张夜景照片:

  • 暗部噪声很重
  • 亮部噪声较轻
  • 边缘区域可能有压缩伪影

普通模型无法区分这些区域,而 CBDNet 会先预测一张 noise map。


三、整体架构设计

我们实现一个简化版 CBDNet,分成两个子网络:

1. Noise Estimation Network

输入 noisy image,输出 noise map。

2. Denoising Network

输入 noisy image + noise map,输出 clean image。

整体流程:

text 复制代码
noisy -> noise estimation -> noise map
noisy + noise map -> denoising network -> clean

四、工程目录结构

复制代码
cbdnet_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── cbdnet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、数据构建:模拟更真实的噪声

真实噪声比高斯噪声复杂。

这里我们用一个工程上常见的简化方式:

  • 随机高斯噪声
  • 随机JPEG压缩
  • 随机噪声强度
  • 局部噪声变化

dataset.py

python 复制代码
import os
import random
import io
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class RealisticNoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def jpeg_compress(self, img):
        quality = random.randint(30, 95)
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=quality)
        buffer.seek(0)
        return Image.open(buffer).convert("L")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        if random.random() < 0.5:
            img = self.jpeg_compress(img)

        base = self.to_tensor(img)

        sigma = random.uniform(5, 50) / 255.0
        noise = torch.randn_like(base) * sigma

        noisy = torch.clamp(base + noise, 0.0, 1.0)

        noise_map = torch.ones_like(clean) * sigma

        return noisy, noise_map, clean

六、CBDNet模型实现

models/cbdnet.py

python 复制代码
import torch
import torch.nn as nn


class NoiseEstimationNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 1, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)


class DenoisingNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(2, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 1, 3, padding=1)
        )

    def forward(self, noisy, noise_map):
        x = torch.cat([noisy, noise_map], dim=1)
        residual = self.net(x)
        return noisy - residual


class CBDNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.noise_estimator = NoiseEstimationNet()
        self.denoiser = DenoisingNet()

    def forward(self, noisy):
        noise_map = self.noise_estimator(noisy)
        clean = self.denoiser(noisy, noise_map)
        return clean, noise_map

七、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import RealisticNoiseDataset
from models.cbdnet import CBDNet


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = RealisticNoiseDataset("data/train")
    loader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=4)

    model = CBDNet().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    image_loss = torch.nn.L1Loss()
    noise_loss = torch.nn.L1Loss()

    for epoch in range(1, 61):
        model.train()
        total_loss = 0

        for noisy, gt_noise_map, clean in loader:
            noisy = noisy.to(device)
            gt_noise_map = gt_noise_map.to(device)
            clean = clean.to(device)

            pred_clean, pred_noise_map = model(noisy)

            loss_img = image_loss(pred_clean, clean)
            loss_noise = noise_loss(pred_noise_map, gt_noise_map)

            loss = loss_img + 0.2 * loss_noise

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"cbdnet_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

八、为什么要监督noise map?

很多人实现类似结构时,只训练最终输出,不监督 noise map。

这样会导致一个问题:

noise estimator 学不到明确含义,只变成一个中间黑盒特征。

我们这里加入:

python 复制代码
loss_noise = L1(pred_noise_map, gt_noise_map)

目的不是让 noise map 完全精确,而是给它一个稳定训练方向。


九、推理代码

python 复制代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.cbdnet import CBDNet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CBDNet().to(device)
model.load_state_dict(torch.load("cbdnet_epoch_60.pth", map_location=device))
model.eval()

img = Image.open("real_noisy.png").convert("L")
to_tensor = transforms.ToTensor()

noisy = to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    pred, noise_map = model(noisy)
    pred = torch.clamp(pred, 0.0, 1.0)

vutils.save_image(pred.cpu(), "cbdnet_result.png")
vutils.save_image(noise_map.cpu(), "estimated_noise_map.png")

十、真实噪声任务中的重要经验

1. 不要只训练高斯噪声

高斯噪声只是最干净的实验设定,不能代表真实图片。

2. 压缩噪声必须加入

现实中的图片大部分都经历过压缩。

3. 不要过度追PSNR

真实噪声下,没有干净GT时,PSNR不一定可用。

肉眼效果、业务指标更重要。

比如 OCR 场景要看识别率,而不是只看图像指标。


十一、踩坑记录

坑1:noise map全变成常数

原因:

  • noise loss权重太小
  • 数据噪声变化不足

解决:

python 复制代码
loss = loss_img + 0.2 * loss_noise

坑2:真实图片去噪后发糊

原因:

  • 训练噪声太单一
  • L1Loss仍然偏平滑

解决:

  • 加JPEG噪声
  • 加随机噪声强度
  • 加边缘损失

坑3:噪声估计图没有意义

noise map不是一定要和真实噪声完全一致,它的价值在于给 denoiser 提供区域性噪声提示。

不要把它当成最终产品,而是中间引导。


十二、效果验证

在合成噪声测试中,CBDNet不一定显著超过UNet。

但在真实噪声场景中,它通常更稳。

场景 UNet CBDNet
高斯噪声 表现好 表现好
JPEG压缩 一般 更稳
低光噪声 容易残留噪点 更自然
真实截图 有伪影 更干净

十三、适合收藏总结

CBDNet完整流程

  1. 输入真实带噪图
  2. 先预测 noise map
  3. 拼接 noisy 和 noise map
  4. 再进行去噪
  5. 输出 clean image

避坑清单

  • 不要只用高斯噪声
  • 加入JPEG压缩增强
  • noise map需要辅助监督
  • 真实场景不要迷信PSNR
  • 业务指标更重要

十四、优化建议

可以继续改进:

  • Noise Estimator改成UNet结构
  • Denoiser改成ResUNet
  • 加感知损失
  • 加OCR识别损失
  • 使用真实噪声数据集微调

结尾总结

CBDNet真正解决的是一个非常工程化的问题:

模型在实验数据上很好,但真实图片不好用。

它的关键不是某个复杂模块,而是建模方式变了:

先估计噪声,再根据噪声去恢复图像。

这是图像去噪从"实验室模型"走向"真实工程"的重要一步。


下一篇预告

Pytorch图像去噪实战(七):Noise2Noise自监督去噪实战,没有干净图也能训练模型

相关推荐
通信小呆呆2 天前
当算法有了“五感”:多模态数据融合如何向人体感官协同学习?
人工智能·学习·算法·机器学习·机器人
程序猿追2 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
xiao5kou4chang6kai42 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
renhongxia12 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC2 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
code_pgf2 天前
端到端自动驾驶 BEV stack
人工智能·机器学习·自动驾驶
闵孚龙2 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
Godspeed Zhao2 天前
Level 4自动驾驶系统设计3——功能与场景3
人工智能·机器学习·自动驾驶
H178535090962 天前
SolidWorks第四部分_直接实体建模特征9_替换面原理
线性代数·算法·机器学习·3d建模·solidworks
Godspeed Zhao2 天前
现代智能汽车系统——智驾SoC之框架版图
人工智能·机器学习·自动驾驶·汽车·soc