FLOW MATCHING FOR GENERATIVE MODELING 阅读笔记

论文提出了一种新的生成模型。论文的目的是给定一个目标分布,有目标分布的一定量的样本,但是不知道目标分布的概率密度函数,学习一个模型能生成服从目标分布的新样本。

Flow Matching (FM)是一种训练连续标准化流Continuous Normalizing Flow (CNF)的方法。

FM是一种通用的方法。FM可以用于训练扩散路径,用FM训练扩散路径更稳定。FM也可以用于训练其他路径,一个例子是训练最优传输(OT)位移插值定义的条件概率路径,这些路径比扩散路径更有效,提供更快的训练和采样,从而获得更好的泛化效果。

核心的思想是把无条件估计问题的转换为有条件的问题的来学习。作者说是从denoised score matching得到的启发:

We first show that we can construct such target vector fields through per-example (i.e., conditional) formulations. Then, inspired by denoising score matching, we show that a per-example training objective, termed Conditional Flow Matching (CFM), provides equivalent gradients and does not require explicit knowledge of the intractable target vector field.

连续标准化流

数据点 x ∈ R d \pmb x \in \mathbb R^d x∈Rd,时变概率密度路径 p : [ 0 , 1 ] × R d → R > 0 p:[0,1] \times \mathbb R^d \rightarrow \mathbb R_{>0} p:[0,1]×Rd→R>0,时变向量场 v t : [ 0 , 1 ] × R d → R d v_t:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d vt:[0,1]×Rd→Rd。

流flow把一个分布映射成另一个分布,可以通过常微分方程用 v t v_t vt构建flow ϕ : [ 0 , 1 ] × R d → R d \phi:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d ϕ:[0,1]×Rd→Rd:
d ϕ t ( x ) d t = v t ( ϕ t ( x ) ) ϕ 0 ( x ) = x (1) \frac{d\phi_t(\pmb x)}{dt}=v_t(\phi_t(\pmb x)) \tag{1} \\ \phi_0(\pmb x)=\pmb x dtdϕt(x)=vt(ϕt(x))ϕ0(x)=x(1)时变向量场可以用神经网络 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)来建模,这样构建的flow ϕ t \phi_t ϕt叫做连续标准化流(Continuous Normalizing Flow,CNF)。CNF通常用于把一个简单的分布 p 0 p_0 p0变成一个复杂的分布 p 1 p_1 p1,其符合push-forward方程:
p t ( x ) = [ ϕ t ] ⋆ p 0 ( x ) = p 0 ( ϕ t − 1 ( x ) ) det ⁡ [ ∂ ϕ t − 1 ∂ x ( x ) ] p_t(x)=[\phi_t]_\star p_0(x)=p_0(\phi_t^{-1}(x))\det[\frac{\partial \phi_t^{-1}}{\partial x}(x)] pt(x)=[ϕt]⋆p0(x)=p0(ϕt−1(x))det[∂x∂ϕt−1(x)]我们的目标是采样服从复杂目标分布的样本,方法是首先随机采样服从简单分布的噪声样本 x ∼ N ( 0 , I ) \pmb x \sim \mathcal N (\pmb 0, \pmb I) x∼N(0,I),然后使用ODE求解器在区间 t ∈ [ 0 , 1 ] t \in [0, 1] t∈[0,1]上使用训练得到的向量场 v t v_t vt求解方程(1)得到服从目标分布的样本 ϕ 1 ( x ) \phi_1(\pmb x) ϕ1(x)。所以主要的问题是如何学习 v t ( x ; θ ) v_t(\pmb x; \theta) vt(x;θ)。

Flow Matching(FM)

用 x 1 \pmb x_1 x1表示服从未知的目标分布 q ( x 1 ) q(\pmb x_1) q(x1)的随机变量,我们不知道 q ( x 1 ) q(\pmb x_1) q(x1)的密度函数,但可以获得服从 q ( x 1 ) q(\pmb x_1) q(x1)的样本。用 p t p_t pt表示概率密度路径, p 0 p_0 p0服从标准高斯分布, p 1 p_1 p1近似 q q q。

Flow Matching的训练目标是学习 v t v_t vt,损失函数是 L F M ( θ ) = E t , p t ( x ) ∥ v t ( x ; θ ) − u t ( x ) ∥ 2 \mathcal L_{FM}(\theta)=\mathbb E_{t,p_t(\pmb x)}\|v_t(\pmb x; \theta)-u_t(\pmb x)\|^2 LFM(θ)=Et,pt(x)∥vt(x;θ)−ut(x)∥2流匹配的损失函数很简单,但在实践中没法使用,因为我们不知道如何定义合适的 p t p_t pt和 u t u_t ut。

