一篇文章入门随机微分方程SDE

文章目录

随机微分方程SDE

笔者建议,学完DDPM再来看SDE的作用和推导过程

标准布朗运动

在学习随机微分方程之前,我们先来看一下什么是标准布朗运动

假设有一个一维的直线,有个小人从原点出发,每次随机地选择向左走1格或者向右走1格,且向左走和向右走的两个选项,被选择的概率相等 → \rightarrow →用 S t S_t St代表小人离原点的距离, t t t代表代表选择的次数,如果选择的次数越多,那么 S t S_t St将会逐渐服从一个均值为0、方差为 t t t的正态分布
布朗运动 W ( t ) W(t) W(t)是期望为0、方差为 t t t的正态分布 ⇔ \Leftrightarrow ⇔ W t ∼ N ( 0 , t ) W_t\sim \mathcal{N}(0,t) Wt∼N(0,t) ⇒ \Rightarrow ⇒ W t + Δ t − W t ∼ N ( 0 , Δ t ) W_{t+\Delta t}-W_t\sim \mathcal{N}(0,\Delta t) Wt+Δt−Wt∼N(0,Δt),当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, d w = d t ε dw=\sqrt{dt}\varepsilon dw=dt ε(重参数技巧)

SDE加噪

在DDPM中,扩散过程被划分为固定的T步 ⇒ \Rightarrow ⇒DDPM=拆楼+建楼 ⇒ \Rightarrow ⇒"拆楼"和"建楼"都被事先划分为了T步,这个划分有着相当大的人为性。事实上,真实的"拆"、"建"过程应该是没有刻意划分的步骤 ⇒ \Rightarrow ⇒可以将它们理解为一个在时间上连续的变换过程,可以用随机微分方程(Stochastic Differential Equation,SDE)来描述,即 d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt,其中 f t ( x t ) f_t(x_t) ft(xt)是漂移项,描述数据的确定性演化 ; g t g_t gt是扩散项,描述的是噪声的扩散程度 ; d w t dw_t dwt是维纳运动(布朗运动)的微小增量,表示随机波动

随机微分方程: d x = dx= dx=确定的变化 + + +随机的变化,其中随机的变化代表着随机性

随机微分方程描述了系统从 t t t时刻到 t + Δ t t+\Delta t t+Δt时刻的变化

我们可以将随机微分方程看成是 x t + Δ t − x t = f t ( x t ) Δ t + g t Δ t ε , ε ∼ N ( 0 , I ) \boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t=\boldsymbol{f}_t(\boldsymbol{x}_t)\Delta t+g_t\sqrt{\Delta t}\boldsymbol{\varepsilon},\quad\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I}) xt+Δt−xt=ft(xt)Δt+gtΔt ε,ε∼N(0,I)在 Δ t → 0 \Delta t\rightarrow 0 Δt→0时的极限 ⇒ \Rightarrow ⇒如果拆楼要1天,那么拆楼就是 x x x从 t = 0 t=0 t=0到 t = 1 t=1 t=1时刻的变化 Δ t → 0 \Delta t\rightarrow 0 Δt→0越小的 Δ t \Delta t Δt意味着对原始噪声越好的近似。如果 Δ t = 0.001 \Delta t=0.001 Δt=0.001,对应着 T = 1000 T=1000 T=1000;如果 Δ t = 0.01 \Delta t=0.01 Δt=0.01,则对应 T = 100 T=100 T=100 ⇒ \Rightarrow ⇒引入SDE的本质好处是"将理论分析和代码实现分离开来"
DDPM的加噪过程本质上是一个SDE,而SDE本质上描述的是微小时间变化下系统状态的变化

  • DDPM的加噪: x t + 1 = 1 − β t x t + β t ϵ x_{t+1}=\sqrt{1-\beta_t}x_t+\sqrt{\beta_t}\epsilon xt+1=1−βt xt+βt ϵ
  • SDE的加噪: d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt

