Score Based diffusion model 数学推导

基于分数的生成模型:分数匹配完整推导

1. 问题根源:生成模型的核心挑战

我们有一组来自未知分布 p data ( x ) p_{\text{data}}(x) pdata(x) 的样本 { x 1 , x 2 , ... , x N } \{x_1, x_2, \dots, x_N\} {x1,x2,...,xN},想要学习一个参数化分布 p θ ( x ) p_\theta(x) pθ(x) 来近似 p data ( x ) p_{\text{data}}(x) pdata(x),并能从中生成新样本。

1.1 传统最大似然估计的障碍

最直接的方法是最大似然估计(MLE)
max ⁡ θ 1 N ∑ i = 1 N log ⁡ p θ ( x i ) \max_\theta \frac{1}{N} \sum_{i=1}^N \log p_\theta(x_i) θmaxN1i=1∑Nlogpθ(xi)

为了确保 p θ ( x ) p_\theta(x) pθ(x) 是有效的概率分布,必须满足归一化条件。通常我们使用能量模型
p θ ( x ) = e − f θ ( x ) Z θ , Z θ = ∫ e − f θ ( x ) d x p_\theta(x) = \frac{e^{-f_\theta(x)}}{Z_\theta}, \quad Z_\theta = \int e^{-f_\theta(x)} dx pθ(x)=Zθe−fθ(x),Zθ=∫e−fθ(x)dx

致命问题 :计算配分函数 Z θ Z_\theta Zθ 需要对整个高维空间积分,计算复杂度随维度指数增长,完全不可行。

2. 核心洞察:从绝对概率到相对方向

2.1 关键观察

要生成新数据,我们其实不需要知道"一张图片有多像猫"的绝对概率值 ,只需要知道如何修改一张图片让它变得更像猫

2.2 分数的定义

定义分数(Score) 为对数概率的梯度:
s θ ( x ) = ∇ x log ⁡ p θ ( x ) s_\theta(x) = \nabla_x \log p_\theta(x) sθ(x)=∇xlogpθ(x)

这个向量告诉我们:在点 x x x 处,应该往哪个方向移动,能最快地增加 log ⁡ p θ ( x ) \log p_\theta(x) logpθ(x)。

2.3 数学奇迹:配分函数自动消失

计算这个分数:
s θ ( x ) = ∇ x log ⁡ p θ ( x ) = ∇ x log ⁡ ( e − f θ ( x ) Z θ ) = ∇ x [ − f θ ( x ) − log ⁡ Z θ ] = − ∇ x f θ ( x ) (因为 ∇ x log ⁡ Z θ = 0 ) \begin{aligned} s_\theta(x) &= \nabla_x \log p_\theta(x) \\ &= \nabla_x \log\left(\frac{e^{-f_\theta(x)}}{Z_\theta}\right) \\ &= \nabla_x\left[-f_\theta(x) - \log Z_\theta\right] \\ &= -\nabla_x f_\theta(x) \quad \text{(因为 } \nabla_x \log Z_\theta = 0 \text{)} \end{aligned} sθ(x)=∇xlogpθ(x)=∇xlog(Zθe−fθ(x))=∇x[−fθ(x)−logZθ]=−∇xfθ(x)(因为 ∇xlogZθ=0)

配分函数 Z θ Z_\theta Zθ 消失了! 我们只需要学习 f θ ( x ) f_\theta(x) fθ(x) 的梯度。

3. 目标设定:匹配真实分布的分数

我们希望模型的分数 s θ ( x ) s_\theta(x) sθ(x) 接近真实分布的分数:
s data ( x ) = ∇ x log ⁡ p data ( x ) s_{\text{data}}(x) = \nabla_x \log p_{\text{data}}(x) sdata(x)=∇xlogpdata(x)

