一致性模型深度解析

目录

  • [第一篇:一致性模型 --- 单步生成的数学基础](#第一篇:一致性模型 — 单步生成的数学基础)
  • [第二篇:一致性蒸馏 --- 从预训练扩散模型学习](#第二篇:一致性蒸馏 — 从预训练扩散模型学习)
  • 第三篇:一致性模型的前沿发展
  • 参考文献

第一篇:一致性模型 --- 单步生成的数学基础

1. 引言

扩散模型(Diffusion Models)在图像生成领域取得了巨大成功,但其采样过程需要数百步迭代去噪,计算代价高昂。即使是加速采样方法(如 DDIM、DPM-Solver),通常也需要 10-50 步才能获得高质量样本。

一致性模型(Consistency Models, Song et al., 2023) 提出了一种全新的生成范式:学习一个函数,将 ODE 轨迹上的任意点直接映射到轨迹起点(数据)。这一设计实现了:

  1. 单步生成:一次前向传播即可生成高质量样本
  2. 多步精炼:支持多步采样以换取更高质量
  3. 数学优雅:基于概率流 ODE 的几何性质,定义简洁

2. 理论基础 --- 概率流 ODE

2.1 扩散过程的 ODE 视角

扩散模型有两种等价的连续时间表述:

随机微分方程(SDE)(前向):

d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt

概率流 ODE(PF-ODE)(前向):

d x t d t = f ( x t , t ) − 1 2 g ( t ) 2 ∇ log ⁡ p t ( x t ) \frac{d\mathbf{x}_t}{dt} = \mathbf{f}(\mathbf{x}_t, t) - \frac{1}{2} g(t)^2 \nabla \log p_t(\mathbf{x}_t) dtdxt=f(xt,t)−21g(t)2∇logpt(xt)

其中 ∇ log ⁡ p t ( x t ) \nabla \log p_t(\mathbf{x}_t) ∇logpt(xt) 是分数函数(score function)

关键性质 :PF-ODE 与 SDE 具有相同的边际分布 p t ( x ) p_t(\mathbf{x}) pt(x),但 PF-ODE 是确定性的------给定初始条件 x 0 \mathbf{x}_0 x0,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 唯一确定。

2.2 方差保持(VP)扩散的 PF-ODE

对于 VP 扩散(DDPM):

f ( x t , t ) = − 1 2 β ( t ) x t , g ( t ) = β ( t ) \mathbf{f}(\mathbf{x}_t, t) = -\frac{1}{2} \beta(t) \mathbf{x}_t, \quad g(t) = \sqrt{\beta(t)} f(xt,t)=−21β(t)xt,g(t)=β(t)

PF-ODE 为:

d x t d t = − 1 2 β ( t ) x t − 1 2 β ( t ) ∇ log ⁡ p t ( x t ) \frac{d\mathbf{x}_t}{dt} = -\frac{1}{2} \beta(t) \mathbf{x}_t - \frac{1}{2} \beta(t) \nabla \log p_t(\mathbf{x}_t) dtdxt=−21β(t)xt−21β(t)∇logpt(xt)

定义去噪函数 D ( x t , t ) \mathbf{D}(\mathbf{x}_t, t) D(xt,t):

D ( x t , t ) = x t + σ t 2 ∇ log ⁡ p t ( x t ) α t \mathbf{D}(\mathbf{x}_t, t) = \frac{\mathbf{x}_t + \sigma_t^2 \nabla \log p_t(\mathbf{x}_t)}{\alpha_t} D(xt,t)=αtxt+σt2∇logpt(xt)

其中 α t = exp ⁡ ( − 1 2 ∫ 0 t β ( s ) d s ) \alpha_t = \exp\left(-\frac{1}{2}\int_0^t \beta(s) ds\right) αt=exp(−21∫0tβ(s)ds), σ t 2 = 1 − α t 2 \sigma_t^2 = 1 - \alpha_t^2 σt2=1−αt2。

则 PF-ODE 可改写为:

d x t d t = − α ˙ t α t x t + α ˙ t σ t α t D ( x t , t ) \frac{d\mathbf{x}_t}{dt} = -\frac{\dot{\alpha}_t}{\alpha_t} \mathbf{x}_t + \frac{\dot{\alpha}_t \sigma_t}{\alpha_t} \mathbf{D}(\mathbf{x}_t, t) dtdxt=−αtα˙txt+αtα˙tσtD(xt,t)

2.3 PF-ODE 轨迹的性质

定理 :对于 VP 扩散的 PF-ODE,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 具有以下性质:

  1. 唯一性 :给定 x 0 \mathbf{x}_0 x0,轨迹唯一确定
  2. 连续性 : x t \mathbf{x}_t xt 关于 t t t 连续
  3. 边界行为 : x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata, x T ≈ N ( 0 , I ) \mathbf{x}_T \approx \mathcal{N}(0, \mathbf{I}) xT≈N(0,I)

直觉 :PF-ODE 将数据分布 p data p_{\text{data}} pdata 连续地"变形"为高斯噪声 N ( 0 , I ) \mathcal{N}(0, \mathbf{I}) N(0,I)。生成过程就是从噪声沿轨迹"走回去"。


3. 一致性模型的核心定义

3.1 一致性函数

定义 (一致性函数 f θ f_\theta fθ):对于 PF-ODE 的任意轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,一致性函数满足:

f θ ( x t , t ) = f θ ( x t ′ , t ′ ) , ∀ t , t ′ ∈ 0 , T f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t'}, t'), \quad \forall t, t' \in 0, T fθ(xt,t)=fθ(xt′,t′),∀t,t′∈0,T

轨迹上的所有点映射到同一个值 ------轨迹的起点 x 0 \mathbf{x}0 x0(或等价地, x ϵ \mathbf{x}\epsilon xϵ 以避免数值问题)。

边界条件

f θ ( x ϵ , ϵ ) = x ϵ f_\theta(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon fθ(xϵ,ϵ)=xϵ

其中 ϵ > 0 \epsilon > 0 ϵ>0 是一个很小的常数(如 ϵ = 0.002 \epsilon = 0.002 ϵ=0.002),用于避免 t = 0 t = 0 t=0 处的数值不稳定性。

3.2 自洽性(Self-Consistency)

一致性模型的核心约束是自洽性

f θ ( x t , t ) = f θ ( x t + δ , t + δ ) , ∀ δ > 0 f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t+\delta}, t + \delta), \quad \forall \delta > 0 fθ(xt,t)=fθ(xt+δ,t+δ),∀δ>0

这意味着:沿 ODE 轨迹前进任意步,一致性函数的输出不变。

几何直觉:一致性函数将整条 ODE 轨迹"压缩"为一个点(起点)。不同轨迹映射到不同的点,但同一轨迹上的所有点映射到同一个点。

3.3 生成过程

利用一致性函数,生成过程极其简单:

单步生成

x 0 = f θ ( x T , T ) , x T ∼ N ( 0 , I ) \mathbf{x}0 = f\theta(\mathbf{x}_T, T), \quad \mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I}) x0=fθ(xT,T),xT∼N(0,I)

一次前向传播,从噪声直接生成数据!

多步生成(迭代精炼):

x ϵ = f θ ( x T , T ) \mathbf{x}\epsilon = f\theta(\mathbf{x}_T, T) xϵ=fθ(xT,T)

x t n + 1 = f θ ( x t n , t n ) + t n 2 − ϵ 2 ⋅ z , z ∼ N ( 0 , I ) \mathbf{x}{t{n+1}} = f_\theta(\mathbf{x}_{t_n}, t_n) + \sqrt{t_n^2 - \epsilon^2} \cdot \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) xtn+1=fθ(xtn,tn)+tn2−ϵ2 ⋅z,z∼N(0,I)

多步生成通过在轨迹上添加少量噪声再映射,实现逐步精炼。


4. 训练方法 --- 一致性蒸馏

4.1 一致性蒸馏(Consistency Distillation)

核心思想:利用预训练的扩散模型(教师)生成轨迹上的相邻点对,训练一致性模型(学生)使它们映射到相同值。

训练数据 :对于每个数据点 x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata,通过 PF-ODE 求解器获得相邻点对 ( x t n + 1 , x t n ) (\mathbf{x}{t{n+1}}, \mathbf{x}_{t_n}) (xtn+1,xtn)。

损失函数

L CD ( θ , θ − ) = E d ( f θ ( x t n + 1 , t n + 1 ) ,   f θ − ( x t n , t n ) ) \mathcal{L}_{\text{CD}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_{t_{n+1}}, t_{n+1}), \\, f_{\\theta\^-}(\\mathbf{x}_{t_n}, t_n) \\right) \\right LCD(θ,θ−)=Ed(fθ(xtn+1,tn+1),fθ−(xtn,tn))

其中:

  • f θ f_\theta fθ 是在线网络(正在更新)
  • f θ − f_{\theta^-} fθ− 是目标网络(EMA 更新)
  • d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(⋅,⋅) 是距离度量(如 L 2 L_2 L2、LPIPS)
  • x t n \mathbf{x}{t_n} xtn 由 x t n + 1 \mathbf{x}{t_{n+1}} xtn+1 通过一步 PF-ODE 求解器获得

