前沿算法深度解析(一)

目录

  • [第一篇:Mamba --- 选择性状态空间模型](#第一篇:Mamba — 选择性状态空间模型)
  • [第二篇:Flow Matching 与 Rectified Flows](#第二篇:Flow Matching 与 Rectified Flows)
  • [第三篇:GRPO --- 群体相对策略优化](#第三篇:GRPO — 群体相对策略优化)
  • 参考文献

第一篇:Mamba --- 选择性状态空间模型

1. 引言

Transformer 架构自 2017 年问世以来统治了深度学习领域,但其核心的自注意力机制存在 O ( L 2 ) O(L^2) O(L2) 的计算复杂度和 O ( L ) O(L) O(L) 的显存占用( L L L 为序列长度),在处理长序列时成为瓶颈。

Mamba (Gu & Dao, 2023)提出了一种全新的序列建模范式------选择性状态空间模型(Selective State Space Model) ,通过将经典控制论中的状态空间模型与数据依赖的选择机制相结合,实现了 O ( L ) O(L) O(L) 的线性复杂度,同时在语言、基因组学、音频等模态上达到了与 Transformer 相当甚至更优的性能。


2. 理论基础 --- 连续时间状态空间模型

2.1 线性时不变系统(LTI System)

状态空间模型(SSM)源自控制论,描述一个 N N N 维隐状态 h ( t ) ∈ R N \mathbf{h}(t) \in \mathbb{R}^N h(t)∈RN 如何随时间演化。连续时间形式为:

d h ( t ) d t = A h ( t ) + B x ( t ) \frac{d\mathbf{h}(t)}{dt} = \mathbf{A}\mathbf{h}(t) + \mathbf{B}x(t) dtdh(t)=Ah(t)+Bx(t)

y ( t ) = C h ( t ) + D ⋅ x ( t ) y(t) = \mathbf{C}\mathbf{h}(t) + D \cdot x(t) y(t)=Ch(t)+D⋅x(t)

其中:

符号 维度 含义
A \mathbf{A} A R N × N \mathbb{R}^{N \times N} RN×N 状态转移矩阵,控制隐状态的演化动力学
B \mathbf{B} B R N × 1 \mathbb{R}^{N \times 1} RN×1 输入投影矩阵,决定输入如何注入状态
C \mathbf{C} C R 1 × N \mathbb{R}^{1 \times N} R1×N 输出投影矩阵,决定如何从状态读取输出
D D D R \mathbb{R} R 直通项(通常忽略或设为残差连接)
x ( t ) x(t) x(t) R \mathbb{R} R 标量输入信号
y ( t ) y(t) y(t) R \mathbb{R} R 标量输出信号

物理直觉 :想象一个弹簧-阻尼系统, A \mathbf{A} A 决定了系统的固有动力学(如弹簧常数、阻尼系数), B \mathbf{B} B 是外部施加的力, C \mathbf{C} C 是我们观察系统的方式。

2.2 离散化:从连续到离散

为了在数字计算机上处理离散序列 x = ( x 1 , x 2 , ... , x L ) \mathbf{x} = (x_1, x_2, \ldots, x_L) x=(x1,x2,...,xL),我们需要对连续系统进行零阶保持(Zero-Order Hold, ZOH) 离散化。

给定步长 Δ > 0 \Delta > 0 Δ>0(可理解为采样间隔),ZOH 离散化假设输入在一个采样间隔内保持不变: x ( t ) = x k , t ∈ [ k Δ , ( k + 1 ) Δ ) x(t) = x_k, \quad t \in [k\Delta, (k+1)\Delta) x(t)=xk,t∈[kΔ,(k+1)Δ)。

离散化后的系统变为:

h k = A ˉ h k − 1 + B ˉ x k \mathbf{h}k = \bar{\mathbf{A}} \mathbf{h}{k-1} + \bar{\mathbf{B}} x_k hk=Aˉhk−1+Bˉxk

y k = C h k y_k = \mathbf{C} \mathbf{h}_k yk=Chk

其中离散化参数为:

A ˉ = exp ⁡ ( Δ A ) \bar{\mathbf{A}} = \exp(\Delta \mathbf{A}) Aˉ=exp(ΔA)

B ˉ = ( Δ A ) − 1 ( exp ⁡ ( Δ A ) − I ) ⋅ Δ B ≈ Δ B \bar{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B} \approx \Delta \mathbf{B} Bˉ=(ΔA)−1(exp(ΔA)−I)⋅ΔB≈ΔB

A ˉ \bar{\mathbf{A}} Aˉ 的矩阵指数解释 : exp ⁡ ( Δ A ) = ∑ k = 0 ∞ ( Δ A ) k k ! \exp(\Delta \mathbf{A}) = \sum_{k=0}^{\infty} \frac{(\Delta \mathbf{A})^k}{k!} exp(ΔA)=∑k=0∞k!(ΔA)k 是状态转移矩阵的自然推广。它精确地描述了在 Δ \Delta Δ 时间内状态如何从 h k − 1 \mathbf{h}_{k-1} hk−1 演化到 h k \mathbf{h}_k hk。

2.3 全局卷积视角

离散 SSM 的一个关键洞察是:线性递推等价于全局卷积

展开递推关系:

y k = C A ˉ k B ˉ x 0 + C A ˉ k − 1 B ˉ x 1 + ⋯ + C B ˉ x k y_k = \mathbf{C}\bar{\mathbf{A}}^k \bar{\mathbf{B}} x_0 + \mathbf{C}\bar{\mathbf{A}}^{k-1} \bar{\mathbf{B}} x_1 + \cdots + \mathbf{C}\bar{\mathbf{B}} x_k yk=CAˉkBˉx0+CAˉk−1Bˉx1+⋯+CBˉxk

定义卷积核:

K ˉ = ( C B ˉ , C A ˉ B ˉ , C A ˉ 2 B ˉ , ... , C A ˉ L − 1 B ˉ ) ∈ R L \bar{\mathbf{K}} = (\mathbf{C}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}^2\bar{\mathbf{B}}, \ldots, \mathbf{C}\bar{\mathbf{A}}^{L-1}\bar{\mathbf{B}}) \in \mathbb{R}^L Kˉ=(CBˉ,CAˉBˉ,CAˉ2Bˉ,...,CAˉL−1Bˉ)∈RL

则输出可表示为:

y = K ˉ ∗ x \mathbf{y} = \bar{\mathbf{K}} * \mathbf{x} y=Kˉ∗x

其中 ∗ * ∗ 为卷积操作。通过 FFT,卷积可在 O ( L log ⁡ L ) O(L \log L) O(LlogL) 时间内计算。

2.4 HiPPO 初始化:解决长程依赖

A \mathbf{A} A 矩阵的初始化至关重要。HiPPO(High-order Polynomial Projection Operators) 框架(Gu et al., 2020)提供了理论上最优的初始化方式。

HiPPO 矩阵的数学定义:

A n k = − { ( 2 n + 1 ) 1 / 2 ( 2 k + 1 ) 1 / 2 if n > k n + 1 if n = k 0 if n < k \mathbf{A}_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} Ank=−⎩ ⎨ ⎧(2n+1)1/2(2k+1)1/2n+10if n>kif n=kif n<k

直觉解释 :HiPPO 矩阵将输入信号在线上投影到正交多项式基上,使得隐状态 h ( t ) \mathbf{h}(t) h(t) 始终保存了过去输入的最优低秩近似。这赋予了 SSM 理论上的长程记忆能力。

2.5 固定参数 SSM 的根本局限

在 S4 等经典 SSM 中, A , B , C \mathbf{A}, \mathbf{B}, \mathbf{C} A,B,C 对所有输入都是固定的(LTI 性质)。这带来了两个根本问题:

问题 1:无法进行内容感知推理

固定参数的 SSM 本质上是一个时不变滤波器。它无法根据输入内容选择性地记住或遗忘信息。例如:

复制代码
输入: "The capital of France is Paris. The capital of Germany is ..."

一个内容无关的 SSM 对所有 token 施加相同的衰减模式,无法根据语义选择性地保留 "France → Paris" 的关联以便回答后续问题。

问题 2:无法实现选择性机制

在语言模型中,关键能力之一是根据上下文选择性地处理信息------有时需要长程依赖(指代消解),有时需要短程关注(局部语法)。固定 SSM 的卷积核 K ˉ \bar{\mathbf{K}} Kˉ 是预计算的,无法适应不同的输入。


3. Mamba 的核心创新 --- 选择性机制

3.1 从固定到动态:输入依赖参数

Mamba 的核心思想极其简洁:让 SSM 的参数依赖于输入

具体地,Mamba 将离散 SSM 的参数 B , C , Δ \mathbf{B}, \mathbf{C}, \Delta B,C,Δ 从固定常数变为输入的函数:

B k = Linear B ( x k ) ∈ R N \mathbf{B}_k = \text{Linear}_B(x_k) \in \mathbb{R}^N Bk=LinearB(xk)∈RN

C k = Linear C ( x k ) ∈ R N \mathbf{C}_k = \text{Linear}_C(x_k) \in \mathbb{R}^N Ck=LinearC(xk)∈RN

Δ k = softplus ( Linear Δ ( x k ) ) ∈ R > 0 \Delta_k = \text{softplus}(\text{Linear}\Delta(x_k)) \in \mathbb{R}{>0} Δk=softplus(LinearΔ(xk))∈R>0

其中 softplus ( z ) = ln ⁡ ( 1 + e z ) \text{softplus}(z) = \ln(1 + e^z) softplus(z)=ln(1+ez) 确保 Δ k > 0 \Delta_k > 0 Δk>0。

关键变化

  • B k \mathbf{B}_k Bk 现在是 N N N 维向量(而非标量),每个状态维度有独立的输入投影
  • C k \mathbf{C}_k Ck 同样是 N N N 维向量
  • Δ k \Delta_k Δk 是标量,控制对当前输入的"关注程度"

3.2 选择机制的数学分析

Δ k \Delta_k Δk 的作用可以精确分析。离散化后:

A ˉ k = exp ⁡ ( Δ k A ) \bar{\mathbf{A}}_k = \exp(\Delta_k \mathbf{A}) Aˉk=exp(ΔkA)

B ˉ k = Δ k B k \bar{\mathbf{B}}_k = \Delta_k \mathbf{B}_k Bˉk=ΔkBk

考虑 A \mathbf{A} A 为对角矩阵 A = diag ( a 1 , a 2 , ... , a N ) \mathbf{A} = \text{diag}(a_1, a_2, \ldots, a_N) A=diag(a1,a2,...,aN)(Mamba 的实现选择),则:

A ˉ k , n n = exp ⁡ ( Δ k ⋅ a n ) \bar{A}_{k,nn} = \exp(\Delta_k \cdot a_n) Aˉk,nn=exp(Δk⋅an)

当 Δ k \Delta_k Δk 很大时

  • A ˉ k → 0 \bar{\mathbf{A}}_k \to \mathbf{0} Aˉk→0(状态被重置)
  • B ˉ k \bar{\mathbf{B}}_k Bˉk 很大(输入被强烈注入)
  • 效果 :模型选择性地忽略 历史状态,聚焦于当前输入

当 Δ k \Delta_k Δk 很小时

  • A ˉ k → I \bar{\mathbf{A}}_k \to \mathbf{I} Aˉk→I(状态几乎不变)
  • B ˉ k → 0 \bar{\mathbf{B}}_k \to \mathbf{0} Bˉk→0(输入几乎不注入)
  • 效果 :模型选择性地保留 历史状态,忽略当前输入

这正是选择机制的本质 : Δ k \Delta_k Δk 作为"门控",动态决定每个时间步是记忆还是遗忘。

3.3 选择性 SSM 的递推公式

综合以上,Mamba 的完整递推公式为:

A ˉ k = exp ⁡ ( Δ k ⋅ diag ( a 1 , ... , a N ) ) \bar{\mathbf{A}}_k = \exp(\Delta_k \cdot \text{diag}(a_1, \ldots, a_N)) Aˉk=exp(Δk⋅diag(a1,...,aN))

B ˉ k = Δ k ⋅ B k \bar{\mathbf{B}}_k = \Delta_k \cdot \mathbf{B}_k Bˉk=Δk⋅Bk

h k = A ˉ k ⊙ h k − 1 + B ˉ k ⊙ x k \mathbf{h}_k = \bar{\mathbf{A}}k \odot \mathbf{h}{k-1} + \bar{\mathbf{B}}_k \odot x_k hk=Aˉk⊙hk−1+Bˉk⊙xk

y k = C k ⋅ h k y_k = \mathbf{C}_k \cdot \mathbf{h}_k yk=Ck⋅hk

其中 ⊙ \odot ⊙ 表示逐元素乘法(因为 A \mathbf{A} A 是对角的)。

复杂度分析

指标 复杂度
每步计算量 O ( N ) O(N) O(N)(向量逐元素乘法和点积)
总计算量 O ( L N ) O(LN) O(LN)
空间复杂度 O ( N ) O(N) O(N)(只需保存当前状态)

3.4 为什么选择性破坏了卷积效率

在经典 SSM 中,参数固定 → 卷积核 K ˉ \bar{\mathbf{K}} Kˉ 可预计算 → 可用 FFT 加速。

但在 Mamba 中,参数依赖输入 → 每个时间步的 A ˉ k , B ˉ k \bar{\mathbf{A}}_k, \bar{\mathbf{B}}_k Aˉk,Bˉk 不同 → 无法预计算卷积核 → 卷积加速失效。

这看似是一个严重的效率倒退,但 Mamba 通过硬件感知的并行扫描算法巧妙地解决了这个问题。


4. 硬件感知并行扫描

4.1 问题定义

选择性 SSM 的递推:

h k = A ˉ k ⊙ h k − 1 + B ˉ k ⊙ x k \mathbf{h}_k = \bar{\mathbf{A}}k \odot \mathbf{h}{k-1} + \bar{\mathbf{B}}_k \odot x_k hk=Aˉk⊙hk−1+Bˉk⊙xk

这是一个线性递推(Linear Recurrence),形式为:

h k = a k ⊙ h k − 1 + b k \mathbf{h}_k = \mathbf{a}k \odot \mathbf{h}{k-1} + \mathbf{b}_k hk=ak⊙hk−1+bk

其中 a k = A ˉ k ∈ R N \mathbf{a}_k = \bar{\mathbf{A}}_k \in \mathbb{R}^N ak=Aˉk∈RN, b k = B ˉ k ⊙ x k ∈ R N \mathbf{b}_k = \bar{\mathbf{B}}_k \odot x_k \in \mathbb{R}^N bk=Bˉk⊙xk∈RN。

串行计算需要 O ( L ) O(L) O(L) 步,无法利用 GPU 的大规模并行性。

4.2 并行前缀扫描(Parallel Prefix Scan)

并行扫描是解决线性递推的经典并行算法。定义扫描算子 ∙ \bullet ∙:

( a 1 , b 1 ) ∙ ( a 2 , b 2 ) = ( a 2 ⊙ a 1 , a 2 ⊙ b 1 + b 2 ) (\mathbf{a}_1, \mathbf{b}_1) \bullet (\mathbf{a}_2, \mathbf{b}_2) = (\mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{b}_2) (a1,b1)∙(a2,b2)=(a2⊙a1,a2⊙b1+b2)

关键性质 :该算子满足结合律

证明 :设三个连续元素 ( a i , b i ) (\mathbf{a}_i, \mathbf{b}_i) (ai,bi),验证结合律:

( a 1 , b 1 ) ∙ ( a 2 , b 2 ) ∙ ( a 3 , b 3 ) (\\mathbf{a}_1, \\mathbf{b}_1) \\bullet (\\mathbf{a}_2, \\mathbf{b}_2) \bullet (\mathbf{a}_3, \mathbf{b}_3) (a1,b1)∙(a2,b2)∙(a3,b3)

= ( a 2 ⊙ a 1 , a 2 ⊙ b 1 + b 2 ) ∙ ( a 3 , b 3 ) = (\mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{b}_2) \bullet (\mathbf{a}_3, \mathbf{b}_3) =(a2⊙a1,a2⊙b1+b2)∙(a3,b3)

= ( a 3 ⊙ a 2 ⊙ a 1 , a 3 ⊙ a 2 ⊙ b 1 + a 3 ⊙ b 2 + b 3 ) = (\mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{a}_3 \odot \mathbf{b}_2 + \mathbf{b}_3) =(a3⊙a2⊙a1,a3⊙a2⊙b1+a3⊙b2+b3)

类似地:

( a 1 , b 1 ) ∙ ( a 2 , b 2 ) ∙ ( a 3 , b 3 ) (\mathbf{a}_1, \mathbf{b}_1) \bullet (\\mathbf{a}_2, \\mathbf{b}_2) \\bullet (\\mathbf{a}_3, \\mathbf{b}_3) (a1,b1)∙(a2,b2)∙(a3,b3)

= ( a 1 , b 1 ) ∙ ( a 3 ⊙ a 2 , a 3 ⊙ b 2 + b 3 ) = (\mathbf{a}_1, \mathbf{b}_1) \bullet (\mathbf{a}_3 \odot \mathbf{a}_2, \mathbf{a}_3 \odot \mathbf{b}_2 + \mathbf{b}_3) =(a1,b1)∙(a3⊙a2,a3⊙b2+b3)

= ( a 3 ⊙ a 2 ⊙ a 1 , a 3 ⊙ a 2 ⊙ b 1 + a 3 ⊙ b 2 + b 3 ) = (\mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{a}_1, \mathbf{a}_3 \odot \mathbf{a}_2 \odot \mathbf{b}_1 + \mathbf{a}_3 \odot \mathbf{b}_2 + \mathbf{b}_3) =(a3⊙a2⊙a1,a3⊙a2⊙b1+a3⊙b2+b3)

两者相等,结合律成立。 ■ \blacksquare ■

4.3 Blelloch 并行扫描算法

利用结合律,我们可以将 L L L 步串行递推转化为 O ( log ⁡ L ) O(\log L) O(logL) 深度的并行计算:

阶段 1:上扫(Up-Sweep) --- 自底向上归约

复制代码
层级 0: (a₁,b₁)  (a₂,b₂)  (a₃,b₃)  (a₄,b₄)  (a₅,b₅)  (a₆,b₆)  (a₇,b₇)  (a₈,b₈)
层级 1:           ●                      ●                      ●                      ●
层级 2:                               ●                                              ●
层级 3:                                                       ●

每一层将相邻的两个元素通过扫描算子合并,直到根节点。

阶段 2:下扫(Down-Sweep) --- 自顶向下分发

从根节点开始,将累积结果分发到每个叶子节点,使得位置 k k k 的结果为前 k k k 个元素的累积积。

指标 复杂度
总工作量 O ( L ) O(L) O(L)(与串行相同)
并行深度 O ( log ⁡ L ) O(\log L) O(logL)
加速比 O ( L / log ⁡ L ) O(L / \log L) O(L/logL)

4.4 硬件感知优化:分块扫描

在实际 GPU 实现中,纯粹的并行扫描面临两个问题:

  1. 内存带宽瓶颈:全局内存访问延迟高
  2. GPU 层级结构:线程 → Warp (32) → Block → Grid

Mamba 的硬件感知实现采用分块策略

复制代码
序列:  [块1: 64元素] [块2: 64元素] [块3: 64元素] ...
       ─────────────  ─────────────  ─────────────
       内部: 串行扫描   内部: 串行扫描   内部: 串行扫描
       (利用 SRAM)     (利用 SRAM)     (利用 SRAM)
              ↓               ↓               ↓
       块间: 并行扫描 (通过 HBM 传递状态)
  • 块内:在 SRAM(共享内存)中执行串行扫描,避免 HBM 访问
  • 块间:执行粗粒度并行扫描,仅传递块边界状态
  • IO 复杂度 : O ( L N / B ) O(LN/B) O(LN/B),其中 B B B 为 SRAM 容量

5. Mamba 架构设计

5.1 整体架构

Mamba 采用了简化的架构,没有使用传统的注意力机制:

复制代码
输入 x ∈ ℝ^(B×L×D)
    ↓
[线性投影] → x' ∈ ℝ^(B×L×E)    (扩展维度 E = 2D 或 4D)
    ↓
[Conv1D 1×1] → 局部特征提取
    ↓
[SiLU 激活]
    ↓
[选择性 SSM] → 核心计算
    ↓
[门控乘法] ← 从 x' 分出的门控分支
    ↓
[线性投影] → y ∈ ℝ^(B×L×D)
    ↓
残差连接 + LayerNorm

5.2 门控机制

Mamba 借鉴了 GLU(Gated Linear Unit)的思想:

MambaBlock ( x ) = Linear ( SiLU ( Conv1D ( Linear ( x ) ) ) ⊙ SSM ( SiLU ( Conv1D ( Linear ( x ) ) ) ) ) \text{MambaBlock}(\mathbf{x}) = \text{Linear}(\text{SiLU}(\text{Conv1D}(\text{Linear}(\mathbf{x}))) \odot \text{SSM}(\text{SiLU}(\text{Conv1D}(\text{Linear}(\mathbf{x}))))) MambaBlock(x)=Linear(SiLU(Conv1D(Linear(x)))⊙SSM(SiLU(Conv1D(Linear(x)))))

其中 SiLU(Sigmoid Linear Unit)定义为:

SiLU ( z ) = z ⋅ σ ( z ) = z 1 + e − z \text{SiLU}(z) = z \cdot \sigma(z) = \frac{z}{1 + e^{-z}} SiLU(z)=z⋅σ(z)=1+e−zz

门控的作用是让模型学习哪些信息需要通过 SSM 处理,哪些可以直通

5.3 参数化细节

对于维度 D D D 的输入,Mamba block 的参数:

参数 维度 说明
W i n \mathbf{W}_{in} Win R D × ( 2 E ) \mathbb{R}^{D \times (2E)} RD×(2E) 输入投影(分为两路:SSM 输入 + 门控)
Conv1D 核大小 4, E E E 通道 因果卷积
A \mathbf{A} A R E \mathbb{R}^E RE 对角矩阵,通过 log ⁡ \log log 参数化确保负值
W B \mathbf{W}_B WB R D s s × E \mathbb{R}^{D_{ss} \times E} RDss×E SSM 状态扩展因子
W C \mathbf{W}_C WC R D s s × E \mathbb{R}^{D_{ss} \times E} RDss×E 输出投影
W Δ \mathbf{W}_\Delta WΔ R E × D \mathbb{R}^{E \times D} RE×D 投影到标量后 softplus
W o u t \mathbf{W}_{out} Wout R E × D \mathbb{R}^{E \times D} RE×D 输出投影

6. 完整可运行实现

6.1 选择性 SSM 核心

python 复制代码
"""
Mamba: Selective State Space Model --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Optional, Tuple


def selective_scan(
    x: torch.Tensor,       # (B, L, D) 输入
    delta: torch.Tensor,    # (B, L, D) 步长参数
    A: torch.Tensor,        # (D, N) 状态转移矩阵(对角元素)
    B: torch.Tensor,        # (B, L, N) 输入投影(输入依赖)
    C: torch.Tensor,        # (B, L, N) 输出投影(输入依赖)
    D: torch.Tensor,        # (D,) 跳跃连接参数
    dt_rank: int,
) -> torch.Tensor:
    """
    选择性状态空间模型的前向传播。

    数学公式:
        Ā_k = exp(Δ_k · A)        (离散化状态转移)
        B̄_k = Δ_k · B_k           (离散化输入投影)
        h_k = Ā_k ⊙ h_{k-1} + B̄_k ⊙ x_k   (状态更新)
        y_k = C_k · h_k            (输出)
        y_k += D ⊙ x_k             (跳跃连接)
    """
    batch_size, seq_len, d_model = x.shape
    n = A.shape[1]  # 状态维度 N

    # 步骤 1: 离散化
    delta_A = torch.exp(
        delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)
    )  # (B, L, D, N)

    delta_B = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, D, N)

    # 步骤 2: 串行递推计算状态
    h = torch.zeros(batch_size, d_model, n, device=x.device, dtype=x.dtype)
    ys = []

    for k in range(seq_len):
        h = delta_A[:, k] * h + delta_B[:, k] * x[:, k].unsqueeze(-1)
        y_k = (h * C[:, k].unsqueeze(1)).sum(dim=-1)  # (B, D)
        ys.append(y_k)

    y = torch.stack(ys, dim=1)  # (B, L, D)

    # 步骤 3: 添加跳跃连接
    y = y + x * D.unsqueeze(0).unsqueeze(0)

    return y


def selective_scan_parallel(
    x: torch.Tensor,       # (B, L, D)
    delta: torch.Tensor,    # (B, L, D)
    A: torch.Tensor,        # (D, N)
    B: torch.Tensor,        # (B, L, N)
    C: torch.Tensor,        # (B, L, N)
    D: torch.Tensor,        # (D,)
) -> torch.Tensor:
    """
    并行前缀扫描实现选择性 SSM。

    核心思想: 将线性递推 h_k = a_k * h_{k-1} + b_k
    转化为前缀积问题,利用结合律并行计算。
    """
    batch_size, seq_len, d_model = x.shape
    n = A.shape[1]

    # 离散化
    delta_A = torch.exp(
        delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)
    )
    delta_B = delta.unsqueeze(-1) * B.unsqueeze(2)

    a = delta_A
    b = delta_B * x.unsqueeze(-1)

    # 并行前缀扫描 (Hillis-Steele)
    levels = int(np.ceil(np.log2(seq_len)))
    padded_len = 2 ** levels
    if padded_len > seq_len:
        pad_size = padded_len - seq_len
        a = F.pad(a, (0, 0, 0, 0, 0, pad_size))
        b = F.pad(b, (0, 0, 0, 0, 0, pad_size))

    for d in range(levels):
        stride = 2 ** d
        a_shifted = torch.zeros_like(a)
        b_shifted = torch.zeros_like(b)
        if stride < padded_len:
            a_shifted[:, stride:] = a[:, :-stride]
            b_shifted[:, stride:] = b[:, :-stride]

        a_new = a * a_shifted
        b_new = a * b_shifted + b

        mask = torch.arange(padded_len, device=a.device) >= stride
        a = torch.where(mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1), a_new, a)
        b = torch.where(mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1), b_new, b)

    a = a[:, :seq_len]
    b = b[:, :seq_len]

    y = (b * C.unsqueeze(2)).sum(dim=-1)
    y = y + x * D.unsqueeze(0).unsqueeze(0)

    return y

6.2 Mamba Block 实现

python 复制代码
@dataclass
class MambaConfig:
    """Mamba 模型配置"""
    d_model: int = 256
    d_state: int = 16
    d_conv: int = 4
    expand: int = 2
    dt_rank: str = "auto"
    dt_min: float = 0.001
    dt_max: float = 0.1
    dt_init: str = "random"
    dt_scale: float = 1.0
    bias: bool = False
    conv_bias: bool = True

    def __post_init__(self):
        self.d_inner = self.d_model * self.expand
        if self.dt_rank == "auto":
            self.dt_rank = math.ceil(self.d_model / 16)


class SelectiveSSM(nn.Module):
    """选择性状态空间模型层"""

    def __init__(self, config: MambaConfig):
        super().__init__()
        self.d_model = config.d_model
        self.d_state = config.d_state
        self.d_inner = config.d_inner
        self.dt_rank = config.dt_rank

        A = torch.arange(1, self.d_state + 1, dtype=torch.float32)
        self.A_log = nn.Parameter(torch.log(A))

        self.D = nn.Parameter(torch.ones(self.d_inner))

        self.x_proj = nn.Linear(
            self.d_inner,
            self.dt_rank + self.d_state * 2,
            bias=False
        )

        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        if config.dt_init == "constant":
            nn.init.constant_(
                self.dt_proj.weight,
                math.log(config.dt_max - config.dt_min)
            )
        else:
            dt_init_std = self.dt_rank ** -0.5 * config.dt_scale
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)

        dt = torch.exp(
            torch.rand(self.d_inner)
            * (math.log(config.dt_max) - math.log(config.dt_min))
            + math.log(config.dt_min)
        ).clamp(min=1e-4)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, seq_len, dim = x.shape
        xz = self.x_proj(x)
        dt, B, C = torch.split(
            xz,
            [self.dt_rank, self.d_state, self.d_state],
            dim=-1
        )
        dt = self.dt_proj(dt)
        dt = F.softplus(dt)
        A = -torch.exp(self.A_log)
        y = selective_scan(x, dt, A, B, C, self.D, self.dt_rank)
        return y


class MambaBlock(nn.Module):
    """完整的 Mamba 块"""

    def __init__(self, config: MambaConfig):
        super().__init__()
        self.config = config

        self.in_proj = nn.Linear(
            config.d_model,
            config.d_inner * 2,
            bias=config.bias
        )

        self.conv1d = nn.Conv1d(
            in_channels=config.d_inner,
            out_channels=config.d_inner,
            kernel_size=config.d_conv,
            bias=config.conv_bias,
            groups=config.d_inner,
            padding=config.d_conv - 1,
        )

        self.ssm = SelectiveSSM(config)
        self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
        self.norm = nn.LayerNorm(config.d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.norm(x)

        xz = self.in_proj(x)
        x_ssm, z = xz.chunk(2, dim=-1)

        x_ssm = x_ssm.transpose(1, 2)
        x_ssm = self.conv1d(x_ssm)[:, :, :x.size(1)]
        x_ssm = x_ssm.transpose(1, 2)

        x_ssm = F.silu(x_ssm)
        y = self.ssm(x_ssm)
        y = y * F.silu(z)

        y = self.out_proj(y)
        return y + residual

6.3 实验代码

python 复制代码
def demonstrate_parallel_scan():
    """演示并行前缀扫描与串行递推的等价性"""
    torch.manual_seed(42)

    batch_size = 2
    seq_len = 32
    d_model = 4
    d_state = 8

    x = torch.randn(batch_size, seq_len, d_model)
    delta = F.softplus(torch.randn(batch_size, seq_len, d_model))
    A = -torch.exp(torch.randn(d_model, d_state))
    B = torch.randn(batch_size, seq_len, d_state) * 0.1
    C = torch.randn(batch_size, seq_len, d_state) * 0.1
    D_param = torch.ones(d_model)

    y_serial = selective_scan(x, delta, A, B, C, D_param, dt_rank=2)
    y_parallel = selective_scan_parallel(x, delta, A, B, C, D_param)

    max_diff = (y_serial - y_parallel).abs().max().item()
    print(f"串行 vs 并行扫描结果比较:")
    print(f"  最大绝对误差: {max_diff:.6e}")
    print(f"  结果一致: {max_diff < 1e-4}")

7. Mamba 与 Transformer 的理论对比

7.1 计算复杂度

操作 Transformer Mamba
前向传播 O ( L 2 D + L D 2 ) O(L^2 D + L D^2) O(L2D+LD2) O ( L D N ) O(L D N) O(LDN)
推理(自回归) O ( L D ) O(L D) O(LD) 每步(KV cache) O ( D N ) O(D N) O(DN) 每步(仅状态)
内存(推理) O ( L D ) O(L D) O(LD) KV cache O ( D N ) O(D N) O(DN) 状态
并行深度 O ( 1 ) O(1) O(1)(矩阵乘法) O ( log ⁡ L ) O(\log L) O(logL)(扫描)

当 L ≫ N L \gg N L≫N 时(长序列),Mamba 的优势显著。

7.2 表达能力分析

Transformer 的优势

  • 全局注意力可以精确检索任意位置的信息
  • In-context learning 能力已被广泛验证

Mamba 的优势

  • 选择机制实现了动态的、内容感知的记忆管理
  • 状态压缩避免了冗余信息的存储
  • 线性复杂度使其能处理超长序列

理论局限

  • 有限状态 N N N 限制了 Mamba 能同时记住的信息量
  • 无法像 Transformer 那样进行精确的"复制"操作(需要 O ( L ) O(L) O(L) 状态)

7.3 选择机制与信息论

从信息论角度,选择机制可以理解为信息瓶颈的一种形式:

I ( h k ; x ≤ k ) ≤ N log ⁡ 2 ( 1 ϵ ) I(\mathbf{h}k; \mathbf{x}{\leq k}) \leq N \log_2 \left(\frac{1}{\epsilon}\right) I(hk;x≤k)≤Nlog2(ϵ1)

其中 ϵ \epsilon ϵ 是状态精度。有限维度 N N N 的状态只能保留有限的信息,选择机制的作用是最大化保留信息的相关性


8. 扩展与前沿发展

8.1 Mamba-2:状态空间对偶性

Mamba-2(Dao & Gu, 2024)发现了 SSM 与注意力之间的深层联系:

结构化状态空间对偶(SSD):选择性 SSM 可以等价地表示为一种结构化的半可分矩阵乘法,形式上类似于线性注意力。

y = SSM ( x )    ⟺    y = M x \mathbf{y} = \text{SSM}(\mathbf{x}) \iff \mathbf{y} = \mathbf{M} \mathbf{x} y=SSM(x)⟺y=Mx

其中 M \mathbf{M} M 是一个半可分矩阵:

M i j = { C i ( ∏ k = j + 1 i A ˉ k ) B ˉ j if i ≥ j 0 if i < j M_{ij} = \begin{cases} \mathbf{C}i \left(\prod{k=j+1}^{i} \bar{\mathbf{A}}_k\right) \bar{\mathbf{B}}_j & \text{if } i \geq j \\ 0 & \text{if } i < j \end{cases} Mij={Ci(∏k=j+1iAˉk)Bˉj0if i≥jif i<j

这种对偶性使得 Mamba-2 可以利用矩阵乘法的硬件优化,实现比 Mamba-1 更高的计算效率。

8.2 混合架构

实践中,纯 Mamba 模型在某些需要精确检索的任务上(如复制、查找)不如 Transformer。混合架构应运而生:

Jamba = Mamba layers ⊕ Attention layers \text{Jamba} = \text{Mamba layers} \oplus \text{Attention layers} Jamba=Mamba layers⊕Attention layers

例如,每 4 层 Mamba 后插入 1 层注意力,兼顾效率与检索能力。

8.3 多模态扩展

Mamba 已扩展到多个模态:

  • Vision Mamba (Vim):将图像分割为 patch,用双向 Mamba 处理
  • Video Mamba:利用 Mamba 的线性复杂度处理长视频序列
  • Genomic Mamba:处理超长 DNA 序列(百万碱基对)

9. Mamba 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════╗
║                    Mamba 数学公式总结                                        ║
╠══════════════════════════════════════════════════════════════════════════════╣
║                                                                            ║
║  1. 连续时间 SSM:                                                          ║
║     dh/dt = A·h(t) + B·x(t)                                              ║
║     y(t)  = C·h(t) + D·x(t)                                              ║
║                                                                            ║
║  2. ZOH 离散化:                                                            ║
║     Ā = exp(Δ·A)                                                          ║
║     B̄ = (Δ·A)^{-1}·(exp(Δ·A) - I)·Δ·B ≈ Δ·B                             ║
║                                                                            ║
║  3. 选择性 SSM (Mamba 的核心创新):                                         ║
║     B_k = Linear_B(x_k)       ← 输入依赖的 B                              ║
║     C_k = Linear_C(x_k)       ← 输入依赖的 C                              ║
║     Δ_k = softplus(Linear_Δ(x_k))  ← 输入依赖的步长                       ║
║                                                                            ║
║  4. 状态更新:                                                              ║
║     h_k = exp(Δ_k · A) ⊙ h_{k-1} + Δ_k · B_k · x_k                      ║
║     y_k = C_k · h_k + D · x_k                                             ║
║                                                                            ║
║  5. 并行扫描算子 (满足结合律):                                             ║
║     (a₁,b₁) • (a₂,b₂) = (a₂·a₁, a₂·b₁ + b₂)                            ║
║                                                                            ║
║  6. 复杂度:                                                                ║
║     串行: O(L·N) 时间, O(N) 空间                                          ║
║     并行: O(L·N) 工作, O(log L) 深度                                      ║
║     对比 Transformer: O(L²) 时间, O(L) 空间                                ║
║                                                                            ║
║  7. 选择机制的直觉:                                                        ║
║     Δ 大 → 重置状态, 聚焦当前输入 (写入)                                   ║
║     Δ 小 → 保留状态, 忽略当前输入 (记忆)                                   ║
║                                                                            ║
╚══════════════════════════════════════════════════════════════════════════════╝

第二篇:Flow Matching 与 Rectified Flows

1. 引言

生成模型的目标是从复杂的数据分布 p data ( x ) p_{\text{data}}(\mathbf{x}) pdata(x) 中采样。近年来,扩散模型(Diffusion Models) 在图像生成领域取得了巨大成功,但其依赖于随机微分方程(SDE)的迭代求解,采样需要数百步去噪,计算代价高昂。

Flow Matching (Lipman et al., 2023)和 Rectified Flows (Liu et al., 2023)提出了一种全新的生成建模范式:通过学习一个确定性的速度场,将简单的噪声分布(如高斯分布)沿直线路径传输到数据分布。这一方法:

  • 数学简洁:只需回归一个速度场,无需 SDE、分数函数或噪声调度
  • 路径最优:通过最优传输理论,构造接近直线的传输路径,大幅减少采样步数
  • 理论优雅:统一了扩散模型、Flow Matching 和最优传输的数学框架

2. 理论基础 --- 连续正规化流

2.1 从离散流到连续流

正规化流(Normalizing Flow) 通过一系列可逆变换将简单分布 p 0 p_0 p0 逐步变形为复杂分布 p T p_T pT:

z 0 ∼ p 0 ( z ) → f 1 z 1 → f 2 ⋯ → f T z T ∼ p T ( x ) \mathbf{z}_0 \sim p_0(\mathbf{z}) \quad \xrightarrow{f_1} \quad \mathbf{z}_1 \quad \xrightarrow{f_2} \quad \cdots \quad \xrightarrow{f_T} \quad \mathbf{z}_T \sim p_T(\mathbf{x}) z0∼p0(z)f1 z1f2 ⋯fT zT∼pT(x)

每一步变换 f k f_k fk 的概率密度通过变量替换公式追踪:

p k ( z k ) = p k − 1 ( z k − 1 ) ∣ det ⁡ ∂ f k ∂ z k − 1 ∣ − 1 p_{k}(\mathbf{z}k) = p{k-1}(\mathbf{z}{k-1}) \left| \det \frac{\partial f_k}{\partial \mathbf{z}{k-1}} \right|^{-1} pk(zk)=pk−1(zk−1) det∂zk−1∂fk −1

连续正规化流(Continuous Normalizing Flow, CNF) 将离散的 T T T 步推广到连续时间 t ∈ 0 , 1 t \in 0, 1 t∈0,1,用一个常微分方程(ODE)描述变换过程:

d z t d t = v θ ( z t , t ) \frac{d\mathbf{z}t}{dt} = \mathbf{v}\theta(\mathbf{z}_t, t) dtdzt=vθ(zt,t)

其中 v θ : R d × 0 , 1 → R d \mathbf{v}_\theta: \mathbb{R}^d \times 0, 1 \to \mathbb{R}^d vθ:Rd×0,1→Rd 是一个由神经网络参数化的速度场(velocity field)

2.2 概率密度的演化:连续性方程

CNF 的关键数学工具是连续性方程(Continuity Equation) ,它描述了概率密度 p t ( x ) p_t(\mathbf{x}) pt(x) 如何随时间演化:

∂ p t ( x ) ∂ t + ∇ ⋅ ( p t ( x )   v θ ( x , t ) ) = 0 \frac{\partial p_t(\mathbf{x})}{\partial t} + \nabla \cdot \left( p_t(\mathbf{x}) \, \mathbf{v}_\theta(\mathbf{x}, t) \right) = 0 ∂t∂pt(x)+∇⋅(pt(x)vθ(x,t))=0

其中 ∇ ⋅ \nabla \cdot ∇⋅ 是散度算子:

∇ ⋅ f = ∑ i = 1 d ∂ f i ∂ x i \nabla \cdot \mathbf{f} = \sum_{i=1}^d \frac{\partial f_i}{\partial x_i} ∇⋅f=i=1∑d∂xi∂fi

物理直觉 :这是流体力学中的质量守恒方程。 p t ( x ) p_t(\mathbf{x}) pt(x) 是流体密度, v θ ( x , t ) \mathbf{v}_\theta(\mathbf{x}, t) vθ(x,t) 是流速场。方程保证了概率质量守恒------没有概率凭空产生或消失。

证明(概率守恒):

对于任意区域 Ω ⊂ R d \Omega \subset \mathbb{R}^d Ω⊂Rd,概率质量的变化率为:

d d t ∫ Ω p t ( x )   d x = − ∫ ∂ Ω p t ( x )   v θ ( x , t ) ⋅ n   d S \frac{d}{dt} \int_\Omega p_t(\mathbf{x}) \, d\mathbf{x} = -\int_{\partial \Omega} p_t(\mathbf{x}) \, \mathbf{v}_\theta(\mathbf{x}, t) \cdot \mathbf{n} \, dS dtd∫Ωpt(x)dx=−∫∂Ωpt(x)vθ(x,t)⋅ndS

由散度定理:

= − ∫ Ω ∇ ⋅ ( p t ( x )   v θ ( x , t ) ) d x = -\int_\Omega \nabla \cdot \left( p_t(\mathbf{x}) \, \mathbf{v}_\theta(\mathbf{x}, t) \right) d\mathbf{x} =−∫Ω∇⋅(pt(x)vθ(x,t))dx

由于 Ω \Omega Ω 任意,被积函数必须为零,即得连续性方程。 ■ \blacksquare ■

2.3 对数概率的演化

对连续性方程两边除以 p t p_t pt,可得对数概率的演化方程:

∂ log ⁡ p t ( x ) ∂ t = − ∇ ⋅ v θ ( x , t ) − v θ ( x , t ) ⋅ ∇ log ⁡ p t ( x ) \frac{\partial \log p_t(\mathbf{x})}{\partial t} = -\nabla \cdot \mathbf{v}\theta(\mathbf{x}, t) - \mathbf{v}\theta(\mathbf{x}, t) \cdot \nabla \log p_t(\mathbf{x}) ∂t∂logpt(x)=−∇⋅vθ(x,t)−vθ(x,t)⋅∇logpt(x)

这个公式在后续推导 Flow Matching 损失函数时至关重要。


3. Flow Matching --- 条件速度场框架

3.1 从仿真损失到条件匹配

训练 CNF 的直接目标是最大化数据的对数似然。利用瞬时变量替换公式(Instantaneous Change of Variables):

log ⁡ p 1 ( x 1 ) = log ⁡ p 0 ( x 0 ) − ∫ 0 1 ∇ ⋅ v θ ( x t , t )   d t \log p_1(\mathbf{x}_1) = \log p_0(\mathbf{x}0) - \int_0^1 \nabla \cdot \mathbf{v}\theta(\mathbf{x}_t, t) \, dt logp1(x1)=logp0(x0)−∫01∇⋅vθ(xt,t)dt

其中 x t \mathbf{x}_t xt 是从 x 0 \mathbf{x}0 x0 沿速度场 v θ \mathbf{v}\theta vθ 积分得到的轨迹。

仿真损失(Simulation Loss)

L sim ( θ ) = E t ∼ U ( 0 , 1 ) ,   x 0 ∼ p 0 ∥ v θ ( x t , t ) − v target ( x t , t ) ∥ 2 \mathcal{L}{\text{sim}}(\theta) = \mathbb{E}{t \sim \mathcal{U}(0,1), \, \mathbf{x}_0 \sim p_0} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - \\mathbf{v}_{\\text{target}}(\\mathbf{x}_t, t) \\\|\^2 \\right Lsim(θ)=Et∼U(0,1),x0∼p0∥vθ(xt,t)−vtarget(xt,t)∥2

但直接计算 v target \mathbf{v}_{\text{target}} vtarget 需要知道边缘分布 p t ( x ) p_t(\mathbf{x}) pt(x) 的梯度,这在高维空间中是不可行的。

Flow Matching 的核心洞察 :虽然边缘速度场 v target \mathbf{v}_{\text{target}} vtarget 难以计算,但条件速度场 v t ( x ∣ x 1 ) \mathbf{v}_t(\mathbf{x} | \mathbf{x}_1) vt(x∣x1) 可以精确构造!

3.2 条件概率路径

给定一个数据样本 x 1 ∼ p data ( x ) \mathbf{x}1 \sim p{\text{data}}(\mathbf{x}) x1∼pdata(x),我们构造一个从噪声 x 0 ∼ p 0 ( x ) \mathbf{x}_0 \sim p_0(\mathbf{x}) x0∼p0(x) 到 x 1 \mathbf{x}_1 x1 的条件概率路径 p t ( x ∣ x 1 ) p_t(\mathbf{x} | \mathbf{x}_1) pt(x∣x1)。

最简单的路径是线性插值

x t = ( 1 − t ) x 0 + t x 1 , t ∈ 0 , 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1, \quad t \in 0, 1 xt=(1−t)x0+tx1,t∈0,1

对应的条件速度场为:

u t ( x t ∣ x 1 ) = d x t d t = x 1 − x 0 \mathbf{u}_t(\mathbf{x}_t | \mathbf{x}_1) = \frac{d\mathbf{x}_t}{dt} = \mathbf{x}_1 - \mathbf{x}_0 ut(xt∣x1)=dtdxt=x1−x0

关键性质 :这个速度场是无旋的(irrotational) ,即存在一个标量势函数 ϕ \phi ϕ 使得 u = ∇ ϕ \mathbf{u} = \nabla \phi u=∇ϕ。这与最优传输有深刻联系。

3.3 边缘速度场与条件速度场的关系

边缘概率路径定义为:

p t ( x ) = ∫ p t ( x ∣ x 1 )   p data ( x 1 )   d x 1 p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1) \, d\mathbf{x}_1 pt(x)=∫pt(x∣x1)pdata(x1)dx1

核心定理(Lipman et al., 2023):边缘速度场等于条件速度场的期望:

v t ∗ ( x ) = E p data ( x 1 ) u t ( x ∣ x 1 )   ∣   x t = x \mathbf{v}t^*(\mathbf{x}) = \mathbb{E}{p_{\text{data}}(\mathbf{x}_1)} \left \\mathbf{u}_t(\\mathbf{x} \| \\mathbf{x}_1) \\, \\Big\| \\, \\mathbf{x}_t = \\mathbf{x} \\right vt∗(x)=Epdata(x1)ut(x∣x1) xt=x

即:

v t ∗ ( x ) = ∫ u t ( x ∣ x 1 )   p t ( x ∣ x 1 )   p data ( x 1 ) p t ( x )   d x 1 \mathbf{v}_t^*(\mathbf{x}) = \int \mathbf{u}_t(\mathbf{x} | \mathbf{x}_1) \, \frac{p_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1)}{p_t(\mathbf{x})} \, d\mathbf{x}_1 vt∗(x)=∫ut(x∣x1)pt(x)pt(x∣x1)pdata(x1)dx1

证明:条件概率路径满足连续性方程:

∂ p t ( x ∣ x 1 ) ∂ t + ∇ ⋅ ( p t ( x ∣ x 1 )   u t ( x ∣ x 1 ) ) = 0 \frac{\partial p_t(\mathbf{x} | \mathbf{x}_1)}{\partial t} + \nabla \cdot \left( p_t(\mathbf{x} | \mathbf{x}_1) \, \mathbf{u}_t(\mathbf{x} | \mathbf{x}_1) \right) = 0 ∂t∂pt(x∣x1)+∇⋅(pt(x∣x1)ut(x∣x1))=0

对 x 1 \mathbf{x}_1 x1 积分:

∂ p t ( x ) ∂ t + ∇ ⋅ ( ∫ p t ( x ∣ x 1 )   u t ( x ∣ x 1 )   p data ( x 1 )   d x 1 ) = 0 \frac{\partial p_t(\mathbf{x})}{\partial t} + \nabla \cdot \left( \int p_t(\mathbf{x} | \mathbf{x}_1) \, \mathbf{u}_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1) \, d\mathbf{x}_1 \right) = 0 ∂t∂pt(x)+∇⋅(∫pt(x∣x1)ut(x∣x1)pdata(x1)dx1)=0

与边缘连续性方程对比:

∂ p t ( x ) ∂ t + ∇ ⋅ ( p t ( x )   v t ∗ ( x ) ) = 0 \frac{\partial p_t(\mathbf{x})}{\partial t} + \nabla \cdot \left( p_t(\mathbf{x}) \, \mathbf{v}_t^*(\mathbf{x}) \right) = 0 ∂t∂pt(x)+∇⋅(pt(x)vt∗(x))=0

可得:

p t ( x )   v t ∗ ( x ) = ∫ p t ( x ∣ x 1 )   u t ( x ∣ x 1 )   p data ( x 1 )   d x 1 p_t(\mathbf{x}) \, \mathbf{v}_t^*(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{x}_1) \, \mathbf{u}_t(\mathbf{x} | \mathbf{x}1) \, p{\text{data}}(\mathbf{x}_1) \, d\mathbf{x}_1 pt(x)vt∗(x)=∫pt(x∣x1)ut(x∣x1)pdata(x1)dx1

两边除以 p t ( x ) p_t(\mathbf{x}) pt(x) 即得结果。 ■ \blacksquare ■

3.4 条件 Flow Matching 损失

基于上述定理,Flow Matching 的训练损失定义为:

L CFM ( θ ) = E t ∼ U ( 0 , 1 ) ,   x 1 ∼ p data ,   x 0 ∼ p 0 ∥ v θ ( x t , t ) − u t ( x t ∣ x 1 ) ∥ 2 \mathcal{L}{\text{CFM}}(\theta) = \mathbb{E}{t \sim \mathcal{U}(0,1), \, \mathbf{x}1 \sim p{\text{data}}, \, \mathbf{x}_0 \sim p_0} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - \\mathbf{u}_t(\\mathbf{x}_t \| \\mathbf{x}_1) \\\|\^2 \\right LCFM(θ)=Et∼U(0,1),x1∼pdata,x0∼p0∥vθ(xt,t)−ut(xt∣x1)∥2

其中 x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1-t)\mathbf{x}_0 + t\mathbf{x}_1 xt=(1−t)x0+tx1。

等价性定理

L CFM ( θ ) = L FM ( θ ) + const \mathcal{L}{\text{CFM}}(\theta) = \mathcal{L}{\text{FM}}(\theta) + \text{const} LCFM(θ)=LFM(θ)+const

其中 L FM \mathcal{L}_{\text{FM}} LFM 是边缘 Flow Matching 损失。这意味着优化条件损失等价于优化边缘损失 ,常数项不依赖于 θ \theta θ。

证明

L CFM = E t , x 1 , x 0 ∥ v θ ( x t , t ) − u t ( x t ∣ x 1 ) ∥ 2 \mathcal{L}{\text{CFM}} = \mathbb{E}{t, \mathbf{x}_1, \mathbf{x}_0} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - \\mathbf{u}_t(\\mathbf{x}_t \| \\mathbf{x}_1) \\\|\^2 \\right LCFM=Et,x1,x0∥vθ(xt,t)−ut(xt∣x1)∥2

展开平方:

= E ∥ v θ ∥ 2 − 2 v θ ⋅ u t + ∥ u t ∥ 2 = \mathbb{E} \left \\\| \\mathbf{v}_\\theta \\\|\^2 - 2 \\mathbf{v}_\\theta \\cdot \\mathbf{u}_t + \\\| \\mathbf{u}_t \\\|\^2 \\right =E∥vθ∥2−2vθ⋅ut+∥ut∥2

中间项:

E x 1 , x 0 ∣ x t v θ ( x t , t ) ⋅ u t ( x t ∣ x 1 ) = v θ ( x t , t ) ⋅ v t ∗ ( x t ) \mathbb{E}_{\mathbf{x}_1, \mathbf{x}_0 | \mathbf{x}t} \left \\mathbf{v}_\\theta(\\mathbf{x}_t, t) \\cdot \\mathbf{u}_t(\\mathbf{x}_t \| \\mathbf{x}_1) \\right = \mathbf{v}\theta(\mathbf{x}_t, t) \cdot \mathbf{v}_t^*(\mathbf{x}_t) Ex1,x0∣xtvθ(xt,t)⋅ut(xt∣x1)=vθ(xt,t)⋅vt∗(xt)

最后一项 E ∥ u t ∥ 2 \mathbb{E}\\\| \\mathbf{u}_t \\\|\^2 E∥ut∥2 不依赖 θ \theta θ,因此:

L CFM = E t , x t ∥ v θ − v t ∗ ∥ 2 + const = L FM + const \mathcal{L}{\text{CFM}} = \mathbb{E}{t, \mathbf{x}t} \left \\\| \\mathbf{v}_\\theta - \\mathbf{v}_t\^\* \\\|\^2 \\right + \text{const} = \mathcal{L}{\text{FM}} + \text{const} LCFM=Et,xt∥vθ−vt∗∥2+const=LFM+const

■ \blacksquare ■


4. 最优传输与 Rectified Flows

4.1 最优传输基础

最优传输(Optimal Transport, OT) 问题:给定两个分布 μ \mu μ 和 ν \nu ν,找到一个传输映射 T : R d → R d T: \mathbb{R}^d \to \mathbb{R}^d T:Rd→Rd,使得:

T # μ = ν ( 即 T 将 μ 推前为 ν ) T_\# \mu = \nu \quad (\text{即 } T \text{ 将 } \mu \text{ 推前为 } \nu) T#μ=ν(即 T 将 μ 推前为 ν)

且传输代价最小:

min ⁡ T ∫ ∥ x − T ( x ) ∥ 2   μ ( x )   d x \min_T \int \| \mathbf{x} - T(\mathbf{x}) \|^2 \, \mu(\mathbf{x}) \, d\mathbf{x} Tmin∫∥x−T(x)∥2μ(x)dx

Brenier 定理:在欧氏空间中,最优传输映射存在且唯一(在适当条件下),并且它是某个凸函数的梯度:

T ∗ ( x ) = ∇ ϕ ( x ) T^*(\mathbf{x}) = \nabla \phi(\mathbf{x}) T∗(x)=∇ϕ(x)

其中 ϕ \phi ϕ 是凸函数。

4.2 OT 与直线路径

最优传输的一个关键性质是:OT 映射诱导的路径是直线

定义从 μ \mu μ 到 ν \nu ν 的 OT 路径:

x t = ( 1 − t ) x 0 + t T ∗ ( x 0 ) , t ∈ 0 , 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t T^*(\mathbf{x}_0), \quad t \in 0, 1 xt=(1−t)x0+tT∗(x0),t∈0,1

定理 (McCann 插值):这是 Wasserstein 距离 W 2 ( μ , ν ) W_2(\mu, \nu) W2(μ,ν) 的唯一测地线(geodesic)。

直觉解释:OT 映射将每个点直接送到目的地,不走弯路。因此每个粒子的轨迹都是直线,只是不同粒子的速度不同。

4.3 为什么标准 Flow Matching 的路径不是直线

在标准 Flow Matching 中,我们使用独立耦合

x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 xt=(1−t)x0+tx1

其中 x 0 \mathbf{x}_0 x0 和 x 1 \mathbf{x}_1 x1 是独立采样 的。这意味着 x 0 \mathbf{x}_0 x0 不一定被映射到与之"最优配对"的 x 1 \mathbf{x}_1 x1。

考虑一个简单例子: x 0 ∼ N ( 0 , 1 ) \mathbf{x}_0 \sim \mathcal{N}(0, 1) x0∼N(0,1), x 1 ∼ N ( 0 , 1 ) \mathbf{x}_1 \sim \mathcal{N}(0, 1) x1∼N(0,1)。

  • 独立耦合:随机配对,路径可能交叉,轨迹弯曲
  • OT 耦合: T ( x ) = x T(x) = x T(x)=x(当两个分布相同时),路径为直线(静止不动)

路径弯曲的代价

  • 速度场更复杂,更难学习
  • 采样时需要更多步数(因为轨迹弯曲,小步长才能精确积分)

4.4 Rectified Flows:拉直路径

Rectified Flow (Liu et al., 2023)的核心思想:先训练一个 Flow Matching 模型,然后用它生成的数据构造 OT 近似的耦合,再重新训练

步骤 1 :训练初始 Flow Matching 模型 v θ ( 0 ) \mathbf{v}_\theta^{(0)} vθ(0)

步骤 2 :用 v θ ( 0 ) \mathbf{v}_\theta^{(0)} vθ(0) 生成配对数据

x 0 ∼ p 0 , x 1 = ODESolve ( v θ ( 0 ) , x 0 , 0 , 1 ) \mathbf{x}_0 \sim p_0, \quad \mathbf{x}1 = \text{ODESolve}(\mathbf{v}\theta^{(0)}, \mathbf{x}_0, 0, 1) x0∼p0,x1=ODESolve(vθ(0),x0,0,1)

这构造了一个确定性耦合 ( x 0 , x 1 ) (\mathbf{x}_0, \mathbf{x}_1) (x0,x1)。

步骤 3:用新的配对数据重新训练

x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 xt=(1−t)x0+tx1

收敛定理(Liu et al., 2023):重复步骤 2-3,路径的"弯曲度"单调递减,最终收敛到直线路径。

4.5 路径直度的量化

定义路径直度(Straightness)

Straightness = E x 0 , x 1 ∫ 0 1 ∥ v θ ( x t , t ) − ( x 1 − x 0 ) ∥ 2   d t \text{Straightness} = \mathbb{E}_{\mathbf{x}_0, \mathbf{x}_1} \left \\int_0\^1 \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - (\\mathbf{x}_1 - \\mathbf{x}_0) \\\|\^2 \\, dt \\right Straightness=Ex0,x1∫01∥vθ(xt,t)−(x1−x0)∥2dt

当路径为直线时, v θ ( x t , t ) = x 1 − x 0 \mathbf{v}_\theta(\mathbf{x}_t, t) = \mathbf{x}_1 - \mathbf{x}_0 vθ(xt,t)=x1−x0(常数),Straightness = 0。

为什么直线路径更好?

直线路径意味着速度场沿路径为常数,因此:

  • 神经网络只需学习一个"简单"的函数
  • ODE 积分可以用更大的步长(因为曲率小)
  • 理论上只需 1 步就能精确采样

5. Flow Matching 损失函数的完整推导

5.1 线性插值路径下的条件速度场

给定 x 0 ∼ N ( 0 , I ) \mathbf{x}_0 \sim \mathcal{N}(0, \mathbf{I}) x0∼N(0,I), x 1 ∼ p data \mathbf{x}1 \sim p{\text{data}} x1∼pdata,线性插值:

x t = ( 1 − t ) x 0 + t x 1 = x 0 + t ( x 1 − x 0 ) \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 = \mathbf{x}_0 + t (\mathbf{x}_1 - \mathbf{x}_0) xt=(1−t)x0+tx1=x0+t(x1−x0)

条件速度场:

u t ( x t ∣ x 1 ) = d x t d t = x 1 − x 0 \mathbf{u}_t(\mathbf{x}_t | \mathbf{x}_1) = \frac{d\mathbf{x}_t}{dt} = \mathbf{x}_1 - \mathbf{x}_0 ut(xt∣x1)=dtdxt=x1−x0

注意:这个速度场不依赖于 t t t,沿整条路径为常数。

5.2 重新参数化

实践中,我们不直接预测速度 x 1 − x 0 \mathbf{x}_1 - \mathbf{x}_0 x1−x0,而是预测目标 x 1 \mathbf{x}_1 x1 (或噪声 x 0 \mathbf{x}_0 x0)。

速度参数化

v θ ( x t , t ) ≈ x 1 − x 0 \mathbf{v}_\theta(\mathbf{x}_t, t) \approx \mathbf{x}_1 - \mathbf{x}_0 vθ(xt,t)≈x1−x0

目标参数化

x ^ θ ( x t , t ) ≈ x 1 \hat{\mathbf{x}}_\theta(\mathbf{x}_t, t) \approx \mathbf{x}_1 x^θ(xt,t)≈x1

两者的关系:

v θ ( x t , t ) = x ^ θ ( x t , t ) − x t 1 − t \mathbf{v}_\theta(\mathbf{x}t, t) = \frac{\hat{\mathbf{x}}\theta(\mathbf{x}_t, t) - \mathbf{x}_t}{1 - t} vθ(xt,t)=1−tx^θ(xt,t)−xt

最简洁的做法是直接预测速度:

L ( θ ) = E t , x 0 , x 1 ∥ v θ ( x t , t ) − ( x 1 − x 0 ) ∥ 2 \mathcal{L}(\theta) = \mathbb{E}_{t, \mathbf{x}_0, \mathbf{x}_1} \left \\\| \\mathbf{v}_\\theta(\\mathbf{x}_t, t) - (\\mathbf{x}_1 - \\mathbf{x}_0) \\\|\^2 \\right L(θ)=Et,x0,x1∥vθ(xt,t)−(x1−x0)∥2

5.3 与扩散模型的统一

Flow Matching 提供了一个统一框架,将扩散模型视为特例:

DDPM :前向过程为 x t = α ˉ t x 1 + 1 − α ˉ t ϵ \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_1 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon} xt=αˉt x1+1−αˉt ϵ

这等价于 Flow Matching 中使用方差保持(VP)插值

x t = α ˉ t x 1 + 1 − α ˉ t x 0 \mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_1 + \sqrt{1 - \bar{\alpha}_t} \mathbf{x}_0 xt=αˉt x1+1−αˉt x0

对应的条件速度场:

u t ( x t ∣ x 1 ) = α ˉ ˙ t 2 α ˉ t x t + α ˉ ˙ t α ˉ t 2 ( 1 − α ˉ t ) ( x t − α ˉ t x 1 ) \mathbf{u}_t(\mathbf{x}_t | \mathbf{x}_1) = \frac{\dot{\bar{\alpha}}_t}{2\bar{\alpha}_t} \mathbf{x}_t + \frac{\dot{\bar{\alpha}}_t \bar{\alpha}_t}{2(1 - \bar{\alpha}_t)} (\mathbf{x}_t - \sqrt{\bar{\alpha}_t} \mathbf{x}_1) ut(xt∣x1)=2αˉtαˉ˙txt+2(1−αˉt)αˉ˙tαˉt(xt−αˉt x1)

这个速度场依赖于 t t t,且路径弯曲,这正是扩散模型需要多步采样的根本原因。

Flow Matching(线性插值)

x t = ( 1 − t ) x 0 + t x 1 \mathbf{x}_t = (1 - t) \mathbf{x}_0 + t \mathbf{x}_1 xt=(1−t)x0+tx1

速度场为常数 x 1 − x 0 \mathbf{x}_1 - \mathbf{x}_0 x1−x0,路径为直线。

不同插值方法对比

方法 α t \alpha_t αt σ t \sigma_t σt 路径形状
DDPM (VP) α ˉ t \sqrt{\bar{\alpha}_t} αˉt 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1−αˉt 弯曲
DDIM (VE) 1 1 − t \sqrt{1-t} 1−t 弯曲
Flow Matching t t t 1 − t 1-t 1−t 直线
Optimal Transport t t t 1 − t 1-t 1−t (OT 耦合) 直线

6. 完整可运行实现

6.1 核心工具函数

python 复制代码
"""
Flow Matching & Rectified Flow --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib, sklearn
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from typing import Tuple, Optional, List
import math
from tqdm import tqdm


def ode_solve(
    velocity_fn,
    x0: torch.Tensor,
    t_span: Tuple[float, float] = (0.0, 1.0),
    num_steps: int = 100,
    method: str = "euler",
) -> torch.Tensor:
    """
    求解 ODE: dx/dt = v(x, t)

    支持的求解器:
    - euler: 欧拉法, 一阶
    - rk4:   四阶 Runge-Kutta
    """
    dt = (t_span[1] - t_span[0]) / num_steps
    x = x0.clone()
    t = t_span[0]
    trajectory = [x.clone()]

    for step in range(num_steps):
        if method == "euler":
            v = velocity_fn(x, torch.full((x.shape[0],), t, device=x.device))
            x = x + dt * v
        elif method == "rk4":
            t_batch = torch.full((x.shape[0],), t, device=x.device)
            k1 = velocity_fn(x, t_batch)
            k2 = velocity_fn(x + 0.5 * dt * k1, t_batch + 0.5 * dt)
            k3 = velocity_fn(x + 0.5 * dt * k2, t_batch + 0.5 * dt)
            k4 = velocity_fn(x + dt * k3, t_batch + dt)
            x = x + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
        t += dt
        trajectory.append(x.clone())

    return x, trajectory

6.2 神经网络速度场

python 复制代码
class SinusoidalTimeEmbedding(nn.Module):
    """正弦时间嵌入"""

    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb


class VelocityNetwork2D(nn.Module):
    """2D 数据的速度场网络"""

    def __init__(self, hidden_dim: int = 256, time_dim: int = 64):
        super().__init__()
        self.time_embed = SinusoidalTimeEmbedding(time_dim)
        self.input_proj = nn.Linear(2 + time_dim, hidden_dim)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(4)
        ])
        self.output_proj = nn.Linear(hidden_dim, 2)
        nn.init.zeros_(self.output_proj.weight)
        nn.init.zeros_(self.output_proj.bias)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        t_emb = self.time_embed(t)
        h = torch.cat([x, t_emb], dim=-1)
        h = self.input_proj(h)
        for block in self.blocks:
            h = h + block(h)
        return self.output_proj(h)