自然的想法是最小化它们的平方差:
J ideal ( θ ) = 1 2 E p data ( x ) ∥ s θ ( x ) − s data ( x ) ∥ 2 J_{\text{ideal}}(\theta) = \frac{1}{2} \mathbb{E}{p{\text{data}}(x)} \| s_\theta(x) - s_{\text{data}}(x) \|^2 Jideal(θ)=21Epdata(x)∥sθ(x)−sdata(x)∥2

问题 :我们不知道 p data ( x ) p_{\text{data}}(x) pdata(x),所以不知道 s data ( x ) s_{\text{data}}(x) sdata(x)。

4. 数学推导:绕过未知的真实分数

4.1 展开目标函数

∥ s θ − s data ∥ 2 = ( s θ − s data ) ⋅ ( s θ − s data ) = ∥ s θ ∥ 2 − 2 s θ ⋅ s data + ∥ s data ∥ 2 \begin{aligned} \| s_\theta - s_{\text{data}} \|^2 &= (s_\theta - s_{\text{data}}) \cdot (s_\theta - s_{\text{data}}) \\ &= \| s_\theta \|^2 - 2 s_\theta \cdot s_{\text{data}} + \| s_{\text{data}} \|^2 \end{aligned} ∥sθ−sdata∥2=(sθ−sdata)⋅(sθ−sdata)=∥sθ∥2−2sθ⋅sdata+∥sdata∥2

4.2 取期望并简化

J ideal ( θ ) = E p data [ 1 2 ∥ s θ ∥ 2 − s θ ⋅ s data + 1 2 ∥ s data ∥ 2 ] J_{\text{ideal}}(\theta) = \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 - s_\theta \cdot s_{\text{data}} + \frac{1}{2} \| s_{\text{data}} \|^2 \right] Jideal(θ)=Epdata[21∥sθ∥2−sθ⋅sdata+21∥sdata∥2]

第三项与 θ \theta θ 无关,优化时可忽略:
J ( θ ) = E p data [ 1 2 ∥ s θ ∥ 2 − s θ ⋅ s data ] J(\theta) = \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 - s_\theta \cdot s_{\text{data}} \right] J(θ)=Epdata[21∥sθ∥2−sθ⋅sdata]

4.3 关键步骤:处理 s θ ⋅ s data s_\theta \cdot s_{\text{data}} sθ⋅sdata 项

将期望写为积分形式:
E p data [ s θ ⋅ s data ] = ∫ p data ( x ) [ s θ ( x ) ⋅ s data ( x ) ] d x \mathbb{E}{p{\text{data}}}[s_\theta \cdot s_{\text{data}}] = \int p_{\text{data}}(x) \left[ s_\theta(x) \cdot s_{\text{data}}(x) \right] dx Epdata[sθ⋅sdata]=∫pdata(x)[sθ(x)⋅sdata(x)]dx

代入 s data ( x ) = ∇ x log ⁡ p data ( x ) s_{\text{data}}(x) = \nabla_x \log p_{\text{data}}(x) sdata(x)=∇xlogpdata(x):
= ∫ p data ( x ) [ s θ ( x ) ⋅ ∇ x log ⁡ p data ( x ) ] d x = \int p_{\text{data}}(x) \left[ s_\theta(x) \cdot \nabla_x \log p_{\text{data}}(x) \right] dx =∫pdata(x)[sθ(x)⋅∇xlogpdata(x)]dx

4.4 利用对数梯度性质

根据微积分: ∇ x log ⁡ p data ( x ) = ∇ x p data ( x ) p data ( x ) \nabla_x \log p_{\text{data}}(x) = \frac{\nabla_x p_{\text{data}}(x)}{p_{\text{data}}(x)} ∇xlogpdata(x)=pdata(x)∇xpdata(x)

= ∫ p data ( x ) [ s θ ( x ) ⋅ ∇ x p data ( x ) p data ( x ) ] d x = \int p_{\text{data}}(x) \left[ s_\theta(x) \cdot \frac{\nabla_x p_{\text{data}}(x)}{p_{\text{data}}(x)} \right] dx =∫pdata(x)[sθ(x)⋅pdata(x)∇xpdata(x)]dx

