diffusion原理和代码延伸笔记1------扩散桥,GOUB,UniDB
引言
扩散模型包含前向过程和反向过程,前向过程把数据分布映射成高斯分布,反向过程想复原之,不过这种映射是一整个分布和一整个分布之间的,学习的是如何从噪声中创造出新的数据。要是图像修复,去雨,超分,分子设计等需要点对点的任务,不怎么需要"创造"能力的任务,比如在药物设计中,我们可能需要生成一个起始构象和目标构象是固定的分子,在图像修复,去雨等问题上,可以理解为一对一的点任务。这一篇笔记介绍GOUB和更为广泛且纳入GOUB,VE,VP等为特殊情况的UniDB。
扩散桥
扩散桥,Diffusion Bridge,基于SDE,它要连接两个已知的端点,作为一个约束的限制。
与DDPM等基于马尔科夫链(离散化SDE)不同,扩散桥不再是关心 p ( x 0 ) p(\mathbf{x}_0) p(x0),而是关心在已知 X s = x s \mathbf{X}_s = \mathbf{x}_s Xs=xs 和 X T = x T \mathbf{X}_T = \mathbf{x}_T XT=xT 的情况下,中间状态 X t , s < t < T \mathbf{X}_t, s<t<T Xt,s<t<T 的条件概率分布 p ( x t ∣ x s , x T ) p(\mathbf{x}_t | \mathbf{x}_s, \mathbf{x}_T) p(xt∣xs,xT)。
从动态视角看, p ( x t ∣ x 0 , x T ) p(\mathbf{x}_t | \mathbf{x}_0, \mathbf{x}_T) p(xt∣x0,xT)其实是一个随机过程,形象点说,描述了一个粒子从 t = 0 t=0 t=0的状态开始随机游走,最终被拉到 x T \mathbf{x_T} xT这个终点。每一次实验下,路径都是随机的,这个粒子的位置也是不确定的。如果剖析一个时间点的话,就可以发现对于一个固定的 x t \mathbf{x_t} xt,其概率分布是完全确定的,后文可以发现确定的是 p ( x t ∣ x 0 , x T ) ∼ N ( m e a n , v a r i a n c e ) p(\mathbf{x}_t | \mathbf{x}_0, \mathbf{x}_T) \sim \mathcal{N} (mean, variance) p(xt∣x0,xT)∼N(mean,variance),这里先不给出平均值和方差的具体形式。
好,为了满足这个双端点约束,SDE方程需要做一些修改。
一个标准的SDE如下:
d X t = f ( X t , t ) d t + g t d W t , x 0 ∼ p ( x 0 ) d\mathbf{X}_t = \mathbf{f}(\mathbf{X}_t, t) dt + g_t d\mathbf{W}_t, \quad \mathbf{x_0} \sim p(\mathbf{x_0}) dXt=f(Xt,t)dt+gtdWt,x0∼p(x0)
其中 f ( X t , t ) \mathbf{f}(\mathbf{X}_t, t) f(Xt,t) 是漂移项, g ( t ) d W t g(t) d\mathbf{W}_t g(t)dWt 是扩散项。
接下来记录一种扩散桥的实现方式。
Doob's h-transform
接下来是Generalized Ornstein-Uhlenbeck,一个基于OU方程扩展的sde,这是一个 t t t趋于无穷会保持结果平稳的高斯-马尔可夫过程(线性上有wiener过程,线性组合是高斯分布,马尔可夫性质),任意时刻 t t t的边际概率分布随时间 t t t的增加会逐渐趋近于一个稳定的均值和方差:
d X t = θ t ( μ − X t ) d t + g t d W t \begin{equation} d\mathbf{X}_t = \theta_t(\boldsymbol{\mu} - \mathbf{X}_t) dt + g_t d\mathbf{W}_t \end{equation} dXt=θt(μ−Xt)dt+gtdWt
其中 μ \boldsymbol{\mu} μ是给定的状态向量, θ t \theta_t θt表示标量漂移系数, g t g_t gt表示扩散系数。假定 θ t \theta_t θt、 g t g_t gt满足指定关系 2 λ 2 = g t 2 / θ t 2\lambda^2 = g_t^2 / \theta_t 2λ2=gt2/θt,其中 λ 2 \lambda^2 λ2是给定的常量标量。因此,其转移概率具有一个封闭形式的解析解:
p ( x t ∣ x s ) = N ( m ˉ s : t , σ ˉ s : t 2 I ) = N ( μ + ( x s − μ ) e − θ ˉ s : t , g t 2 2 θ t ( 1 − e − 2 θ ˉ s : t ) I ) , θ ˉ s : t = ∫ s t θ z d z \begin{align} p (x_t | x_s) &= \mathcal{N}(\bar{m}{s:t}, \bar{\sigma}^2{s:t}I) \\ &= \mathcal{N}\left(\mu+ (x_s -\mu) e^{-\bar{\theta}{s:t}} , \frac{g_t^2}{2\theta_t} \left(1 - e^{-2\bar{\theta}{s:t}}\right) I \right), \quad \\ \bar{\theta}_{s:t} &= \int_s^t \theta_zdz \end{align} p(xt∣xs)θˉs:t=N(mˉs:t,σˉs:t2I)=N(μ+(xs−μ)e−θˉs:t,2θtgt2(1−e−2θˉs:t)I),=∫stθzdz
随着时间 t t t的推移,整个 X t \mathbf{X_t} Xt会收敛于 N ( μ , λ 2 ) \mathcal{N} (\boldsymbol{\mu}, \lambda^2) N(μ,λ2).
这个是怎么推导的?Yue等人在Appendix C给出了推导,但没有对寻找的辅助函数进行说明。这里给一个证明。
一般地对于随机过程若有:
d X t = ( A t X t + B t ) d t + ( C t X t + D t ) d W t dX_t = (A_t X_t + B_t)dt + (C_t X_t + D_t) dW_t dXt=(AtXt+Bt)dt+(CtXt+Dt)dWt
有一个本身符合Ito过程的 I t = μ I ( t ) d t + σ I ( t ) d W t I_t = \mu_I(t) dt + \sigma_I(t) dW_t It=μI(t)dt+σI(t)dWt,对于 d Y t = d X t I t = I t d W t + W t d I t + d < I , X > t dY_t = dX_tI_t = I_t dW_t + W_t dI_t + d<I,X>_t dYt=dXtIt=ItdWt+WtdIt+d<I,X>t, d < I , X > t d<I,X>_t d<I,X>t为一个二次协变差,其为 ( σ I ( t ) d W t ) ( C t X t d W t ) = σ I ( t ) C t X t d t (\sigma_I(t) dW_t )(C_t X_t dW_t) = \sigma_I(t) C_t X_t dt (σI(t)dWt)(CtXtdWt)=σI(t)CtXtdt,代入到 d X t I t dX_tI_t dXtIt之中,消除让积分变得困难的随机过程 X t X_t Xt,则可以获得对应的 I t I_t It。
将我们找到的 I ( t ) I(t) I(t)代回到 d Y t dY_t dYt的表达式中。这里你甚至可以不用直接赵 I ( t ) I(t) I(t),而是找到一个表达式: ( I ′ ( t ) − I ( t ) θ t ) x t = 0 (I'(t) - I(t)\theta_t)x_t = 0 (I′(t)−I(t)θt)xt=0,可以看到方程大大简化了:
d Y t = ( I ( t ) θ t μ ) d t + I ( t ) g t d w t dY_t = (I(t)\theta_t\mu) dt + I(t)g_t dw_t dYt=(I(t)θtμ)dt+I(t)gtdwt
现在 d Y t dY_t dYt的漂移项和扩散项都只依赖于时间 t t t,不依赖于随机过程本身。我们可以直接对两边从 s s s到 t t t进行积分:
∫ s t d Y z = ∫ s t I ( z ) θ z μ d z + ∫ s t I ( z ) g z d w z \int_s^t dY_z = \int_s^t I(z)\theta_z\mu dz + \int_s^t I(z)g_z dw_z ∫stdYz=∫stI(z)θzμdz+∫stI(z)gzdwz
左边等于 Y t − Y s Y_t - Y_s Yt−Ys。根据 Y t Y_t Yt的定义:
Y t = I ( t ) x t = e θ ˉ t x t Y_t = I(t)\mathbf{x}t = e^{\bar{\theta}{t}}\mathbf{x}_t Yt=I(t)xt=eθˉtxt
Y s = I ( s ) x s = e θ ˉ s x s = x s Y_s = I(s)\mathbf{x}s = e^{\bar{\theta}{s}}\mathbf{x}_s = \mathbf{x}_s Ys=I(s)xs=eθˉsxs=xs
所以,积分后的方程为:
e θ ˉ t x t − e θ ˉ s x s = μ ∫ s t e θ ˉ z θ z d z + ∫ s t e θ ˉ z g z d w z e^{\bar{\theta}{t}}\mathbf{x}t - e^{\bar{\theta}{s}} \mathbf{x}s = \boldsymbol{\mu} \int_s^t e^{\bar{\theta}{z}}\theta_z dz + \int_s^t e^{\bar{\theta}{z}}g_z d\mathbf{w}_z eθˉtxt−eθˉsxs=μ∫steθˉzθzdz+∫steθˉzgzdwz
这两个积分里,前者可以直接积分,后者需要使用Ito Isometry,简单来说就是其方差可以直接放在积分里面,推导需要条件期望公式和全方差公式。别忘记 d w z d\mathbf{w}z dwz本身是一个标准高斯分布,于是:
∫ s t e θ ˉ z g z d w z = N ( 0 , ∫ s t e 2 θ ˉ z g z 2 d z I ) = N ( 0 , λ 2 ∫ s t e 2 θ ˉ z 2 θ z d z I ) = N ( 0 , λ 2 ( e 2 θ ˉ t − e 2 θ ˉ s ) I ) \begin{align} \int_s^t e^{\bar{\theta}{z}}g_z d\mathbf{w}z &= \mathcal{N}(0, \int_s^t e^{2\bar{\theta}{z}} g_z^2 dz I) \\ &= \mathcal{N}(0, \lambda^2 \int_s^t e^{2\bar{\theta}{z}} 2 \theta_z dz I) \\ &= \mathcal{N}(0, \lambda^2 (e^{2\bar{\theta}{t}} - e^{2\bar{\theta}_{s}})I) \end{align} ∫steθˉzgzdwz=N(0,∫ste2θˉzgz2dzI)=N(0,λ2∫ste2θˉz2θzdzI)=N(0,λ2(e2θˉt−e2θˉs)I)
(5)到(6)是因为用了 2 λ 2 = g t 2 / θ t 2\lambda^2 = g_t^2 / \theta_t 2λ2=gt2/θt,最后放在一起就可以推出(2)-(4)了。
Doob's h-transform标准形式如下:
d X t = ( f ( X t , t ) + g t 2 h ( X t , t , X T , T ) d t + g t d W t , x 0 ∼ p ( x 0 ∣ x T ) d\mathbf{X}_t = (\mathbf{f}(\mathbf{X}_t, t) + g_t^2 \mathbf{h}(\mathbf{X}_t, t, \mathbf{X}_T, T) dt + g_t d\mathbf{W}_t, \quad \mathbf{x_0} \sim p(\mathbf{x_0}|\mathbf{x_T}) dXt=(f(Xt,t)+gt2h(Xt,t,XT,T)dt+gtdWt,x0∼p(x0∣xT)
Doob变换是一种随机过程里的数学技术。它通过将特定的 h函数纳入随机微分方程(SDE)的漂移项来变换原始过程,使该过程能够通过预定的终点。在漂移项额外加入 h ( X t , t , X T , T ) = ∇ x T log p ( x T ∣ x t ) \mathbf{h}(\mathbf{X}_t, t, \mathbf{X}T, T) = \nabla{\mathbf{x_T}} \log p(\mathbf{x_T}|\mathbf{x_t}) h(Xt,t,XT,T)=∇xTlogp(xT∣xt),当 t = T t=T t=T时, p ( x t ∣ x 0 , x T ) = 1 p(\mathbf{x}_t | \mathbf{x}_0, \mathbf{x}_T) = 1 p(xt∣x0,xT)=1
GOUB
Yue等人发现,GOU过程(1)具有均值回归特性,即如果我们将初始状态 x 0 x_0 x0视为高质量图像,将对应的低质量图像 x T = μ x_T = \mu xT=μ作为最终条件,那么高质量图像将逐渐收敛于一个以低质量图像为均值、方差稳定为 λ 2 \lambda^2 λ2的高斯分布。然而,逆向过程的初始状态需要人为地向低质量图像中添加噪声,这会导致一定的信息损失,从而影响性能。巧合的是,Doob's h-transform可以修改随机微分方程,使其在终端时间 T时通过指定的 x T x_T xT。因此,需要着重指出的是,将 h -变换应用于GOU过程能有效消除终端噪声的影响,直接在高质量图像和低质量图像之间建立点对点的关系。
前向和反向过程
利用Doob's h-transform和(2)-(4),基于GOU这个方程,得到前向过程。
前向过程如下:
d x t = ( θ t + g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ) ( x T − x t ) d t + g t d w t . \begin{equation} d\mathbf{x_t} = \left(\theta_t + g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}{t:T}^2}\right) (\mathbf{x}_T - \mathbf{x}_t)dt + g_td\mathbf{w_t}. \end{equation} dxt=(θt+gt2σˉt:T2e−2θˉt:T)(xT−xt)dt+gtdwt.
其中 σ ˉ t : T 2 = λ 2 ( 1 − e − 2 θ ˉ t : T ) \bar{\sigma}{t:T}^2 = \lambda^2 (1 - e^{-2\bar{\theta}{t:T}}) σˉt:T2=λ2(1−e−2θˉt:T)。具体推导比较长,在Yue等人的Appendix A.1里,大概过程是从(2)-(4)写出 p ( x t ∣ x s ) p(\mathbf{x}_t|\mathbf{x}s) p(xt∣xs)的具体分布,依据 ∇ x T log p ( x T ∣ x t ) \nabla{\mathbf{x_T}} \log p(\mathbf{x_T}|\mathbf{x_t}) ∇xTlogp(xT∣xt)推导 h h h,即可推导(8), h ( X t , t , X T , T ) = g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ( x T − x t ) \mathbf{h}(\mathbf{X}t, t, \mathbf{X}T, T) = g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}{t:T}^2} (\mathbf{x}_T - \mathbf{x}_t) h(Xt,t,XT,T)=gt2σˉt:T2e−2θˉt:T(xT−xt)
p ( x t ∣ x 0 , x T ) p(\mathbf{x}_t|\mathbf{x}_0, \mathbf{x}_T) p(xt∣x0,xT)的推导可以从贝叶斯公式推导。
p ( x t ∣ x 0 , x T ) = N ( m ˉ t ′ , σ ˉ t ′ 2 I ) m ˉ t ′ = e − θ ˉ t σ ˉ t : T 2 σ ˉ T 2 x 0 + ( 1 − e − θ ˉ t σ ˉ t : T 2 σ ˉ T 2 + e − 2 θ ˉ t : T σ ˉ t 2 σ ˉ T 2 ) x T σ ˉ t ′ 2 = σ ˉ t 2 σ ˉ t : T 2 σ ˉ 2 \begin{align} p(\mathbf{x}t | \mathbf{x}0, \mathbf{x}T) &= \mathcal{N}(\mathbf{\bar{m}}'t, \bar{\sigma}'^2_t \mathbf{I}) \\ \mathbf{\bar{m}}'t = e^{-\bar{\theta}t} \frac{\bar{\sigma}^2{t:T}}{\bar{\sigma}^2{T}} \mathbf{x}0 &+ \left(1 - e^{-\bar{\theta}t} \frac{\bar{\sigma}^2{t:T}}{\bar{\sigma}^2{T}} + e^{-2\bar{\theta}{t:T}} \frac{\bar{\sigma}^2{t}}{\bar{\sigma}^2{T}}\right) \mathbf{x}T \\ \bar{\sigma}'^2_t &= \frac{\bar{\sigma}^2{t} \bar{\sigma}^2{t:T}}{\bar{\sigma}^2} \end{align} p(xt∣x0,xT)mˉt′=e−θˉtσˉT2σˉt:T2x0σˉt′2=N(mˉt′,σˉt′2I)+(1−e−θˉtσˉT2σˉt:T2+e−2θˉt:TσˉT2σˉt2)xT=σˉ2σˉt2σˉt:T2
有了SDE,我们就不用一步一步推导,而是直接一步到位,从 x 0 , x T \mathbf{x}_0, \mathbf{x}_T x0,xT直接到 x t \mathbf{x}_t xt,这是训练的第一步。训练的第二步是让模型学习到从 x t \mathbf{x}t xt到 x t − 1 \mathbf{x}{t-1} xt−1的演化。
反向SDE如下,有着 p ( x t ∣ x T ) p(\mathbf{x_t}|\mathbf{x_T}) p(xt∣xT)的边际分布:
d x t = [ ( θ t + g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ) ( x T − x t ) − g t 2 ∇ x t log p ( x t ∣ x T ) ] d t + g t d w t d\mathbf{x}t = \left[(\theta_t + g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}_{t:T}^2}) (\mathbf{x}_T - \mathbf{x}t) - g_t^2 \nabla{\mathbf{x}_t} \log p(\mathbf{x}_t | \mathbf{x}_T) \right] dt + g_t d\mathbf{w}_t dxt=[(θt+gt2σˉt:T2e−2θˉt:T)(xT−xt)−gt2∇xtlogp(xt∣xT)]dt+gtdwt
并且存在一个概率流常微分方程:
d x t = [ ( θ t + g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ) ( x T − x t ) − 1 2 g t 2 ∇ x t log p ( x t ∣ x T ) ] d t d\mathbf{x}t = \left[(\theta_t + g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}_{t:T}^2}) (\mathbf{x}_T - \mathbf{x}t) - \frac{1}{2} g_t^2 \nabla{\mathbf{x}_t} \log p(\mathbf{x}_t | \mathbf{x}_T) \right] dt dxt=[(θt+gt2σˉt:T2e−2θˉt:T)(xT−xt)−21gt2∇xtlogp(xt∣xT)]dt
至于为什么这里ODE变成了 1 2 \frac 1 2 21,有一个性质,为保持边际概率密度不变,这一项就得恰好减半。
损失函数
先回顾一下,依照Score based diffusion model,利用conditional score matching,损失函数如下:
L = 1 2 ∫ 0 T E x t [ λ ( t ) ∥ ∇ x t log p ( x t ) − s θ ( x t , t ) ∥ 2 ] d t ∝ 1 2 ∫ 0 T E x 0 , x t [ λ ( t ) ∥ ∇ x t log p ( x t ∣ x 0 ) − s θ ( x t , t ) ∥ 2 ] d t L = \frac{1}{2} \int_{0}^{T} \mathbb{E}{x_t} \left[ \lambda (t) \left\lVert \nabla{x_t} \log p (x_t) - s_{\theta} (x_t, t) \right\rVert^2 \right] dt \propto \frac{1}{2} \int_{0}^{T} \mathbb{E}{x_0,x_t} \left[ \lambda (t) \left\lVert \nabla{x_t} \log p (x_t | x_0) - s_{\theta} (x_t, t) \right\rVert^2 \right] dt L=21∫0TExt[λ(t)∥∇xtlogp(xt)−sθ(xt,t)∥2]dt∝21∫0TEx0,xt[λ(t)∥∇xtlogp(xt∣x0)−sθ(xt,t)∥2]dt
其中 λ ( t ) \lambda(t) λ(t)作为加权函数,若将其选为 g 2 ( t ) g^2(t) g2(t),则能在负对数似然上得到更优的上界(Song等,2021a)。而正比一行实际上是最常用的,因为条件概率 p ( x t ∣ x 0 ) p(x_t | x_0) p(xt∣x0)通常是可获取的。最终,可以从先验分布 p ( x T ) ≈ p prior ( x ) p(x_T) \approx p_{\text{prior}}(x) p(xT)≈pprior(x)中采样得到 x T x_T xT,并通过迭代步骤对公式 (2)进行数值求解来得到 x 0 x_0 x0,从而完成生成过程。
相应地,在GOUB里,得分项 ∇ x t log p ( x t ∣ x T ) \nabla_{x_t} \log p(\mathbf{x}_t | \mathbf{x}T) ∇xtlogp(xt∣xT)可以由神经网络 s θ ( x t , x T , t ) s{\theta}(\mathbf{x}_t, \mathbf{x}_T, t) sθ(xt,xT,t)进行参数化,并且可以使用上述score matching的损失函数进行估计。不幸的是,对随机微分方程的得分函数进行训练通常是一项重大挑战。作者没有明说,但我想有一个原因很重要,SDE是连续的,训练神经网络是离散的,为此,Yue等人是通过用反向SDE再用Euler Sampling得到的反向离散化的方程。而先前因为GOUB的解析解是知道的,作者于是推出了一个更为稳定的,使用ELBO的损失函数并推导证明之。
假设 x T x_T xT是满足GOU方程的一个有限随机变量,对于固定的 x_T,对数似然函数 E p ( x 0 ) [ log p θ ( x 0 ∣ x T ) ] E_{p(x_0)}[\log p_{\theta}(x_0 | x_T)] Ep(x0)[logpθ(x0∣xT)]具有一个ELBO:
ELBO = E p ( x 0 ) { E p ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 , x T ) ] − ∑ t = 2 T E p ( x t ∣ x 0 ) [ KL ( p ( x t − 1 ∣ x 0 , x t , x T ) ∣ ∣ p θ ( x t − 1 ∣ x t , x T ) ) ] } \text{ELBO} = \mathbb{E}_{p(\mathbf{x}0)} \left\{ \mathbb{E}{p(\mathbf{x}_1|\mathbf{x}0)} [\log p{\boldsymbol{\theta}} (\mathbf{x}_0 | \mathbf{x}1, \mathbf{x}T)] - \sum{t=2}^{T} \mathbb{E}{p(\mathbf{x}_t|\mathbf{x}0)}[\text{KL} (p (\mathbf{x}{t -1} | \mathbf{x}_0, \mathbf{x}t, \mathbf{x}T) || p{\boldsymbol{\theta}} (\mathbf{x}{t -1} | \mathbf{x}_t, \mathbf{x}_T))] \right\} ELBO=Ep(x0){Ep(x1∣x0)[logpθ(x0∣x1,xT)]−t=2∑TEp(xt∣x0)[KL(p(xt−1∣x0,xt,xT)∣∣pθ(xt−1∣xt,xT))]}
假设 p θ ( x t − 1 ∣ x t , x T ) p_{\boldsymbol{\theta}}(\mathbf{x}{t -1} | \mathbf{x}t, \mathbf{x}T) pθ(xt−1∣xt,xT) 是一个具有恒定方差的高斯分布 N ( μ θ , t − 1 , σ θ , t − 1 2 I ) \mathcal{N}(\boldsymbol{\mu}{\boldsymbol{\theta},t -1}, \sigma^2{\boldsymbol{\theta},t -1}\mathbf{I}) N(μθ,t−1,σθ,t−12I),最大化ELBO等价于最小化:
L = E t , x 0 , x t , x T [ 1 2 σ θ , t − 1 2 ∥ μ t − 1 − μ θ , t − 1 ∥ 2 ] L = \mathbb{E}{t,\mathbf{x}0,\mathbf{x}t,\mathbf{x}T} \left[ \frac{1}{2\sigma^2{\boldsymbol{\theta},t -1}} \|\boldsymbol{\mu}{t -1} - \boldsymbol{\mu}{\boldsymbol{\theta},t -1}\|^2 \right] L=Et,x0,xt,xT[2σθ,t−121∥μt−1−μθ,t−1∥2]
其中, μ t − 1 \boldsymbol{\mu}{t -1} μt−1 表示 p ( x t − 1 ∣ x 0 , x t , x T ) p(\mathbf{x}{t -1} | \mathbf{x}0, \mathbf{x}t, \mathbf{x}T) p(xt−1∣x0,xt,xT) 的均值:
μ t − 1 = 1 σ ˉ t ′ 2 [ σ ˉ t − 1 ′ 2 ( x t − b x T ) a + ( σ ˉ t ′ 2 − σ ˉ t − 1 ′ 2 a 2 ) m ˉ t ′ ] \mu{t -1} = \frac{1}{\bar{\sigma}'^2_t} \left[ \bar{\sigma}'^2{t -1}(x_t - bx_T)a + (\bar{\sigma}'^2_t - \bar{\sigma}'^2{t -1}a^2) \bar{m}'_t \right] μt−1=σˉt′21[σˉt−1′2(xt−bxT)a+(σˉt′2−σˉt−1′2a2)mˉt′]
其中,
a = e − θ ˉ t − 1 : t σ ˉ t : T 2 σ ˉ t − 1 : T 2 a = e^{-\bar{\theta}{t -1:t}} \frac{\bar{\sigma}^2{t:T}}{\bar{\sigma}^2_{t -1:T}} a=e−θˉt−1:tσˉt−1:T2σˉt:T2
b = 1 σ ˉ T 2 { ( 1 − e − θ ˉ t ) σ ˉ t : T 2 + e − 2 θ ˉ t : T σ ˉ t 2 − [ ( 1 − e − θ ˉ t − 1 ) σ ˉ t − 1 : T 2 + e − 2 θ ˉ t − 1 : T σ ˉ t − 1 2 ] a } b = \frac{1}{\bar{\sigma}^2_T} \left\{ (1 - e^{-\bar{\theta}t})\bar{\sigma}^2{t:T} + e^{-2\bar{\theta}{t:T}} \bar{\sigma}^2_t - \left[ (1 - e^{-\bar{\theta}{t -1}})\bar{\sigma}^2_{t -1:T} + e^{-2\bar{\theta}{t -1:T}} \bar{\sigma}^2{t -1} \right] a \right\} b=σˉT21{(1−e−θˉt)σˉt:T2+e−2θˉt:Tσˉt2−[(1−e−θˉt−1)σˉt−1:T2+e−2θˉt−1:Tσˉt−12]a}
这个证明就比较多,感兴趣的可以参考参考文献第二篇。
根据反向SDE方程,离散化:
x t − 1 = x t − ( θ t + g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ) ( x T − x t ) + g t 2 ∇ x t log p ( x t ∣ x T ) − g t ϵ t \mathbf{x}{t-1} = \mathbf{x}t - \left( \theta_t + g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}{t:T}^2} \right) (\mathbf{x}_T - \mathbf{x}t) + g_t^2 \nabla{\mathbf{x}_t} \log p(\mathbf{x}_t | \mathbf{x}_T) - g_t \boldsymbol{\epsilon}_t xt−1=xt−(θt+gt2σˉt:T2e−2θˉt:T)(xT−xt)+gt2∇xtlogp(xt∣xT)−gtϵt
其中, ϵ t ∼ N ( 0 , d t I ) \boldsymbol{\epsilon}_t \sim \mathcal{N}(\mathbf{0}, d_t\mathbf{I}) ϵt∼N(0,dtI)。
因此:
μ θ , t − 1 = x t − ( θ t + g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ) ( x T − x t ) + g t 2 ∇ x t log p θ ( x t ∣ x T ) \boldsymbol{\mu}{\theta,t-1} = \mathbf{x}t - \left(\theta_t + g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}{t:T}^2}\right) (\mathbf{x}_T - \mathbf{x}t) + g_t^2 \nabla{\mathbf{x}t} \log p{\theta}(\mathbf{x}_t | \mathbf{x}_T) μθ,t−1=xt−(θt+gt2σˉt:T2e−2θˉt:T)(xT−xt)+gt2∇xtlogpθ(xt∣xT)
标准差就是: σ θ , t − 1 = g t \sigma_{\theta,t-1} = g_t σθ,t−1=gt.
作者发现,L1范数损失在图像重构的结果上效果更好,故而采用L1范数,最后的损失函数结果太长就不写了,代入上面的结果即可。最后,如果我们得到最优的 ϵ θ ∗ ( x t , x T , t ) \boldsymbol{\epsilon}^*_{\boldsymbol{\theta}}(\mathbf{x}_t, \mathbf{x}T, t) ϵθ∗(xt,xT,t),就可以计算反向过程的得分 ∇ x t log p ( x t ∣ x T ) ≈ − ϵ θ ∗ ( x t , x T , t ) σ ˉ t ′ \nabla{\mathbf{x}_t} \log p(\mathbf{x}_t | \mathbf{x}T) \approx \frac {-\boldsymbol{\epsilon}^*{\boldsymbol{\theta}}(\mathbf{x}_t, \mathbf{x}_T, t)} {\bar{\sigma}'_t} ∇xtlogp(xt∣xT)≈σˉt′−ϵθ∗(xt,xT,t),直接代入即可。
Mean-ODE
与普通的扩散模型不同,作者表示,对均值 μ θ , t − 1 \mu_{\theta,t -1} μθ,t−1 的参数化是从随机微分方程的微分推导而来的,这有效地结合了离散扩散模型和基于连续分数的生成模型的特点。在反向过程中,每个采样步骤的值在训练期间会逼近真实均值。因此,作者提出了一个Mean - ODE模型,该模型省略了布朗漂移项,也就是直接在反向SDE上,从经验和实验结果的表现证明上直接删除了 d W t d\mathbf{W_t} dWt:
d x t = [ θ t + g t 2 e − 2 θ ˉ t : T σ ˉ t : T 2 ( x T − x t ) − g t 2 ∇ x t log p ( x t ∣ x T ) ] d t ( 9 ) d\mathbf{x}t = \left[ \theta_t + g_t^2 \frac{e^{-2\bar{\theta}{t:T}}}{\bar{\sigma}_{t:T}^2} (\mathbf{x}_T - \mathbf{x}t) - g_t^2 \nabla{\mathbf{x}_t} \log p(\mathbf{x}_t | \mathbf{x}_T) \,\right] dt \quad (9) dxt=[θt+gt2σˉt:T2e−2θˉt:T(xT−xt)−gt2∇xtlogp(xt∣xT)]dt(9)
作者在实验中也发现,Mean-ODE的表现比Score-ODE好。
UniDB
UniDB仅在GOUB代码的基础上做了极少的修改,而且也利用stochastic optimal control在数学上提供了diffusion bridge via Doob's h-transform的情况的理解,也证明了这是UniDB的 γ → ∞ \gamma \to \infty γ→∞的一种特殊情况,在超分辨率(DIV2K)、图像修复(CelebA - HQ)和去雨(Rain100H)上都表现到了SOTA级别。
UniDB指出,GOUB其核心技术Doob's h-transform是一种次优解。而且GOUB虽然性能好,但也有内在的细节模糊或扭曲问题,并通过理论实验帮助阐释了这一点。不过UniDB和GOUB都有着采样慢的通病,但我认为两者的采样速度不会有多少差距,UniDB是可以做到即插即用的,只在GOUB上做了极少量的修改,起到了统一和更深的insight。
注意到图上右边s.t.的部分, f t x t f_t \mathbf{x_t} ftxt是drift项没错,不过多了一个 h t m h_t \mathbf{m} htm, m \mathbf{m} m是一个given state,一个给定的状态,比如 x t − 1 \mathbf{x_{t-1}} xt−1或者别的。'
把之前GOUB的漂移项展开:
θ t ( μ − x t ) = θ t μ − θ t x t \theta_t (\boldsymbol{\mu} - x_t) = \theta_t \boldsymbol{\mu} - \theta_t x_t θt(μ−xt)=θtμ−θtxt
然后再看UniDB的通用漂移项:
f t x t + h t m f_t x_t + h_t \mathbf{m} ftxt+htm
只要进行如下的参数代换,两者就完全等价了:
- 令 f t = − θ t f_t = -\theta_t ft=−θt
- 令 h t = θ t h_t = \theta_t ht=θt
- 令 m = μ \mathbf{m} = \boldsymbol{\mu} m=μ
这也是原文中提到的,还有VE,VP等。
UniDB中的一些proposition和GOUB很相似,这里阐述一下不同的部分。
- 依据Theorem 4.1,可以得到一个从 x 0 x_0 x0连接到终端 x T x_T xT邻域的最优控制正向随机微分方程,这个 u t , γ ∗ \mathbf{u}_{t,\gamma}^* ut,γ∗也是可以计算的,与 m , x t , x T \mathbf{m},\mathbf{x_t},\mathbf{x_T} m,xt,xT有关,正向过程中 x t x_t xt的转移情况也可以推出。
- 对于SOC问题,当 γ → ∞ \gamma \to \infty γ→∞时,最优控制器变为 u t , ∞ ∗ = g t ∇ x t log p ( x T ∣ x t ) u^*{t,\infty} = g_t\nabla{\mathbf{x}_t} \log p(\mathbf{x}_T | \mathbf{x}_t) ut,∞∗=gt∇xtlogp(xT∣xt),并且对应于线性随机微分方程形式的前向和后向随机微分方程与 Doob 的 h h h-变换相同。
- 记 J ( u t , γ , γ ) ≜ ∫ 0 T 1 2 ∥ u t , γ ∥ 2 2 d t + γ 2 ∥ x T u − x T ∥ 2 2 \mathcal{J}(\mathbf{u}{t,\gamma}, \gamma) \triangleq \int_0^T \frac{1}{2} \|\mathbf{u}{t,\gamma}\|2^2 \mathrm{d}t + \frac{\gamma}{2} \|\mathbf{x}T^u - x_T\|2^2 J(ut,γ,γ)≜∫0T21∥ut,γ∥22dt+2γ∥xTu−xT∥22为系统的总成本, u t , γ ∗ u{t,\gamma}^* ut,γ∗为最优控制器,则有 J ( u t , γ ∗ , γ ) ≤ J ( u t , ∞ ∗ , ∞ ) \mathcal{J}(\mathbf{u}^*{t,\gamma}, \gamma) \le \mathcal{J}(\mathbf{u}^*{t,\infty}, \infty) J(ut,γ∗,γ)≤J(ut,∞∗,∞),这说明 γ → ∞ \gamma \to \infty γ→∞的情况并非是最优解,后面作者根据实验发现 γ \gamma γ的取值是随着不同的具体任务而有变化的。
- 记初始状态分布为 x 0 x_0 x0,由控制器产生的终端分布为 x T u \mathbf{x}_T^u xTu,以及预先定义的终端分布为 x T x_T xT,则
∥ x T u − x T ∥ 2 2 = e − 2 θ ˉ T ( 1 + γ λ 2 ( 1 − e − 2 θ ˉ T ) ) 2 ∥ x T − x 0 ∥ 2 2 \|\mathbf{x}_T^u - x_T\|_2^2 = \frac{e^{-2\bar{\theta}_T}}{(1 + \gamma\lambda^2(1 - e^{-2\bar{\theta}_T}))^2} \|x_T - x_0\|_2^2 ∥xTu−xT∥22=(1+γλ2(1−e−2θˉT))2e−2θˉT∥xT−x0∥22
这说明控制的终点和实际的终点是受到 γ \gamma γ的调控的,如下图,红色区域是作者推荐的关注区域,蓝色点竖线是作者在后面的消融实验中的选取方式,在四倍超分,图像修复,去雨三个任务上,从PSNR,SSIM,LPIPS,FIDS四个指标看, γ \gamma γ的不同,分数也不同,而且同一个任务里也可能并非一个gamma能得到四个指标都有良好的结果。
与之前GOUB的类似,反向过程SDE和Mean-ODE如下:
d x t = [ f t x t + h t m + g t u t , γ ∗ − g t 2 ∇ x t log p ( x t ∣ x T ) ] d t + g t d w ~ t \mathrm{d}\mathbf{x}t = [f_t\mathbf{x}t + h_t\mathbf{m} + g_t\mathbf{u}^*{t,\gamma} - g_t^2\nabla{\mathbf{x}_t}\log p(\mathbf{x}_t | x_T)]\mathrm{d}t + g_t\mathrm{d}\tilde{\mathbf{w}}_t dxt=[ftxt+htm+gtut,γ∗−gt2∇xtlogp(xt∣xT)]dt+gtdw~t
d x t = [ f t x t + h t m + g t u t , γ ∗ − g t 2 ∇ x t log p ( x t ∣ x T ) ] d t \mathrm{d}\mathbf{x}t = [f_t\mathbf{x}t + h_t\mathbf{m} + g_t\mathbf{u}^*{t,\gamma} - g_t^2\nabla{\mathbf{x}_t}\log p(\mathbf{x}_t | x_T)]\mathrm{d}t dxt=[ftxt+htm+gtut,γ∗−gt2∇xtlogp(xt∣xT)]dt
整个训练中,以GOU为例子,项的差距与GOUB差距很小,可以看到几乎只是 γ − 1 \gamma^{-1} γ−1的引入:
e − θ ˉ t σ ˉ t : T 2 σ ˉ T 2 ⇒ e − θ ˉ t γ − 1 + σ ˉ t : T 2 γ − 1 + σ ˉ T 2 e^{-\bar{\theta}t} \frac{\bar{\sigma}{t:T}^2}{\bar{\sigma}T^2} \Rightarrow e^{-\bar{\theta}t} \frac{\gamma^{-1} + \bar{\sigma}{t:T}^2}{\gamma^{-1} + \bar{\sigma}T^2} e−θˉtσˉT2σˉt:T2⇒e−θˉtγ−1+σˉT2γ−1+σˉt:T2
g t h = g t e − 2 θ ˉ t : T ( x T − x t ) σ ˉ t : T 2 ⏟ GOUB ⇒ u t , γ ∗ = g t e − 2 θ ˉ t : T ( x T − x t ) γ − 1 + σ ˉ t : T 2 ⏟ UniDB-GOU \underbrace{g_t \mathbf{h} = \frac{g_t e^{-2\bar{\theta}{t:T}}(x_T - \mathbf{x}t)}{\bar{\sigma}{t:T}^2}}{\text{GOUB}} \Rightarrow \underbrace{\mathbf{u}^*{t,\gamma} = \frac{g_t e^{-2\bar{\theta}{t:T}}(x_T - \mathbf{x}t)}{\gamma^{-1} + \bar{\sigma}{t:T}^2}}_{\text{UniDB-GOU}} GOUB gth=σˉt:T2gte−2θˉt:T(xT−xt)⇒UniDB-GOU ut,γ∗=γ−1+σˉt:T2gte−2θˉt:T(xT−xt)
下面是两个算法。算法一的思路就是,先随机抽取一对图像 ( x 0 , x t ) (\mathbf{x_0},\mathbf{x_t}) (x0,xt),然后在 U n i f o r m { 1 , ... , T } Uniform\{1,\dots,T\} Uniform{1,...,T}抽取一个 t t t,计算 μ ˉ t , γ , γ , σ ˉ t ′ 2 \boldsymbol{\bar{\mu}}{t,\gamma},\gamma,\bar{\sigma}t'^2 μˉt,γ,γ,σˉt′2。为训练稳定,不直接预测分数,而是依据分数匹配理论,分数函数可以被参数化为:
∇ x t log p ( x t ∣ x T ) ≈ − ϵ θ ( x t , x T , t ) σ ˉ t ′ \nabla{x_t} \log p(x_t|x_T) \approx -\frac{\epsilon\theta(x_t, x_T, t)}{\bar{\sigma}'_t} ∇xtlogp(xt∣xT)≈−σˉt′ϵθ(xt,xT,t)
通过计算 μ ˉ t − 1 , θ \boldsymbol{\bar{\mu}}{t-1,\theta} μˉt−1,θ,再计算 μ ˉ t − 1 , γ \boldsymbol{\bar{\mu}}{t-1,\gamma} μˉt−1,γ,计算损失函数梯度,让算法收敛。整个算法其实和之前计算GOUB的 μ θ , t − 1 \boldsymbol{\mu}{\theta,t-1} μθ,t−1和 μ t − 1 \boldsymbol{\mu}{t-1} μt−1挺像的。
算法二就是采样啦,解决新问题。
写到这里
参考文献
Zhu K, Pan M, Ma Y, et al. UniDB: A Unified Diffusion Bridge Framework via Stochastic Optimal Control[J]. arXiv preprint arXiv:2502.05749, 2025.
Yue C, Peng Z, Ma J, et al. Image restoration through generalized ornstein-uhlenbeck bridge[J]. arXiv preprint arXiv:2312.10299, 2023.