在这里,笔者介绍一下将DDPM加噪公式映射到SDE加噪公式的推导过程:

  1. 重写DDPM加噪公式: x t + 1 − x t = ( 1 − β t − 1 ) x t + β t ϵ x_{t+1}-x_t=(\sqrt{1-\beta_t}-1)x_t+\sqrt{\beta_t}\epsilon xt+1−xt=(1−βt −1)xt+βt ϵ ⇒ \Rightarrow ⇒ 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1−βt ≈1−2βt ⇒ \Rightarrow ⇒将DDPM加噪公式重新表示为一个确定项和随机噪声项的和: x t + 1 − x t ≈ − β t 2 x t + β t ϵ x_{t+1}-x_{t}\approx-\frac{\beta_{t}}{2}x_{t}+\sqrt{\beta_{t}}\epsilon xt+1−xt≈−2βtxt+βt ϵ
    在这里,使用泰勒展开得到 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1−βt ≈1−2βt
    先来介绍一下泰勒展开:如果 f ( x ) f(x) f(x)在 x = a x=a x=a处是可微的,则它的泰勒展开可以写为 f ( x ) ≈ f ( a ) + f ′ ( a ) ( x − a ) + f ′ ′ ( a ) 2 ! ( x − a ) 2 + ... f(x)\approx f(a)+f'(a)(x-a)+\frac{f''(a)}{2!}(x-a)^2+\ldots f(x)≈f(a)+f′(a)(x−a)+2!f′′(a)(x−a)2+...,其中 f ′ ( a ) f'(a) f′(a)和 f ′ ′ ( a ) f''(a) f′′(a)分别是 f ( x ) f(x) f(x)在 a a a处的一阶导数和二阶导数;在泰勒展开中,若函数依赖多个变量,需要对每个变量分别进行展开
    f ( β t ) = 1 − β t f(\beta_t)=\sqrt{1-\beta_t} f(βt)=1−βt 在 β t = 0 \beta_t=0 βt=0处展开 ⇒ \Rightarrow ⇒零阶项: f ( 0 ) = 1 − 0 = 1 f(0)=\sqrt{1-0}=1 f(0)=1−0 =1;一阶导数: f ′ ( β t ) = d d β t 1 − β t = − 1 2 1 − β t f'(\beta_t)=\frac{d}{d\beta_t}\sqrt{1-\beta_t}=\frac{-1}{2\sqrt{1-\beta_t}} f′(βt)=dβtd1−βt =21−βt −1,在 β t = 0 \beta_t=0 βt=0处 f ′ ( 0 ) = − 1 2 1 − 0 = − 1 2 f'(0)=\frac{-1}{2\sqrt{1-0}}=-\frac{1}{2} f′(0)=21−0 −1=−21 ⇒ \Rightarrow ⇒ f ( β t ) ≈ f ( 0 ) + f ′ ( 0 ) β t f(\beta_t)\approx f(0)+f'(0)\beta_t f(βt)≈f(0)+f′(0)βt ⇒ \Rightarrow ⇒ 1 − β t ≈ 1 − 1 2 β t \sqrt{1-\beta_t}\approx1-\frac{1}{2}\beta_t 1−βt ≈1−21βt
  2. 引入 Δ t \Delta t Δt: Δ t \Delta t Δt只是在数学上引入的时间增量,而 β t \beta_t βt在离散模型中的定义是独立于 Δ t \Delta t Δt的,将DDPM离散的加噪过程转换为连续时间的随机微分方程描述: x t + Δ t − x t ≈ − β t 2 x t Δ t + β t Δ t ϵ x_{t+\Delta t}-x_t\approx-\frac{\beta_t}{2}x_t\Delta t+\sqrt{\beta_t\Delta t}\epsilon xt+Δt−xt≈−2βtxtΔt+βtΔt ϵ ⇒ \Rightarrow ⇒ d x = − 1 2 β t x t d t + β t d w dx=-\frac{1}{2}\beta_tx_tdt+\sqrt{\beta_t}dw dx=−21βtxtdt+βt dw
  3. SDE的形式:漂移项 f t ( x t ) = − β t 2 x t f_t(x_t)=-\frac{\beta_t}2x_t ft(xt)=−2βtxt,扩散系数 g t = β t g_{t}=\sqrt{\beta_{t}} gt=βt
    左侧是数据分布,右侧是正态分布,t是连续时间

SDE去噪

SDE去噪的目标是求 p ( x t ∣ x t + Δ t ) p(x_t|x_{t+\Delta t}) p(xt∣xt+Δt)

已知: x t + Δ t x_{t+\Delta t} xt+Δt和前向SDE过程 p ( x t + Δ t ∣ x t ) p(x_{t+\Delta t}|x_t) p(xt+Δt∣xt) ⇒ \Rightarrow ⇒贝叶斯公式: p ( x t ∣ x t + Δ t ) = p ( x t + Δ t ∣ x t ) p ( x t ) p ( x t + Δ t ) p(x_t|x_{t+\Delta t})=\frac{p(x_{t+\Delta t}|x_t)p(x_t)}{p(x_{t+\Delta t})} p(xt∣xt+Δt)=p(xt+Δt)p(xt+Δt∣xt)p(xt)

为了简化问题,尽可能使 p ( x t ∣ x t + Δ t ) p(x_t|x_{t+\Delta t}) p(xt∣xt+Δt)的分布满足正太分布

  1. x t + Δ t = x t + f t ( x t ) Δ t + g t Δ t ϵ x_{t+\Delta t}=x_t+f_t(x_t)\Delta t+g_t\sqrt{\Delta t}\epsilon xt+Δt=xt+ft(xt)Δt+gtΔt ϵ ⇒ \Rightarrow ⇒根据重参数可得: x t + Δ t ∼ N ( x t + f t ( x t ) Δ t , g t 2 Δ t ) x_{t+\Delta t}\sim\mathcal{N}(x_t+f_t(x_t)\Delta t,g_t^2\Delta t) xt+Δt∼N(xt+ft(xt)Δt,gt2Δt)
  2. 正态分布的概率密度函数: f ( x ) = 1 σ 2 π e x p ( − ( x − μ ) 2 2 σ 2 ) f(x)=\frac1{\sigma\sqrt{2\pi}}exp(-\frac{(x-\mu)^2}{2\sigma^2}) f(x)=σ2π 1exp(−2σ2(x−μ)2)
  3. p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − f t ( x t ) Δ t ) 2 2 g t 2 Δ t + l o g p ( x t ) − l o g p ( x t + Δ t ) ) ( 1 ) \begin{aligned} p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-f_t(x_t)\Delta t)^2}{2g_t^2\Delta t}+logp(x_t)-logp(x_{t+\Delta t}))&&(1)\\ \end{aligned} p(xt∣xt+Δt)=exp(−2gt2Δt(xt+Δt−xt−ft(xt)Δt)2+logp(xt)−logp(xt+Δt))(1)
  4. 在 x t x_t xt处泰勒展开 l o g p ( x t + Δ t ) logp(x_{t+\Delta t}) logp(xt+Δt): l o g p ( x t + Δ t ) ≈ l o g p ( x t ) + ( x t + Δ t − x t ) ∇ x l o g p ( x t ) + ( t + Δ t − t ) ∇ t l o g p ( x t ) \begin{aligned} logp(x_{t+\Delta t})\approx logp(x_t)+(x_{t+\Delta t}-x_t)\nabla_xlogp(x_t)+(t+\Delta t-t)\nabla_tlogp(x_t) \end{aligned} logp(xt+Δt)≈logp(xt)+(xt+Δt−xt)∇xlogp(xt)+(t+Δt−t)∇tlogp(xt),其中 ∇ x log ⁡ p ( x t ) \nabla_{x}\log p(x_{t}) ∇xlogp(xt)表示状态变量变化对 log ⁡ p ( x t ) \log p(x_{t}) logp(xt)的影响, ∇ t log ⁡ p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt)表示时间变化对 log ⁡ p ( x t ) \log p(x_{t}) logp(xt)的影响
    在这里笔者介绍一下为什么会多出一项 ∇ t log ⁡ p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt): p ( x t ) p(x_t) p(xt)实际上是" t t t时刻随机变量等于 x t x_t xt的概率密度", p ( x t + Δ t ) p(x_{t+\Delta t}) p(xt+Δt)实际上是" t + Δ t t+\Delta t t+Δt时刻随机变量等于 x t + Δ t x_{t+\Delta t} xt+Δt的概率密度",即 p ( x t ) p(x_t) p(xt)实际上同时是 t t t和 x t x_t xt的函数,概率密度不会会因为状态的变化而变化,还会因为时间的推移而变换,所以要多一项 t t t的偏导数,考虑状态和时间的变化对概率密度的双重影响 ⇒ \Rightarrow ⇒泰勒展开中的零阶项反映的是静态的信息,不涉及任何随时间变化的因素;一阶项捕捉了概率密度函数在时间 t t t和状态 x x x处的变化趋势
  5. 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, Δ 2 t = 0 \Delta^2 t=0 Δ2t=0
  6. ( 1 ) = p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) 2 2 g t 2 Δ t ) = e x p ( − ( x t − ( x t + Δ t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) ) 2 2 g t 2 Δ t ) ( 2 ) \begin{aligned} (1)=p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-(f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t)^2}{2g_t^2\Delta t})\\ &=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t}(x_t)-g_t^2\nabla xlogp(x_t))\Delta t))^2}{2g_t^2\Delta t})&&(2) \end{aligned} (1)=p(xt∣xt+Δt)=exp(−2gt2Δt(xt+Δt−xt−(ft(xt)−gt2∇xlogp(xt))Δt)2)=exp(−2gt2Δt(xt−(xt+Δt−(ft(xt)−gt2∇xlogp(xt))Δt))2)(2)
  7. 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, t + Δ t → t t+\Delta t\rightarrow t t+Δt→t
  8. ( 2 ) = e x p ( − ( x t − ( x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t ) ) 2 2 g t + Δ t 2 Δ t ) \begin{aligned} (2)=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t))^2}{2g_{t+\Delta t}^2\Delta t}) \end{aligned} (2)=exp(−2gt+Δt2Δt(xt−(xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt))2)
  9. x t = x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_t=x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt=xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt+gt+ΔtΔt ϵ
  10. 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, Δ t → d t \Delta t\rightarrow dt Δt→dt
  11. x t + Δ t − x t = ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_{t+\Delta t}-x_t=(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt+Δt−xt=(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt+gt+ΔtΔt ϵ
    12.12. d x t = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d t ϵ = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d w ˉ \begin{aligned} dx_t=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}\sqrt{dt}\epsilon=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}d\bar{w} \end{aligned} dxt=(ft(xt)−gt2∇xlogp(xt))dt+gtdt ϵ=(ft(xt)−gt2∇xlogp(xt))dt+gtdwˉ

