基于分数的生成模型:分数匹配完整推导
1. 问题根源:生成模型的核心挑战
我们有一组来自未知分布 p data ( x ) p_{\text{data}}(x) pdata(x) 的样本 { x 1 , x 2 , ... , x N } \{x_1, x_2, \dots, x_N\} {x1,x2,...,xN},想要学习一个参数化分布 p θ ( x ) p_\theta(x) pθ(x) 来近似 p data ( x ) p_{\text{data}}(x) pdata(x),并能从中生成新样本。
1.1 传统最大似然估计的障碍
最直接的方法是最大似然估计(MLE) :
max θ 1 N ∑ i = 1 N log p θ ( x i ) \max_\theta \frac{1}{N} \sum_{i=1}^N \log p_\theta(x_i) θmaxN1i=1∑Nlogpθ(xi)
为了确保 p θ ( x ) p_\theta(x) pθ(x) 是有效的概率分布,必须满足归一化条件。通常我们使用能量模型 :
p θ ( x ) = e − f θ ( x ) Z θ , Z θ = ∫ e − f θ ( x ) d x p_\theta(x) = \frac{e^{-f_\theta(x)}}{Z_\theta}, \quad Z_\theta = \int e^{-f_\theta(x)} dx pθ(x)=Zθe−fθ(x),Zθ=∫e−fθ(x)dx
致命问题 :计算配分函数 Z θ Z_\theta Zθ 需要对整个高维空间积分,计算复杂度随维度指数增长,完全不可行。
2. 核心洞察:从绝对概率到相对方向
2.1 关键观察
要生成新数据,我们其实不需要知道"一张图片有多像猫"的绝对概率值 ,只需要知道如何修改一张图片让它变得更像猫。
2.2 分数的定义
定义分数(Score) 为对数概率的梯度:
s θ ( x ) = ∇ x log p θ ( x ) s_\theta(x) = \nabla_x \log p_\theta(x) sθ(x)=∇xlogpθ(x)
这个向量告诉我们:在点 x x x 处,应该往哪个方向移动,能最快地增加 log p θ ( x ) \log p_\theta(x) logpθ(x)。
2.3 数学奇迹:配分函数自动消失
计算这个分数:
s θ ( x ) = ∇ x log p θ ( x ) = ∇ x log ( e − f θ ( x ) Z θ ) = ∇ x [ − f θ ( x ) − log Z θ ] = − ∇ x f θ ( x ) (因为 ∇ x log Z θ = 0 ) \begin{aligned} s_\theta(x) &= \nabla_x \log p_\theta(x) \\ &= \nabla_x \log\left(\frac{e^{-f_\theta(x)}}{Z_\theta}\right) \\ &= \nabla_x\left[-f_\theta(x) - \log Z_\theta\right] \\ &= -\nabla_x f_\theta(x) \quad \text{(因为 } \nabla_x \log Z_\theta = 0 \text{)} \end{aligned} sθ(x)=∇xlogpθ(x)=∇xlog(Zθe−fθ(x))=∇x[−fθ(x)−logZθ]=−∇xfθ(x)(因为 ∇xlogZθ=0)
配分函数 Z θ Z_\theta Zθ 消失了! 我们只需要学习 f θ ( x ) f_\theta(x) fθ(x) 的梯度。
3. 目标设定:匹配真实分布的分数
我们希望模型的分数 s θ ( x ) s_\theta(x) sθ(x) 接近真实分布的分数:
s data ( x ) = ∇ x log p data ( x ) s_{\text{data}}(x) = \nabla_x \log p_{\text{data}}(x) sdata(x)=∇xlogpdata(x)
自然的想法是最小化它们的平方差:
J ideal ( θ ) = 1 2 E p data ( x ) ∥ s θ ( x ) − s data ( x ) ∥ 2 J_{\text{ideal}}(\theta) = \frac{1}{2} \mathbb{E}{p{\text{data}}(x)} \| s_\theta(x) - s_{\text{data}}(x) \|^2 Jideal(θ)=21Epdata(x)∥sθ(x)−sdata(x)∥2
问题 :我们不知道 p data ( x ) p_{\text{data}}(x) pdata(x),所以不知道 s data ( x ) s_{\text{data}}(x) sdata(x)。
4. 数学推导:绕过未知的真实分数
4.1 展开目标函数
∥ s θ − s data ∥ 2 = ( s θ − s data ) ⋅ ( s θ − s data ) = ∥ s θ ∥ 2 − 2 s θ ⋅ s data + ∥ s data ∥ 2 \begin{aligned} \| s_\theta - s_{\text{data}} \|^2 &= (s_\theta - s_{\text{data}}) \cdot (s_\theta - s_{\text{data}}) \\ &= \| s_\theta \|^2 - 2 s_\theta \cdot s_{\text{data}} + \| s_{\text{data}} \|^2 \end{aligned} ∥sθ−sdata∥2=(sθ−sdata)⋅(sθ−sdata)=∥sθ∥2−2sθ⋅sdata+∥sdata∥2
4.2 取期望并简化
J ideal ( θ ) = E p data [ 1 2 ∥ s θ ∥ 2 − s θ ⋅ s data + 1 2 ∥ s data ∥ 2 ] J_{\text{ideal}}(\theta) = \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 - s_\theta \cdot s_{\text{data}} + \frac{1}{2} \| s_{\text{data}} \|^2 \right] Jideal(θ)=Epdata[21∥sθ∥2−sθ⋅sdata+21∥sdata∥2]
第三项与 θ \theta θ 无关,优化时可忽略:
J ( θ ) = E p data [ 1 2 ∥ s θ ∥ 2 − s θ ⋅ s data ] J(\theta) = \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 - s_\theta \cdot s_{\text{data}} \right] J(θ)=Epdata[21∥sθ∥2−sθ⋅sdata]
4.3 关键步骤:处理 s θ ⋅ s data s_\theta \cdot s_{\text{data}} sθ⋅sdata 项
将期望写为积分形式:
E p data [ s θ ⋅ s data ] = ∫ p data ( x ) [ s θ ( x ) ⋅ s data ( x ) ] d x \mathbb{E}{p{\text{data}}}[s_\theta \cdot s_{\text{data}}] = \int p_{\text{data}}(x) \left[ s_\theta(x) \cdot s_{\text{data}}(x) \right] dx Epdata[sθ⋅sdata]=∫pdata(x)[sθ(x)⋅sdata(x)]dx
代入 s data ( x ) = ∇ x log p data ( x ) s_{\text{data}}(x) = \nabla_x \log p_{\text{data}}(x) sdata(x)=∇xlogpdata(x):
= ∫ p data ( x ) [ s θ ( x ) ⋅ ∇ x log p data ( x ) ] d x = \int p_{\text{data}}(x) \left[ s_\theta(x) \cdot \nabla_x \log p_{\text{data}}(x) \right] dx =∫pdata(x)[sθ(x)⋅∇xlogpdata(x)]dx
4.4 利用对数梯度性质
根据微积分: ∇ x log p data ( x ) = ∇ x p data ( x ) p data ( x ) \nabla_x \log p_{\text{data}}(x) = \frac{\nabla_x p_{\text{data}}(x)}{p_{\text{data}}(x)} ∇xlogpdata(x)=pdata(x)∇xpdata(x)
= ∫ p data ( x ) [ s θ ( x ) ⋅ ∇ x p data ( x ) p data ( x ) ] d x = \int p_{\text{data}}(x) \left[ s_\theta(x) \cdot \frac{\nabla_x p_{\text{data}}(x)}{p_{\text{data}}(x)} \right] dx =∫pdata(x)[sθ(x)⋅pdata(x)∇xpdata(x)]dx
约去 p data ( x ) p_{\text{data}}(x) pdata(x):
= ∫ s θ ( x ) ⋅ ∇ x p data ( x ) d x = \int s_\theta(x) \cdot \nabla_x p_{\text{data}}(x) dx =∫sθ(x)⋅∇xpdata(x)dx
4.5 分部积分(散度定理)
假设当 ∥ x ∥ → ∞ \|x\| \to \infty ∥x∥→∞ 时 p data ( x ) → 0 p_{\text{data}}(x) \to 0 pdata(x)→0(边界项为0):
∫ s θ ( x ) ⋅ ∇ x p data ( x ) d x = − ∫ p data ( x ) [ ∇ x ⋅ s θ ( x ) ] d x \int s_\theta(x) \cdot \nabla_x p_{\text{data}}(x) dx = -\int p_{\text{data}}(x) \left[ \nabla_x \cdot s_\theta(x) \right] dx ∫sθ(x)⋅∇xpdata(x)dx=−∫pdata(x)[∇x⋅sθ(x)]dx
其中 ∇ x ⋅ s θ ( x ) = ∑ i = 1 D ∂ s θ , i ( x ) ∂ x i \nabla_x \cdot s_\theta(x) = \sum_{i=1}^D \frac{\partial s_{\theta,i}(x)}{\partial x_i} ∇x⋅sθ(x)=∑i=1D∂xi∂sθ,i(x) 是散度。
4.6 回到期望形式
E p data [ s θ ⋅ s data ] = − E p data [ ∇ x ⋅ s θ ( x ) ] \mathbb{E}{p{\text{data}}}[s_\theta \cdot s_{\text{data}}] = -\mathbb{E}{p{\text{data}}}[\nabla_x \cdot s_\theta(x)] Epdata[sθ⋅sdata]=−Epdata[∇x⋅sθ(x)]
5. 分数匹配目标函数
代入原目标函数:
J ( θ ) = E p data [ 1 2 ∥ s θ ∥ 2 − s θ ⋅ s data ] = E p data [ 1 2 ∥ s θ ∥ 2 + ∇ x ⋅ s θ ( x ) ] \begin{aligned} J(\theta) &= \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 - s_\theta \cdot s_{\text{data}} \right] \\ &= \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 + \nabla_x \cdot s_\theta(x) \right] \end{aligned} J(θ)=Epdata[21∥sθ∥2−sθ⋅sdata]=Epdata[21∥sθ∥2+∇x⋅sθ(x)]
这就是分数匹配(Score Matching) 目标函数:
J SM ( θ ) = E p data ( x ) [ 1 2 ∥ s θ ( x ) ∥ 2 + ∇ x ⋅ s θ ( x ) ] J_{\text{SM}}(\theta) = \mathbb{E}{p{\text{data}}(x)} \left[ \frac{1}{2} \| s_\theta(x) \|^2 + \nabla_x \cdot s_\theta(x) \right] JSM(θ)=Epdata(x)[21∥sθ(x)∥2+∇x⋅sθ(x)]
优点:
- 只需要从 p data p_{\text{data}} pdata 中采样数据点 x x x
- 不需要知道 p data ( x ) p_{\text{data}}(x) pdata(x) 或 s data ( x ) s_{\text{data}}(x) sdata(x) 的具体形式
实际问题:显式分数匹配的双重困境
虽然显式分数匹配提供了一种绕过未知真实分数 s data ( x ) s_{\text{data}}(x) sdata(x) 的方法,但在处理高维数据(如图像)时仍然面临两个相互关联的根本性挑战:
6.1 计算复杂度挑战
对于 D D D 维数据,计算目标函数中的散度项 ∇ x ⋅ s θ ( x ) \nabla_x \cdot s_\theta(x) ∇x⋅sθ(x) 需要:
∇ x ⋅ s θ ( x ) = ∑ i = 1 D ∂ s θ , i ( x ) ∂ x i \nabla_x \cdot s_\theta(x) = \sum_{i=1}^D \frac{\partial s_{\theta,i}(x)}{\partial x_i} ∇x⋅sθ(x)=i=1∑D∂xi∂sθ,i(x)
这涉及计算 D D D 个偏导数,每个偏导数又需要对神经网络进行反向传播,总体计算量达到 O ( D 2 ) O(D^2) O(D2)。对于图像数据( D > 1 0 5 D > 10^5 D>105),这种计算开销在实践中是完全不可接受的。
6.2 流形假设的理论挑战
更本质的问题是,根据流形假设(manifold hypothesis),真实数据往往分布在高维空间中的一个低维流形上。这意味着:
-
数据分布高度集中 :在高维空间中,数据密度 p data ( x ) p_{\text{data}}(x) pdata(x) 仅在低维流形上显著大于零,而在流形之外的广阔区域趋近于零。
-
分数定义的病态性:
- 在流形之外的低密度区域, p data ( x ) ≈ 0 p_{\text{data}}(x) \approx 0 pdata(x)≈0 导致分数 ∇ x log p data ( x ) \nabla_x \log p_{\text{data}}(x) ∇xlogpdata(x) 的计算成为数值不稳定的"0/0"形式。
- 即使我们使用显式分数匹配得到 s θ ( x ) s_\theta(x) sθ(x),模型在这些区域也缺乏有效的训练信号。
-
生成过程的不稳定性:
- 朗之万动力学等采样方法从随机噪声开始,其采样路径必然穿越这些低密度区域。
- 如果模型在这些区域的分数估计不准确,生成过程就会发散或陷入次优解,导致生成样本质量低下或失败。
7. 解决方案:去噪分数匹配
7.1 核心思想:制造已知的监督信号
既然不知道真实分数,就自己创造一个已知分数的分布。
具体方法 :对每个真实数据点 x x x(来自 p data p_{\text{data}} pdata),添加高斯噪声:
x ~ = x + σ ϵ , ϵ ∼ N ( 0 , I ) \tilde{x} = x + \sigma \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) x~=x+σϵ,ϵ∼N(0,I)
7.2 条件分布是已知的高斯分布
给定原始数据点 x x x,噪声数据 x ~ \tilde{x} x~ 的条件分布:
q σ ( x ~ ∣ x ) = N ( x ~ ∣ x , σ 2 I ) q_\sigma(\tilde{x} | x) = \mathcal{N}(\tilde{x} | x, \sigma^2 I) qσ(x~∣x)=N(x~∣x,σ2I)
7.3 关键:这个条件分布的分数已知
高斯分布的分数有解析表达式:
∇ x ~ log q σ ( x ~ ∣ x ) = − x ~ − x σ 2 \nabla_{\tilde{x}} \log q_\sigma(\tilde{x} | x) = -\frac{\tilde{x} - x}{\sigma^2} ∇x~logqσ(x~∣x)=−σ2x~−x
7.4 去噪分数匹配目标
训练网络 s θ ( x ~ , σ ) s_\theta(\tilde{x}, \sigma) sθ(x~,σ) 来匹配这个已知的分数:
J DSM ( θ ; σ ) = 1 2 E x ∼ p data E x ~ ∼ q σ ( x ~ ∣ x ) ∥ s θ ( x ~ , σ ) − ∇ x ~ log q σ ( x ~ ∣ x ) ∥ 2 J_{\text{DSM}}(\theta; \sigma) = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}} \mathbb{E}{\tilde{x} \sim q\sigma(\tilde{x}|x)} \| s_\theta(\tilde{x}, \sigma) - \nabla_{\tilde{x}} \log q_\sigma(\tilde{x}|x) \|^2 JDSM(θ;σ)=21Ex∼pdataEx~∼qσ(x~∣x)∥sθ(x~,σ)−∇x~logqσ(x~∣x)∥2
代入已知表达式:
J DSM ( θ ; σ ) = 1 2 E x ∼ p data E x ~ ∼ q σ ( x ~ ∣ x ) ∥ s θ ( x ~ , σ ) + x ~ − x σ 2 ∥ 2 J_{\text{DSM}}(\theta; \sigma) = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}} \mathbb{E}{\tilde{x} \sim q\sigma(\tilde{x}|x)} \left\| s_\theta(\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \right\|^2 JDSM(θ;σ)=21Ex∼pdataEx~∼qσ(x~∣x) sθ(x~,σ)+σ2x~−x 2
7.5 重参数化技巧
令 ϵ = x ~ − x σ ∼ N ( 0 , I ) \epsilon = \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I) ϵ=σx~−x∼N(0,I),则 x ~ = x + σ ϵ \tilde{x} = x + \sigma \epsilon x~=x+σϵ:
J DSM ( θ ; σ ) = 1 2 E x ∼ p data E ϵ ∼ N ( 0 , I ) ∥ s θ ( x + σ ϵ , σ ) + ϵ σ ∥ 2 J_{\text{DSM}}(\theta; \sigma) = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}} \mathbb{E}{\epsilon \sim \mathcal{N}(0,I)} \left\| s\theta(x + \sigma \epsilon, \sigma) + \frac{\epsilon}{\sigma} \right\|^2 JDSM(θ;σ)=21Ex∼pdataEϵ∼N(0,I) sθ(x+σϵ,σ)+σϵ 2
8. 多尺度噪声训练
8.1 为什么需要多个噪声尺度?
- 大 σ \sigma σ:覆盖范围广,帮助探索低概率区域,但估计粗糙
- 小 σ \sigma σ:估计精确,接近真实分布,但只覆盖数据附近区域
8.2 噪声尺度序列
使用递减的噪声尺度: σ 1 > σ 2 > ⋯ > σ L \sigma_1 > \sigma_2 > \cdots > \sigma_L σ1>σ2>⋯>σL
8.3 最终目标函数
L ( θ ) = 1 L ∑ i = 1 L λ ( σ i ) J DSM ( θ ; σ i ) L(\theta) = \frac{1}{L} \sum_{i=1}^L \lambda(\sigma_i) J_{\text{DSM}}(\theta; \sigma_i) L(θ)=L1i=1∑Lλ(σi)JDSM(θ;σi)
其中权重 λ ( σ i ) = σ i 2 \lambda(\sigma_i) = \sigma_i^2 λ(σi)=σi2 用于平衡不同尺度。
9. 训练伪代码
python
def train_step(batch_size):
# 1. 从真实数据采样
x_clean = sample_data(batch_size) # 来自 p_data
# 2. 随机选择噪声尺度
sigma = random.choice([σ1, σ2, ..., σL])
# 3. 添加噪声
epsilon = torch.randn_like(x_clean)
x_noisy = x_clean + sigma * epsilon
# 4. 计算目标分数(已知的真值)
target_score = -epsilon / sigma # = -(x_noisy - x_clean)/σ²
# 5. 网络预测
predicted_score = score_network(x_noisy, sigma)
# 6. 计算损失(加权)
loss = torch.mean((predicted_score - target_score)**2) * sigma**2
# 7. 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
10. 生成样本:退火朗之万动力学
10.1 朗之万动力学更新
有了训练好的分数网络 s θ ( x , σ ) s_\theta(x, \sigma) sθ(x,σ),可以从随机噪声开始生成样本:
x t + 1 = x t + η ⋅ s θ ( x t , σ ) + 2 η ⋅ z t , z t ∼ N ( 0 , I ) x_{t+1} = x_t + \eta \cdot s_\theta(x_t, \sigma) + \sqrt{2\eta} \cdot z_t, \quad z_t \sim \mathcal{N}(0, I) xt+1=xt+η⋅sθ(xt,σ)+2η ⋅zt,zt∼N(0,I)
10.2 退火策略
- 从最大的噪声尺度 σ 1 \sigma_1 σ1 开始
- 执行若干步朗之万更新
- 切换到下一个更小的噪声尺度 σ 2 \sigma_2 σ2
- 重复直到最小的噪声尺度 σ L \sigma_L σL
python
def annealed_langevin_dynamics(score_net, sigmas):
x = torch.randn(shape) # 从噪声开始
for sigma in sigmas: # 从大到小
step_size = 0.00002 * (sigma / sigmas[-1])**2
for t in range(num_steps):
# 获取当前噪声尺度下的分数
score = score_net(x, sigma)
# 随机噪声
noise = torch.randn_like(x)
# 朗之万更新
x = x + step_size * score + np.sqrt(2*step_size) * noise
return x # 生成的样本
总结:分数匹配的逻辑链
- 发现问题 :传统最大似然估计需要计算不可处理的配分函数 Z θ Z_\theta Zθ
- 转换思路 :不建模绝对概率 p ( x ) p(x) p(x),转而建模其梯度 ∇ x log p ( x ) \nabla_x \log p(x) ∇xlogp(x)(分数)
- 数学奇迹 :分数的计算中,配分函数 Z θ Z_\theta Zθ 自动消去
- 训练挑战 :不知道真实分数 s data ( x ) s_{\text{data}}(x) sdata(x)
- 第一次突破 :通过分部积分,得到不依赖 s data ( x ) s_{\text{data}}(x) sdata(x) 的目标函数(但需计算散度)
- 计算难题 :计算散度 ∇ x ⋅ s θ ( x ) \nabla_x \cdot s_\theta(x) ∇x⋅sθ(x) 需要二阶导数,计算量过大
- 第二次突破:去噪分数匹配------添加已知的高斯噪声,利用条件分布的已知分数作为监督信号
- 实践技巧:使用多尺度噪声训练,平衡覆盖范围与估计精度
- 生成过程:通过退火朗之万动力学,从噪声中逐步"雕刻"出高质量样本
最终哲学 :放弃计算"这个东西有多好"(绝对概率),转而学习"如何让这个东西变得更好"(梯度方向)。这种转变巧妙地绕过了生成模型的最大障碍,成为现代扩散模型的理论基础。