Rectified Flow 原理简介与示例代码解读

Rectified Flow 原理简介与示例代码解读

Rectified Flow 是最近非常火热的图像生成模型框架,最新的 SD3、Flux 等模型都是基于该框架。本文对其原理进行简单直观地介绍,并通过分析官方示例代码来加深理解。

论文原作者的中文解读:[ICLR2023] 扩散生成模型新方法:极度简化,一步生成。写得非常简单易懂,推荐阅读。

Flow

图像生成模型中,我们一般是训练模型将一个尽可能简单的先验分布(比如高斯分布)转换为数据分布。Rectified Flow 则考虑了一个更一般的传输映射问题,将任意分布转换为任意分布。对于来自不同分布的样本 X 0 ∼ π 0 , X 1 ∼ π 1 X_0\sim\pi_0,X_1\sim\pi_1 X0∼π0,X1∼π1,通过一个转换函数 T T T,来实现分布的转换 X 1 = T ( X 0 ) X_1=T(X_0) X1=T(X0)。这样就不仅可以实现将先验分布转换为图像数据分布这样的生成式建模,还可以实现将人脸图像分布转换为猫脸图像分布这样的域(风格)迁移,或者文本语义分布转换为图像语义分布这样的语义跨模态转换等等。

具体来说,Rectified Flow 模型通过一个时间 t ∈ [ 0 , 1 ] t\in[0,1] t∈[0,1] 上的 ODE
d X t = v ( X t , t ) d t (1) dX_t=v(X_t,t)dt \tag{1} dXt=v(Xt,t)dt(1)

来将 X 0 ∼ π 0 X_0\sim\pi_0 X0∼π0 转换到 X 1 ∼ π 1 X_1\sim\pi_1 X1∼π1。其中 v : R d → R d v:\mathbb{R}^d\rightarrow\mathbb{R}^d v:Rd→Rd 是一个向量场(可以理解为一个速度场)。可以形象地理解为 X 0 X_0 X0 是 π 0 \pi_0 π0 中的一个粒子,从 t = 0 t=0 t=0 时刻开始运动,在 t t t 时刻的速度为 v ( X t , t ) v(X_t,t) v(Xt,t),在 t = 1 t=1 t=1 时刻得到 X 1 X_1 X1,我们期望他能服从目标分布 X 1 ∼ π 1 X_1\sim\pi_1 X1∼π1。这里我们生成式建模的目标就是用一个神经网络估计过程粒子的速度 v ( X t , t ) v(X_t,t) v(Xt,t),从而保证结束时有 X 1 ∼ π 1 X_1\sim\pi_1 X1∼π1。

现在首先的问题就是,我们应该选择一个什么样的 ODE,什么形式的 ODE 更好?考虑到我们在训练好这个速度场之后,要通过对 ODE 的求解,来实现学习到的分布转换,以最常见的 Euler 法为例,即:
X t + ϵ = X t + ϵ v ( X t , t ) (2) X_{t+\epsilon}=X_t+\epsilon v(X_t,t) \notag \tag{2} Xt+ϵ=Xt+ϵv(Xt,t)(2)

其中 ϵ \epsilon ϵ 是离散化步长。 ϵ \epsilon ϵ 越小,近似越精确,但同时会导致采样步数增加,采样速度变慢,反之采样速度快但是近似误差大。因此,我们一般需要调整 ϵ \epsilon ϵ 来在采样速度和近似精度之间进行权衡。那么,什么样的 ODE 形式,或者说什么样的粒子轨迹,能够让我们更好地兼得这两点呢?答案其实很直觉:尽可能走直线 。如果轨迹特别弯曲,那我们就必须使用尽量多的步数进行尽可能高的近似精度。而如果轨迹足够"直",粒子初次移动的方向就是指向终点的,那么我们甚至可以一步实现的高精度近似。

为了得到 "直线" 的轨迹,作者考虑将这个粒子的移动轨迹直接建模为起终点的插值:
X t = t X 1 + ( 1 − t ) X 0 , t ∈ [ 0 , 1 ] (3) X_t=tX_1+(1-t)X_0,\quad t\in[0,1] \tag{3} Xt=tX1+(1−t)X0,t∈[0,1](3)

