Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程

Pytorch图像去噪实战(十二):DDPM图像去噪完整训练流程,构建可复现扩散模型工程


一、问题场景:扩散模型能跑,但工程代码很容易写乱

上一篇我们从最小实现理解了 Diffusion 的核心逻辑。

但如果真正放到项目里,会很快遇到问题:

  • beta schedule 写在训练脚本里,后续不好改
  • 采样逻辑和训练逻辑混在一起
  • 模型保存与恢复不规范
  • 训练参数不可复现
  • 后续无法扩展 DDIM、条件去噪、彩色图像

很多人学扩散模型时,能写出一个 demo,但很难整理成工程。

这一篇我们重点做一件事:

把 DDPM 图像去噪流程整理成一个可复现、可扩展的工程结构。


二、DDPM核心训练目标

DDPM训练目标仍然是预测噪声:

text 复制代码
epsilon_theta(x_t, t) ≈ epsilon

训练时:

  1. 从数据集中取 clean image x0
  2. 随机采样时间步 t
  3. 根据 t 给 x0 加噪得到 xt
  4. 模型输入 xt 和 t
  5. 模型预测 noise
  6. 使用 MSELoss 训练

三、推荐工程结构

text 复制代码
ddpm_denoise/
├── configs/
│   └── train_config.py
├── data/
│   └── train/
├── models/
│   └── unet.py
├── diffusion/
│   └── ddpm.py
├── dataset.py
├── train.py
├── sample.py
└── utils.py

这个结构相比简单 demo 有几个好处:

  • 模型独立
  • 扩散过程独立
  • 配置独立
  • 训练和采样分离
  • 后续扩展方便

四、配置文件

configs/train_config.py

python 复制代码
class TrainConfig:
    image_size = 64
    channels = 1

    batch_size = 32
    num_workers = 4

    epochs = 100
    lr = 2e-4

    timesteps = 1000
    beta_start = 1e-4
    beta_end = 0.02

    save_interval = 10
    data_dir = "data/train"
    save_dir = "checkpoints"

配置单独抽出来,最大的好处是:

实验参数不会散落在代码里。

后面复现实验时非常重要。


五、数据集代码

dataset.py

python 复制代码
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, image_size=64, channels=1):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".jpeg", ".png"))
        ]

        if channels == 1:
            self.mode = "L"
        else:
            self.mode = "RGB"

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = Image.open(self.paths[index]).convert(self.mode)
        return self.transform(img)

六、DDPM扩散类封装

diffusion/ddpm.py

python 复制代码
import torch


class DDPM:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device="cuda"):
        self.timesteps = timesteps
        self.device = device

        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)

        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

        self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1.0 - self.alpha_bars)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)

        sqrt_alpha_bar = self.sqrt_alpha_bars[t].view(-1, 1, 1, 1)
        sqrt_one_minus = self.sqrt_one_minus_alpha_bars[t].view(-1, 1, 1, 1)

        xt = sqrt_alpha_bar * x0 + sqrt_one_minus * noise

        return xt, noise

    @torch.no_grad()
    def p_sample(self, model, x, t):
        beta = self.betas[t]
        alpha = self.alphas[t]
        alpha_bar = self.alpha_bars[t]

        batch_t = torch.full((x.size(0),), t, device=x.device, dtype=torch.long)

        pred_noise = model(x, batch_t)

        mean = (1 / torch.sqrt(alpha)) * (
            x - (beta / torch.sqrt(1.0 - alpha_bar)) * pred_noise
        )

        if t > 0:
            noise = torch.randn_like(x)
            return mean + torch.sqrt(beta) * noise

        return mean

七、UNet噪声预测模型

models/unet.py

python 复制代码
import torch
import torch.nn as nn


class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(1, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, t):
        t = t.float().view(-1, 1) / 1000.0
        return self.net(t)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_proj = nn.Linear(time_dim, out_channels)

        self.shortcut = nn.Identity()
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)

        self.act = nn.SiLU()

    def forward(self, x, t_emb):
        h = self.act(self.conv1(x))

        time = self.time_proj(t_emb).view(x.size(0), -1, 1, 1)
        h = h + time

        h = self.conv2(self.act(h))

        return h + self.shortcut(x)


class DDPMUNet(nn.Module):
    def __init__(self, channels=1, base=64, time_dim=128):
        super().__init__()

        self.time_mlp = TimeEmbedding(time_dim)

        self.down1 = ResidualBlock(channels, base, time_dim)
        self.down2 = ResidualBlock(base, base * 2, time_dim)

        self.pool = nn.MaxPool2d(2)

        self.mid = ResidualBlock(base * 2, base * 2, time_dim)

        self.up = nn.ConvTranspose2d(base * 2, base, 2, 2)
        self.up_block = ResidualBlock(base * 2, base, time_dim)

        self.out = nn.Conv2d(base, channels, 3, padding=1)

    def forward(self, x, t):
        t_emb = self.time_mlp(t)

        d1 = self.down1(x, t_emb)
        d2 = self.down2(self.pool(d1), t_emb)

        mid = self.mid(d2, t_emb)

        u = self.up(mid)
        u = torch.cat([u, d1], dim=1)
        u = self.up_block(u, t_emb)

        return self.out(u)

八、训练脚本

train.py

python 复制代码
import os
import torch
from torch.utils.data import DataLoader

from configs.train_config import TrainConfig
from dataset import ImageFolderDataset
from models.unet import DDPMUNet
from diffusion.ddpm import DDPM