目标网络更新

θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ

其中 μ \mu μ 是 EMA 衰减率(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)。

4.2 数学正确性

定理 :当 f θ f_\theta fθ 满足自洽性时, L CD = 0 \mathcal{L}_{\text{CD}} = 0 LCD=0。

证明

由自洽性:

f θ ( x t n + 1 , t n + 1 ) = f θ ( x t n , t n ) f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}) = f_\theta(\mathbf{x}_{t_n}, t_n) fθ(xtn+1,tn+1)=fθ(xtn,tn)

当 θ − = θ \theta^- = \theta θ−=θ 时(EMA 已收敛):

d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x t n , t n ) ) = d ( f θ ( x t n + 1 , t n + 1 ) , f θ ( x t n , t n ) ) = 0 d(f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_{\theta^-}(\mathbf{x}{t_n}, t_n)) = d(f\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_\theta(\mathbf{x}_{t_n}, t_n)) = 0 d(fθ(xtn+1,tn+1),fθ−(xtn,tn))=d(fθ(xtn+1,tn+1),fθ(xtn,tn))=0

■ \blacksquare ■

4.3 一步 PF-ODE 求解

为了获得相邻点对,需要一步 PF-ODE 求解。使用 DDIM 求解器:

x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)

其中 D ϕ \mathbf{D}_\phi Dϕ 是预训练的去噪模型(教师), α t , σ t \alpha_t, \sigma_t αt,σt 是噪声调度参数。


5. 完整可运行实现

5.1 一致性模型核心实现

python 复制代码
"""
一致性模型 (Consistency Models) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Tuple, Optional, List
from dataclasses import dataclass


@dataclass
class ConsistencyConfig:
    """一致性模型配置"""
    data_dim: int = 2
    hidden_dim: int = 256
    time_dim: int = 64
    num_layers: int = 6
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: float = 7.0
    num_timesteps: int = 40  # 时间步离散化数量
    ema_decay: float = 0.9999


def get_sigmas_karras(
    sigma_min: float, sigma_max: float, rho: float, num_steps: int
) -> torch.Tensor:
    """Karras 噪声调度 (Karras et al., 2022)"""
    inv_rho = 1.0 / rho
    steps = torch.arange(num_steps, dtype=torch.float64) / (num_steps - 1)
    sigmas = (sigma_max ** inv_rho + steps * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
    return sigmas.float()


def get_alpha_sigma(t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """VP 扩散的 alpha 和 sigma 参数"""
    alpha = torch.cos(t * math.pi / 2)
    sigma = torch.sin(t * math.pi / 2)
    return alpha, sigma


class SinusoidalTimeEmbedding(nn.Module):
    """正弦时间嵌入"""

    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t.unsqueeze(-1) * emb.unsqueeze(0)
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)


class ConsistencyModel(nn.Module):
    """一致性模型网络"""

    def __init__(self, config: ConsistencyConfig):
        super().__init__()
        self.config = config

        self.time_embed = SinusoidalTimeEmbedding(config.time_dim)
        self.input_proj = nn.Linear(config.data_dim + config.time_dim, config.hidden_dim)

        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(config.hidden_dim),
                nn.SiLU(),
                nn.Linear(config.hidden_dim, config.hidden_dim),
                nn.LayerNorm(config.hidden_dim),
                nn.SiLU(),
                nn.Linear(config.hidden_dim, config.hidden_dim),
            )
            for _ in range(config.num_layers)
        ])

        self.output_proj = nn.Linear(config.hidden_dim, config.data_dim)
        nn.init.zeros_(self.output_proj.weight)
        nn.init.zeros_(self.output_proj.bias)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        x: (B, D) 带噪数据
        t: (B,) 时间步 (0=数据, 1=噪声)
        """
        t_emb = self.time_embed(t)
        h = torch.cat([x, t_emb], dim=-1)
        h = self.input_proj(h)

        for block in self.blocks:
            h = h + block(h)

        return self.output_proj(h)

5.2 一致性蒸馏训练

