第五章:计算机视觉-项目实战之生成式算法实战:扩散模型
第三部分:生成式算法实战:扩散模型
第三节:DDPM模型训练与推理
一、章节导读
在前两节中,我们分别介绍了扩散模型训练流程的总体思路 以及数据读取与预处理 。
本节将进入核心内容------DDPM(Denoising Diffusion Probabilistic Model)模型的训练与推理。
这一节的目标是让你:
- 
理解 DDPM 的训练目标与损失函数;
 - 
掌握 噪声添加与反向去噪过程的实现逻辑;
 - 
学会如何 使用PyTorch实现完整的训练与推理流程;
 - 
了解如何 从噪声生成高质量图像。
 
二、DDPM模型的训练原理回顾
扩散模型的训练核心思想是:
"教模型从噪声中一步步还原原始图像。"
这个过程分为两个阶段:
- 
前向扩散(Forward Process)
在每一步中,向图像中逐渐加入高斯噪声,直到完全变成随机噪声。
公式如下:
累积后可得:
其中:
- 
:噪声步的强度;
 - 
;
 - 
。
 
 - 
 - 
反向去噪(Reverse Process)
模型学习预测噪声
,从而重建
。
训练目标为:
 
三、模型结构设计
DDPM通常使用 U-Net 结构来完成噪声预测任务:
- 
输入 :含噪图像
+ 时间步 (t)
 - 
输出 :预测噪声
 - 
特征:
- 
多尺度卷积编码器 + 解码器结构
 - 
Skip Connection 保留细节信息
 - 
时间步通过位置编码注入
 
 - 
 
简化版U-Net定义:
            
            
              python
              
              
            
          
          import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=64):
        super().__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.out = nn.Conv2d(base_channels, out_channels, 3, padding=1)
    def forward(self, x, t):
        # 简化:暂不引入时间编码
        d1 = self.down1(x)
        d2 = self.down2(d1)
        u1 = self.up1(d2)
        x = self.out(u1 + d1)  # skip connection
        return x
        四、训练流程实现
1. 训练步骤伪代码
            
            
              python
              
              
            
          
          for epoch in range(num_epochs):
    for x_0 in dataloader:
        x_0 = x_0.to(device)
        
        # 随机采样时间步 t
        t = torch.randint(0, T, (x_0.size(0),), device=device).long()
        
        # 采样噪声
        noise = torch.randn_like(x_0)
        
        # 前向扩散
        alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
        
        # 模型预测噪声
        noise_pred = model(x_t, t)
        
        # 损失函数
        loss = F.mse_loss(noise_pred, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        2. 学习率调度与优化器
            
            
              python
              
              
            
          
          optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
        五、模型推理(采样生成)
在训练完成后,我们从纯噪声开始,依次反向采样:
其中。
推理流程示例:
            
            
              python
              
              
            
          
          @torch.no_grad()
def sample(model, image_size, T, device):
    model.eval()
    x_t = torch.randn(1, 3, image_size, image_size, device=device)
    for t in reversed(range(T)):
        t_tensor = torch.tensor([t], device=device)
        noise_pred = model(x_t, t_tensor)
        alpha_t = alphas[t]
        alpha_bar_t = alpha_bars[t]
        beta_t = betas[t]
        
        if t > 0:
            z = torch.randn_like(x_t)
        else:
            z = 0
        
        x_t = (1 / torch.sqrt(alpha_t)) * (
            x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
        ) + torch.sqrt(beta_t) * z
    
    return x_t
        输出可视化:
            
            
              python
              
              
            
          
          from torchvision.utils import save_image
generated = sample(model, image_size=64, T=1000, device=device)
save_image((generated + 1) / 2, "generated_sample.png")
        六、训练与推理效果对比
| 阶段 | 输入 | 输出 | 说明 | 
|---|---|---|---|
| 训练 | 原始图像 + 时间步 | 噪声预测 | 学习"去噪映射" | 
| 推理 | 随机噪声 | 清晰图像 | 逐步"逆扩散"生成图像 | 
通常训练若干个 epoch 后,模型能从完全随机噪声中生成结构清晰、颜色自然的图像。
七、可视化生成过程
为理解模型行为,可将生成过程中的图像保存下来:
            
            
              python
              
              
            
          
          images = []
x_t = torch.randn(1, 3, 64, 64, device=device)
for t in reversed(range(T)):
    ...
    images.append(x_t.cpu())
# 可用matplotlib展示部分阶段
import matplotlib.pyplot as plt
plt.figure(figsize=(12,3))
for i, step in enumerate([T-1, T//2, 10, 0]):
    plt.subplot(1,4,i+1)
    plt.imshow((images[step][0].permute(1,2,0)+1)/2)
    plt.axis('off')
plt.show()
        八、训练技巧与优化建议
| 优化项 | 说明 | 
|---|---|
| 学习率warm-up | 前几步使用小学习率防止梯度爆炸 | 
| EMA(Exponential Moving Average) | 保存模型权重平滑版本提升稳定性 | 
| 混合精度训练(AMP) | 减少显存占用并加速训练 | 
| 梯度裁剪 | 防止训练不稳定 | 
| 分布式训练(DDP) | 大规模数据集必备 | 
九、小结
| 内容 | 核心要点 | 
|---|---|
| 训练目标 | 学习从噪声中恢复原图像的映射 | 
| 模型结构 | U-Net 多尺度特征提取 + 时间步嵌入 | 
| 训练逻辑 | 随机采样时间步,添加噪声并预测 | 
| 推理过程 | 从随机噪声逆向生成图像 | 
| 实践建议 | 保持数据一致性、优化加载与EMA策略 | 
本节收获
- 
掌握了DDPM模型的训练与推理流程全貌;
 - 
理解了扩散模型的核心数学原理与反向采样机制;
 - 
能够实现一个可运行的DDPM生成模型原型;
 - 
为下一节的**模型优化与高级应用(如条件生成、Guided Diffusion)**奠定了基础。