扩散模型 (Diffusion Models) 详解

目录

  1. 概述
  2. 生成模型发展史
  3. 核心思想
  4. 数学基础
  5. DDPM详解
  6. DDIM加速采样
  7. Score-Based模型
  8. 条件生成
  9. [Latent Diffusion与Stable Diffusion](#Latent Diffusion与Stable Diffusion)
  10. [Classifier-Free Guidance](#Classifier-Free Guidance)
  11. 图像生成应用
  12. 视频生成
  13. 3D与音频生成
  14. 完整代码实现
  15. 训练技巧
  16. 推理优化
  17. 参考资料

1. 概述

1.1 什么是扩散模型

扩散模型是一类基于概率的生成模型,通过学习逐步去噪过程来生成数据。其核心包括:

  • 前向过程:逐步向数据添加噪声,直至变为纯高斯噪声
  • 反向过程:学习从噪声中逐步恢复原始数据

1.2 发展历程

时间 事件
2015 Sohl-Dickstein首次提出扩散模型概念
2020 DDPM实现高质量图像生成
2021 DDIM加速采样;Score SDE统一理论
2022 Latent Diffusion/Stable Diffusion;DALL-E 2
2023 SDXL;Consistency Models;视频生成
2024 Sora视频生成;Flow Matching

2. 生成模型发展史

2.1 主流生成模型对比

模型 优点 缺点
GAN 生成快、质量高 训练不稳定、模式坍塌
VAE 有隐空间、可训练 生成模糊
Flow 精确似然 计算量大、架构受限
自回归 精确建模 生成慢
Diffusion 训练稳定、质量高 生成较慢

2.2 为什么选择扩散模型

  • 训练过程简单稳定(只预测噪声)
  • 生成质量目前最高
  • 不易模式坍塌
  • 理论基础坚实

3. 核心思想

3.1 直观理解

前向过程:照片逐渐褪色变成噪声

反向过程:从噪声中恢复出清晰照片

3.2 关键公式

前向:

复制代码
x_t = √(ᾱ_t) * x_0 + √(1-ᾱ_t) * ε

反向(学习):

复制代码
x_{t-1} = μ_θ(x_t, t) + σ_t * z

训练目标:预测噪声

复制代码
L = E[||ε - ε_θ(x_t, t)||²]

4. 数学基础

4.1 高斯分布

一元:N(μ, σ²)

多元:N(μ, Σ)

4.2 KL散度

复制代码
D_KL(p||q) = ∫ p(x) log(p(x)/q(x)) dx

高斯分布间的KL散度有解析解。

4.3 得分函数

复制代码
s(x) = ∇_x log p(x)

不需要归一化常数,是Score-Based模型的基础。

4.4 马尔可夫链

扩散过程是马尔可夫链:x_t只依赖x_{t-1}


5. DDPM详解

5.1 前向过程

复制代码
q(x_t|x_{t-1}) = N(x_t; √(1-β_t)x_{t-1}, β_t I)

性质:可以直接从x_0计算x_t
x_t = √(ᾱ_t)x_0 + √(1-ᾱ_t)ε

5.2 反向过程

复制代码
p_θ(x_{t-1}|x_t) = N(x_{t-1}; μ_θ(x_t,t), σ_t²I)

神经网络预测噪声ε_θ,然后计算均值:
μ_θ = (1/√α_t)(x_t - β_t/√(1-ᾱ_t) * ε_θ)

5.3 训练目标

简化损失:

复制代码
L_simple = E[||ε - ε_θ(x_t, t)||²]

5.4 训练算法

复制代码
repeat:
    采样x_0 ~ q(x_0)
    采样t ~ Uniform(1,T)
    采样ε ~ N(0,I)
    计算x_t = √(ᾱ_t)x_0 + √(1-ᾱ_t)ε
    计算损失 ||ε - ε_θ(x_t, t)||²
    梯度下降更新θ

5.5 采样算法

复制代码
x_T ~ N(0,I)
for t = T,...,1:
    z ~ N(0,I) if t>1 else z=0
    x_{t-1} = μ_θ(x_t,t) + σ_t * z
return x_0

6. DDIM加速采样

6.1 核心思想

DDPM采样需要T步(通常1000步),太慢。

DDIM定义一族非马尔可夫前向过程,支持跳步采样。

6.2 DDIM采样公式

复制代码
x_{t-1} = √(ᾱ_{t-1}) * predicted_x0 + √(1-ᾱ_{t-1}-σ²) * predicted_noise + σ * z

其中:
predicted_x0 = (x_t - √(1-ᾱ_t)*ε_θ) / √(ᾱ_t)

6.3 确定性采样

当σ=0时,采样过程完全确定(无随机性),便于隐空间插值。

6.4 加速效果

  • DDPM:1000步
  • DDIM:50-100步(质量相近)
  • 进一步:10-20步(质量略有下降)

7. Score-Based模型

7.1 得分匹配

学习数据分布的得分函数s_θ(x) ≈ ∇_x log p(x)

损失函数:

复制代码
L = E[||s_θ(x_t, t) - ∇_{x_t} log q(x_t|x_0)||²]

关键联系:预测噪声与预测得分等价

复制代码
s_θ(x_t, t) = -ε_θ(x_t, t) / √(1-ᾱ_t)

7.2 SDE框架

前向SDE:

复制代码
dx = f(x,t)dt + g(t)dw

反向SDE:

复制代码
dx = [f(x,t) - g(t)²∇_x log p_t(x)]dt + g(t)dw̄

7.3 三种VP/VE/子VP SDE

  • VP (Variance Preserving):对应DDPM
  • VE (Variance Exploding):对应SMLD
  • sub-VP:介于两者之间

8. 条件生成

8.1 Classifier Guidance

训练一个分类器p(y|x_t),用其梯度引导生成:

复制代码
ε̃ = ε_θ(x_t,t) - √(1-ᾱ_t) * s * ∇_{x_t} log p(y|x_t)

8.2 Classifier-Free Guidance

同时训练条件和无条件模型:

复制代码
ε̃ = ε_θ(x_t,t,∅) + s * (ε_θ(x_t,t,c) - ε_θ(x_t,t,∅))

s>1增强条件效果,常用s=7.5

8.3 文本条件

使用CLIP或T5编码文本,通过交叉注意力注入UNet。


9. Latent Diffusion与Stable Diffusion

9.1 潜空间扩散

在像素空间做扩散计算量太大。解决方案:

  1. 训练VAE将图像压缩到潜空间
  2. 在潜空间做扩散
  3. 用VAE解码回像素空间

9.2 Stable Diffusion架构

复制代码
文本 → [CLIP/T5] → 文本嵌入
                        ↓
噪声 → [UNet+交叉注意力] → 去噪潜变量 → [VAE解码器] → 图像

关键组件:

  • VAE:4-8倍下采样
  • UNet:带交叉注意力的时间条件网络
  • 文本编码器:CLIP或OpenCLIP

9.3 训练流程

复制代码
1. 训练VAE(重建+KL+感知损失)
2. 冻结VAE,训练UNet
   - 图像→潜变量z_0
   - 添加噪声得到z_t
   - UNet预测噪声,以文本嵌入为条件

10. Classifier-Free Guidance详解

10.1 训练

以概率p_uncond将条件置空(通常p_uncond=0.1-0.2)

10.2 推理

复制代码
ε_guided = ε_uncond + w * (ε_cond - ε_uncond)

w=1等价于标准条件生成,w>1增强条件遵循。

10.3 效果

  • w=1:多样但可能不遵循条件
  • w=7-12:很好的平衡
  • w很大:严格遵循条件但多样性降低

11. 图像生成应用

11.1 文本到图像

代表模型:DALL-E 2/3、Stable Diffusion、Midjourney

流程:

复制代码
文本描述 → 文本编码 → 扩散生成 → 高分辨率图像

11.2 图像编辑

  • Inpainting:遮罩区域重新生成
  • Outpainting:扩展图像边界
  • Style Transfer:风格迁移
  • Image-to-Image:图像到图像转换

11.3 超分辨率

复制代码
低分辨率图像 → 条件扩散 → 高分辨率图像

12. 视频生成

12.1 方法

  • 时空UNet:3D卷积+时间注意力
  • 图像模型扩展:在图像模型上加时间层
  • 级联生成:先低分辨率再超分

12.2 代表模型

  • Sora (OpenAI):最长60秒
  • Runway Gen-2/3
  • Pika Labs
  • Stable Video Diffusion

13. 3D与音频生成

13.1 3D生成

  • DreamFusion:2D扩散+NeRF
  • Magic3D:两阶段生成
  • Point-E:点云生成

13.2 音频生成

  • AudioLDM:音频潜空间扩散
  • MusicLDM:音乐生成
  • Riffusion:频谱图扩散

14. 完整代码实现

14.1 DDPM实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ========== 时间嵌入 ==========
class SinusoidalPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=device) * -emb)
        emb = time[:, None] * emb[None, :]
        return torch.cat([emb.sin(), emb.cos()], dim=-1)

