接上篇:带噪学习 | Ambient Diffusion (NeurIPS 2023) 上篇
采样过程
上述介绍的是训练损失的推导,而接下来开始考虑推理过程中的采样操作。
扩散模型的核心是通过逆转扩散过程采样分布,需依赖分数函数 ∇ x t log p t ( x t ) \nabla_{x_t} \log p_t(x_t) ∇xtlogpt(xt)。
通过 Tweedie 公式,分数函数可表示为
∇ x t log p t ( x t ) = E [ x 0 ∣ x t ] − x t σ t \nabla_{x_t} \log p_t(x_t) = \frac{\mathbb{E}[x_0 | x_t] - x_t}{\sigma_t} ∇xtlogpt(xt)=σtE[x0∣xt]−xt
Tweedie公式
这里需要从扩散模型前向过程的递推关系出发,结合条件分布的概率密度 和分数函数的定义 逐步推导。Ambient Diffusion采用的是简化的线性高斯前向过程 (连续/离散时间均可,这里以离散时间为例)。从干净图像 x 0 x_0 x0 开始,逐步向图像中添加高斯噪声 ,得到 t t t 时刻的含噪图像 x t x_t xt,递推关系为: x t = x t − 1 + σ t η t , η t ∼ N ( 0 , I ) x_t = x_{t-1} + \sigma_t \eta_t, \eta_t \sim \mathcal{N}(0, I) xt=xt−1+σtηt,ηt∼N(0,I)。其中 x t − 1 x_{t-1} xt−1表示 t − 1 t-1 t−1时刻的含噪图像; σ t \sigma_t σt表示 t t t时刻的噪声标准差( t t t 越大, σ t \sigma_t σt 越大,图像噪声越强); η t \eta_t ηt表示独立的标准高斯噪声(与 x t − 1 x_{t-1} xt−1无关)。
通过递推式累加,可直接得到 x t = x 0 + ∑ k = 1 t σ k η k x_t = x_0 + \sum_{k=1}^t \sigma_k \eta_k xt=x0+∑k=1tσkηk。由于高斯噪声的线性组合仍为高斯噪声,则递推关系可简化为更紧凑的形式
x t = x 0 + σ t η , η ∼ N ( 0 , I ) x_t = x_0 + \sigma_t \eta, \eta \sim \mathcal{N}(0, I) xt=x0+σtη,η∼N(0,I)这就是Ambient Diffusion中前向过程的核心关系, t t t时刻的含噪图像是干净图像+当前尺度的高斯噪声。
由前向过程的关系 x t = x 0 + σ t η x_t = x_0 + \sigma_t \eta xt=x0+σtη,给定 x 0 x_0 x0 时, x t x_t xt 服从均值为 x 0 x_0 x0、方差为 σ t 2 I \sigma_t^2 I σt2I 的高斯分布,即
p ( x t ∣ x 0 ) = N ( x t ∣ x 0 , σ t 2 I ) p(x_t \mid x_0) = \mathcal{N}(x_t \mid x_0, \sigma_t^2 I) p(xt∣x0)=N(xt∣x0,σt2I)其概率密度函数为
p ( x t ∣ x 0 ) = 1 ( 2 π σ t 2 ) n / 2 exp ( − 1 2 σ t 2 ∥ x t − x 0 ∥ 2 ) p(x_t \mid x_0) = \frac{1}{(2\pi \sigma_t^2)^{n/2}} \exp\left( -\frac{1}{2\sigma_t^2} \| x_t - x_0 \|^2 \right) p(xt∣x0)=(2πσt2)n/21exp(−2σt21∥xt−x0∥2)其中 n n n 是向量 x t x_t xt 的维度。
分数函数 ∇ x t log p t ( x t ) \nabla_{x_t} \log p_t(x_t) ∇xtlogpt(xt)的核心,是边缘分布 p t ( x t ) p_t(x_t) pt(xt) 的对数梯度,但需先计算条件分布 p ( x t ∣ x 0 ) p(x_t \mid x_0) p(xt∣x0) 的对数梯度。
对 log p ( x t ∣ x 0 ) \log p(x_t \mid x_0) logp(xt∣x0) 关于 x t x_t xt 求梯度。整体的计算过程如下:
log p ( x t ∣ x 0 ) = − n 2 log ( 2 π σ t 2 ) − 1 2 σ t 2 ∥ x t − x 0 ∥ 2 ∇ x t log p ( x t ∣ x 0 ) = − 1 2 σ t 2 ⋅ 2 ( x t − x 0 ) = x 0 − x t σ t 2 ∇ x t log p t ( x t ) = 1 p t ( x t ) ∇ x t p t ( x t ) ∇ x t p t ( x t ) = ∇ x t ∫ p ( x t ∣ x 0 ) p 0 ( x 0 ) d x 0 边缘分布 p t ( x t ) 是对所有 x 0 积分后的分布 = ∫ ∇ x t p ( x t ∣ x 0 ) p 0 ( x 0 ) d x 0 = ∫ p ( x t ∣ x 0 ) ⋅ ∇ x t log p ( x t ∣ x 0 ) ⋅ p 0 ( x 0 ) d x 0 ∇ x t log p t ( x t ) = 1 p t ( x t ) ∇ x t p t ( x t ) = 1 p t ( x t ) ∫ p ( x t ∣ x 0 ) ⋅ x 0 − x t σ t 2 ⋅ p 0 ( x 0 ) d x 0 带入 ∇ x t log p ( x t ∣ x 0 ) = x 0 − x t σ t 2 = 1 σ t 2 ⋅ 1 p t ( x t ) ∫ ( x 0 − x t ) p ( x t ∣ x 0 ) p 0 ( x 0 ) d x 0 = 1 σ t 2 ⋅ 1 p t ( x t ) ∫ x 0 p ( x t ∣ x 0 ) p 0 ( x 0 ) d x 0 − x t σ t 2 = 1 σ t 2 ⋅ ∫ x 0 p ( x t ∣ x 0 ) p 0 ( x 0 ) p t ( x t ) d x 0 − x t σ t 2 = 1 σ t 2 ⋅ ∫ x 0 p ( x 0 , x t ) p t ( x t ) d x 0 − x t σ t 2 = 1 σ t 2 ⋅ ∫ x 0 p ( x 0 ∣ x t ) d x 0 − x t σ t 2 = E [ x 0 ∣ x t ] σ t 2 − x t σ t 2 = E [ x 0 ∣ x t ] − x t σ t 2 \begin{aligned} \log p(x_t \mid x_0) & = -\frac{n}{2} \log(2\pi \sigma_t^2) - \frac{1}{2\sigma_t^2} \| x_t - x_0 \|^2 \\ \nabla_{x_t} \log p(x_t \mid x_0) & = -\frac{1}{2\sigma_t^2} \cdot 2(x_t - x_0) = \frac{x_0 - x_t}{\sigma_t^2} \nabla_{x_t} \log p_t(x_t) & = \frac{1}{p_t(x_t)} \nabla_{x_t} p_t(x_t) \\ \nabla_{x_t} p_t(x_t) &= \nabla_{x_t} \int p(x_t \mid x_0) p_0(x_0) dx_0 & \text{边缘分布 p_t(x_t) 是对所有 x_0 积分后的分布} \\ &= \int \nabla_{x_t} p(x_t \mid x_0) p_0(x_0) dx_0 \\ &= \int p(x_t \mid x_0) \cdot \nabla_{x_t} \log p(x_t \mid x_0) \cdot p_0(x_0) dx_0 \\ \nabla_{x_t} \log p_t(x_t) & = \frac{1}{p_t(x_t)} \nabla_{x_t} p_t(x_t) \\ & = \frac{1}{p_t(x_t)} \int p(x_t \mid x_0) \cdot \frac{x_0 - x_t}{\sigma_t^2} \cdot p_0(x_0) dx_0 & \text{带入} \nabla_{x_t} \log p(x_t \mid x_0) = \frac{x_0 - x_t}{\sigma_t^2} \\ &= \frac{1}{\sigma_t^2} \cdot \frac{1}{p_t(x_t)} \int (x_0 - x_t) p(x_t \mid x_0) p_0(x_0) dx_0 \\ &= \frac{1}{\sigma_t^2} \cdot \frac{1}{p_t(x_t)} \int x_0 p(x_t \mid x_0) p_0(x_0) dx_0 - \frac{x_t}{\sigma_t^2} \\ & = \frac{1}{\sigma_t^2} \cdot \int x_0 \frac{ p(x_t \mid x_0) p_0(x_0) }{p_t(x_t)} dx_0 - \frac{x_t}{\sigma_t^2} \\ & = \frac{1}{\sigma_t^2} \cdot \int x_0 \frac{ p(x_0, x_t)}{p_t(x_t)} dx_0 - \frac{x_t}{\sigma_t^2} \\ & = \frac{1}{\sigma_t^2} \cdot \int x_0 p(x_0 \mid x_t) dx_0 - \frac{x_t}{\sigma_t^2} \\ & = \frac{\mathbb{E}[x_0 \mid x_t]}{\sigma_t^2} - \frac{x_t}{\sigma_t^2} \\ & = \frac{\mathbb{E}[x_0 \mid x_t] - x_t}{\sigma_t^2} \end{aligned} logp(xt∣x0)∇xtlogp(xt∣x0)∇xtpt(xt)∇xtlogpt(xt)=−2nlog(2πσt2)−2σt21∥xt−x0∥2=−2σt21⋅2(xt−x0)=σt2x0−xt∇xtlogpt(xt)=∇xt∫p(xt∣x0)p0(x0)dx0=∫∇xtp(xt∣x0)p0(x0)dx0=∫p(xt∣x0)⋅∇xtlogp(xt∣x0)⋅p0(x0)dx0=pt(xt)1∇xtpt(xt)=pt(xt)1∫p(xt∣x0)⋅σt2x0−xt⋅p0(x0)dx0=σt21⋅pt(xt)1∫(x0−xt)p(xt∣x0)p0(x0)dx0=σt21⋅pt(xt)1∫x0p(xt∣x0)p0(x0)dx0−σt2xt=σt21⋅∫x0pt(xt)p(xt∣x0)p0(x0)dx0−σt2xt=σt21⋅∫x0pt(xt)p(x0,xt)dx0−σt2xt=σt21⋅∫x0p(x0∣xt)dx0−σt2xt=σt2E[x0∣xt]−σt2xt=σt2E[x0∣xt]−xt=pt(xt)1∇xtpt(xt)边缘分布 pt(xt) 是对所有 x0 积分后的分布带入∇xtlogp(xt∣x0)=σt2x0−xt
这里需要注意的是,原文中的符号是 E [ x 0 ∣ x t ] − x t σ t \frac{\mathbb{E}[x_0 \mid x_t] - x_t}{\sigma_t} σtE[x0∣xt]−xt,这是噪声尺度的符号简化,在一些扩散模型的设定中,会将前向过程的噪声标准差定义为 σ t \sqrt{\sigma_t} σt (而非 σ t \sigma_t σt),此时条件分布的方差变为 σ t \sigma_t σt,梯度推导后形式会简化为 n a b l a x t log p t ( x t ) = E [ x 0 ∣ x t ] − x t σ t nabla_{x_t} \log p_t(x_t) = \frac{\mathbb{E}[x_0 \mid x_t] - x_t}{\sigma_t} nablaxtlogpt(xt)=σtE[x0∣xt]−xt。这本质是噪声尺度的参数定义差异 ,核心关系不变,即分数函数是 x t x_t xt 到 x 0 x_0 x0 的条件期望与 x t x_t xt 的差值,除以当前时刻的噪声尺度。
通过前面的损失设计,模型最终预测趋向于 E [ x 0 ∣ A ~ x t , A ~ ] \mathbb{E}[x_0 | \tilde{A}x_t, \tilde{A}] E[x0∣A~xt,A~],本文用它代替常规基于得分的扩散模型中使用的目标 E [ x 0 ∣ x t ] \mathbb{E}[x_0 | x_t] E[x0∣xt]。
这里为什么用 E [ x 0 ∣ A ~ x t , A ~ ] \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] E[x0∣A~xt,A~] 替代 E [ x 0 ∣ x t ] \mathbb{E}[x_0 \mid x_t] E[x0∣xt]?
Ambient Diffusion的核心训练目标是让模型学到 E [ x 0 ∣ A ~ x t , A ~ ] \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] E[x0∣A~xt,A~](干净数据关于额外损坏含噪图像 A ~ x t \tilde{A}x_t A~xt和额外损坏矩阵 A ~ \tilde{A} A~的条件期望)。
模型学到的期望包含 x 0 x_0 x0的全局分布信息,Ambient Diffusion通过额外损坏制造像素不确定性,强制模型学习 x 0 x_0 x0的全图统计特征,而非局部损坏像素的重构。因此 E [ x 0 ∣ A ~ x t , A ~ ] \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] E[x0∣A~xt,A~]是 x 0 x_0 x0的全局条件期望,与 E [ x 0 ∣ x t ] \mathbb{E}[x_0 \mid x_t] E[x0∣xt](关于干净含噪图像的期望)的核心信息一致,二者都反映了 x 0 x_0 x0的分布特征。
模型在训练中已学会从 A ~ x t \tilde{A}x_t A~xt和 A ~ \tilde{A} A~中恢复改进样本的全局期望,因此可以用这个期望近似从干净含噪图像 x t x_t xt到 x 0 x_0 x0的期望。
扩散模型的采样是逆转前向过程,采样公式的设计基于反向过程进行推导可以得到。根据 Song et al. (ICLR 2021) 的理论,对于前向过程 x t = x 0 + σ t ϵ x_t = x_0 + \sigma_t \epsilon xt=x0+σtϵ,其对应的反向确定性概率流 ODE 形式如下:
d x t d t = − 1 2 d ( σ t 2 ) d t ⏟ 方差变化率 ⋅ ∇ x t log p t ( x t ) ⏟ 得分函数 \frac{d x_t}{d t} = - \frac{1}{2} \underbrace{\frac{d(\sigma_t^2)}{dt}}{\text{方差变化率}} \cdot \underbrace{\nabla{x_t} \log p_t(x_t)}_{\text{得分函数}} dtdxt=−21方差变化率 dtd(σt2)⋅得分函数 ∇xtlogpt(xt)
这个方程描述了像素点 x t x_t xt 随时间 t t t 变化的速度。速度的大小和方向由"噪声的变化快慢"和"概率密度的梯度"共同决定。利用 Tweedie 公式,将抽象的"得分函数"替换为具象的"期望"( ∇ x t log p t ( x t ) = x ^ 0 − x t σ t 2 \nabla_{x_t} \log p_t(x_t) = \frac{\hat{x}_0 - x_t}{\sigma_t^2} ∇xtlogpt(xt)=σt2x^0−xt),其中 x ^ 0 \hat{x}_0 x^0 是当前时刻模型预测的 E [ x 0 ∣ ... ] \mathbb{E}[x_0 \mid \dots] E[x0∣...]:
d x t d t = − 1 2 d ( σ t 2 ) d t ( x ^ 0 − x t σ t 2 ) = − 1 2 ( 2 σ t d σ t d t ) x ^ 0 − x t σ t 2 = − 1 σ t d σ t d t ( x ^ 0 − x t ) = 1 σ t d σ t d t ( x t − x ^ 0 ) = d ( ln σ t ) d t ( x t − x ^ 0 ) \begin{aligned} \frac{d x_t}{d t} & = - \frac{1}{2} \frac{d(\sigma_t^2)}{dt} \left( \frac{\hat{x}_0 - x_t}{\sigma_t^2} \right) \\ &= - \frac{1}{2} (2\sigma_t \frac{d \sigma_t}{dt}) \frac{\hat{x}_0 - x_t}{\sigma_t^2} \\ &= - \frac{1}{\sigma_t} \frac{d \sigma_t}{dt} (\hat{x}_0 - x_t) \\ &= \frac{1}{\sigma_t} \frac{d \sigma_t}{dt} ( x_t - \hat{x}_0 ) \\ & = \frac{d (\ln \sigma_t)}{dt} (x_t - \hat{x}_0) \\ \end{aligned} dtdxt=−21dtd(σt2)(σt2x^0−xt)=−21(2σtdtdσt)σt2x^0−xt=−σt1dtdσt(x^0−xt)=σt1dtdσt(xt−x^0)=dtd(lnσt)(xt−x^0)
开始计算从当前时刻 t t t 走到下一时刻 s s s(即 t − Δ t t-\Delta t t−Δt)后的状态 x s x_s xs。将方程变形,把含 x x x 的项移到左边,含 t t t(即 σ \sigma σ)的项移到右边。
d x t x t − x ^ 0 = d σ t σ t \frac{d x_t}{x_t - \hat{x}_0} = \frac{d \sigma_t}{\sigma_t} xt−x^0dxt=σtdσt
假设在极短的时间步内,模型预测的 x ^ 0 \hat{x}0 x^0 保持不变(这是一个常数)。对两边进行定积分,积分区间从 t t t 到 s s s:
∫ x t x s 1 x − x ^ 0 d x = ∫ σ t σ s 1 σ d σ \int{x_t}^{x_s} \frac{1}{x - \hat{x}0} dx = \int{\sigma_t}^{\sigma_s} \frac{1}{\sigma} d\sigma ∫xtxsx−x^01dx=∫σtσsσ1dσ
- 计算积分结果: ln ( x − x ^ 0 ) ∣ x t x s = ln ( σ ) ∣ σ t σ s \ln(x - \hat{x}0) \Big|{x_t}^{x_s} = \ln(\sigma) \Big|_{\sigma_t}^{\sigma_s} ln(x−x^0) xtxs=ln(σ) σtσs
- 代入上下限: ln ( x s − x ^ 0 ) − ln ( x t − x ^ 0 ) = ln ( σ s ) − ln ( σ t ) \ln(x_s - \hat{x}_0) - \ln(x_t - \hat{x}_0) = \ln(\sigma_s) - \ln(\sigma_t) ln(xs−x^0)−ln(xt−x^0)=ln(σs)−ln(σt)
- 利用对数性质合并: ln ( x s − x ^ 0 x t − x ^ 0 ) = ln ( σ s σ t ) \ln \left( \frac{x_s - \hat{x}_0}{x_t - \hat{x}_0} \right) = \ln \left( \frac{\sigma_s}{\sigma_t} \right) ln(xt−x^0xs−x^0)=ln(σtσs)
- 去对数(Exponentiate): x s − x ^ 0 x t − x ^ 0 = σ s σ t \frac{x_s - \hat{x}_0}{x_t - \hat{x}_0} = \frac{\sigma_s}{\sigma_t} xt−x^0xs−x^0=σtσs,这里可以看到,图像当前状态与目标状态( x ^ 0 \hat{x}_0 x^0)的距离,是严格按照噪声标准差 σ \sigma σ 的比例进行缩放的。如果 σ \sigma σ 减半,那么图像离目标的距离也减半。
- 乘过去: x s − x ^ 0 = σ s σ t ( x t − x ^ 0 ) x_s - \hat{x}_0 = \frac{\sigma_s}{\sigma_t} (x_t - \hat{x}_0) xs−x^0=σtσs(xt−x^0)
- 移项: x s = x ^ 0 + σ s σ t x t − σ s σ t x ^ 0 x_s = \hat{x}_0 + \frac{\sigma_s}{\sigma_t} x_t - \frac{\sigma_s}{\sigma_t} \hat{x}_0 xs=x^0+σtσsxt−σtσsx^0
- 合并同类项( x ^ 0 \hat{x}_0 x^0 的系数): x s = σ s σ t x t + ( 1 − σ s σ t ) x ^ 0 x_s = \frac{\sigma_s}{\sigma_t} x_t + \left( 1 - \frac{\sigma_s}{\sigma_t} \right) \hat{x}_0 xs=σtσsxt+(1−σtσs)x^0
- 通分: 1 − σ s σ t = σ t − σ s σ t 1 - \frac{\sigma_s}{\sigma_t} = \frac{\sigma_t - \sigma_s}{\sigma_t} 1−σtσs=σtσt−σs
- 最终结果: x s = σ s σ t x t + σ t − σ s σ t x ^ 0 x_s = \frac{\sigma_s}{\sigma_t} x_t + \frac{\sigma_t - \sigma_s}{\sigma_t} \hat{x}_0 xs=σtσsxt+σtσt−σsx^0
将 x ^ 0 \hat{x}0 x^0 替换回 E [ x 0 ∣ A ~ x t , A ~ ] \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] E[x0∣A~xt,A~],这正是我们想要的采样公式:
x t − Δ t = σ t − Δ t σ t x t + σ t − σ t − Δ t σ t E [ x 0 ∣ A ~ x t , A ~ ] x{t-\Delta t} = \frac{\sigma_{t-\Delta t}}{\sigma_t} x_t + \frac{\sigma_t - \sigma_{t-\Delta t}}{\sigma_t} \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] xt−Δt=σtσt−Δtxt+σtσt−σt−ΔtE[x0∣A~xt,A~]
上述推导得到的公式被称为固定掩码采样器(Fixed Mask Sampler)。虽然理论上可行,但在实际应用中,尤其是在高噪声( t → 0 t \to 0 t→0)或遮挡严重的情况下,该基础采样器存在两个核心问题:
- 平均化效应(Averaging Effects):由于遮挡严重,模型对 E [ x 0 ] \mathbb{E}[x_0] E[x0] 的预测本质上是所有可能结果的加权平均。这导致生成的图像在被遮挡区域出现明显的模糊,丢失高频细节。
- 不一致性(Inconsistency):模型从未见过某些被遮挡的像素。如果我们换一个不同的掩码 A ~ ′ \tilde{A}' A~′ 去评估同一个 x t x_t xt,模型预测出的去噪图像 E [ x 0 ∣ A ~ ′ x t , A ~ ′ ] \mathbb{E}[x_0 \mid \tilde{A}'x_t, \tilde{A}'] E[x0∣A~′xt,A~′] 可能会与使用原始掩码 A ~ \tilde{A} A~ 的预测完全不同。这种对掩码选择的敏感性是不理想的。
为了解决这些问题,论文提出了重构引导(Reconstruction Guidance)。其核心思想是引入一个额外的梯度项,强迫模型生成的图像在不同掩码下的预测保持一致。在采样更新步骤中加入以下引导项:
− w t ∇ x t E A ~ ′ ∥ E [ x 0 ∣ A ~ x t , A ~ ] − E [ x 0 ∣ A ~ ′ x t , A ~ ′ ] ∥ 2 - w_t \nabla_{x_t} \mathbb{E}_{\tilde{A}'} \left\| \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] - \mathbb{E}[x_0 \mid \tilde{A}'x_t, \tilde{A}'] \right\|^2 −wt∇xtEA~′ E[x0∣A~xt,A~]−E[x0∣A~′xt,A~′] 2
其中 w t w_t wt 是随时间变化的权重系数。这个项通过最小化不同掩码下预测结果的差异(方差),引导采样轨迹走向那些"无论怎么遮挡,看起来都一致"的清晰图像区域,从而显著提升了生成质量并抑制了模糊。
最终的修正采样公式变为:
x t − Δ t = σ t − Δ t σ t x t + σ t − σ t − Δ t σ t E [ x 0 ∣ A ~ x t , A ~ ] ⏟ 基础采样项 − w t ∇ x t E A ~ ′ ∥ E [ x 0 ∣ A ~ x t , A ~ ] − E [ x 0 ∣ A ~ ′ x t , A ~ ′ ] ∥ 2 x_{t-\Delta t} = \underbrace{\frac{\sigma_{t-\Delta t}}{\sigma_t} x_t + \frac{\sigma_t - \sigma_{t-\Delta t}}{\sigma_t} \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}]}{\text{基础采样项}} - w_t \nabla{x_t} \mathbb{E}_{\tilde{A}'} \left\| \mathbb{E}[x_0 \mid \tilde{A}x_t, \tilde{A}] - \mathbb{E}[x_0 \mid \tilde{A}'x_t, \tilde{A}'] \right\|^2 xt−Δt=基础采样项 σtσt−Δtxt+σtσt−σt−ΔtE[x0∣A~xt,A~]−wt∇xtEA~′ E[x0∣A~xt,A~]−E[x0∣A~′xt,A~′] 2
通常我们在训练时,是固定输入 x x x,对模型参数 θ \theta θ 求导(loss.backward() 更新权重)。而在推理时的这个引导(Guidance)过程中,我们是固定模型参数 θ \theta θ,对输入图像 x t x_t xt 求导。简单来说,这就像是在生成过程的每一步里,做了一次微小的"训练"反向传播。
扩展讨论:隐私保护
基础生成图像的需求之外,本文还关注了隐私需求:即生成模型不能记忆训练样本(如医疗影像不能泄露患者隐私)。实际上,Ambient Diffusion的方案天然满足这一点:
- 模型训练时从未接触干净样本,仅用高度损坏的样本;
- 额外损坏制造的像素不确定性,让模型无法锁定训练样本的具体像素;
通过实验验证,生成样本与训练集的相似度显著降低,无近复制样本。
总结
从用损坏数据学干净分布的基础需求出发,每一步都是为了填补传统方法的缺陷:
基础需求:损坏数据→干净分布 现有方案缺陷:只拟合可见像素 解决思路:额外损坏制造像素不确定性 损失函数:用 A 约束可见+额外损坏像素 数学验证:解是干净数据的条件期望 采样方案:用条件期望近似分数函数 延伸价值:抑制记忆
这个过程中,额外损坏和基于 A A A 的损失之所以不直接,是因为直接用损坏数据训练传统模型的路走不通,必须通过制造不确定性→设计针对性损失的间接方式,强制模型学习全局特征。