注意这里 X 0 X_0 X0 和 X 1 X_1 X1 是随机配对的,这里二者的配对关系是本文提出的 Reflow 方法要改进的点,后面我们详细介绍。现在我们对 X t X_t Xt 进行求导,就得到了对应的 ODE:
d d t X t = X 1 − X 0 (4) \frac{d}{dt}X_t=X_1-X_0 \tag{4} \\ dtdXt=X1−X0(4)

这看起来很合理,要将 X 0 X_0 X0 移动到 X 1 X_1 X1,沿着 X 1 − X 0 X_1-X_0 X1−X0 的方向走就好了。但问题是, X 1 X_1 X1 我们是未知的(实际上 X 1 X_1 X1 就是我们生成的最终目标)。遇到了未知的估计问题,就该神经网络登场了。我们引入一个神经网络,来估计这个速度场 v v v。训练目标就是最简单的 MSE 回归损失。
min ⁡ v ∫ 0 1 E X 0 ∼ π 0 , X 1 ∼ π 1 [ ∣ ∣ ( X 1 − X 0 ) − v ( X t , t ) ∣ ∣ 2 ] d t (5) \min_v\int_0^1\mathbb{E}_{X_0\sim\pi_0,X_1\sim\pi_1}\left[||(X_1-X_0)-v(X_t,t)||^2\right]dt \tag{5} vmin∫01EX0∼π0,X1∼π1[∣∣(X1−X0)−v(Xt,t)∣∣2]dt(5)

到这里,我们就可以对 Rectified Flow 模型进行训练(式 4)和采样(式 2)了。同时期也有 Flow Matching 等工作提出了类似的从最优传输和插值的角度构建生成模型。而 Rectified Flow 这篇工作实际上还进一步提出了将采样轨迹进一步拉直的方法,称为 "Reflow"。

Reflow

对于两个复杂的分布,随机配对并插值得到的轨迹几乎会出现交叉的情况。我们通过前面介绍的 Rectified Flow 模型的训练,得到了一个因果的速度场估计模型,避免轨迹交叉。但是,此时的采样轨迹仍然可能是弯曲的。为了进一步拉直轨迹,作者提出了 Reflow 方法,使用已经训练好的模型给出的样本配对,来训练一个新的模型,称为 "2-Rectified Flow"。此时训练目标可写为:
min ⁡ v ∫ 0 1 E X 0 ∼ π 0 , X 1 ∼ Flow 1 ( X 0 ) [ ∣ ∣ ( X 1 − X 0 ) − v ( X t , t ) ∣ ∣ 2 ] d t (6) \min_v\int_0^1\mathbb{E}_{X_0\sim\pi_0,X_1\sim\text{Flow}_1(X_0)}\left[||(X_1-X_0)-v(X_t,t)||^2\right]dt \tag{6} vmin∫01EX0∼π0,X1∼Flow1(X0)[∣∣(X1−X0)−v(Xt,t)∣∣2]dt(6)

相比式 5 只有 X 1 X_1 X1 的采样方式变了,变成根据已经训练好的模型对 X 0 X_0 X0 进行分布转换的结果。

在训练 "1-Rectified Flow" 时,我们用的是来自两个分布中随机采样的样本 X 0 ∼ π 0 , X 1 ∼ π 1 X_0\sim\pi_0,X_1\sim\pi_1 X0∼π0,X1∼π1,它们之间的配对关系完全是随机的,这会导致上面提到的轨迹交叉和弯曲的问题。在训练完 "1-Rectified Flow" 之后,我们可以根据模型的预测,得到配对的样本 X 0 ∼ π 0 , X 1 ∼ Flow 1 ( X 0 ) X_0\sim\pi_0,X_1\sim\text{Flow}_1(X_0) X0∼π0,X1∼Flow1(X0),我们用这种配对的样本来训练 "2-Rectified Flow",就可以改善这个问题,得到更 "直" 的轨迹。

下图提供了一个直观的理解。图中紫色点和红色点分别是分布 π 0 , π 1 \pi_0,\pi_1 π0,π1,虚线和实线分别表示 "训练时给定的配对关系下的插值结果" 和 "训练后的采样结果"。图 (a) 中,我们随机给定两个分布中样本点的配对关系,可以看到存在交叉的情形。这是因为此时是非因果的,粒子站在交叉点不知道往哪边走;图 (b) 是我们训练得到ODE 的采样结果,他是因果的,因此不存在交叉的情况,但是由于训练样本中给定的配对是随机的,可以看到还是存在弯曲的轨迹;图 © 中,我们根据已经训练好的 "1-Rectified Flow" 模型,给出样本配对关系,避免了轨迹交叉和弯曲的情况;图 (d) 中,根据这种配对关系训练出的 "2-Rectified Flow" 模型,得到了 "直" 的采样轨迹。

以上就是 Reflow 方法的一次拉直的情况,如果你愿意,也可以继续做 "3-Rectified Flow"、"4-Rectified Flow" ...。作者证明了这样继续下去可以单调地减少传输代价,得到更 "直" 的模型。但是注意,由于训练不可能是完美的,根据 "1-Rectified Flow" 模型得到的 X 1 ∼ Flow 1 ( X 0 ) X_1\sim\text{Flow}_1(X_0) X1∼Flow1(X0) 未必能很好地服从 π 1 \pi_1 π1,因此每次 Reflow 都会积累误差。所以,这里会存在一个 "服从 π 1 \pi_1 π1 的积累误差" 与 "更优的配对关系" 之间存在一个权衡关系,由 Reflow 的次数调控。作者指出,做一次 Reflow 就能得到比较好的结果了。

Reflow 与 Distillation

将已经训练好模型的输入输出配对作为样本,来训练新的模型,这听起来非常像是一种蒸馏。论文作者在其中文解读中解释了 Reflow 和蒸馏的区别,笔者这里也简单谈下自己的理解,可能有错误的地方,以原作者的解释为准。

我们考虑两种蒸馏的形式,一种是直接单步蒸馏,另一种是 Consistency Models 中的一致性蒸馏(注意 CM 比 RF 要晚,所以原作者没有讨论 CM)。

  • 直接单步蒸馏是指只训练从 X 0 X_0 X0 一步到 X 1 X_1 X1 的过程,相当于将 Reflow 目标函数(式 6)中的对 t t t 从 0 到 1 的积分拿掉,只取 t = 0 t=0 t=0 的情况。这种情况作者有讨论其与 Reflow 的区别。此时由于监督信号只有起终点 X 0 , X 1 X_0,X_1 X0,X1,而中间 t ∈ ( 0 , 1 ] t\in(0,1] t∈(0,1] 的情况都没有监督,因此模型只能 "死板地" 学习起终点的直接映射关系。而在 Reflow 中,中间过程 X t X_t Xt 的边缘分布也需要学习(实际上,学习边缘分布是最重要的),因此,Reflow 过程中模型可以改善 X 0 , X 1 X_0,X_1 X0,X1 的配对关系。这一点是单步蒸馏无法做到的。
  • 一致性蒸馏是要求从教师模型轨迹上任意一点能够直接一步回到 X 0 X_0 X0,在这种情况下,学生模型还是在教师模型的轨迹上进行学习。而在 Reflow 中,教师模型(k-1 Rectified Flow)仅产生起终点的样本配对,中间的轨迹还是由式 3 插值出来的 "直线"。

当然了,作者也提到在 Reflow 过后的最后一步,样本配对已经足够好时,再结合蒸馏,能够进一步提高模型在低采样步数时的生成效果。

个人感觉 Reflow 也可以认为是一种蒸馏,只是形式不同。

示例代码

论文作者给出了示例代码 Colab 来帮助我们理解。该示例是一个二维情况下 Rectified Flow 训练和采样的过程,并通过可视化展示了 Reflow 前后,模型单步采样的结果分布和轨迹。我们简单过一下其中的关键部分。

首先,作者定义了两个简单的二维分布 π 0 , π 1 \pi_0,\pi_1 π0,π1,其可视化结果如下图所示。其中蓝色、橙色点分别表示 π 0 , π 1 \pi_0,\pi_1 π0,π1 的采样结果,可以看到,两个分布彼此交错,如果随机配对的话,很可能会出现轨迹弯曲或者交叉的情况。在实际的图像生成中,分布肯定比这复杂得多,出现弯曲或者交叉的情况几乎是必然的。