Score Matching

在上述SDE去噪过程中,我们得到 d x t = ( f t ( x t ) − g t 2 ∇ x t l o g p ( x t ) ) d t + g t d w ˉ \begin{aligned} dx_t=(f_{t}(x_{t})-g_{t}^2\nabla_{x_t}logp(x_{t}))dt+g_{t}d\bar{w} \end{aligned} dxt=(ft(xt)−gt2∇xtlogp(xt))dt+gtdwˉ,那么 ∇ x t log ⁡ p ( x t ) \nabla_{x_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)这一项是什么?如果不知道这一项的话,似乎我们也无法得到 d x t dx_t dxt的值

我们先来看一下 ∇ x t log ⁡ p ( x t ) \nabla_{x_t}\log p(\boldsymbol{x}t) ∇xtlogp(xt)的含义:对数概率密度函数 log ⁡ p ( x t ) \log p(x{t}) logp(xt)关于 x t x_t xt的梯度 → \rightarrow →梯度指向概率密度变化最快的方向, ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}t}\log p(\boldsymbol{x}t) ∇xtlogp(xt)给出移动方向和移动速度以找到概率更大的区域 ⇒ \Rightarrow ⇒使用一个 θ \theta θ参数化的概率分布 p θ p{\theta} pθ模拟 p p p,通过学习参数 θ \theta θ使 p θ p{\theta} pθ接近 p p p