约去 p data ( x ) p_{\text{data}}(x) pdata(x):
= ∫ s θ ( x ) ⋅ ∇ x p data ( x ) d x = \int s_\theta(x) \cdot \nabla_x p_{\text{data}}(x) dx =∫sθ(x)⋅∇xpdata(x)dx

4.5 分部积分(散度定理)

假设当 ∥ x ∥ → ∞ \|x\| \to \infty ∥x∥→∞ 时 p data ( x ) → 0 p_{\text{data}}(x) \to 0 pdata(x)→0(边界项为0):
∫ s θ ( x ) ⋅ ∇ x p data ( x ) d x = − ∫ p data ( x ) [ ∇ x ⋅ s θ ( x ) ] d x \int s_\theta(x) \cdot \nabla_x p_{\text{data}}(x) dx = -\int p_{\text{data}}(x) \left[ \nabla_x \cdot s_\theta(x) \right] dx ∫sθ(x)⋅∇xpdata(x)dx=−∫pdata(x)[∇x⋅sθ(x)]dx

其中 ∇ x ⋅ s θ ( x ) = ∑ i = 1 D ∂ s θ , i ( x ) ∂ x i \nabla_x \cdot s_\theta(x) = \sum_{i=1}^D \frac{\partial s_{\theta,i}(x)}{\partial x_i} ∇x⋅sθ(x)=∑i=1D∂xi∂sθ,i(x) 是散度。

4.6 回到期望形式

E p data [ s θ ⋅ s data ] = − E p data [ ∇ x ⋅ s θ ( x ) ] \mathbb{E}{p{\text{data}}}[s_\theta \cdot s_{\text{data}}] = -\mathbb{E}{p{\text{data}}}[\nabla_x \cdot s_\theta(x)] Epdata[sθ⋅sdata]=−Epdata[∇x⋅sθ(x)]

5. 分数匹配目标函数

代入原目标函数:
J ( θ ) = E p data [ 1 2 ∥ s θ ∥ 2 − s θ ⋅ s data ] = E p data [ 1 2 ∥ s θ ∥ 2 + ∇ x ⋅ s θ ( x ) ] \begin{aligned} J(\theta) &= \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 - s_\theta \cdot s_{\text{data}} \right] \\ &= \mathbb{E}{p{\text{data}}} \left[ \frac{1}{2} \| s_\theta \|^2 + \nabla_x \cdot s_\theta(x) \right] \end{aligned} J(θ)=Epdata[21∥sθ∥2−sθ⋅sdata]=Epdata[21∥sθ∥2+∇x⋅sθ(x)]

这就是分数匹配(Score Matching) 目标函数:
J SM ( θ ) = E p data ( x ) [ 1 2 ∥ s θ ( x ) ∥ 2 + ∇ x ⋅ s θ ( x ) ] J_{\text{SM}}(\theta) = \mathbb{E}{p{\text{data}}(x)} \left[ \frac{1}{2} \| s_\theta(x) \|^2 + \nabla_x \cdot s_\theta(x) \right] JSM(θ)=Epdata(x)[21∥sθ(x)∥2+∇x⋅sθ(x)]

优点

  1. 只需要从 p data p_{\text{data}} pdata 中采样数据点 x x x
  2. 不需要知道 p data ( x ) p_{\text{data}}(x) pdata(x) 或 s data ( x ) s_{\text{data}}(x) sdata(x) 的具体形式

实际问题:显式分数匹配的双重困境

虽然显式分数匹配提供了一种绕过未知真实分数 s data ( x ) s_{\text{data}}(x) sdata(x) 的方法,但在处理高维数据(如图像)时仍然面临两个相互关联的根本性挑战:

6.1 计算复杂度挑战

对于 D D D 维数据,计算目标函数中的散度项 ∇ x ⋅ s θ ( x ) \nabla_x \cdot s_\theta(x) ∇x⋅sθ(x) 需要:

