第五章:计算机视觉-项目实战之生成式算法实战:扩散模型
第三部分:生成式算法实战:扩散模型
第三节:DDPM模型训练与推理
一、章节导读
在前两节中,我们分别介绍了扩散模型训练流程的总体思路 以及数据读取与预处理 。
本节将进入核心内容------DDPM(Denoising Diffusion Probabilistic Model)模型的训练与推理。
这一节的目标是让你:
-
理解 DDPM 的训练目标与损失函数;
-
掌握 噪声添加与反向去噪过程的实现逻辑;
-
学会如何 使用PyTorch实现完整的训练与推理流程;
-
了解如何 从噪声生成高质量图像。
二、DDPM模型的训练原理回顾
扩散模型的训练核心思想是:
"教模型从噪声中一步步还原原始图像。"
这个过程分为两个阶段:
-
前向扩散(Forward Process)
在每一步中,向图像中逐渐加入高斯噪声,直到完全变成随机噪声。
公式如下:
累积后可得:
其中:
-
:噪声步的强度;
-
;
-
。
-
-
反向去噪(Reverse Process)
模型学习预测噪声
,从而重建
。
训练目标为:
三、模型结构设计
DDPM通常使用 U-Net 结构来完成噪声预测任务:
-
输入 :含噪图像
+ 时间步 (t)
-
输出 :预测噪声
-
特征:
-
多尺度卷积编码器 + 解码器结构
-
Skip Connection 保留细节信息
-
时间步通过位置编码注入
-
简化版U-Net定义:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleUNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, base_channels=64):
super().__init__()
self.down1 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, 3, padding=1),
nn.ReLU(),
nn.Conv2d(base_channels, base_channels, 3, padding=1),
nn.ReLU()
)
self.down2 = nn.Sequential(
nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
nn.ReLU()
)
self.up1 = nn.Sequential(
nn.ConvTranspose2d(base_channels * 2, base_channels, 4, stride=2, padding=1),
nn.ReLU()
)
self.out = nn.Conv2d(base_channels, out_channels, 3, padding=1)
def forward(self, x, t):
# 简化:暂不引入时间编码
d1 = self.down1(x)
d2 = self.down2(d1)
u1 = self.up1(d2)
x = self.out(u1 + d1) # skip connection
return x
四、训练流程实现
1. 训练步骤伪代码
python
for epoch in range(num_epochs):
for x_0 in dataloader:
x_0 = x_0.to(device)
# 随机采样时间步 t
t = torch.randint(0, T, (x_0.size(0),), device=device).long()
# 采样噪声
noise = torch.randn_like(x_0)
# 前向扩散
alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
# 模型预测噪声
noise_pred = model(x_t, t)
# 损失函数
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
2. 学习率调度与优化器
python
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
五、模型推理(采样生成)
在训练完成后,我们从纯噪声开始,依次反向采样:
其中。
推理流程示例:
python
@torch.no_grad()
def sample(model, image_size, T, device):
model.eval()
x_t = torch.randn(1, 3, image_size, image_size, device=device)
for t in reversed(range(T)):
t_tensor = torch.tensor([t], device=device)
noise_pred = model(x_t, t_tensor)
alpha_t = alphas[t]
alpha_bar_t = alpha_bars[t]
beta_t = betas[t]
if t > 0:
z = torch.randn_like(x_t)
else:
z = 0
x_t = (1 / torch.sqrt(alpha_t)) * (
x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * noise_pred
) + torch.sqrt(beta_t) * z
return x_t
输出可视化:
python
from torchvision.utils import save_image
generated = sample(model, image_size=64, T=1000, device=device)
save_image((generated + 1) / 2, "generated_sample.png")
六、训练与推理效果对比
阶段 | 输入 | 输出 | 说明 |
---|---|---|---|
训练 | 原始图像 + 时间步 | 噪声预测 | 学习"去噪映射" |
推理 | 随机噪声 | 清晰图像 | 逐步"逆扩散"生成图像 |
通常训练若干个 epoch 后,模型能从完全随机噪声中生成结构清晰、颜色自然的图像。
七、可视化生成过程
为理解模型行为,可将生成过程中的图像保存下来:
python
images = []
x_t = torch.randn(1, 3, 64, 64, device=device)
for t in reversed(range(T)):
...
images.append(x_t.cpu())
# 可用matplotlib展示部分阶段
import matplotlib.pyplot as plt
plt.figure(figsize=(12,3))
for i, step in enumerate([T-1, T//2, 10, 0]):
plt.subplot(1,4,i+1)
plt.imshow((images[step][0].permute(1,2,0)+1)/2)
plt.axis('off')
plt.show()
八、训练技巧与优化建议
优化项 | 说明 |
---|---|
学习率warm-up | 前几步使用小学习率防止梯度爆炸 |
EMA(Exponential Moving Average) | 保存模型权重平滑版本提升稳定性 |
混合精度训练(AMP) | 减少显存占用并加速训练 |
梯度裁剪 | 防止训练不稳定 |
分布式训练(DDP) | 大规模数据集必备 |
九、小结
内容 | 核心要点 |
---|---|
训练目标 | 学习从噪声中恢复原图像的映射 |
模型结构 | U-Net 多尺度特征提取 + 时间步嵌入 |
训练逻辑 | 随机采样时间步,添加噪声并预测 |
推理过程 | 从随机噪声逆向生成图像 |
实践建议 | 保持数据一致性、优化加载与EMA策略 |
本节收获
-
掌握了DDPM模型的训练与推理流程全貌;
-
理解了扩散模型的核心数学原理与反向采样机制;
-
能够实现一个可运行的DDPM生成模型原型;
-
为下一节的**模型优化与高级应用(如条件生成、Guided Diffusion)**奠定了基础。