我们可以将 p θ p_{\theta} pθ看成是由两部分组成的,分别是表示密度的函数 p θ ~ \tilde{p_{\theta}} pθ~、归一化因子 Z θ Z_{\theta} Zθ ⇒ \Rightarrow ⇒ p θ ( x ) = p ~ θ ( x ) Z θ = p ~ θ ( x ) ∫ x ∈ X p ~ θ ( x ) d x p_\theta(x)=\frac{\tilde{p}\theta(x)}{Z\theta}=\frac{\tilde{p}\theta(x)}{\int{x\in X}\tilde{p}\theta(x)dx} pθ(x)=Zθp~θ(x)=∫x∈Xp~θ(x)dxp~θ(x),其中未归一化的概率密度函数 p θ ~ \tilde{p{\theta}} pθ~给出某个数据点 x x x相对于其他数据点的可能性大小,但并不能给出直接用于表示 x x x发生的真实概率

目前,使用极大似然估计求解 θ \theta θ的问题:不知道归一化因子 Z θ Z_{\theta} Zθ的值

解决方法:

  1. 引入得分函数(score function):概率密度函数的梯度 ∇ x log ⁡ p θ ( x ) \nabla_x\log p_\theta(x) ∇xlogpθ(x)
  2. 将 p θ ( x ) p_{\theta}(x) pθ(x)通过 l o g log log拆分成两项 ∇ x log ⁡ p ~ θ ( x ) − ∇ x log ⁡ Z θ \nabla_{x}\log\tilde{p}{\theta}(x)-\nabla{x}\log Z_{\theta} ∇xlogp~θ(x)−∇xlogZθ ⇒ \Rightarrow ⇒由于求解的是 x x x的梯度,所以可以直接消掉 ∇ x log ⁡ Z θ \nabla_{x}\log Z_{\theta} ∇xlogZθ,因为 ∇ x log ⁡ Z θ \nabla_{x}\log Z_{\theta} ∇xlogZθ与 x x x无关;同时 p θ ~ \tilde{p_{\theta}} pθ~不受"概率分布"的约束,可以使用神经网络作为 p θ ~ \tilde{p_{\theta}} pθ~,因为 p θ ~ \tilde{p_{\theta}} pθ~本身就不是概率密度函数, p θ ~ \tilde{p_{\theta}} pθ~只是密度函数
  3. 目标:选择一个loss让 ∇ x log ⁡ p θ ( x ) \nabla_x\log p_\theta(x) ∇xlogpθ(x)尽可能接近 ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)

新的问题:不知道数据分布的 score function ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)

为了简化公式,下面公式中的 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x)等同于 ∇ x t log ⁡ p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)

解决方法:Score Matching

  1. Score Matching:用于估计概率密度函数的梯度(得分函数 score ⁡ ( x ) = ∇ x t log ⁡ p ( x t ) \operatorname{score}(x)=\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) score(x)=∇xtlogp(xt)),而无需知道密度函数的归一化常数
  2. Score Matching的目标:学习一个模型 q ( x ; θ ) q(x;\theta) q(x;θ),使得模型得分函数 ∇ x log ⁡ q ( x ; θ ) \nabla_x\log q(x;\theta) ∇xlogq(x;θ)真实分布 p ( x ) p(x) p(x)的得分函数尽可能接近
  3. Score Matching的损失函数: L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2],其中的期望差异可以帮助模型更全面地学习到真实分布的特征

接下来,对Score Matching的损失函数 L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2]进行推导:

  1. 展开欧几里得范数的平方项: ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 = ∥ ∇ x log ⁡ q ( x ; θ ) ∥ 2 − 2 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) + ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\|^2=\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2 ∥∇xlogq(x;θ)−∇xlogp(x)∥2=∥∇xlogq(x;θ)∥2−2∇xlogq(x;θ)⋅∇xlogp(x)+∥∇xlogp(x)∥2
  2. 将上式代入原始损失函数中可得 L ( θ ) = E x ∼ p ( x ) [ 1 2 ( ∥ ∇ x log ⁡ q ( x ; θ ) ∥ 2 − 2 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) + ∥ ∇ x log ⁡ p ( x ) ∥ 2 ) ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left(\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2\right)\right] L(θ)=Ex∼p(x)[21(∥∇xlogq(x;θ)∥2−2∇xlogq(x;θ)⋅∇xlogp(x)+∥∇xlogp(x)∥2)]
  3. 消除不可计算的项:由于不知道真实分布的 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x),我们无法直接计算 ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2和 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) \nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x) ∇xlogq(x;θ)⋅∇xlogp(x)

接下来,笔者给出如何消除不可计算项的过程:

  1. 由于 ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2和 θ \theta θ无关,它仅仅依赖于真实数据分布 p ( x ) p(x) p(x),所以可以直接消掉 ∥ ∇ x log ⁡ p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2
  2. 对损失函数中的项 ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) \nabla_{x}\log q(x;\theta)\cdot\nabla_{x}\log p(x) ∇xlogq(x;θ)⋅∇xlogp(x)进行分部积分 ⇒ \Rightarrow ⇒原始的积分形式为 ∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) d x \int p(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx,应用分部积分,将其中的 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x)转移到 p ( x ) p(x) p(x)上,设 f ( x ) = p ( x ) f(\boldsymbol{x})=p(\boldsymbol{x}) f(x)=p(x)、 g ( x ) = ∇ x log ⁡ q ( x ; θ ) g(\boldsymbol{x})=\nabla_x\log q(\boldsymbol{x};\theta) g(x)=∇xlogq(x;θ) ⇒ \Rightarrow ⇒ ∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) d x = − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int p(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx=−∫p(x)∇x2logq(x;θ)dx

∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x log ⁡ p ( x ) d x = − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int p(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx=−∫p(x)∇x2logq(x;θ)dx的推导过程:

  1. ∫ p ( x ) ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) p ( x ) d x = ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x \int p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\frac{\nabla_{\boldsymbol{x}}p(\boldsymbol{x})}{p(\boldsymbol{x})} d\boldsymbol{x}=\int\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\nabla_{x}p(\boldsymbol{x}) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅p(x)∇xp(x)dx=∫∇xlogq(x;θ)⋅∇xp(x)dx
  2. 分部积分: ∫ A ⋅ ∇ x B d x = ∫ ∇ x ⋅ ( A B ) d x − ∫ B ∇ x ⋅ A d x \int\mathbf{A}\cdot\nabla_x\mathbf{B} d\boldsymbol{x}=\int\nabla_x\cdot\left(\mathbf{A}\mathbf{B}\right)d\boldsymbol{x}-\int\mathbf{B}\nabla_x\cdot\mathbf{A} d\boldsymbol{x} ∫A⋅∇xBdx=∫∇x⋅(AB)dx−∫B∇x⋅Adx,其中 A \mathbf{A} A和 B \mathbf{B} B是向量场(向量场:空间中的每个点,都有一个向量与之对应), A = ∇ x log ⁡ q ( x ; θ ) \mathbf{A}=\nabla_{x}\log q(\boldsymbol{x};\theta) A=∇xlogq(x;θ)、 B = p ( x ) \mathbf{B}=p(\boldsymbol{x}) B=p(x)
  3. ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x = ∫ ∇ x ⋅ ( p ( x ) ∇ x log ⁡ q ( x ; θ ) ) d x − ∫ p ( x ) ∇ x ⋅ ∇ x log ⁡ q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x}) d\boldsymbol{x}=\int\nabla_x\cdot(p(\boldsymbol{x})\nabla_x\log q(\boldsymbol{x};\theta)) d\boldsymbol{x}-\int p(\boldsymbol{x})\nabla_x\cdot\nabla_x\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=∫∇x⋅(p(x)∇xlogq(x;θ))dx−∫p(x)∇x⋅∇xlogq(x;θ)dx
  4. 由于 ∇ x ⋅ ∇ x log ⁡ q ( x ; θ ) = ∇ x 2 log ⁡ q ( x ; θ ) \nabla_x\cdot\nabla_x\log q(\boldsymbol{x};\theta)=\nabla_x^2\log q(\boldsymbol{x};\theta) ∇x⋅∇xlogq(x;θ)=∇x2logq(x;θ)可知 ⇒ \Rightarrow ⇒ ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x = ∫ ∇ x ⋅ ( p ( x ) ∇ x log ⁡ q ( x ; θ ) ) d x − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\nabla_{x}p(\boldsymbol{x}) d\boldsymbol{x}=\int\nabla_{x}\cdot\left(p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\right)d\boldsymbol{x}-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=∫∇x⋅(p(x)∇xlogq(x;θ))dx−∫p(x)∇x2logq(x;θ)dx
  5. 假设该边界项 ∫ ∇ x ⋅ ( p ( x ) ∇ x log ⁡ q ( x ; θ ) ) d x \int\nabla_{x}\cdot\left(p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\right)d\boldsymbol{x} ∫∇x⋅(p(x)∇xlogq(x;θ))dx在适当的条件下为零,则 ∫ ∇ x log ⁡ q ( x ; θ ) ⋅ ∇ x p ( x ) d x = − ∫ p ( x ) ∇ x 2 log ⁡ q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x})d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_x^2\log q(\boldsymbol{x};\theta)d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=−∫p(x)∇x2logq(x;θ)dx
  6. L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) − ∇ x log ⁡ p ( x ) ∥ 2 ] = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log ⁡ q ( x ; θ ) ∥ 2 + ∇ x 2 log ⁡ q ( x ; θ ) ] \begin{aligned} L(\theta) &=\mathbb{E}{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right]\\ &=\mathbb{E}{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\frac{1}{2}\|\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\|^2+\nabla_{\boldsymbol{x}}^2\log q(\boldsymbol{x};\theta)\right] \end{aligned} L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2]=Ex∼p(x)[21∥∇xlogq(x;θ)∥2+∇x2logq(x;θ)]

