目录
- 概述
- 生成模型发展史
- 核心思想
- 数学基础
- DDPM详解
- DDIM加速采样
- Score-Based模型
- 条件生成
- [Latent Diffusion与Stable Diffusion](#Latent Diffusion与Stable Diffusion)
- [Classifier-Free Guidance](#Classifier-Free Guidance)
- 图像生成应用
- 视频生成
- 3D与音频生成
- 完整代码实现
- 训练技巧
- 推理优化
- 参考资料
1. 概述
1.1 什么是扩散模型
扩散模型是一类基于概率的生成模型,通过学习逐步去噪过程来生成数据。其核心包括:
- 前向过程:逐步向数据添加噪声,直至变为纯高斯噪声
- 反向过程:学习从噪声中逐步恢复原始数据
1.2 发展历程
| 时间 | 事件 |
|---|---|
| 2015 | Sohl-Dickstein首次提出扩散模型概念 |
| 2020 | DDPM实现高质量图像生成 |
| 2021 | DDIM加速采样;Score SDE统一理论 |
| 2022 | Latent Diffusion/Stable Diffusion;DALL-E 2 |
| 2023 | SDXL;Consistency Models;视频生成 |
| 2024 | Sora视频生成;Flow Matching |
2. 生成模型发展史
2.1 主流生成模型对比
| 模型 | 优点 | 缺点 |
|---|---|---|
| GAN | 生成快、质量高 | 训练不稳定、模式坍塌 |
| VAE | 有隐空间、可训练 | 生成模糊 |
| Flow | 精确似然 | 计算量大、架构受限 |
| 自回归 | 精确建模 | 生成慢 |
| Diffusion | 训练稳定、质量高 | 生成较慢 |
2.2 为什么选择扩散模型
- 训练过程简单稳定(只预测噪声)
- 生成质量目前最高
- 不易模式坍塌
- 理论基础坚实
3. 核心思想
3.1 直观理解
前向过程:照片逐渐褪色变成噪声
反向过程:从噪声中恢复出清晰照片
3.2 关键公式
前向:
x_t = √(ᾱ_t) * x_0 + √(1-ᾱ_t) * ε
反向(学习):
x_{t-1} = μ_θ(x_t, t) + σ_t * z
训练目标:预测噪声
L = E[||ε - ε_θ(x_t, t)||²]
4. 数学基础
4.1 高斯分布
一元:N(μ, σ²)
多元:N(μ, Σ)
4.2 KL散度
D_KL(p||q) = ∫ p(x) log(p(x)/q(x)) dx
高斯分布间的KL散度有解析解。
4.3 得分函数
s(x) = ∇_x log p(x)
不需要归一化常数,是Score-Based模型的基础。
4.4 马尔可夫链
扩散过程是马尔可夫链:x_t只依赖x_{t-1}
5. DDPM详解
5.1 前向过程
q(x_t|x_{t-1}) = N(x_t; √(1-β_t)x_{t-1}, β_t I)
性质:可以直接从x_0计算x_t
x_t = √(ᾱ_t)x_0 + √(1-ᾱ_t)ε
5.2 反向过程
p_θ(x_{t-1}|x_t) = N(x_{t-1}; μ_θ(x_t,t), σ_t²I)
神经网络预测噪声ε_θ,然后计算均值:
μ_θ = (1/√α_t)(x_t - β_t/√(1-ᾱ_t) * ε_θ)
5.3 训练目标
简化损失:
L_simple = E[||ε - ε_θ(x_t, t)||²]
5.4 训练算法
repeat:
采样x_0 ~ q(x_0)
采样t ~ Uniform(1,T)
采样ε ~ N(0,I)
计算x_t = √(ᾱ_t)x_0 + √(1-ᾱ_t)ε
计算损失 ||ε - ε_θ(x_t, t)||²
梯度下降更新θ
5.5 采样算法
x_T ~ N(0,I)
for t = T,...,1:
z ~ N(0,I) if t>1 else z=0
x_{t-1} = μ_θ(x_t,t) + σ_t * z
return x_0
6. DDIM加速采样
6.1 核心思想
DDPM采样需要T步(通常1000步),太慢。
DDIM定义一族非马尔可夫前向过程,支持跳步采样。
6.2 DDIM采样公式
x_{t-1} = √(ᾱ_{t-1}) * predicted_x0 + √(1-ᾱ_{t-1}-σ²) * predicted_noise + σ * z
其中:
predicted_x0 = (x_t - √(1-ᾱ_t)*ε_θ) / √(ᾱ_t)
6.3 确定性采样
当σ=0时,采样过程完全确定(无随机性),便于隐空间插值。
6.4 加速效果
- DDPM:1000步
- DDIM:50-100步(质量相近)
- 进一步:10-20步(质量略有下降)
7. Score-Based模型
7.1 得分匹配
学习数据分布的得分函数s_θ(x) ≈ ∇_x log p(x)
损失函数:
L = E[||s_θ(x_t, t) - ∇_{x_t} log q(x_t|x_0)||²]
关键联系:预测噪声与预测得分等价
s_θ(x_t, t) = -ε_θ(x_t, t) / √(1-ᾱ_t)
7.2 SDE框架
前向SDE:
dx = f(x,t)dt + g(t)dw
反向SDE:
dx = [f(x,t) - g(t)²∇_x log p_t(x)]dt + g(t)dw̄
7.3 三种VP/VE/子VP SDE
- VP (Variance Preserving):对应DDPM
- VE (Variance Exploding):对应SMLD
- sub-VP:介于两者之间
8. 条件生成
8.1 Classifier Guidance
训练一个分类器p(y|x_t),用其梯度引导生成:
ε̃ = ε_θ(x_t,t) - √(1-ᾱ_t) * s * ∇_{x_t} log p(y|x_t)
8.2 Classifier-Free Guidance
同时训练条件和无条件模型:
ε̃ = ε_θ(x_t,t,∅) + s * (ε_θ(x_t,t,c) - ε_θ(x_t,t,∅))
s>1增强条件效果,常用s=7.5
8.3 文本条件
使用CLIP或T5编码文本,通过交叉注意力注入UNet。
9. Latent Diffusion与Stable Diffusion
9.1 潜空间扩散
在像素空间做扩散计算量太大。解决方案:
- 训练VAE将图像压缩到潜空间
- 在潜空间做扩散
- 用VAE解码回像素空间
9.2 Stable Diffusion架构
文本 → [CLIP/T5] → 文本嵌入
↓
噪声 → [UNet+交叉注意力] → 去噪潜变量 → [VAE解码器] → 图像
关键组件:
- VAE:4-8倍下采样
- UNet:带交叉注意力的时间条件网络
- 文本编码器:CLIP或OpenCLIP
9.3 训练流程
1. 训练VAE(重建+KL+感知损失)
2. 冻结VAE,训练UNet
- 图像→潜变量z_0
- 添加噪声得到z_t
- UNet预测噪声,以文本嵌入为条件
10. Classifier-Free Guidance详解
10.1 训练
以概率p_uncond将条件置空(通常p_uncond=0.1-0.2)
10.2 推理
ε_guided = ε_uncond + w * (ε_cond - ε_uncond)
w=1等价于标准条件生成,w>1增强条件遵循。
10.3 效果
- w=1:多样但可能不遵循条件
- w=7-12:很好的平衡
- w很大:严格遵循条件但多样性降低
11. 图像生成应用
11.1 文本到图像
代表模型:DALL-E 2/3、Stable Diffusion、Midjourney
流程:
文本描述 → 文本编码 → 扩散生成 → 高分辨率图像
11.2 图像编辑
- Inpainting:遮罩区域重新生成
- Outpainting:扩展图像边界
- Style Transfer:风格迁移
- Image-to-Image:图像到图像转换
11.3 超分辨率
低分辨率图像 → 条件扩散 → 高分辨率图像
12. 视频生成
12.1 方法
- 时空UNet:3D卷积+时间注意力
- 图像模型扩展:在图像模型上加时间层
- 级联生成:先低分辨率再超分
12.2 代表模型
- Sora (OpenAI):最长60秒
- Runway Gen-2/3
- Pika Labs
- Stable Video Diffusion
13. 3D与音频生成
13.1 3D生成
- DreamFusion:2D扩散+NeRF
- Magic3D:两阶段生成
- Point-E:点云生成
13.2 音频生成
- AudioLDM:音频潜空间扩散
- MusicLDM:音乐生成
- Riffusion:频谱图扩散
14. 完整代码实现
14.1 DDPM实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ========== 时间嵌入 ==========
class SinusoidalPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half = self.dim // 2
emb = math.log(10000) / (half - 1)
emb = torch.exp(torch.arange(half, device=device) * -emb)
emb = time[:, None] * emb[None, :]
return torch.cat([emb.sin(), emb.cos()], dim=-1)
# ========== 残差块 ==========
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch))
self.norm2 = nn.GroupNorm(32, out_ch)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t):
h = F.silu(self.norm1(x))
h = self.conv1(h)
h = h + self.time_proj(t)[:, :, None, None]
h = F.silu(self.norm2(h))
h = self.dropout(h)
h = self.conv2(h)
return h + self.skip(x)
# ========== 注意力块 ==========
class AttnBlock(nn.Module):
def __init__(self, ch, heads=8):
super().__init__()
self.norm = nn.GroupNorm(32, ch)
self.attn = nn.MultiheadAttention(ch, heads, batch_first=True)
def forward(self, x):
b, c, h, w = x.shape
residual = x
x = self.norm(x).view(b, c, -1).transpose(1, 2)
x, _ = self.attn(x, x, x)
x = x.transpose(1, 2).view(b, c, h, w)
return x + residual
# ========== UNet ==========
class UNet(nn.Module):
def __init__(self, in_ch=3, out_ch=3, base_ch=128,
ch_mult=(1,2,4,8), attn_res=(2,), res_blocks=2):
super().__init__()
time_dim = base_ch * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbedding(base_ch),
nn.Linear(base_ch, time_dim), nn.SiLU(),
nn.Linear(time_dim, time_dim)
)
self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
# 编码器
self.downs = nn.ModuleList()
chs = [base_ch]
ch = base_ch
for level, mult in enumerate(ch_mult):
out = base_ch * mult
for _ in range(res_blocks):
layers = [ResBlock(ch, out, time_dim)]
if level in attn_res:
layers.append(AttnBlock(out))
self.downs.append(nn.ModuleList(layers))
ch = out
chs.append(ch)
if level < len(ch_mult) - 1:
self.downs.append(nn.ModuleList([nn.Conv2d(ch, ch, 3, 2, 1)]))
chs.append(ch)
# 中间
self.mid = nn.ModuleList([
ResBlock(ch, ch, time_dim),
AttnBlock(ch),
ResBlock(ch, ch, time_dim)
])
# 解码器
self.ups = nn.ModuleList()
for level, mult in reversed(list(enumerate(ch_mult))):
out = base_ch * mult
for i in range(res_blocks + 1):
skip = chs.pop()
layers = [ResBlock(ch + skip, out, time_dim)]
if level in attn_res:
layers.append(AttnBlock(out))
if level > 0 and i == res_blocks:
layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
self.ups.append(nn.ModuleList(layers))
ch = out
self.out = nn.Sequential(
nn.GroupNorm(32, ch), nn.SiLU(),
nn.Conv2d(ch, out_ch, 3, padding=1)
)
def forward(self, x, t):
t = self.time_mlp(t.float())
x = self.in_conv(x)
skips = [x]
for block in self.downs:
for layer in block:
x = layer(x, t) if isinstance(layer, ResBlock) else layer(x)
skips.append(x)
for layer in self.mid:
x = layer(x, t) if isinstance(layer, ResBlock) else layer(x)
for block in self.ups:
x = torch.cat([x, skips.pop()], 1)
for layer in block:
if isinstance(layer, ResBlock):
x = layer(x, t)
elif isinstance(layer, AttnBlock):
x = layer(x)
else:
x = layer(x)
return self.out(x)
# ========== DDPM ==========
class DDPM:
def __init__(self, model, T=1000, beta1=1e-4, beta2=0.02):
self.model = model
self.T = T
self.betas = torch.linspace(beta1, beta2, T)
self.alphas = 1 - self.betas
self.alpha_bar = torch.cumprod(self.alphas, 0)
self.alpha_bar_prev = F.pad(self.alpha_bar[:-1], (1,0), value=1.0)
self.posterior_var = self.betas * (1 - self.alpha_bar_prev) / (1 - self.alpha_bar)
def q_sample(self, x0, t, eps=None):
if eps is None:
eps = torch.randn_like(x0)
a = self.alpha_bar[t][:, None, None, None].to(x0.device)
return a.sqrt() * x0 + (1 - a).sqrt() * eps
def loss(self, x0):
b = x0.shape[0]
t = torch.randint(0, self.T, (b,), device=x0.device)
eps = torch.randn_like(x0)
xt = self.q_sample(x0, t, eps)
eps_pred = self.model(xt, t)
return F.mse_loss(eps_pred, eps)
@torch.no_grad()
def sample(self, shape, device):
x = torch.randn(shape, device=device)
for t in reversed(range(self.T)):
t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
eps_pred = self.model(x, t_batch)
a = self.alphas[t]
ab = self.alpha_bar[t]
b = self.betas[t]
mean = (1 / a.sqrt()) * (x - b / (1 - ab).sqrt() * eps_pred)
if t > 0:
var = self.posterior_var[t]
x = mean + var.sqrt() * torch.randn_like(x)
else:
x = mean
return x
14.2 Latent Diffusion实现
python
class LatentDiffusion(nn.Module):
def __init__(self, vae, unet, text_encoder):
super().__init__()
self.vae = vae
self.unet = unet
self.text_encoder = text_encoder
self.latent_scale = 0.18215 # SD的缩放因子
def encode(self, x):
posterior = self.vae.encode(x).latent_dist
return posterior.sample() * self.latent_scale
def decode(self, z):
return self.vae.decode(z / self.latent_scale).sample
def forward(self, x, text):
# 编码
z0 = self.encode(x)
# 文本条件
text_emb = self.text_encoder(text).last_hidden_state
# 扩散训练
b = z0.shape[0]
t = torch.randint(0, 1000, (b,), device=z0.device)
eps = torch.randn_like(z0)
zt = self.ddpm.q_sample(z0, t, eps)
# 预测
eps_pred = self.unet(zt, t, text_emb)
return F.mse_loss(eps_pred, eps)
@torch.no_grad()
def generate(self, text, steps=50, guidance_scale=7.5):
# 文本编码
cond = self.text_encoder(text).last_hidden_state
uncond = self.text_encoder("").last_hidden_state
# 初始噪声
z = torch.randn(1, 4, 64, 64).to(self.device)
# DDIM采样
for t in reversed(range(steps)):
t_batch = torch.tensor([t], device=self.device)
# Classifier-Free Guidance
eps_uncond = self.unet(z, t_batch, uncond)
eps_cond = self.unet(z, t_batch, cond)
eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
# 去噪
z = self.ddim_step(z, eps, t)
# 解码
return self.decode(z)
15. 训练技巧
15.1 噪声调度选择
- 线性调度:经典选择
- 余弦调度:更好的低噪声区域建模
- Sigmoid调度:平衡两者
15.2 损失函数变体
python
# 简化损失
loss = F.mse_loss(eps_pred, eps)
# 加权损失(强调某些时间步)
loss = F.mse_loss(eps_pred, eps, reduction='none')
loss = (loss * weights[t]).mean()
# SNR加权
snr = alpha_bar / (1 - alpha_bar)
weights = snr / (snr + 1)
loss = (loss * weights).mean()
15.3 EMA
python
class EMA:
def __init__(self, model, decay=0.9999):
self.model = model
self.decay = decay
self.shadow = {name: param.clone() for name, param in model.named_parameters()}
def update(self):
for name, param in self.model.named_parameters():
self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param
def apply(self):
self.backup = {name: param.clone() for name, param in self.model.named_parameters()}
for name, param in self.model.named_parameters():
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
param.data = self.backup[name]
15.4 混合精度训练
python
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for x, _ in dataloader:
with autocast():
loss = ddpm.loss(x)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
16. 推理优化
16.1 采样器选择
| 采样器 | 步数 | 特点 |
|---|---|---|
| DDPM | 1000 | 经典,慢 |
| DDIM | 50-100 | 加速,可确定性 |
| DPM-Solver | 10-20 | ODE求解器 |
| UniPC | 5-10 | 无训练加速 |
16.2 量化
python
# INT8量化
model_int8 = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
16.3 缓存优化
缓存中间结果,避免重复计算。
17. 参考资料
核心论文
- DDPM: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- DDIM: "Denoising Diffusion Implicit Models" (Song et al., 2021)
- Score SDE: "Score-Based Generative Modeling through SDEs" (Song et al., 2021)
- LDM: "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., 2022)
- DALL-E 2: "Hierarchical Text-Conditional Image Generation with CLIP Latents" (Ramesh et al., 2022)
- Classifier-Free Guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)
开源项目
- Stable Diffusion: https://github.com/CompVis/stable-diffusion
- Diffusers: https://github.com/huggingface/diffusers
- OpenAI DALL-E: https://github.com/openai/DALL-E
推荐资源
- Lilian Weng: "What are Diffusion Models?"
- The Annotated Diffusion Model (HuggingFace)
- Stanford CS236: Deep Generative Models