22-扩散模型完全指南:从直觉到数学的深度解析
引言
扩散模型(Diffusion Models)是当前最先进的生成模型之一,在图像生成领域取得了超越GAN的效果。从DALL-E 2到Stable Diffusion,从Midjourney到Imagen,这些震撼世界的AI艺术工具背后,都是扩散模型的身影。
本文目标:从零开始,用最详细的方式介绍扩散模型的完整原理------从最初的直觉理解,到严格的数学推导,再到实际的训练和采样过程。
适合人群:具备基础的概率论知识(高斯分布、条件概率)和深度学习基础的读者。
第一部分:直觉理解 - 扩散模型在做什么?
从一滴墨水说起
想象一个实验:在一杯清水中滴入一滴墨水。
观察到的现象:
- 0秒:墨水是一个清晰的黑点
- 10秒:墨水开始扩散,边缘模糊
- 30秒:墨水继续扩散,形状变得不规则
- 5分钟:墨水充分扩散,整杯水变成均匀的淡灰色
这个过程就是扩散(Diffusion):有序的结构逐渐变成无序的均匀分布。
物理本质:墨水分子受到随机的布朗运动,从高浓度区域向低浓度区域移动,最终达到熵最大的状态------均匀分布。
逆转时间:从噪声到图像
扩散模型的核心想法:如果我们能逆转扩散过程,就能从噪声生成有序的结构。
类比:
- 正向扩散:清晰图像 → 加噪声 → 加噪声 → ... → 纯噪声
- 反向扩散:纯噪声 → 去噪 → 去噪 → ... → 清晰图像
但问题是:物理上的扩散是不可逆的(熵增原理)。你不可能把那杯灰水变回一滴墨水。
扩散模型的突破 :虽然物理上不可逆,但我们可以训练一个神经网络学习逆过程。
为什么要这么做?
生成模型的目标:从随机噪声生成真实的数据(图像、音频、文本等)。
传统GAN的思路:
- 生成器:噪声 → 图像(一步到位)
- 判别器:判断图像真假
- 问题:训练不稳定,模式崩溃
扩散模型的思路:
- 不一步到位,而是逐步去噪
- 每一步只需要预测"噪声的一小部分"
- 分解了复杂问题,训练更稳定
类比:
- GAN:要求你一笔画出蒙娜丽莎(难!)
- 扩散模型:给你一张模糊的蒙娜丽莎,让你一点点清晰化(容易得多!)
扩散模型的三个核心问题
理解扩散模型,需要回答三个问题:
问题1:如何定义"逐步加噪"的过程?
- 答案:设计一个马尔可夫链,每步加一点高斯噪声
问题2:如何学习"逐步去噪"的过程?
- 答案:训练神经网络预测每一步添加的噪声
问题3:如何从噪声生成图像?
- 答案:从纯噪声开始,逐步运行去噪网络
接下来,我们将用严格的数学语言回答这三个问题。
第二部分:前向扩散过程 - 从图像到噪声的数学描述
马尔可夫链:离散的扩散过程
扩散模型将连续的扩散过程离散化成 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 步(通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1000 T=1000 </math>T=1000)。
定义 :给定原始数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0(例如一张图像),前向过程定义为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x 0 → x 1 → x 2 → ⋯ → x T x_0 \to x_1 \to x_2 \to \cdots \to x_T </math>x0→x1→x2→⋯→xT
其中每一步都是一个马尔可夫过程 : <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 只依赖于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1,与更早的状态无关。
单步转移:加一点高斯噪声
核心公式 :从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1 到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 的转移分布定义为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) </math>q(xt∣xt−1)=N(xt;1−βt xt−1,βtI)
符号说明:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) </math>N(μ,σ2):均值为 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ \mu </math>μ,方差为 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ 2 \sigma^2 </math>σ2 的高斯分布
- <math xmlns="http://www.w3.org/1998/Math/MathML"> β t ∈ ( 0 , 1 ) \beta_t \in (0, 1) </math>βt∈(0,1):第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 步的噪声强度(可学习或预定义)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I:单位矩阵(各维度独立加噪)
直观理解:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 以 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − β t x t − 1 \sqrt{1-\beta_t} x_{t-1} </math>1−βt xt−1 为中心(保留部分原始信号)
- 同时加上方差为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 的随机噪声
采样公式(重参数化):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t = 1 − β t x t − 1 + β t ε t − 1 x_t = \sqrt{1-\beta_t} x_{t-1} + \sqrt{\beta_t} \varepsilon_{t-1} </math>xt=1−βt xt−1+βt εt−1
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε t − 1 ∼ N ( 0 , I ) \varepsilon_{t-1} \sim \mathcal{N}(0, I) </math>εt−1∼N(0,I)。
为什么要乘 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − β t \sqrt{1-\beta_t} </math>1−βt ?
这是为了保持方差稳定 。如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 的方差为1,这个系数保证所有 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 的方差也为1,避免数值爆炸或消失。
噪声调度: <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 的设计
<math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 控制每一步的噪声强度。常见的调度策略:
1. 线性调度(Linear Schedule)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β t = β min + t − 1 T − 1 ( β max − β min ) \beta_t = \beta_{\min} + \frac{t-1}{T-1}(\beta_{\max} - \beta_{\min}) </math>βt=βmin+T−1t−1(βmax−βmin)
典型值: <math xmlns="http://www.w3.org/1998/Math/MathML"> β min = 0.0001 , β max = 0.02 \beta_{\min} = 0.0001, \beta_{\max} = 0.02 </math>βmin=0.0001,βmax=0.02。
特点:
- 早期( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t小): <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 很小,加噪缓慢
- 后期( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t大): <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 接近 <math xmlns="http://www.w3.org/1998/Math/MathML"> β max \beta_{\max} </math>βmax,加噪快速
2. 余弦调度(Cosine Schedule)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> α ˉ t = f ( t ) f ( 0 ) , f ( t ) = cos ( t / T + s 1 + s ⋅ π 2 ) 2 \bar{\alpha}_t = \frac{f(t)}{f(0)}, \quad f(t) = \cos\left(\frac{t/T + s}{1+s} \cdot \frac{\pi}{2}\right)^2 </math>αˉt=f(0)f(t),f(t)=cos(1+st/T+s⋅2π)2
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> s = 0.008 s=0.008 </math>s=0.008 是偏移量。
优势:避免了线性调度在开始和结束时的突变,训练更稳定。
累积效应:从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 直接跳到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt
逐步加噪可以合并成一步。定义累积参数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> α t = 1 − β t , α ˉ t = ∏ i = 1 t α i = ∏ i = 1 t ( 1 − β i ) \alpha_t = 1 - \beta_t, \quad \bar{\alpha}t = \prod{i=1}^{t} \alpha_i = \prod_{i=1}^{t}(1-\beta_i) </math>αt=1−βt,αˉt=i=1∏tαi=i=1∏t(1−βi)
核心定理 (可直接从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I) </math>q(xt∣x0)=N(xt;αˉt x0,(1−αˉt)I)
重参数化形式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t = α ˉ t x 0 + 1 − α ˉ t ε , ε ∼ N ( 0 , I ) x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I) </math>xt=αˉt x0+1−αˉt ε,ε∼N(0,I)
意义 :训练时可以一步生成任意时刻的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt,无需逐步迭代1000次。
终点:纯噪声
当 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = T t = T </math>t=T 时(通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1000 T=1000 </math>T=1000), <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ T ≈ 0 \bar{\alpha}_T \approx 0 </math>αˉT≈0,因此:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x T ≈ N ( 0 , I ) x_T \approx \mathcal{N}(0, I) </math>xT≈N(0,I)
即 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T x_T </math>xT 近似为标准高斯噪声,原始图像信息完全被破坏。
第三部分:反向扩散过程 - 从噪声到图像的逆转
反向过程的目标
我们想要逆转前向过程 ,即从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T x_T </math>xT 恢复 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x T → x T − 1 → x T − 2 → ⋯ → x 0 x_T \to x_{T-1} \to x_{T-2} \to \cdots \to x_0 </math>xT→xT−1→xT−2→⋯→x0
关键问题 :如何定义 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) </math>q(xt−1∣xt)(反向转移分布)?
贝叶斯公式的尝试
根据贝叶斯定理:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) q(x_{t-1}|x_t) = \frac{q(x_t|x_{t-1})q(x_{t-1})}{q(x_t)} </math>q(xt−1∣xt)=q(xt)q(xt∣xt−1)q(xt−1)
困难:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) </math>q(xt∣xt−1) 已知(前向过程)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ) q(x_{t-1}) </math>q(xt−1) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ) q(x_t) </math>q(xt) 是数据的边缘分布,未知且难以计算
结论 :无法直接解析地计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) </math>q(xt−1∣xt)。
条件反向分布:已知 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 的情况
如果我们知道原始数据 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0,则可以计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) </math>q(xt−1∣xt,x0)
根据贝叶斯定理:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_t, x_0) = \frac{q(x_t|x_{t-1}, x_0)q(x_{t-1}|x_0)}{q(x_t|x_0)} </math>q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1,x0)q(xt−1∣x0)
由于马尔可夫性, <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ∣ x t − 1 , x 0 ) = q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}, x_0) = q(x_t|x_{t-1}) </math>q(xt∣xt−1,x0)=q(xt∣xt−1):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(x_{t-1}|x_t, x_0) = \frac{q(x_t|x_{t-1})q(x_{t-1}|x_0)}{q(x_t|x_0)} </math>q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1)q(xt−1∣x0)
已知的分布:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t I ) q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{\alpha_t}x_{t-1}, \beta_t I) </math>q(xt∣xt−1)=N(xt;αt xt−1,βtI)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) q(x_{t-1}|x_0) = \mathcal{N}(x_{t-1}; \sqrt{\bar{\alpha}{t-1}}x_0, (1-\bar{\alpha}{t-1})I) </math>q(xt−1∣x0)=N(xt−1;αˉt−1 x0,(1−αˉt−1)I)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I) </math>q(xt∣x0)=N(xt;αˉt x0,(1−αˉt)I)
后验分布的推导
将三个高斯分布代入贝叶斯公式,利用高斯分布的乘积仍是高斯分布,可以推导出:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(x_{t-1}|x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I) </math>q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
后验均值:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ~ t ( x t , x 0 ) = α ˉ t − 1 β t 1 − α ˉ t x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t \tilde{\mu}t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}{t-1}}\beta_t}{1-\bar{\alpha}t}x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}{t-1})}{1-\bar{\alpha}_t}x_t </math>μ~t(xt,x0)=1−αˉtαˉt−1 βtx0+1−αˉtαt (1−αˉt−1)xt
后验方差:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}t = \frac{1-\bar{\alpha}{t-1}}{1-\bar{\alpha}_t}\beta_t </math>β~t=1−αˉt1−αˉt−1βt
推导思路:将三个高斯分布的对数相加减,配方得到新的高斯分布的均值和方差(详细推导见DDPM论文附录)。
用 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε \varepsilon </math>ε 重参数化
从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon </math>xt=αˉt x0+1−αˉt ε 可以解出:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x 0 = x t − 1 − α ˉ t ε α ˉ t x_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\varepsilon}{\sqrt{\bar{\alpha}_t}} </math>x0=αˉt xt−1−αˉt ε
代入 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ ~ t \tilde{\mu}_t </math>μ~t 公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ~ t ( x t , x 0 ) = 1 α t ( x t − β t 1 − α ˉ t ε ) \tilde{\mu}_t(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon\right) </math>μ~t(xt,x0)=αt 1(xt−1−αˉt βtε)
关键洞察 :如果我们能预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε \varepsilon </math>ε,就能计算后验均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ ~ t \tilde{\mu}t </math>μ~t,从而采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x{t-1} </math>xt−1。
第四部分:训练目标 - 学习逆过程
定义神经网络逼近反向过程
由于无法直接计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) </math>q(xt−1∣xt),我们用神经网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ p_\theta </math>pθ 来逼近:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) </math>pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ 是网络参数。
简化假设(DDPM的选择):方差固定为后验方差
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Σ θ ( x t , t ) = β ~ t I \Sigma_\theta(x_t, t) = \tilde{\beta}_t I </math>Σθ(xt,t)=β~tI
因此只需学习均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) \mu_\theta(x_t, t) </math>μθ(xt,t)。
变分下界(ELBO)
训练目标是最大化对数似然 <math xmlns="http://www.w3.org/1998/Math/MathML"> log p θ ( x 0 ) \log p_\theta(x_0) </math>logpθ(x0)。通过变分推断推导出ELBO(Evidence Lower Bound):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L ELBO = E q [ D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) − log p θ ( x 0 ∣ x 1 ) ] \mathcal{L}{\text{ELBO}} = \mathbb{E}q\left[D{KL}(q(x_T|x_0) \| p(x_T)) + \sum{t=2}^{T}D_{KL}(q(x_{t-1}|x_t,x_0) \| p_\theta(x_{t-1}|x_t)) - \log p_\theta(x_0|x_1)\right] </math>LELBO=Eq[DKL(q(xT∣x0)∥p(xT))+t=2∑TDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))−logpθ(x0∣x1)]
核心项 : <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) D_{KL}(q(x_{t-1}|x_t,x_0) \| p_\theta(x_{t-1}|x_t)) </math>DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)) ------ 真实后验与模型后验的KL散度。
由于两者都是高斯分布且方差相同,KL散度简化为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L t − 1 = E q [ 1 2 β ~ t ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] L_{t-1} = \mathbb{E}_q\left[\frac{1}{2\tilde{\beta}_t}\|\tilde{\mu}t(x_t, x_0) - \mu\theta(x_t, t)\|^2\right] </math>Lt−1=Eq[2β~t1∥μ~t(xt,x0)−μθ(xt,t)∥2]
目标 :让网络预测的均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ \mu_\theta </math>μθ 接近真实后验均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ ~ t \tilde{\mu}_t </math>μ~t。
从预测均值到预测噪声
回顾重参数化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ~ t ( x t , x 0 ) = 1 α t ( x t − β t 1 − α ˉ t ε ) \tilde{\mu}_t(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\varepsilon\right) </math>μ~t(xt,x0)=αt 1(xt−1−αˉt βtε)
我们参数化 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ \mu_\theta </math>μθ 为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ε θ ( x t , t ) ) \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\varepsilon\theta(x_t, t)\right) </math>μθ(xt,t)=αt 1(xt−1−αˉt βtεθ(xt,t))
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ \varepsilon_\theta </math>εθ 是神经网络,预测噪声。
代入损失函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L t − 1 = E x 0 , ε [ β t 2 2 β ~ t α t ( 1 − α ˉ t ) ∥ ε − ε θ ( x t , t ) ∥ 2 ] L_{t-1} = \mathbb{E}_{x_0, \varepsilon}\left[\frac{\beta_t^2}{2\tilde{\beta}_t\alpha_t(1-\bar{\alpha}t)}\|\varepsilon - \varepsilon\theta(x_t, t)\|^2\right] </math>Lt−1=Ex0,ε[2β~tαt(1−αˉt)βt2∥ε−εθ(xt,t)∥2]
简化损失函数
DDPM论文发现,去掉权重系数后训练效果更好:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L simple = E t , x 0 , ε [ ∥ ε − ε θ ( x t , t ) ∥ 2 ] L_{\text{simple}} = \mathbb{E}{t, x_0, \varepsilon}\left[\|\varepsilon - \varepsilon\theta(x_t, t)\|^2\right] </math>Lsimple=Et,x0,ε[∥ε−εθ(xt,t)∥2]
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> t ∼ Uniform ( 1 , T ) t \sim \text{Uniform}(1, T) </math>t∼Uniform(1,T)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 ∼ q ( x 0 ) x_0 \sim q(x_0) </math>x0∼q(x0)(数据分布)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ∼ N ( 0 , I ) \varepsilon \sim \mathcal{N}(0, I) </math>ε∼N(0,I)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon </math>xt=αˉt x0+1−αˉt ε
最终训练目标 :让神经网络预测前向过程中添加的噪声。
直观理解:
- 给网络一个加噪图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 和时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t
- 网络预测这张图像的噪声成分 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t ) \varepsilon_\theta(x_t, t) </math>εθ(xt,t)
- 与真实噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε \varepsilon </math>ε 对比,计算均方误差
第五部分:噪声预测网络的架构 - U-Net详解
前面我们知道了扩散模型需要一个神经网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t ) \varepsilon_\theta(x_t, t) </math>εθ(xt,t) 来预测噪声。但这个网络到底长什么样?由哪些层组成?信息如何流动?
这部分详细拆解扩散模型最核心的网络架构:U-Net。
U-Net的整体结构:编码器-解码器架构
U-Net是扩散模型(包括DDPM、Stable Diffusion等)的标准backbone。它的名字来源于其U形结构。
整体架构:
c
输入: 加噪图像 x_t (256×256×3) + 时间步 t
x_t
↓
┌───────────────────────┐
│ Initial Conv │ → 256×256×128
└───────────────────────┘
↓
┌───────────────────────┐
┌───│ Encoder Block 1 │ → 128×128×256 ← 下采样
│ └───────────────────────┘
│ ↓
│ ┌───────────────────────┐
├───│ Encoder Block 2 │ → 64×64×512 ← 下采样
│ └───────────────────────┘
│ ↓
│ ┌───────────────────────┐
├───│ Encoder Block 3 │ → 32×32×1024 ← 下采样
│ └───────────────────────┘
│ ↓
│ ┌───────────────────────┐
│ │ Bottleneck (中间层) │ → 32×32×1024
│ └───────────────────────┘
│ ↓
│ ┌───────────────────────┐
└──→│ Decoder Block 1 │ → 64×64×512 ← 上采样 + 跳跃连接
└───────────────────────┘
↓
┌───────────────────────┐
┌──→│ Decoder Block 2 │ → 128×128×256 ← 上采样 + 跳跃连接
│ └───────────────────────┘
│ ↓
│ ┌───────────────────────┐
└──→│ Decoder Block 3 │ → 256×256×128 ← 上采样 + 跳跃连接
└───────────────────────┘
↓
┌───────────────────────┐
│ Output Conv │ → 256×256×3
└───────────────────────┘
↓
预测的噪声 ε̂
关键特征:
- 对称结构:编码器逐步降低分辨率,解码器逐步恢复分辨率
- 跳跃连接(Skip Connections):编码器的特征直接连到对应的解码器层
- 时间嵌入 :时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 通过嵌入向量注入到每一层
组件一:ResNet Block(残差块)
U-Net的基本构建单元是ResNet Block,每个Encoder/Decoder Block包含多个ResNet Block。
单个ResNet Block的结构:
ini
输入特征 h (H×W×C)
↓
┌─────────────────────────┐
│ GroupNorm + SiLU │ ← 归一化 + 激活
└─────────────────────────┘
↓
┌─────────────────────────┐
│ Conv 3×3 │ ← 卷积
└─────────────────────────┘
↓
┌─────────────────────────┐
│ 时间嵌入注入: │
│ h = h + Linear(t_emb) │ ← 加入时间信息
└─────────────────────────┘
↓
┌─────────────────────────┐
│ GroupNorm + SiLU │
└─────────────────────────┘
↓
┌─────────────────────────┐
│ Conv 3×3 │
└─────────────────────────┘
↓
+ ← 残差连接(与输入相加)
↓
输出特征 h'
数学表达:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 1 = Conv ( SiLU ( GN ( h ) ) ) + Linear ( t emb ) h_1 = \text{Conv}(\text{SiLU}(\text{GN}(h))) + \text{Linear}(t_{\text{emb}}) </math>h1=Conv(SiLU(GN(h)))+Linear(temb)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 2 = Conv ( SiLU ( GN ( h 1 ) ) ) h_2 = \text{Conv}(\text{SiLU}(\text{GN}(h_1))) </math>h2=Conv(SiLU(GN(h1)))
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> output = h + h 2 (残差连接) \text{output} = h + h_2 \quad \text{(残差连接)} </math>output=h+h2(残差连接)
为什么用ResNet Block?
- 梯度流畅:残差连接使得梯度可以直接反向传播,避免梯度消失
- 深度网络:可以堆叠很多层(U-Net通常有20-30层)
组件二:时间嵌入(Time Embedding)
时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 是扩散模型的关键输入,网络需要知道"当前是第几步"才能预测对应的噪声。
时间嵌入的流程:
scss
时间步 t (标量, 如 t=500)
↓
┌─────────────────────────────┐
│ Sinusoidal Position Encoding│ ← 将t编码成高维向量
│ 类似Transformer的位置编码 │
└─────────────────────────────┘
↓
时间向量 t_vec (128维)
↓
┌─────────────────────────────┐
│ Linear (128 → 512) │ ← 两层MLP扩展维度
│ SiLU │
│ Linear (512 → 512) │
└─────────────────────────────┘
↓
时间嵌入 t_emb (512维)
↓
注入到每个ResNet Block
正弦位置编码的公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE ( t , 2 i ) = sin ( t 1000 0 2 i / d ) \text{PE}(t, 2i) = \sin\left(\frac{t}{10000^{2i/d}}\right) </math>PE(t,2i)=sin(100002i/dt)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE ( t , 2 i + 1 ) = cos ( t 1000 0 2 i / d ) \text{PE}(t, 2i+1) = \cos\left(\frac{t}{10000^{2i/d}}\right) </math>PE(t,2i+1)=cos(100002i/dt)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 是维度索引, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是总维度(如128)。
为什么用正弦编码?
- 连续性:相邻时间步的编码相近
- 周期性:能表达时间的周期模式
- 泛化性:没有可学习参数,避免过拟合
时间嵌入的注入方式:
在每个ResNet Block中,时间嵌入通过加法注入到特征中:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h = h + Linear ( t emb ) h = h + \text{Linear}(t_{\text{emb}}) </math>h=h+Linear(temb)
这里的Linear层将512维的 <math xmlns="http://www.w3.org/1998/Math/MathML"> t emb t_{\text{emb}} </math>temb 投影到当前特征的通道数 <math xmlns="http://www.w3.org/1998/Math/MathML"> C C </math>C。
组件三:下采样(Downsampling)
编码器通过下采样逐步减小空间分辨率,增加通道数。
两种下采样方式:
方式1:步长为2的卷积
ini
输入: 128×128×256
↓
┌────────────────────┐
│ Conv 3×3, stride=2 │ ← 卷积核大小3×3,步长2
└────────────────────┘
↓
输出: 64×64×512
方式2:平均池化
makefile
输入: 128×128×256
↓
┌────────────────────┐
│ AvgPool 2×2 │ ← 2×2平均池化
└────────────────────┘
↓
输出: 64×64×256
↓
┌────────────────────┐
│ Conv 1×1 │ ← 1×1卷积调整通道数
└────────────────────┘
↓
输出: 64×64×512
为什么下采样?
- 增大感受野:低分辨率特征包含更全局的信息
- 计算效率:低分辨率特征计算量小
- 多尺度表示:不同分辨率捕获不同尺度的模式
组件四:上采样(Upsampling)
解码器通过上采样逐步恢复空间分辨率。
两种上采样方式:
方式1:转置卷积(Transposed Convolution)
ini
输入: 64×64×512
↓
┌────────────────────────┐
│ ConvTranspose2d │ ← 反卷积
│ kernel=3, stride=2 │
└────────────────────────┘
↓
输出: 128×128×256
方式2:最近邻插值 + 卷积(更常用)
makefile
输入: 64×64×512
↓
┌────────────────────────┐
│ Upsample (x2) │ ← 最近邻或双线性插值
│ 64×64 → 128×128 │
└────────────────────────┘
↓
┌────────────────────────┐
│ Conv 3×3 │ ← 卷积平滑插值产生的伪影
└────────────────────────┘
↓
输出: 128×128×256
为什么用插值+卷积?
- 转置卷积容易产生棋盘伪影(checkerboard artifacts)
- 插值+卷积效果更平滑
组件五:跳跃连接(Skip Connections)
这是U-Net的核心创新,也是"U"形名称的来源。
跳跃连接的作用:
编码器的特征直接传递给解码器的对应层:
scss
Encoder Block 1 (128×128×256) ─────┐
│
├→ Concatenate
│
Decoder Block 3 (128×128×256) ←────┘
具体操作:
解码器在上采样后,将编码器的对应特征拼接到当前特征:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h decoder = Concat ( h upsampled , h encoder ) h_{\text{decoder}} = \text{Concat}(h_{\text{upsampled}}, h_{\text{encoder}}) </math>hdecoder=Concat(hupsampled,hencoder)
例如:
- 上采样后: <math xmlns="http://www.w3.org/1998/Math/MathML"> 128 × 128 × 256 128 \times 128 \times 256 </math>128×128×256
- 编码器特征: <math xmlns="http://www.w3.org/1998/Math/MathML"> 128 × 128 × 256 128 \times 128 \times 256 </math>128×128×256
- 拼接后: <math xmlns="http://www.w3.org/1998/Math/MathML"> 128 × 128 × 512 128 \times 128 \times 512 </math>128×128×512(通道数翻倍)
为什么需要跳跃连接?
- 保留细节:编码器的高分辨率特征包含细粒度信息(边缘、纹理)
- 梯度流动:提供额外的梯度路径,缓解梯度消失
- 多尺度融合:结合低层次(细节)和高层次(语义)特征
组件六:注意力层(Attention Layer)
在较低分辨率的层(如32×32、16×16),U-Net会插入自注意力层。
自注意力的位置:
scss
Encoder Block 3 (32×32×1024)
↓
┌────────────────────────┐
│ ResNet Block │
└────────────────────────┘
↓
┌────────────────────────┐
│ Self-Attention │ ← 注意力层
└────────────────────────┘
↓
┌────────────────────────┐
│ ResNet Block │
└────────────────────────┘
Self-Attention的结构:
ini
输入特征 h (32×32×1024)
↓
展平成序列: (1024, 1024) ← 1024个token,每个1024维
↓
┌────────────────────────┐
│ Linear投影: Q, K, V │
│ Q = W_Q · h │
│ K = W_K · h │
│ V = W_V · h │
└────────────────────────┘
↓
┌────────────────────────┐
│ Attention(Q, K, V) │
│ = softmax(QK^T/√d)·V │
└────────────────────────┘
↓
┌────────────────────────┐
│ 输出投影 │
└────────────────────────┘
↓
+ ← 残差连接
↓
输出特征 h'
数学表达:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V </math>Attention(Q,K,V)=softmax(d QKT)V
为什么要加注意力层?
- 全局依赖:卷积只能捕获局部依赖(受感受野限制),注意力可以建模全局关系
- 长距离交互:图像左上角的猫耳朵可以"看到"右下角的猫尾巴
- 只在低分辨率使用 :因为注意力复杂度是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) O(N^2) </math>O(N2),高分辨率计算量太大
条件注入:如何加入类别或文本?
扩散模型需要支持条件生成(如"生成一只猫")。条件信息有多种注入方式:
方式1:时间嵌入融合(类别条件)
ini
类别标签 c (如 c=281 表示"猫")
↓
┌────────────────────────┐
│ Embedding Layer │ ← 类别ID → 类别向量
└────────────────────────┘
↓
类别嵌入 c_emb (512维)
↓
┌────────────────────────┐
│ 与时间嵌入相加: │
│ emb = t_emb + c_emb │
└────────────────────────┘
↓
注入到ResNet Block
方式2:Cross-Attention(文本条件)
css
文本 "a photo of a cat"
↓
┌────────────────────────┐
│ CLIP Text Encoder │ ← 预训练的文本编码器
└────────────────────────┘
↓
文本嵌入序列 c (77×768) ← 77个token,每个768维
↓
在U-Net的每层插入Cross-Attention:
图像特征 h (32×32×1024)
↓
展平: (1024, 1024)
↓
┌────────────────────────┐
│ Q = W_Q · h │ ← Query来自图像
│ K = W_K · c │ ← Key来自文本
│ V = W_V · c │ ← Value来自文本
└────────────────────────┘
↓
┌────────────────────────┐
│ CrossAttention(Q,K,V) │ ← 图像"查询"文本信息
│ = softmax(QK^T/√d)·V │
└────────────────────────┘
↓
输出特征(融合了文本信息)
Cross-Attention vs Self-Attention:
| 类型 | Query来源 | Key/Value来源 | 作用 |
|---|---|---|---|
| Self-Attention | 图像特征 | 图像特征 | 图像内部不同位置交互 |
| Cross-Attention | 图像特征 | 文本嵌入 | 图像根据文本调整特征 |
完整的U-Net前向传播流程
把所有组件串起来,完整走一遍:
输入:
- 加噪图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt: <math xmlns="http://www.w3.org/1998/Math/MathML"> 256 × 256 × 3 256 \times 256 \times 3 </math>256×256×3
- 时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t: 标量(如500)
- 条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c: 类别或文本嵌入
第1步:时间嵌入
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> t → Sinusoidal t vec → MLP t emb ∈ R 512 t \xrightarrow{\text{Sinusoidal}} t_{\text{vec}} \xrightarrow{\text{MLP}} t_{\text{emb}} \in \mathbb{R}^{512} </math>tSinusoidal tvecMLP temb∈R512
第2步:Initial Conv
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t → Conv 3×3 h 0 ∈ R 256 × 256 × 128 x_t \xrightarrow{\text{Conv 3×3}} h_0 \in \mathbb{R}^{256 \times 256 \times 128} </math>xtConv 3×3 h0∈R256×256×128
第3步:编码器(下采样路径)
scss
h_0 (256×256×128)
↓ ResNet Block × 2 + Attention
h_1 (256×256×128) ─────────┐ (保存用于跳跃连接)
↓ Downsample │
h_2 (128×128×256) │
↓ ResNet Block × 2 │
h_3 (128×128×256) ─────────┼─┐
↓ Downsample │ │
h_4 (64×64×512) │ │
↓ ResNet Block × 2 │ │
h_5 (64×64×512) ───────────┼─┼─┐
↓ Downsample │ │ │
h_6 (32×32×1024) │ │ │
第4步:瓶颈层(Bottleneck)
scss
h_6 (32×32×1024)
↓ ResNet Block × 2
↓ Self-Attention
h_7 (32×32×1024)
第5步:解码器(上采样路径)
scss
h_7 (32×32×1024)
↓ Upsample
h_8 (64×64×1024)
↓ Concat(h_8, h_5) ← 跳跃连接
h_9 (64×64×1536)
↓ ResNet Block × 3
h_10 (64×64×512)
↓ Upsample
h_11 (128×128×512)
↓ Concat(h_11, h_3) ← 跳跃连接
h_12 (128×128×768)
↓ ResNet Block × 3
h_13 (128×128×256)
↓ Upsample
h_14 (256×256×256)
↓ Concat(h_14, h_1) ← 跳跃连接
h_15 (256×256×384)
↓ ResNet Block × 3
h_16 (256×256×128)
第6步:输出投影
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 16 → GroupNorm → SiLU → Conv 3×3 ε ^ ∈ R 256 × 256 × 3 h_{16} \xrightarrow{\text{GroupNorm}} \xrightarrow{\text{SiLU}} \xrightarrow{\text{Conv 3×3}} \hat{\varepsilon} \in \mathbb{R}^{256 \times 256 \times 3} </math>h16GroupNorm SiLU Conv 3×3 ε^∈R256×256×3
输出 :预测的噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,与输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 同样大小。
U-Net的参数量
一个典型的扩散U-Net(如Stable Diffusion的U-Net)参数量:
| 组件 | 参数量 |
|---|---|
| Initial Conv | 0.1M |
| Encoder Blocks | 300M |
| Bottleneck | 50M |
| Decoder Blocks | 300M |
| Attention Layers | 200M |
| 总计 | ~860M |
相比之下:
- ResNet-50: 25M参数
- ViT-Base: 86M参数
- GPT-2: 1.5B参数
U-Net的参数量介于CV和NLP模型之间。
为什么U-Net适合扩散模型?
1. 对称的编码-解码结构
- 编码器提取多尺度特征
- 解码器逐步恢复细节
- 非常适合"去噪"任务(从噪声恢复细节)
2. 跳跃连接保留细节
- 去噪需要精确恢复像素级细节
- 跳跃连接直接传递高分辨率特征
3. 多尺度处理
- 噪声在不同尺度上的表现不同
- U-Net自然地在多个尺度上操作
4. 已有的成功经验
- U-Net最初为医学图像分割设计
- 分割任务(像素级预测)与去噪任务(像素级预测)很相似
小结:U-Net的组成板块
扩散模型的噪声预测网络(U-Net)由以下板块组成:
- ResNet Block:基本构建单元,带残差连接
- 时间嵌入:将时间步编码并注入每一层
- 下采样模块:降低分辨率,增加通道数
- 上采样模块:恢复分辨率,减少通道数
- 跳跃连接:编码器特征直接传递给解码器
- 注意力层:在低分辨率层建模全局依赖
- 条件注入:通过时间嵌入融合或Cross-Attention加入条件
整体数据流:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> 输入 : x t , t , c 时间嵌入 : t → t emb 编码器 : x t → h 1 → h 2 → ⋯ → h 6 (逐步下采样) 瓶颈层 : h 6 → h 7 (最低分辨率) 解码器 : h 7 → h 8 → ⋯ → h 16 (逐步上采样 + 跳跃连接) 输出 : h 16 → ε ^ \begin{align} \text{输入} &: x_t, t, c \\ \text{时间嵌入} &: t \to t_{\text{emb}} \\ \text{编码器} &: x_t \to h_1 \to h_2 \to \cdots \to h_6 \quad \text{(逐步下采样)} \\ \text{瓶颈层} &: h_6 \to h_7 \quad \text{(最低分辨率)} \\ \text{解码器} &: h_7 \to h_8 \to \cdots \to h_{16} \quad \text{(逐步上采样 + 跳跃连接)} \\ \text{输出} &: h_{16} \to \hat{\varepsilon} \end{align} </math>输入时间嵌入编码器瓶颈层解码器输出:xt,t,c:t→temb:xt→h1→h2→⋯→h6(逐步下采样):h6→h7(最低分辨率):h7→h8→⋯→h16(逐步上采样 + 跳跃连接):h16→ε^
现在你应该清楚扩散模型的网络"长什么样"了------它是一个U形的编码器-解码器结构,通过ResNet块、注意力层、跳跃连接和时间嵌入共同完成噪声预测任务。
第六部分:采样算法 - 从噪声生成图像
DDPM采样:严格的概率过程
训练完成后,我们有了噪声预测网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t ) \varepsilon_\theta(x_t, t) </math>εθ(xt,t)。如何用它生成图像?
反向采样过程:
从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) </math>xT∼N(0,I) 开始,逐步采样:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 ∼ p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , β ~ t I ) x_{t-1} \sim p_\theta(x_{t-1}|x_t) = \mathcal{N}\left(x_{t-1}; \mu_\theta(x_t, t), \tilde{\beta}_t I\right) </math>xt−1∼pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),β~tI)
采样公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = 1 α t ( x t − β t 1 − α ˉ t ε θ ( x t , t ) ) + β ~ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\varepsilon\theta(x_t, t)\right) + \sqrt{\tilde{\beta}_t}z </math>xt−1=αt 1(xt−1−αˉt βtεθ(xt,t))+β~t z
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) </math>z∼N(0,I)( <math xmlns="http://www.w3.org/1998/Math/MathML"> t = 1 t=1 </math>t=1 时 <math xmlns="http://www.w3.org/1998/Math/MathML"> z = 0 z=0 </math>z=0)。
完整算法:
arduino
输入: 噪声预测网络 ε_θ
输出: 生成的图像 x_0
1. 采样 x_T ~ N(0, I)
2. for t = T, T-1, ..., 1 do
3. z ~ N(0, I) if t > 1 else z = 0
4. ε̂ = ε_θ(x_t, t)
5. x_{t-1} = 1/√α_t · (x_t - β_t/√(1-ᾱ_t) · ε̂) + √β̃_t · z
6. end for
7. return x_0
特点:
- 慢 :需要 <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1000 T=1000 </math>T=1000 步,每步都要前向传播网络
- 随机性 :每次采样噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,生成结果不同
- 质量高:严格遵循训练时的概率过程
方差的选择
后验方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ~ t \tilde{\beta}_t </math>β~t 有两个合理的选择:
选择1:后验方差
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}t = \frac{1-\bar{\alpha}{t-1}}{1-\bar{\alpha}_t}\beta_t </math>β~t=1−αˉt1−αˉt−1βt
选择2:前向方差
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β ~ t = β t \tilde{\beta}_t = \beta_t </math>β~t=βt
DDPM使用选择1,IDDPM(Improved DDPM)发现学习方差插值效果更好。
DDIM采样:确定性加速
DDPM的问题:1000步太慢,能否加速?
DDIM的核心思想 :构造一个非马尔可夫的前向过程,使得反向过程可以跳步。
DDIM前向过程:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q σ ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ⋅ x t − α ˉ t x 0 1 − α ˉ t , σ t 2 I ) q_\sigma(x_{t-1}|x_t, x_0) = \mathcal{N}\left(x_{t-1}; \sqrt{\bar{\alpha}{t-1}}x_0 + \sqrt{1-\bar{\alpha}{t-1}-\sigma_t^2} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t}x_0}{\sqrt{1-\bar{\alpha}_t}}, \sigma_t^2 I\right) </math>qσ(xt−1∣xt,x0)=N(xt−1;αˉt−1 x0+1−αˉt−1−σt2 ⋅1−αˉt xt−αˉt x0,σt2I)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t \sigma_t </math>σt 控制随机性:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = β ~ t \sigma_t = \sqrt{\tilde{\beta}_t} </math>σt=β~t :退化为DDPM
- <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = 0 \sigma_t = 0 </math>σt=0:完全确定性
DDIM采样公式:
用神经网络预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ 0 = x t − 1 − α ˉ t ε θ ( x t , t ) α ˉ t \hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}t}\varepsilon\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}} </math>x^0=αˉt xt−1−αˉt εθ(xt,t)
采样 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = α ˉ t − 1 x ^ 0 + 1 − α ˉ t − 1 − σ t 2 ⋅ ε θ ( x t , t ) + σ t z x_{t-1} = \sqrt{\bar{\alpha}{t-1}}\hat{x}0 + \sqrt{1-\bar{\alpha}{t-1}-\sigma_t^2} \cdot \varepsilon\theta(x_t, t) + \sigma_t z </math>xt−1=αˉt−1 x^0+1−αˉt−1−σt2 ⋅εθ(xt,t)+σtz
确定性情况 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = 0 \sigma_t = 0 </math>σt=0):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = α ˉ t − 1 x ^ 0 + 1 − α ˉ t − 1 ⋅ ε θ ( x t , t ) x_{t-1} = \sqrt{\bar{\alpha}{t-1}}\hat{x}0 + \sqrt{1-\bar{\alpha}{t-1}} \cdot \varepsilon\theta(x_t, t) </math>xt−1=αˉt−1 x^0+1−αˉt−1 ⋅εθ(xt,t)
跳步采样:
由于确定性,可以只在子序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> { τ 1 , τ 2 , ... , τ S } ⊂ { 1 , 2 , ... , T } \{\tau_1, \tau_2, \ldots, \tau_S\} \subset \{1, 2, \ldots, T\} </math>{τ1,τ2,...,τS}⊂{1,2,...,T} 上采样。
例如: <math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1000 , S = 50 T=1000, S=50 </math>T=1000,S=50,选择 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ = [ 20 , 40 , 60 , ... , 1000 ] \tau = [20, 40, 60, \ldots, 1000] </math>τ=[20,40,60,...,1000]。
算法:
ini
输入: 噪声预测网络 ε_θ, 时间步序列 τ = [τ_1, ..., τ_S]
输出: 生成的图像 x_0
1. 采样 x_{τ_S} ~ N(0, I)
2. for i = S, S-1, ..., 1 do
3. t = τ_i
4. t_prev = τ_{i-1} if i > 1 else 0
5. ε̂ = ε_θ(x_t, t)
6. x̂_0 = (x_t - √(1-ᾱ_t)·ε̂) / √ᾱ_t
7. x_{t_prev} = √ᾱ_{t_prev}·x̂_0 + √(1-ᾱ_{t_prev})·ε̂
8. end for
9. return x_0
特点:
- 快:50步即可,速度提升20倍
- 确定性:相同初始噪声产生相同结果
- 质量略降:FID略高于DDPM,但肉眼难以区分
为什么DDIM可以跳步?
直观理解:
DDPM的随机性累积,跳步会导致误差放大。DDIM去除随机性,每步都是确定性的轨迹,因此可以只在关键点采样。
数学角度:
DDIM的更新公式可以看作ODE(常微分方程)的离散化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d x d t = f ( x , t ) \frac{dx}{dt} = f(x, t) </math>dtdx=f(x,t)
而ODE的解对初值连续依赖,因此可以用更大的步长(跳步)求解。
采样质量与速度的权衡
| 方法 | 步数 | 时间 | FID (ImageNet 256×256) | 特点 |
|---|---|---|---|---|
| DDPM | 1000 | 60s | 2.50 | 严格概率过程,质量最高 |
| DDIM | 50 | 3s | 2.80 | 确定性,速度快 |
| DDIM | 10 | 0.6s | 4.50 | 极速,质量下降 |
第六部分:条件生成与引导
类别条件生成
实际应用中,我们希望控制生成内容,例如"生成一只猫"。
类别条件扩散模型:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) </math>εθ(xt,t,c)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 是类别标签(如ImageNet的1000类)。
训练 :数据集包含 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x 0 , c ) (x_0, c) </math>(x0,c) 对,损失函数变为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = E t , x 0 , c , ε [ ∥ ε − ε θ ( x t , t , c ) ∥ 2 ] L = \mathbb{E}{t, x_0, c, \varepsilon}\left[\|\varepsilon - \varepsilon\theta(x_t, t, c)\|^2\right] </math>L=Et,x0,c,ε[∥ε−εθ(xt,t,c)∥2]
Classifier Guidance
问题:标准条件生成的类别一致性不够强。
解决方案:利用分类器梯度引导采样。
假设有一个分类器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ϕ ( c ∣ x t , t ) p_\phi(c|x_t, t) </math>pϕ(c∣xt,t),在采样时调整均值:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ~ t = μ θ ( x t , t , c ) + s ⋅ Σ θ ∇ x t log p ϕ ( c ∣ x t , t ) \tilde{\mu}t = \mu\theta(x_t, t, c) + s \cdot \Sigma_\theta \nabla_{x_t}\log p_\phi(c|x_t, t) </math>μ~t=μθ(xt,t,c)+s⋅Σθ∇xtlogpϕ(c∣xt,t)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s 是引导强度。
直观理解 :沿着"更像类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c"的方向调整。
缺点:需要额外训练分类器,且需要在噪声数据上训练。
Classifier-Free Guidance(CFG)
核心思想:不用分类器,而是训练时同时学习条件和无条件模型。
训练 :以概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p (通常10%) 将条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 置为空:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c ′ = { ∅ 概率 p c 概率 1 − p c' = \begin{cases} \emptyset & \text{概率 } p \\ c & \text{概率 } 1-p \end{cases} </math>c′={∅c概率 p概率 1−p
这样网络学会两种预测:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) </math>εθ(xt,t,c):条件预测
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, \emptyset) </math>εθ(xt,t,∅):无条件预测
采样时的引导:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε ~ θ ( x t , t , c ) = ε θ ( x t , t , ∅ ) + w ⋅ ( ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) ) \tilde{\varepsilon}\theta(x_t, t, c) = \varepsilon\theta(x_t, t, \emptyset) + w \cdot (\varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset)) </math>ε~θ(xt,t,c)=εθ(xt,t,∅)+w⋅(εθ(xt,t,c)−εθ(xt,t,∅))
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> w ≥ 1 w \geq 1 </math>w≥1 是引导尺度(guidance scale)。
改写为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε ~ θ = ( 1 − w ) ε θ ( x t , t , ∅ ) + w ⋅ ε θ ( x t , t , c ) \tilde{\varepsilon}\theta = (1-w)\varepsilon\theta(x_t, t, \emptyset) + w \cdot \varepsilon_\theta(x_t, t, c) </math>ε~θ=(1−w)εθ(xt,t,∅)+w⋅εθ(xt,t,c)
数学解释:
定义"条件方向"为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Δ = ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) \Delta = \varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset) </math>Δ=εθ(xt,t,c)−εθ(xt,t,∅)
则:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε ~ θ = ε θ ( x t , t , ∅ ) + w Δ \tilde{\varepsilon}\theta = \varepsilon\theta(x_t, t, \emptyset) + w\Delta </math>ε~θ=εθ(xt,t,∅)+wΔ
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 0 w = 0 </math>w=0:无条件生成
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 1 w = 1 </math>w=1:标准条件生成
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w > 1 w > 1 </math>w>1:放大条件影响
效果:
| <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w | 类别一致性 | 多样性 | 图像质量 |
|---|---|---|---|
| 1.0 | 低 | 高 | 中 |
| 3.0 | 中 | 中 | 好 |
| 7.5 | 高 | 低 | 最好 |
| 15.0 | 过高 | 过低 | 过饱和 |
代价:推理时需要两次前向传播(条件+无条件),速度减半。
文本条件生成
对于文本到图像(Text-to-Image),条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c 是文本嵌入。
CLIP嵌入:使用预训练的CLIP模型将文本编码为向量。
Cross-Attention注入:在U-Net的每层添加交叉注意力:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V </math>Attention(Q,K,V)=softmax(d QKT)V
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Q = W Q ⋅ h Q = W_Q \cdot h </math>Q=WQ⋅h(图像特征)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> K = W K ⋅ c K = W_K \cdot c </math>K=WK⋅c(文本嵌入)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> V = W V ⋅ c V = W_V \cdot c </math>V=WV⋅c
这使得图像特征可以"查询"文本信息。
第七部分:扩散模型的理论基础
与score-based模型的联系
扩散模型与score-based生成模型本质相同。
Score函数:数据分布的对数梯度
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s ( x ) = ∇ x log p ( x ) s(x) = \nabla_x \log p(x) </math>s(x)=∇xlogp(x)
Tweedie's Formula :给定噪声观测 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = x 0 + σ ε x_t = x_0 + \sigma\varepsilon </math>xt=x0+σε,则:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E [ x 0 ∣ x t ] = x t + σ 2 ∇ x t log p ( x t ) \mathbb{E}[x_0|x_t] = x_t + \sigma^2 \nabla_{x_t}\log p(x_t) </math>E[x0∣xt]=xt+σ2∇xtlogp(xt)
扩散模型的score:
从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}t}\varepsilon </math>xt=αˉt x0+1−αˉt ε 可知:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∇ x t log p ( x t ) = − ε 1 − α ˉ t \nabla{x_t}\log p(x_t) = -\frac{\varepsilon}{\sqrt{1-\bar{\alpha}_t}} </math>∇xtlogp(xt)=−1−αˉt ε
因此预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε \varepsilon </math>ε 等价于预测score:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε θ ( x t , t ) = − 1 − α ˉ t ⋅ s θ ( x t , t ) \varepsilon_\theta(x_t, t) = -\sqrt{1-\bar{\alpha}t} \cdot s\theta(x_t, t) </math>εθ(xt,t)=−1−αˉt ⋅sθ(xt,t)
随机微分方程(SDE)视角
扩散过程可以表示为SDE:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d x = f ( x , t ) d t + g ( t ) d w dx = f(x, t)dt + g(t)dw </math>dx=f(x,t)dt+g(t)dw
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x , t ) f(x, t) </math>f(x,t):漂移系数
- <math xmlns="http://www.w3.org/1998/Math/MathML"> g ( t ) g(t) </math>g(t):扩散系数
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w:布朗运动
前向SDE(方差保持,VP-SDE):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d x = − 1 2 β ( t ) x d t + β ( t ) d w dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dw </math>dx=−21β(t)xdt+β(t) dw
反向SDE:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d x = [ f ( x , t ) − g 2 ( t ) ∇ x log p t ( x ) ] d t + g ( t ) d w ˉ dx = \left[f(x, t) - g^2(t)\nabla_x\log p_t(x)\right]dt + g(t)d\bar{w} </math>dx=[f(x,t)−g2(t)∇xlogpt(x)]dt+g(t)dwˉ
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> w ˉ \bar{w} </math>wˉ 是反向布朗运动。
数值求解:
用神经网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> s θ s_\theta </math>sθ 逼近 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ x log p t ( x ) \nabla_x\log p_t(x) </math>∇xlogpt(x),然后用欧拉-丸山方法求解反向SDE,即可生成样本。
优势:
- 统一框架(DDPM、DDIM、score-based都是特例)
- 可以使用高阶ODE求解器加速(DPM-Solver)
- 理论保证(SDE理论)
概率流ODE(Probability Flow ODE)
对于任意SDE,存在唯一的ODE,其边缘分布与SDE相同:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d x = [ f ( x , t ) − 1 2 g 2 ( t ) ∇ x log p t ( x ) ] d t dx = \left[f(x, t) - \frac{1}{2}g^2(t)\nabla_x\log p_t(x)\right]dt </math>dx=[f(x,t)−21g2(t)∇xlogpt(x)]dt
意义:
- ODE是确定性的,无随机性
- 可以精确计算似然(通过瞬时变化公式)
- 可以进行隐空间插值
DDIM就是概率流ODE的离散化。
第八部分:扩散模型的变体与改进
Latent Diffusion Models(LDM)
核心思想:在低维隐空间做扩散,而非原始像素空间。
架构:
-
自编码器 : <math xmlns="http://www.w3.org/1998/Math/MathML"> x → E z → D x ^ x \xrightarrow{E} z \xrightarrow{D} \hat{x} </math>xE zD x^
- 编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> E E </math>E:图像 → 隐表示(如 <math xmlns="http://www.w3.org/1998/Math/MathML"> 256 × 256 × 3 → 32 × 32 × 4 256 \times 256 \times 3 \to 32 \times 32 \times 4 </math>256×256×3→32×32×4)
- 解码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> D D </math>D:隐表示 → 图像
-
扩散过程 :在 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z 空间做扩散
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> z t = α ˉ t z 0 + 1 − α ˉ t ε z_t = \sqrt{\bar{\alpha}_t}z_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon </math>zt=αˉt z0+1−αˉt ε
- 噪声预测网络 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( z t , t , c ) \varepsilon_\theta(z_t, t, c) </math>εθ(zt,t,c)
优势:
- 效率 :隐空间维度远小于像素空间( <math xmlns="http://www.w3.org/1998/Math/MathML"> 3 2 2 × 4 ≪ 25 6 2 × 3 32^2 \times 4 \ll 256^2 \times 3 </math>322×4≪2562×3),计算量减少64倍
- 语义:隐空间更具语义性,更易于条件控制
- 质量:解码器补充高频细节
代表:Stable Diffusion、Imagen Video。
Cascaded Diffusion Models
思想:分阶段生成,从低分辨率到高分辨率。
流程:
- 基础模型:生成 <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 × 64 64 \times 64 </math>64×64 图像
- 超分辨率模型1: <math xmlns="http://www.w3.org/1998/Math/MathML"> 64 × 64 → 256 × 256 64 \times 64 \to 256 \times 256 </math>64×64→256×256
- 超分辨率模型2: <math xmlns="http://www.w3.org/1998/Math/MathML"> 256 × 256 → 1024 × 1024 256 \times 256 \to 1024 \times 1024 </math>256×256→1024×1024
每个阶段都是独立训练的扩散模型,条件包括低分辨率图像。
优势:
- 分解复杂度,每个模型专注于特定尺度
- 高分辨率生成(DALL-E 2使用此方法)
Consistency Models
目标:一步生成,去除迭代采样。
核心思想 :训练一个网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> f θ ( x t , t ) f_\theta(x_t, t) </math>fθ(xt,t) 满足自一致性:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f θ ( x t , t ) = f θ ( x t ′ , t ′ ) = x 0 ∀ t , t ′ f_\theta(x_t, t) = f_\theta(x_{t'}, t') = x_0 \quad \forall t, t' </math>fθ(xt,t)=fθ(xt′,t′)=x0∀t,t′
即所有噪声状态映射到同一个干净图像。
训练方法:
- 蒸馏:从预训练的扩散模型蒸馏
- 直接训练:设计自一致性损失
效果:
- 速度:一步生成(0.1秒/图)
- 质量:FID约5(比DDIM 50步的2.8差,但可接受)
Diffusion Transformers(DiT)
改进:用Transformer替换U-Net作为噪声预测网络。
优势:
- Scaling law:模型越大效果越好
- 统一架构:与NLP模型一致
第九部分:扩散模型的应用
图像生成
无条件生成:学习数据分布,生成多样化样本(CelebA-HQ、FFHQ)。
条件生成:
- 类别条件:ImageNet生成
- 文本条件:DALL-E 2、Stable Diffusion、Midjourney
- 布局条件:根据语义图生成(COCO-Stuff)
图像编辑
SDEdit:给定粗略编辑,添加噪声后去噪,保持语义同时提升质量。
Inpainting:修复图像缺失区域,条件为mask和已知区域。
Image-to-Image:根据引导图像生成新图像(风格迁移、超分辨率)。
视频生成
时序扩散模型 :将视频视为4D张量 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( T , H , W , C ) (T, H, W, C) </math>(T,H,W,C),在时空维度做扩散。
代表:Imagen Video、Make-A-Video、Runway Gen-2。
3D生成
DreamFusion:用预训练的2D扩散模型指导3D模型优化(NeRF)。
方法:从多视角渲染3D模型,用扩散模型的score函数作为损失。
音频与语音
语音合成:条件为文本,生成语音波形(Grad-TTS、DiffWave)。
音乐生成:Riffusion(在频谱图上做扩散)。
分子设计与科学
蛋白质结构预测:RFDiffusion(David Baker组)。
分子生成:生成具有特定性质的药物分子。
第十部分:扩散模型的优缺点与未来
优势
1. 生成质量高
- FID优于GAN
- 多样性好,无模式崩溃
2. 训练稳定
- 简单的MSE损失
- 无对抗训练的不稳定性
3. 似然可计算
- ELBO提供似然下界
- 概率流ODE可精确计算
4. 灵活的条件控制
- 支持多种条件(文本、类别、图像)
- CFG提供可调的条件强度
劣势
1. 采样速度慢
- DDPM需要1000步
- 即使DDIM也需50步
- 实时应用困难
2. 训练成本高
- 大规模数据集(数百万图像)
- 长时间训练(数周到数月)
- 高计算资源(数百GPU)
3. 理论不完善
- 损失函数的权重设计缺乏理论指导
- 采样步数与质量的关系不清晰
未来方向
1. 采样加速
- 更好的ODE求解器(DPM-Solver++)
- 蒸馏方法(Consistency Models、Progressive Distillation)
- 一步生成(Adversarial Diffusion Distillation)
2. 架构改进
- Transformer替代U-Net(DiT、U-ViT)
- 更高效的注意力机制(Flash Attention)
3. 条件控制
- 更细粒度的控制(ControlNet、T2I-Adapter)
- 多模态条件(文本+图像+布局)
4. 应用拓展
- 视频生成(时序一致性)
- 3D生成(多视角一致性)
- 科学领域(药物设计、材料科学)
5. 理论深化
- 更好的score估计方法
- 采样轨迹的优化
- 与其他生成模型的统一
总结
扩散模型是深度生成模型的重大突破,其核心思想简洁而优雅:
前向过程:逐步加噪,将数据转化为纯噪声
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x 0 → x 1 → ⋯ → x T ≈ N ( 0 , I ) x_0 \to x_1 \to \cdots \to x_T \approx \mathcal{N}(0, I) </math>x0→x1→⋯→xT≈N(0,I)
反向过程:训练神经网络逐步去噪,从噪声恢复数据
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x T → x T − 1 → ⋯ → x 0 x_T \to x_{T-1} \to \cdots \to x_0 </math>xT→xT−1→⋯→x0
训练目标:预测每一步添加的噪声
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = E t , x 0 , ε [ ∥ ε − ε θ ( x t , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}{t, x_0, \varepsilon}\left[\|\varepsilon - \varepsilon\theta(x_t, t)\|^2\right] </math>L=Et,x0,ε[∥ε−εθ(xt,t)∥2]
采样过程:从纯噪声开始,迭代去噪生成样本
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = 1 α t ( x t − β t 1 − α ˉ t ε θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}t}}\varepsilon\theta(x_t, t)\right) + \sigma_t z </math>xt−1=αt 1(xt−1−αˉt βtεθ(xt,t))+σtz
从物理的扩散现象,到严格的概率推导,再到实用的采样算法,扩散模型展现了理论与实践的完美结合。
虽然采样速度仍是瓶颈,但随着算法和硬件的进步,扩散模型正在重塑AI生成内容的格局,从艺术创作到科学研究,从娱乐产业到工业设计,扩散模型的应用前景无限广阔。
扩散模型告诉我们:有时候,通往目标的最佳路径,不是直线冲刺,而是小步迭代------从噪声到秩序,从混沌到结构,一步一步,终将抵达。
参考文献
- Sohl-Dickstein, J., et al. (2015). Deep Unsupervised Learning using Nonequilibrium Thermodynamics. ICML 2015.
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
- Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
- Nichol, A., & Dhariwal, P. (2021). Improved Denoising Diffusion Probabilistic Models. ICML 2021.
- Dhariwal, P., & Nichol, A. (2021). Diffusion Models Beat GANs on Image Synthesis. NeurIPS 2021.
- Ho, J., & Salimans, T. (2022). Classifier-Free Diffusion Guidance. NeurIPS Workshop 2021.
- Song, Y., et al. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. ICLR 2021.
- Rombach, R., et al. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR 2022.
- Ramesh, A., et al. (2022). Hierarchical Text-Conditional Image Generation with CLIP Latents. arXiv:2204.06125.
- Song, Y., et al. (2023). Consistency Models. ICML 2023.