摘要
扩散模型(Diffusion Models)是当前生成式AI领域最核心的技术之一,在图像生成、音频合成、分子设计等任务中展现出超越GAN和VAE的生成质量。本文从数学原理出发,严格推导前向扩散与反向去噪过程,并基于PyTorch实现一个完整的、可运行的扩散模型训练与采样代码。文章涵盖所有关键细节:噪声调度、损失函数、采样策略、常见陷阱及解决方案。全文无图,纯逻辑推导与工程实现,适合具备深度学习基础并希望深入理解扩散模型底层机制的读者。
应用场景
扩散模型的核心能力是从噪声中恢复出高保真数据分布。当前主流应用包括:
- 文本到图像生成(如Stable Diffusion、DALL-E 3)
- 图像超分辨率、修复、编辑
- 音频生成(如AudioLDM)
- 分子构象生成
- 时间序列预测与插值
- 三维点云生成
任何需要从随机噪声中生成高质量、多样化样本的任务,扩散模型均可适配。
核心原理
扩散模型包含两个过程:
-
前向扩散过程:对原始数据x_0逐步添加高斯噪声,经过T步后变为纯噪声x_T ~ N(0, I)。这是一个马尔可夫链,每一步的转移核为: q(x_t | x_{t-1}) = N(x_t; sqrt(1 - beta_t) * x_{t-1}, beta_t * I) 其中beta_t是预定义的噪声调度(noise schedule),通常从1e-4线性增加到0.02。
利用重参数化技巧,可以直接从x_0得到任意时刻t的x_t: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon 其中alpha_t = 1 - beta_t,alpha_bar_t = prod_{i=1}^t alpha_i,epsilon ~ N(0, I)。
-
反向去噪过程:学习一个神经网络epsilon_theta(x_t, t)来预测添加的噪声epsilon,然后逐步去除噪声,从x_T恢复出x_0。反向过程也是一个马尔可夫链,其转移核为: p_theta(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I) 其中mu_theta的推导基于变分下界,最终简化为: mu_theta(x_t, t) = (1 / sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_theta(x_t, t)) sigma_t^2 = beta_t(DDPM原始设置)或使用更复杂的调度。
损失函数:训练时最小化预测噪声与真实噪声之间的均方误差: L = E_{t, x_0, epsilon} \|\| epsilon - epsilon_theta( sqrt(alpha_bar_t) \* x_0 + sqrt(1 - alpha_bar_t) \* epsilon, t ) \|\|\^2
采样过程:从x_T ~ N(0, I)开始,对t从T到1迭代: x_{t-1} = (1 / sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_theta(x_t, t)) + sigma_t * z 其中z ~ N(0, I),当t=1时z=0。
详细步骤
1. 定义噪声调度
线性调度:beta_t从beta_1线性增加到beta_T。计算alpha_t和alpha_bar_t。
2. 构建神经网络
使用UNet结构,包含下采样、上采样和跳跃连接。输入为带噪图像x_t和时间步t。时间步通过正弦位置编码嵌入,并注入到每个残差块中。
3. 训练循环
- 从数据集中采样x_0
- 随机采样t ~ Uniform(1, T)
- 采样噪声epsilon ~ N(0, I)
- 计算x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
- 网络预测epsilon_pred = epsilon_theta(x_t, t)
- 计算损失MSE(epsilon, epsilon_pred)
- 反向传播更新参数
4. 采样(推理)
- 从标准正态分布采样x_T
- 对t从T到1:
- 预测噪声epsilon_pred = epsilon_theta(x_t, t)
- 计算x_{t-1}的均值mu
- 若t>1,添加噪声sigma_t * z
- 返回x_0
完整可运行代码
以下代码基于PyTorch实现一个简化的扩散模型,在MNIST数据集上训练并生成手写数字。代码可直接运行,包含详细注释。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import math
# ---------- 1. 噪声调度 ----------
def linear_beta_schedule(T, beta_start=1e-4, beta_end=0.02):
"""线性噪声调度,返回beta_t, alpha_t, alpha_bar_t"""
betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0) # 累积乘积
return betas, alphas, alphas_bar
# ---------- 2. 时间嵌入 ----------
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码,将时间步t映射为embedding向量"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
# t: [batch_size],取值范围[0, T-1]
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None].float() * emb[None, :] # [batch, half_dim]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # [batch, dim]
return emb
# ---------- 3. 简易UNet ----------
class SimpleUNet(nn.Module):
"""适用于MNIST的轻量UNet,输入通道1,输出通道1"""
def __init__(self, T, time_emb_dim=128):
super().__init__()
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU(),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
# 下采样
self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)
# 中间
self.mid = nn.Conv2d(256, 256, 3, padding=1)
# 上采样
self.dec3 = nn.ConvTranspose2d(256 + 256, 128, 4, stride=2, padding=1)
self.dec2 = nn.ConvTranspose2d(128 + 128, 64, 4, stride=2, padding=1)
self.dec1 = nn.Conv2d(64 + 64, 1, 3, padding=1)
# 时间embedding映射到各层
self.time_proj1 = nn.Linear(time_emb_dim, 64)
self.time_proj2 = nn.Linear(time_emb_dim, 128)
self.time_proj3 = nn.Linear(time_emb_dim, 256)
self.time_proj_mid = nn.Linear(time_emb_dim, 256)
self.time_proj_d3 = nn.Linear(time_emb_dim, 128)
self.time_proj_d2 = nn.Linear(time_emb_dim, 64)
def forward(self, x, t):
# x: [batch, 1, 28, 28], t: [batch]
time_emb = self.time_mlp(t) # [batch, time_emb_dim]
# 编码
e1 = self.enc1(x) # [batch, 64, 28, 28]
e1 = e1 + self.time_proj1(time_emb)[:, :, None, None]
e1 = F.relu(e1)
e2 = self.enc2(e1) # [batch, 128, 14, 14]
e2 = e2 + self.time_proj2(time_emb)[:, :, None, None]
e2 = F.relu(e2)
e3 = self.enc3(e2) # [batch, 256, 7, 7]
e3 = e3 + self.time_proj3(time_emb)[:, :, None, None]
e3 = F.relu(e3)
# 中间
m = self.mid(e3) # [batch, 256, 7, 7]
m = m + self.time_proj_mid(time_emb)[:, :, None, None]
m = F.relu(m)
# 解码(跳跃连接)
d3 = torch.cat([m, e3], dim=1) # [batch, 512, 7, 7]
d3 = self.dec3(d3) # [batch, 128, 14, 14]
d3 = d3 + self.time_proj_d3(time_emb)[:, :, None, None]
d3 = F.relu(d3)
d2 = torch.cat([d3, e2], dim=1) # [batch, 256, 14, 14]
d2 = self.dec2(d2) # [batch, 64, 28, 28]
d2 = d2 + self.time_proj_d2(time_emb)[:, :, None, None]
d2 = F.relu(d2)
d1 = torch.cat([d2, e1], dim=1) # [batch, 128, 28, 28]
out = self.dec1(d1) # [batch, 1, 28, 28]
return out
# ---------- 4. 扩散模型封装 ----------
class DiffusionModel:
def __init__(self, T=1000, device='cuda'):
self.T = T
self.device = device
self.betas, self.alphas, self.alphas_bar = linear_beta_schedule(T)
self.betas = self.betas.to(device)
self.alphas = self.alphas.to(device)
self.alphas_bar = self.alphas_bar.to(device)
# 预计算一些常数
self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
self.posterior_variance = self.betas # sigma_t^2 = beta_t
self.model = SimpleUNet(T).to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
def train_step(self, x0):
"""单步训练"""
batch_size = x0.shape[0]
t = torch.randint(0, self.T, (batch_size,), device=self.device).long()
noise = torch.randn_like(x0)
# 前向加噪
sqrt_ab = self.sqrt_alphas_bar[t][:, None, None, None]
sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t][:, None, None, None]
xt = sqrt_ab * x0 + sqrt_one_minus_ab * noise
# 预测噪声
noise_pred = self.model(xt, t.float())
# 损失
loss = F.mse_loss(noise_pred, noise)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
@torch.no_grad()
def sample(self, batch_size=16):
"""DDPM采样,从x_T逐步去噪到x_0"""
x = torch.randn(batch_size, 1, 28, 28, device=self.device)
for t in reversed(range(self.T)):
t_tensor = torch.full((batch_size,), t, device=self.device, dtype=torch.float32)
noise_pred = self.model(x, t_tensor)
# 计算均值
sqrt_recip_alpha = self.sqrt_recip_alphas[t]
beta = self.betas[t]
sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t]
x_mean = sqrt_recip_alpha * (x - (beta / sqrt_one_minus_ab) * noise_pred)
if t > 0:
noise = torch.randn_like(x)
sigma = torch.sqrt(self.posterior_variance[t])
x = x_mean + sigma * noise
else:
x = x_mean
return x
# ---------- 5. 训练与采样 ----------
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
# 模型初始化
diffusion = DiffusionModel(T=200, device=device) # T=200加速演示
n_epochs = 5
# 训练
for epoch in range(n_epochs):
total_loss = 0.0
for batch_idx, (images, _) in enumerate(dataloader):
images = images.to(device)
loss = diffusion.train_step(images)
total_loss += loss
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss:.6f}')
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch+1} finished, Average Loss: {avg_loss:.6f}')
# 采样
print('Generating samples...')
samples = diffusion.sample(batch_size=16)
# 将样本从[-1,1]反归一化到[0,1]用于保存
samples = (samples + 1.0) / 2.0
samples = torch.clamp(samples, 0.0, 1.0)
# 保存为numpy数组,可用于可视化
samples_np = samples.cpu().numpy()
np.save('generated_mnist.npy', samples_np)
print(f'Samples saved to generated_mnist.npy, shape: {samples_np.shape}')
if __name__ == '__main__':
main()
运行结果说明
代码运行后,控制台会输出每个epoch的loss,loss通常从0.5左右下降到0.05以下。采样完成后,会生成一个numpy文件generated_mnist.npy,形状为(16, 1, 28, 28),包含16张生成的手写数字图像。每张图像的值在0, 1范围内,可直接使用matplotlib的imshow显示,肉眼可辨识出清晰的数字轮廓。若增加T至1000并延长训练epoch,生成质量会显著提升,接近真实MNIST样本。
常见问题与避坑
1. 训练不收敛或loss震荡
- 原因:学习率过大或batch size过小。建议使用Adam优化器,lr=1e-4,batch size>=64。
- 检查:确保输入x0归一化到-1, 1(使用Normalize(0.5, 0.5)),因为噪声epsilon的尺度是标准正态,若x0尺度不匹配会导致梯度异常。
2. 生成样本全黑或全白
- 原因:采样时未正确使用噪声调度。常见错误:忘记对x_t进行缩放,或sigma_t计算错误。
- 检查:确保采样循环中x_mean的计算公式正确,且t>0时添加的噪声标准差为sqrt(beta_t)。
3. 生成样本模糊
- 原因:T太小(如<100)或网络容量不足。扩散模型需要足够多的步数才能精细恢复细节。
- 解决:增加T至1000,并加深UNet(增加通道数或层数)。
4. 显存溢出
- 原因:batch size过大或T过大导致中间变量过多。
- 解决:减小batch size,使用梯度累积,或使用FP16混合精度训练。
5. 训练时loss为NaN
- 原因:数值不稳定。检查alpha_bar_t是否出现0(当T很大时累积乘积可能下溢)。
- 解决:使用float64或对alpha_bar_t添加eps(如1e-12)。更推荐使用cosine调度替代线性调度。
6. 采样速度慢
- 原因:DDPM需要逐步迭代T次,T=1000时很慢。
- 解决:使用DDIM采样(确定性采样,可减少步数至50-100),或使用DPM-solver等加速方法。
7. 模型过拟合
- 原因:数据集太小或网络容量过大。
- 解决:增加数据增强(随机翻转、旋转),或使用dropout。
总结
本文从数学推导到工程实现完整呈现了扩散模型的核心细节。关键在于理解前向过程的闭合形式(x_t与x_0的直接关系)和反向过程的变分下界简化。训练时只需预测噪声,采样时逐步去噪。代码实现了完整的训练与采样流程,可直接运行生成MNIST数字。
扩散模型相比GAN的优势在于训练稳定、模式覆盖全面,但采样速度慢是主要瓶颈。工业界常通过DDIM、LCM、蒸馏等技术加速。理解本文的基础实现后,可进一步阅读DDPM、DDIM、Score SDE等论文,深入探索更高级的变体与应用。
建议读者在运行代码后,尝试修改T、噪声调度类型(如cosine调度)、网络结构,观察生成质量的变化,从而建立对扩散模型各组件作用的直观理解。