Pytorch图像去噪实战(十一):Diffusion扩散模型去噪入门,从噪声预测理解生成式图像恢复

Pytorch图像去噪实战(十一):Diffusion扩散模型去噪入门,从噪声预测理解生成式图像恢复


一、问题场景:传统去噪模型能用,但上限开始明显

前面我们已经做了 DnCNN、UNet、ResUNet、Attention UNet、FFDNet、CBDNet、Noise2Noise、Noise2Void、SwinIR、Restormer。

这些模型有一个共同特点:

大多数都是直接学习 noisy -> clean,或者 noisy -> noise。

在普通图像去噪任务里,这样已经够用。

但当我处理一些复杂图像时,问题开始变明显:

  • 高噪声图像细节恢复差
  • 老照片去噪后纹理不自然
  • 真实噪声图像容易残留伪影
  • 强去噪后图像发糊
  • 模型对未知噪声泛化不足

这时就会接触到一个更强的方向:Diffusion Model 扩散模型

扩散模型不是简单做一次映射,而是学习一个逐步去噪过程。


二、Diffusion去噪和普通去噪有什么区别?

普通去噪模型:

text 复制代码
noisy_image -> clean_image

Diffusion模型:

text 复制代码
clean_image -> 不断加噪 -> pure noise
pure noise -> 逐步去噪 -> clean_image

在训练阶段,它学习的是:

给定某一步的带噪图像,预测其中的噪声。

也就是:

text 复制代码
x_t -> noise

这和 DnCNN 的"预测噪声"思想有相似之处,但 Diffusion 更进一步,把噪声过程拆成了很多步。


三、核心思想:前向加噪与反向去噪

1. 前向过程

从干净图像 x0 开始,逐步加入噪声:

text 复制代码
x0 -> x1 -> x2 -> ... -> xT

最后 xT 接近纯噪声。

2. 反向过程

模型学习从 xT 一步步恢复:

text 复制代码
xT -> xT-1 -> ... -> x0

训练目标通常是预测噪声 epsilon。


四、工程目录结构

text 复制代码
diffusion_denoise/
├── data/
│   └── train/
├── models/
│   └── simple_unet.py
├── diffusion.py
├── dataset.py
├── train.py
├── sample.py
└── utils.py

五、数据集准备

这里先做灰度图像去噪,方便理解扩散模型流程。

python 复制代码
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class ImageDataset(Dataset):
    def __init__(self, root_dir, image_size=64):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")
        return self.transform(img)

六、扩散过程实现

diffusion.py

python 复制代码
import torch


class GaussianDiffusion:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
        self.timesteps = timesteps
        self.device = device

        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def add_noise(self, x0, t):
        noise = torch.randn_like(x0)

        alpha_bar = self.alpha_bars[t].view(-1, 1, 1, 1)

        noisy = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1.0 - alpha_bar) * noise

        return noisy, noise

七、时间步编码

Diffusion模型必须知道当前是第几步噪声。

python 复制代码
import torch
import torch.nn as nn
import math


class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)

        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)

        return self.mlp(emb)

八、简化版UNet噪声预测网络

models/simple_unet.py

python 复制代码
import torch
import torch.nn as nn


class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, t):
        t = t.float().view(-1, 1) / 1000.0
        return self.net(t)


class SimpleDenoiseUNet(nn.Module):
    def __init__(self, channels=1, base=64, time_dim=128):
        super().__init__()

        self.time_mlp = TimeEmbedding(time_dim)

        self.conv1 = nn.Conv2d(channels, base, 3, padding=1)
        self.conv2 = nn.Conv2d(base, base, 3, padding=1)
        self.conv3 = nn.Conv2d(base, channels, 3, padding=1)

        self.time_proj = nn.Linear(time_dim, base)

        self.act = nn.SiLU()

    def forward(self, x, t):
        time_emb = self.time_mlp(t)
        time_emb = self.time_proj(time_emb).view(x.size(0), -1, 1, 1)

        h = self.act(self.conv1(x))
        h = h + time_emb
        h = self.act(self.conv2(h))

        return self.conv3(h)

九、训练代码

