目录
- [第一篇:一致性模型 --- 单步生成的数学基础](#第一篇:一致性模型 — 单步生成的数学基础)
- [第二篇:一致性蒸馏 --- 从预训练扩散模型学习](#第二篇:一致性蒸馏 — 从预训练扩散模型学习)
- 第三篇:一致性模型的前沿发展
- 参考文献
第一篇:一致性模型 --- 单步生成的数学基础
1. 引言
扩散模型(Diffusion Models)在图像生成领域取得了巨大成功,但其采样过程需要数百步迭代去噪,计算代价高昂。即使是加速采样方法(如 DDIM、DPM-Solver),通常也需要 10-50 步才能获得高质量样本。
一致性模型(Consistency Models, Song et al., 2023) 提出了一种全新的生成范式:学习一个函数,将 ODE 轨迹上的任意点直接映射到轨迹起点(数据)。这一设计实现了:
- 单步生成:一次前向传播即可生成高质量样本
- 多步精炼:支持多步采样以换取更高质量
- 数学优雅:基于概率流 ODE 的几何性质,定义简洁
2. 理论基础 --- 概率流 ODE
2.1 扩散过程的 ODE 视角
扩散模型有两种等价的连续时间表述:
随机微分方程(SDE)(前向):
d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt
概率流 ODE(PF-ODE)(前向):
d x t d t = f ( x t , t ) − 1 2 g ( t ) 2 ∇ log p t ( x t ) \frac{d\mathbf{x}_t}{dt} = \mathbf{f}(\mathbf{x}_t, t) - \frac{1}{2} g(t)^2 \nabla \log p_t(\mathbf{x}_t) dtdxt=f(xt,t)−21g(t)2∇logpt(xt)
其中 ∇ log p t ( x t ) \nabla \log p_t(\mathbf{x}_t) ∇logpt(xt) 是分数函数(score function)。
关键性质 :PF-ODE 与 SDE 具有相同的边际分布 p t ( x ) p_t(\mathbf{x}) pt(x),但 PF-ODE 是确定性的------给定初始条件 x 0 \mathbf{x}_0 x0,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 唯一确定。
2.2 方差保持(VP)扩散的 PF-ODE
对于 VP 扩散(DDPM):
f ( x t , t ) = − 1 2 β ( t ) x t , g ( t ) = β ( t ) \mathbf{f}(\mathbf{x}_t, t) = -\frac{1}{2} \beta(t) \mathbf{x}_t, \quad g(t) = \sqrt{\beta(t)} f(xt,t)=−21β(t)xt,g(t)=β(t)
PF-ODE 为:
d x t d t = − 1 2 β ( t ) x t − 1 2 β ( t ) ∇ log p t ( x t ) \frac{d\mathbf{x}_t}{dt} = -\frac{1}{2} \beta(t) \mathbf{x}_t - \frac{1}{2} \beta(t) \nabla \log p_t(\mathbf{x}_t) dtdxt=−21β(t)xt−21β(t)∇logpt(xt)
定义去噪函数 D ( x t , t ) \mathbf{D}(\mathbf{x}_t, t) D(xt,t):
D ( x t , t ) = x t + σ t 2 ∇ log p t ( x t ) α t \mathbf{D}(\mathbf{x}_t, t) = \frac{\mathbf{x}_t + \sigma_t^2 \nabla \log p_t(\mathbf{x}_t)}{\alpha_t} D(xt,t)=αtxt+σt2∇logpt(xt)
其中 α t = exp ( − 1 2 ∫ 0 t β ( s ) d s ) \alpha_t = \exp\left(-\frac{1}{2}\int_0^t \beta(s) ds\right) αt=exp(−21∫0tβ(s)ds), σ t 2 = 1 − α t 2 \sigma_t^2 = 1 - \alpha_t^2 σt2=1−αt2。
则 PF-ODE 可改写为:
d x t d t = − α ˙ t α t x t + α ˙ t σ t α t D ( x t , t ) \frac{d\mathbf{x}_t}{dt} = -\frac{\dot{\alpha}_t}{\alpha_t} \mathbf{x}_t + \frac{\dot{\alpha}_t \sigma_t}{\alpha_t} \mathbf{D}(\mathbf{x}_t, t) dtdxt=−αtα˙txt+αtα˙tσtD(xt,t)
2.3 PF-ODE 轨迹的性质
定理 :对于 VP 扩散的 PF-ODE,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 具有以下性质:
- 唯一性 :给定 x 0 \mathbf{x}_0 x0,轨迹唯一确定
- 连续性 : x t \mathbf{x}_t xt 关于 t t t 连续
- 边界行为 : x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata, x T ≈ N ( 0 , I ) \mathbf{x}_T \approx \mathcal{N}(0, \mathbf{I}) xT≈N(0,I)
直觉 :PF-ODE 将数据分布 p data p_{\text{data}} pdata 连续地"变形"为高斯噪声 N ( 0 , I ) \mathcal{N}(0, \mathbf{I}) N(0,I)。生成过程就是从噪声沿轨迹"走回去"。
3. 一致性模型的核心定义
3.1 一致性函数
定义 (一致性函数 f θ f_\theta fθ):对于 PF-ODE 的任意轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,一致性函数满足:
f θ ( x t , t ) = f θ ( x t ′ , t ′ ) , ∀ t , t ′ ∈ 0 , T f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t'}, t'), \quad \forall t, t' \in 0, T fθ(xt,t)=fθ(xt′,t′),∀t,t′∈0,T
即轨迹上的所有点映射到同一个值 ------轨迹的起点 x 0 \mathbf{x}0 x0(或等价地, x ϵ \mathbf{x}\epsilon xϵ 以避免数值问题)。
边界条件:
f θ ( x ϵ , ϵ ) = x ϵ f_\theta(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon fθ(xϵ,ϵ)=xϵ
其中 ϵ > 0 \epsilon > 0 ϵ>0 是一个很小的常数(如 ϵ = 0.002 \epsilon = 0.002 ϵ=0.002),用于避免 t = 0 t = 0 t=0 处的数值不稳定性。
3.2 自洽性(Self-Consistency)
一致性模型的核心约束是自洽性:
f θ ( x t , t ) = f θ ( x t + δ , t + δ ) , ∀ δ > 0 f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t+\delta}, t + \delta), \quad \forall \delta > 0 fθ(xt,t)=fθ(xt+δ,t+δ),∀δ>0
这意味着:沿 ODE 轨迹前进任意步,一致性函数的输出不变。
几何直觉:一致性函数将整条 ODE 轨迹"压缩"为一个点(起点)。不同轨迹映射到不同的点,但同一轨迹上的所有点映射到同一个点。
3.3 生成过程
利用一致性函数,生成过程极其简单:
单步生成:
x 0 = f θ ( x T , T ) , x T ∼ N ( 0 , I ) \mathbf{x}0 = f\theta(\mathbf{x}_T, T), \quad \mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I}) x0=fθ(xT,T),xT∼N(0,I)
一次前向传播,从噪声直接生成数据!
多步生成(迭代精炼):
x ϵ = f θ ( x T , T ) \mathbf{x}\epsilon = f\theta(\mathbf{x}_T, T) xϵ=fθ(xT,T)
x t n + 1 = f θ ( x t n , t n ) + t n 2 − ϵ 2 ⋅ z , z ∼ N ( 0 , I ) \mathbf{x}{t{n+1}} = f_\theta(\mathbf{x}_{t_n}, t_n) + \sqrt{t_n^2 - \epsilon^2} \cdot \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) xtn+1=fθ(xtn,tn)+tn2−ϵ2 ⋅z,z∼N(0,I)
多步生成通过在轨迹上添加少量噪声再映射,实现逐步精炼。
4. 训练方法 --- 一致性蒸馏
4.1 一致性蒸馏(Consistency Distillation)
核心思想:利用预训练的扩散模型(教师)生成轨迹上的相邻点对,训练一致性模型(学生)使它们映射到相同值。
训练数据 :对于每个数据点 x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata,通过 PF-ODE 求解器获得相邻点对 ( x t n + 1 , x t n ) (\mathbf{x}{t{n+1}}, \mathbf{x}_{t_n}) (xtn+1,xtn)。
损失函数:
L CD ( θ , θ − ) = E d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x t n , t n ) ) \mathcal{L}_{\text{CD}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_{t_{n+1}}, t_{n+1}), \\, f_{\\theta\^-}(\\mathbf{x}_{t_n}, t_n) \\right) \\right LCD(θ,θ−)=Ed(fθ(xtn+1,tn+1),fθ−(xtn,tn))
其中:
- f θ f_\theta fθ 是在线网络(正在更新)
- f θ − f_{\theta^-} fθ− 是目标网络(EMA 更新)
- d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(⋅,⋅) 是距离度量(如 L 2 L_2 L2、LPIPS)
- x t n \mathbf{x}{t_n} xtn 由 x t n + 1 \mathbf{x}{t_{n+1}} xtn+1 通过一步 PF-ODE 求解器获得
目标网络更新:
θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ
其中 μ \mu μ 是 EMA 衰减率(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)。
4.2 数学正确性
定理 :当 f θ f_\theta fθ 满足自洽性时, L CD = 0 \mathcal{L}_{\text{CD}} = 0 LCD=0。
证明:
由自洽性:
f θ ( x t n + 1 , t n + 1 ) = f θ ( x t n , t n ) f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}) = f_\theta(\mathbf{x}_{t_n}, t_n) fθ(xtn+1,tn+1)=fθ(xtn,tn)
当 θ − = θ \theta^- = \theta θ−=θ 时(EMA 已收敛):
d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x t n , t n ) ) = d ( f θ ( x t n + 1 , t n + 1 ) , f θ ( x t n , t n ) ) = 0 d(f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_{\theta^-}(\mathbf{x}{t_n}, t_n)) = d(f\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_\theta(\mathbf{x}_{t_n}, t_n)) = 0 d(fθ(xtn+1,tn+1),fθ−(xtn,tn))=d(fθ(xtn+1,tn+1),fθ(xtn,tn))=0
■ \blacksquare ■
4.3 一步 PF-ODE 求解
为了获得相邻点对,需要一步 PF-ODE 求解。使用 DDIM 求解器:
x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)
其中 D ϕ \mathbf{D}_\phi Dϕ 是预训练的去噪模型(教师), α t , σ t \alpha_t, \sigma_t αt,σt 是噪声调度参数。
5. 完整可运行实现
5.1 一致性模型核心实现
python
"""
一致性模型 (Consistency Models) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Tuple, Optional, List
from dataclasses import dataclass
@dataclass
class ConsistencyConfig:
"""一致性模型配置"""
data_dim: int = 2
hidden_dim: int = 256
time_dim: int = 64
num_layers: int = 6
sigma_min: float = 0.002
sigma_max: float = 80.0
rho: float = 7.0
num_timesteps: int = 40 # 时间步离散化数量
ema_decay: float = 0.9999
def get_sigmas_karras(
sigma_min: float, sigma_max: float, rho: float, num_steps: int
) -> torch.Tensor:
"""Karras 噪声调度 (Karras et al., 2022)"""
inv_rho = 1.0 / rho
steps = torch.arange(num_steps, dtype=torch.float64) / (num_steps - 1)
sigmas = (sigma_max ** inv_rho + steps * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
return sigmas.float()
def get_alpha_sigma(t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""VP 扩散的 alpha 和 sigma 参数"""
alpha = torch.cos(t * math.pi / 2)
sigma = torch.sin(t * math.pi / 2)
return alpha, sigma
class SinusoidalTimeEmbedding(nn.Module):
"""正弦时间嵌入"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
class ConsistencyModel(nn.Module):
"""一致性模型网络"""
def __init__(self, config: ConsistencyConfig):
super().__init__()
self.config = config
self.time_embed = SinusoidalTimeEmbedding(config.time_dim)
self.input_proj = nn.Linear(config.data_dim + config.time_dim, config.hidden_dim)
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.SiLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.LayerNorm(config.hidden_dim),
nn.SiLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
for _ in range(config.num_layers)
])
self.output_proj = nn.Linear(config.hidden_dim, config.data_dim)
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
x: (B, D) 带噪数据
t: (B,) 时间步 (0=数据, 1=噪声)
"""
t_emb = self.time_embed(t)
h = torch.cat([x, t_emb], dim=-1)
h = self.input_proj(h)
for block in self.blocks:
h = h + block(h)
return self.output_proj(h)
5.2 一致性蒸馏训练
python
class ConsistencyDistillation:
"""一致性蒸馏训练器"""
def __init__(
self,
model: ConsistencyModel,
teacher_model: nn.Module,
config: ConsistencyConfig,
device: torch.device,
):
self.model = model
self.teacher = teacher_model
self.config = config
self.device = device
# 创建目标网络 (EMA)
self.target_model = ConsistencyModel(config).to(device)
self.target_model.load_state_dict(model.state_dict())
# 冻结教师模型
for param in self.teacher.parameters():
param.requires_grad = False
# 噪声调度
self.sigmas = get_sigmas_karras(
config.sigma_min, config.sigma_max, config.rho, config.num_timesteps
).to(device)
self.optimizer = torch.optim.AdamW(
model.parameters(), lr=1e-4, weight_decay=0.0
)
def add_noise(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""添加噪声: x_t = alpha_t * x + sigma_t * noise"""
alpha, sigma = get_alpha_sigma(t)
noise = torch.randn_like(x)
x_t = alpha.unsqueeze(-1) * x + sigma.unsqueeze(-1) * noise
return x_t, noise
def one_step_denoise(
self, x_t: torch.Tensor, t: torch.Tensor, t_prev: torch.Tensor
) -> torch.Tensor:
"""使用教师模型进行一步去噪 (DDIM 求解器)"""
with torch.no_grad():
# 教师模型预测去噪结果
x_denoised = self.teacher(x_t, t)
# DDIM 一步更新
alpha_t, sigma_t = get_alpha_sigma(t)
alpha_prev, sigma_prev = get_alpha_sigma(t_prev)
# x_{t_prev} = alpha_{t_prev} * x_denoised + sigma_{t_prev} * (x_t - alpha_t * x_denoised) / sigma_t
x_prev = (
alpha_prev.unsqueeze(-1) * x_denoised
+ sigma_prev.unsqueeze(-1) * (x_t - alpha_t.unsqueeze(-1) * x_denoised) / sigma_t.unsqueeze(-1)
)
return x_prev
def compute_loss(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, dict]:
"""计算一致性蒸馏损失"""
B = x.shape[0]
# 随机采样时间步对 (t_{n+1}, t_n)
# 使用 Karras 调度的离散时间步
n = torch.randint(0, self.config.num_timesteps - 1, (B,), device=self.device)
t = self.sigmas[n] # t_{n+1}
t_prev = self.sigmas[n + 1] # t_n (更接近数据)
# 添加噪声
x_t, noise = self.add_noise(x, t)
# 教师模型一步去噪
x_prev = self.one_step_denoise(x_t, t, t_prev)
# 在线网络预测
pred_online = self.model(x_t, t)
# 目标网络预测
with torch.no_grad():
pred_target = self.target_model(x_prev, t_prev)
# 一致性损失 (L2)
loss = F.mse_loss(pred_online, pred_target)
metrics = {
"loss": loss.item(),
"t_mean": t.mean().item(),
"t_prev_mean": t_prev.mean().item(),
}
return loss, metrics
def train_step(self, x: torch.Tensor) -> dict:
"""执行一步训练"""
self.model.train()
self.optimizer.zero_grad()
loss, metrics = self.compute_loss(x)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
# EMA 更新目标网络
with torch.no_grad():
for param, target_param in zip(
self.model.parameters(), self.target_model.parameters()
):
target_param.data.mul_(self.config.ema_decay).add_(
param.data, alpha=1 - self.config.ema_decay
)
return metrics
@torch.no_grad()
def sample(
self, num_samples: int, num_steps: int = 1
) -> torch.Tensor:
"""生成样本"""
self.model.eval()
# 从噪声开始
x = torch.randn(num_samples, self.config.data_dim, device=self.device)
if num_steps == 1:
# 单步生成
t = torch.ones(num_samples, device=self.device)
x = self.model(x, t)
else:
# 多步生成
timesteps = get_sigmas_karras(
self.config.sigma_min, self.config.sigma_max,
self.config.rho, num_steps
).to(self.device)
for i in range(len(timesteps) - 1):
t = timesteps[i].expand(num_samples)
x = self.model(x, t)
# 添加少量噪声 (用于迭代精炼)
if i < len(timesteps) - 2:
noise = torch.randn_like(x)
t_next = timesteps[i + 1]
_, sigma = get_alpha_sigma(t_next)
x = x + sigma * noise * 0.5
return x
5.3 实验代码
python
def experiment_consistency_2d():
"""在 2D 双月数据上训练一致性模型"""
from sklearn.datasets import make_moons
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 生成数据
data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
data = (data - data.mean(axis=0)) / data.std(axis=0)
data = torch.tensor(data, dtype=torch.float32).to(device)
config = ConsistencyConfig(data_dim=2, hidden_dim=256, num_layers=6)
# 教师模型 (预训练的去噪模型)
teacher = ConsistencyModel(config).to(device)
# 学生模型 (一致性模型)
student = ConsistencyModel(config).to(device)
# 初始化教师模型 (模拟预训练)
# 实际应用中,这里加载预训练的扩散模型
teacher.load_state_dict(student.state_dict())
trainer = ConsistencyDistillation(student, teacher, config, device)
# 训练
print("一致性蒸馏训练...")
batch_size = 256
losses = []
for step in range(5000):
idx = torch.randint(0, data.shape[0], (batch_size,))
x_batch = data[idx]
metrics = trainer.train_step(x_batch)
losses.append(metrics["loss"])
if (step + 1) % 1000 == 0:
avg_loss = np.mean(losses[-100:])
print(f"Step {step+1} | Loss: {avg_loss:.6f}")
# 生成样本
print("\n生成样本...")
samples_1step = trainer.sample(num_samples=1000, num_steps=1)
samples_5step = trainer.sample(num_samples=1000, num_steps=5)
return trainer, samples_1step.cpu(), samples_5step.cpu()
6. 一致性模型的理论性质
6.1 表达能力定理
定理(Song et al., 2023):一致性模型的表达能力足够强大,可以精确表示任意 PF-ODE 轨迹。
证明思路 :对于任意 PF-ODE 轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,定义 f ∗ ( x t , t ) = x 0 f^*(\mathbf{x}t, t) = \mathbf{x}0 f∗(xt,t)=x0。这是一个合法的一致性函数,且满足边界条件 f ∗ ( x ϵ , ϵ ) = x ϵ f^*(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon f∗(xϵ,ϵ)=xϵ。
当神经网络 f θ f_\theta fθ 的容量足够大时,它可以任意逼近 f ∗ f^* f∗。 ■ \blacksquare ■
6.2 单步生成的误差分析
定理:单步生成的误差上界为:
E ∥ x 0 − f θ ( x T , T ) ∥ 2 ≤ L CD ( θ , θ − ) + O ( Δ t ) \mathbb{E}\left\\\|\\mathbf{x}_0 - f_\\theta(\\mathbf{x}_T, T)\\\|\^2\\right \leq \mathcal{L}_{\text{CD}}(\theta, \theta^-) + O(\Delta t) E∥x0−fθ(xT,T)∥2≤LCD(θ,θ−)+O(Δt)
其中 Δ t \Delta t Δt 是时间离散化的步长。
直觉 :一致性蒸馏损失越小,单步生成的质量越高。时间步离散化越细( Δ t \Delta t Δt 越小),误差越小。
6.3 与扩散模型的关系
| 特性 | 扩散模型 | 一致性模型 |
|---|---|---|
| 生成步数 | 10-1000 步 | 1-5 步 |
| 训练目标 | 去噪 | 自洽性 |
| 采样方式 | 迭代去噪 | 直接映射 |
| 理论基础 | SDE/PF-ODE | PF-ODE 轨迹 |
| 质量-速度权衡 | 高质量但慢 | 快但略低质量 |
7. 一致性模型数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ 一致性模型 (Consistency Models) 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 概率流 ODE: ║
║ dx_t/dt = f(x_t, t) - ½g(t)² ∇log p_t(x_t) ║
║ 轨迹: x_0 → x_T (数据 → 噪声) ║
║ ║
║ 2. 一致性函数定义: ║
║ f_θ(x_t, t) = f_θ(x_{t'}, t') ∀t, t' ∈ [0, T] (同一轨迹→同一输出) ║
║ 边界条件: f_θ(x_ε, ε) = x_ε ║
║ ║
║ 3. 单步生成: ║
║ x_0 = f_θ(x_T, T), x_T ~ N(0, I) ║
║ 一次前向传播, 从噪声直接生成数据 ║
║ ║
║ 4. 一致性蒸馏损失: ║
║ L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ] ║
║ θ⁻ = EMA(θ), x_{t_n} 由一步 DDIM 求解器获得 ║
║ ║
║ 5. 正确性: ║
║ f_θ 满足自洽性 ⟹ L_CD = 0 ║
║ ║
║ 6. 误差上界: ║
║ E[‖x₀ - f_θ(x_T, T)‖²] ≤ L_CD + O(Δt) ║
║ ║
║ 7. VP 扩散参数: ║
║ α_t = cos(πt/2), σ_t = sin(πt/2) ║
║ x_t = α_t·x₀ + σ_t·ε, ε ~ N(0, I) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第二篇:一致性蒸馏 --- 从预训练扩散模型学习
1. 引言
一致性蒸馏(Consistency Distillation)是训练一致性模型的主要方法,它利用预训练的扩散模型作为教师,通过蒸馏的方式训练学生模型。本篇深入分析蒸馏过程的数学细节和实现技巧。
2. 蒸馏过程的数学分析
2.1 教师-学生框架
教师模型 D ϕ \mathbf{D}_\phi Dϕ:预训练的去噪模型,满足:
D ϕ ( x t , t ) ≈ E x 0 ∣ x t \mathbf{D}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ(xt,t)≈Ex0∣xt
学生模型 f θ f_\theta fθ:一致性模型,满足自洽性。
蒸馏目标:利用教师模型生成轨迹上的相邻点对,训练学生模型使它们映射到相同值。
2.2 DDIM 求解器
DDIM(Denoising Diffusion Implicit Models, Song et al., 2020)是一种确定性的 PF-ODE 求解器:
x t n = α t n ( x t n + 1 − σ t n + 1 ϵ ^ α t n + 1 ) ⏟ x ^ 0 + σ t n ϵ ^ \mathbf{x}{t{n}} = \alpha_{t_n} \underbrace{\left(\frac{\mathbf{x}{t{n+1}} - \sigma_{t_{n+1}} \hat{\mathbf{\epsilon}}}{\alpha_{t_{n+1}}}\right)}_{\hat{\mathbf{x}}0} + \sigma{t_n} \hat{\mathbf{\epsilon}} xtn=αtnx^0 (αtn+1xtn+1−σtn+1ϵ^)+σtnϵ^
其中 ϵ ^ = ( x t n + 1 − α t n + 1 x ^ 0 ) / σ t n + 1 \hat{\mathbf{\epsilon}} = (\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \hat{\mathbf{x}}0) / \sigma{t_{n+1}} ϵ^=(xtn+1−αtn+1x^0)/σtn+1 是预测的噪声。
等价形式(使用去噪函数):
x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)
2.3 时间步离散化
Karras 调度(Karras et al., 2022):
t i = ( t max 1 / ρ + i N − 1 ( t min 1 / ρ − t max 1 / ρ ) ) ρ t_i = \left(t_{\max}^{1/\rho} + \frac{i}{N-1} (t_{\min}^{1/\rho} - t_{\max}^{1/\rho})\right)^\rho ti=(tmax1/ρ+N−1i(tmin1/ρ−tmax1/ρ))ρ
其中 ρ = 7 \rho = 7 ρ=7 控制时间步的分布(更多步集中在噪声端)。
直觉:噪声端的 ODE 曲率更大,需要更细的时间步离散化。
3. 训练技巧
3.1 EMA 更新
目标网络使用 EMA(Exponential Moving Average)更新:
θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ
为什么需要 EMA:如果直接使用在线网络作为目标,训练会不稳定------两个相同的网络相互"追逐",导致发散。EMA 提供了稳定的目标。
3.2 距离度量
L2 距离:
d ( x , y ) = ∥ x − y ∥ 2 2 d(\mathbf{x}, \mathbf{y}) = \|\mathbf{x} - \mathbf{y}\|_2^2 d(x,y)=∥x−y∥22
LPIPS 距离(用于图像):
d LPIPS ( x , y ) = ∑ l ∥ feat l ( x ) − feat l ( y ) ∥ 2 2 d_{\text{LPIPS}}(\mathbf{x}, \mathbf{y}) = \sum_l \|\text{feat}_l(\mathbf{x}) - \text{feat}_l(\mathbf{y})\|_2^2 dLPIPS(x,y)=l∑∥featl(x)−featl(y)∥22
LPIPS 使用预训练网络的特征距离,更符合人类感知。
3.3 梯度裁剪
一致性蒸馏的梯度可能很大,需要梯度裁剪:
g ← g ⋅ min ( 1 , c ∥ g ∥ ) \mathbf{g} \leftarrow \mathbf{g} \cdot \min\left(1, \frac{c}{\|\mathbf{g}\|}\right) g←g⋅min(1,∥g∥c)
其中 c c c 是裁剪阈值(如 c = 1.0 c = 1.0 c=1.0)。
4. 一致性蒸馏的收敛性
4.1 收敛定理
定理(非正式):在适当条件下,一致性蒸馏收敛到教师模型的 PF-ODE 轨迹。
条件:
- 教师模型足够好( D ϕ ≈ E x 0 ∣ x t \mathbf{D}_\phi \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ≈Ex0∣xt)
- 时间步离散化足够细( Δ t → 0 \Delta t \to 0 Δt→0)
- 网络容量足够大
- EMA 衰减率 μ \mu μ 适当(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)
4.2 蒸馏误差的分解
总蒸馏误差可分解为:
Total Error = Approximation Error ⏟ 网络容量不足 + Discretization Error ⏟ 时间步离散化 + Optimization Error ⏟ 训练不充分 \text{Total Error} = \underbrace{\text{Approximation Error}}{\text{网络容量不足}} + \underbrace{\text{Discretization Error}}{\text{时间步离散化}} + \underbrace{\text{Optimization Error}}_{\text{训练不充分}} Total Error=网络容量不足 Approximation Error+时间步离散化 Discretization Error+训练不充分 Optimization Error
- 逼近误差 : O ( 1 / n ) O(1/\sqrt{n}) O(1/n ), n n n 是网络参数量
- 离散化误差 : O ( Δ t ) O(\Delta t) O(Δt), Δ t \Delta t Δt 是时间步间隔
- 优化误差:随训练步数减小
5. 一致性蒸馏数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ 一致性蒸馏 (Consistency Distillation) 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. DDIM 求解器: ║
║ x_{t_n} = α_{t_n}·D_φ(x_{t_{n+1}}, t_{n+1}) ║
║ + σ_{t_n}·(x_{t_{n+1}} - α_{t_{n+1}}·D_φ) / σ_{t_{n+1}} ║
║ ║
║ 2. 蒸馏损失: ║
║ L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ] ║
║ x_{t_n} 由一步 DDIM 获得, θ⁻ = EMA(θ) ║
║ ║
║ 3. EMA 更新: ║
║ θ⁻ ← μ·θ⁻ + (1-μ)·θ, μ = 0.9999 ║
║ 提供稳定的训练目标 ║
║ ║
║ 4. 时间步调度 (Karras): ║
║ t_i = (t_max^{1/ρ} + i/(N-1)·(t_min^{1/ρ} - t_max^{1/ρ}))^ρ ║
║ ρ = 7, 噪声端步长更细 ║
║ ║
║ 5. 误差分解: ║
║ Total = Approximation + Discretization + Optimization ║
║ = O(1/√n) + O(Δt) + O(1/√T) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第三篇:一致性模型的前沿发展
1. 引言
一致性模型自 2023 年提出以来,已经发展出多个重要变体和应用方向。
2. 一致性训练(Consistency Training)
2.1 无需教师的训练
一致性蒸馏需要预训练的教师模型。一致性训练(Consistency Training) 直接从数据训练,无需教师。
核心思想:利用 SDE 的随机性生成相邻点对。
前向 SDE:
d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt
对于同一个 x 0 \mathbf{x}_0 x0,两次独立的 SDE 采样得到 x t \mathbf{x}t xt 和 x t ′ \mathbf{x}{t'} xt′( t ≈ t ′ t \approx t' t≈t′),它们在同一条轨迹附近。
一致性训练损失:
L CT ( θ , θ − ) = E d ( f θ ( x t , t ) , f θ − ( x t ′ , t ′ ) ) \mathcal{L}_{\text{CT}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_t, t), \\, f_{\\theta\^-}(\\mathbf{x}_{t'}, t') \\right) \\right LCT(θ,θ−)=Ed(fθ(xt,t),fθ−(xt′,t′))
2.2 数学挑战
一致性训练的理论保证弱于一致性蒸馏,因为 SDE 轨迹不完全确定------两次独立采样的点可能不在同一条 PF-ODE 轨迹上。
缓解策略:
- 使用较小的时间步间隔 ∣ t − t ′ ∣ |t - t'| ∣t−t′∣
- 使用较大的 EMA 衰减率
- 使用更稳定的距离度量
3. 进阶变体
3.1 渐进蒸馏(Progressive Distillation)
思想:逐步减少采样步数,每轮将步数减半。
流程:
- 训练 N N N 步的扩散模型
- 蒸馏为 N / 2 N/2 N/2 步
- 蒸馏为 N / 4 N/4 N/4 步
- ...直到 1 步
优势:每轮蒸馏的难度更低,训练更稳定。
3.2 一致性模型 + Latent Diffusion
将一致性模型应用于 Latent Diffusion(如 Stable Diffusion):
z 0 = f θ ( z T , T ) , x 0 = Decoder ( z 0 ) \mathbf{z}0 = f\theta(\mathbf{z}_T, T), \quad \mathbf{x}_0 = \text{Decoder}(\mathbf{z}_0) z0=fθ(zT,T),x0=Decoder(z0)
优势:在低维潜在空间中操作,计算效率更高。
3.3 一致性模型 + Classifier-Free Guidance
将 CFG 应用于一致性模型:
f θ cfg ( x t , t , c ) = ( 1 + w ) f θ ( x t , t , c ) − w f θ ( x t , t , ∅ ) f_\theta^{\text{cfg}}(\mathbf{x}t, t, c) = (1 + w) f\theta(\mathbf{x}t, t, c) - w f\theta(\mathbf{x}_t, t, \varnothing) fθcfg(xt,t,c)=(1+w)fθ(xt,t,c)−wfθ(xt,t,∅)
其中 w w w 是引导强度, c c c 是条件(如文本), ∅ \varnothing ∅ 是空条件。
4. 一致性模型与其他方法的对比
| 方法 | 采样步数 | 训练方式 | 质量 | 速度 |
|---|---|---|---|---|
| DDPM | 1000 | 去噪 | 最高 | 最慢 |
| DDIM | 10-50 | 去噪 | 高 | 中等 |
| Flow Matching | 10-50 | 速度场 | 高 | 中等 |
| 一致性蒸馏 | 1-5 | 蒸馏 | 中高 | 快 |
| 一致性训练 | 1-5 | 直接训练 | 中 | 最快 |
5. 前沿研究方向
5.1 音频一致性模型
将一致性模型应用于音频生成(AudioLDM、MusicGen),实现单步音频合成。
5.2 视频一致性模型
将一致性模型应用于视频生成,利用时间维度的一致性。
5.3 3D 一致性模型
将一致性模型应用于 3D 生成(NeRF、3D Gaussian Splatting),实现单步 3D 重建。
5.4 一致性模型的理论深化
- 最优传输视角:将一致性模型与最优传输理论联系
- 信息论分析:分析一致性模型的信息压缩率
- 收敛速率:改进一致性蒸馏的收敛速率分析
6. 前沿发展数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ 一致性模型前沿发展 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 一致性训练 (无教师): ║
║ L_CT = E[ d(f_θ(x_t, t), f_{θ⁻}(x_{t'}, t')) ] ║
║ x_t, x_{t'} 由同一 x₀ 的两次独立 SDE 采样获得 ║
║ ║
║ 2. 渐进蒸馏: ║
║ N 步 → N/2 步 → N/4 步 → ... → 1 步 ║
║ 每轮蒸馏难度更低, 训练更稳定 ║
║ ║
║ 3. CFG 引导: ║
║ f_θ^{cfg}(x_t, t, c) = (1+w)·f_θ(x_t, t, c) - w·f_θ(x_t, t, ∅) ║
║ ║
║ 4. Latent 一致性: ║
║ z₀ = f_θ(z_T, T), x₀ = Decoder(z₀) ║
║ 在低维潜在空间操作, 计算效率更高 ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
参考文献
核心论文
- Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. ICML 2023.
- Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-Based Generative Modeling through Stochastic Differential Equations. ICLR 2021.
- Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
扩散模型基础
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
- Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. NeurIPS 2022.
加速采样
- Lu, C., Zhou, Y., Bao, F., Chen, J., Li, C., & Zhu, J. (2022). DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. NeurIPS 2022.
- Salimans, T., & Ho, J. (2022). Progressive Distillation for Fast Sampling of Diffusion Models. ICLR 2022.
应用
- Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR 2022.
- Saharia, C., Chan, W., et al. (2022). Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. NeurIPS 2022.
一致性模型深度解析
目录
- [第一篇:一致性模型 --- 单步生成的数学基础](#第一篇:一致性模型 — 单步生成的数学基础)
- [第二篇:一致性蒸馏 --- 从预训练扩散模型学习](#第二篇:一致性蒸馏 — 从预训练扩散模型学习)
- 第三篇:一致性模型的前沿发展
- 参考文献
第一篇:一致性模型 --- 单步生成的数学基础
1. 引言
扩散模型(Diffusion Models)在图像生成领域取得了巨大成功,但其采样过程需要数百步迭代去噪,计算代价高昂。即使是加速采样方法(如 DDIM、DPM-Solver),通常也需要 10-50 步才能获得高质量样本。
一致性模型(Consistency Models, Song et al., 2023) 提出了一种全新的生成范式:学习一个函数,将 ODE 轨迹上的任意点直接映射到轨迹起点(数据)。这一设计实现了:
- 单步生成:一次前向传播即可生成高质量样本
- 多步精炼:支持多步采样以换取更高质量
- 数学优雅:基于概率流 ODE 的几何性质,定义简洁
2. 理论基础 --- 概率流 ODE
2.1 扩散过程的 ODE 视角
扩散模型有两种等价的连续时间表述:
随机微分方程(SDE)(前向):
d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt
概率流 ODE(PF-ODE)(前向):
d x t d t = f ( x t , t ) − 1 2 g ( t ) 2 ∇ log p t ( x t ) \frac{d\mathbf{x}_t}{dt} = \mathbf{f}(\mathbf{x}_t, t) - \frac{1}{2} g(t)^2 \nabla \log p_t(\mathbf{x}_t) dtdxt=f(xt,t)−21g(t)2∇logpt(xt)
其中 ∇ log p t ( x t ) \nabla \log p_t(\mathbf{x}_t) ∇logpt(xt) 是分数函数(score function)。
关键性质 :PF-ODE 与 SDE 具有相同的边际分布 p t ( x ) p_t(\mathbf{x}) pt(x),但 PF-ODE 是确定性的------给定初始条件 x 0 \mathbf{x}_0 x0,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 唯一确定。
2.2 方差保持(VP)扩散的 PF-ODE
对于 VP 扩散(DDPM):
f ( x t , t ) = − 1 2 β ( t ) x t , g ( t ) = β ( t ) \mathbf{f}(\mathbf{x}_t, t) = -\frac{1}{2} \beta(t) \mathbf{x}_t, \quad g(t) = \sqrt{\beta(t)} f(xt,t)=−21β(t)xt,g(t)=β(t)
PF-ODE 为:
d x t d t = − 1 2 β ( t ) x t − 1 2 β ( t ) ∇ log p t ( x t ) \frac{d\mathbf{x}_t}{dt} = -\frac{1}{2} \beta(t) \mathbf{x}_t - \frac{1}{2} \beta(t) \nabla \log p_t(\mathbf{x}_t) dtdxt=−21β(t)xt−21β(t)∇logpt(xt)
定义去噪函数 D ( x t , t ) \mathbf{D}(\mathbf{x}_t, t) D(xt,t):
D ( x t , t ) = x t + σ t 2 ∇ log p t ( x t ) α t \mathbf{D}(\mathbf{x}_t, t) = \frac{\mathbf{x}_t + \sigma_t^2 \nabla \log p_t(\mathbf{x}_t)}{\alpha_t} D(xt,t)=αtxt+σt2∇logpt(xt)
其中 α t = exp ( − 1 2 ∫ 0 t β ( s ) d s ) \alpha_t = \exp\left(-\frac{1}{2}\int_0^t \beta(s) ds\right) αt=exp(−21∫0tβ(s)ds), σ t 2 = 1 − α t 2 \sigma_t^2 = 1 - \alpha_t^2 σt2=1−αt2。
则 PF-ODE 可改写为:
d x t d t = − α ˙ t α t x t + α ˙ t σ t α t D ( x t , t ) \frac{d\mathbf{x}_t}{dt} = -\frac{\dot{\alpha}_t}{\alpha_t} \mathbf{x}_t + \frac{\dot{\alpha}_t \sigma_t}{\alpha_t} \mathbf{D}(\mathbf{x}_t, t) dtdxt=−αtα˙txt+αtα˙tσtD(xt,t)
2.3 PF-ODE 轨迹的性质
定理 :对于 VP 扩散的 PF-ODE,轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T 具有以下性质:
- 唯一性 :给定 x 0 \mathbf{x}_0 x0,轨迹唯一确定
- 连续性 : x t \mathbf{x}_t xt 关于 t t t 连续
- 边界行为 : x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata, x T ≈ N ( 0 , I ) \mathbf{x}_T \approx \mathcal{N}(0, \mathbf{I}) xT≈N(0,I)
直觉 :PF-ODE 将数据分布 p data p_{\text{data}} pdata 连续地"变形"为高斯噪声 N ( 0 , I ) \mathcal{N}(0, \mathbf{I}) N(0,I)。生成过程就是从噪声沿轨迹"走回去"。
3. 一致性模型的核心定义
3.1 一致性函数
定义 (一致性函数 f θ f_\theta fθ):对于 PF-ODE 的任意轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,一致性函数满足:
f θ ( x t , t ) = f θ ( x t ′ , t ′ ) , ∀ t , t ′ ∈ 0 , T f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t'}, t'), \quad \forall t, t' \in 0, T fθ(xt,t)=fθ(xt′,t′),∀t,t′∈0,T
即轨迹上的所有点映射到同一个值 ------轨迹的起点 x 0 \mathbf{x}0 x0(或等价地, x ϵ \mathbf{x}\epsilon xϵ 以避免数值问题)。
边界条件:
f θ ( x ϵ , ϵ ) = x ϵ f_\theta(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon fθ(xϵ,ϵ)=xϵ
其中 ϵ > 0 \epsilon > 0 ϵ>0 是一个很小的常数(如 ϵ = 0.002 \epsilon = 0.002 ϵ=0.002),用于避免 t = 0 t = 0 t=0 处的数值不稳定性。
3.2 自洽性(Self-Consistency)
一致性模型的核心约束是自洽性:
f θ ( x t , t ) = f θ ( x t + δ , t + δ ) , ∀ δ > 0 f_\theta(\mathbf{x}t, t) = f\theta(\mathbf{x}_{t+\delta}, t + \delta), \quad \forall \delta > 0 fθ(xt,t)=fθ(xt+δ,t+δ),∀δ>0
这意味着:沿 ODE 轨迹前进任意步,一致性函数的输出不变。
几何直觉:一致性函数将整条 ODE 轨迹"压缩"为一个点(起点)。不同轨迹映射到不同的点,但同一轨迹上的所有点映射到同一个点。
3.3 生成过程
利用一致性函数,生成过程极其简单:
单步生成:
x 0 = f θ ( x T , T ) , x T ∼ N ( 0 , I ) \mathbf{x}0 = f\theta(\mathbf{x}_T, T), \quad \mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I}) x0=fθ(xT,T),xT∼N(0,I)
一次前向传播,从噪声直接生成数据!
多步生成(迭代精炼):
x ϵ = f θ ( x T , T ) \mathbf{x}\epsilon = f\theta(\mathbf{x}_T, T) xϵ=fθ(xT,T)
x t n + 1 = f θ ( x t n , t n ) + t n 2 − ϵ 2 ⋅ z , z ∼ N ( 0 , I ) \mathbf{x}{t{n+1}} = f_\theta(\mathbf{x}_{t_n}, t_n) + \sqrt{t_n^2 - \epsilon^2} \cdot \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) xtn+1=fθ(xtn,tn)+tn2−ϵ2 ⋅z,z∼N(0,I)
多步生成通过在轨迹上添加少量噪声再映射,实现逐步精炼。
4. 训练方法 --- 一致性蒸馏
4.1 一致性蒸馏(Consistency Distillation)
核心思想:利用预训练的扩散模型(教师)生成轨迹上的相邻点对,训练一致性模型(学生)使它们映射到相同值。
训练数据 :对于每个数据点 x 0 ∼ p data \mathbf{x}0 \sim p{\text{data}} x0∼pdata,通过 PF-ODE 求解器获得相邻点对 ( x t n + 1 , x t n ) (\mathbf{x}{t{n+1}}, \mathbf{x}_{t_n}) (xtn+1,xtn)。
损失函数:
L CD ( θ , θ − ) = E d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x t n , t n ) ) \mathcal{L}_{\text{CD}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_{t_{n+1}}, t_{n+1}), \\, f_{\\theta\^-}(\\mathbf{x}_{t_n}, t_n) \\right) \\right LCD(θ,θ−)=Ed(fθ(xtn+1,tn+1),fθ−(xtn,tn))
其中:
- f θ f_\theta fθ 是在线网络(正在更新)
- f θ − f_{\theta^-} fθ− 是目标网络(EMA 更新)
- d ( ⋅ , ⋅ ) d(\cdot, \cdot) d(⋅,⋅) 是距离度量(如 L 2 L_2 L2、LPIPS)
- x t n \mathbf{x}{t_n} xtn 由 x t n + 1 \mathbf{x}{t_{n+1}} xtn+1 通过一步 PF-ODE 求解器获得
目标网络更新:
θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ
其中 μ \mu μ 是 EMA 衰减率(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)。
4.2 数学正确性
定理 :当 f θ f_\theta fθ 满足自洽性时, L CD = 0 \mathcal{L}_{\text{CD}} = 0 LCD=0。
证明:
由自洽性:
f θ ( x t n + 1 , t n + 1 ) = f θ ( x t n , t n ) f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}) = f_\theta(\mathbf{x}_{t_n}, t_n) fθ(xtn+1,tn+1)=fθ(xtn,tn)
当 θ − = θ \theta^- = \theta θ−=θ 时(EMA 已收敛):
d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x t n , t n ) ) = d ( f θ ( x t n + 1 , t n + 1 ) , f θ ( x t n , t n ) ) = 0 d(f_\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_{\theta^-}(\mathbf{x}{t_n}, t_n)) = d(f\theta(\mathbf{x}{t{n+1}}, t_{n+1}), f_\theta(\mathbf{x}_{t_n}, t_n)) = 0 d(fθ(xtn+1,tn+1),fθ−(xtn,tn))=d(fθ(xtn+1,tn+1),fθ(xtn,tn))=0
■ \blacksquare ■
4.3 一步 PF-ODE 求解
为了获得相邻点对,需要一步 PF-ODE 求解。使用 DDIM 求解器:
x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)
其中 D ϕ \mathbf{D}_\phi Dϕ 是预训练的去噪模型(教师), α t , σ t \alpha_t, \sigma_t αt,σt 是噪声调度参数。
5. 完整可运行实现
5.1 一致性模型核心实现
python
"""
一致性模型 (Consistency Models) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import Tuple, Optional, List
from dataclasses import dataclass
@dataclass
class ConsistencyConfig:
"""一致性模型配置"""
data_dim: int = 2
hidden_dim: int = 256
time_dim: int = 64
num_layers: int = 6
sigma_min: float = 0.002
sigma_max: float = 80.0
rho: float = 7.0
num_timesteps: int = 40 # 时间步离散化数量
ema_decay: float = 0.9999
def get_sigmas_karras(
sigma_min: float, sigma_max: float, rho: float, num_steps: int
) -> torch.Tensor:
"""Karras 噪声调度 (Karras et al., 2022)"""
inv_rho = 1.0 / rho
steps = torch.arange(num_steps, dtype=torch.float64) / (num_steps - 1)
sigmas = (sigma_max ** inv_rho + steps * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
return sigmas.float()
def get_alpha_sigma(t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""VP 扩散的 alpha 和 sigma 参数"""
alpha = torch.cos(t * math.pi / 2)
sigma = torch.sin(t * math.pi / 2)
return alpha, sigma
class SinusoidalTimeEmbedding(nn.Module):
"""正弦时间嵌入"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
class ConsistencyModel(nn.Module):
"""一致性模型网络"""
def __init__(self, config: ConsistencyConfig):
super().__init__()
self.config = config
self.time_embed = SinusoidalTimeEmbedding(config.time_dim)
self.input_proj = nn.Linear(config.data_dim + config.time_dim, config.hidden_dim)
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.SiLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.LayerNorm(config.hidden_dim),
nn.SiLU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
for _ in range(config.num_layers)
])
self.output_proj = nn.Linear(config.hidden_dim, config.data_dim)
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
x: (B, D) 带噪数据
t: (B,) 时间步 (0=数据, 1=噪声)
"""
t_emb = self.time_embed(t)
h = torch.cat([x, t_emb], dim=-1)
h = self.input_proj(h)
for block in self.blocks:
h = h + block(h)
return self.output_proj(h)
5.2 一致性蒸馏训练
python
class ConsistencyDistillation:
"""一致性蒸馏训练器"""
def __init__(
self,
model: ConsistencyModel,
teacher_model: nn.Module,
config: ConsistencyConfig,
device: torch.device,
):
self.model = model
self.teacher = teacher_model
self.config = config
self.device = device
# 创建目标网络 (EMA)
self.target_model = ConsistencyModel(config).to(device)
self.target_model.load_state_dict(model.state_dict())
# 冻结教师模型
for param in self.teacher.parameters():
param.requires_grad = False
# 噪声调度
self.sigmas = get_sigmas_karras(
config.sigma_min, config.sigma_max, config.rho, config.num_timesteps
).to(device)
self.optimizer = torch.optim.AdamW(
model.parameters(), lr=1e-4, weight_decay=0.0
)
def add_noise(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""添加噪声: x_t = alpha_t * x + sigma_t * noise"""
alpha, sigma = get_alpha_sigma(t)
noise = torch.randn_like(x)
x_t = alpha.unsqueeze(-1) * x + sigma.unsqueeze(-1) * noise
return x_t, noise
def one_step_denoise(
self, x_t: torch.Tensor, t: torch.Tensor, t_prev: torch.Tensor
) -> torch.Tensor:
"""使用教师模型进行一步去噪 (DDIM 求解器)"""
with torch.no_grad():
# 教师模型预测去噪结果
x_denoised = self.teacher(x_t, t)
# DDIM 一步更新
alpha_t, sigma_t = get_alpha_sigma(t)
alpha_prev, sigma_prev = get_alpha_sigma(t_prev)
# x_{t_prev} = alpha_{t_prev} * x_denoised + sigma_{t_prev} * (x_t - alpha_t * x_denoised) / sigma_t
x_prev = (
alpha_prev.unsqueeze(-1) * x_denoised
+ sigma_prev.unsqueeze(-1) * (x_t - alpha_t.unsqueeze(-1) * x_denoised) / sigma_t.unsqueeze(-1)
)
return x_prev
def compute_loss(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, dict]:
"""计算一致性蒸馏损失"""
B = x.shape[0]
# 随机采样时间步对 (t_{n+1}, t_n)
# 使用 Karras 调度的离散时间步
n = torch.randint(0, self.config.num_timesteps - 1, (B,), device=self.device)
t = self.sigmas[n] # t_{n+1}
t_prev = self.sigmas[n + 1] # t_n (更接近数据)
# 添加噪声
x_t, noise = self.add_noise(x, t)
# 教师模型一步去噪
x_prev = self.one_step_denoise(x_t, t, t_prev)
# 在线网络预测
pred_online = self.model(x_t, t)
# 目标网络预测
with torch.no_grad():
pred_target = self.target_model(x_prev, t_prev)
# 一致性损失 (L2)
loss = F.mse_loss(pred_online, pred_target)
metrics = {
"loss": loss.item(),
"t_mean": t.mean().item(),
"t_prev_mean": t_prev.mean().item(),
}
return loss, metrics
def train_step(self, x: torch.Tensor) -> dict:
"""执行一步训练"""
self.model.train()
self.optimizer.zero_grad()
loss, metrics = self.compute_loss(x)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
# EMA 更新目标网络
with torch.no_grad():
for param, target_param in zip(
self.model.parameters(), self.target_model.parameters()
):
target_param.data.mul_(self.config.ema_decay).add_(
param.data, alpha=1 - self.config.ema_decay
)
return metrics
@torch.no_grad()
def sample(
self, num_samples: int, num_steps: int = 1
) -> torch.Tensor:
"""生成样本"""
self.model.eval()
# 从噪声开始
x = torch.randn(num_samples, self.config.data_dim, device=self.device)
if num_steps == 1:
# 单步生成
t = torch.ones(num_samples, device=self.device)
x = self.model(x, t)
else:
# 多步生成
timesteps = get_sigmas_karras(
self.config.sigma_min, self.config.sigma_max,
self.config.rho, num_steps
).to(self.device)
for i in range(len(timesteps) - 1):
t = timesteps[i].expand(num_samples)
x = self.model(x, t)
# 添加少量噪声 (用于迭代精炼)
if i < len(timesteps) - 2:
noise = torch.randn_like(x)
t_next = timesteps[i + 1]
_, sigma = get_alpha_sigma(t_next)
x = x + sigma * noise * 0.5
return x
5.3 实验代码
python
def experiment_consistency_2d():
"""在 2D 双月数据上训练一致性模型"""
from sklearn.datasets import make_moons
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 生成数据
data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
data = (data - data.mean(axis=0)) / data.std(axis=0)
data = torch.tensor(data, dtype=torch.float32).to(device)
config = ConsistencyConfig(data_dim=2, hidden_dim=256, num_layers=6)
# 教师模型 (预训练的去噪模型)
teacher = ConsistencyModel(config).to(device)
# 学生模型 (一致性模型)
student = ConsistencyModel(config).to(device)
# 初始化教师模型 (模拟预训练)
# 实际应用中,这里加载预训练的扩散模型
teacher.load_state_dict(student.state_dict())
trainer = ConsistencyDistillation(student, teacher, config, device)
# 训练
print("一致性蒸馏训练...")
batch_size = 256
losses = []
for step in range(5000):
idx = torch.randint(0, data.shape[0], (batch_size,))
x_batch = data[idx]
metrics = trainer.train_step(x_batch)
losses.append(metrics["loss"])
if (step + 1) % 1000 == 0:
avg_loss = np.mean(losses[-100:])
print(f"Step {step+1} | Loss: {avg_loss:.6f}")
# 生成样本
print("\n生成样本...")
samples_1step = trainer.sample(num_samples=1000, num_steps=1)
samples_5step = trainer.sample(num_samples=1000, num_steps=5)
return trainer, samples_1step.cpu(), samples_5step.cpu()
6. 一致性模型的理论性质
6.1 表达能力定理
定理(Song et al., 2023):一致性模型的表达能力足够强大,可以精确表示任意 PF-ODE 轨迹。
证明思路 :对于任意 PF-ODE 轨迹 { x t } t ∈ 0 , T \{\mathbf{x}t\}{t \in 0,T} {xt}t∈0,T,定义 f ∗ ( x t , t ) = x 0 f^*(\mathbf{x}t, t) = \mathbf{x}0 f∗(xt,t)=x0。这是一个合法的一致性函数,且满足边界条件 f ∗ ( x ϵ , ϵ ) = x ϵ f^*(\mathbf{x}\epsilon, \epsilon) = \mathbf{x}\epsilon f∗(xϵ,ϵ)=xϵ。
当神经网络 f θ f_\theta fθ 的容量足够大时,它可以任意逼近 f ∗ f^* f∗。 ■ \blacksquare ■
6.2 单步生成的误差分析
定理:单步生成的误差上界为:
E ∥ x 0 − f θ ( x T , T ) ∥ 2 ≤ L CD ( θ , θ − ) + O ( Δ t ) \mathbb{E}\left\\\|\\mathbf{x}_0 - f_\\theta(\\mathbf{x}_T, T)\\\|\^2\\right \leq \mathcal{L}_{\text{CD}}(\theta, \theta^-) + O(\Delta t) E∥x0−fθ(xT,T)∥2≤LCD(θ,θ−)+O(Δt)
其中 Δ t \Delta t Δt 是时间离散化的步长。
直觉 :一致性蒸馏损失越小,单步生成的质量越高。时间步离散化越细( Δ t \Delta t Δt 越小),误差越小。
6.3 与扩散模型的关系
| 特性 | 扩散模型 | 一致性模型 |
|---|---|---|
| 生成步数 | 10-1000 步 | 1-5 步 |
| 训练目标 | 去噪 | 自洽性 |
| 采样方式 | 迭代去噪 | 直接映射 |
| 理论基础 | SDE/PF-ODE | PF-ODE 轨迹 |
| 质量-速度权衡 | 高质量但慢 | 快但略低质量 |
7. 一致性模型数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ 一致性模型 (Consistency Models) 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 概率流 ODE: ║
║ dx_t/dt = f(x_t, t) - ½g(t)² ∇log p_t(x_t) ║
║ 轨迹: x_0 → x_T (数据 → 噪声) ║
║ ║
║ 2. 一致性函数定义: ║
║ f_θ(x_t, t) = f_θ(x_{t'}, t') ∀t, t' ∈ [0, T] (同一轨迹→同一输出) ║
║ 边界条件: f_θ(x_ε, ε) = x_ε ║
║ ║
║ 3. 单步生成: ║
║ x_0 = f_θ(x_T, T), x_T ~ N(0, I) ║
║ 一次前向传播, 从噪声直接生成数据 ║
║ ║
║ 4. 一致性蒸馏损失: ║
║ L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ] ║
║ θ⁻ = EMA(θ), x_{t_n} 由一步 DDIM 求解器获得 ║
║ ║
║ 5. 正确性: ║
║ f_θ 满足自洽性 ⟹ L_CD = 0 ║
║ ║
║ 6. 误差上界: ║
║ E[‖x₀ - f_θ(x_T, T)‖²] ≤ L_CD + O(Δt) ║
║ ║
║ 7. VP 扩散参数: ║
║ α_t = cos(πt/2), σ_t = sin(πt/2) ║
║ x_t = α_t·x₀ + σ_t·ε, ε ~ N(0, I) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第二篇:一致性蒸馏 --- 从预训练扩散模型学习
1. 引言
一致性蒸馏(Consistency Distillation)是训练一致性模型的主要方法,它利用预训练的扩散模型作为教师,通过蒸馏的方式训练学生模型。本篇深入分析蒸馏过程的数学细节和实现技巧。
2. 蒸馏过程的数学分析
2.1 教师-学生框架
教师模型 D ϕ \mathbf{D}_\phi Dϕ:预训练的去噪模型,满足:
D ϕ ( x t , t ) ≈ E x 0 ∣ x t \mathbf{D}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ(xt,t)≈Ex0∣xt
学生模型 f θ f_\theta fθ:一致性模型,满足自洽性。
蒸馏目标:利用教师模型生成轨迹上的相邻点对,训练学生模型使它们映射到相同值。
2.2 DDIM 求解器
DDIM(Denoising Diffusion Implicit Models, Song et al., 2020)是一种确定性的 PF-ODE 求解器:
x t n = α t n ( x t n + 1 − σ t n + 1 ϵ ^ α t n + 1 ) ⏟ x ^ 0 + σ t n ϵ ^ \mathbf{x}{t{n}} = \alpha_{t_n} \underbrace{\left(\frac{\mathbf{x}{t{n+1}} - \sigma_{t_{n+1}} \hat{\mathbf{\epsilon}}}{\alpha_{t_{n+1}}}\right)}_{\hat{\mathbf{x}}0} + \sigma{t_n} \hat{\mathbf{\epsilon}} xtn=αtnx^0 (αtn+1xtn+1−σtn+1ϵ^)+σtnϵ^
其中 ϵ ^ = ( x t n + 1 − α t n + 1 x ^ 0 ) / σ t n + 1 \hat{\mathbf{\epsilon}} = (\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \hat{\mathbf{x}}0) / \sigma{t_{n+1}} ϵ^=(xtn+1−αtn+1x^0)/σtn+1 是预测的噪声。
等价形式(使用去噪函数):
x t n = α t n D ϕ ( x t n + 1 , t n + 1 ) + σ t n x t n + 1 − α t n + 1 D ϕ ( x t n + 1 , t n + 1 ) σ t n + 1 \mathbf{x}{t_n} = \alpha{t_n} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1}) + \sigma_{t_n} \frac{\mathbf{x}{t{n+1}} - \alpha_{t_{n+1}} \mathbf{D}\phi(\mathbf{x}{t_{n+1}}, t_{n+1})}{\sigma_{t_{n+1}}} xtn=αtnDϕ(xtn+1,tn+1)+σtnσtn+1xtn+1−αtn+1Dϕ(xtn+1,tn+1)
2.3 时间步离散化
Karras 调度(Karras et al., 2022):
t i = ( t max 1 / ρ + i N − 1 ( t min 1 / ρ − t max 1 / ρ ) ) ρ t_i = \left(t_{\max}^{1/\rho} + \frac{i}{N-1} (t_{\min}^{1/\rho} - t_{\max}^{1/\rho})\right)^\rho ti=(tmax1/ρ+N−1i(tmin1/ρ−tmax1/ρ))ρ
其中 ρ = 7 \rho = 7 ρ=7 控制时间步的分布(更多步集中在噪声端)。
直觉:噪声端的 ODE 曲率更大,需要更细的时间步离散化。
3. 训练技巧
3.1 EMA 更新
目标网络使用 EMA(Exponential Moving Average)更新:
θ − ← μ θ − + ( 1 − μ ) θ \theta^- \leftarrow \mu \theta^- + (1 - \mu) \theta θ−←μθ−+(1−μ)θ
为什么需要 EMA:如果直接使用在线网络作为目标,训练会不稳定------两个相同的网络相互"追逐",导致发散。EMA 提供了稳定的目标。
3.2 距离度量
L2 距离:
d ( x , y ) = ∥ x − y ∥ 2 2 d(\mathbf{x}, \mathbf{y}) = \|\mathbf{x} - \mathbf{y}\|_2^2 d(x,y)=∥x−y∥22
LPIPS 距离(用于图像):
d LPIPS ( x , y ) = ∑ l ∥ feat l ( x ) − feat l ( y ) ∥ 2 2 d_{\text{LPIPS}}(\mathbf{x}, \mathbf{y}) = \sum_l \|\text{feat}_l(\mathbf{x}) - \text{feat}_l(\mathbf{y})\|_2^2 dLPIPS(x,y)=l∑∥featl(x)−featl(y)∥22
LPIPS 使用预训练网络的特征距离,更符合人类感知。
3.3 梯度裁剪
一致性蒸馏的梯度可能很大,需要梯度裁剪:
g ← g ⋅ min ( 1 , c ∥ g ∥ ) \mathbf{g} \leftarrow \mathbf{g} \cdot \min\left(1, \frac{c}{\|\mathbf{g}\|}\right) g←g⋅min(1,∥g∥c)
其中 c c c 是裁剪阈值(如 c = 1.0 c = 1.0 c=1.0)。
4. 一致性蒸馏的收敛性
4.1 收敛定理
定理(非正式):在适当条件下,一致性蒸馏收敛到教师模型的 PF-ODE 轨迹。
条件:
- 教师模型足够好( D ϕ ≈ E x 0 ∣ x t \mathbf{D}_\phi \approx \mathbb{E}\\mathbf{x}_0 \| \\mathbf{x}_t Dϕ≈Ex0∣xt)
- 时间步离散化足够细( Δ t → 0 \Delta t \to 0 Δt→0)
- 网络容量足够大
- EMA 衰减率 μ \mu μ 适当(如 μ = 0.9999 \mu = 0.9999 μ=0.9999)
4.2 蒸馏误差的分解
总蒸馏误差可分解为:
Total Error = Approximation Error ⏟ 网络容量不足 + Discretization Error ⏟ 时间步离散化 + Optimization Error ⏟ 训练不充分 \text{Total Error} = \underbrace{\text{Approximation Error}}{\text{网络容量不足}} + \underbrace{\text{Discretization Error}}{\text{时间步离散化}} + \underbrace{\text{Optimization Error}}_{\text{训练不充分}} Total Error=网络容量不足 Approximation Error+时间步离散化 Discretization Error+训练不充分 Optimization Error
- 逼近误差 : O ( 1 / n ) O(1/\sqrt{n}) O(1/n ), n n n 是网络参数量
- 离散化误差 : O ( Δ t ) O(\Delta t) O(Δt), Δ t \Delta t Δt 是时间步间隔
- 优化误差:随训练步数减小
5. 一致性蒸馏数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ 一致性蒸馏 (Consistency Distillation) 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. DDIM 求解器: ║
║ x_{t_n} = α_{t_n}·D_φ(x_{t_{n+1}}, t_{n+1}) ║
║ + σ_{t_n}·(x_{t_{n+1}} - α_{t_{n+1}}·D_φ) / σ_{t_{n+1}} ║
║ ║
║ 2. 蒸馏损失: ║
║ L_CD = E[ d(f_θ(x_{t_{n+1}}, t_{n+1}), f_{θ⁻}(x_{t_n}, t_n)) ] ║
║ x_{t_n} 由一步 DDIM 获得, θ⁻ = EMA(θ) ║
║ ║
║ 3. EMA 更新: ║
║ θ⁻ ← μ·θ⁻ + (1-μ)·θ, μ = 0.9999 ║
║ 提供稳定的训练目标 ║
║ ║
║ 4. 时间步调度 (Karras): ║
║ t_i = (t_max^{1/ρ} + i/(N-1)·(t_min^{1/ρ} - t_max^{1/ρ}))^ρ ║
║ ρ = 7, 噪声端步长更细 ║
║ ║
║ 5. 误差分解: ║
║ Total = Approximation + Discretization + Optimization ║
║ = O(1/√n) + O(Δt) + O(1/√T) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第三篇:一致性模型的前沿发展
1. 引言
一致性模型自 2023 年提出以来,已经发展出多个重要变体和应用方向。
2. 一致性训练(Consistency Training)
2.1 无需教师的训练
一致性蒸馏需要预训练的教师模型。一致性训练(Consistency Training) 直接从数据训练,无需教师。
核心思想:利用 SDE 的随机性生成相邻点对。
前向 SDE:
d x t = f ( x t , t ) d t + g ( t ) d w t d\mathbf{x}_t = \mathbf{f}(\mathbf{x}_t, t) dt + g(t) d\mathbf{w}_t dxt=f(xt,t)dt+g(t)dwt
对于同一个 x 0 \mathbf{x}_0 x0,两次独立的 SDE 采样得到 x t \mathbf{x}t xt 和 x t ′ \mathbf{x}{t'} xt′( t ≈ t ′ t \approx t' t≈t′),它们在同一条轨迹附近。
一致性训练损失:
L CT ( θ , θ − ) = E d ( f θ ( x t , t ) , f θ − ( x t ′ , t ′ ) ) \mathcal{L}_{\text{CT}}(\theta, \theta^-) = \mathbb{E}\left d\\left( f_\\theta(\\mathbf{x}_t, t), \\, f_{\\theta\^-}(\\mathbf{x}_{t'}, t') \\right) \\right LCT(θ,θ−)=Ed(fθ(xt,t),fθ−(xt′,t′))
2.2 数学挑战
一致性训练的理论保证弱于一致性蒸馏,因为 SDE 轨迹不完全确定------两次独立采样的点可能不在同一条 PF-ODE 轨迹上。
缓解策略:
- 使用较小的时间步间隔 ∣ t − t ′ ∣ |t - t'| ∣t−t′∣
- 使用较大的 EMA 衰减率
- 使用更稳定的距离度量
3. 进阶变体
3.1 渐进蒸馏(Progressive Distillation)
思想:逐步减少采样步数,每轮将步数减半。
流程:
- 训练 N N N 步的扩散模型
- 蒸馏为 N / 2 N/2 N/2 步
- 蒸馏为 N / 4 N/4 N/4 步
- ...直到 1 步
优势:每轮蒸馏的难度更低,训练更稳定。
3.2 一致性模型 + Latent Diffusion
将一致性模型应用于 Latent Diffusion(如 Stable Diffusion):
z 0 = f θ ( z T , T ) , x 0 = Decoder ( z 0 ) \mathbf{z}0 = f\theta(\mathbf{z}_T, T), \quad \mathbf{x}_0 = \text{Decoder}(\mathbf{z}_0) z0=fθ(zT,T),x0=Decoder(z0)
优势:在低维潜在空间中操作,计算效率更高。
3.3 一致性模型 + Classifier-Free Guidance
将 CFG 应用于一致性模型:
f θ cfg ( x t , t , c ) = ( 1 + w ) f θ ( x t , t , c ) − w f θ ( x t , t , ∅ ) f_\theta^{\text{cfg}}(\mathbf{x}t, t, c) = (1 + w) f\theta(\mathbf{x}t, t, c) - w f\theta(\mathbf{x}_t, t, \varnothing) fθcfg(xt,t,c)=(1+w)fθ(xt,t,c)−wfθ(xt,t,∅)
其中 w w w 是引导强度, c c c 是条件(如文本), ∅ \varnothing ∅ 是空条件。
4. 一致性模型与其他方法的对比
| 方法 | 采样步数 | 训练方式 | 质量 | 速度 |
|---|---|---|---|---|
| DDPM | 1000 | 去噪 | 最高 | 最慢 |
| DDIM | 10-50 | 去噪 | 高 | 中等 |
| Flow Matching | 10-50 | 速度场 | 高 | 中等 |
| 一致性蒸馏 | 1-5 | 蒸馏 | 中高 | 快 |
| 一致性训练 | 1-5 | 直接训练 | 中 | 最快 |
5. 前沿研究方向
5.1 音频一致性模型
将一致性模型应用于音频生成(AudioLDM、MusicGen),实现单步音频合成。
5.2 视频一致性模型
将一致性模型应用于视频生成,利用时间维度的一致性。
5.3 3D 一致性模型
将一致性模型应用于 3D 生成(NeRF、3D Gaussian Splatting),实现单步 3D 重建。
5.4 一致性模型的理论深化
- 最优传输视角:将一致性模型与最优传输理论联系
- 信息论分析:分析一致性模型的信息压缩率
- 收敛速率:改进一致性蒸馏的收敛速率分析
6. 前沿发展数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ 一致性模型前沿发展 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 一致性训练 (无教师): ║
║ L_CT = E[ d(f_θ(x_t, t), f_{θ⁻}(x_{t'}, t')) ] ║
║ x_t, x_{t'} 由同一 x₀ 的两次独立 SDE 采样获得 ║
║ ║
║ 2. 渐进蒸馏: ║
║ N 步 → N/2 步 → N/4 步 → ... → 1 步 ║
║ 每轮蒸馏难度更低, 训练更稳定 ║
║ ║
║ 3. CFG 引导: ║
║ f_θ^{cfg}(x_t, t, c) = (1+w)·f_θ(x_t, t, c) - w·f_θ(x_t, t, ∅) ║
║ ║
║ 4. Latent 一致性: ║
║ z₀ = f_θ(z_T, T), x₀ = Decoder(z₀) ║
║ 在低维潜在空间操作, 计算效率更高 ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
参考文献
核心论文
- Song, Y., Dhariwal, P., Chen, M., & Sutskever, I. (2023). Consistency Models. ICML 2023.
- Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-Based Generative Modeling through Stochastic Differential Equations. ICLR 2021.
- Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
扩散模型基础
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
- Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. NeurIPS 2022.
加速采样
- Lu, C., Zhou, Y., Bao, F., Chen, J., Li, C., & Zhu, J. (2022). DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. NeurIPS 2022.
- Salimans, T., & Ho, J. (2022). Progressive Distillation for Fast Sampling of Diffusion Models. ICLR 2022.
应用
- Rombach, R., Blattmann, A., Lorenz, D., Esser, P., & Ommer, B. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR 2022.
- Saharia, C., Chan, W., et al. (2022). Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. NeurIPS 2022.