∇ x ⋅ s θ ( x ) = ∑ i = 1 D ∂ s θ , i ( x ) ∂ x i \nabla_x \cdot s_\theta(x) = \sum_{i=1}^D \frac{\partial s_{\theta,i}(x)}{\partial x_i} ∇x⋅sθ(x)=i=1∑D∂xi∂sθ,i(x)

这涉及计算 D D D 个偏导数,每个偏导数又需要对神经网络进行反向传播,总体计算量达到 O ( D 2 ) O(D^2) O(D2)。对于图像数据( D > 1 0 5 D > 10^5 D>105),这种计算开销在实践中是完全不可接受的。

6.2 流形假设的理论挑战

更本质的问题是,根据流形假设(manifold hypothesis),真实数据往往分布在高维空间中的一个低维流形上。这意味着:

  1. 数据分布高度集中 :在高维空间中,数据密度 p data ( x ) p_{\text{data}}(x) pdata(x) 仅在低维流形上显著大于零,而在流形之外的广阔区域趋近于零。

  2. 分数定义的病态性

    • 在流形之外的低密度区域, p data ( x ) ≈ 0 p_{\text{data}}(x) \approx 0 pdata(x)≈0 导致分数 ∇ x log ⁡ p data ( x ) \nabla_x \log p_{\text{data}}(x) ∇xlogpdata(x) 的计算成为数值不稳定的"0/0"形式。
    • 即使我们使用显式分数匹配得到 s θ ( x ) s_\theta(x) sθ(x),模型在这些区域也缺乏有效的训练信号。
  3. 生成过程的不稳定性

    • 朗之万动力学等采样方法从随机噪声开始,其采样路径必然穿越这些低密度区域
    • 如果模型在这些区域的分数估计不准确,生成过程就会发散或陷入次优解,导致生成样本质量低下或失败。

7. 解决方案:去噪分数匹配

7.1 核心思想:制造已知的监督信号

既然不知道真实分数,就自己创造一个已知分数的分布

具体方法 :对每个真实数据点 x x x(来自 p data p_{\text{data}} pdata),添加高斯噪声:
x ~ = x + σ ϵ , ϵ ∼ N ( 0 , I ) \tilde{x} = x + \sigma \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) x~=x+σϵ,ϵ∼N(0,I)

7.2 条件分布是已知的高斯分布

给定原始数据点 x x x,噪声数据 x ~ \tilde{x} x~ 的条件分布:
q σ ( x ~ ∣ x ) = N ( x ~ ∣ x , σ 2 I ) q_\sigma(\tilde{x} | x) = \mathcal{N}(\tilde{x} | x, \sigma^2 I) qσ(x~∣x)=N(x~∣x,σ2I)

7.3 关键:这个条件分布的分数已知

高斯分布的分数有解析表达式:
∇ x ~ log ⁡ q σ ( x ~ ∣ x ) = − x ~ − x σ 2 \nabla_{\tilde{x}} \log q_\sigma(\tilde{x} | x) = -\frac{\tilde{x} - x}{\sigma^2} ∇x~logqσ(x~∣x)=−σ2x~−x

7.4 去噪分数匹配目标

训练网络 s θ ( x ~ , σ ) s_\theta(\tilde{x}, \sigma) sθ(x~,σ) 来匹配这个已知的分数:
J DSM ( θ ; σ ) = 1 2 E x ∼ p data E x ~ ∼ q σ ( x ~ ∣ x ) ∥ s θ ( x ~ , σ ) − ∇ x ~ log ⁡ q σ ( x ~ ∣ x ) ∥ 2 J_{\text{DSM}}(\theta; \sigma) = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}} \mathbb{E}{\tilde{x} \sim q\sigma(\tilde{x}|x)} \| s_\theta(\tilde{x}, \sigma) - \nabla_{\tilde{x}} \log q_\sigma(\tilde{x}|x) \|^2 JDSM(θ;σ)=21Ex∼pdataEx~∼qσ(x~∣x)∥sθ(x~,σ)−∇x~logqσ(x~∣x)∥2

