给定一个来自数据分布的样本 x 0 ∼ q ( x 0 ) x_0 \sim q(x_0) x0∼q(x0),我们通过逐步向样本中添加高斯噪声来产生一个由潜变量 x 1 , . . . , x T x_1, ..., x_T x1,...,xT组成的马尔科夫链:
q ( x t ∣ x t − 1 ) : = N ( x t ; α t x t − 1 , ( 1 − α t ) I ) q(x_t \mid x_{t-1}) := \mathcal{N}(x_t; \sqrt{\alpha_t} x_{t-1}, (1-\alpha_t)\mathcal{I}) q(xt∣xt−1):=N(xt;αt xt−1,(1−αt)I)
参数解释:
x t x_t xt:扩散过程中的潜变量(在第 t t t步的样本)。这是在步骤 t t t时的样本,逐步加入噪声。
x t − 1 x_{t-1} xt−1:上一时刻的潜变量。
α t \alpha_t αt:控制噪声添加幅度的系数,它在每个时间步骤中控制噪声的大小,通常 α t \alpha_t αt 是随时间递减的,意味着初期的样本保留更多信息,而到后期则加入更多噪声。
N ( x t ; μ , Σ ) \mathcal{N}(x_t; \mu, \Sigma) N(xt;μ,Σ):表示高斯分布,其中 μ \mu μ 和 Σ \Sigma Σ 是均值和协方差。对于每个步骤,这里给出的高斯分布的均值为 α t x t − 1 \sqrt{\alpha_t} x_{t-1} αt xt−1,协方差为 ( 1 − α t ) I (1 - \alpha_t) \mathcal{I} (1−αt)I,即一个逐渐增加的噪声项。
如果每一步加入的噪声幅度 1 − α t 1 - \alpha_t 1−αt足够小,那么后验 q ( x t − 1 ∣ x t ) q(x_{t-1} \mid x_t) q(xt−1∣xt)可以被对角高斯很好地近似。进一步地,如果在整个链中添加的噪声幅度 1 − α 1 , . . . , α T 1 - \alpha_1, ..., \alpha_T 1−α1,...,αT足够大,那么 x T x_T xT可以被 N ( 0 , I ) \mathcal{N}(0, \mathcal{I}) N(0,I)很好地近似。这些性质表明可以学习一个模型 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1} \mid x_t) pθ(xt−1∣xt)来逼近真实后验:
p θ ( x t − 1 ∣ x t ) : = N ( μ θ ( x t ) , Σ θ ( x t ) ) p_\theta(x_{t-1} \mid x_t) := \mathcal{N}(\mu_\theta(x_t), \Sigma_\theta(x_t)) pθ(xt−1∣xt):=N(μθ(xt),Σθ(xt))
参数解释:
p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1} \mid x_t) pθ(xt−1∣xt):由模型学习的后验分布,用于根据当前的噪声样本 x t x_t xt 预测上一时刻的潜变量 x t − 1 x_{t-1} xt−1。
μ θ ( x t ) \mu_\theta(x_t) μθ(xt):该模型的预测均值,通常是通过神经网络预测的结果,它是恢复原始样本的核心。
Σ θ ( x t ) \Sigma_\theta(x_t) Σθ(xt):该模型的预测协方差,通常在训练中被固定为常数,表示噪声的分布。后来的一些工作通过学习 Σ θ \Sigma_\theta Σθ 进一步改进了模型。
其可以被用来产生样本 x 0 ∼ p θ ( x 0 ) x_0 \sim p_\theta(x_0) x0∼pθ(x0),通过从高斯噪声 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, \mathcal{I}) xT∼N(0,I)开始并以一系列步骤 x T − 1 , x T − 2 , . . . , x 0 x_{T-1}, x_{T-2}, ..., x_0 xT−1,xT−2,...,x0逐渐减少噪声。
虽然 log p θ ( x 0 ) \log p_\theta(x_0) logpθ(x0)存在一个易处理的变分下界(VLB),但通过优化一个重新加权VLB项的代理目标可以得到更好的结果。为了计算这个代理目标,我们通过对 x 0 x_0 x0施加高斯噪声 ϵ \epsilon ϵ来生成样本 x t ∼ q ( x t ∣ x 0 ) x_t \sim q(x_t \mid x_0) xt∼q(xt∣x0),然后使用标准均方误差损失训练一个模型 ϵ θ \epsilon_\theta ϵθ来预测添加的噪声:
L simple : = E t ∼ [ 1 , T ] , x 0 ∼ q ( x 0 ) , ϵ ∼ N ( 0 , I ) [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] L_{\text{simple}} := \mathbb{E}{t \sim [1, T], x_0 \sim q(x_0), \epsilon \sim \mathcal{N}(0, \mathcal{I})}[\|\epsilon - \epsilon\theta(x_t, t)\|^2] Lsimple:=Et∼[1,T],x0∼q(x0),ϵ∼N(0,I)[∥ϵ−ϵθ(xt,t)∥2]
参数解释:
t ∼ [ 1 , T ] t \sim [1, T] t∼[1,T]:在训练中,时间步 t t t 是随机选择的,表示在扩散过程中采样的每个时间步。
x 0 ∼ q ( x 0 ) x_0 \sim q(x_0) x0∼q(x0):从数据分布 q ( x 0 ) q(x_0) q(x0) 中采样原始样本 x 0 x_0 x0,通常是从真实数据集中采样。
ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, \mathcal{I}) ϵ∼N(0,I):标准高斯噪声,表示向 x 0 x_0 x0 添加的噪声。
ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t):由模型预测的噪声。目标是使模型能够精确地预测出添加到 x t x_t xt 中的噪声 ϵ \epsilon ϵ。
Ho等人展示了如何从 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t)推导出 μ θ ( x t ) \mu_\theta(x_t) μθ(xt),并将 Σ θ \Sigma_\theta Σθ固定为常数。他们还证明了其与之前基于分数匹配的去噪模型的等价性,其中分数函数为 ∇ x t log p ( x t ) ∝ ϵ θ ( x t , t ) \nabla_{x_t} \log p(x_t) \propto \epsilon_\theta(x_t, t) ∇xtlogp(xt)∝ϵθ(xt,t)。在后续工作中,Nichol和Dhariwal提出了一种学习 Σ θ \Sigma_\theta Σθ的策略,使得模型能够以更少的扩散步骤产生高质量的样本。我们在训练本文的模型时采用了这种技术。
扩散模型还成功应用到了图像超分辨率。根据扩散的标准公式,高分辨率图像 y 0 y_0 y0在一系列步骤中被逐步地施加噪声。然而, p θ ( y t − 1 ∣ y t , x ) p_\theta(y_{t-1} \mid y_t, x) pθ(yt−1∣yt,x)额外地以下采样输入 x x x为条件,通过沿通道维度拼接 x x x(双三次上采样后)提供给模型。这些模型的结果在FID、IS和人工比较的分数上都优于以前的方法。
引导扩散(Guided Diffusion)
Dhariwal和Nicho在论文:Diffusion Models Beat GANs on Image Synthesis中,发现了类别条件扩散模型的样本往往可以使用分类器引导来改善,其中,均值为 μ θ ( x t ∣ y ) \mu_\theta(x_t \mid y) μθ(xt∣y),方差为 Σ θ ( x t ∣ y ) \Sigma_\theta(x_t \mid y) Σθ(xt∣y)的类别条件扩散模型由分类器预测的目标类别 y y y的对数概率 log p ϕ ( y ∣ x t ) \log p_\phi(y \mid x_t) logpϕ(y∣xt)附加地扰动。得到的新的扰动后的均值 μ ^ θ ( x t ∣ y ) \hat{\mu}\theta(x_t \mid y) μ^θ(xt∣y)为:
μ ^ θ ( x t ∣ y ) = μ θ ( x t ∣ y ) + s ⋅ Σ θ ( x t ∣ y ) ∇ x t log p ϕ ( y ∣ x t ) \hat{\mu}\theta(x_t \mid y) = \mu_\theta(x_t \mid y) + s \cdot \Sigma_\theta(x_t \mid y) \nabla_{x_t} \log p_\phi(y \mid x_t) μ^θ(xt∣y)=μθ(xt∣y)+s⋅Σθ(xt∣y)∇xtlogpϕ(y∣xt)
系数 s s s称为引导尺度,Dhariwal和Nichol发现增加 s s s会以牺牲多样性为代价提高样本质量。
Ho和Salimans在论文:CLASSIFIER-FREE DIFFUSION GUIDANCE中提出了classifier-free引导,一种用于引导扩散模型且不需要训练一个单独的分类器 模型的技术。对于classifier-free引导,在训练过程中以固定概率将类别条件扩散模型 ϵ θ ( x t ∣ y ) \epsilon_\theta(x_t \mid y) ϵθ(xt∣y)中的标签 y y y替换为空标签 ∅ \emptyset ∅ 。在采样过程中,模型的输出沿 ϵ θ ( x t ∣ y ) \epsilon_\theta(x_t \mid y) ϵθ(xt∣y)方向进一步推理,并远离 ϵ θ ( x t ∣ ∅ ) \epsilon_\theta(x_t \mid \emptyset) ϵθ(xt∣∅),如下所示:
ϵ ^ θ ( x t ∣ y ) = ϵ θ ( x t ∣ ∅ ) + s ⋅ ( ϵ θ ( x t ∣ y ) − ϵ θ ( x t ∣ ∅ ) ) \hat{\epsilon}\theta(x_t \mid y) = \epsilon\theta(x_t \mid \emptyset) + s \cdot (\epsilon_\theta(x_t \mid y) - \epsilon_\theta(x_t \mid \emptyset)) ϵ^θ(xt∣y)=ϵθ(xt∣∅)+s⋅(ϵθ(xt∣y)−ϵθ(xt∣∅))
其中, s ≥ 1 s \geq 1 s≥1为引导尺度。这种函数形式是受隐式分类器的启发:
p i ( y ∣ x t ) ∝ p ( x t ∣ y ) p ( x t ) p_i(y \mid x_t) \propto \frac{p(x_t \mid y)}{p(x_t)} pi(y∣xt)∝p(xt)p(xt∣y)
其梯度可以用真实分数 ϵ ∗ \epsilon^* ϵ∗来表示:
∇ x t log p i ( y ∣ x t ) ∝ ∇ x t log p ( x t ∣ y ) − ∇ x t log p ( x t ) ∝ ϵ ∗ ( x t ∣ y ) − ϵ ∗ ( x t ) \nabla_{x_t} \log p_i(y \mid x_t) \propto \nabla_{x_t} \log p(x_t \mid y) - \nabla_{x_t} \log p(x_t) \propto \epsilon^*(x_t \mid y) - \epsilon^*(x_t) ∇xtlogpi(y∣xt)∝∇xtlogp(xt∣y)−∇xtlogp(xt)∝ϵ∗(xt∣y)−ϵ∗(xt)
为了使用通用文本提示实现classifier-free引导,我们有时在训练过程中用空序列(我们也称其为 ∅ \emptyset ∅)替换文本描述。然后我们使用修正的预测 ϵ ^ \hat{\epsilon} ϵ^向描述 c c c引导:
ϵ ^ θ ( x t ∣ c ) = ϵ θ ( x t ∣ ∅ ) + s ⋅ ( ϵ θ ( x t ∣ c ) − ϵ θ ( x t ∣ ∅ ) ) \hat{\epsilon}\theta(x_t \mid c) = \epsilon\theta(x_t \mid \emptyset) + s \cdot (\epsilon_\theta(x_t \mid c) - \epsilon_\theta(x_t \mid \emptyset)) ϵ^θ(xt∣c)=ϵθ(xt∣∅)+s⋅(ϵθ(xt∣c)−ϵθ(xt∣∅))
公式:
μ ^ θ ( x t ∣ c ) = μ θ ( x t ∣ c ) + s ⋅ Σ θ ( x t ∣ c ) ∇ x t ( f ( x t ) ⋅ g ( c ) ) \hat{\mu}\theta(x_t \mid c) = \mu\theta(x_t \mid c) + s \cdot \Sigma_\theta(x_t \mid c) \nabla_{x_t} (f(x_t) \cdot g(c)) μ^θ(xt∣c)=μθ(xt∣c)+s⋅Σθ(xt∣c)∇xt(f(xt)⋅g(c))
其中:
f ( x t ) f(x_t) f(xt) 和 g ( c ) g(c) g(c) 分别是图像和文本的嵌入。
∇ x t ( f ( x t ) ⋅ g ( c ) ) \nabla_{x_t} (f(x_t) \cdot g(c)) ∇xt(f(xt)⋅g(c)) 是图像和描述点积的梯度,表示引导方向。
以往使用扩散模型进行图像修复的工作通常并没有显式地对模型进行针对修复任务的训练。传统方法通常通过从扩散模型中采样,并在每个采样步骤后,用从 q ( x t ∣ x 0 ) q(x_t | x_0) q(xt∣x0) 生成的样本替换图像的已知区域。然而,这种方法有一个缺点:模型无法看到全局上下文,只能依赖噪声版本的图像,可能导致边缘伪影。