【第五章:计算机视觉-项目实战之生成式算法实战:扩散模型】3.生成式算法实战:扩散模型-(3)DDPM模型训练与推理

第五章:计算机视觉-项目实战之生成式算法实战:扩散模型

第三部分:生成式算法实战:扩散模型

第三节:DDPM模型训练与推理


一、章节导读

在前两节中,我们分别介绍了扩散模型训练流程的总体思路 以及数据读取与预处理

本节将进入核心内容------DDPM(Denoising Diffusion Probabilistic Model)模型的训练与推理

这一节的目标是让你:

  • 理解 DDPM 的训练目标与损失函数

  • 掌握 噪声添加与反向去噪过程的实现逻辑

  • 学会如何 使用PyTorch实现完整的训练与推理流程

  • 了解如何 从噪声生成高质量图像


二、DDPM模型的训练原理回顾

扩散模型的训练核心思想是:

"教模型从噪声中一步步还原原始图像。"

这个过程分为两个阶段:

  1. 前向扩散(Forward Process)

    在每一步中,向图像中逐渐加入高斯噪声,直到完全变成随机噪声。

    公式如下:

    累积后可得:

    其中:

    • :噪声步的强度;

  2. 反向去噪(Reverse Process)

    模型学习预测噪声,从而重建

    训练目标为:


三、模型结构设计

DDPM通常使用 U-Net 结构来完成噪声预测任务:

  • 输入 :含噪图像 + 时间步 (t)

  • 输出 :预测噪声

  • 特征

    • 多尺度卷积编码器 + 解码器结构

    • Skip Connection 保留细节信息

    • 时间步通过位置编码注入

简化版U-Net定义:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        super().__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.out = nn.Conv2d(base_channels, out_channels, 3, padding=1)

    def forward(self, x, t):
        # 简化:暂不引入时间编码
        d1 = self.down1(x)
        d2 = self.down2(d1)
        u1 = self.up1(d2)
        x = self.out(u1 + d1)  # skip connection
        return x

四、训练流程实现

1. 训练步骤伪代码

python 复制代码
for epoch in range(num_epochs):
    for x_0 in dataloader:
        x_0 = x_0.to(device)
        
        # 随机采样时间步 t
        t = torch.randint(0, T, (x_0.size(0),), device=device).long()
        
        # 采样噪声
        noise = torch.randn_like(x_0)
        
        # 前向扩散
        alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
        
        # 模型预测噪声
        noise_pred = model(x_t, t)
        
        # 损失函数
        loss = F.mse_loss(noise_pred, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

2. 学习率调度与优化器

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

五、模型推理(采样生成)

在训练完成后,我们从纯噪声开始,依次反向采样:

其中

推理流程示例:

python 复制代码
@torch.no_grad()
def sample(model, image_size, T, device):
    model.eval()
    x_t = torch.randn(1, 3, image_size, image_size, device=device)

    for t in reversed(range(T)):
        t_tensor = torch.tensor([t], device=device)
        noise_pred = model(x_t, t_tensor)

        alpha_t = alphas[t]
        alpha_bar_t = alpha_bars[t]
        beta_t = betas[t]
        
        if t > 0:
            z = torch.randn_like(x_t)
        else:
            z = 0
        
        x_t = (1 / torch.sqrt(alpha_t)) * (
            x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
        ) + torch.sqrt(beta_t) * z
    
    return x_t

输出可视化:

python 复制代码
from torchvision.utils import save_image
generated = sample(model, image_size=64, T=1000, device=device)
save_image((generated + 1) / 2, "generated_sample.png")

六、训练与推理效果对比

阶段 输入 输出 说明
训练 原始图像 + 时间步 噪声预测 学习"去噪映射"
推理 随机噪声 清晰图像 逐步"逆扩散"生成图像

通常训练若干个 epoch 后,模型能从完全随机噪声中生成结构清晰、颜色自然的图像。


七、可视化生成过程

为理解模型行为,可将生成过程中的图像保存下来:

python 复制代码
images = []
x_t = torch.randn(1, 3, 64, 64, device=device)
for t in reversed(range(T)):
    ...
    images.append(x_t.cpu())

# 可用matplotlib展示部分阶段
import matplotlib.pyplot as plt
plt.figure(figsize=(12,3))
for i, step in enumerate([T-1, T//2, 10, 0]):
    plt.subplot(1,4,i+1)
    plt.imshow((images[step][0].permute(1,2,0)+1)/2)
    plt.axis('off')
plt.show()

八、训练技巧与优化建议

优化项 说明
学习率warm-up 前几步使用小学习率防止梯度爆炸
EMA(Exponential Moving Average) 保存模型权重平滑版本提升稳定性
混合精度训练(AMP) 减少显存占用并加速训练
梯度裁剪 防止训练不稳定
分布式训练(DDP) 大规模数据集必备

九、小结

内容 核心要点
训练目标 学习从噪声中恢复原图像的映射
模型结构 U-Net 多尺度特征提取 + 时间步嵌入
训练逻辑 随机采样时间步,添加噪声并预测
推理过程 从随机噪声逆向生成图像
实践建议 保持数据一致性、优化加载与EMA策略

本节收获

  • 掌握了DDPM模型的训练与推理流程全貌

  • 理解了扩散模型的核心数学原理与反向采样机制

  • 能够实现一个可运行的DDPM生成模型原型

  • 为下一节的**模型优化与高级应用(如条件生成、Guided Diffusion)**奠定了基础。

相关推荐
共享家952718 小时前
数独系列算法
算法·深度优先
rengang6618 小时前
04-深度学习的基本概念:涵盖深度学习中的关键术语和原理
人工智能·深度学习
杨成功18 小时前
大语言模型(LLM)学习笔记
人工智能·llm
文火冰糖的硅基工坊18 小时前
[人工智能-大模型-122]:模型层 - RNN是通过神经元还是通过张量时间记录状态信息?时间状态信息是如何被更新的?
人工智能·rnn·深度学习
Dev7z18 小时前
基于深度学习的中国交通警察手势识别与指令优先级判定系统
人工智能·深度学习
阿_旭18 小时前
复杂环境下驾驶员注意力实时检测: 双目深度补偿 + 双向 LSTM
人工智能·lstm·驾驶员注意力
liebe1*119 小时前
C语言程序代码(四)
c语言·数据结构·算法
进击的圆儿19 小时前
递归专题4 - 网格DFS与回溯
数据结构·算法·递归回溯
程序猿202319 小时前
Python每日一练---第一天:买卖股票的最佳时机
算法
Elastic 中国社区官方博客19 小时前
Elastic AI agent builder 介绍(三)
大数据·人工智能·elasticsearch·搜索引擎·ai·全文检索