作者随后定义了一个简单的 MLP 网络,就是接收当前的 X t X_t Xt 和 t t t,返回预测的速度场 v = model(x_t, t),这里略过。

再接下来定义了 Rectified Flow 类和训练的过程,这里是重点。以下是 Rectified Flow 类的定义,我们先看采样的部分,即 sample_ode 方法。该方法接收一个起始样本点 z0 和采样步数 N,通过最常见的欧拉法来进行采样生成。在每一步中,dt 相当于是步长 ϵ \epsilon ϵ,根据模型估计的速度场 pred,根据式 2 进行采样生成 z = z + pred * dt

python 复制代码
class RectifiedFlow():
  def __init__(self, model=None, num_steps=1000):
    self.model = model
    self.N = num_steps
  
  def get_train_tuple(self, z0=None, z1=None):
    t = torch.rand((z1.shape[0], 1))
    z_t =  t * z1 + (1.-t) * z0
    target = z1 - z0 
        
    return z_t, t, target

  @torch.no_grad()
  def sample_ode(self, z0=None, N=None):
    ### NOTE: Use Euler method to sample from the learned flow
    if N is None:
      N = self.N    
    dt = 1./N
    traj = [] # to store the trajectory
    z = z0.detach().clone()
    batchsize = z.shape[0]
    
    traj.append(z.detach().clone())
    for i in range(N):
      t = torch.ones((batchsize,1)) * i / N
      pred = self.model(z, t)
      z = z.detach().clone() + pred * dt
      
      traj.append(z.detach().clone())

    return traj

然后看训练部分,这里用到了上面 Rectified Flow 类的 get_train_tuple 函数,该函数是均匀采样一个时间 t,然后根据式 3、4 分别计算出 z_ttarget。在训练循环中,就是将 z_tt 传入模型计算出 pred,最小化其与 target 的 MSE(式 5)。

python 复制代码
from tqdm import tqdm
def train_rectified_flow(rectified_flow, optimizer, pairs, batchsize, inner_iters):
  loss_curve = []
  for i in tqdm(range(inner_iters+1)):
    optimizer.zero_grad()
    indices = torch.randperm(len(pairs))[:batchsize]
    batch = pairs[indices]
    z0 = batch[:, 0].detach().clone()
    z1 = batch[:, 1].detach().clone()
    z_t, t, target = rectified_flow.get_train_tuple(z0=z0, z1=z1)

    pred = rectified_flow.model(z_t, t)
    loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)
    loss = loss.mean()
    loss.backward()
    
    optimizer.step()
    loss_curve.append(np.log(loss.item())) ## to store the loss curve

  return rectified_flow, loss_curve

接下来定义了可视化函数,用于可视化采样结果的分布与轨迹,代码细节略过。

然后就开始进行训练了。一开始,我们只能对 x_0x_1 进行随机配对,就是下面用 torch.randperm 实现的。在定义好超参数后开始训练,代码中会展示训练的损失曲线,这里图就不贴了。

python 复制代码
x_0 = samples_0.detach().clone()[torch.randperm(len(samples_0))]
x_1 = samples_1.detach().clone()[torch.randperm(len(samples_1))]
x_pairs = torch.stack([x_0, x_1], dim=1)

iterations = 10000
batchsize = 2048
input_dim = 2

rectified_flow_1 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
optimizer = torch.optim.Adam(rectified_flow_1.model.parameters(), lr=5e-3)

rectified_flow_1, loss_curve = train_rectified_flow(rectified_flow_1, optimizer, x_pairs, batchsize, iterations)
plt.plot(np.linspace(0, iterations, iterations+1), loss_curve[:(iterations+1)])
plt.title('Training Loss Curve')

训练结束后,我们进行采样,并可视化其结果分布和轨迹。首先我们将采样步数设置为 100,结果如下。左图中绿色点是采样结果的分布,可以看到,效果还是不错的,采样结果基本上都分布在目标 π 1 \pi_1 π1 附近。但是,看右图中的采样轨迹,可以看到一部分是比较理想的直线轨迹,连接相邻的 π 0 , π 1 \pi_0,\pi_1 π0,π1,但是也有很多 ">" 状的弯曲轨迹,先朝向中心,在拐到目标分布附近。这就是前面说的 "1-Rectified Flow" 模型存在的弯曲的情况。在这种情况下,如果采样步数较少,是很难达到 X 1 ∼ π 1 X_1\sim\pi_1 X1∼π1 的。

