目录
- [第一篇:Mamba --- 选择性状态空间模型](#第一篇:Mamba — 选择性状态空间模型)
- [第二篇:Flow Matching 与 Rectified Flows](#第二篇:Flow Matching 与 Rectified Flows)
- [第三篇:GRPO --- 群体相对策略优化](#第三篇:GRPO — 群体相对策略优化)
- 参考文献
第一篇:Mamba --- 选择性状态空间模型
1. 引言
Transformer 架构自 2017 年问世以来统治了深度学习领域,但其核心的自注意力机制存在 O ( L 2 ) O(L^2) O(L2) 的计算复杂度和 O ( L ) O(L) O(L) 的显存占用( L L L 为序列长度),在处理长序列时成为瓶颈。
Mamba (Gu & Dao, 2023)提出了一种全新的序列建模范式------选择性状态空间模型(Selective State Space Model) ,通过将经典控制论中的状态空间模型与数据依赖的选择机制相结合,实现了 O ( L ) O(L) O(L) 的线性复杂度,同时在语言、基因组学、音频等模态上达到了与 Transformer 相当甚至更优的性能。
2. 理论基础 --- 连续时间状态空间模型
2.1 线性时不变系统(LTI System)
状态空间模型(SSM)源自控制论,描述一个 N N N 维隐状态 h ( t ) ∈ R N \mathbf{h}(t) \in \mathbb{R}^N h(t)∈RN 如何随时间演化。连续时间形式为:
d h ( t ) d t = A h ( t ) + B x ( t ) \frac{d\mathbf{h}(t)}{dt} = \mathbf{A}\mathbf{h}(t) + \mathbf{B}x(t) dtdh(t)=Ah(t)+Bx(t)
y ( t ) = C h ( t ) + D ⋅ x ( t ) y(t) = \mathbf{C}\mathbf{h}(t) + D \cdot x(t) y(t)=Ch(t)+D⋅x(t)
其中:
| 符号 | 维度 | 含义 |
|---|---|---|
| A \mathbf{A} A | R N × N \mathbb{R}^{N \times N} RN×N | 状态转移矩阵,控制隐状态的演化动力学 |
| B \mathbf{B} B | R N × 1 \mathbb{R}^{N \times 1} RN×1 | 输入投影矩阵,决定输入如何注入状态 |
| C \mathbf{C} C | R 1 × N \mathbb{R}^{1 \times N} R1×N | 输出投影矩阵,决定如何从状态读取输出 |
| D D D | R \mathbb{R} R | 直通项(通常忽略或设为残差连接) |
| x ( t ) x(t) x(t) | R \mathbb{R} R | 标量输入信号 |
| y ( t ) y(t) y(t) | R \mathbb{R} R | 标量输出信号 |
物理直觉 :想象一个弹簧-阻尼系统, A \mathbf{A} A 决定了系统的固有动力学(如弹簧常数、阻尼系数), B \mathbf{B} B 是外部施加的力, C \mathbf{C} C 是我们观察系统的方式。
2.2 离散化:从连续到离散
为了在数字计算机上处理离散序列 x = ( x 1 , x 2 , ... , x L ) \mathbf{x} = (x_1, x_2, \ldots, x_L) x=(x1,x2,...,xL),我们需要对连续系统进行零阶保持(Zero-Order Hold, ZOH) 离散化。
给定步长 Δ > 0 \Delta > 0 Δ>0(可理解为采样间隔),ZOH 离散化假设输入在一个采样间隔内保持不变: x ( t ) = x k , t ∈ [ k Δ , ( k + 1 ) Δ ) x(t) = x_k, \quad t \in [k\Delta, (k+1)\Delta) x(t)=xk,t∈[kΔ,(k+1)Δ)。
离散化后的系统变为:
h k = A ˉ h k − 1 + B ˉ x k \mathbf{h}k = \bar{\mathbf{A}} \mathbf{h}{k-1} + \bar{\mathbf{B}} x_k hk=Aˉhk−1+Bˉxk
y k = C h k y_k = \mathbf{C} \mathbf{h}_k yk=Chk
其中离散化参数为:
A ˉ = exp ( Δ A ) \bar{\mathbf{A}} = \exp(\Delta \mathbf{A}) Aˉ=exp(ΔA)
B ˉ = ( Δ A ) − 1 ( exp ( Δ A ) − I ) ⋅ Δ B ≈ Δ B \bar{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B} \approx \Delta \mathbf{B} Bˉ=(ΔA)−1(exp(ΔA)−I)⋅ΔB≈ΔB
A ˉ \bar{\mathbf{A}} Aˉ 的矩阵指数解释 : exp ( Δ A ) = ∑ k = 0 ∞ ( Δ A ) k k ! \exp(\Delta \mathbf{A}) = \sum_{k=0}^{\infty} \frac{(\Delta \mathbf{A})^k}{k!} exp(ΔA)=∑k=0∞k!(ΔA)k 是状态转移矩阵的自然推广。它精确地描述了在 Δ \Delta Δ 时间内状态如何从 h k − 1 \mathbf{h}_{k-1} hk−1 演化到 h k \mathbf{h}_k hk。
2.3 全局卷积视角
离散 SSM 的一个关键洞察是:线性递推等价于全局卷积。
展开递推关系:
y k = C A ˉ k B ˉ x 0 + C A ˉ k − 1 B ˉ x 1 + ⋯ + C B ˉ x k y_k = \mathbf{C}\bar{\mathbf{A}}^k \bar{\mathbf{B}} x_0 + \mathbf{C}\bar{\mathbf{A}}^{k-1} \bar{\mathbf{B}} x_1 + \cdots + \mathbf{C}\bar{\mathbf{B}} x_k yk=CAˉkBˉx0+CAˉk−1Bˉx1+⋯+CBˉxk
定义卷积核:
K ˉ = ( C B ˉ , C A ˉ B ˉ , C A ˉ 2 B ˉ , ... , C A ˉ L − 1 B ˉ ) ∈ R L \bar{\mathbf{K}} = (\mathbf{C}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}^2\bar{\mathbf{B}}, \ldots, \mathbf{C}\bar{\mathbf{A}}^{L-1}\bar{\mathbf{B}}) \in \mathbb{R}^L Kˉ=(CBˉ,CAˉBˉ,CAˉ2Bˉ,...,CAˉL−1Bˉ)∈RL
则输出可表示为:
y = K ˉ ∗ x \mathbf{y} = \bar{\mathbf{K}} * \mathbf{x} y=Kˉ∗x
其中 ∗ * ∗ 为卷积操作。通过 FFT,卷积可在 O ( L log L ) O(L \log L) O(LlogL) 时间内计算。
2.4 HiPPO 初始化:解决长程依赖
A \mathbf{A} A 矩阵的初始化至关重要。HiPPO(High-order Polynomial Projection Operators) 框架(Gu et al., 2020)提供了理论上最优的初始化方式。
HiPPO 矩阵的数学定义:
A n k = − { ( 2 n + 1 ) 1 / 2 ( 2 k + 1 ) 1 / 2 if n > k n + 1 if n = k 0 if n < k \mathbf{A}_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} Ank=−⎩ ⎨ ⎧(2n+1)1/2(2k+1)1/2n+10if n>kif n=kif n<k
直觉解释 :HiPPO 矩阵将输入信号在线上投影到正交多项式基上,使得隐状态 h ( t ) \mathbf{h}(t) h(t) 始终保存了过去输入的最优低秩近似。这赋予了 SSM 理论上的长程记忆能力。
2.5 固定参数 SSM 的根本局限
在 S4 等经典 SSM 中, A , B , C \mathbf{A}, \mathbf{B}, \mathbf{C} A,B,C 对所有输入都是固定的(LTI 性质)。这带来了两个根本问题:
问题 1:无法进行内容感知推理
固定参数的 SSM 本质上是一个时不变滤波器。它无法根据输入内容选择性地记住或遗忘信息。例如:
输入: "The capital of France is Paris. The capital of Germany is ..."
一个内容无关的 SSM 对所有 token 施加相同的衰减模式,无法根据语义选择性地保留 "France → Paris" 的关联以便回答后续问题。
问题 2:无法实现选择性机制
在语言模型中,关键能力之一是根据上下文选择性地处理信息------有时需要长程依赖(指代消解),有时需要短程关注(局部语法)。固定 SSM 的卷积核 K ˉ \bar{\mathbf{K}} Kˉ 是预计算的,无法适应不同的输入。
3. Mamba 的核心创新 --- 选择性机制
3.1 从固定到动态:输入依赖参数
Mamba 的核心思想极其简洁:让 SSM 的参数依赖于输入。
具体地,Mamba 将离散 SSM 的参数 B , C , Δ \mathbf{B}, \mathbf{C}, \Delta B,C,Δ 从固定常数变为输入的函数:
B k = Linear B ( x k ) ∈ R N \mathbf{B}_k = \text{Linear}_B(x_k) \in \mathbb{R}^N Bk=LinearB(xk)∈RN
C k = Linear C ( x k ) ∈ R N \mathbf{C}_k = \text{Linear}_C(x_k) \in \mathbb{R}^N Ck=LinearC(xk)∈RN
Δ k = softplus ( Linear Δ ( x k ) ) ∈ R > 0 \Delta_k = \text{softplus}(\text{Linear}\Delta(x_k)) \in \mathbb{R}{>0} Δk=softplus(LinearΔ(xk))∈R>0
其中 softplus ( z ) = ln ( 1 + e z ) \text{softplus}(z) = \ln(1 + e^z) softplus(z)=ln(1+ez) 确保 Δ k > 0 \Delta_k > 0 Δk>0。
关键变化:
- B k \mathbf{B}_k Bk 现在是 N N N 维向量(而非标量),每个状态维度有独立的输入投影
- C k \mathbf{C}_k Ck 同样是 N N N 维向量
- Δ k \Delta_k Δk 是标量,控制对当前输入的"关注程度"
3.2 选择机制的数学分析
Δ k \Delta_k Δk 的作用可以精确分析。离散化后:
A ˉ k = exp ( Δ k A ) \bar{\mathbf{A}}_k = \exp(\Delta_k \mathbf{A}) Aˉk=exp(ΔkA)
B ˉ k = Δ k B k \bar{\mathbf{B}}_k = \Delta_k \mathbf{B}_k Bˉk=ΔkBk
考虑 A \mathbf{A} A 为对角矩阵 A = diag ( a 1 , a 2 , ... , a N ) \mathbf{A} = \text{diag}(a_1, a_2, \ldots, a_N) A=diag(a1,a2,...,aN)(Mamba 的实现选择),则:
A ˉ k , n n = exp ( Δ k ⋅ a n ) \bar{A}_{k,nn} = \exp(\Delta_k \cdot a_n) Aˉk,nn=exp(Δk⋅an)
当 Δ k \Delta_k Δk 很大时:
- A ˉ k → 0 \bar{\mathbf{A}}_k \to \mathbf{0} Aˉk→0(状态被重置)
- B ˉ k \bar{\mathbf{B}}_k Bˉk 很大(输入被强烈注入)
- 效果 :模型选择性地忽略 历史状态,聚焦于当前输入
当 Δ k \Delta_k Δk 很小时:
- A ˉ k → I \bar{\mathbf{A}}_k \to \mathbf{I} Aˉk→I(状态几乎不变)
- B ˉ k → 0 \bar{\mathbf{B}}_k \to \mathbf{0} Bˉk→0(输入几乎不注入)
- 效果 :模型选择性地保留 历史状态,忽略当前输入
这正是选择机制的本质 : Δ k \Delta_k Δk 作为"门控",动态决定每个时间步是记忆还是遗忘。
3.3 选择性 SSM 的递推公式
综合以上,Mamba 的完整递推公式为:
A ˉ k = exp ( Δ k ⋅ diag ( a 1 , ... , a N ) ) \bar{\mathbf{A}}_k = \exp(\Delta_k \cdot \text{diag}(a_1, \ldots, a_N)) Aˉk=exp(Δk⋅diag(a1,...,aN))
B ˉ k = Δ k ⋅ B k \bar{\mathbf{B}}_k = \Delta_k \cdot \mathbf{B}_k Bˉk=Δk⋅Bk
h k = A ˉ k ⊙ h k − 1 + B ˉ k ⊙ x k \mathbf{h}_k = \bar{\mathbf{A}}k \odot \mathbf{h}{k-1} + \bar{\mathbf{B}}_k \odot x_k hk=Aˉk⊙hk−1+Bˉk⊙xk
y k = C k ⋅ h k y_k = \mathbf{C}_k \cdot \mathbf{h}_k yk=Ck⋅hk
其中 ⊙ \odot ⊙ 表示逐元素乘法(因为 A \mathbf{A} A 是对角的)。
复杂度分析:
| 指标 | 复杂度 |
|---|---|
| 每步计算量 | O ( N ) O(N) O(N)(向量逐元素乘法和点积) |
| 总计算量 | O ( L N ) O(LN) O(LN) |
| 空间复杂度 | O ( N ) O(N) O(N)(只需保存当前状态) |
3.4 为什么选择性破坏了卷积效率
在经典 SSM 中,参数固定 → 卷积核 K ˉ \bar{\mathbf{K}} Kˉ 可预计算 → 可用 FFT 加速。
但在 Mamba 中,参数依赖输入 → 每个时间步的 A ˉ k , B ˉ k \bar{\mathbf{A}}_k, \bar{\mathbf{B}}_k Aˉk,Bˉk 不同 → 无法预计算卷积核 → 卷积加速失效。
这看似是一个严重的效率倒退,但 Mamba 通过硬件感知的并行扫描算法巧妙地解决了这个问题。
4. 硬件感知并行扫描
4.1 问题定义
选择性 SSM 的递推:
h k = A ˉ k ⊙ h k − 1 + B ˉ k ⊙ x k \mathbf{h}_k = \bar{\mathbf{A}}k \odot \mathbf{h}{k-1} + \bar{\mathbf{B}}_k \odot x_k hk=Aˉk⊙hk−1+Bˉk⊙xk
这是一个线性递推(Linear Recurrence),形式为:
h k = a k ⊙ h k − 1 + b k \mathbf{h}_k = \mathbf{a}k \odot \mathbf{h}{k-1} + \mathbf{b}_k hk=ak⊙hk−1+bk
其中 a k = A ˉ k ∈ R N \mathbf{a}_k = \bar{\mathbf{A}}_k \in \mathbb{R}^N ak=Aˉk∈RN, b k = B ˉ k ⊙ x k ∈ R N \mathbf{b}_k = \bar{\mathbf{B}}_k \odot x_k \in \mathbb{R}^N bk=Bˉk⊙xk∈RN。
串行计算需要 O ( L ) O(L) O(L) 步,无法利用 GPU 的大规模并行性。
4.2 并行前缀扫描(Parallel Prefix Scan)
并行扫描是解决线性递推的经典并行算法。定义扫描算子 ∙ \bullet ∙:
( a 1 , b 1 ) ∙ ( a 2 , b 2 ) = ( a 2 ⊙ a 1 , a 2 ⊙ b 1 + b 2 ) (\mathbf{a}_1, \mathbf{b}_1) \bullet (\mathbf{a}_2, \mathbf{b}_2) = (\mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{b}_2) (a1,b1)∙(a2,b2)=(a2⊙a1,a2⊙b1+b2)
关键性质 :该算子满足结合律。
证明 :设三个连续元素 ( a i , b i ) (\mathbf{a}_i, \mathbf{b}_i) (ai,bi),验证结合律:
( a 1 , b 1 ) ∙ ( a 2 , b 2 ) ∙ ( a 3 , b 3 ) (\\mathbf{a}_1, \\mathbf{b}_1) \\bullet (\\mathbf{a}_2, \\mathbf{b}_2) \bullet (\mathbf{a}_3, \mathbf{b}_3) (a1,b1)∙(a2,b2)∙(a3,b3)
= ( a 2 ⊙ a 1 , a 2 ⊙ b 1 + b 2 ) ∙ ( a 3 , b 3 ) = (\mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{b}_2) \bullet (\mathbf{a}_3, \mathbf{b}_3) =(a2⊙a1,a2⊙b1+b2)∙(a3,b3)
= ( a 3 ⊙ a 2 ⊙ a 1 , a 3 ⊙ a 2 ⊙ b 1 + a 3 ⊙ b 2 + b 3 ) = (\mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{a}_3 \odot \mathbf{b}_2 + \mathbf{b}_3) =(a3⊙a2⊙a1,a3⊙a2⊙b1+a3⊙b2+b3)
类似地:
( a 1 , b 1 ) ∙ ( a 2 , b 2 ) ∙ ( a 3 , b 3 ) (\mathbf{a}_1, \mathbf{b}_1) \bullet (\\mathbf{a}_2, \\mathbf{b}_2) \\bullet (\\mathbf{a}_3, \\mathbf{b}_3) (a1,b1)∙(a2,b2)∙(a3,b3)
= ( a 1 , b 1 ) ∙ ( a 3 ⊙ a 2 , a 3 ⊙ b 2 + b 3 ) = (\mathbf{a}_1, \mathbf{b}_1) \bullet (\mathbf{a}_3 \odot \mathbf{a}_2, \mathbf{a}_3 \odot \mathbf{b}_2 + \mathbf{b}_3) =(a1,b1)∙(a3⊙a2,a3⊙b2+b3)
= ( a 3 ⊙ a 2 ⊙ a 1 , a 3 ⊙ a 2 ⊙ b 1 + a 3 ⊙ b 2 + b 3 ) = (\mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{a}_3 \odot \mathbf{b}_2 + \mathbf{b}_3) =(a3⊙a2⊙a1,a3⊙a2⊙b1+a3⊙b2+b3)
两者相等,结合律成立。 ■ \blacksquare ■
4.3 Blelloch 并行扫描算法
利用结合律,我们可以将 L L L 步串行递推转化为 O ( log L ) O(\log L) O(logL) 深度的并行计算:
阶段 1:上扫(Up-Sweep) --- 自底向上归约
层级 0: (a₁,b₁) (a₂,b₂) (a₃,b₃) (a₄,b₄) (a₅,b₅) (a₆,b₆) (a₇,b₇) (a₈,b₈)
层级 1: ● ● ● ●
层级 2: ● ●
层级 3: ●
每一层将相邻的两个元素通过扫描算子合并,直到根节点。
阶段 2:下扫(Down-Sweep) --- 自顶向下分发
从根节点开始,将累积结果分发到每个叶子节点,使得位置 k k k 的结果为前 k k k 个元素的累积积。
| 指标 | 复杂度 |
|---|---|
| 总工作量 | O ( L ) O(L) O(L)(与串行相同) |
| 并行深度 | O ( log L ) O(\log L) O(logL) |
| 加速比 | O ( L / log L ) O(L / \log L) O(L/logL) |
4.4 硬件感知优化:分块扫描
在实际 GPU 实现中,纯粹的并行扫描面临两个问题:
- 内存带宽瓶颈:全局内存访问延迟高
- GPU 层级结构:线程 → Warp (32) → Block → Grid
Mamba 的硬件感知实现采用分块策略:
序列: [块1: 64元素] [块2: 64元素] [块3: 64元素] ...
───────────── ───────────── ─────────────
内部: 串行扫描 内部: 串行扫描 内部: 串行扫描
(利用 SRAM) (利用 SRAM) (利用 SRAM)
↓ ↓ ↓
块间: 并行扫描 (通过 HBM 传递状态)
- 块内:在 SRAM(共享内存)中执行串行扫描,避免 HBM 访问
- 块间:执行粗粒度并行扫描,仅传递块边界状态
- IO 复杂度 : O ( L N / B ) O(LN/B) O(LN/B),其中 B B B 为 SRAM 容量
5. Mamba 架构设计
5.1 整体架构
Mamba 采用了简化的架构,没有使用传统的注意力机制:
输入 x ∈ ℝ^(B×L×D)
↓
[线性投影] → x' ∈ ℝ^(B×L×E) (扩展维度 E = 2D 或 4D)
↓
[Conv1D 1×1] → 局部特征提取
↓
[SiLU 激活]
↓
[选择性 SSM] → 核心计算
↓
[门控乘法] ← 从 x' 分出的门控分支
↓
[线性投影] → y ∈ ℝ^(B×L×D)
↓
残差连接 + LayerNorm
5.2 门控机制
Mamba 借鉴了 GLU(Gated Linear Unit)的思想:
MambaBlock ( x ) = Linear ( SiLU ( Conv1D ( Linear ( x ) ) ) ⊙ SSM ( SiLU ( Conv1D ( Linear ( x ) ) ) ) ) \text{MambaBlock}(\mathbf{x}) = \text{Linear}(\text{SiLU}(\text{Conv1D}(\text{Linear}(\mathbf{x}))) \odot \text{SSM}(\text{SiLU}(\text{Conv1D}(\text{Linear}(\mathbf{x}))))) MambaBlock(x)=Linear(SiLU(Conv1D(Linear(x)))⊙SSM(SiLU(Conv1D(Linear(x)))))
其中 SiLU(Sigmoid Linear Unit)定义为:
SiLU ( z ) = z ⋅ σ ( z ) = z 1 + e − z \text{SiLU}(z) = z \cdot \sigma(z) = \frac{z}{1 + e^{-z}} SiLU(z)=z⋅σ(z)=1+e−zz
门控的作用是让模型学习哪些信息需要通过 SSM 处理,哪些可以直通。
5.3 参数化细节
对于维度 D D D 的输入,Mamba block 的参数:
| 参数 | 维度 | 说明 |
|---|---|---|
| W i n \mathbf{W}_{in} Win | R D × ( 2 E ) \mathbb{R}^{D \times (2E)} RD×(2E) | 输入投影(分为两路:SSM 输入 + 门控) |
| Conv1D | 核大小 4, E E E 通道 | 因果卷积 |
| A \mathbf{A} A | R E \mathbb{R}^E RE | 对角矩阵,通过 log \log log 参数化确保负值 |
| W B \mathbf{W}_B WB | R D s s × E \mathbb{R}^{D_{ss} \times E} RDss×E | SSM 状态扩展因子 |
| W C \mathbf{W}_C WC | R D s s × E \mathbb{R}^{D_{ss} \times E} RDss×E | 输出投影 |
| W Δ \mathbf{W}_\Delta WΔ | R E × D \mathbb{R}^{E \times D} RE×D | 投影到标量后 softplus |
| W o u t \mathbf{W}_{out} Wout | R E × D \mathbb{R}^{E \times D} RE×D | 输出投影 |
6. 完整可运行实现
6.1 选择性 SSM 核心
python
"""
Mamba: Selective State Space Model --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Optional, Tuple
def selective_scan(
x: torch.Tensor, # (B, L, D) 输入
delta: torch.Tensor, # (B, L, D) 步长参数
A: torch.Tensor, # (D, N) 状态转移矩阵(对角元素)
B: torch.Tensor, # (B, L, N) 输入投影(输入依赖)
C: torch.Tensor, # (B, L, N) 输出投影(输入依赖)
D: torch.Tensor, # (D,) 跳跃连接参数
dt_rank: int,
) -> torch.Tensor:
"""
选择性状态空间模型的前向传播。
数学公式:
Ā_k = exp(Δ_k · A) (离散化状态转移)
B̄_k = Δ_k · B_k (离散化输入投影)
h_k = Ā_k ⊙ h_{k-1} + B̄_k ⊙ x_k (状态更新)
y_k = C_k · h_k (输出)
y_k += D ⊙ x_k (跳跃连接)
"""
batch_size, seq_len, d_model = x.shape
n = A.shape[1] # 状态维度 N
# 步骤 1: 离散化
delta_A = torch.exp(
delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)
) # (B, L, D, N)
delta_B = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, D, N)
# 步骤 2: 串行递推计算状态
h = torch.zeros(batch_size, d_model, n, device=x.device, dtype=x.dtype)
ys = []
for k in range(seq_len):
h = delta_A[:, k] * h + delta_B[:, k] * x[:, k].unsqueeze(-1)
y_k = (h * C[:, k].unsqueeze(1)).sum(dim=-1) # (B, D)
ys.append(y_k)
y = torch.stack(ys, dim=1) # (B, L, D)
# 步骤 3: 添加跳跃连接
y = y + x * D.unsqueeze(0).unsqueeze(0)
return y
def selective_scan_parallel(
x: torch.Tensor, # (B, L, D)
delta: torch.Tensor, # (B, L, D)
A: torch.Tensor, # (D, N)
B: torch.Tensor, # (B, L, N)
C: torch.Tensor, # (B, L, N)
D: torch.Tensor, # (D,)
) -> torch.Tensor:
"""
并行前缀扫描实现选择性 SSM。
核心思想: 将线性递推 h_k = a_k * h_{k-1} + b_k
转化为前缀积问题,利用结合律并行计算。
"""
batch_size, seq_len, d_model = x.shape
n = A.shape[1]
# 离散化
delta_A = torch.exp(
delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)
)
delta_B = delta.unsqueeze(-1) * B.unsqueeze(2)
a = delta_A
b = delta_B * x.unsqueeze(-1)
# 并行前缀扫描 (Hillis-Steele)
levels = int(np.ceil(np.log2(seq_len)))
padded_len = 2 ** levels
if padded_len > seq_len:
pad_size = padded_len - seq_len
a = F.pad(a, (0, 0, 0, 0, 0, pad_size))
b = F.pad(b, (0, 0, 0, 0, 0, pad_size))
for d in range(levels):
stride = 2 ** d
a_shifted = torch.zeros_like(a)
b_shifted = torch.zeros_like(b)
if stride < padded_len:
a_shifted[:, stride:] = a[:, :-stride]
b_shifted[:, stride:] = b[:, :-stride]
a_new = a * a_shifted
b_new = a * b_shifted + b
mask = torch.arange(padded_len, device=a.device) >= stride
a = torch.where(mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1), a_new, a)
b = torch.where(mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1), b_new, b)
a = a[:, :seq_len]
b = b[:, :seq_len]
y = (b * C.unsqueeze(2)).sum(dim=-1)
y = y + x * D.unsqueeze(0).unsqueeze(0)
return y
6.2 Mamba Block 实现
python
@dataclass
class MambaConfig:
"""Mamba 模型配置"""
d_model: int = 256
d_state: int = 16
d_conv: int = 4
expand: int = 2
dt_rank: str = "auto"
dt_min: float = 0.001
dt_max: float = 0.1
dt_init: str = "random"
dt_scale: float = 1.0
bias: bool = False
conv_bias: bool = True
def __post_init__(self):
self.d_inner = self.d_model * self.expand
if self.dt_rank == "auto":
self.dt_rank = math.ceil(self.d_model / 16)
class SelectiveSSM(nn.Module):
"""选择性状态空间模型层"""
def __init__(self, config: MambaConfig):
super().__init__()
self.d_model = config.d_model
self.d_state = config.d_state
self.d_inner = config.d_inner
self.dt_rank = config.dt_rank
A = torch.arange(1, self.d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.x_proj = nn.Linear(
self.d_inner,
self.dt_rank + self.d_state * 2,
bias=False
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
if config.dt_init == "constant":
nn.init.constant_(
self.dt_proj.weight,
math.log(config.dt_max - config.dt_min)
)
else:
dt_init_std = self.dt_rank ** -0.5 * config.dt_scale
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
dt = torch.exp(
torch.rand(self.d_inner)
* (math.log(config.dt_max) - math.log(config.dt_min))
+ math.log(config.dt_min)
).clamp(min=1e-4)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq_len, dim = x.shape
xz = self.x_proj(x)
dt, B, C = torch.split(
xz,
[self.dt_rank, self.d_state, self.d_state],
dim=-1
)
dt = self.dt_proj(dt)
dt = F.softplus(dt)
A = -torch.exp(self.A_log)
y = selective_scan(x, dt, A, B, C, self.D, self.dt_rank)
return y
class MambaBlock(nn.Module):
"""完整的 Mamba 块"""
def __init__(self, config: MambaConfig):
super().__init__()
self.config = config
self.in_proj = nn.Linear(
config.d_model,
config.d_inner * 2,
bias=config.bias
)
self.conv1d = nn.Conv1d(
in_channels=config.d_inner,
out_channels=config.d_inner,
kernel_size=config.d_conv,
bias=config.conv_bias,
groups=config.d_inner,
padding=config.d_conv - 1,
)
self.ssm = SelectiveSSM(config)
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
self.norm = nn.LayerNorm(config.d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.norm(x)
xz = self.in_proj(x)
x_ssm, z = xz.chunk(2, dim=-1)
x_ssm = x_ssm.transpose(1, 2)
x_ssm = self.conv1d(x_ssm)[:, :, :x.size(1)]
x_ssm = x_ssm.transpose(1, 2)
x_ssm = F.silu(x_ssm)
y = self.ssm(x_ssm)
y = y * F.silu(z)
y = self.out_proj(y)
return y + residual
6.3 实验代码
python
def demonstrate_parallel_scan():
"""演示并行前缀扫描与串行递推的等价性"""
torch.manual_seed(42)
batch_size = 2
seq_len = 32
d_model = 4
d_state = 8
x = torch.randn(batch_size, seq_len, d_model)
delta = F.softplus(torch.randn(batch_size, seq_len, d_model))
A = -torch.exp(torch.randn(d_model, d_state))
B = torch.randn(batch_size, seq_len, d_state) * 0.1
C = torch.randn(batch_size, seq_len, d_state) * 0.1
D_param = torch.ones(d_model)
y_serial = selective_scan(x, delta, A, B, C, D_param, dt_rank=2)
y_parallel = selective_scan_parallel(x, delta, A, B, C, D_param)
max_diff = (y_serial - y_parallel).abs().max().item()
print(f"串行 vs 并行扫描结果比较:")
print(f" 最大绝对误差: {max_diff:.6e}")
print(f" 结果一致: {max_diff < 1e-4}")
7. Mamba 与 Transformer 的理论对比
7.1 计算复杂度
| 操作 | Transformer | Mamba |
|---|---|---|
| 前向传播 | O ( L 2 D + L D 2 ) O(L^2 D + L D^2) O(L2D+LD2) | O ( L D N ) O(L D N) O(LDN) |
| 推理(自回归) | O ( L D ) O(L D) O(LD) 每步(KV cache) | O ( D N ) O(D N) O(DN) 每步(仅状态) |
| 内存(推理) | O ( L D ) O(L D) O(LD) KV cache | O ( D N ) O(D N) O(DN) 状态 |
| 并行深度 | O ( 1 ) O(1) O(1)(矩阵乘法) | O ( log L ) O(\log L) O(logL)(扫描) |
当 L ≫ N L \gg N L≫N 时(长序列),Mamba 的优势显著。
7.2 表达能力分析
Transformer 的优势:
- 全局注意力可以精确检索任意位置的信息
- In-context learning 能力已被广泛验证
Mamba 的优势:
- 选择机制实现了动态的、内容感知的记忆管理
- 状态压缩避免了冗余信息的存储
- 线性复杂度使其能处理超长序列
理论局限:
- 有限状态 N N N 限制了 Mamba 能同时记住的信息量
- 无法像 Transformer 那样进行精确的"复制"操作(需要 O ( L ) O(L) O(L) 状态)
7.3 选择机制与信息论
从信息论角度,选择机制可以理解为信息瓶颈的一种形式:
I ( h k ; x ≤ k ) ≤ N log 2 ( 1 ϵ ) I(\mathbf{h}k; \mathbf{x}{\leq k}) \leq N \log_2 \left(\frac{1}{\epsilon}\right) I(hk;x≤k)≤Nlog2(ϵ1)
其中 ϵ \epsilon ϵ 是状态精度。有限维度 N N N 的状态只能保留有限的信息,选择机制的作用是最大化保留信息的相关性。
8. 扩展与前沿发展
8.1 Mamba-2:状态空间对偶性
Mamba-2(Dao & Gu, 2024)发现了 SSM 与注意力之间的深层联系:
结构化状态空间对偶(SSD):选择性 SSM 可以等价地表示为一种结构化的半可分矩阵乘法,形式上类似于线性注意力。
y = SSM ( x ) ⟺ y = M x \mathbf{y} = \text{SSM}(\mathbf{x}) \iff \mathbf{y} = \mathbf{M} \mathbf{x} y=SSM(x)⟺y=Mx
其中 M \mathbf{M} M 是一个半可分矩阵:
M i j = { C i ( ∏ k = j + 1 i A ˉ k ) B ˉ j if i ≥ j 0 if i < j M_{ij} = \begin{cases} \mathbf{C}i \left(\prod{k=j+1}^{i} \bar{\mathbf{A}}_k\right) \bar{\mathbf{B}}_j & \text{if } i \geq j \\ 0 & \text{if } i < j \end{cases} Mij={Ci(∏k=j+1iAˉk)Bˉj0if i≥jif i<j
这种对偶性使得 Mamba-2 可以利用矩阵乘法的硬件优化,实现比 Mamba-1 更高的计算效率。
8.2 混合架构
实践中,纯 Mamba 模型在某些需要精确检索的任务上(如复制、查找)不如 Transformer。混合架构应运而生:
Jamba = Mamba layers ⊕ Attention layers \text{Jamba} = \text{Mamba layers} \oplus \text{Attention layers} Jamba=Mamba layers⊕Attention layers
例如,每 4 层 Mamba 后插入 1 层注意力,兼顾效率与检索能力。
8.3 多模态扩展
Mamba 已扩展到多个模态:
- Vision Mamba (Vim):将图像分割为 patch,用双向 Mamba 处理
- Video Mamba:利用 Mamba 的线性复杂度处理长视频序列
- Genomic Mamba:处理超长 DNA 序列(百万碱基对)
9. Mamba 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════╗
║ Mamba 数学公式总结 ║
╠══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 连续时间 SSM: ║
║ dh/dt = A·h(t) + B·x(t) ║
║ y(t) = C·h(t) + D·x(t) ║
║ ║
║ 2. ZOH 离散化: ║
║ Ā = exp(Δ·A) ║
║ B̄ = (Δ·A)^{-1}·(exp(Δ·A) - I)·Δ·B ≈ Δ·B ║
║ ║
║ 3. 选择性 SSM (Mamba 的核心创新): ║
║ B_k = Linear_B(x_k) ← 输入依赖的 B ║
║ C_k = Linear_C(x_k) ← 输入依赖的 C ║
║ Δ_k = softplus(Linear_Δ(x_k)) ← 输入依赖的步长 ║
║ ║
║ 4. 状态更新: ║
║ h_k = exp(Δ_k · A) ⊙ h_{k-1} + Δ_k · B_k · x_k ║
║ y_k = C_k · h_k + D · x_k ║
║ ║
║ 5. 并行扫描算子 (满足结合律): ║
║ (a₁,b₁) • (a₂,b₂) = (a₂·a₁, a₂·b₁ + b₂) ║
║ ║
║ 6. 复杂度: ║
║ 串行: O(L·N) 时间, O(N) 空间 ║
║ 并行: O(L·N) 工作, O(log L) 深度 ║
║ 对比 Transformer: O(L²) 时间, O(L) 空间 ║
║ ║
║ 7. 选择机制的直觉: ║
║ Δ 大 → 重置状态, 聚焦当前输入 (写入) ║
║ Δ 小 → 保留状态, 忽略当前输入 (记忆) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
第二篇:Flow Matching 与 Rectified Flows
1. 引言
生成模型的目标是从复杂的数据分布 p data ( x ) p_{\text{data}}(\mathbf{x}) pdata(x) 中采样。近年来,扩散模型(Diffusion Models) 在图像生成领域取得了巨大成功,但其依赖于随机微分方程(SDE)的迭代求解,采样需要数百步去噪,计算代价高昂。
Flow Matching (Lipman et al., 2023)和 Rectified Flows (Liu et al., 2023)提出了一种全新的生成建模范式:通过学习一个确定性的速度场,将简单的噪声分布(如高斯分布)沿直线路径传输到数据分布。这一方法:
- 数学简洁:只需回归一个速度场,无需 SDE、分数函数或噪声调度
- 路径最优:通过最优传输理论,构造接近直线的传输路径,大幅减少采样步数
- 理论优雅:统一了扩散模型、Flow Matching 和最优传输的数学框架
2. 理论基础 --- 连续正规化流
2.1 从离散流到连续流
正规化流(Normalizing Flow) 通过一系列可逆变换将简单分布 p 0 p_0 p0 逐步变形为复杂分布 p T p_T pT:
z 0 ∼ p 0 ( z ) → f 1 z 1 → f 2 ⋯ → f T z T ∼ p T ( x ) \mathbf{z}_0 \sim p_0(\mathbf{z}) \quad \xrightarrow{f_1} \quad \mathbf{z}_1 \quad \xrightarrow{f_2} \quad \cdots \quad \xrightarrow{f_T} \quad \mathbf{z}_T \sim p_T(\mathbf{x}) z0∼p0(z)f1 z1f2 ⋯fT zT∼pT(x)
每一步变换 f k f_k fk 的概率密度通过变量替换公式追踪:
p k ( z k ) = p k − 1 ( z k − 1 ) ∣ det ∂ f k ∂ z k − 1 ∣ − 1 p_{k}(\mathbf{z}k) = p{k-1}(\mathbf{z}{k-1}) \left| \det \frac{\partial f_k}{\partial \mathbf{z}{k-1}} \right|^{-1} pk(zk)=pk−1(zk−1) det∂zk−1∂fk −1
连续正规化流(Continuous Normalizing Flow, CNF) 将离散的 T T T 步推广到连续时间 t ∈ 0 , 1 t \in 0, 1 t∈0,1,用一个常微分方程(ODE)描述变换过程:
d z t d t = v θ ( z t , t ) \frac{d\mathbf{z}t}{dt} = \mathbf{v}\theta(\mathbf{z}_t, t) dtdzt=vθ(zt,t)
其中 v θ : R d × 0 , 1 → R d \mathbf{v}_\theta: \mathbb{R}^d \times 0, 1 \to \mathbb{R}^d vθ:Rd×0,1→Rd 是一个由神经网络参数化的速度场(velocity field)。
2.2 概率密度的演化:连续性方程
CNF 的关键数学工具是连续性方程(Continuity Equation) ,它描述了概率密度 p t ( x ) p_t(\mathbf{x}) pt(x) 如何随时间演化:
∂ p t ( x ) ∂ t + ∇ ⋅ ( p t ( x ) v θ ( x , t ) ) = 0 \frac{\partial p_t(\mathbf{x})}{\partial t} + \nabla \cdot \left( p_t(\mathbf{x}) \, \mathbf{v}_\theta(\mathbf{x}, t) \right) = 0 ∂t∂pt(x)+∇⋅(pt(x)vθ(x,t))=0
其中 ∇ ⋅ \nabla \cdot ∇⋅ 是散度算子:
∇ ⋅ f = ∑ i = 1 d ∂ f i ∂ x i \nabla \cdot \mathbf{f} = \sum_{i=1}^d \frac{\partial f_i}{\partial x_i} ∇⋅f=i=1∑d∂xi∂fi
物理直觉 :这是流体力学中的质量守恒方程。 p t ( x ) p_t(\mathbf{x}) pt(x) 是流体密度, v θ ( x , t ) \mathbf{v}_\theta(\mathbf{x}, t) vθ(x,t) 是流速场。方程保证了概率质量守恒------没有概率凭空产生或消失。
证明(概率守恒):
对于任意区域 Ω ⊂ R d \Omega \subset \mathbb{R}^d Ω⊂Rd,概率质量的变化率为:
d d t ∫ Ω p t ( x ) d x = − ∫ ∂ Ω p t ( x ) v θ ( x , t ) ⋅ n d S \frac{d}{dt} \int_\Omega p_t(\mathbf{x}) \, d\mathbf{x} = -\int_{\partial \Omega} p_t(\mathbf{x}) \, \mathbf{v}_\theta(\mathbf{x}, t) \cdot \mathbf{n} \, dS dtd∫Ωpt(x)dx=−∫∂Ωpt(x)vθ(x,t)⋅ndS
由散度定理:
= − ∫ Ω ∇ ⋅ ( p t ( x ) v θ ( x , t ) ) d x = -\int_\Omega \nabla \cdot \left( p_t(\mathbf{x}) \, \mathbf{v}_\theta(\mathbf{x}, t) \right) d\mathbf{x} =−∫Ω∇⋅(pt(x)vθ(x,t))dx
由于 Ω \Omega Ω 任意,被积函数必须为零,即得连续性方程。 ■ \blacksquare ■
2.3 对数概率的演化
对连续性方程两边除以 p t p_t pt,可得对数概率的演化方程:
∂ log p t ( x ) ∂ t = − ∇ ⋅ v θ ( x , t ) − v θ ( x , t ) ⋅ ∇ log p t ( x ) \frac{\partial \log p_t(\mathbf{x})}{\partial t} = -\nabla \cdot \mathbf{v}\theta(\mathbf{x}, t) - \mathbf{v}\theta(\mathbf{x}, t) \cdot \nabla \log p_t(\mathbf{x}) ∂t∂logpt(x)=−∇⋅vθ(x,t)−vθ(x,t)⋅∇logpt(x)
这个公式在后续推导 Flow Matching 损失函数时至关重要。
3. Flow Matching --- 条件速度场框架
3.1 从仿真损失到条件匹配
训练 CNF 的直接目标是最大化数据的对数似然。利用瞬时变量替换公式(Instantaneous Change of Variables):
log p 1 ( x 1 ) = log p 0 ( x 0 ) − ∫ 0 1 ∇ ⋅ v θ ( x t , t ) d t \log p_1(\mathbf{x}_1) = \log p_0(\mathbf{x}0) - \int_0^1 \nabla \cdot \mathbf{v}\theta(\mathbf{x}_t, t) \, dt logp1(x1)=logp0(x0)−∫01∇⋅vθ(xt,t)dt
其中 x t \mathbf{x}_t xt 是从 x 0 \mathbf{x}0 x0 沿速度场 v θ \mathbf{v}\theta vθ 积分得到的轨迹。
仿真损失(Simulation Loss):
L sim ( θ ) = E t ∼ U ( 0 , 1 ) , x 0 ∼ p 0 ∥ v θ ( x t , t ) − v target ( x t , t ) ∥ 2 \mathcal{L}{\text{sim}}(\theta) = \mathbb{E}{t \sim \mathcal{U}(0,1), \, \mathbf{x}_0 \sim p_0} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - \\mathbf{v}_{\\text{target}}(\\mathbf{x}_t, t) \\\|\^2 \\right Lsim(θ)=Et∼U(0,1),x0∼p0∥vθ(xt,t)−vtarget(xt,t)∥2
但直接计算 v target \mathbf{v}_{\text{target}} vtarget 需要知道边缘分布 p t ( x ) p_t(\mathbf{x}) pt(x) 的梯度,这在高维空间中是不可行的。
Flow Matching 的核心洞察 :虽然边缘速度场 v target \mathbf{v}_{\text{target}} vtarget 难以计算,但条件速度场 v t ( x ∣ x 1 ) \mathbf{v}_t(\mathbf{x} | \mathbf{x}_1) vt(x∣x1) 可以精确构造!
3.2 条件概率路径
给定一个数据样本 x 1 ∼ p data ( x ) \mathbf{x}1 \sim p{\text{data}}(\mathbf{x}) x1∼pdata(x),我们构造一个从噪声 x 0 ∼ p 0 ( x ) \mathbf{x}_0 \sim p_0(\mathbf{x}) x0∼p0(x) 到 x 1 \mathbf{x}_1 x1 的条件概率路径 p t ( x ∣ x 1 ) p_t(\mathbf{x} | \mathbf{x}_1) pt(x∣x1)。
最简单的路径是线性插值:
x t = ( 1 − t ) x 0 + t x 1 , t ∈ 0 , 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1, \quad t \in 0, 1 xt=(1−t)x0+tx1,t∈0,1
对应的条件速度场为:
u t ( x t ∣ x 1 ) = d x t d t = x 1 − x 0 \mathbf{u}_t(\mathbf{x}_t | \mathbf{x}_1) = \frac{d\mathbf{x}_t}{dt} = \mathbf{x}_1 - \mathbf{x}_0 ut(xt∣x1)=dtdxt=x1−x0
关键性质 :这个速度场是无旋的(irrotational) ,即存在一个标量势函数 ϕ \phi ϕ 使得 u = ∇ ϕ \mathbf{u} = \nabla \phi u=∇ϕ。这与最优传输有深刻联系。
3.3 边缘速度场与条件速度场的关系
边缘概率路径定义为:
p t ( x ) = ∫ p t ( x ∣ x 1 ) p data ( x 1 ) d x 1 p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1) \, d\mathbf{x}_1 pt(x)=∫pt(x∣x1)pdata(x1)dx1
核心定理(Lipman et al., 2023):边缘速度场等于条件速度场的期望:
v t ∗ ( x ) = E p data ( x 1 ) u t ( x ∣ x 1 ) ∣ x t = x \mathbf{v}t^*(\mathbf{x}) = \mathbb{E}{p_{\text{data}}(\mathbf{x}_1)} \left \\mathbf{u}_t(\\mathbf{x} \| \\mathbf{x}_1) \\, \\Big\| \\, \\mathbf{x}_t = \\mathbf{x} \\right vt∗(x)=Epdata(x1)ut(x∣x1) xt=x
即:
v t ∗ ( x ) = ∫ u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) p data ( x 1 ) p t ( x ) d x 1 \mathbf{v}_t^*(\mathbf{x}) = \int \mathbf{u}_t(\mathbf{x} | \mathbf{x}_1) \, \frac{p_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1)}{p_t(\mathbf{x})} \, d\mathbf{x}_1 vt∗(x)=∫ut(x∣x1)pt(x)pt(x∣x1)pdata(x1)dx1
证明:条件概率路径满足连续性方程:
∂ p t ( x ∣ x 1 ) ∂ t + ∇ ⋅ ( p t ( x ∣ x 1 ) u t ( x ∣ x 1 ) ) = 0 \frac{\partial p_t(\mathbf{x} | \mathbf{x}_1)}{\partial t} + \nabla \cdot \left( p_t(\mathbf{x} | \mathbf{x}_1) \, \mathbf{u}_t(\mathbf{x} | \mathbf{x}_1) \right) = 0 ∂t∂pt(x∣x1)+∇⋅(pt(x∣x1)ut(x∣x1))=0
对 x 1 \mathbf{x}_1 x1 积分:
∂ p t ( x ) ∂ t + ∇ ⋅ ( ∫ p t ( x ∣ x 1 ) u t ( x ∣ x 1 ) p data ( x 1 ) d x 1 ) = 0 \frac{\partial p_t(\mathbf{x})}{\partial t} + \nabla \cdot \left( \int p_t(\mathbf{x} | \mathbf{x}_1) \, \mathbf{u}_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1) \, d\mathbf{x}_1 \right) = 0 ∂t∂pt(x)+∇⋅(∫pt(x∣x1)ut(x∣x1)pdata(x1)dx1)=0
与边缘连续性方程对比:
∂ p t ( x ) ∂ t + ∇ ⋅ ( p t ( x ) v t ∗ ( x ) ) = 0 \frac{\partial p_t(\mathbf{x})}{\partial t} + \nabla \cdot \left( p_t(\mathbf{x}) \, \mathbf{v}_t^*(\mathbf{x}) \right) = 0 ∂t∂pt(x)+∇⋅(pt(x)vt∗(x))=0
可得:
p t ( x ) v t ∗ ( x ) = ∫ p t ( x ∣ x 1 ) u t ( x ∣ x 1 ) p data ( x 1 ) d x 1 p_t(\mathbf{x}) \, \mathbf{v}_t^*(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{x}_1) \, \mathbf{u}_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1) \, d\mathbf{x}_1 pt(x)vt∗(x)=∫pt(x∣x1)ut(x∣x1)pdata(x1)dx1
两边除以 p t ( x ) p_t(\mathbf{x}) pt(x) 即得结果。 ■ \blacksquare ■
3.4 条件 Flow Matching 损失
基于上述定理,Flow Matching 的训练损失定义为:
L CFM ( θ ) = E t ∼ U ( 0 , 1 ) , x 1 ∼ p data , x 0 ∼ p 0 ∥ v θ ( x t , t ) − u t ( x t ∣ x 1 ) ∥ 2 \mathcal{L}{\text{CFM}}(\theta) = \mathbb{E}{t \sim \mathcal{U}(0,1), \, \mathbf{x}1 \sim p{\text{data}}, \, \mathbf{x}_0 \sim p_0} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - \\mathbf{u}_t(\\mathbf{x}_t \| \\mathbf{x}_1) \\\|\^2 \\right LCFM(θ)=Et∼U(0,1),x1∼pdata,x0∼p0∥vθ(xt,t)−ut(xt∣x1)∥2
其中 x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1-t)\mathbf{x}_0 + t\mathbf{x}_1 xt=(1−t)x0+tx1。
等价性定理:
L CFM ( θ ) = L FM ( θ ) + const \mathcal{L}{\text{CFM}}(\theta) = \mathcal{L}{\text{FM}}(\theta) + \text{const} LCFM(θ)=LFM(θ)+const
其中 L FM \mathcal{L}_{\text{FM}} LFM 是边缘 Flow Matching 损失。这意味着优化条件损失等价于优化边缘损失 ,常数项不依赖于 θ \theta θ。
证明:
L CFM = E t , x 1 , x 0 ∥ v θ ( x t , t ) − u t ( x t ∣ x 1 ) ∥ 2 \mathcal{L}{\text{CFM}} = \mathbb{E}{t, \mathbf{x}_1, \mathbf{x}_0} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - \\mathbf{u}_t(\\mathbf{x}_t \| \\mathbf{x}_1) \\\|\^2 \\right LCFM=Et,x1,x0∥vθ(xt,t)−ut(xt∣x1)∥2
展开平方:
= E ∥ v θ ∥ 2 − 2 v θ ⋅ u t + ∥ u t ∥ 2 = \mathbb{E} \left \\\| \\mathbf{v}_\\theta \\\|\^2 - 2 \\mathbf{v}_\\theta \\cdot \\mathbf{u}_t + \\\| \\mathbf{u}_t \\\|\^2 \\right =E∥vθ∥2−2vθ⋅ut+∥ut∥2
中间项:
E x 1 , x 0 ∣ x t v θ ( x t , t ) ⋅ u t ( x t ∣ x 1 ) = v θ ( x t , t ) ⋅ v t ∗ ( x t ) \mathbb{E}_{\mathbf{x}_1, \mathbf{x}_0 | \mathbf{x}t} \left \\mathbf{v}_\\theta(\\mathbf{x}_t, t) \\cdot \\mathbf{u}_t(\\mathbf{x}_t \| \\mathbf{x}_1) \\right = \mathbf{v}\theta(\mathbf{x}_t, t) \cdot \mathbf{v}_t^*(\mathbf{x}_t) Ex1,x0∣xtvθ(xt,t)⋅ut(xt∣x1)=vθ(xt,t)⋅vt∗(xt)
最后一项 E ∥ u t ∥ 2 \mathbb{E}\\\| \\mathbf{u}_t \\\|\^2 E∥ut∥2 不依赖 θ \theta θ,因此:
L CFM = E t , x t ∥ v θ − v t ∗ ∥ 2 + const = L FM + const \mathcal{L}{\text{CFM}} = \mathbb{E}{t, \mathbf{x}t} \left \\\| \\mathbf{v}_\\theta - \\mathbf{v}_t\^\* \\\|\^2 \\right + \text{const} = \mathcal{L}{\text{FM}} + \text{const} LCFM=Et,xt∥vθ−vt∗∥2+const=LFM+const
■ \blacksquare ■
4. 最优传输与 Rectified Flows
4.1 最优传输基础
最优传输(Optimal Transport, OT) 问题:给定两个分布 μ \mu μ 和 ν \nu ν,找到一个传输映射 T : R d → R d T: \mathbb{R}^d \to \mathbb{R}^d T:Rd→Rd,使得:
T # μ = ν ( 即 T 将 μ 推前为 ν ) T_\# \mu = \nu \quad (\text{即 } T \text{ 将 } \mu \text{ 推前为 } \nu) T#μ=ν(即 T 将 μ 推前为 ν)
且传输代价最小:
min T ∫ ∥ x − T ( x ) ∥ 2 μ ( x ) d x \min_T \int \| \mathbf{x} - T(\mathbf{x}) \|^2 \, \mu(\mathbf{x}) \, d\mathbf{x} Tmin∫∥x−T(x)∥2μ(x)dx
Brenier 定理:在欧氏空间中,最优传输映射存在且唯一(在适当条件下),并且它是某个凸函数的梯度:
T ∗ ( x ) = ∇ ϕ ( x ) T^*(\mathbf{x}) = \nabla \phi(\mathbf{x}) T∗(x)=∇ϕ(x)
其中 ϕ \phi ϕ 是凸函数。
4.2 OT 与直线路径
最优传输的一个关键性质是:OT 映射诱导的路径是直线。
定义从 μ \mu μ 到 ν \nu ν 的 OT 路径:
x t = ( 1 − t ) x 0 + t T ∗ ( x 0 ) , t ∈ 0 , 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t T^*(\mathbf{x}_0), \quad t \in 0, 1 xt=(1−t)x0+tT∗(x0),t∈0,1
定理 (McCann 插值):这是 Wasserstein 距离 W 2 ( μ , ν ) W_2(\mu, \nu) W2(μ,ν) 的唯一测地线(geodesic)。
直觉解释:OT 映射将每个点直接送到目的地,不走弯路。因此每个粒子的轨迹都是直线,只是不同粒子的速度不同。
4.3 为什么标准 Flow Matching 的路径不是直线
在标准 Flow Matching 中,我们使用独立耦合:
x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 xt=(1−t)x0+tx1
其中 x 0 \mathbf{x}_0 x0 和 x 1 \mathbf{x}_1 x1 是独立采样 的。这意味着 x 0 \mathbf{x}_0 x0 不一定被映射到与之"最优配对"的 x 1 \mathbf{x}_1 x1。
考虑一个简单例子: x 0 ∼ N ( 0 , 1 ) \mathbf{x}_0 \sim \mathcal{N}(0, 1) x0∼N(0,1), x 1 ∼ N ( 0 , 1 ) \mathbf{x}_1 \sim \mathcal{N}(0, 1) x1∼N(0,1)。
- 独立耦合:随机配对,路径可能交叉,轨迹弯曲
- OT 耦合: T ( x ) = x T(x) = x T(x)=x(当两个分布相同时),路径为直线(静止不动)
路径弯曲的代价:
- 速度场更复杂,更难学习
- 采样时需要更多步数(因为轨迹弯曲,小步长才能精确积分)
4.4 Rectified Flows:拉直路径
Rectified Flow (Liu et al., 2023)的核心思想:先训练一个 Flow Matching 模型,然后用它生成的数据构造 OT 近似的耦合,再重新训练。
步骤 1 :训练初始 Flow Matching 模型 v θ ( 0 ) \mathbf{v}_\theta^{(0)} vθ(0)
步骤 2 :用 v θ ( 0 ) \mathbf{v}_\theta^{(0)} vθ(0) 生成配对数据
x 0 ∼ p 0 , x 1 = ODESolve ( v θ ( 0 ) , x 0 , 0 , 1 ) \mathbf{x}_0 \sim p_0, \quad \mathbf{x}1 = \text{ODESolve}(\mathbf{v}\theta^{(0)}, \mathbf{x}_0, 0, 1) x0∼p0,x1=ODESolve(vθ(0),x0,0,1)
这构造了一个确定性耦合 ( x 0 , x 1 ) (\mathbf{x}_0, \mathbf{x}_1) (x0,x1)。
步骤 3:用新的配对数据重新训练
x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 xt=(1−t)x0+tx1
收敛定理(Liu et al., 2023):重复步骤 2-3,路径的"弯曲度"单调递减,最终收敛到直线路径。
4.5 路径直度的量化
定义路径直度(Straightness):
Straightness = E x 0 , x 1 ∫ 0 1 ∥ v θ ( x t , t ) − ( x 1 − x 0 ) ∥ 2 d t \text{Straightness} = \mathbb{E}_{\mathbf{x}_0, \mathbf{x}_1} \left \\int_0\^1 \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - (\\mathbf{x}_1 - \\mathbf{x}_0) \\\|\^2 \\, dt \\right Straightness=Ex0,x1∫01∥vθ(xt,t)−(x1−x0)∥2dt
当路径为直线时, v θ ( x t , t ) = x 1 − x 0 \mathbf{v}_\theta(\mathbf{x}_t, t) = \mathbf{x}_1 - \mathbf{x}_0 vθ(xt,t)=x1−x0(常数),Straightness = 0。
为什么直线路径更好?
直线路径意味着速度场沿路径为常数,因此:
- 神经网络只需学习一个"简单"的函数
- ODE 积分可以用更大的步长(因为曲率小)
- 理论上只需 1 步就能精确采样
5. Flow Matching 损失函数的完整推导
5.1 线性插值路径下的条件速度场
给定 x 0 ∼ N ( 0 , I ) \mathbf{x}_0 \sim \mathcal{N}(0, \mathbf{I}) x0∼N(0,I), x 1 ∼ p data \mathbf{x}1 \sim p{\text{data}} x1∼pdata,线性插值:
x t = ( 1 − t ) x 0 + t x 1 = x 0 + t ( x 1 − x 0 ) \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 = \mathbf{x}_0 + t (\mathbf{x}_1 - \mathbf{x}_0) xt=(1−t)x0+tx1=x0+t(x1−x0)
条件速度场:
u t ( x t ∣ x 1 ) = d x t d t = x 1 − x 0 \mathbf{u}_t(\mathbf{x}_t | \mathbf{x}_1) = \frac{d\mathbf{x}_t}{dt} = \mathbf{x}_1 - \mathbf{x}_0 ut(xt∣x1)=dtdxt=x1−x0
注意:这个速度场不依赖于 t t t,沿整条路径为常数。
5.2 重新参数化
实践中,我们不直接预测速度 x 1 − x 0 \mathbf{x}_1 - \mathbf{x}_0 x1−x0,而是预测目标 x 1 \mathbf{x}_1 x1 (或噪声 x 0 \mathbf{x}_0 x0)。
速度参数化:
v θ ( x t , t ) ≈ x 1 − x 0 \mathbf{v}_\theta(\mathbf{x}_t, t) \approx \mathbf{x}_1 - \mathbf{x}_0 vθ(xt,t)≈x1−x0
目标参数化:
x ^ θ ( x t , t ) ≈ x 1 \hat{\mathbf{x}}_\theta(\mathbf{x}_t, t) \approx \mathbf{x}_1 x^θ(xt,t)≈x1
两者的关系:
v θ ( x t , t ) = x ^ θ ( x t , t ) − x t 1 − t \mathbf{v}_\theta(\mathbf{x}t, t) = \frac{\hat{\mathbf{x}}\theta(\mathbf{x}_t, t) - \mathbf{x}_t}{1 - t} vθ(xt,t)=1−tx^θ(xt,t)−xt
最简洁的做法是直接预测速度:
L ( θ ) = E t , x 0 , x 1 ∥ v θ ( x t , t ) − ( x 1 − x 0 ) ∥ 2 \mathcal{L}(\theta) = \mathbb{E}_{t, \mathbf{x}_0, \mathbf{x}_1} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - (\\mathbf{x}_1 - \\mathbf{x}_0) \\\|\^2 \\right L(θ)=Et,x0,x1∥vθ(xt,t)−(x1−x0)∥2
5.3 与扩散模型的统一
Flow Matching 提供了一个统一框架,将扩散模型视为特例:
DDPM :前向过程为 x t = α ˉ t x 1 + 1 − α ˉ t ϵ \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_1 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon} xt=αˉt x1+1−αˉt ϵ
这等价于 Flow Matching 中使用方差保持(VP)插值:
x t = α ˉ t x 1 + 1 − α ˉ t x 0 \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_1 + \sqrt{1 - \bar{\alpha}_t} \mathbf{x}_0 xt=αˉt x1+1−αˉt x0
对应的条件速度场:
u t ( x t ∣ x 1 ) = α ˉ ˙ t 2 α ˉ t x t + α ˉ ˙ t α ˉ t 2 ( 1 − α ˉ t ) ( x t − α ˉ t x 1 ) \mathbf{u}_t(\mathbf{x}_t | \mathbf{x}_1) = \frac{\dot{\bar{\alpha}}_t}{2\bar{\alpha}_t} \mathbf{x}_t + \frac{\dot{\bar{\alpha}}_t \bar{\alpha}_t}{2(1 - \bar{\alpha}_t)} (\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_1) ut(xt∣x1)=2αˉtαˉ˙txt+2(1−αˉt)αˉ˙tαˉt(xt−αˉt x1)
这个速度场依赖于 t t t,且路径弯曲,这正是扩散模型需要多步采样的根本原因。
Flow Matching(线性插值):
x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 xt=(1−t)x0+tx1
速度场为常数 x 1 − x 0 \mathbf{x}_1 - \mathbf{x}_0 x1−x0,路径为直线。
不同插值方法对比:
| 方法 | α t \alpha_t αt | σ t \sigma_t σt | 路径形状 |
|---|---|---|---|
| DDPM (VP) | α ˉ t \sqrt{\bar{\alpha}_t} αˉt | 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1−αˉt | 弯曲 |
| DDIM (VE) | 1 | 1 − t \sqrt{1-t} 1−t | 弯曲 |
| Flow Matching | t t t | 1 − t 1-t 1−t | 直线 |
| Optimal Transport | t t t | 1 − t 1-t 1−t (OT 耦合) | 直线 |
6. 完整可运行实现
6.1 核心工具函数
python
"""
Flow Matching & Rectified Flow --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib, sklearn
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from typing import Tuple, Optional, List
import math
from tqdm import tqdm
def ode_solve(
velocity_fn,
x0: torch.Tensor,
t_span: Tuple[float, float] = (0.0, 1.0),
num_steps: int = 100,
method: str = "euler",
) -> torch.Tensor:
"""
求解 ODE: dx/dt = v(x, t)
支持的求解器:
- euler: 欧拉法, 一阶
- rk4: 四阶 Runge-Kutta
"""
dt = (t_span[1] - t_span[0]) / num_steps
x = x0.clone()
t = t_span[0]
trajectory = [x.clone()]
for step in range(num_steps):
if method == "euler":
v = velocity_fn(x, torch.full((x.shape[0],), t, device=x.device))
x = x + dt * v
elif method == "rk4":
t_batch = torch.full((x.shape[0],), t, device=x.device)
k1 = velocity_fn(x, t_batch)
k2 = velocity_fn(x + 0.5 * dt * k1, t_batch + 0.5 * dt)
k3 = velocity_fn(x + 0.5 * dt * k2, t_batch + 0.5 * dt)
k4 = velocity_fn(x + dt * k3, t_batch + dt)
x = x + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
t += dt
trajectory.append(x.clone())
return x, trajectory
6.2 神经网络速度场
python
class SinusoidalTimeEmbedding(nn.Module):
"""正弦时间嵌入"""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
class VelocityNetwork2D(nn.Module):
"""2D 数据的速度场网络"""
def __init__(self, hidden_dim: int = 256, time_dim: int = 64):
super().__init__()
self.time_embed = SinusoidalTimeEmbedding(time_dim)
self.input_proj = nn.Linear(2 + time_dim, hidden_dim)
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
) for _ in range(4)
])
self.output_proj = nn.Linear(hidden_dim, 2)
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_emb = self.time_embed(t)
h = torch.cat([x, t_emb], dim=-1)
h = self.input_proj(h)
for block in self.blocks:
h = h + block(h)
return self.output_proj(h)
6.3 Flow Matching 训练器
python
class FlowMatchingTrainer:
"""Flow Matching 训练器"""
def __init__(self, model: nn.Module, sigma_min: float = 1e-4):
self.model = model
self.sigma_min = sigma_min
def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.rand(batch_size, device=device)
def compute_loss(self, x1: torch.Tensor) -> torch.Tensor:
batch_size = x1.shape[0]
device = x1.device
x0 = torch.randn_like(x1)
t = self.sample_time(batch_size, device)
if x1.dim() == 2:
t_expand = t.unsqueeze(-1)
elif x1.dim() == 4:
t_expand = t.view(-1, 1, 1, 1)
xt = (1 - t_expand) * x0 + t_expand * x1
v_target = x1 - x0
v_pred = self.model(xt, t)
return F.mse_loss(v_pred, v_target)
def train_step(self, x1: torch.Tensor, optimizer) -> float:
optimizer.zero_grad()
loss = self.compute_loss(x1)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
return loss.item()
@torch.no_grad()
def sample(
self,
num_samples: int,
shape: Tuple[int, ...],
device: torch.device,
num_steps: int = 50,
method: str = "euler",
) -> torch.Tensor:
x0 = torch.randn(num_samples, *shape, device=device)
def velocity_fn(x, t):
return self.model(x, t)
x1, trajectory = ode_solve(
velocity_fn, x0,
t_span=(0.0, 1.0),
num_steps=num_steps,
method=method
)
return x1, trajectory
6.4 Rectified Flow
python
class RectifiedFlow:
"""Rectified Flow: 通过 reflow 操作拉直传输路径"""
def __init__(self, model_class, model_kwargs, device: torch.device):
self.device = device
self.model_class = model_class
self.model_kwargs = model_kwargs
self.models = []
def train_round(
self,
data_samples: torch.Tensor,
num_steps: int = 5000,
batch_size: int = 256,
lr: float = 1e-3,
prev_model: Optional[nn.Module] = None,
num_ode_steps: int = 50,
) -> nn.Module:
model = self.model_class(**self.model_kwargs).to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
trainer = FlowMatchingTrainer(model)
dataset_size = data_samples.shape[0]
losses = []
for step in tqdm(range(num_steps), desc="Training"):
idx = torch.randint(0, dataset_size, (batch_size,))
x1_batch = data_samples[idx].to(self.device)
if prev_model is not None:
x0_batch = torch.randn_like(x1_batch)
def prev_velocity(x, t):
return prev_model(x, t)
with torch.no_grad():
x1_generated, _ = ode_solve(
prev_velocity, x0_batch,
t_span=(0.0, 1.0),
num_steps=num_ode_steps,
method="rk4"
)
loss = trainer.compute_loss(x1_generated)
else:
loss = trainer.compute_loss(x1_batch)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
losses.append(loss.item())
self.models.append(model)
return model, losses
def compute_path_straightness(
self,
model: nn.Module,
num_samples: int = 100,
num_timesteps: int = 50,
) -> Tuple[float, List[torch.Tensor]]:
model.eval()
x0 = torch.randn(num_samples, 2, device=self.device)
def velocity_fn(x, t):
return model(x, t)
_, trajectory = ode_solve(
velocity_fn, x0,
t_span=(0.0, 1.0),
num_steps=num_timesteps,
method="euler"
)
x1 = trajectory[-1]
straightnesses = []
ts = np.linspace(0, 1, num_timesteps + 1)
for i, (t_val, xt) in enumerate(zip(ts, trajectory)):
with torch.no_grad():
v_pred = model(xt, torch.full((num_samples,), t_val, device=self.device))
v_target = x1 - x0
straightness = ((v_pred - v_target) ** 2).mean().item()
straightnesses.append(straightness)
return np.mean(straightnesses), trajectory
6.5 实验代码
python
def experiment_2d_flow_matching():
"""实验: 2D 数据上的 Flow Matching"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 生成数据
data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
data = (data - data.mean(axis=0)) / data.std(axis=0)
data = torch.tensor(data, dtype=torch.float32)
# 创建模型
model = VelocityNetwork2D(hidden_dim=256, time_dim=64).to(device)
trainer = FlowMatchingTrainer(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
# 训练
losses = []
for step in range(3000):
idx = torch.randint(0, data.shape[0], (512,))
x1 = data[idx].to(device)
loss = trainer.train_step(x1, optimizer)
losses.append(loss)
if (step + 1) % 500 == 0:
print(f"Step {step+1} | Loss: {np.mean(losses[-100:]):.6f}")
# 生成样本
model.eval()
generated, trajectory = trainer.sample(
num_samples=2000, shape=(2,), device=device,
num_steps=100, method="rk4"
)
return model, trainer, generated.cpu(), losses
def experiment_2d_rectified_flow():
"""实验: Rectified Flow (Reflow)"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
data = (data - data.mean(axis=0)) / data.std(axis=0)
data = torch.tensor(data, dtype=torch.float32)
def create_model():
return VelocityNetwork2D(hidden_dim=256, time_dim=64)
rf = RectifiedFlow(create_model, {}, device)
prev_model = None
for round_idx in range(3):
print(f"\n--- Round {round_idx + 1}/3 ---")
model, losses = rf.train_round(
data, num_steps=3000, batch_size=512,
lr=1e-3, prev_model=prev_model, num_ode_steps=50
)
straightness, _ = rf.compute_path_straightness(model)
print(f" 路径直度: {straightness:.6f}")
prev_model = model
return rf
7. 理论深度分析
7.1 概率流 ODE 与随机 SDE
扩散模型有两种等价表述:
概率流 ODE(PF-ODE):
d x t = f ( x t , t ) − 1 2 g ( t ) 2 ∇ log p t ( x t ) d t d\mathbf{x}_t = \left \\mathbf{f}(\\mathbf{x}_t, t) - \\frac{1}{2} g(t)\^2 \\nabla \\log p_t(\\mathbf{x}_t) \\right dt dxt=f(xt,t)−21g(t)2∇logpt(xt)dt
反向 SDE:
d x t = f ( x t , t ) − g ( t ) 2 ∇ log p t ( x t ) d t + g ( t ) d w ˉ t d\mathbf{x}_t = \left \\mathbf{f}(\\mathbf{x}_t, t) - g(t)\^2 \\nabla \\log p_t(\\mathbf{x}_t) \\right dt + g(t) d\bar{\mathbf{w}}_t dxt=f(xt,t)−g(t)2∇logpt(xt)dt+g(t)dwˉt
Flow Matching 直接学习 PF-ODE 的速度场,避免了分数函数 ∇ log p t \nabla \log p_t ∇logpt 的估计。
7.2 最优传输的理论保证
定理 (Benamou-Brenier 公式):两个分布 μ \mu μ 和 ν \nu ν 之间的 Wasserstein-2 距离为:
W 2 2 ( μ , ν ) = inf v , p ∫ 0 1 ∫ R d ∥ v t ( x ) ∥ 2 p t ( x ) d x d t W_2^2(\mu, \nu) = \inf_{\mathbf{v}, p} \int_0^1 \int_{\mathbb{R}^d} \| \mathbf{v}_t(\mathbf{x}) \|^2 p_t(\mathbf{x}) \, d\mathbf{x} \, dt W22(μ,ν)=v,pinf∫01∫Rd∥vt(x)∥2pt(x)dxdt
其中下确界取遍所有满足连续性方程的速度场 v \mathbf{v} v 和概率路径 p t p_t pt,且 p 0 = μ , p 1 = ν p_0 = \mu, p_1 = \nu p0=μ,p1=ν。
推论:OT 路径最小化了"动能",因此是最直的路径。Rectified Flow 的 reflow 操作正是在数值上逼近这个最优解。
7.3 单步生成的理论极限
理论上,如果速度场完美学习,只需 1 步 ODE 积分即可生成:
x 1 = x 0 + v θ ( x 0 , 0 ) \mathbf{x}_1 = \mathbf{x}0 + \mathbf{v}\theta(\mathbf{x}_0, 0) x1=x0+vθ(x0,0)
但这要求 v θ ( x , 0 ) \mathbf{v}_\theta(\mathbf{x}, 0) vθ(x,0) 精确等于 OT 映射 T ∗ ( x ) − x T^*(\mathbf{x}) - \mathbf{x} T∗(x)−x。实践中,Rectified Flow 通过 reflow 逐步逼近这一极限。
8. Flow Matching 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════╗
║ Flow Matching & Rectified Flow 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 连续正规化流 (CNF): ║
║ dx/dt = v_θ(x, t), t ∈ [0, 1] ║
║ ║
║ 2. 连续性方程 (概率守恒): ║
║ ∂p_t(x)/∂t + ∇·(p_t(x) · v_θ(x, t)) = 0 ║
║ ║
║ 3. 线性插值路径: ║
║ x_t = (1-t)·x₀ + t·x₁ (x₀ ~ N(0,I), x₁ ~ p_data) ║
║ ║
║ 4. 条件速度场: ║
║ u_t(x_t | x₁) = dx_t/dt = x₁ - x₀ ║
║ ║
║ 5. 条件 Flow Matching 损失: ║
║ L_CFM(θ) = E_{t,x₀,x₁} [ ‖v_θ(x_t, t) - (x₁ - x₀)‖² ] ║
║ ║
║ 6. 边缘速度场 = 条件期望: ║
║ v*_t(x) = E_{p(x₁|x_t=x)} [ u_t(x | x₁) ] ║
║ ║
║ 7. 损失等价性: ║
║ L_CFM(θ) = L_FM(θ) + const (优化条件损失 ≡ 优化边缘损失) ║
║ ║
║ 8. 最优传输 (OT) 路径: ║
║ x_t = (1-t)·x₀ + t·T*(x₀) (T* = OT 映射, 路径为直线) ║
║ ║
║ 9. Rectified Flow 收敛: ║
║ 重复 reflow → 路径直度单调递减 → 收敛到 OT 直线 ║
║ ║
║ 10. 采样: ║
║ x₁ = x₀ + ∫₀¹ v_θ(x_t, t) dt (ODE 积分, 路径越直步数越少) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════╝
参考文献
Mamba
- Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022.
- Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
- Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. ICML 2024.
- Gu, A., Dao, T., et al. (2020). HiPPO: Recurrent Memory with Optimal Polynomial Projections. NeurIPS 2020.
Flow Matching
- Lipman, Y., Chen, R.T., Ben-Hamu, H., Nickel, M., & Le, M. (2023). Flow Matching for Generative Modeling. ICLR 2023.
- Liu, X., Gong, C., & Liu, Q. (2023). Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. ICLR 2023.
- Esser, P., Kulal, S., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. ICML 2024.
- Chen, R.T., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural Ordinary Differential Equations. NeurIPS 2018.
- Song, Y., Sohl-Dickstein, J., et al. (2023). Consistency Models. ICML 2023.
GRPO
- Shao, Z., Wang, P., et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.
- DeepSeek-AI. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning.
- Schulman, J., Wolski, F., et al. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347.
- Schulman, J., Moritz, P., et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation. ICLR 2016.
- Sutton, R.S., McAllester, D., et al. (1999). Policy Gradient Methods for Reinforcement Learning with Function Approximation. NeurIPS 1999.
- Ouyang, L., Wu, J., et al. (2022). Training Language Models to Follow Instructions with Human Feedback. NeurIPS 2022.
- Guo, D., Yang, D., et al. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning.