论文提出了一种新的生成模型。论文的目的是给定一个目标分布,有目标分布的一定量的样本,但是不知道目标分布的概率密度函数,学习一个模型能生成服从目标分布的新样本。
Flow Matching (FM)是一种训练连续标准化流Continuous Normalizing Flow (CNF)的方法。
FM是一种通用的方法。FM可以用于训练扩散路径,用FM训练扩散路径更稳定。FM也可以用于训练其他路径,一个例子是训练最优传输(OT)位移插值定义的条件概率路径,这些路径比扩散路径更有效,提供更快的训练和采样,从而获得更好的泛化效果。
核心的思想是把无条件估计问题的转换为有条件的问题的来学习。作者说是从denoised score matching得到的启发:
We first show that we can construct such target vector fields through per-example (i.e., conditional) formulations. Then, inspired by denoising score matching, we show that a per-example training objective, termed Conditional Flow Matching (CFM), provides equivalent gradients and does not require explicit knowledge of the intractable target vector field.
连续标准化流
数据点 x ∈ R d \pmb x \in \mathbb R^d x∈Rd,时变概率密度路径 p : [ 0 , 1 ] × R d → R > 0 p:[0,1] \times \mathbb R^d \rightarrow \mathbb R_{>0} p:[0,1]×Rd→R>0,时变向量场 v t : [ 0 , 1 ] × R d → R d v_t:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d vt:[0,1]×Rd→Rd。
流flow把一个分布映射成另一个分布,可以通过常微分方程用 v t v_t vt构建flow ϕ : [ 0 , 1 ] × R d → R d \phi:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d ϕ:[0,1]×Rd→Rd:
d ϕ t ( x ) d t = v t ( ϕ t ( x ) ) ϕ 0 ( x ) = x (1) \frac{d\phi_t(\pmb x)}{dt}=v_t(\phi_t(\pmb x)) \tag{1} \\ \phi_0(\pmb x)=\pmb x dtdϕt(x)=vt(ϕt(x))ϕ0(x)=x(1)时变向量场可以用神经网络 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)来建模,这样构建的flow ϕ t \phi_t ϕt叫做连续标准化流(Continuous Normalizing Flow,CNF)。CNF通常用于把一个简单的分布 p 0 p_0 p0变成一个复杂的分布 p 1 p_1 p1,其符合push-forward方程:
p t ( x ) = [ ϕ t ] ⋆ p 0 ( x ) = p 0 ( ϕ t − 1 ( x ) ) det [ ∂ ϕ t − 1 ∂ x ( x ) ] p_t(x)=[\phi_t]_\star p_0(x)=p_0(\phi_t^{-1}(x))\det[\frac{\partial \phi_t^{-1}}{\partial x}(x)] pt(x)=[ϕt]⋆p0(x)=p0(ϕt−1(x))det[∂x∂ϕt−1(x)]我们的目标是采样服从复杂目标分布的样本,方法是首先随机采样服从简单分布的噪声样本 x ∼ N ( 0 , I ) \pmb x \sim \mathcal N (\pmb 0, \pmb I) x∼N(0,I),然后使用ODE求解器在区间 t ∈ [ 0 , 1 ] t \in [0, 1] t∈[0,1]上使用训练得到的向量场 v t v_t vt求解方程(1)得到服从目标分布的样本 ϕ 1 ( x ) \phi_1(\pmb x) ϕ1(x)。所以主要的问题是如何学习 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)。
Flow Matching(FM)
用 x 1 \pmb x_1 x1表示服从未知的目标分布 q ( x 1 ) q(\pmb x_1) q(x1)的随机变量,我们不知道 q ( x 1 ) q(\pmb x_1) q(x1)的密度函数,但可以获得服从 q ( x 1 ) q(\pmb x_1) q(x1)的样本。用 p t p_t pt表示概率密度路径, p 0 p_0 p0服从标准高斯分布, p 1 p_1 p1近似 q q q。
Flow Matching的训练目标是学习 v t v_t vt,损失函数是 L F M ( θ ) = E t , p t ( x ) ∥ v t ( x ; θ ) − u t ( x ) ∥ 2 \mathcal L_{FM}(\theta)=\mathbb E_{t,p_t(\pmb x)}\|v_t(\pmb x; \theta)-u_t(\pmb x)\|^2 LFM(θ)=Et,pt(x)∥vt(x;θ)−ut(x)∥2流匹配的损失函数很简单,但在实践中没法使用,因为我们不知道如何定义合适的 p t p_t pt和 u t u_t ut。
Conditional Flow Matching(CFM)
为了解决上面的问题,考虑条件流匹配。条件流匹配的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ; θ ) − u t ( x ∣ x 1 ) ∥ 2 \mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p_t(\pmb x|\pmb x_1)}\|v_t(\pmb x; \theta)-u_t(\pmb x|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x;θ)−ut(x∣x1)∥2与流匹配的目标不同,条件流匹配的目标允许我们轻松地对无偏估计进行采样,只要我们可以从 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(x∣x1) 有效地采样并计算 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(x∣x1),这两者都可以很容易地完成,因为它们是对每个样本定义的。
论文中证明了优化CFM目标等价于优化FM目标(从期望的角度)。所以,剩下的问题是如何设计合适的条件概率路径 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(x∣x1)和向量场 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(x∣x1)。
条件概率路径和条件向量场
上面的讨论是通用的,并没有规定条件概率路径和条件向量场的形式。为了简单,作者讨论的是高斯条件概率路径:
p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(\pmb x|\pmb x_1)=\mathcal N(\pmb x| \mu_t(\pmb x_1), \sigma_t(\pmb x_1)^2\pmb I) pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)其中 μ 0 ( x 1 ) = 0 \mu_0(\pmb x_1)=0 μ0(x1)=0, σ 0 ( x 1 ) = 1 \sigma_0(\pmb x_1)=1 σ0(x1)=1, μ 1 ( x 1 ) = x 1 \mu_1(\pmb x_1)=\pmb x_1 μ1(x1)=x1, σ 1 ( x 1 ) = σ min \sigma_1(\pmb x_1)=\sigma_{\min} σ1(x1)=σmin。
有无数的向量场可以产生给定的概率路径,这里作者讨论的是最简单的典型变换。
考虑条件flow:
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(\pmb x)= \sigma_t(\pmb x_1)\pmb x + \mu_t(\pmb x_1) ψt(x)=σt(x1)x+μt(x1)对应的条件向量场可以通过求解方程得到,并有封闭解:
u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) u_t(\pmb x|\pmb x_1)=\frac{\sigma't(\pmb x_1)}{\sigma_t(\pmb x_1)}(x-\mu_t(\pmb x_1))+\mu't(\pmb x_1) ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1)优化的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ; θ ) − u t ( ψ t ( x 0 ) ∣ x 1 ) ∥ 2 \mathcal L{CFM}(\theta)=\mathbb E{t,q(\pmb x_1),p(\pmb x_0)}\|v_t(\psi_t(\pmb x_0); \theta)-u_t(\psi_t(\pmb x_0)|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0);θ)−ut(ψt(x0)∣x1)∥2