python 复制代码
class ConsistencyDistillation:
    """一致性蒸馏训练器"""

    def __init__(
        self,
        model: ConsistencyModel,
        teacher_model: nn.Module,
        config: ConsistencyConfig,
        device: torch.device,
    ):
        self.model = model
        self.teacher = teacher_model
        self.config = config
        self.device = device

        # 创建目标网络 (EMA)
        self.target_model = ConsistencyModel(config).to(device)
        self.target_model.load_state_dict(model.state_dict())

        # 冻结教师模型
        for param in self.teacher.parameters():
            param.requires_grad = False

        # 噪声调度
        self.sigmas = get_sigmas_karras(
            config.sigma_min, config.sigma_max, config.rho, config.num_timesteps
        ).to(device)

        self.optimizer = torch.optim.AdamW(
            model.parameters(), lr=1e-4, weight_decay=0.0
        )

    def add_noise(
        self, x: torch.Tensor, t: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """添加噪声: x_t = alpha_t * x + sigma_t * noise"""
        alpha, sigma = get_alpha_sigma(t)
        noise = torch.randn_like(x)
        x_t = alpha.unsqueeze(-1) * x + sigma.unsqueeze(-1) * noise
        return x_t, noise

    def one_step_denoise(
        self, x_t: torch.Tensor, t: torch.Tensor, t_prev: torch.Tensor
    ) -> torch.Tensor:
        """使用教师模型进行一步去噪 (DDIM 求解器)"""
        with torch.no_grad():
            # 教师模型预测去噪结果
            x_denoised = self.teacher(x_t, t)

            # DDIM 一步更新
            alpha_t, sigma_t = get_alpha_sigma(t)
            alpha_prev, sigma_prev = get_alpha_sigma(t_prev)

            # x_{t_prev} = alpha_{t_prev} * x_denoised + sigma_{t_prev} * (x_t - alpha_t * x_denoised) / sigma_t
            x_prev = (
                alpha_prev.unsqueeze(-1) * x_denoised
                + sigma_prev.unsqueeze(-1) * (x_t - alpha_t.unsqueeze(-1) * x_denoised) / sigma_t.unsqueeze(-1)
            )

        return x_prev

    def compute_loss(
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, dict]:
        """计算一致性蒸馏损失"""
        B = x.shape[0]

        # 随机采样时间步对 (t_{n+1}, t_n)
        # 使用 Karras 调度的离散时间步
        n = torch.randint(0, self.config.num_timesteps - 1, (B,), device=self.device)
        t = self.sigmas[n]  # t_{n+1}
        t_prev = self.sigmas[n + 1]  # t_n (更接近数据)

        # 添加噪声
        x_t, noise = self.add_noise(x, t)

        # 教师模型一步去噪
        x_prev = self.one_step_denoise(x_t, t, t_prev)

        # 在线网络预测
        pred_online = self.model(x_t, t)

        # 目标网络预测
        with torch.no_grad():
            pred_target = self.target_model(x_prev, t_prev)

        # 一致性损失 (L2)
        loss = F.mse_loss(pred_online, pred_target)

        metrics = {
            "loss": loss.item(),
            "t_mean": t.mean().item(),
            "t_prev_mean": t_prev.mean().item(),
        }

        return loss, metrics

    def train_step(self, x: torch.Tensor) -> dict:
        """执行一步训练"""
        self.model.train()
        self.optimizer.zero_grad()

        loss, metrics = self.compute_loss(x)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        # EMA 更新目标网络
        with torch.no_grad():
            for param, target_param in zip(
                self.model.parameters(), self.target_model.parameters()
            ):
                target_param.data.mul_(self.config.ema_decay).add_(
                    param.data, alpha=1 - self.config.ema_decay
                )

        return metrics

    @torch.no_grad()
    def sample(
        self, num_samples: int, num_steps: int = 1
    ) -> torch.Tensor:
        """生成样本"""
        self.model.eval()

        # 从噪声开始
        x = torch.randn(num_samples, self.config.data_dim, device=self.device)

        if num_steps == 1:
            # 单步生成
            t = torch.ones(num_samples, device=self.device)
            x = self.model(x, t)
        else:
            # 多步生成
            timesteps = get_sigmas_karras(
                self.config.sigma_min, self.config.sigma_max,
                self.config.rho, num_steps
            ).to(self.device)

            for i in range(len(timesteps) - 1):
                t = timesteps[i].expand(num_samples)
                x = self.model(x, t)

                # 添加少量噪声 (用于迭代精炼)
                if i < len(timesteps) - 2:
                    noise = torch.randn_like(x)
                    t_next = timesteps[i + 1]
                    _, sigma = get_alpha_sigma(t_next)
                    x = x + sigma * noise * 0.5

        return x

5.3 实验代码

python 复制代码
def experiment_consistency_2d():
    """在 2D 双月数据上训练一致性模型"""
    from sklearn.datasets import make_moons

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 生成数据
    data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    data = torch.tensor(data, dtype=torch.float32).to(device)

    config = ConsistencyConfig(data_dim=2, hidden_dim=256, num_layers=6)

    # 教师模型 (预训练的去噪模型)
    teacher = ConsistencyModel(config).to(device)

    # 学生模型 (一致性模型)
    student = ConsistencyModel(config).to(device)

    # 初始化教师模型 (模拟预训练)
    # 实际应用中,这里加载预训练的扩散模型
    teacher.load_state_dict(student.state_dict())

    trainer = ConsistencyDistillation(student, teacher, config, device)

    # 训练
    print("一致性蒸馏训练...")
    batch_size = 256
    losses = []

    for step in range(5000):
        idx = torch.randint(0, data.shape[0], (batch_size,))
        x_batch = data[idx]

        metrics = trainer.train_step(x_batch)
        losses.append(metrics["loss"])

        if (step + 1) % 1000 == 0:
            avg_loss = np.mean(losses[-100:])
            print(f"Step {step+1} | Loss: {avg_loss:.6f}")

    # 生成样本
    print("\n生成样本...")
    samples_1step = trainer.sample(num_samples=1000, num_steps=1)
    samples_5step = trainer.sample(num_samples=1000, num_steps=5)

    return trainer, samples_1step.cpu(), samples_5step.cpu()

6. 一致性模型的理论性质

6.1 表达能力定理

定理(Song et al., 2023):一致性模型的表达能力足够强大,可以精确表示任意 PF-ODE 轨迹。

证明思路 :对于任意 PF-ODE 轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,定义 f ∗ ( x t , t ) = x 0 f^*(\mathbf{x}t, t) = \mathbf{x}0 f∗(xt,t)=x0。这是一个合法的一致性函数,且满足边界条件 f ∗ ( x ϵ , ϵ ) = x ϵ f^*(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon f∗(xϵ,ϵ)=xϵ。

当神经网络 f θ f_\theta fθ 的容量足够大时,它可以任意逼近 f ∗ f^* f∗。 ■ \blacksquare ■

6.2 单步生成的误差分析

定理:单步生成的误差上界为:

E ∥ x 0 − f θ ( x T , T ) ∥ 2 ≤ L CD ( θ , θ − ) + O ( Δ t ) \mathbb{E}\left\\\|\\mathbf{x}_0 - f_\\theta(\\mathbf{x}_T, T)\\\|\^2\\right \leq \mathcal{L}_{\text{CD}}(\theta, \theta^-) + O(\Delta t) E∥x0−fθ(xT,T)∥2≤LCD(θ,θ−)+O(Δt)

其中 Δ t \Delta t Δt 是时间离散化的步长。

直觉 :一致性蒸馏损失越小,单步生成的质量越高。时间步离散化越细( Δ t \Delta t Δt 越小),误差越小。

6.3 与扩散模型的关系

特性 扩散模型 一致性模型
生成步数 10-1000 步 1-5 步
训练目标 去噪 自洽性
采样方式 迭代去噪 直接映射
理论基础 SDE/PF-ODE PF-ODE 轨迹
质量-速度权衡 高质量但慢 快但略低质量

7. 一致性模型数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    一致性模型 (Consistency Models) 数学总结                                ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 概率流 ODE:                                                                         ║
║     dx_t/dt = f(x_t, t) - ½g(t)² ∇log p_t(x_t)                                       ║
║     轨迹: x_0 → x_T (数据 → 噪声)                                                     ║
║                                                                                        ║
║  2. 一致性函数定义:                                                                     ║
║     f_θ(x_t, t) = f_θ(x_{t'}, t')    ∀t, t' ∈ [0, T]  (同一轨迹→同一输出)            ║
║     边界条件: f_θ(x_ε, ε) = x_ε                                                       ║
║                                                                                        ║
║  3. 单步生成:                                                                           ║
║     x_0 = f_θ(x_T, T),    x_T ~ N(0, I)                                               ║
║     一次前向传播, 从噪声直接生成数据                                                    ║
║                                                                                        ║
║  4. 一致性蒸馏损失:                                                                     ║
║     L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ]                   ║
║     θ⁻ = EMA(θ),  x_{t_n} 由一步 DDIM 求解器获得                                       ║
║                                                                                        ║
║  5. 正确性:                                                                             ║
║     f_θ 满足自洽性 ⟹ L_CD = 0                                                          ║
║                                                                                        ║
║  6. 误差上界:                                                                           ║
║     E[‖x₀ - f_θ(x_T, T)‖²] ≤ L_CD + O(Δt)                                            ║
║                                                                                        ║
║  7. VP 扩散参数:                                                                        ║
║     α_t = cos(πt/2),  σ_t = sin(πt/2)                                                 ║
║     x_t = α_t·x₀ + σ_t·ε,  ε ~ N(0, I)                                               ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第二篇:一致性蒸馏 --- 从预训练扩散模型学习

1. 引言

一致性蒸馏(Consistency Distillation)是训练一致性模型的主要方法,它利用预训练的扩散模型作为教师,通过蒸馏的方式训练学生模型。本篇深入分析蒸馏过程的数学细节和实现技巧。


2. 蒸馏过程的数学分析

2.1 教师-学生框架

教师模型 D ϕ \mathbf{D}_\phi Dϕ:预训练的去噪模型,满足:

D ϕ ( x t , t ) ≈ E x 0 ∣ x t \mathbf{D}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ(xt,t)≈Ex0∣xt

学生模型 f θ f_\theta fθ:一致性模型,满足自洽性。

蒸馏目标:利用教师模型生成轨迹上的相邻点对,训练学生模型使它们映射到相同值。

2.2 DDIM 求解器

DDIM(Denoising Diffusion Implicit Models, Song et al., 2020)是一种确定性的 PF-ODE 求解器:

x t n = α t n ( x t n + 1 − σ t n + 1 ϵ ^ α t n + 1 ) ⏟ x ^ 0 + σ t n ϵ ^ \mathbf{x}{t{n}} = \alpha_{t_n} \underbrace{\left(\frac{\mathbf{x}{t{n+1}} - \sigma_{t_{n+1}} \hat{\mathbf{\epsilon}}}{\alpha_{t_{n+1}}}\right)}_{\hat{\mathbf{x}}0} + \sigma{t_n} \hat{\mathbf{\epsilon}} xtn=αtnx^0 (αtn+1xtn+1−σtn+1ϵ^)+σtnϵ^

其中 ϵ ^ = ( x t n + 1 − α t n + 1 x ^ 0 ) / σ t n + 1 \hat{\mathbf{\epsilon}} = (\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \hat{\mathbf{x}}0) / \sigma{t_{n+1}} ϵ^=(xtn+1−αtn+1x^0)/σtn+1 是预测的噪声。

等价形式(使用去噪函数):

x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)

2.3 时间步离散化

Karras 调度(Karras et al., 2022):

t i = ( t max ⁡ 1 / ρ + i N − 1 ( t min ⁡ 1 / ρ − t max ⁡ 1 / ρ ) ) ρ t_i = \left(t_{\max}^{1/\rho} + \frac{i}{N-1} (t_{\min}^{1/\rho} - t_{\max}^{1/\rho})\right)^\rho ti=(tmax1/ρ+N−1i(tmin1/ρ−tmax1/ρ))ρ

其中 ρ = 7 \rho = 7 ρ=7 控制时间步的分布(更多步集中在噪声端)。

直觉:噪声端的 ODE 曲率更大,需要更细的时间步离散化。


3. 训练技巧

3.1 EMA 更新

目标网络使用 EMA(Exponential Moving Average)更新:

θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ

为什么需要 EMA:如果直接使用在线网络作为目标,训练会不稳定------两个相同的网络相互"追逐",导致发散。EMA 提供了稳定的目标。

3.2 距离度量

L2 距离

d ( x , y ) = ∥ x − y ∥ 2 2 d(\mathbf{x}, \mathbf{y}) = \|\mathbf{x} - \mathbf{y}\|_2^2 d(x,y)=∥x−y∥22

LPIPS 距离(用于图像):

d LPIPS ( x , y ) = ∑ l ∥ feat l ( x ) − feat l ( y ) ∥ 2 2 d_{\text{LPIPS}}(\mathbf{x}, \mathbf{y}) = \sum_l \|\text{feat}_l(\mathbf{x}) - \text{feat}_l(\mathbf{y})\|_2^2 dLPIPS(x,y)=l∑∥featl(x)−featl(y)∥22

LPIPS 使用预训练网络的特征距离,更符合人类感知。

3.3 梯度裁剪

一致性蒸馏的梯度可能很大,需要梯度裁剪:

g ← g ⋅ min ⁡ ( 1 , c ∥ g ∥ ) \mathbf{g} \leftarrow \mathbf{g} \cdot \min\left(1, \frac{c}{\|\mathbf{g}\|}\right) g←g⋅min(1,∥g∥c)

其中 c c c 是裁剪阈值(如 c = 1.0 c = 1.0 c=1.0)。


4. 一致性蒸馏的收敛性

4.1 收敛定理

定理(非正式):在适当条件下,一致性蒸馏收敛到教师模型的 PF-ODE 轨迹。

条件

  1. 教师模型足够好( D ϕ ≈ E x 0 ∣ x t \mathbf{D}_\phi \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ≈Ex0∣xt
  2. 时间步离散化足够细( Δ t → 0 \Delta t \to 0 Δt→0)
  3. 网络容量足够大
  4. EMA 衰减率 μ \mu μ 适当(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)

4.2 蒸馏误差的分解

总蒸馏误差可分解为:

Total Error = Approximation Error ⏟ 网络容量不足 + Discretization Error ⏟ 时间步离散化 + Optimization Error ⏟ 训练不充分 \text{Total Error} = \underbrace{\text{Approximation Error}}{\text{网络容量不足}} + \underbrace{\text{Discretization Error}}{\text{时间步离散化}} + \underbrace{\text{Optimization Error}}_{\text{训练不充分}} Total Error=网络容量不足 Approximation Error+时间步离散化 Discretization Error+训练不充分 Optimization Error

  • 逼近误差 : O ( 1 / n ) O(1/\sqrt{n}) O(1/n ), n n n 是网络参数量
  • 离散化误差 : O ( Δ t ) O(\Delta t) O(Δt), Δ t \Delta t Δt 是时间步间隔
  • 优化误差:随训练步数减小

5. 一致性蒸馏数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    一致性蒸馏 (Consistency Distillation) 数学总结                         ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. DDIM 求解器:                                                                        ║
║     x_{t_n} = α_{t_n}·D_φ(x_{t_{n+1}}, t_{n+1})                                      ║
║             + σ_{t_n}·(x_{t_{n+1}} - α_{t_{n+1}}·D_φ) / σ_{t_{n+1}}                  ║
║                                                                                        ║
║  2. 蒸馏损失:                                                                           ║
║     L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ]                   ║
║     x_{t_n} 由一步 DDIM 获得, θ⁻ = EMA(θ)                                             ║
║                                                                                        ║
║  3. EMA 更新:                                                                           ║
║     θ⁻ ← μ·θ⁻ + (1-μ)·θ,    μ = 0.9999                                               ║
║     提供稳定的训练目标                                                                  ║
║                                                                                        ║
║  4. 时间步调度 (Karras):                                                                ║
║     t_i = (t_max^{1/ρ} + i/(N-1)·(t_min^{1/ρ} - t_max^{1/ρ}))^ρ                      ║
║     ρ = 7, 噪声端步长更细                                                              ║
║                                                                                        ║
║  5. 误差分解:                                                                           ║
║     Total = Approximation + Discretization + Optimization                              ║
║     = O(1/√n) + O(Δt) + O(1/√T)                                                       ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第三篇:一致性模型的前沿发展

1. 引言

一致性模型自 2023 年提出以来,已经发展出多个重要变体和应用方向。


2. 一致性训练(Consistency Training)

2.1 无需教师的训练

一致性蒸馏需要预训练的教师模型。一致性训练(Consistency Training) 直接从数据训练,无需教师。

核心思想:利用 SDE 的随机性生成相邻点对。

前向 SDE

d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt

对于同一个 x 0 \mathbf{x}_0 x0,两次独立的 SDE 采样得到 x t \mathbf{x}t xt 和 x t ′ \mathbf{x}{t'} xt′( t ≈ t ′ t \approx t' t≈t′),它们在同一条轨迹附近。

一致性训练损失

L CT ( θ , θ − ) = E d ( f θ ( x t , t ) ,   f θ − ( x t ′ , t ′ ) ) \mathcal{L}_{\text{CT}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_t, t), \\, f_{\\theta\^-}(\\mathbf{x}_{t'}, t') \\right) \\right LCT(θ,θ−)=Ed(fθ(xt,t),fθ−(xt′,t′))

2.2 数学挑战

一致性训练的理论保证弱于一致性蒸馏,因为 SDE 轨迹不完全确定------两次独立采样的点可能不在同一条 PF-ODE 轨迹上。

缓解策略

  1. 使用较小的时间步间隔 ∣ t − t ′ ∣ |t - t'| ∣t−t′∣
  2. 使用较大的 EMA 衰减率
  3. 使用更稳定的距离度量

3. 进阶变体

3.1 渐进蒸馏(Progressive Distillation)

思想:逐步减少采样步数,每轮将步数减半。

流程

  1. 训练 N N N 步的扩散模型
  2. 蒸馏为 N / 2 N/2 N/2 步
  3. 蒸馏为 N / 4 N/4 N/4 步
  4. ...直到 1 步

优势:每轮蒸馏的难度更低,训练更稳定。

3.2 一致性模型 + Latent Diffusion

将一致性模型应用于 Latent Diffusion(如 Stable Diffusion):

z 0 = f θ ( z T , T ) , x 0 = Decoder ( z 0 ) \mathbf{z}0 = f\theta(\mathbf{z}_T, T), \quad \mathbf{x}_0 = \text{Decoder}(\mathbf{z}_0) z0=fθ(zT,T),x0=Decoder(z0)

优势:在低维潜在空间中操作,计算效率更高。

3.3 一致性模型 + Classifier-Free Guidance

将 CFG 应用于一致性模型:

f θ cfg ( x t , t , c ) = ( 1 + w ) f θ ( x t , t , c ) − w f θ ( x t , t , ∅ ) f_\theta^{\text{cfg}}(\mathbf{x}t, t, c) = (1 + w) f\theta(\mathbf{x}t, t, c) - w f\theta(\mathbf{x}_t, t, \varnothing) fθcfg(xt,t,c)=(1+w)fθ(xt,t,c)−wfθ(xt,t,∅)

其中 w w w 是引导强度, c c c 是条件(如文本), ∅ \varnothing ∅ 是空条件。


4. 一致性模型与其他方法的对比

方法 采样步数 训练方式 质量 速度
DDPM 1000 去噪 最高 最慢
DDIM 10-50 去噪 中等
Flow Matching 10-50 速度场 中等
一致性蒸馏 1-5 蒸馏 中高
一致性训练 1-5 直接训练 最快

5. 前沿研究方向

5.1 音频一致性模型

将一致性模型应用于音频生成(AudioLDM、MusicGen),实现单步音频合成。

5.2 视频一致性模型

将一致性模型应用于视频生成,利用时间维度的一致性。

5.3 3D 一致性模型

将一致性模型应用于 3D 生成(NeRF、3D Gaussian Splatting),实现单步 3D 重建。

5.4 一致性模型的理论深化

  1. 最优传输视角:将一致性模型与最优传输理论联系
  2. 信息论分析:分析一致性模型的信息压缩率
  3. 收敛速率:改进一致性蒸馏的收敛速率分析

6. 前沿发展数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    一致性模型前沿发展 数学总结                                             ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 一致性训练 (无教师):                                                                ║
║     L_CT = E[ d(f_θ(x_t, t), f_{θ⁻}(x_{t'}, t')) ]                                   ║
║     x_t, x_{t'} 由同一 x₀ 的两次独立 SDE 采样获得                                      ║
║                                                                                        ║
║  2. 渐进蒸馏:                                                                           ║
║     N 步 → N/2 步 → N/4 步 → ... → 1 步                                               ║
║     每轮蒸馏难度更低, 训练更稳定                                                        ║
║                                                                                        ║
║  3. CFG 引导:                                                                           ║
║     f_θ^{cfg}(x_t, t, c) = (1+w)·f_θ(x_t, t, c) - w·f_θ(x_t, t, ∅)                  ║
║                                                                                        ║
║  4. Latent 一致性:                                                                      ║
║     z₀ = f_θ(z_T, T),  x₀ = Decoder(z₀)                                              ║
║     在低维潜在空间操作, 计算效率更高                                                     ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

参考文献

核心论文

  1. Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. ICML 2023.
  2. Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-Based Generative Modeling through Stochastic Differential Equations. ICLR 2021.
  3. Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.

扩散模型基础

  1. Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
  2. Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. NeurIPS 2022.

加速采样

  1. Lu, C., Zhou, Y., Bao, F., Chen, J., Li, C., & Zhu, J. (2022). DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. NeurIPS 2022.
  2. Salimans, T., & Ho, J. (2022). Progressive Distillation for Fast Sampling of Diffusion Models. ICLR 2022.

应用

  1. Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR 2022.
  2. Saharia, C., Chan, W., et al. (2022). Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. NeurIPS 2022.

一致性模型深度解析

目录

  • [第一篇:一致性模型 --- 单步生成的数学基础](#第一篇:一致性模型 — 单步生成的数学基础)
  • [第二篇:一致性蒸馏 --- 从预训练扩散模型学习](#第二篇:一致性蒸馏 — 从预训练扩散模型学习)
  • 第三篇:一致性模型的前沿发展
  • 参考文献

第一篇:一致性模型 --- 单步生成的数学基础

1. 引言

扩散模型(Diffusion Models)在图像生成领域取得了巨大成功,但其采样过程需要数百步迭代去噪,计算代价高昂。即使是加速采样方法(如 DDIM、DPM-Solver),通常也需要 10-50 步才能获得高质量样本。

一致性模型(Consistency Models, Song et al., 2023) 提出了一种全新的生成范式:学习一个函数,将 ODE 轨迹上的任意点直接映射到轨迹起点(数据)。这一设计实现了:

  1. 单步生成:一次前向传播即可生成高质量样本
  2. 多步精炼:支持多步采样以换取更高质量
  3. 数学优雅:基于概率流 ODE 的几何性质,定义简洁

2. 理论基础 --- 概率流 ODE

2.1 扩散过程的 ODE 视角

扩散模型有两种等价的连续时间表述:

随机微分方程(SDE)(前向):

d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt

概率流 ODE(PF-ODE)(前向):

d x t d t = f ( x t , t ) − 1 2 g ( t ) 2 ∇ log ⁡ p t ( x t ) \frac{d\mathbf{x}_t}{dt} = \mathbf{f}(\mathbf{x}_t, t) - \frac{1}{2} g(t)^2 \nabla \log p_t(\mathbf{x}_t) dtdxt=f(xt,t)−21g(t)2∇logpt(xt)

其中 ∇ log ⁡ p t ( x t ) \nabla \log p_t(\mathbf{x}_t) ∇logpt(xt) 是分数函数(score function)

关键性质 :PF-ODE 与 SDE 具有相同的边际分布 p t ( x ) p_t(\mathbf{x}) pt(x),但 PF-ODE 是确定性的------给定初始条件 x 0 \mathbf{x}_0 x0,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 唯一确定。

2.2 方差保持(VP)扩散的 PF-ODE

对于 VP 扩散(DDPM):

f ( x t , t ) = − 1 2 β ( t ) x t , g ( t ) = β ( t ) \mathbf{f}(\mathbf{x}_t, t) = -\frac{1}{2} \beta(t) \mathbf{x}_t, \quad g(t) = \sqrt{\beta(t)} f(xt,t)=−21β(t)xt,g(t)=β(t)

PF-ODE 为:

d x t d t = − 1 2 β ( t ) x t − 1 2 β ( t ) ∇ log ⁡ p t ( x t ) \frac{d\mathbf{x}_t}{dt} = -\frac{1}{2} \beta(t) \mathbf{x}_t - \frac{1}{2} \beta(t) \nabla \log p_t(\mathbf{x}_t) dtdxt=−21β(t)xt−21β(t)∇logpt(xt)

定义去噪函数 D ( x t , t ) \mathbf{D}(\mathbf{x}_t, t) D(xt,t):

D ( x t , t ) = x t + σ t 2 ∇ log ⁡ p t ( x t ) α t \mathbf{D}(\mathbf{x}_t, t) = \frac{\mathbf{x}_t + \sigma_t^2 \nabla \log p_t(\mathbf{x}_t)}{\alpha_t} D(xt,t)=αtxt+σt2∇logpt(xt)

其中 α t = exp ⁡ ( − 1 2 ∫ 0 t β ( s ) d s ) \alpha_t = \exp\left(-\frac{1}{2}\int_0^t \beta(s) ds\right) αt=exp(−21∫0tβ(s)ds), σ t 2 = 1 − α t 2 \sigma_t^2 = 1 - \alpha_t^2 σt2=1−αt2。

则 PF-ODE 可改写为:

d x t d t = − α ˙ t α t x t + α ˙ t σ t α t D ( x t , t ) \frac{d\mathbf{x}_t}{dt} = -\frac{\dot{\alpha}_t}{\alpha_t} \mathbf{x}_t + \frac{\dot{\alpha}_t \sigma_t}{\alpha_t} \mathbf{D}(\mathbf{x}_t, t) dtdxt=−αtα˙txt+αtα˙tσtD(xt,t)

2.3 PF-ODE 轨迹的性质

定理 :对于 VP 扩散的 PF-ODE,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 具有以下性质:

  1. 唯一性 :给定 x 0 \mathbf{x}_0 x0,轨迹唯一确定
  2. 连续性 : x t \mathbf{x}_t xt 关于 t t t 连续
  3. 边界行为 : x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata, x T ≈ N ( 0 , I ) \mathbf{x}_T \approx \mathcal{N}(0, \mathbf{I}) xT≈N(0,I)

直觉 :PF-ODE 将数据分布 p data p_{\text{data}} pdata 连续地"变形"为高斯噪声 N ( 0 , I ) \mathcal{N}(0, \mathbf{I}) N(0,I)。生成过程就是从噪声沿轨迹"走回去"。


3. 一致性模型的核心定义

3.1 一致性函数

定义 (一致性函数 f θ f_\theta fθ):对于 PF-ODE 的任意轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,一致性函数满足:

f θ ( x t , t ) = f θ ( x t ′ , t ′ ) , ∀ t , t ′ ∈ 0 , T f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t'}, t'), \quad \forall t, t' \in 0, T fθ(xt,t)=fθ(xt′,t′),∀t,t′∈0,T

轨迹上的所有点映射到同一个值 ------轨迹的起点 x 0 \mathbf{x}0 x0(或等价地, x ϵ \mathbf{x}\epsilon xϵ 以避免数值问题)。

边界条件

f θ ( x ϵ , ϵ ) = x ϵ f_\theta(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon fθ(xϵ,ϵ)=xϵ

其中 ϵ > 0 \epsilon > 0 ϵ>0 是一个很小的常数(如 ϵ = 0.002 \epsilon = 0.002 ϵ=0.002),用于避免 t = 0 t = 0 t=0 处的数值不稳定性。

3.2 自洽性(Self-Consistency)

一致性模型的核心约束是自洽性

f θ ( x t , t ) = f θ ( x t + δ , t + δ ) , ∀ δ > 0 f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t+\delta}, t + \delta), \quad \forall \delta > 0 fθ(xt,t)=fθ(xt+δ,t+δ),∀δ>0

这意味着:沿 ODE 轨迹前进任意步,一致性函数的输出不变。

几何直觉:一致性函数将整条 ODE 轨迹"压缩"为一个点(起点)。不同轨迹映射到不同的点,但同一轨迹上的所有点映射到同一个点。

3.3 生成过程

利用一致性函数,生成过程极其简单:

单步生成

x 0 = f θ ( x T , T ) , x T ∼ N ( 0 , I ) \mathbf{x}0 = f\theta(\mathbf{x}_T, T), \quad \mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I}) x0=fθ(xT,T),xT∼N(0,I)

一次前向传播,从噪声直接生成数据!

多步生成(迭代精炼):

x ϵ = f θ ( x T , T ) \mathbf{x}\epsilon = f\theta(\mathbf{x}_T, T) xϵ=fθ(xT,T)

x t n + 1 = f θ ( x t n , t n ) + t n 2 − ϵ 2 ⋅ z , z ∼ N ( 0 , I ) \mathbf{x}{t{n+1}} = f_\theta(\mathbf{x}_{t_n}, t_n) + \sqrt{t_n^2 - \epsilon^2} \cdot \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) xtn+1=fθ(xtn,tn)+tn2−ϵ2 ⋅z,z∼N(0,I)

