去噪扩散:从随机噪声到高保真图像的数学之路

扩散模型通过逐步向数据添加高斯噪声学习数据分布,再通过逆向去噪过程生成新样本。本文从变分推断和随机微分方程两个视角推导训练目标,剖析DDPM、DDIM、Score-based模型的统一框架,并给出生产级实现中的关键工程细节。

1. 核心思想:正向加噪与逆向去噪

扩散模型的核心直觉:如果能把一张图像逐步变成纯噪声,那么学会这个过程的逆过程,就能从噪声中重建图像。

flowchart LR A["x0 [原始图像]"] -->|"q(xt|x0)"| B["x1 [轻微噪声]"] B -->|"逐步加噪"| C["xT [纯高斯噪声]"] D["xT [随机噪声]"] -->|"pθ(xt-1|xt)"| E["xT-1"] E -->|"逐步去噪"| F["x0 [生成图像]"] classDef default fill:#000000,stroke:#ffffff,color:#ffffff,stroke-width:2px class A,F fill:#1a1a2e,stroke:#6c5ce7 class C,D fill:#1a1a2e,stroke:#fd79a8
python 复制代码
# 来源:DDPM (Ho et al., 2020) 核心加噪过程
import torch
import torch.nn as nn

def forward_process(x0, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
    """
    正向加噪:q(xt|x0) = N(xt; sqrt(alpha_cumprod_t)*x0, (1-alpha_cumprod_t)*I)
    一步到位直接从x0跳到xt,无需逐步迭代
    """
    # 从预计算的调度系数中获取当前时间步的值
    sqrt_alpha = sqrt_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    sqrt_one_minus_alpha = sqrt_one_minus_alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    
    # 采样标准高斯噪声
    noise = torch.randn_like(x0)
    
    # 一步生成xt:原始图像的缩放 + 噪声的缩放
    xt = sqrt_alpha * x0 + sqrt_one_minus_alpha * noise
    return xt, noise

2. 变分下界推导:为什么优化MSE

DDPM的训练目标本质上是最小化模型预测的噪声与真实噪声之间的均方误差,这背后是变分下界(ELBO)的推导。

flowchart TD A["log pθ(x0) [对数似然]"] --> B["≥ ELBO"] B --> C["Σt KL(q(xt-1|xt,x0) || pθ(xt-1|xt))"] C --> D["重构项 t=0"] C --> E["先验匹配项 t=T"] C --> F["去噪匹配项 1
python 复制代码
# 来源:DDPM (Ho et al., 2020) / diffusers/schedulers/scheduling_ddpm.py
# 简化后的核心训练目标:预测噪声的MSE损失
def ddpm_training_loss(model, x0, t, noise_scheduler):
    """
    训练目标:最小化 ||eps - eps_theta(sqrt(alpha_t)*x0 + sqrt(1-alpha_t)*eps, t)||^2
    直觉:给定带噪图像xt和时间步t,模型预测添加的噪声eps
    """
    # 采样噪声并生成带噪图像
    noise = torch.randn_like(x0)
    xt = noise_scheduler.add_noise(x0, noise, t)
    
    # 模型预测噪声
    noise_pred = model(xt, t)
    
    # MSE损失:预测噪声 vs 真实噪声
    loss = nn.functional.mse_loss(noise_pred, noise)
    return loss

3. 噪声调度:α的衰减策略

噪声调度决定了信噪比随时间的衰减曲线,直接影响生成质量。

flowchart LR A["线性调度 βt线性增长"] --> D["SNR 急剧衰减 前几步丢失大量信息"] B["余弦调度 αt按余弦衰减"] --> E["SNR 缓慢衰减 信息保留更均匀"] B --> F["DDPM 默认使用 生成质量提升2-3%"] classDef default fill:#000000,stroke:#ffffff,color:#ffffff,stroke-width:2px class E fill:#1a1a2e,stroke:#6c5ce7
python 复制代码
# 来源:diffusers/schedulers/scheduling_ddpm.py (HuggingFace)
import numpy as np

def cosine_beta_schedule(timesteps, s=0.008):
    """
    余弦噪声调度:相比线性调度,在早期保留更多图像信息
    公式:alpha_cumprod_t = f(t) / f(0),其中 f(t) = cos((t/T + s)/(1+s) * pi/2)^2
    """
    steps = timesteps + 1
    t = np.linspace(0, timesteps, steps, dtype=np.float64)
    # 计算累积余弦值
    alphas_cumprod = np.cos((t / timesteps + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    
    # 从累积乘积推导出βt
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    # 裁剪到合理范围,防止数值不稳定
    return np.clip(betas, 0.0001, 0.9999)

4. 采样策略:DDPM vs DDIM

DDPM采样需要数百到上千步,DDIM通过非马尔可夫过程实现跳步采样,在10-50步内达到近似质量。

sequenceDiagram participant N as 噪声 xT participant M as 噪声预测器 εθ participant D as 去噪步骤 participant I as 生成图像 x0 N->>M: 输入 (xT, t=T) M->>D: 输出 εθ(xT, T) D->>D: 计算 xT-1 (或跳步到 xT-k) D->>M: 输入 (xT-1, t=T-1) Note over D: DDIM: 确定性采样<br/>DDPM: 随机采样 D->>I: 最终输出 x0
python 复制代码
# 来源:DDIM (Song et al., 2021) / diffusers/schedulers/scheduling_ddim.py
def ddim_step(model_output, timestep, sample, scheduler):
    """
    DDIM采样核心:确定性映射,支持跳步
    关键差异:去掉了DDPM中的随机噪声项,变为隐式概率模型
    """
    # 计算累积乘积项
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    # 前一时间步的alpha(支持跳步)
    alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - scheduler.step_size] \
        if timestep >= scheduler.step_size else scheduler.final_alpha_cumprod
    
    # 从噪声预测推导x0的估计值
    # x0_pred = (xt - sqrt(1-alpha_t) * eps_theta) / sqrt(alpha_t)
    beta_prod_t = 1 - alpha_prod_t
    pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
    
    # DDIM确定性更新公式
    # x_{t-1} = sqrt(alpha_{t-1}) * x0_pred + sqrt(1-alpha_{t-1}) * eps_theta
    pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
    prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
    return prev_sample

5. Score-Based模型:随机微分方程视角

扩散模型可以统一到随机微分方程(SDE)框架下,正向过程是一个SDE,逆向过程是time-reversed SDE。

flowchart TD A[&#34;正向SDE&#160;dx = f(x,t)dt + g(t)dw&#34;] --> B[&#34;逆向SDE&#160;dx = [f-g²∇x log p]dt + g(t)dẁ&#34;] B --> C[&#34;得分函数 sθ = ∇x log pθ(xt)&#34;] C --> D[&#34;训练目标&#160;||sθ - ∇xt log p(xt|x0)||²&#34;] D --> E[&#34;去噪得分匹配&#160;避免直接计算得分&#34;] classDef default fill:#000000,stroke:#ffffff,color:#ffffff,stroke-width:2px class E fill:#1a1a2e,stroke:#6c5ce7
python 复制代码
# 来源:Score-based SDE (Song et al., 2021)
def score_matching_loss(score_model, x0, t, noise_scheduler):
    """
    去噪得分匹配:避免直接计算数据分布的得分函数
    关键等式:grad_xt log q(xt|x0) = -(xt - sqrt(alpha_t)*x0) / (1-alpha_t)
    因此只需让score_model预测 -eps_theta / sqrt(1-alpha_t)
    """
    noise = torch.randn_like(x0)
    # 计算带噪样本
    xt = noise_scheduler.add_noise(x0, noise, t)
    
    # 目标得分:对数密度的梯度
    # q(xt|x0) = N(sqrt(alpha_t)*x0, (1-alpha_t)*I)
    # grad_xt log q(xt|x0) = -(xt - sqrt(alpha_t)*x0) / (1-alpha_t) = -eps/(1-alpha_t)
    sigma = (1 - noise_scheduler.alphas_cumprod[t]) ** 0.5
    target_score = -noise / sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    
    # 模型预测得分
    predicted_score = score_model(xt, t)
    
    # 加权MSE:不同时间步的权重不同
    loss = torch.mean((predicted_score - target_score) ** 2)
    return loss

6. U-Net架构与时间步嵌入

扩散模型的骨干网络是U-Net,通过时间步嵌入和交叉注意力实现条件生成。

flowchart TD A[&#34;输入&#160;xt + t&#34;] --> B[&#34;ResBlock&#160;+ 时间嵌入&#34;] B --> C[&#34;ResBlock&#160;+ 时间嵌入&#34;] C --> D[&#34;Self-Attention&#160;+ 交叉注意力&#34;] D --> E[&#34;ResBlock&#160;+ 时间嵌入&#34;] E --> F[&#34;ResBlock&#160;+ 时间嵌入&#34;] F --> G[&#34;输出&#160;εθ(xt,t)&#34;] H[&#34;条件输入&#160;text/image&#34;] --> D classDef default fill:#000000,stroke:#ffffff,color:#ffffff,stroke-width:2px class G fill:#1a1a2e,stroke:#6c5ce7
python 复制代码
# 来源:diffusers/models/unet_2d_condition.py (HuggingFace)
import torch
import torch.nn as nn

class SinusoidalPositionEmbedding(nn.Module):
    """
    正弦位置编码:将标量时间步t映射为高维向量
    不同频率的正弦/余弦函数确保每个时间步有唯一表示
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        half_dim = self.dim // 2
        # 频率递减的正弦位置编码
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None].float() * emb[None, :]
        # 拼接正弦和余弦分量
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb

class ResBlock(nn.Module):
    """
    带时间嵌入的残差块:将时间信息注入到空间特征中
    """
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        # 时间嵌入投影:将时间向量映射到特征维度
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()
    
    def forward(self, x, t_emb):
        h = self.norm1(x)
        h = nn.functional.silu(h)
        h = self.conv1(h)
        # 关键:将时间嵌入加到空间特征上
        h = h + self.time_mlp(nn.functional.silu(t_emb))[:, :, None, None]
        h = self.norm2(h)
        h = nn.functional.silu(h)
        h = self.conv2(h)
        return h + self.residual_conv(x)

7. 生产级优化:Classifier-Free Guidance

Classifier-Free Guidance (CFG) 是目前提升生成质量最实用的技术,无需额外分类器。

flowchart LR A[&#34;无条件预测&#160;εθ(xt, ∅)&#34;] --> C[&#34;ε̂ = εθ(xt,∅) + s·(εθ(xt,c) - εθ(xt,∅))&#34;] B[&#34;条件预测&#160;εθ(xt, c)&#34;] --> C C --> D[&#34;s=1: 标准采样&#160;s>1: 增强条件&#160;s
python 复制代码
# 来源:Classifier-Free Guidance (Ho & Salimans, 2022)
def classifier_free_guidance(noise_pred_cond, noise_pred_uncond, guidance_scale):
    """
    CFG核心公式:ε̂ = ε_uncond + s · (ε_cond - ε_uncond)
    直觉:放大条件预测与无条件预测的差值,增强条件控制力度
    guidance_scale=1: 退化为标准采样
    guidance_scale=7.5: Stable Diffusion默认值
    """
    # 计算条件偏移量
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
    return noise_pred

8. 训练与推理的工程实践

flowchart TD A[&#34;数据准备&#160;图像+文本对&#34;] --> B[&#34;VAE编码&#160;像素→潜空间&#34;] B --> C[&#34;U-Net训练&#160;预测噪声&#34;] C --> D[&#34;EMA权重更新&#160;指数移动平均&#34;] D --> E[&#34;推理阶段&#34;] E --> F[&#34;CLIP编码文本&#34;] F --> G[&#34;DDIM采样&#160;50步&#34;] G --> H[&#34;VAE解码&#160;潜空间→像素&#34;] classDef default fill:#000000,stroke:#ffffff,color:#ffffff,stroke-width:2px class H fill:#1a1a2e,stroke:#6c5ce7
python 复制代码
# 来源:Stable Diffusion训练流程 (Rombach et al., 2022)
def train_step(vae, unet, text_encoder, images, texts, noise_scheduler, optimizer):
    """
    Stable Diffusion单步训练流程
    关键:在VAE潜空间而非像素空间做扩散,大幅降低计算量
    """
    # 1. 将图像编码到潜空间(压缩比8x)
    latents = vae.encode(images).latent_dist.sample()
    latents = latents * 0.18215  # 缩放因子,稳定训练
    
    # 2. 采样时间步和噪声
    t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (images.shape[0],))
    noise = torch.randn_like(latents)
    noisy_latents = noise_scheduler.add_noise(latents, noise, t)
    
    # 3. 文本编码
    text_embeddings = text_encoder(texts)
    
    # 4. 模型前向传播
    noise_pred = unet(noisy_latents, t, text_embeddings)
    
    # 5. 计算损失并反向传播
    loss = nn.functional.mse_loss(noise_pred, noise)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return loss.item()

总结

扩散模型通过变分推断将生成建模转化为噪声预测问题,DDPM/DDIM/Score-based模型在SDE框架下统一。工程上,潜空间扩散+CFG+余弦调度是当前生产部署的标准组合,在10-50步采样内即可达到高质量生成。

相关推荐
这个DBA有点耶1 小时前
AI写的SQL跑崩了生产库,这锅谁背?
数据库·人工智能·程序员
阿里云大数据AI技术1 小时前
阿里云 EMR AI 助手正式发布:从问答工具到全栈智能运维助手
运维·人工智能
AlbertZein2 小时前
别被模型宣传骗了,真实 Agent 任务一跑就知道
aigc·openai·ai编程
Larcher2 小时前
从零搭建 MCP 服务——让 AI 拥有无限扩展能力
人工智能·程序员
zzzzzz3102 小时前
你的 AI 写的 React 烂透了?这个 8000+ Star 的开源工具能揪出 90% 的「Agent 屎山」
人工智能
小星AI2 小时前
MCP协议超详细教程,从入门到实战
人工智能
小星AI2 小时前
Kimi Code CLI 超详细教程,附源码
人工智能·agent
小碗细面3 小时前
让 AI Agent 真正读懂你的资料:我开源了 source-skill-pipeline
aigc·ai编程·claude
牧艺3 小时前
Cursor Rules / Skills 分层设计:让 Agent 像「团队新同事」
前端·人工智能·cursor