代入已知表达式:
J DSM ( θ ; σ ) = 1 2 E x ∼ p data E x ~ ∼ q σ ( x ~ ∣ x ) ∥ s θ ( x ~ , σ ) + x ~ − x σ 2 ∥ 2 J_{\text{DSM}}(\theta; \sigma) = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}} \mathbb{E}{\tilde{x} \sim q\sigma(\tilde{x}|x)} \left\| s_\theta(\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \right\|^2 JDSM(θ;σ)=21Ex∼pdataEx~∼qσ(x~∣x) sθ(x~,σ)+σ2x~−x 2

7.5 重参数化技巧

令 ϵ = x ~ − x σ ∼ N ( 0 , I ) \epsilon = \frac{\tilde{x} - x}{\sigma} \sim \mathcal{N}(0, I) ϵ=σx~−x∼N(0,I),则 x ~ = x + σ ϵ \tilde{x} = x + \sigma \epsilon x~=x+σϵ:
J DSM ( θ ; σ ) = 1 2 E x ∼ p data E ϵ ∼ N ( 0 , I ) ∥ s θ ( x + σ ϵ , σ ) + ϵ σ ∥ 2 J_{\text{DSM}}(\theta; \sigma) = \frac{1}{2} \mathbb{E}{x \sim p{\text{data}}} \mathbb{E}{\epsilon \sim \mathcal{N}(0,I)} \left\| s\theta(x + \sigma \epsilon, \sigma) + \frac{\epsilon}{\sigma} \right\|^2 JDSM(θ;σ)=21Ex∼pdataEϵ∼N(0,I) sθ(x+σϵ,σ)+σϵ 2

8. 多尺度噪声训练

8.1 为什么需要多个噪声尺度?

  • 大 σ \sigma σ:覆盖范围广,帮助探索低概率区域,但估计粗糙
  • 小 σ \sigma σ:估计精确,接近真实分布,但只覆盖数据附近区域

8.2 噪声尺度序列

使用递减的噪声尺度: σ 1 > σ 2 > ⋯ > σ L \sigma_1 > \sigma_2 > \cdots > \sigma_L σ1>σ2>⋯>σL

8.3 最终目标函数

L ( θ ) = 1 L ∑ i = 1 L λ ( σ i ) J DSM ( θ ; σ i ) L(\theta) = \frac{1}{L} \sum_{i=1}^L \lambda(\sigma_i) J_{\text{DSM}}(\theta; \sigma_i) L(θ)=L1i=1∑Lλ(σi)JDSM(θ;σi)

其中权重 λ ( σ i ) = σ i 2 \lambda(\sigma_i) = \sigma_i^2 λ(σi)=σi2 用于平衡不同尺度。

9. 训练伪代码

python 复制代码
def train_step(batch_size):
    # 1. 从真实数据采样
    x_clean = sample_data(batch_size)  # 来自 p_data
    
    # 2. 随机选择噪声尺度
    sigma = random.choice([σ1, σ2, ..., σL])
    
    # 3. 添加噪声
    epsilon = torch.randn_like(x_clean)
    x_noisy = x_clean + sigma * epsilon
    
    # 4. 计算目标分数(已知的真值)
    target_score = -epsilon / sigma  # = -(x_noisy - x_clean)/σ²
    
    # 5. 网络预测
    predicted_score = score_network(x_noisy, sigma)
    
    # 6. 计算损失(加权)
    loss = torch.mean((predicted_score - target_score)**2) * sigma**2
    
    # 7. 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

10. 生成样本:退火朗之万动力学

10.1 朗之万动力学更新

有了训练好的分数网络 s θ ( x , σ ) s_\theta(x, \sigma) sθ(x,σ),可以从随机噪声开始生成样本:
x t + 1 = x t + η ⋅ s θ ( x t , σ ) + 2 η ⋅ z t , z t ∼ N ( 0 , I ) x_{t+1} = x_t + \eta \cdot s_\theta(x_t, \sigma) + \sqrt{2\eta} \cdot z_t, \quad z_t \sim \mathcal{N}(0, I) xt+1=xt+η⋅sθ(xt,σ)+2η ⋅zt,zt∼N(0,I)

10.2 退火策略

  1. 从最大的噪声尺度 σ 1 \sigma_1 σ1 开始
  2. 执行若干步朗之万更新
  3. 切换到下一个更小的噪声尺度 σ 2 \sigma_2 σ2
  4. 重复直到最小的噪声尺度 σ L \sigma_L σL
python 复制代码
def annealed_langevin_dynamics(score_net, sigmas):
    x = torch.randn(shape)  # 从噪声开始
    
    for sigma in sigmas:  # 从大到小
        step_size = 0.00002 * (sigma / sigmas[-1])**2
        
        for t in range(num_steps):
            # 获取当前噪声尺度下的分数
            score = score_net(x, sigma)
            
            # 随机噪声
            noise = torch.randn_like(x)
            
            # 朗之万更新
            x = x + step_size * score + np.sqrt(2*step_size) * noise
    
    return x  # 生成的样本

总结:分数匹配的逻辑链

  1. 发现问题 :传统最大似然估计需要计算不可处理的配分函数 Z θ Z_\theta Zθ
  2. 转换思路 :不建模绝对概率 p ( x ) p(x) p(x),转而建模其梯度 ∇ x log ⁡ p ( x ) \nabla_x \log p(x) ∇xlogp(x)(分数)
  3. 数学奇迹 :分数的计算中,配分函数 Z θ Z_\theta Zθ 自动消去
  4. 训练挑战 :不知道真实分数 s data ( x ) s_{\text{data}}(x) sdata(x)
  5. 第一次突破 :通过分部积分,得到不依赖 s data ( x ) s_{\text{data}}(x) sdata(x) 的目标函数(但需计算散度)
  6. 计算难题 :计算散度 ∇ x ⋅ s θ ( x ) \nabla_x \cdot s_\theta(x) ∇x⋅sθ(x) 需要二阶导数,计算量过大
  7. 第二次突破:去噪分数匹配------添加已知的高斯噪声,利用条件分布的已知分数作为监督信号
  8. 实践技巧:使用多尺度噪声训练,平衡覆盖范围与估计精度
  9. 生成过程:通过退火朗之万动力学,从噪声中逐步"雕刻"出高质量样本

最终哲学 :放弃计算"这个东西有多好"(绝对概率),转而学习"如何让这个东西变得更好"(梯度方向)。这种转变巧妙地绕过了生成模型的最大障碍,成为现代扩散模型的理论基础。

参考:扩散模型 | 1.Score-based model精讲

相关推荐
声声codeGrandMaster2 小时前
AI之模型提升
人工智能·pytorch·python·算法·ai
黄金小码农2 小时前
工具坐标系
算法
小南家的青蛙3 小时前
LeetCode第1261题 - 在受污染的二叉树中查找元素
算法·leetcode·职场和发展
君义_noip3 小时前
信息学奥赛一本通 1453:移动玩具 | 洛谷 P4289 [HAOI2008] 移动玩具
c++·算法·信息学奥赛·csp-s
玖剹3 小时前
记忆化搜索题目(二)
c语言·c++·算法·leetcode·深度优先·剪枝·深度优先遍历
Xy-unu4 小时前
[LLM]AIM: Adaptive Inference of Multi-Modal LLMs via Token Merging and Pruning
论文阅读·人工智能·算法·机器学习·transformer·论文笔记·剪枝
Hcoco_me4 小时前
算法选型 + 调参避坑指南
算法
Jul1en_4 小时前
【算法】分治-归并类题目
java·算法·leetcode·排序算法
kangk124 小时前
统计学基础之概率(生物信息方向)
人工智能·算法·机器学习