多步生成通过在轨迹上添加少量噪声再映射,实现逐步精炼。


4. 训练方法 --- 一致性蒸馏

4.1 一致性蒸馏(Consistency Distillation)

核心思想:利用预训练的扩散模型(教师)生成轨迹上的相邻点对,训练一致性模型(学生)使它们映射到相同值。

训练数据 :对于每个数据点 x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata,通过 PF-ODE 求解器获得相邻点对 ( x t n + 1 , x t n ) (\mathbf{x}{t{n+1}}, \mathbf{x}_{t_n}) (xtn+1,xtn)。

损失函数

L CD ( θ , θ − ) = E d ( f θ ( x t n + 1 , t n + 1 ) ,   f θ − ( x t n , t n ) ) \mathcal{L}_{\text{CD}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_{t_{n+1}}, t_{n+1}), \\, f_{\\theta\^-}(\\mathbf{x}_{t_n}, t_n) \\right) \\right LCD(θ,θ−)=Ed(fθ(xtn+1,tn+1),fθ−(xtn,tn))

其中:

  • f θ f_\theta fθ 是在线网络(正在更新)
  • f θ − f_{\theta^-} fθ− 是目标网络(EMA 更新)
  • d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(⋅,⋅) 是距离度量(如 L 2 L_2 L2、LPIPS)
  • x t n \mathbf{x}{t_n} xtn 由 x t n + 1 \mathbf{x}{t_{n+1}} xtn+1 通过一步 PF-ODE 求解器获得