至此,我们可以通过损失函数 L ( θ ) L(\theta) L(θ)使 ∇ x log ⁡ q ( x ; θ ) \nabla_x\log q(x;\theta) ∇xlogq(x;θ)接近 ∇ x log ⁡ p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x),进而求出SDE去噪过程中的 d x dx dx

笔者也是刚刚接触SDE,如果文中出现错误,请各位读者指正

参考文献

1、生成扩散模型漫谈(五):一般框架之SDE篇

2、SDE公式推导

3、SDE的底层原理

4、AIGC: SGM (Score-based Generative Model) 笔记

5、Score Matching(得分匹配)

相关推荐
Aurora-Borealis.5 分钟前
Day27 机器学习流水线
人工智能·机器学习
黑符石2 小时前
【论文研读】Madgwick 姿态滤波算法报告总结
人工智能·算法·机器学习·imu·惯性动捕·madgwick·姿态滤波
JQLvopkk2 小时前
智能AI“学习功能”在程序开发部分的逻辑
人工智能·机器学习·计算机视觉
jiayong233 小时前
model.onnx 深度分析报告(第2篇)
人工智能·机器学习·向量数据库·向量模型
张祥6422889043 小时前
数理统计基础一
人工智能·机器学习·概率论
悟乙己3 小时前
使用TimeGPT进行时间序列预测案例解析
机器学习·大模型·llm·时间序列·预测
云和数据.ChenGuang4 小时前
人工智能实践之基于CNN的街区餐饮图片识别案例实践
人工智能·深度学习·神经网络·机器学习·cnn
人工智能培训5 小时前
什么是马尔可夫决策过程(MDP)?马尔可夫性的核心含义是什么?
人工智能·深度学习·机器学习·cnn·智能体·马尔可夫决策
木头左5 小时前
基于集成学习的多因子特征融合策略在指数期权方向性预测中的应用
人工智能·机器学习·集成学习
星河耀银海5 小时前
人工智能从入门到精通:机器学习基础算法实战与应用
人工智能·算法·机器学习