python 复制代码
draw_plot(rectified_flow_1, z0=initial_model.sample([2000]), z1=samples_1.detach().clone(), N=100)

接下来我们看当采样步数只有 1 步时的可视化结果。可以看到,正如我们前面分析的那样,在少采样步数的情况下,由于轨迹不够 "直",没法直接指向目标分布,采样结果就非常不准了。

python 复制代码
draw_plot(rectified_flow_1, z0=initial_model.sample([2000]), z1=samples_1.detach().clone(), N=1)

接下来,我们进行 Reflow 训练。利用已经训练好的 "1-Rectified Flow" 模型,构造样本配对 z10z11。根据配对好的样本进行训练,其他的都不用变,完全复用相同的训练过程即可。训练结束,得到 "2-Rectified Flow" 模型。

python 复制代码
z10 = samples_0.detach().clone()
traj = rectified_flow_1.sample_ode(z0=z10.detach().clone(), N=100)
z11 = traj[-1].detach().clone()
z_pairs = torch.stack([z10, z11], dim=1)

reflow_iterations = 50000

rectified_flow_2 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
import copy 
rectified_flow_2.net = copy.deepcopy(rectified_flow_1) # we fine-tune the model from 1-Rectified Flow for faster training.
optimizer = torch.optim.Adam(rectified_flow_2.model.parameters(), lr=5e-3)

rectified_flow_2, loss_curve = train_rectified_flow(rectified_flow_2, optimizer, z_pairs, batchsize, reflow_iterations)
plt.plot(np.linspace(0, reflow_iterations, reflow_iterations+1), loss_curve[:(reflow_iterations+1)])

然后我们看 "2-Rectified Flow" 模型在 100 步和 1 步下的采样可视化结果。这次可以看到,采样的轨迹非常 "直",直接走向目标分布。这样,即使在 1 步时,也能比较准的走到目标分布上。

python 复制代码
draw_plot(rectified_flow_2, z0=initial_model.sample([1000]), z1=samples_1.detach().clone(), N=1)
draw_plot(rectified_flow_2, z0=initial_model.sample([1000]), z1=samples_1.detach().clone(), N=1)

总结

Rectified Flow 直接在两分布间插值构建了一种新的扩散模型形式,在简化形式的同时期望实现 "走直线" 的采样轨迹,从而减少采样步数,加快生图速度,这种形式已经在 SD3、Flux 等最新的图像生成中得到了验证和应用。本文还进一步提出了 Reflow 方法,通过训练好的模型的预测结果来为新模型的训练构造样本配对,从而能够进一步拉直采样轨迹,提高低步数下的采样质量。并为理解扩散模型蒸馏提供了一种新的角度。

相关推荐
MinIO官方账号13 分钟前
使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器实现可迭代式数据集
人工智能·pytorch·python
四口鲸鱼爱吃盐15 分钟前
Pytorch | 利用IE-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·计算机视觉
四口鲸鱼爱吃盐17 分钟前
Pytorch | 利用EMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
章章小鱼44 分钟前
LLM预训练recipe — 摘要版
人工智能
红色的山茶花1 小时前
YOLOv9-0.1部分代码阅读笔记-loss_tal_dual.py
笔记·深度学习·yolo
算家云1 小时前
Stability AI 新一代AI绘画模型:StableCascade 本地部署教程
人工智能·ai作画·stable diffusion·模型构建·算家云·算力租赁·stablecascade
RacheV+TNY2642781 小时前
深度解析:电商平台API接口的安全挑战与应对策略
人工智能·python·自动化·api
学术会议1 小时前
“智能控制的新纪元:2025年机器学习与控制工程国际会议引领变革
大数据·人工智能·科技·计算机网络·机器学习·区块链
呆头鹅AI工作室2 小时前
基于特征工程(pca分析)、小波去噪以及数据增强,同时采用基于注意力机制的BiLSTM、随机森林、ARIMA模型进行序列数据预测
人工智能·深度学习·神经网络·算法·随机森林·回归
huhuhu15322 小时前
第P4周:猴痘病识别
图像处理·python·深度学习·cnn