# ========== 残差块 ==========
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch))
        self.norm2 = nn.GroupNorm(32, out_ch)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t):
        h = F.silu(self.norm1(x))
        h = self.conv1(h)
        h = h + self.time_proj(t)[:, :, None, None]
        h = F.silu(self.norm2(h))
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.skip(x)

# ========== 注意力块 ==========
class AttnBlock(nn.Module):
    def __init__(self, ch, heads=8):
        super().__init__()
        self.norm = nn.GroupNorm(32, ch)
        self.attn = nn.MultiheadAttention(ch, heads, batch_first=True)

    def forward(self, x):
        b, c, h, w = x.shape
        residual = x
        x = self.norm(x).view(b, c, -1).transpose(1, 2)
        x, _ = self.attn(x, x, x)
        x = x.transpose(1, 2).view(b, c, h, w)
        return x + residual

# ========== UNet ==========
class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, base_ch=128,
                 ch_mult=(1,2,4,8), attn_res=(2,), res_blocks=2):
        super().__init__()

        time_dim = base_ch * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbedding(base_ch),
            nn.Linear(base_ch, time_dim), nn.SiLU(),
            nn.Linear(time_dim, time_dim)
        )

        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)

        # 编码器
        self.downs = nn.ModuleList()
        chs = [base_ch]
        ch = base_ch
        for level, mult in enumerate(ch_mult):
            out = base_ch * mult
            for _ in range(res_blocks):
                layers = [ResBlock(ch, out, time_dim)]
                if level in attn_res:
                    layers.append(AttnBlock(out))
                self.downs.append(nn.ModuleList(layers))
                ch = out
                chs.append(ch)
            if level < len(ch_mult) - 1:
                self.downs.append(nn.ModuleList([nn.Conv2d(ch, ch, 3, 2, 1)]))
                chs.append(ch)

        # 中间
        self.mid = nn.ModuleList([
            ResBlock(ch, ch, time_dim),
            AttnBlock(ch),
            ResBlock(ch, ch, time_dim)
        ])

        # 解码器
        self.ups = nn.ModuleList()
        for level, mult in reversed(list(enumerate(ch_mult))):
            out = base_ch * mult
            for i in range(res_blocks + 1):
                skip = chs.pop()
                layers = [ResBlock(ch + skip, out, time_dim)]
                if level in attn_res:
                    layers.append(AttnBlock(out))
                if level > 0 and i == res_blocks:
                    layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
                self.ups.append(nn.ModuleList(layers))
                ch = out

        self.out = nn.Sequential(
            nn.GroupNorm(32, ch), nn.SiLU(),
            nn.Conv2d(ch, out_ch, 3, padding=1)
        )

    def forward(self, x, t):
        t = self.time_mlp(t.float())
        x = self.in_conv(x)

        skips = [x]
        for block in self.downs:
            for layer in block:
                x = layer(x, t) if isinstance(layer, ResBlock) else layer(x)
            skips.append(x)

        for layer in self.mid:
            x = layer(x, t) if isinstance(layer, ResBlock) else layer(x)

        for block in self.ups:
            x = torch.cat([x, skips.pop()], 1)
            for layer in block:
                if isinstance(layer, ResBlock):
                    x = layer(x, t)
                elif isinstance(layer, AttnBlock):
                    x = layer(x)
                else:
                    x = layer(x)

        return self.out(x)

