文章目录
随机微分方程SDE
笔者建议,学完DDPM再来看SDE的作用和推导过程
标准布朗运动
在学习随机微分方程之前,我们先来看一下什么是标准布朗运动
假设有一个一维的直线,有个小人从原点出发,每次随机地选择向左走1格或者向右走1格,且向左走和向右走的两个选项,被选择的概率相等 → \rightarrow →用 S t S_t St代表小人离原点的距离, t t t代表代表选择的次数,如果选择的次数越多,那么 S t S_t St将会逐渐服从一个均值为0、方差为 t t t的正态分布
布朗运动 W ( t ) W(t) W(t)是期望为0、方差为 t t t的正态分布 ⇔ \Leftrightarrow ⇔ W t ∼ N ( 0 , t ) W_t\sim \mathcal{N}(0,t) Wt∼N(0,t) ⇒ \Rightarrow ⇒ W t + Δ t − W t ∼ N ( 0 , Δ t ) W_{t+\Delta t}-W_t\sim \mathcal{N}(0,\Delta t) Wt+Δt−Wt∼N(0,Δt),当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, d w = d t ε dw=\sqrt{dt}\varepsilon dw=dt ε(重参数技巧)
SDE加噪
在DDPM中,扩散过程被划分为固定的T步 ⇒ \Rightarrow ⇒DDPM=拆楼+建楼 ⇒ \Rightarrow ⇒"拆楼"和"建楼"都被事先划分为了T步,这个划分有着相当大的人为性。事实上,真实的"拆"、"建"过程应该是没有刻意划分的步骤 ⇒ \Rightarrow ⇒可以将它们理解为一个在时间上连续的变换过程,可以用随机微分方程(Stochastic Differential Equation,SDE)来描述,即 d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt,其中 f t ( x t ) f_t(x_t) ft(xt)是漂移项,描述数据的确定性演化 ; g t g_t gt是扩散项,描述的是噪声的扩散程度 ; d w t dw_t dwt是维纳运动(布朗运动)的微小增量,表示随机波动
随机微分方程: d x = dx= dx=确定的变化 + + +随机的变化,其中随机的变化代表着随机性
随机微分方程描述了系统从 t t t时刻到 t + Δ t t+\Delta t t+Δt时刻的变化
我们可以将随机微分方程看成是 x t + Δ t − x t = f t ( x t ) Δ t + g t Δ t ε , ε ∼ N ( 0 , I ) \boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t=\boldsymbol{f}_t(\boldsymbol{x}_t)\Delta t+g_t\sqrt{\Delta t}\boldsymbol{\varepsilon},\quad\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I}) xt+Δt−xt=ft(xt)Δt+gtΔt ε,ε∼N(0,I)在 Δ t → 0 \Delta t\rightarrow 0 Δt→0时的极限 ⇒ \Rightarrow ⇒如果拆楼要1天,那么拆楼就是 x x x从 t = 0 t=0 t=0到 t = 1 t=1 t=1时刻的变化 Δ t → 0 \Delta t\rightarrow 0 Δt→0越小的 Δ t \Delta t Δt意味着对原始噪声越好的近似。如果 Δ t = 0.001 \Delta t=0.001 Δt=0.001,对应着 T = 1000 T=1000 T=1000;如果 Δ t = 0.01 \Delta t=0.01 Δt=0.01,则对应 T = 100 T=100 T=100 ⇒ \Rightarrow ⇒引入SDE的本质好处是"将理论分析和代码实现分离开来"
DDPM的加噪过程本质上是一个SDE,而SDE本质上描述的是微小时间变化下系统状态的变化
- DDPM的加噪: x t + 1 = 1 − β t x t + β t ϵ x_{t+1}=\sqrt{1-\beta_t}x_t+\sqrt{\beta_t}\epsilon xt+1=1−βt xt+βt ϵ
- SDE的加噪: d x = f t ( x ) d t + g t d w t d\boldsymbol{x}=\boldsymbol{f}_t(\boldsymbol{x})dt+g_td\boldsymbol{w_t} dx=ft(x)dt+gtdwt
在这里,笔者介绍一下将DDPM加噪公式映射到SDE加噪公式的推导过程:
- 重写DDPM加噪公式: x t + 1 − x t = ( 1 − β t − 1 ) x t + β t ϵ x_{t+1}-x_t=(\sqrt{1-\beta_t}-1)x_t+\sqrt{\beta_t}\epsilon xt+1−xt=(1−βt −1)xt+βt ϵ ⇒ \Rightarrow ⇒ 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1−βt ≈1−2βt ⇒ \Rightarrow ⇒将DDPM加噪公式重新表示为一个确定项和随机噪声项的和: x t + 1 − x t ≈ − β t 2 x t + β t ϵ x_{t+1}-x_{t}\approx-\frac{\beta_{t}}{2}x_{t}+\sqrt{\beta_{t}}\epsilon xt+1−xt≈−2βtxt+βt ϵ
在这里,使用泰勒展开得到 1 − β t ≈ 1 − β t 2 \sqrt{1-\beta_t}\approx1-\frac{\beta_t}2 1−βt ≈1−2βt
先来介绍一下泰勒展开:如果 f ( x ) f(x) f(x)在 x = a x=a x=a处是可微的,则它的泰勒展开可以写为 f ( x ) ≈ f ( a ) + f ′ ( a ) ( x − a ) + f ′ ′ ( a ) 2 ! ( x − a ) 2 + ... f(x)\approx f(a)+f'(a)(x-a)+\frac{f''(a)}{2!}(x-a)^2+\ldots f(x)≈f(a)+f′(a)(x−a)+2!f′′(a)(x−a)2+...,其中 f ′ ( a ) f'(a) f′(a)和 f ′ ′ ( a ) f''(a) f′′(a)分别是 f ( x ) f(x) f(x)在 a a a处的一阶导数和二阶导数;在泰勒展开中,若函数依赖多个变量,需要对每个变量分别进行展开
f ( β t ) = 1 − β t f(\beta_t)=\sqrt{1-\beta_t} f(βt)=1−βt 在 β t = 0 \beta_t=0 βt=0处展开 ⇒ \Rightarrow ⇒零阶项: f ( 0 ) = 1 − 0 = 1 f(0)=\sqrt{1-0}=1 f(0)=1−0 =1;一阶导数: f ′ ( β t ) = d d β t 1 − β t = − 1 2 1 − β t f'(\beta_t)=\frac{d}{d\beta_t}\sqrt{1-\beta_t}=\frac{-1}{2\sqrt{1-\beta_t}} f′(βt)=dβtd1−βt =21−βt −1,在 β t = 0 \beta_t=0 βt=0处 f ′ ( 0 ) = − 1 2 1 − 0 = − 1 2 f'(0)=\frac{-1}{2\sqrt{1-0}}=-\frac{1}{2} f′(0)=21−0 −1=−21 ⇒ \Rightarrow ⇒ f ( β t ) ≈ f ( 0 ) + f ′ ( 0 ) β t f(\beta_t)\approx f(0)+f'(0)\beta_t f(βt)≈f(0)+f′(0)βt ⇒ \Rightarrow ⇒ 1 − β t ≈ 1 − 1 2 β t \sqrt{1-\beta_t}\approx1-\frac{1}{2}\beta_t 1−βt ≈1−21βt - 引入 Δ t \Delta t Δt: Δ t \Delta t Δt只是在数学上引入的时间增量,而 β t \beta_t βt在离散模型中的定义是独立于 Δ t \Delta t Δt的,将DDPM离散的加噪过程转换为连续时间的随机微分方程描述: x t + Δ t − x t ≈ − β t 2 x t Δ t + β t Δ t ϵ x_{t+\Delta t}-x_t\approx-\frac{\beta_t}{2}x_t\Delta t+\sqrt{\beta_t\Delta t}\epsilon xt+Δt−xt≈−2βtxtΔt+βtΔt ϵ ⇒ \Rightarrow ⇒ d x = − 1 2 β t x t d t + β t d w dx=-\frac{1}{2}\beta_tx_tdt+\sqrt{\beta_t}dw dx=−21βtxtdt+βt dw
- SDE的形式:漂移项 f t ( x t ) = − β t 2 x t f_t(x_t)=-\frac{\beta_t}2x_t ft(xt)=−2βtxt,扩散系数 g t = β t g_{t}=\sqrt{\beta_{t}} gt=βt
左侧是数据分布,右侧是正态分布,t是连续时间
SDE去噪
SDE去噪的目标是求 p ( x t ∣ x t + Δ t ) p(x_t|x_{t+\Delta t}) p(xt∣xt+Δt)
已知: x t + Δ t x_{t+\Delta t} xt+Δt和前向SDE过程 p ( x t + Δ t ∣ x t ) p(x_{t+\Delta t}|x_t) p(xt+Δt∣xt) ⇒ \Rightarrow ⇒贝叶斯公式: p ( x t ∣ x t + Δ t ) = p ( x t + Δ t ∣ x t ) p ( x t ) p ( x t + Δ t ) p(x_t|x_{t+\Delta t})=\frac{p(x_{t+\Delta t}|x_t)p(x_t)}{p(x_{t+\Delta t})} p(xt∣xt+Δt)=p(xt+Δt)p(xt+Δt∣xt)p(xt)
为了简化问题,尽可能使 p ( x t ∣ x t + Δ t ) p(x_t|x_{t+\Delta t}) p(xt∣xt+Δt)的分布满足正太分布
- x t + Δ t = x t + f t ( x t ) Δ t + g t Δ t ϵ x_{t+\Delta t}=x_t+f_t(x_t)\Delta t+g_t\sqrt{\Delta t}\epsilon xt+Δt=xt+ft(xt)Δt+gtΔt ϵ ⇒ \Rightarrow ⇒根据重参数可得: x t + Δ t ∼ N ( x t + f t ( x t ) Δ t , g t 2 Δ t ) x_{t+\Delta t}\sim\mathcal{N}(x_t+f_t(x_t)\Delta t,g_t^2\Delta t) xt+Δt∼N(xt+ft(xt)Δt,gt2Δt)
- 正态分布的概率密度函数: f ( x ) = 1 σ 2 π e x p ( − ( x − μ ) 2 2 σ 2 ) f(x)=\frac1{\sigma\sqrt{2\pi}}exp(-\frac{(x-\mu)^2}{2\sigma^2}) f(x)=σ2π 1exp(−2σ2(x−μ)2)
- p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − f t ( x t ) Δ t ) 2 2 g t 2 Δ t + l o g p ( x t ) − l o g p ( x t + Δ t ) ) ( 1 ) \begin{aligned} p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-f_t(x_t)\Delta t)^2}{2g_t^2\Delta t}+logp(x_t)-logp(x_{t+\Delta t}))&&(1)\\ \end{aligned} p(xt∣xt+Δt)=exp(−2gt2Δt(xt+Δt−xt−ft(xt)Δt)2+logp(xt)−logp(xt+Δt))(1)
- 在 x t x_t xt处泰勒展开 l o g p ( x t + Δ t ) logp(x_{t+\Delta t}) logp(xt+Δt): l o g p ( x t + Δ t ) ≈ l o g p ( x t ) + ( x t + Δ t − x t ) ∇ x l o g p ( x t ) + ( t + Δ t − t ) ∇ t l o g p ( x t ) \begin{aligned} logp(x_{t+\Delta t})\approx logp(x_t)+(x_{t+\Delta t}-x_t)\nabla_xlogp(x_t)+(t+\Delta t-t)\nabla_tlogp(x_t) \end{aligned} logp(xt+Δt)≈logp(xt)+(xt+Δt−xt)∇xlogp(xt)+(t+Δt−t)∇tlogp(xt),其中 ∇ x log p ( x t ) \nabla_{x}\log p(x_{t}) ∇xlogp(xt)表示状态变量变化对 log p ( x t ) \log p(x_{t}) logp(xt)的影响, ∇ t log p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt)表示时间变化对 log p ( x t ) \log p(x_{t}) logp(xt)的影响
在这里笔者介绍一下为什么会多出一项 ∇ t log p ( x t ) \nabla_{t}\log p(x_{t}) ∇tlogp(xt): p ( x t ) p(x_t) p(xt)实际上是" t t t时刻随机变量等于 x t x_t xt的概率密度", p ( x t + Δ t ) p(x_{t+\Delta t}) p(xt+Δt)实际上是" t + Δ t t+\Delta t t+Δt时刻随机变量等于 x t + Δ t x_{t+\Delta t} xt+Δt的概率密度",即 p ( x t ) p(x_t) p(xt)实际上同时是 t t t和 x t x_t xt的函数,概率密度不会会因为状态的变化而变化,还会因为时间的推移而变换,所以要多一项 t t t的偏导数,考虑状态和时间的变化对概率密度的双重影响 ⇒ \Rightarrow ⇒泰勒展开中的零阶项反映的是静态的信息,不涉及任何随时间变化的因素;一阶项捕捉了概率密度函数在时间 t t t和状态 x x x处的变化趋势 - 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, Δ 2 t = 0 \Delta^2 t=0 Δ2t=0
- ( 1 ) = p ( x t ∣ x t + Δ t ) = e x p ( − ( x t + Δ t − x t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) 2 2 g t 2 Δ t ) = e x p ( − ( x t − ( x t + Δ t − ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) Δ t ) ) 2 2 g t 2 Δ t ) ( 2 ) \begin{aligned} (1)=p(x_t|x_{t+\Delta t}) &=exp(-\frac{(x_{t+\Delta t}-x_t-(f_t(x_t)-g_t^2\nabla xlogp(x_t))\Delta t)^2}{2g_t^2\Delta t})\\ &=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t}(x_t)-g_t^2\nabla xlogp(x_t))\Delta t))^2}{2g_t^2\Delta t})&&(2) \end{aligned} (1)=p(xt∣xt+Δt)=exp(−2gt2Δt(xt+Δt−xt−(ft(xt)−gt2∇xlogp(xt))Δt)2)=exp(−2gt2Δt(xt−(xt+Δt−(ft(xt)−gt2∇xlogp(xt))Δt))2)(2)
- 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, t + Δ t → t t+\Delta t\rightarrow t t+Δt→t
- ( 2 ) = e x p ( − ( x t − ( x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t ) ) 2 2 g t + Δ t 2 Δ t ) \begin{aligned} (2)=exp(-\frac{(x_t-(x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t))^2}{2g_{t+\Delta t}^2\Delta t}) \end{aligned} (2)=exp(−2gt+Δt2Δt(xt−(xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt))2)
- x t = x t + Δ t − ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_t=x_{t+\Delta t}-(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt=xt+Δt−(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt+gt+ΔtΔt ϵ
- 当 Δ t → 0 \Delta t\rightarrow 0 Δt→0时, Δ t → d t \Delta t\rightarrow dt Δt→dt
- x t + Δ t − x t = ( f t + Δ t ( x t + Δ t ) − g t + Δ t 2 ∇ x l o g p ( x t + Δ t ) ) Δ t + g t + Δ t Δ t ϵ x_{t+\Delta t}-x_t=(f_{t+\Delta t}(x_{t+\Delta t})-g_{t+\Delta t}^2\nabla xlogp(x_{t+\Delta t}))\Delta t+g_{t+\Delta t}\sqrt{\Delta t}\epsilon xt+Δt−xt=(ft+Δt(xt+Δt)−gt+Δt2∇xlogp(xt+Δt))Δt+gt+ΔtΔt ϵ
12.12. d x t = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d t ϵ = ( f t ( x t ) − g t 2 ∇ x l o g p ( x t ) ) d t + g t d w ˉ \begin{aligned} dx_t=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}\sqrt{dt}\epsilon=(f_{t}(x_{t})-g_{t}^2\nabla xlogp(x_{t}))dt+g_{t}d\bar{w} \end{aligned} dxt=(ft(xt)−gt2∇xlogp(xt))dt+gtdt ϵ=(ft(xt)−gt2∇xlogp(xt))dt+gtdwˉ
Score Matching
在上述SDE去噪过程中,我们得到 d x t = ( f t ( x t ) − g t 2 ∇ x t l o g p ( x t ) ) d t + g t d w ˉ \begin{aligned} dx_t=(f_{t}(x_{t})-g_{t}^2\nabla_{x_t}logp(x_{t}))dt+g_{t}d\bar{w} \end{aligned} dxt=(ft(xt)−gt2∇xtlogp(xt))dt+gtdwˉ,那么 ∇ x t log p ( x t ) \nabla_{x_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)这一项是什么?如果不知道这一项的话,似乎我们也无法得到 d x t dx_t dxt的值
我们先来看一下 ∇ x t log p ( x t ) \nabla_{x_t}\log p(\boldsymbol{x}t) ∇xtlogp(xt)的含义:对数概率密度函数 log p ( x t ) \log p(x{t}) logp(xt)关于 x t x_t xt的梯度 → \rightarrow →梯度指向概率密度变化最快的方向, ∇ x t log p ( x t ) \nabla_{\boldsymbol{x}t}\log p(\boldsymbol{x}t) ∇xtlogp(xt)给出移动方向和移动速度以找到概率更大的区域 ⇒ \Rightarrow ⇒使用一个 θ \theta θ参数化的概率分布 p θ p{\theta} pθ模拟 p p p,通过学习参数 θ \theta θ使 p θ p{\theta} pθ接近 p p p
我们可以将 p θ p_{\theta} pθ看成是由两部分组成的,分别是表示密度的函数 p θ ~ \tilde{p_{\theta}} pθ~、归一化因子 Z θ Z_{\theta} Zθ ⇒ \Rightarrow ⇒ p θ ( x ) = p ~ θ ( x ) Z θ = p ~ θ ( x ) ∫ x ∈ X p ~ θ ( x ) d x p_\theta(x)=\frac{\tilde{p}\theta(x)}{Z\theta}=\frac{\tilde{p}\theta(x)}{\int{x\in X}\tilde{p}\theta(x)dx} pθ(x)=Zθp~θ(x)=∫x∈Xp~θ(x)dxp~θ(x),其中未归一化的概率密度函数 p θ ~ \tilde{p{\theta}} pθ~给出某个数据点 x x x相对于其他数据点的可能性大小,但并不能给出直接用于表示 x x x发生的真实概率
目前,使用极大似然估计求解 θ \theta θ的问题:不知道归一化因子 Z θ Z_{\theta} Zθ的值
解决方法:
- 引入得分函数(score function):概率密度函数的梯度 ∇ x log p θ ( x ) \nabla_x\log p_\theta(x) ∇xlogpθ(x)
- 将 p θ ( x ) p_{\theta}(x) pθ(x)通过 l o g log log拆分成两项 ∇ x log p ~ θ ( x ) − ∇ x log Z θ \nabla_{x}\log\tilde{p}{\theta}(x)-\nabla{x}\log Z_{\theta} ∇xlogp~θ(x)−∇xlogZθ ⇒ \Rightarrow ⇒由于求解的是 x x x的梯度,所以可以直接消掉 ∇ x log Z θ \nabla_{x}\log Z_{\theta} ∇xlogZθ,因为 ∇ x log Z θ \nabla_{x}\log Z_{\theta} ∇xlogZθ与 x x x无关;同时 p θ ~ \tilde{p_{\theta}} pθ~不受"概率分布"的约束,可以使用神经网络作为 p θ ~ \tilde{p_{\theta}} pθ~,因为 p θ ~ \tilde{p_{\theta}} pθ~本身就不是概率密度函数, p θ ~ \tilde{p_{\theta}} pθ~只是密度函数
- 目标:选择一个loss让 ∇ x log p θ ( x ) \nabla_x\log p_\theta(x) ∇xlogpθ(x)尽可能接近 ∇ x t log p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)
新的问题:不知道数据分布的 score function ∇ x t log p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)
为了简化公式,下面公式中的 ∇ x log p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x)等同于 ∇ x t log p ( x t ) \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) ∇xtlogp(xt)
解决方法:Score Matching
- Score Matching:用于估计概率密度函数的梯度(得分函数 score ( x ) = ∇ x t log p ( x t ) \operatorname{score}(x)=\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) score(x)=∇xtlogp(xt)),而无需知道密度函数的归一化常数
- Score Matching的目标:学习一个模型 q ( x ; θ ) q(x;\theta) q(x;θ),使得模型得分函数 ∇ x log q ( x ; θ ) \nabla_x\log q(x;\theta) ∇xlogq(x;θ)真实分布 p ( x ) p(x) p(x)的得分函数尽可能接近
- Score Matching的损失函数: L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2],其中的期望差异可以帮助模型更全面地学习到真实分布的特征
接下来,对Score Matching的损失函数 L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right] L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2]进行推导:
- 展开欧几里得范数的平方项: ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 = ∥ ∇ x log q ( x ; θ ) ∥ 2 − 2 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) + ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\|^2=\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2 ∥∇xlogq(x;θ)−∇xlogp(x)∥2=∥∇xlogq(x;θ)∥2−2∇xlogq(x;θ)⋅∇xlogp(x)+∥∇xlogp(x)∥2
- 将上式代入原始损失函数中可得 L ( θ ) = E x ∼ p ( x ) [ 1 2 ( ∥ ∇ x log q ( x ; θ ) ∥ 2 − 2 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) + ∥ ∇ x log p ( x ) ∥ 2 ) ] L(\theta)=\mathbb{E}_{x\sim p(x)}\left[\frac12\left(\|\nabla_x\log q(x;\theta)\|^2-2\nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x)+\|\nabla_x\log p(x)\|^2\right)\right] L(θ)=Ex∼p(x)[21(∥∇xlogq(x;θ)∥2−2∇xlogq(x;θ)⋅∇xlogp(x)+∥∇xlogp(x)∥2)]
- 消除不可计算的项:由于不知道真实分布的 ∇ x log p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x),我们无法直接计算 ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2和 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) \nabla_x\log q(x;\theta)\cdot\nabla_x\log p(x) ∇xlogq(x;θ)⋅∇xlogp(x)
接下来,笔者给出如何消除不可计算项的过程:
- 由于 ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2和 θ \theta θ无关,它仅仅依赖于真实数据分布 p ( x ) p(x) p(x),所以可以直接消掉 ∥ ∇ x log p ( x ) ∥ 2 \|\nabla_x\log p(x)\|^2 ∥∇xlogp(x)∥2
- 对损失函数中的项 ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) \nabla_{x}\log q(x;\theta)\cdot\nabla_{x}\log p(x) ∇xlogq(x;θ)⋅∇xlogp(x)进行分部积分 ⇒ \Rightarrow ⇒原始的积分形式为 ∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) d x \int p(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx,应用分部积分,将其中的 ∇ x log p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x)转移到 p ( x ) p(x) p(x)上,设 f ( x ) = p ( x ) f(\boldsymbol{x})=p(\boldsymbol{x}) f(x)=p(x)、 g ( x ) = ∇ x log q ( x ; θ ) g(\boldsymbol{x})=\nabla_x\log q(\boldsymbol{x};\theta) g(x)=∇xlogq(x;θ) ⇒ \Rightarrow ⇒ ∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) d x = − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int p(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx=−∫p(x)∇x2logq(x;θ)dx
∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x log p ( x ) d x = − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int p(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\cdot\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅∇xlogp(x)dx=−∫p(x)∇x2logq(x;θ)dx的推导过程:
- ∫ p ( x ) ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) p ( x ) d x = ∫ ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) d x \int p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\frac{\nabla_{\boldsymbol{x}}p(\boldsymbol{x})}{p(\boldsymbol{x})} d\boldsymbol{x}=\int\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\nabla_{x}p(\boldsymbol{x}) d\boldsymbol{x} ∫p(x)∇xlogq(x;θ)⋅p(x)∇xp(x)dx=∫∇xlogq(x;θ)⋅∇xp(x)dx
- 分部积分: ∫ A ⋅ ∇ x B d x = ∫ ∇ x ⋅ ( A B ) d x − ∫ B ∇ x ⋅ A d x \int\mathbf{A}\cdot\nabla_x\mathbf{B} d\boldsymbol{x}=\int\nabla_x\cdot\left(\mathbf{A}\mathbf{B}\right)d\boldsymbol{x}-\int\mathbf{B}\nabla_x\cdot\mathbf{A} d\boldsymbol{x} ∫A⋅∇xBdx=∫∇x⋅(AB)dx−∫B∇x⋅Adx,其中 A \mathbf{A} A和 B \mathbf{B} B是向量场(向量场:空间中的每个点,都有一个向量与之对应), A = ∇ x log q ( x ; θ ) \mathbf{A}=\nabla_{x}\log q(\boldsymbol{x};\theta) A=∇xlogq(x;θ)、 B = p ( x ) \mathbf{B}=p(\boldsymbol{x}) B=p(x)
- ∫ ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) d x = ∫ ∇ x ⋅ ( p ( x ) ∇ x log q ( x ; θ ) ) d x − ∫ p ( x ) ∇ x ⋅ ∇ x log q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x}) d\boldsymbol{x}=\int\nabla_x\cdot(p(\boldsymbol{x})\nabla_x\log q(\boldsymbol{x};\theta)) d\boldsymbol{x}-\int p(\boldsymbol{x})\nabla_x\cdot\nabla_x\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=∫∇x⋅(p(x)∇xlogq(x;θ))dx−∫p(x)∇x⋅∇xlogq(x;θ)dx
- 由于 ∇ x ⋅ ∇ x log q ( x ; θ ) = ∇ x 2 log q ( x ; θ ) \nabla_x\cdot\nabla_x\log q(\boldsymbol{x};\theta)=\nabla_x^2\log q(\boldsymbol{x};\theta) ∇x⋅∇xlogq(x;θ)=∇x2logq(x;θ)可知 ⇒ \Rightarrow ⇒ ∫ ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) d x = ∫ ∇ x ⋅ ( p ( x ) ∇ x log q ( x ; θ ) ) d x − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int\nabla_{x}\log q(\boldsymbol{x};\theta)\cdot\nabla_{x}p(\boldsymbol{x}) d\boldsymbol{x}=\int\nabla_{x}\cdot\left(p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\right)d\boldsymbol{x}-\int p(\boldsymbol{x})\nabla_{x}^{2}\log q(\boldsymbol{x};\theta) d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=∫∇x⋅(p(x)∇xlogq(x;θ))dx−∫p(x)∇x2logq(x;θ)dx
- 假设该边界项 ∫ ∇ x ⋅ ( p ( x ) ∇ x log q ( x ; θ ) ) d x \int\nabla_{x}\cdot\left(p(\boldsymbol{x})\nabla_{x}\log q(\boldsymbol{x};\theta)\right)d\boldsymbol{x} ∫∇x⋅(p(x)∇xlogq(x;θ))dx在适当的条件下为零,则 ∫ ∇ x log q ( x ; θ ) ⋅ ∇ x p ( x ) d x = − ∫ p ( x ) ∇ x 2 log q ( x ; θ ) d x \int\nabla_x\log q(\boldsymbol{x};\theta)\cdot\nabla_xp(\boldsymbol{x})d\boldsymbol{x}=-\int p(\boldsymbol{x})\nabla_x^2\log q(\boldsymbol{x};\theta)d\boldsymbol{x} ∫∇xlogq(x;θ)⋅∇xp(x)dx=−∫p(x)∇x2logq(x;θ)dx
- L ( θ ) = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) − ∇ x log p ( x ) ∥ 2 ] = E x ∼ p ( x ) [ 1 2 ∥ ∇ x log q ( x ; θ ) ∥ 2 + ∇ x 2 log q ( x ; θ ) ] \begin{aligned} L(\theta) &=\mathbb{E}{x\sim p(x)}\left[\frac12\left\|\nabla_x\log q(x;\theta)-\nabla_x\log p(x)\right\|^2\right]\\ &=\mathbb{E}{\boldsymbol{x}\sim p(\boldsymbol{x})}\left[\frac{1}{2}\|\nabla_{\boldsymbol{x}}\log q(\boldsymbol{x};\theta)\|^2+\nabla_{\boldsymbol{x}}^2\log q(\boldsymbol{x};\theta)\right] \end{aligned} L(θ)=Ex∼p(x)[21∥∇xlogq(x;θ)−∇xlogp(x)∥2]=Ex∼p(x)[21∥∇xlogq(x;θ)∥2+∇x2logq(x;θ)]
至此,我们可以通过损失函数 L ( θ ) L(\theta) L(θ)使 ∇ x log q ( x ; θ ) \nabla_x\log q(x;\theta) ∇xlogq(x;θ)接近 ∇ x log p ( x ) \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) ∇xlogp(x),进而求出SDE去噪过程中的 d x dx dx
笔者也是刚刚接触SDE,如果文中出现错误,请各位读者指正
参考文献
2、SDE公式推导
3、SDE的底层原理