【第五章:计算机视觉-项目实战之生成式算法实战:扩散模型】3.生成式算法实战:扩散模型-(3)DDPM模型训练与推理

第五章:计算机视觉-项目实战之生成式算法实战:扩散模型

第三部分:生成式算法实战:扩散模型

第三节:DDPM模型训练与推理


一、章节导读

在前两节中,我们分别介绍了扩散模型训练流程的总体思路 以及数据读取与预处理

本节将进入核心内容------DDPM(Denoising Diffusion Probabilistic Model)模型的训练与推理

这一节的目标是让你:

  • 理解 DDPM 的训练目标与损失函数

  • 掌握 噪声添加与反向去噪过程的实现逻辑

  • 学会如何 使用PyTorch实现完整的训练与推理流程

  • 了解如何 从噪声生成高质量图像


二、DDPM模型的训练原理回顾

扩散模型的训练核心思想是:

"教模型从噪声中一步步还原原始图像。"

这个过程分为两个阶段:

  1. 前向扩散(Forward Process)

    在每一步中,向图像中逐渐加入高斯噪声,直到完全变成随机噪声。

    公式如下:

    累积后可得:

    其中:

    • :噪声步的强度;

  2. 反向去噪(Reverse Process)

    模型学习预测噪声,从而重建

    训练目标为:


三、模型结构设计

DDPM通常使用 U-Net 结构来完成噪声预测任务:

  • 输入 :含噪图像 + 时间步 (t)

  • 输出 :预测噪声

  • 特征

    • 多尺度卷积编码器 + 解码器结构

    • Skip Connection 保留细节信息

    • 时间步通过位置编码注入

简化版U-Net定义:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        super().__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.out = nn.Conv2d(base_channels, out_channels, 3, padding=1)

    def forward(self, x, t):
        # 简化:暂不引入时间编码
        d1 = self.down1(x)
        d2 = self.down2(d1)
        u1 = self.up1(d2)
        x = self.out(u1 + d1)  # skip connection
        return x

四、训练流程实现

1. 训练步骤伪代码

python 复制代码
for epoch in range(num_epochs):
    for x_0 in dataloader:
        x_0 = x_0.to(device)
        
        # 随机采样时间步 t
        t = torch.randint(0, T, (x_0.size(0),), device=device).long()
        
        # 采样噪声
        noise = torch.randn_like(x_0)
        
        # 前向扩散
        alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
        
        # 模型预测噪声
        noise_pred = model(x_t, t)
        
        # 损失函数
        loss = F.mse_loss(noise_pred, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

2. 学习率调度与优化器

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

五、模型推理(采样生成)

在训练完成后,我们从纯噪声开始,依次反向采样:

其中

推理流程示例:

python 复制代码
@torch.no_grad()
def sample(model, image_size, T, device):
    model.eval()
    x_t = torch.randn(1, 3, image_size, image_size, device=device)

    for t in reversed(range(T)):
        t_tensor = torch.tensor([t], device=device)
        noise_pred = model(x_t, t_tensor)

        alpha_t = alphas[t]
        alpha_bar_t = alpha_bars[t]
        beta_t = betas[t]
        
        if t > 0:
            z = torch.randn_like(x_t)
        else:
            z = 0
        
        x_t = (1 / torch.sqrt(alpha_t)) * (
            x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
        ) + torch.sqrt(beta_t) * z
    
    return x_t

输出可视化:

python 复制代码
from torchvision.utils import save_image
generated = sample(model, image_size=64, T=1000, device=device)
save_image((generated + 1) / 2, "generated_sample.png")

六、训练与推理效果对比

阶段 输入 输出 说明
训练 原始图像 + 时间步 噪声预测 学习"去噪映射"
推理 随机噪声 清晰图像 逐步"逆扩散"生成图像

通常训练若干个 epoch 后,模型能从完全随机噪声中生成结构清晰、颜色自然的图像。


七、可视化生成过程

为理解模型行为,可将生成过程中的图像保存下来:

python 复制代码
images = []
x_t = torch.randn(1, 3, 64, 64, device=device)
for t in reversed(range(T)):
    ...
    images.append(x_t.cpu())

# 可用matplotlib展示部分阶段
import matplotlib.pyplot as plt
plt.figure(figsize=(12,3))
for i, step in enumerate([T-1, T//2, 10, 0]):
    plt.subplot(1,4,i+1)
    plt.imshow((images[step][0].permute(1,2,0)+1)/2)
    plt.axis('off')
plt.show()

八、训练技巧与优化建议

优化项 说明
学习率warm-up 前几步使用小学习率防止梯度爆炸
EMA(Exponential Moving Average) 保存模型权重平滑版本提升稳定性
混合精度训练(AMP) 减少显存占用并加速训练
梯度裁剪 防止训练不稳定
分布式训练(DDP) 大规模数据集必备

九、小结

内容 核心要点
训练目标 学习从噪声中恢复原图像的映射
模型结构 U-Net 多尺度特征提取 + 时间步嵌入
训练逻辑 随机采样时间步,添加噪声并预测
推理过程 从随机噪声逆向生成图像
实践建议 保持数据一致性、优化加载与EMA策略

本节收获

  • 掌握了DDPM模型的训练与推理流程全貌

  • 理解了扩散模型的核心数学原理与反向采样机制

  • 能够实现一个可运行的DDPM生成模型原型

  • 为下一节的**模型优化与高级应用(如条件生成、Guided Diffusion)**奠定了基础。

相关推荐
乐迪信息2 小时前
乐迪信息:智慧煤矿输送带安全如何保障?AI摄像机全天候识别
大数据·运维·人工智能·安全·自动化·视觉检测
知孤云出岫2 小时前
为 AI / LLM / Agent 构建安全基础
人工智能·安全
独自破碎E3 小时前
Leetcode2166-设计位集
java·数据结构·算法
阿里云大数据AI技术3 小时前
云栖实录|人工智能+大数据平台加速企业模型后训练
大数据·人工智能
ARM+FPGA+AI工业主板定制专家3 小时前
基于JETSON/RK3588机器人高动态双目视觉系统方案
人工智能·机器学习·fpga开发·机器人·自动驾驶
Swift社区3 小时前
LeetCode 396 - 旋转函数 (Rotate Function)
算法·leetcode·职场和发展
海琴烟Sunshine3 小时前
leetcode 88.合并两个有序数组
python·算法·leetcode
Cikiss3 小时前
LeetCode160.相交链表【最通俗易懂版双指针】
java·数据结构·算法·链表
东方芷兰3 小时前
LLM 笔记 —— 08 Embeddings(One-hot、Word、Word2Vec、Glove、FastText)
人工智能·笔记·神经网络·语言模型·自然语言处理·word·word2vec