# ========== DDPM ==========
class DDPM:
    def __init__(self, model, T=1000, beta1=1e-4, beta2=0.02):
        self.model = model
        self.T = T
        self.betas = torch.linspace(beta1, beta2, T)
        self.alphas = 1 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, 0)
        self.alpha_bar_prev = F.pad(self.alpha_bar[:-1], (1,0), value=1.0)
        self.posterior_var = self.betas * (1 - self.alpha_bar_prev) / (1 - self.alpha_bar)

    def q_sample(self, x0, t, eps=None):
        if eps is None:
            eps = torch.randn_like(x0)
        a = self.alpha_bar[t][:, None, None, None].to(x0.device)
        return a.sqrt() * x0 + (1 - a).sqrt() * eps

    def loss(self, x0):
        b = x0.shape[0]
        t = torch.randint(0, self.T, (b,), device=x0.device)
        eps = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps)
        eps_pred = self.model(xt, t)
        return F.mse_loss(eps_pred, eps)

    @torch.no_grad()
    def sample(self, shape, device):
        x = torch.randn(shape, device=device)
        for t in reversed(range(self.T)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
            eps_pred = self.model(x, t_batch)

            a = self.alphas[t]
            ab = self.alpha_bar[t]
            b = self.betas[t]

            mean = (1 / a.sqrt()) * (x - b / (1 - ab).sqrt() * eps_pred)

            if t > 0:
                var = self.posterior_var[t]
                x = mean + var.sqrt() * torch.randn_like(x)
            else:
                x = mean
        return x

14.2 Latent Diffusion实现

python 复制代码
class LatentDiffusion(nn.Module):
    def __init__(self, vae, unet, text_encoder):
        super().__init__()
        self.vae = vae
        self.unet = unet
        self.text_encoder = text_encoder
        self.latent_scale = 0.18215  # SD的缩放因子

    def encode(self, x):
        posterior = self.vae.encode(x).latent_dist
        return posterior.sample() * self.latent_scale

    def decode(self, z):
        return self.vae.decode(z / self.latent_scale).sample

    def forward(self, x, text):
        # 编码
        z0 = self.encode(x)

        # 文本条件
        text_emb = self.text_encoder(text).last_hidden_state

        # 扩散训练
        b = z0.shape[0]
        t = torch.randint(0, 1000, (b,), device=z0.device)
        eps = torch.randn_like(z0)
        zt = self.ddpm.q_sample(z0, t, eps)

        # 预测
        eps_pred = self.unet(zt, t, text_emb)
        return F.mse_loss(eps_pred, eps)

    @torch.no_grad()
    def generate(self, text, steps=50, guidance_scale=7.5):
        # 文本编码
        cond = self.text_encoder(text).last_hidden_state
        uncond = self.text_encoder("").last_hidden_state

        # 初始噪声
        z = torch.randn(1, 4, 64, 64).to(self.device)

        # DDIM采样
        for t in reversed(range(steps)):
            t_batch = torch.tensor([t], device=self.device)

            # Classifier-Free Guidance
            eps_uncond = self.unet(z, t_batch, uncond)
            eps_cond = self.unet(z, t_batch, cond)
            eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

            # 去噪
            z = self.ddim_step(z, eps, t)

        # 解码
        return self.decode(z)

15. 训练技巧

15.1 噪声调度选择

  • 线性调度:经典选择
  • 余弦调度:更好的低噪声区域建模
  • Sigmoid调度:平衡两者

15.2 损失函数变体

python 复制代码
# 简化损失
loss = F.mse_loss(eps_pred, eps)

# 加权损失(强调某些时间步)
loss = F.mse_loss(eps_pred, eps, reduction='none')
loss = (loss * weights[t]).mean()

# SNR加权
snr = alpha_bar / (1 - alpha_bar)
weights = snr / (snr + 1)
loss = (loss * weights).mean()

15.3 EMA

python 复制代码
class EMA:
    def __init__(self, model, decay=0.9999):
        self.model = model
        self.decay = decay
        self.shadow = {name: param.clone() for name, param in model.named_parameters()}

    def update(self):
        for name, param in self.model.named_parameters():
            self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param

    def apply(self):
        self.backup = {name: param.clone() for name, param in self.model.named_parameters()}
        for name, param in self.model.named_parameters():
            param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            param.data = self.backup[name]

15.4 混合精度训练

python 复制代码
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for x, _ in dataloader:
    with autocast():
        loss = ddpm.loss(x)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

16. 推理优化

16.1 采样器选择

采样器 步数 特点
DDPM 1000 经典,慢
DDIM 50-100 加速,可确定性
DPM-Solver 10-20 ODE求解器
UniPC 5-10 无训练加速

16.2 量化

python 复制代码
# INT8量化
model_int8 = torch.quantization.quantize_dynamic(
    model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)

16.3 缓存优化

缓存中间结果,避免重复计算。


17. 参考资料

核心论文

  1. DDPM: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
  2. DDIM: "Denoising Diffusion Implicit Models" (Song et al., 2021)
  3. Score SDE: "Score-Based Generative Modeling through SDEs" (Song et al., 2021)
  4. LDM: "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., 2022)
  5. DALL-E 2: "Hierarchical Text-Conditional Image Generation with CLIP Latents" (Ramesh et al., 2022)
  6. Classifier-Free Guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)