目标网络更新

θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ

其中 μ \mu μ 是 EMA 衰减率(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)。

4.2 数学正确性

定理 :当 f θ f_\theta fθ 满足自洽性时, L CD = 0 \mathcal{L}_{\text{CD}} = 0 LCD=0。

证明

由自洽性:

f θ ( x t n + 1 , t n + 1 ) = f θ ( x t n , t n ) f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}) = f_\theta(\mathbf{x}_{t_n}, t_n) fθ(xtn+1,tn+1)=fθ(xtn,tn)

当 θ − = θ \theta^- = \theta θ−=θ 时(EMA 已收敛):

d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x t n , t n ) ) = d ( f θ ( x t n + 1 , t n + 1 ) , f θ ( x t n , t n ) ) = 0 d(f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_{\theta^-}(\mathbf{x}{t_n}, t_n)) = d(f\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_\theta(\mathbf{x}_{t_n}, t_n)) = 0 d(fθ(xtn+1,tn+1),fθ−(xtn,tn))=d(fθ(xtn+1,tn+1),fθ(xtn,tn))=0

■ \blacksquare ■

4.3 一步 PF-ODE 求解

为了获得相邻点对,需要一步 PF-ODE 求解。使用 DDIM 求解器:

x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)

其中 D ϕ \mathbf{D}_\phi Dϕ 是预训练的去噪模型(教师), α t , σ t \alpha_t, \sigma_t αt,σt 是噪声调度参数。