6.3 Flow Matching 训练器

python 复制代码
class FlowMatchingTrainer:
    """Flow Matching 训练器"""

    def __init__(self, model: nn.Module, sigma_min: float = 1e-4):
        self.model = model
        self.sigma_min = sigma_min

    def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.rand(batch_size, device=device)

    def compute_loss(self, x1: torch.Tensor) -> torch.Tensor:
        batch_size = x1.shape[0]
        device = x1.device

        x0 = torch.randn_like(x1)
        t = self.sample_time(batch_size, device)

        if x1.dim() == 2:
            t_expand = t.unsqueeze(-1)
        elif x1.dim() == 4:
            t_expand = t.view(-1, 1, 1, 1)

        xt = (1 - t_expand) * x0 + t_expand * x1
        v_target = x1 - x0
        v_pred = self.model(xt, t)

        return F.mse_loss(v_pred, v_target)

    def train_step(self, x1: torch.Tensor, optimizer) -> float:
        optimizer.zero_grad()
        loss = self.compute_loss(x1)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        optimizer.step()
        return loss.item()

    @torch.no_grad()
    def sample(
        self,
        num_samples: int,
        shape: Tuple[int, ...],
        device: torch.device,
        num_steps: int = 50,
        method: str = "euler",
    ) -> torch.Tensor:
        x0 = torch.randn(num_samples, *shape, device=device)

        def velocity_fn(x, t):
            return self.model(x, t)

        x1, trajectory = ode_solve(
            velocity_fn, x0,
            t_span=(0.0, 1.0),
            num_steps=num_steps,
            method=method
        )
        return x1, trajectory