def train():
    cfg = TrainConfig()

    os.makedirs(cfg.save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = ImageFolderDataset(
        root_dir=cfg.data_dir,
        image_size=cfg.image_size,
        channels=cfg.channels
    )

    loader = DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers
    )

    model = DDPMUNet(channels=cfg.channels).to(device)
    diffusion = DDPM(
        timesteps=cfg.timesteps,
        beta_start=cfg.beta_start,
        beta_end=cfg.beta_end,
        device=device
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    criterion = torch.nn.MSELoss()

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        total_loss = 0

        for x0 in loader:
            x0 = x0.to(device)

            t = torch.randint(0, cfg.timesteps, (x0.size(0),), device=device)

            xt, noise = diffusion.q_sample(x0, t)

            pred_noise = model(xt, t)
            loss = criterion(pred_noise, noise)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch [{epoch}/{cfg.epochs}], Loss: {avg_loss:.6f}")

        if epoch % cfg.save_interval == 0:
            path = os.path.join(cfg.save_dir, f"ddpm_epoch_{epoch}.pth")
            torch.save(model.state_dict(), path)


if __name__ == "__main__":
    train()

九、采样脚本

sample.py

python 复制代码
import torch
import torchvision.utils as vutils

from configs.train_config import TrainConfig
from models.unet import DDPMUNet
from diffusion.ddpm import DDPM


@torch.no_grad()
def sample():
    cfg = TrainConfig()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = DDPMUNet(channels=cfg.channels).to(device)
    model.load_state_dict(torch.load("checkpoints/ddpm_epoch_100.pth", map_location=device))
    model.eval()

    diffusion = DDPM(
        timesteps=cfg.timesteps,
        beta_start=cfg.beta_start,
        beta_end=cfg.beta_end,
        device=device
    )

    x = torch.randn(16, cfg.channels, cfg.image_size, cfg.image_size).to(device)

    for t in reversed(range(cfg.timesteps)):
        x = diffusion.p_sample(model, x, t)

    x = torch.clamp(x, 0.0, 1.0)
    vutils.save_image(x.cpu(), "ddpm_sample.png", nrow=4)


if __name__ == "__main__":
    sample()

十、为什么要做工程拆分?

很多扩散模型代码一开始写在一个文件里,能跑,但很难维护。

工程拆分带来的好处:

  • diffusion类可复用
  • UNet可替换
  • config方便调参
  • train和sample互不干扰
  • 后续DDIM可以直接扩展

这也是从"能跑demo"到"能做项目"的关键一步。


十一、踩坑记录

坑1:采样结果全是噪声

常见原因:

  • 模型训练不够
  • 时间步输入错误
  • beta schedule太激进
  • 采样公式写错

建议先用小数据集验证过拟合能力。


坑2:loss下降但采样效果差

DDPM的loss下降不代表马上能生成好图。

采样质量通常需要更多训练轮数。


坑3:训练太慢

DDPM采样慢是正常现象,因为要从 T 逐步采样。

后续可以使用 DDIM 或减少 timesteps。


十二、适合收藏总结

DDPM工程化流程

  1. 配置文件管理参数
  2. Dataset加载图像
  3. DDPM类负责加噪和采样
  4. UNet预测噪声
  5. train.py训练模型
  6. sample.py生成结果

避坑清单

  • 不要把所有代码写一个文件
  • 时间步必须正确传入
  • beta schedule要稳定
  • 采样结果差不一定是loss问题
  • 先用小尺寸图跑通

十三、优化建议

后续可以继续做:

  • DDIM加速采样
  • 条件Diffusion去噪
  • 彩色图像支持
  • EMA模型权重
  • 混合精度训练

结尾总结

DDPM不是一个单独模型,而是一套完整的扩散训练和采样框架。

如果你只是写一个demo,很容易跑通;但如果要长期做系列实验,就必须从一开始整理好工程结构。

这一篇的重点不是追求最强效果,而是把DDPM搭成一个稳定可复现的项目骨架。


下一篇预告

Pytorch图像去噪实战(十三):DDIM加速采样,让扩散模型去噪从1000步降到50步

相关推荐
卷Java4 小时前
上下文压缩
开发语言·windows·python
本地化文档4 小时前
setuptools-docs-l10n
python·github·gitcode
梦想不只是梦与想4 小时前
Python 属性访问的 MRO 规则
python·mro规则
Ulyanov4 小时前
基于 Python 的三维动态导弹攻防演示系统设计与实现:从架构到实战的深度剖析
开发语言·python·qt·架构·雷达电子对抗
蔡俊锋4 小时前
AI时代:人类从操控者到旁观者的蜕变
人工智能·深度学习·hermes·ai团队·ai团队知识沉淀
Leinwin4 小时前
Claude 四月宕机七次:从一次事故看企业级 AI 部署的容灾设计
后端·python·flask
棉猴4 小时前
Python海龟绘图之绘制文本
javascript·python·html·write·turtle·海龟绘图·输出文本
渣渣盟4 小时前
大数据技术栈全景图:从零到一的入门路线(深度实战版)
大数据·hadoop·python·flink·spark
AI医影跨模态组学4 小时前
如何将深度学习超声影像特征与乳腺癌腋窝淋巴结治疗响应的生物学机制建立关联,并进一步解释其预测pCR与个体化治疗的机制联系
人工智能·深度学习·论文·医学·医学影像·影像组学·医学科研