5. 完整可运行实现

5.1 一致性模型核心实现

python 复制代码
"""
一致性模型 (Consistency Models) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Tuple, Optional, List
from dataclasses import dataclass


@dataclass
class ConsistencyConfig:
    """一致性模型配置"""
    data_dim: int = 2
    hidden_dim: int = 256
    time_dim: int = 64
    num_layers: int = 6
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: float = 7.0
    num_timesteps: int = 40  # 时间步离散化数量
    ema_decay: float = 0.9999


def get_sigmas_karras(
    sigma_min: float, sigma_max: float, rho: float, num_steps: int
) -> torch.Tensor:
    """Karras 噪声调度 (Karras et al., 2022)"""
    inv_rho = 1.0 / rho
    steps = torch.arange(num_steps, dtype=torch.float64) / (num_steps - 1)
    sigmas = (sigma_max ** inv_rho + steps * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
    return sigmas.float()


def get_alpha_sigma(t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """VP 扩散的 alpha 和 sigma 参数"""
    alpha = torch.cos(t * math.pi / 2)
    sigma = torch.sin(t * math.pi / 2)
    return alpha, sigma


class SinusoidalTimeEmbedding(nn.Module):
    """正弦时间嵌入"""

    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t.unsqueeze(-1) * emb.unsqueeze(0)
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)


class ConsistencyModel(nn.Module):
    """一致性模型网络"""

    def __init__(self, config: ConsistencyConfig):
        super().__init__()
        self.config = config

        self.time_embed = SinusoidalTimeEmbedding(config.time_dim)
        self.input_proj = nn.Linear(config.data_dim + config.time_dim, config.hidden_dim)

        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(config.hidden_dim),
                nn.SiLU(),
                nn.Linear(config.hidden_dim, config.hidden_dim),
                nn.LayerNorm(config.hidden_dim),
                nn.SiLU(),
                nn.Linear(config.hidden_dim, config.hidden_dim),
            )
            for _ in range(config.num_layers)
        ])

        self.output_proj = nn.Linear(config.hidden_dim, config.data_dim)
        nn.init.zeros_(self.output_proj.weight)
        nn.init.zeros_(self.output_proj.bias)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        x: (B, D) 带噪数据
        t: (B,) 时间步 (0=数据, 1=噪声)
        """
        t_emb = self.time_embed(t)
        h = torch.cat([x, t_emb], dim=-1)
        h = self.input_proj(h)

        for block in self.blocks:
            h = h + block(h)

        return self.output_proj(h)

5.2 一致性蒸馏训练

