PyTorch从零搭建DDPM:时间嵌入+UNet网络+扩散调度完整复现

摘要

扩散模型(Diffusion Models)是当前生成式AI领域最核心的技术之一,在图像生成、音频合成、分子设计等任务中展现出超越GAN和VAE的生成质量。本文从数学原理出发,严格推导前向扩散与反向去噪过程,并基于PyTorch实现一个完整的、可运行的扩散模型训练与采样代码。文章涵盖所有关键细节:噪声调度、损失函数、采样策略、常见陷阱及解决方案。全文无图,纯逻辑推导与工程实现,适合具备深度学习基础并希望深入理解扩散模型底层机制的读者。

应用场景

扩散模型的核心能力是从噪声中恢复出高保真数据分布。当前主流应用包括:

  • 文本到图像生成(如Stable Diffusion、DALL-E 3)
  • 图像超分辨率、修复、编辑
  • 音频生成(如AudioLDM)
  • 分子构象生成
  • 时间序列预测与插值
  • 三维点云生成

任何需要从随机噪声中生成高质量、多样化样本的任务,扩散模型均可适配。

核心原理

扩散模型包含两个过程:

  1. 前向扩散过程:对原始数据x_0逐步添加高斯噪声,经过T步后变为纯噪声x_T ~ N(0, I)。这是一个马尔可夫链,每一步的转移核为: q(x_t | x_{t-1}) = N(x_t; sqrt(1 - beta_t) * x_{t-1}, beta_t * I) 其中beta_t是预定义的噪声调度(noise schedule),通常从1e-4线性增加到0.02。

    利用重参数化技巧,可以直接从x_0得到任意时刻t的x_t: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon 其中alpha_t = 1 - beta_t,alpha_bar_t = prod_{i=1}^t alpha_i,epsilon ~ N(0, I)。

  2. 反向去噪过程:学习一个神经网络epsilon_theta(x_t, t)来预测添加的噪声epsilon,然后逐步去除噪声,从x_T恢复出x_0。反向过程也是一个马尔可夫链,其转移核为: p_theta(x_{t-1} | x_t) = N(x_{t-1}; mu_theta(x_t, t), sigma_t^2 * I) 其中mu_theta的推导基于变分下界,最终简化为: mu_theta(x_t, t) = (1 / sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_theta(x_t, t)) sigma_t^2 = beta_t(DDPM原始设置)或使用更复杂的调度。

损失函数:训练时最小化预测噪声与真实噪声之间的均方误差: L = E_{t, x_0, epsilon} \|\| epsilon - epsilon_theta( sqrt(alpha_bar_t) \* x_0 + sqrt(1 - alpha_bar_t) \* epsilon, t ) \|\|\^2

采样过程:从x_T ~ N(0, I)开始,对t从T到1迭代: x_{t-1} = (1 / sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_theta(x_t, t)) + sigma_t * z 其中z ~ N(0, I),当t=1时z=0。

详细步骤

1. 定义噪声调度

线性调度:beta_t从beta_1线性增加到beta_T。计算alpha_t和alpha_bar_t。

2. 构建神经网络

使用UNet结构,包含下采样、上采样和跳跃连接。输入为带噪图像x_t和时间步t。时间步通过正弦位置编码嵌入,并注入到每个残差块中。

3. 训练循环

  • 从数据集中采样x_0
  • 随机采样t ~ Uniform(1, T)
  • 采样噪声epsilon ~ N(0, I)
  • 计算x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
  • 网络预测epsilon_pred = epsilon_theta(x_t, t)
  • 计算损失MSE(epsilon, epsilon_pred)
  • 反向传播更新参数

4. 采样(推理)

  • 从标准正态分布采样x_T
  • 对t从T到1:
    • 预测噪声epsilon_pred = epsilon_theta(x_t, t)
    • 计算x_{t-1}的均值mu
    • 若t>1,添加噪声sigma_t * z
  • 返回x_0

完整可运行代码

以下代码基于PyTorch实现一个简化的扩散模型,在MNIST数据集上训练并生成手写数字。代码可直接运行,包含详细注释。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import math

# ---------- 1. 噪声调度 ----------
def linear_beta_schedule(T, beta_start=1e-4, beta_end=0.02):
    """线性噪声调度,返回beta_t, alpha_t, alpha_bar_t"""
    betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32)
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)  # 累积乘积
    return betas, alphas, alphas_bar