train.py

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import ImageDataset
from diffusion import GaussianDiffusion
from models.simple_unet import SimpleDenoiseUNet


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = ImageDataset("data/train", image_size=64)
    loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    model = SimpleDenoiseUNet().to(device)
    diffusion = GaussianDiffusion(timesteps=1000, device=device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    criterion = torch.nn.MSELoss()

    for epoch in range(1, 101):
        model.train()
        total_loss = 0

        for x0 in loader:
            x0 = x0.to(device)

            t = torch.randint(0, diffusion.timesteps, (x0.size(0),), device=device)

            xt, noise = diffusion.add_noise(x0, t)

            pred_noise = model(xt, t)

            loss = criterion(pred_noise, noise)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"diffusion_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

十、为什么Diffusion训练预测noise,而不是预测clean?

这是很多人第一次学扩散模型时最容易疑惑的地方。

如果直接预测 clean:

text 复制代码
model(x_t, t) -> x0

模型在高噪声阶段很难恢复完整图像。

而预测 noise:

text 复制代码
model(x_t, t) -> epsilon

训练目标更稳定,也更符合扩散模型的数学推导。

工程上看,预测噪声还有一个优点:

loss更稳定,模型更容易收敛。


十一、采样过程简化实现

下面写一个简化版采样逻辑,帮助理解反向去噪。

python 复制代码
import torch
import torchvision.utils as vutils
from diffusion import GaussianDiffusion
from models.simple_unet import SimpleDenoiseUNet


@torch.no_grad()
def sample():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SimpleDenoiseUNet().to(device)
    model.load_state_dict(torch.load("diffusion_epoch_100.pth", map_location=device))
    model.eval()

    diffusion = GaussianDiffusion(timesteps=1000, device=device)

    x = torch.randn(16, 1, 64, 64).to(device)

    for i in reversed(range(diffusion.timesteps)):
        t = torch.full((x.size(0),), i, device=device, dtype=torch.long)

        pred_noise = model(x, t)

        beta = diffusion.betas[i]
        alpha = diffusion.alphas[i]
        alpha_bar = diffusion.alpha_bars[i]

        x = (1 / torch.sqrt(alpha)) * (
            x - (beta / torch.sqrt(1 - alpha_bar)) * pred_noise
        )

        if i > 0:
            noise = torch.randn_like(x)
            x = x + torch.sqrt(beta) * noise

    x = torch.clamp(x, 0.0, 1.0)
    vutils.save_image(x.cpu(), "diffusion_samples.png", nrow=4)


if __name__ == "__main__":
    sample()

十二、踩坑记录

坑1:时间步没有输入模型

Diffusion模型必须知道 t。

如果只输入 x_t,不输入 t,模型不知道当前噪声强度,训练会非常差。


坑2:学习率过大导致loss震荡

扩散模型训练比普通UNet更敏感。

建议:

python 复制代码
lr = 2e-4

如果不稳定,降到:

python 复制代码
lr = 1e-4

坑3:图像尺寸一开始不要太大

Diffusion训练成本高。

建议从:

text 复制代码
64x64

开始,流程跑通后再放大。


十三、适合收藏总结

Diffusion去噪训练流程

  1. 读取干净图像
  2. 随机采样时间步 t
  3. 根据 t 给图像加噪
  4. 模型预测噪声
  5. 用真实噪声监督训练
  6. 推理时逐步反向去噪

避坑清单

  • 必须输入时间步
  • 训练目标建议预测noise
  • 初期图像尺寸别太大
  • 学习率不要过高
  • 采样速度较慢是正常现象

十四、优化建议

可以继续升级:

  • 更完整UNet结构
  • 加Attention模块
  • 使用DDIM加速采样
  • 支持条件去噪
  • 使用真实噪声数据微调

结尾总结

Diffusion模型的核心不是"一个更大的UNet",而是一套新的去噪建模方式:

把图像恢复拆成多个连续的小步骤,让模型逐步从噪声中恢复结构。

如果你已经理解 DnCNN 的残差噪声预测,那么学习 Diffusion 会更容易,因为它本质上也是在学噪声,只是把这个过程做得更细。


下一篇预告

Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程

相关推荐
电科一班林耿超7 小时前
机器学习大师课 第 4 课:分类问题入门 —— 逻辑回归(垃圾邮件分类实战)
人工智能·机器学习·分类·逻辑回归
小怪兽会微笑7 小时前
世界模型Genie 论文解读
人工智能·深度学习·agi
前端技术7 小时前
机器学习性能评估_指标偏差与工程实践
机器学习·性能优化·混淆矩阵·交叉验证·分布偏移
nonono7 小时前
深度学习基础——(3)视觉处理基础实战【CNN实现CIFAR10 多分类】
深度学习·分类·cnn
RWKV元始智能16 小时前
RWKV超并发项目教程,RWKV-LM训练提速40%
人工智能·rnn·深度学习·自然语言处理·开源
AI技术增长17 小时前
Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题
pytorch·深度学习·机器学习
博.闻广见19 小时前
AI_概率统计-2.常见分布
人工智能·机器学习
小糖学代码19 小时前
LLM系列:2.pytorch入门:8.神经网络的损失函数(criterion)
人工智能·深度学习·神经网络
Jmayday19 小时前
Pytorch:RNN理论基础
pytorch·rnn·深度学习