Conditional Flow Matching(CFM)

为了解决上面的问题,考虑条件流匹配。条件流匹配的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ; θ ) − u t ( x ∣ x 1 ) ∥ 2 \mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p_t(\pmb x|\pmb x_1)}\|v_t(\pmb x; \theta)-u_t(\pmb x|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x;θ)−ut(x∣x1)∥2与流匹配的目标不同,条件流匹配的目标允许我们轻松地对无偏估计进行采样,只要我们可以从 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(x∣x1) 有效地采样并计算 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(x∣x1),这两者都可以很容易地完成,因为它们是对每个样本定义的。

论文中证明了优化CFM目标等价于优化FM目标(从期望的角度)。所以,剩下的问题是如何设计合适的条件概率路径 p t ( x ∣ x 1 ) p_t(\pmb x|\pmb x_1) pt(x∣x1)和向量场 u t ( x ∣ x 1 ) u_t(\pmb x|\pmb x_1) ut(x∣x1)。

条件概率路径和条件向量场

上面的讨论是通用的,并没有规定条件概率路径和条件向量场的形式。为了简单,作者讨论的是高斯条件概率路径:
p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(\pmb x|\pmb x_1)=\mathcal N(\pmb x| \mu_t(\pmb x_1), \sigma_t(\pmb x_1)^2\pmb I) pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)其中 μ 0 ( x 1 ) = 0 \mu_0(\pmb x_1)=0 μ0(x1)=0, σ 0 ( x 1 ) = 1 \sigma_0(\pmb x_1)=1 σ0(x1)=1, μ 1 ( x 1 ) = x 1 \mu_1(\pmb x_1)=\pmb x_1 μ1(x1)=x1, σ 1 ( x 1 ) = σ min ⁡ \sigma_1(\pmb x_1)=\sigma_{\min} σ1(x1)=σmin。

有无数的向量场可以产生给定的概率路径,这里作者讨论的是最简单的典型变换。

考虑条件flow:
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(\pmb x)= \sigma_t(\pmb x_1)\pmb x + \mu_t(\pmb x_1) ψt(x)=σt(x1)x+μt(x1)对应的条件向量场可以通过求解方程得到,并有封闭解:
u t ( x ∣ x 1 ) = σ t ′ ( x 1 ) σ t ( x 1 ) ( x − μ t ( x 1 ) ) + μ t ′ ( x 1 ) u_t(\pmb x|\pmb x_1)=\frac{\sigma't(\pmb x_1)}{\sigma_t(\pmb x_1)}(x-\mu_t(\pmb x_1))+\mu't(\pmb x_1) ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1)优化的损失函数是 L C F M ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ; θ ) − u t ( ψ t ( x 0 ) ∣ x 1 ) ∥ 2 \mathcal L{CFM}(\theta)=\mathbb E{t,q(\pmb x_1),p(\pmb x_0)}\|v_t(\psi_t(\pmb x_0); \theta)-u_t(\psi_t(\pmb x_0)|\pmb x_1)\|^2 LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0);θ)−ut(ψt(x0)∣x1)∥2

相关推荐
Pandaconda1 小时前
【C++ 面试 - 新特性】每日 3 题(六)
开发语言·c++·经验分享·笔记·后端·面试·职场和发展
手打猪大屁1 小时前
STM32——串口通信(发送/接收数据与中断函数应用)
经验分享·笔记·stm32·单片机·嵌入式硬件
阿拉伯的劳伦斯2921 小时前
LeetCode第一题(梦开始的地方)
数据结构·算法·leetcode
Mr_Xuhhh1 小时前
C语言深度剖析--不定期更新的第六弹
c语言·开发语言·数据结构·算法
贾saisai2 小时前
Xilinx系FPGA学习笔记(四)VIO、ISSP(Altera)及串口学习
笔记·学习·fpga开发
月夕花晨3742 小时前
C++学习笔记(13)
c++·笔记·学习
吵闹的人群保持笑容多冷静2 小时前
2024CCPC网络预选赛 I. 找行李 【DP】
算法
probably1212 小时前
学习记录之Java学习笔记3
java·笔记·学习
桃酥4032 小时前
算法day22|组合总和 (含剪枝)、40.组合总和II、131.分割回文串
数据结构·c++·算法·leetcode·剪枝
山脚ice2 小时前
【Hot100】LeetCode—55. 跳跃游戏
算法·leetcode