开源项目

推荐资源

  • Lilian Weng: "What are Diffusion Models?"
  • The Annotated Diffusion Model (HuggingFace)
  • Stanford CS236: Deep Generative Models

相关推荐
凌波粒6 小时前
深度学习入门(鱼书)第3章笔记——神经网络
笔记·深度学习·神经网络
Dymc10 小时前
【论文解析】用神经网络给优化器“热身“——面向 UAV-UGV 交接任务的学习加速轨迹规划
人工智能·神经网络·学习
星恒随风10 小时前
从机器学习基础到 MLP(下):神经网络为什么能起作用?
人工智能·笔记·神经网络·学习·机器学习
ZHW_AI课题组1 天前
基于MLP神经网络的红酒品质回归预测
人工智能·神经网络·机器学习·回归
人工智能培训1 天前
探析数字孪生的核心特性与应用价值
人工智能·深度学习·神经网络·机器学习·生成对抗网络
Yunzenn1 天前
深度分析字节最新研究cola-DLM第 06 章:分块因果 DiT 先验 —— 在隐空间里做 Flow Matching
人工智能·rnn·深度学习·神经网络·生成对抗网络·架构·transformer
通信小呆呆1 天前
维度分数傅里叶时频图 + 图神经网络:突破传统时频分析的目标识别与杂波抑制新框架
人工智能·神经网络·算法
EnCi Zheng1 天前
09aa-偏置是什么?
人工智能·pytorch·python·深度学习·神经网络
烟雨江南7851 天前
嘈杂工业场景下的自适应VAD与双码本声纹识别鉴权系统:基于端侧轻量化神经网络与向量量化(VQ)重构
人工智能·深度学习·神经网络·算法·语音识别