# ---------- 2. 时间嵌入 ----------
class SinusoidalPosEmb(nn.Module):
    """正弦位置编码,将时间步t映射为embedding向量"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        # t: [batch_size],取值范围[0, T-1]
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None].float() * emb[None, :]  # [batch, half_dim]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # [batch, dim]
        return emb

# ---------- 3. 简易UNet ----------
class SimpleUNet(nn.Module):
    """适用于MNIST的轻量UNet,输入通道1,输出通道1"""
    def __init__(self, T, time_emb_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )

        # 下采样
        self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)

        # 中间
        self.mid = nn.Conv2d(256, 256, 3, padding=1)

        # 上采样
        self.dec3 = nn.ConvTranspose2d(256 + 256, 128, 4, stride=2, padding=1)
        self.dec2 = nn.ConvTranspose2d(128 + 128, 64, 4, stride=2, padding=1)
        self.dec1 = nn.Conv2d(64 + 64, 1, 3, padding=1)

        # 时间embedding映射到各层
        self.time_proj1 = nn.Linear(time_emb_dim, 64)
        self.time_proj2 = nn.Linear(time_emb_dim, 128)
        self.time_proj3 = nn.Linear(time_emb_dim, 256)
        self.time_proj_mid = nn.Linear(time_emb_dim, 256)
        self.time_proj_d3 = nn.Linear(time_emb_dim, 128)
        self.time_proj_d2 = nn.Linear(time_emb_dim, 64)

    def forward(self, x, t):
        # x: [batch, 1, 28, 28], t: [batch]
        time_emb = self.time_mlp(t)  # [batch, time_emb_dim]

        # 编码
        e1 = self.enc1(x)  # [batch, 64, 28, 28]
        e1 = e1 + self.time_proj1(time_emb)[:, :, None, None]
        e1 = F.relu(e1)

        e2 = self.enc2(e1)  # [batch, 128, 14, 14]
        e2 = e2 + self.time_proj2(time_emb)[:, :, None, None]
        e2 = F.relu(e2)

        e3 = self.enc3(e2)  # [batch, 256, 7, 7]
        e3 = e3 + self.time_proj3(time_emb)[:, :, None, None]
        e3 = F.relu(e3)

        # 中间
        m = self.mid(e3)  # [batch, 256, 7, 7]
        m = m + self.time_proj_mid(time_emb)[:, :, None, None]
        m = F.relu(m)

        # 解码(跳跃连接)
        d3 = torch.cat([m, e3], dim=1)  # [batch, 512, 7, 7]
        d3 = self.dec3(d3)  # [batch, 128, 14, 14]
        d3 = d3 + self.time_proj_d3(time_emb)[:, :, None, None]
        d3 = F.relu(d3)

        d2 = torch.cat([d3, e2], dim=1)  # [batch, 256, 14, 14]
        d2 = self.dec2(d2)  # [batch, 64, 28, 28]
        d2 = d2 + self.time_proj_d2(time_emb)[:, :, None, None]
        d2 = F.relu(d2)

        d1 = torch.cat([d2, e1], dim=1)  # [batch, 128, 28, 28]
        out = self.dec1(d1)  # [batch, 1, 28, 28]
        return out

# ---------- 4. 扩散模型封装 ----------
class DiffusionModel:
    def __init__(self, T=1000, device='cuda'):
        self.T = T
        self.device = device
        self.betas, self.alphas, self.alphas_bar = linear_beta_schedule(T)
        self.betas = self.betas.to(device)
        self.alphas = self.alphas.to(device)
        self.alphas_bar = self.alphas_bar.to(device)

        # 预计算一些常数
        self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
        self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.posterior_variance = self.betas  # sigma_t^2 = beta_t

        self.model = SimpleUNet(T).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)

    def train_step(self, x0):
        """单步训练"""
        batch_size = x0.shape[0]
        t = torch.randint(0, self.T, (batch_size,), device=self.device).long()
        noise = torch.randn_like(x0)

        # 前向加噪
        sqrt_ab = self.sqrt_alphas_bar[t][:, None, None, None]
        sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t][:, None, None, None]
        xt = sqrt_ab * x0 + sqrt_one_minus_ab * noise

        # 预测噪声
        noise_pred = self.model(xt, t.float())

        # 损失
        loss = F.mse_loss(noise_pred, noise)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    @torch.no_grad()
    def sample(self, batch_size=16):
        """DDPM采样,从x_T逐步去噪到x_0"""
        x = torch.randn(batch_size, 1, 28, 28, device=self.device)
        for t in reversed(range(self.T)):
            t_tensor = torch.full((batch_size,), t, device=self.device, dtype=torch.float32)
            noise_pred = self.model(x, t_tensor)

            # 计算均值
            sqrt_recip_alpha = self.sqrt_recip_alphas[t]
            beta = self.betas[t]
            sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t]
            x_mean = sqrt_recip_alpha * (x - (beta / sqrt_one_minus_ab) * noise_pred)

            if t > 0:
                noise = torch.randn_like(x)
                sigma = torch.sqrt(self.posterior_variance[t])
                x = x_mean + sigma * noise
            else:
                x = x_mean
        return x

# ---------- 5. 训练与采样 ----------
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # 归一化到[-1, 1]
    ])
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

    # 模型初始化
    diffusion = DiffusionModel(T=200, device=device)  # T=200加速演示
    n_epochs = 5

    # 训练
    for epoch in range(n_epochs):
        total_loss = 0.0
        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.to(device)
            loss = diffusion.train_step(images)
            total_loss += loss
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss:.6f}')
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1} finished, Average Loss: {avg_loss:.6f}')

    # 采样
    print('Generating samples...')
    samples = diffusion.sample(batch_size=16)
    # 将样本从[-1,1]反归一化到[0,1]用于保存
    samples = (samples + 1.0) / 2.0
    samples = torch.clamp(samples, 0.0, 1.0)

    # 保存为numpy数组,可用于可视化
    samples_np = samples.cpu().numpy()
    np.save('generated_mnist.npy', samples_np)
    print(f'Samples saved to generated_mnist.npy, shape: {samples_np.shape}')

if __name__ == '__main__':
    main()

运行结果说明

代码运行后,控制台会输出每个epoch的loss,loss通常从0.5左右下降到0.05以下。采样完成后,会生成一个numpy文件generated_mnist.npy,形状为(16, 1, 28, 28),包含16张生成的手写数字图像。每张图像的值在0, 1范围内,可直接使用matplotlib的imshow显示,肉眼可辨识出清晰的数字轮廓。若增加T至1000并延长训练epoch,生成质量会显著提升,接近真实MNIST样本。

常见问题与避坑

1. 训练不收敛或loss震荡

  • 原因:学习率过大或batch size过小。建议使用Adam优化器,lr=1e-4,batch size>=64。
  • 检查:确保输入x0归一化到-1, 1(使用Normalize(0.5, 0.5)),因为噪声epsilon的尺度是标准正态,若x0尺度不匹配会导致梯度异常。

2. 生成样本全黑或全白

  • 原因:采样时未正确使用噪声调度。常见错误:忘记对x_t进行缩放,或sigma_t计算错误。
  • 检查:确保采样循环中x_mean的计算公式正确,且t>0时添加的噪声标准差为sqrt(beta_t)。

3. 生成样本模糊

  • 原因:T太小(如<100)或网络容量不足。扩散模型需要足够多的步数才能精细恢复细节。
  • 解决:增加T至1000,并加深UNet(增加通道数或层数)。

4. 显存溢出

  • 原因:batch size过大或T过大导致中间变量过多。
  • 解决:减小batch size,使用梯度累积,或使用FP16混合精度训练。

5. 训练时loss为NaN

  • 原因:数值不稳定。检查alpha_bar_t是否出现0(当T很大时累积乘积可能下溢)。
  • 解决:使用float64或对alpha_bar_t添加eps(如1e-12)。更推荐使用cosine调度替代线性调度。

6. 采样速度慢

  • 原因:DDPM需要逐步迭代T次,T=1000时很慢。
  • 解决:使用DDIM采样(确定性采样,可减少步数至50-100),或使用DPM-solver等加速方法。

7. 模型过拟合

  • 原因:数据集太小或网络容量过大。
  • 解决:增加数据增强(随机翻转、旋转),或使用dropout。

总结

本文从数学推导到工程实现完整呈现了扩散模型的核心细节。关键在于理解前向过程的闭合形式(x_t与x_0的直接关系)和反向过程的变分下界简化。训练时只需预测噪声,采样时逐步去噪。代码实现了完整的训练与采样流程,可直接运行生成MNIST数字。

扩散模型相比GAN的优势在于训练稳定、模式覆盖全面,但采样速度慢是主要瓶颈。工业界常通过DDIM、LCM、蒸馏等技术加速。理解本文的基础实现后,可进一步阅读DDPM、DDIM、Score SDE等论文,深入探索更高级的变体与应用。

建议读者在运行代码后,尝试修改T、噪声调度类型(如cosine调度)、网络结构,观察生成质量的变化,从而建立对扩散模型各组件作用的直观理解。

相关推荐
Bigfish_coding1 小时前
前端转agent-【python】-06 长期记忆(向量数据库 + 嵌入)
人工智能
小林ixn1 小时前
别再手写Prompt了!用AI Loop实现自动化自我迭代,效率提升10倍
人工智能·自动化运维
说了很好1 小时前
逐行注释DDPM源码:正向加噪、逆向去噪、MSE损失全流程复现
人工智能
Dilee1 小时前
Spring AI 1.1.7 接入 MCP:Filesystem Server 最小 Demo
人工智能·后端
Token炼金师1 小时前
大模型推理超参数原理详解
人工智能
Token炼金师1 小时前
大模型训练超参数:从Loss曲面到收敛策略的底层逻辑
人工智能
后端小肥肠1 小时前
Skill 囤了一堆却用不起来?我用 Codex 写了个整理神器
人工智能·agent
魏祖潇1 小时前
从"会聊天"到"能干活":用 OpenCode 给自己找个 AI 搭子
人工智能
子兮曰1 小时前
AI Coding Method Map:一张图看懂 AI 编程的完整链路
前端·人工智能·后端