python 复制代码
class ConsistencyDistillation:
    """一致性蒸馏训练器"""

    def __init__(
        self,
        model: ConsistencyModel,
        teacher_model: nn.Module,
        config: ConsistencyConfig,
        device: torch.device,
    ):
        self.model = model
        self.teacher = teacher_model
        self.config = config
        self.device = device

        # 创建目标网络 (EMA)
        self.target_model = ConsistencyModel(config).to(device)
        self.target_model.load_state_dict(model.state_dict())

        # 冻结教师模型
        for param in self.teacher.parameters():
            param.requires_grad = False

        # 噪声调度
        self.sigmas = get_sigmas_karras(
            config.sigma_min, config.sigma_max, config.rho, config.num_timesteps
        ).to(device)

        self.optimizer = torch.optim.AdamW(
            model.parameters(), lr=1e-4, weight_decay=0.0
        )

    def add_noise(
        self, x: torch.Tensor, t: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """添加噪声: x_t = alpha_t * x + sigma_t * noise"""
        alpha, sigma = get_alpha_sigma(t)
        noise = torch.randn_like(x)
        x_t = alpha.unsqueeze(-1) * x + sigma.unsqueeze(-1) * noise
        return x_t, noise

    def one_step_denoise(
        self, x_t: torch.Tensor, t: torch.Tensor, t_prev: torch.Tensor
    ) -> torch.Tensor:
        """使用教师模型进行一步去噪 (DDIM 求解器)"""
        with torch.no_grad():
            # 教师模型预测去噪结果
            x_denoised = self.teacher(x_t, t)

            # DDIM 一步更新
            alpha_t, sigma_t = get_alpha_sigma(t)
            alpha_prev, sigma_prev = get_alpha_sigma(t_prev)

            # x_{t_prev} = alpha_{t_prev} * x_denoised + sigma_{t_prev} * (x_t - alpha_t * x_denoised) / sigma_t
            x_prev = (
                alpha_prev.unsqueeze(-1) * x_denoised
                + sigma_prev.unsqueeze(-1) * (x_t - alpha_t.unsqueeze(-1) * x_denoised) / sigma_t.unsqueeze(-1)
            )

        return x_prev

    def compute_loss(
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, dict]:
        """计算一致性蒸馏损失"""
        B = x.shape[0]

        # 随机采样时间步对 (t_{n+1}, t_n)
        # 使用 Karras 调度的离散时间步
        n = torch.randint(0, self.config.num_timesteps - 1, (B,), device=self.device)
        t = self.sigmas[n]  # t_{n+1}
        t_prev = self.sigmas[n + 1]  # t_n (更接近数据)

        # 添加噪声
        x_t, noise = self.add_noise(x, t)

        # 教师模型一步去噪
        x_prev = self.one_step_denoise(x_t, t, t_prev)

        # 在线网络预测
        pred_online = self.model(x_t, t)

        # 目标网络预测
        with torch.no_grad():
            pred_target = self.target_model(x_prev, t_prev)

        # 一致性损失 (L2)
        loss = F.mse_loss(pred_online, pred_target)

        metrics = {
            "loss": loss.item(),
            "t_mean": t.mean().item(),
            "t_prev_mean": t_prev.mean().item(),
        }

        return loss, metrics

    def train_step(self, x: torch.Tensor) -> dict:
        """执行一步训练"""
        self.model.train()
        self.optimizer.zero_grad()

        loss, metrics = self.compute_loss(x)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        # EMA 更新目标网络
        with torch.no_grad():
            for param, target_param in zip(
                self.model.parameters(), self.target_model.parameters()
            ):
                target_param.data.mul_(self.config.ema_decay).add_(
                    param.data, alpha=1 - self.config.ema_decay
                )

        return metrics

    @torch.no_grad()
    def sample(
        self, num_samples: int, num_steps: int = 1
    ) -> torch.Tensor:
        """生成样本"""
        self.model.eval()

        # 从噪声开始
        x = torch.randn(num_samples, self.config.data_dim, device=self.device)

        if num_steps == 1:
            # 单步生成
            t = torch.ones(num_samples, device=self.device)
            x = self.model(x, t)
        else:
            # 多步生成
            timesteps = get_sigmas_karras(
                self.config.sigma_min, self.config.sigma_max,
                self.config.rho, num_steps
            ).to(self.device)

            for i in range(len(timesteps) - 1):
                t = timesteps[i].expand(num_samples)
                x = self.model(x, t)

                # 添加少量噪声 (用于迭代精炼)
                if i < len(timesteps) - 2:
                    noise = torch.randn_like(x)
                    t_next = timesteps[i + 1]
                    _, sigma = get_alpha_sigma(t_next)
                    x = x + sigma * noise * 0.5

        return x

5.3 实验代码

python 复制代码
def experiment_consistency_2d():
    """在 2D 双月数据上训练一致性模型"""
    from sklearn.datasets import make_moons

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 生成数据
    data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    data = torch.tensor(data, dtype=torch.float32).to(device)

    config = ConsistencyConfig(data_dim=2, hidden_dim=256, num_layers=6)

    # 教师模型 (预训练的去噪模型)
    teacher = ConsistencyModel(config).to(device)

    # 学生模型 (一致性模型)
    student = ConsistencyModel(config).to(device)

    # 初始化教师模型 (模拟预训练)
    # 实际应用中,这里加载预训练的扩散模型
    teacher.load_state_dict(student.state_dict())

    trainer = ConsistencyDistillation(student, teacher, config, device)

    # 训练
    print("一致性蒸馏训练...")
    batch_size = 256
    losses = []

    for step in range(5000):
        idx = torch.randint(0, data.shape[0], (batch_size,))
        x_batch = data[idx]

        metrics = trainer.train_step(x_batch)
        losses.append(metrics["loss"])

        if (step + 1) % 1000 == 0:
            avg_loss = np.mean(losses[-100:])
            print(f"Step {step+1} | Loss: {avg_loss:.6f}")

    # 生成样本
    print("\n生成样本...")
    samples_1step = trainer.sample(num_samples=1000, num_steps=1)
    samples_5step = trainer.sample(num_samples=1000, num_steps=5)

    return trainer, samples_1step.cpu(), samples_5step.cpu()

6. 一致性模型的理论性质

6.1 表达能力定理

定理(Song et al., 2023):一致性模型的表达能力足够强大,可以精确表示任意 PF-ODE 轨迹。

证明思路 :对于任意 PF-ODE 轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,定义 f ∗ ( x t , t ) = x 0 f^*(\mathbf{x}t, t) = \mathbf{x}0 f∗(xt,t)=x0。这是一个合法的一致性函数,且满足边界条件 f ∗ ( x ϵ , ϵ ) = x ϵ f^*(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon f∗(xϵ,ϵ)=xϵ。

当神经网络 f θ f_\theta fθ 的容量足够大时,它可以任意逼近 f ∗ f^* f∗。 ■ \blacksquare ■

6.2 单步生成的误差分析

定理:单步生成的误差上界为:

E ∥ x 0 − f θ ( x T , T ) ∥ 2 ≤ L CD ( θ , θ − ) + O ( Δ t ) \mathbb{E}\left\\\|\\mathbf{x}_0 - f_\\theta(\\mathbf{x}_T, T)\\\|\^2\\right \leq \mathcal{L}_{\text{CD}}(\theta, \theta^-) + O(\Delta t) E∥x0−fθ(xT,T)∥2≤LCD(θ,θ−)+O(Δt)

其中 Δ t \Delta t Δt 是时间离散化的步长。

直觉 :一致性蒸馏损失越小,单步生成的质量越高。时间步离散化越细( Δ t \Delta t Δt 越小),误差越小。

6.3 与扩散模型的关系

特性 扩散模型 一致性模型
生成步数 10-1000 步 1-5 步
训练目标 去噪 自洽性
采样方式 迭代去噪 直接映射
理论基础 SDE/PF-ODE PF-ODE 轨迹
质量-速度权衡 高质量但慢 快但略低质量

7. 一致性模型数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    一致性模型 (Consistency Models) 数学总结                                ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 概率流 ODE:                                                                         ║
║     dx_t/dt = f(x_t, t) - ½g(t)² ∇log p_t(x_t)                                       ║
║     轨迹: x_0 → x_T (数据 → 噪声)                                                     ║
║                                                                                        ║
║  2. 一致性函数定义:                                                                     ║
║     f_θ(x_t, t) = f_θ(x_{t'}, t')    ∀t, t' ∈ [0, T]  (同一轨迹→同一输出)            ║
║     边界条件: f_θ(x_ε, ε) = x_ε                                                       ║
║                                                                                        ║
║  3. 单步生成:                                                                           ║
║     x_0 = f_θ(x_T, T),    x_T ~ N(0, I)                                               ║
║     一次前向传播, 从噪声直接生成数据                                                    ║
║                                                                                        ║
║  4. 一致性蒸馏损失:                                                                     ║
║     L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ]                   ║
║     θ⁻ = EMA(θ),  x_{t_n} 由一步 DDIM 求解器获得                                       ║
║                                                                                        ║
║  5. 正确性:                                                                             ║
║     f_θ 满足自洽性 ⟹ L_CD = 0                                                          ║
║                                                                                        ║
║  6. 误差上界:                                                                           ║
║     E[‖x₀ - f_θ(x_T, T)‖²] ≤ L_CD + O(Δt)                                            ║
║                                                                                        ║
║  7. VP 扩散参数:                                                                        ║
║     α_t = cos(πt/2),  σ_t = sin(πt/2)                                                 ║
║     x_t = α_t·x₀ + σ_t·ε,  ε ~ N(0, I)                                               ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第二篇:一致性蒸馏 --- 从预训练扩散模型学习

1. 引言

一致性蒸馏(Consistency Distillation)是训练一致性模型的主要方法,它利用预训练的扩散模型作为教师,通过蒸馏的方式训练学生模型。本篇深入分析蒸馏过程的数学细节和实现技巧。


2. 蒸馏过程的数学分析

2.1 教师-学生框架

教师模型 D ϕ \mathbf{D}_\phi Dϕ:预训练的去噪模型,满足:

D ϕ ( x t , t ) ≈ E x 0 ∣ x t \mathbf{D}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ(xt,t)≈Ex0∣xt

学生模型 f θ f_\theta fθ:一致性模型,满足自洽性。

蒸馏目标:利用教师模型生成轨迹上的相邻点对,训练学生模型使它们映射到相同值。

2.2 DDIM 求解器

DDIM(Denoising Diffusion Implicit Models, Song et al., 2020)是一种确定性的 PF-ODE 求解器:

x t n = α t n ( x t n + 1 − σ t n + 1 ϵ ^ α t n + 1 ) ⏟ x ^ 0 + σ t n ϵ ^ \mathbf{x}{t{n}} = \alpha_{t_n} \underbrace{\left(\frac{\mathbf{x}{t{n+1}} - \sigma_{t_{n+1}} \hat{\mathbf{\epsilon}}}{\alpha_{t_{n+1}}}\right)}_{\hat{\mathbf{x}}0} + \sigma{t_n} \hat{\mathbf{\epsilon}} xtn=αtnx^0 (αtn+1xtn+1−σtn+1ϵ^)+σtnϵ^

其中 ϵ ^ = ( x t n + 1 − α t n + 1 x ^ 0 ) / σ t n + 1 \hat{\mathbf{\epsilon}} = (\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \hat{\mathbf{x}}0) / \sigma{t_{n+1}} ϵ^=(xtn+1−αtn+1x^0)/σtn+1 是预测的噪声。

等价形式(使用去噪函数):

x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)

2.3 时间步离散化

Karras 调度(Karras et al., 2022):

t i = ( t max ⁡ 1 / ρ + i N − 1 ( t min ⁡ 1 / ρ − t max ⁡ 1 / ρ ) ) ρ t_i = \left(t_{\max}^{1/\rho} + \frac{i}{N-1} (t_{\min}^{1/\rho} - t_{\max}^{1/\rho})\right)^\rho ti=(tmax1/ρ+N−1i(tmin1/ρ−tmax1/ρ))ρ

其中 ρ = 7 \rho = 7 ρ=7 控制时间步的分布(更多步集中在噪声端)。

直觉:噪声端的 ODE 曲率更大,需要更细的时间步离散化。


3. 训练技巧

3.1 EMA 更新

目标网络使用 EMA(Exponential Moving Average)更新:

θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ

为什么需要 EMA:如果直接使用在线网络作为目标,训练会不稳定------两个相同的网络相互"追逐",导致发散。EMA 提供了稳定的目标。

3.2 距离度量

L2 距离

d ( x , y ) = ∥ x − y ∥ 2 2 d(\mathbf{x}, \mathbf{y}) = \|\mathbf{x} - \mathbf{y}\|_2^2 d(x,y)=∥x−y∥22

LPIPS 距离(用于图像):

d LPIPS ( x , y ) = ∑ l ∥ feat l ( x ) − feat l ( y ) ∥ 2 2 d_{\text{LPIPS}}(\mathbf{x}, \mathbf{y}) = \sum_l \|\text{feat}_l(\mathbf{x}) - \text{feat}_l(\mathbf{y})\|_2^2 dLPIPS(x,y)=l∑∥featl(x)−featl(y)∥22

LPIPS 使用预训练网络的特征距离,更符合人类感知。

3.3 梯度裁剪

一致性蒸馏的梯度可能很大,需要梯度裁剪:

g ← g ⋅ min ⁡ ( 1 , c ∥ g ∥ ) \mathbf{g} \leftarrow \mathbf{g} \cdot \min\left(1, \frac{c}{\|\mathbf{g}\|}\right) g←g⋅min(1,∥g∥c)

其中 c c c 是裁剪阈值(如 c = 1.0 c = 1.0 c=1.0)。


4. 一致性蒸馏的收敛性

4.1 收敛定理

定理(非正式):在适当条件下,一致性蒸馏收敛到教师模型的 PF-ODE 轨迹。

条件

  1. 教师模型足够好( D ϕ ≈ E x 0 ∣ x t \mathbf{D}_\phi \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ≈Ex0∣xt
  2. 时间步离散化足够细( Δ t → 0 \Delta t \to 0 Δt→0)
  3. 网络容量足够大
  4. EMA 衰减率 μ \mu μ 适当(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)

4.2 蒸馏误差的分解

总蒸馏误差可分解为:

Total Error = Approximation Error ⏟ 网络容量不足 + Discretization Error ⏟ 时间步离散化 + Optimization Error ⏟ 训练不充分 \text{Total Error} = \underbrace{\text{Approximation Error}}{\text{网络容量不足}} + \underbrace{\text{Discretization Error}}{\text{时间步离散化}} + \underbrace{\text{Optimization Error}}_{\text{训练不充分}} Total Error=网络容量不足 Approximation Error+时间步离散化 Discretization Error+训练不充分 Optimization Error

  • 逼近误差 : O ( 1 / n ) O(1/\sqrt{n}) O(1/n ), n n n 是网络参数量
  • 离散化误差 : O ( Δ t ) O(\Delta t) O(Δt), Δ t \Delta t Δt 是时间步间隔
  • 优化误差:随训练步数减小

5. 一致性蒸馏数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    一致性蒸馏 (Consistency Distillation) 数学总结                         ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. DDIM 求解器:                                                                        ║
║     x_{t_n} = α_{t_n}·D_φ(x_{t_{n+1}}, t_{n+1})                                      ║
║             + σ_{t_n}·(x_{t_{n+1}} - α_{t_{n+1}}·D_φ) / σ_{t_{n+1}}                  ║
║                                                                                        ║
║  2. 蒸馏损失:                                                                           ║
║     L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ]                   ║
║     x_{t_n} 由一步 DDIM 获得, θ⁻ = EMA(θ)                                             ║
║                                                                                        ║
║  3. EMA 更新:                                                                           ║
║     θ⁻ ← μ·θ⁻ + (1-μ)·θ,    μ = 0.9999                                               ║
║     提供稳定的训练目标                                                                  ║
║                                                                                        ║
║  4. 时间步调度 (Karras):                                                                ║
║     t_i = (t_max^{1/ρ} + i/(N-1)·(t_min^{1/ρ} - t_max^{1/ρ}))^ρ                      ║
║     ρ = 7, 噪声端步长更细                                                              ║
║                                                                                        ║
║  5. 误差分解:                                                                           ║
║     Total = Approximation + Discretization + Optimization                              ║
║     = O(1/√n) + O(Δt) + O(1/√T)                                                       ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第三篇:一致性模型的前沿发展

1. 引言

一致性模型自 2023 年提出以来,已经发展出多个重要变体和应用方向。


2. 一致性训练(Consistency Training)

2.1 无需教师的训练

一致性蒸馏需要预训练的教师模型。一致性训练(Consistency Training) 直接从数据训练,无需教师。

核心思想:利用 SDE 的随机性生成相邻点对。

前向 SDE

d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt

对于同一个 x 0 \mathbf{x}_0 x0,两次独立的 SDE 采样得到 x t \mathbf{x}t xt 和 x t ′ \mathbf{x}{t'} xt′( t ≈ t ′ t \approx t' t≈t′),它们在同一条轨迹附近。

一致性训练损失

L CT ( θ , θ − ) = E d ( f θ ( x t , t ) ,   f θ − ( x t ′ , t ′ ) ) \mathcal{L}_{\text{CT}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_t, t), \\, f_{\\theta\^-}(\\mathbf{x}_{t'}, t') \\right) \\right LCT(θ,θ−)=Ed(fθ(xt,t),fθ−(xt′,t′))

2.2 数学挑战

一致性训练的理论保证弱于一致性蒸馏,因为 SDE 轨迹不完全确定------两次独立采样的点可能不在同一条 PF-ODE 轨迹上。

缓解策略

  1. 使用较小的时间步间隔 ∣ t − t ′ ∣ |t - t'| ∣t−t′∣
  2. 使用较大的 EMA 衰减率
  3. 使用更稳定的距离度量

3. 进阶变体

3.1 渐进蒸馏(Progressive Distillation)

思想:逐步减少采样步数,每轮将步数减半。

流程

  1. 训练 N N N 步的扩散模型
  2. 蒸馏为 N / 2 N/2 N/2 步
  3. 蒸馏为 N / 4 N/4 N/4 步
  4. ...直到 1 步

优势:每轮蒸馏的难度更低,训练更稳定。

3.2 一致性模型 + Latent Diffusion

将一致性模型应用于 Latent Diffusion(如 Stable Diffusion):

z 0 = f θ ( z T , T ) , x 0 = Decoder ( z 0 ) \mathbf{z}0 = f\theta(\mathbf{z}_T, T), \quad \mathbf{x}_0 = \text{Decoder}(\mathbf{z}_0) z0=fθ(zT,T),x0=Decoder(z0)

优势:在低维潜在空间中操作,计算效率更高。

3.3 一致性模型 + Classifier-Free Guidance

将 CFG 应用于一致性模型:

f θ cfg ( x t , t , c ) = ( 1 + w ) f θ ( x t , t , c ) − w f θ ( x t , t , ∅ ) f_\theta^{\text{cfg}}(\mathbf{x}t, t, c) = (1 + w) f\theta(\mathbf{x}t, t, c) - w f\theta(\mathbf{x}_t, t, \varnothing) fθcfg(xt,t,c)=(1+w)fθ(xt,t,c)−wfθ(xt,t,∅)

其中 w w w 是引导强度, c c c 是条件(如文本), ∅ \varnothing ∅ 是空条件。


4. 一致性模型与其他方法的对比

方法 采样步数 训练方式 质量 速度
DDPM 1000 去噪 最高 最慢
DDIM 10-50 去噪 中等
Flow Matching 10-50 速度场 中等
一致性蒸馏 1-5 蒸馏 中高
一致性训练 1-5 直接训练 最快

5. 前沿研究方向

5.1 音频一致性模型

将一致性模型应用于音频生成(AudioLDM、MusicGen),实现单步音频合成。

5.2 视频一致性模型

将一致性模型应用于视频生成,利用时间维度的一致性。

5.3 3D 一致性模型

将一致性模型应用于 3D 生成(NeRF、3D Gaussian Splatting),实现单步 3D 重建。

5.4 一致性模型的理论深化

  1. 最优传输视角:将一致性模型与最优传输理论联系
  2. 信息论分析:分析一致性模型的信息压缩率
  3. 收敛速率:改进一致性蒸馏的收敛速率分析

6. 前沿发展数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    一致性模型前沿发展 数学总结                                             ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 一致性训练 (无教师):                                                                ║
║     L_CT = E[ d(f_θ(x_t, t), f_{θ⁻}(x_{t'}, t')) ]                                   ║
║     x_t, x_{t'} 由同一 x₀ 的两次独立 SDE 采样获得                                      ║
║                                                                                        ║
║  2. 渐进蒸馏:                                                                           ║
║     N 步 → N/2 步 → N/4 步 → ... → 1 步                                               ║
║     每轮蒸馏难度更低, 训练更稳定                                                        ║
║                                                                                        ║
║  3. CFG 引导:                                                                           ║
║     f_θ^{cfg}(x_t, t, c) = (1+w)·f_θ(x_t, t, c) - w·f_θ(x_t, t, ∅)                  ║
║                                                                                        ║
║  4. Latent 一致性:                                                                      ║
║     z₀ = f_θ(z_T, T),  x₀ = Decoder(z₀)                                              ║
║     在低维潜在空间操作, 计算效率更高                                                     ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

参考文献

核心论文

  1. Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. ICML 2023.
  2. Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-Based Generative Modeling through Stochastic Differential Equations. ICLR 2021.
  3. Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.

扩散模型基础

  1. Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
  2. Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. NeurIPS 2022.

加速采样

  1. Lu, C., Zhou, Y., Bao, F., Chen, J., Li, C., & Zhu, J. (2022). DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. NeurIPS 2022.
  2. Salimans, T., & Ho, J. (2022). Progressive Distillation for Fast Sampling of Diffusion Models. ICLR 2022.

应用

  1. Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR 2022.
  2. Saharia, C., Chan, W., et al. (2022). Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. NeurIPS 2022.
相关推荐
RisunJan1 小时前
Linux命令-patch (为开放源代码软件安装补丁程序)
linux·服务器·算法
cxr8281 小时前
基于人工智能的超材料逆向设计
人工智能·材料逆向设计合成
霸道流氓气质1 小时前
Spring AI Alibaba Skills 完整实战:从零构建智能会议助手
java·人工智能·spring
眠りたいです1 小时前
LangChainv1:agent快速上手与中间件认识
人工智能·python·中间件·langchain·langgraph
JJJennie7771 小时前
从苹果 2026 落地场景,看系统级 Agent 时代的隐私边界与 MAI Gateway 的企业Token治理
人工智能·gateway·apple
一条大祥脚1 小时前
ABC460贪心|多源BFS|数论|计数|线段树|树的直径
算法·宽度优先
甲维斯1 小时前
我超!Claude Fable真来了,比Mythos还强?!
人工智能
三叶草4351 小时前
Claude Code 接入 DeepSeek强强联合
人工智能
AI程序员1 小时前
Loop Engineering:你不再 prompt agent,而是设计 prompt agent 的系统
人工智能