6.4 Rectified Flow

python 复制代码
class RectifiedFlow:
    """Rectified Flow: 通过 reflow 操作拉直传输路径"""

    def __init__(self, model_class, model_kwargs, device: torch.device):
        self.device = device
        self.model_class = model_class
        self.model_kwargs = model_kwargs
        self.models = []

    def train_round(
        self,
        data_samples: torch.Tensor,
        num_steps: int = 5000,
        batch_size: int = 256,
        lr: float = 1e-3,
        prev_model: Optional[nn.Module] = None,
        num_ode_steps: int = 50,
    ) -> nn.Module:
        model = self.model_class(**self.model_kwargs).to(self.device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
        trainer = FlowMatchingTrainer(model)

        dataset_size = data_samples.shape[0]
        losses = []

        for step in tqdm(range(num_steps), desc="Training"):
            idx = torch.randint(0, dataset_size, (batch_size,))
            x1_batch = data_samples[idx].to(self.device)

            if prev_model is not None:
                x0_batch = torch.randn_like(x1_batch)
                def prev_velocity(x, t):
                    return prev_model(x, t)
                with torch.no_grad():
                    x1_generated, _ = ode_solve(
                        prev_velocity, x0_batch,
                        t_span=(0.0, 1.0),
                        num_steps=num_ode_steps,
                        method="rk4"
                    )
                loss = trainer.compute_loss(x1_generated)
            else:
                loss = trainer.compute_loss(x1_batch)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            losses.append(loss.item())

        self.models.append(model)
        return model, losses

    def compute_path_straightness(
        self,
        model: nn.Module,
        num_samples: int = 100,
        num_timesteps: int = 50,
    ) -> Tuple[float, List[torch.Tensor]]:
        model.eval()
        x0 = torch.randn(num_samples, 2, device=self.device)

        def velocity_fn(x, t):
            return model(x, t)

        _, trajectory = ode_solve(
            velocity_fn, x0,
            t_span=(0.0, 1.0),
            num_steps=num_timesteps,
            method="euler"
        )

        x1 = trajectory[-1]
        straightnesses = []
        ts = np.linspace(0, 1, num_timesteps + 1)

        for i, (t_val, xt) in enumerate(zip(ts, trajectory)):
            with torch.no_grad():
                v_pred = model(xt, torch.full((num_samples,), t_val, device=self.device))
            v_target = x1 - x0
            straightness = ((v_pred - v_target) ** 2).mean().item()
            straightnesses.append(straightness)

        return np.mean(straightnesses), trajectory

6.5 实验代码

python 复制代码
def experiment_2d_flow_matching():
    """实验: 2D 数据上的 Flow Matching"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 生成数据
    data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    data = torch.tensor(data, dtype=torch.float32)

    # 创建模型
    model = VelocityNetwork2D(hidden_dim=256, time_dim=64).to(device)
    trainer = FlowMatchingTrainer(model)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    # 训练
    losses = []
    for step in range(3000):
        idx = torch.randint(0, data.shape[0], (512,))
        x1 = data[idx].to(device)
        loss = trainer.train_step(x1, optimizer)
        losses.append(loss)
        if (step + 1) % 500 == 0:
            print(f"Step {step+1} | Loss: {np.mean(losses[-100:]):.6f}")

    # 生成样本
    model.eval()
    generated, trajectory = trainer.sample(
        num_samples=2000, shape=(2,), device=device,
        num_steps=100, method="rk4"
    )

    return model, trainer, generated.cpu(), losses


def experiment_2d_rectified_flow():
    """实验: Rectified Flow (Reflow)"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data, _ = make_moons(n_samples=10000, noise=0.05, random_state=42)
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    data = torch.tensor(data, dtype=torch.float32)

    def create_model():
        return VelocityNetwork2D(hidden_dim=256, time_dim=64)

    rf = RectifiedFlow(create_model, {}, device)

    prev_model = None
    for round_idx in range(3):
        print(f"\n--- Round {round_idx + 1}/3 ---")
        model, losses = rf.train_round(
            data, num_steps=3000, batch_size=512,
            lr=1e-3, prev_model=prev_model, num_ode_steps=50
        )
        straightness, _ = rf.compute_path_straightness(model)
        print(f"  路径直度: {straightness:.6f}")
        prev_model = model

    return rf

7. 理论深度分析

7.1 概率流 ODE 与随机 SDE

扩散模型有两种等价表述:

概率流 ODE(PF-ODE)

d x t = f ( x t , t ) − 1 2 g ( t ) 2 ∇ log ⁡ p t ( x t ) d t d\mathbf{x}_t = \left \\mathbf{f}(\\mathbf{x}_t, t) - \\frac{1}{2} g(t)\^2 \\nabla \\log p_t(\\mathbf{x}_t) \\right dt dxt=f(xt,t)−21g(t)2∇logpt(xt)dt

反向 SDE

d x t = f ( x t , t ) − g ( t ) 2 ∇ log ⁡ p t ( x t ) d t + g ( t ) d w ˉ t d\mathbf{x}_t = \left \\mathbf{f}(\\mathbf{x}_t, t) - g(t)\^2 \\nabla \\log p_t(\\mathbf{x}_t) \\right dt + g(t) d\bar{\mathbf{w}}_t dxt=f(xt,t)−g(t)2∇logpt(xt)dt+g(t)dwˉt

Flow Matching 直接学习 PF-ODE 的速度场,避免了分数函数 ∇ log ⁡ p t \nabla \log p_t ∇logpt 的估计。

7.2 最优传输的理论保证

定理 (Benamou-Brenier 公式):两个分布 μ \mu μ 和 ν \nu ν 之间的 Wasserstein-2 距离为:

W 2 2 ( μ , ν ) = inf ⁡ v , p ∫ 0 1 ∫ R d ∥ v t ( x ) ∥ 2 p t ( x )   d x   d t W_2^2(\mu, \nu) = \inf_{\mathbf{v}, p} \int_0^1 \int_{\mathbb{R}^d} \| \mathbf{v}_t(\mathbf{x}) \|^2 p_t(\mathbf{x}) \, d\mathbf{x} \, dt W22(μ,ν)=v,pinf∫01∫Rd∥vt(x)∥2pt(x)dxdt

其中下确界取遍所有满足连续性方程的速度场 v \mathbf{v} v 和概率路径 p t p_t pt,且 p 0 = μ , p 1 = ν p_0 = \mu, p_1 = \nu p0=μ,p1=ν。

推论:OT 路径最小化了"动能",因此是最直的路径。Rectified Flow 的 reflow 操作正是在数值上逼近这个最优解。

7.3 单步生成的理论极限

理论上,如果速度场完美学习,只需 1 步 ODE 积分即可生成:

x 1 = x 0 + v θ ( x 0 , 0 ) \mathbf{x}_1 = \mathbf{x}0 + \mathbf{v}\theta(\mathbf{x}_0, 0) x1=x0+vθ(x0,0)

但这要求 v θ ( x , 0 ) \mathbf{v}_\theta(\mathbf{x}, 0) vθ(x,0) 精确等于 OT 映射 T ∗ ( x ) − x T^*(\mathbf{x}) - \mathbf{x} T∗(x)−x。实践中,Rectified Flow 通过 reflow 逐步逼近这一极限。


8. Flow Matching 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════╗
║                    Flow Matching & Rectified Flow 数学总结                       ║
╠══════════════════════════════════════════════════════════════════════════════════╣
║                                                                                ║
║  1. 连续正规化流 (CNF):                                                        ║
║     dx/dt = v_θ(x, t),  t ∈ [0, 1]                                           ║
║                                                                                ║
║  2. 连续性方程 (概率守恒):                                                     ║
║     ∂p_t(x)/∂t + ∇·(p_t(x) · v_θ(x, t)) = 0                                 ║
║                                                                                ║
║  3. 线性插值路径:                                                              ║
║     x_t = (1-t)·x₀ + t·x₁     (x₀ ~ N(0,I), x₁ ~ p_data)                   ║
║                                                                                ║
║  4. 条件速度场:                                                                ║
║     u_t(x_t | x₁) = dx_t/dt = x₁ - x₀                                       ║
║                                                                                ║
║  5. 条件 Flow Matching 损失:                                                   ║
║     L_CFM(θ) = E_{t,x₀,x₁} [ ‖v_θ(x_t, t) - (x₁ - x₀)‖² ]                ║
║                                                                                ║
║  6. 边缘速度场 = 条件期望:                                                    ║
║     v*_t(x) = E_{p(x₁|x_t=x)} [ u_t(x | x₁) ]                               ║
║                                                                                ║
║  7. 损失等价性:                                                                ║
║     L_CFM(θ) = L_FM(θ) + const  (优化条件损失 ≡ 优化边缘损失)                ║
║                                                                                ║
║  8. 最优传输 (OT) 路径:                                                        ║
║     x_t = (1-t)·x₀ + t·T*(x₀)   (T* = OT 映射, 路径为直线)                  ║
║                                                                                ║
║  9. Rectified Flow 收敛:                                                       ║
║     重复 reflow → 路径直度单调递减 → 收敛到 OT 直线                           ║
║                                                                                ║
║  10. 采样:                                                                     ║
║      x₁ = x₀ + ∫₀¹ v_θ(x_t, t) dt  (ODE 积分, 路径越直步数越少)            ║
║                                                                                ║
╚══════════════════════════════════════════════════════════════════════════════════╝

参考文献

Mamba

  1. Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022.
  2. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces.
  3. Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. ICML 2024.
  4. Gu, A., Dao, T., et al. (2020). HiPPO: Recurrent Memory with Optimal Polynomial Projections. NeurIPS 2020.

Flow Matching

  1. Lipman, Y., Chen, R.T., Ben-Hamu, H., Nickel, M., & Le, M. (2023). Flow Matching for Generative Modeling. ICLR 2023.
  2. Liu, X., Gong, C., & Liu, Q. (2023). Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. ICLR 2023.
  3. Esser, P., Kulal, S., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. ICML 2024.
  4. Chen, R.T., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural Ordinary Differential Equations. NeurIPS 2018.
  5. Song, Y., Sohl-Dickstein, J., et al. (2023). Consistency Models. ICML 2023.

GRPO

  1. Shao, Z., Wang, P., et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.
  2. DeepSeek-AI. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning.
  3. Schulman, J., Wolski, F., et al. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347.
  4. Schulman, J., Moritz, P., et al. (2016). High-Dimensional Continuous Control Using Generalized Advantage Estimation. ICLR 2016.
  5. Sutton, R.S., McAllester, D., et al. (1999). Policy Gradient Methods for Reinforcement Learning with Function Approximation. NeurIPS 1999.
  6. Ouyang, L., Wu, J., et al. (2022). Training Language Models to Follow Instructions with Human Feedback. NeurIPS 2022.
  7. Guo, D., Yang, D., et al. (2025). DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning.
相关推荐
小欣加油1 小时前
leetcode1926 迷宫中离入口最近的出口
数据结构·c++·算法·leetcode·职场和发展
happymaker06264 小时前
LeetCodeHot100——42.接雨水
算法
阿正的梦工坊4 小时前
【Rust】07-错误处理:Option、Result 与 ? 运算符
开发语言·算法·rust
八解毒剂6 小时前
数据结构-平衡二叉树——对二叉搜索树的优化
数据结构·c++·算法
运行时记录6 小时前
别再手动写提示词了 — SkillOpt 让技能文档自己进化
算法
啦啦啦啦啦zzzz6 小时前
算法总结(二分查找、双指针)
c++·算法
qq_8573058197 小时前
python语法
开发语言·python·算法
DXM05217 小时前
第9期|从机器学习到深度学习:AI遥感解译的进化逻辑
人工智能·算法·计算机视觉
小蒋学算法7 小时前
算法-阶乘函数后K个零
算法