扩散模型(Diffusion Model)是继GAN、VAE后的一种生成式模型,而目前在文生图领域比较流行的工具,如DALL-E2、Imagen、Stable Diffusion等,均是以上述扩散模型为基础,不断进行算法优化、迭代,取得了令人惊艳的效果。
DDPM
扩散模型于2015年在论文《Deep unsupervised learning using nonequilibrium thermodynamics》中被提出,并于2020年在论文《Denoising diffusion probabilistic models》中被改进、用于图片生成。《Denoising Diffusion Probabilistic Models》中提出的扩散模型被称为DDPM。
正向扩散过程
令原始图片样本为 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0,其满足分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 ∼ q ( x 0 ) x_0 \sim q(x_0) </math>x0∼q(x0)。定义前向扩散过程,在 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T步内,每步给样本增加一个小的满足高斯分布的噪声,从而产生 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T个带噪声的样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 1 , . . . , x T x_1,...,x_T </math>x1,...,xT,整个过程为一个一阶马尔可夫过程, <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt只与 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1有关,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1})=\mathcal{N}(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t\mathbf{I}) </math>q(xt∣xt−1)=N(xt;1−βt xt−1,βtI)
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) </math>q(xt∣xt−1)表示给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1时, <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_{t} </math>xt的条件概率,即均值为 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 − β t x t − 1 \sqrt{1-\beta_t}x_{t-1} </math>1−βt xt−1、方差为 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t I \beta_t\mathbf{I} </math>βtI的高斯分布,集合 <math xmlns="http://www.w3.org/1998/Math/MathML"> { β t ∈ ( 0 , 1 ) } t = 1 T \{\beta_t \in (0,1)\}{t=1}^{T} </math>{βt∈(0,1)}t=1T用于控制每步的噪声大小。进一步给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0时,整个马尔科夫过程的条件概率为各步条件概率的连乘,可用以下公式表示: <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x{1:T}|x_0)=\prod_{t=1}^{T}{q(x_t|x_{t-1})} </math>q(x1:T∣x0)=∏t=1Tq(xt∣xt−1) 正向扩散过程可由图1从右到左的过程表示,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0为原始图片,随着每步增加噪声,图片逐渐变得模糊。

对于上述正向扩散过程,可进一步令 <math xmlns="http://www.w3.org/1998/Math/MathML"> α t = 1 − β t \alpha_t=1-\beta_t </math>αt=1−βt,且 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t = ∏ i = 1 t α i \bar{\alpha}t=\prod{i=1}^{t}{\alpha_i} </math>αˉt=∏i=1tαi,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t = α t x t − 1 + 1 − α t ϵ t − 1 ; where ϵ t − 1 , ϵ t − 2 , . . . ∼ N ( 0 , I ) = α t α t − 1 x t − 2 + 1 − α t α t − 1 ϵ ˉ t − 2 ; where ϵ ˉ t − 2 merge two Gaussians ( ∗ ) . = . . . = α ˉ t x 0 + 1 − α ˉ t ϵ \begin{aligned} x_t&=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1} &;\text{where }\epsilon_{t-1},\epsilon_{t-2},...\sim\mathcal{N}(0,\mathbf{I})\\ &=\sqrt{\alpha_t \alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar{\epsilon}{t-2} &;\text{where }\bar{\epsilon}{t-2}\text{ merge two Gaussians }(*). \\ &=... \\ &=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon \end{aligned} </math>xt=αt xt−1+1−αt ϵt−1=αtαt−1 xt−2+1−αtαt−1 ϵˉt−2=...=αˉt x0+1−αˉt ϵ;where ϵt−1,ϵt−2,...∼N(0,I);where ϵˉt−2 merge two Gaussians (∗).
即 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt是在 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1的基础上,增加一个满足高斯分布的噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ t − 1 \epsilon_{t-1} </math>ϵt−1,循环递归,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt可进一步推导为在 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0的基础上,增加一个满足高斯分布的噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ。这里使用了高斯分布的一个特性,即两个高斯分布合并后仍是一个高斯分布,例如分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , σ 1 2 I ) \mathcal{N}(0,\sigma_1^2\mathbf{I}) </math>N(0,σ12I)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , σ 2 2 I ) \mathcal{N}(0,\sigma_2^2\mathbf{I}) </math>N(0,σ22I),合并后的分布为 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , ( σ 1 2 + σ 2 2 ) I ) \mathcal{N}(0,(\sigma_1^2+\sigma_2^2)\mathbf{I}) </math>N(0,(σ12+σ22)I)。
反向扩散过程
以上介绍了正向扩散过程,即图1从右到左,对原始图片逐步增加噪声,如果将过程逆向,即图1从左到右,那么就能从满足高斯分布的噪音 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0,\mathbf{I}) </math>xT∼N(0,I)逐步还原原始图片样本,这就是基于扩散模型生成图片的基本思想,即从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T x_T </math>xT到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0的每一步,在给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt时,根据条件概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) </math>q(xt−1∣xt)采样求解 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1,直至最终得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0。 而当正向扩散过程每步增加的噪声很小时,反向扩散过程的条件概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) </math>q(xt−1∣xt)也可以认为满足高斯分布,但实际上,我们不能直接求解该条件概率,因为直接求解需要整体数据集合。除直接求解外,另一个方法是训练一个模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ p_\theta </math>pθ近似预估上述条件概率,可用以下公式表示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\mu_\theta(x_t,t),\Sigma_\theta(x_t,t)) </math>pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{0:T})=p(x_T)\prod_{t=1}^{T}{p_\theta(x_{t-1}|x_t)} </math>pθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)
从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T x_T </math>xT到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0的每一步,通过模型 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ,输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1的高斯分布的均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) \mu_\theta(x_t,t) </math>μθ(xt,t)和方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t),基于预测值,可以从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1的高斯分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) </math>pθ(xt−1∣xt)中进行采样,从而得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1的一个可能取值,如此循环,直至最终得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0的一个可能取值。通过上述反向扩散过程,即可以实现从一个满足高斯分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , I ) \mathcal{N}(0,\mathcal{I}) </math>N(0,I)的随机噪声,生成一张图片。而由于每次预测均是从一个概率密度函数中进行采样,因此,可以保证生成图片的多样性。 更进一步,论文进一步将模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) </math>pθ(xt−1∣xt)的均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) \mu_\theta(x_t,t) </math>μθ(xt,t)和方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)转化为预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t),并推导出 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) \mu_\theta(x_t,t) </math>μθ(xt,t)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t)的关系:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t,t)\right) </math>μθ(xt,t)=αt 1(xt−1−αˉt 1−αtϵθ(xt,t))
因此, <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) </math>pθ(xt−1∣xt)可表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t,t)\right),\Sigma_\theta(x_t,t)) </math>pθ(xt−1∣xt)=N(xt−1;αt 1(xt−1−αˉt 1−αtϵθ(xt,t)),Σθ(xt,t))
论文将 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)固定为常量,通过模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t),并使用上述公式的概率密度函数进行采样,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt降噪得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1。
模型结构
DDPM中预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t)的模型基于OpenAI于2017年发布的一个U-Net形式的网络结构PixelCNN++。U-Net于2015年在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中发布,起初主要用于医学图像的切割,目前作为常用的去噪结构,广泛应用于扩散模型中。

U-Net的网络结构如图2所示,因整体结构形似字母U而得名,U型左侧是4层编码器层,对图片进行降维,U型右侧是4层解码器层,对图片进行升维。编码器的每一层,先是连续的两个卷积层(卷积核维度为3×3)和ReLU层,再接一个池化层进行下采样,然后输入下一层,卷积层的通道数逐层加倍,例如,第一层的输入是572×572的单通道图片,经过两个卷积核维度为3×3、通道为64、无padding的卷积层,输出张量维度分别为570×570×64、568×568×64,再经过一个维度为2×2的最大池化层后,输出张量维度为284×284×64,如此循环,最后一层输出的张量维度为32×32×512。在编码器层和解码器层之间的中间层,经过两个卷积层(卷积核维度为3×3、通道为1024)和ReLU层,输出的张量维度为28×28×1024。解码器的每一层,和编码器类似,也先是连续的两个卷积层(卷积核维度为3×3)和ReLU层,和编码器不同的是,解码器从下层到上层,通过一个上卷积层(卷积核维度为3×3)进行上采样(长、宽维度加倍,但通道缩小),同时,解码器每层的输入除上一层上采样的输出外,还包括同层编码器输出的裁剪。例如,解码器第一层的输入,包括中间层上采样的输出,张量维度为56×56×512,和同层编码器输出的裁剪,张量维度为56×56×512,合并后的张量维度为56×56×1024,经过两个卷积核维度为3×3、通道为512的卷积层,输出张量维度分别为54×54×512、52×52×512。解码器最后一层的输出,再通过一个1×1的卷积层,将原先的64通道映射为指定的通道,因为原始U-Net用于图像切割,即对图像每个像素做分类,所以有多少个分类,即有多少个最终的通道。

而PixelCNN++于2017年在论文《PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications》中发布,其网络结构如图3所示。图中,矩形区块对应于U-Net中的编码器或解码器层,共3个编码器层、3个解码器层。在每个编码器或解码器中,PixelCNN++在原U-Net两个卷积层的基础上,增加了一个残差连接。DDPM进一步进行网络结构的改进,包括:使用Group Normalization进行归一化;在残差卷积块后增加自注意力层;使用Transformer中Sinusoidal Position Embedding对步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t编码成Embedding向量作为模型输入。 DDPM的代码开源,代码地址是:github.com/hojonathanh...,其深度学习框架采用Tensorflow,计算资源采用Google Cloud TPU v3-8。diffusion_tf/models/unet.py中定义了网络结构,核心代码如下所示(增加了部分注释):
python
with tf.variable_scope(name, reuse=reuse):
# Timestep embedding
# 将步数t编码成Embedding向量
with tf.variable_scope('temb'):
temb = nn.get_timestep_embedding(t, ch)
temb = nn.dense(temb, name='dense0', num_units=ch * 4)
temb = nn.dense(nonlinearity(temb), name='dense1', num_units=ch * 4)
assert temb.shape == [B, ch * 4]
# Downsampling
# 多层编码器层
hs = [nn.conv2d(x, name='conv_in', num_units=ch)]
for i_level in range(num_resolutions):
with tf.variable_scope('down_{}'.format(i_level)):
# Residual blocks for this resolution
# 构造编码器层,残差卷积块+自注意力层,并进行下采样
for i_block in range(num_res_blocks):
h = resnet_block(
hs[-1], name='block_{}'.format(i_block), temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout)
if h.shape[1] in attn_resolutions:
h = attn_block(h, name='attn_{}'.format(i_block), temb=temb)
hs.append(h)
# Downsample
if i_level != num_resolutions - 1:
hs.append(downsample(hs[-1], name='downsample', with_conv=resamp_with_conv))
# Middle
# 中间层,残差卷积块+自注意力层+残差卷积块
with tf.variable_scope('mid'):
h = hs[-1]
h = resnet_block(h, temb=temb, name='block_1', dropout=dropout)
h = attn_block(h, name='attn_1'.format(i_block), temb=temb)
h = resnet_block(h, temb=temb, name='block_2', dropout=dropout)
# Upsampling
# 多层解码器层
for i_level in reversed(range(num_resolutions)):
with tf.variable_scope('up_{}'.format(i_level)):
# Residual blocks for this resolution
# 构造解码器层,残差卷积块+自注意力层,并进行上采样
for i_block in range(num_res_blocks + 1):
h = resnet_block(tf.concat([h, hs.pop()], axis=-1), name='block_{}'.format(i_block),
temb=temb, out_ch=ch * ch_mult[i_level], dropout=dropout)
if h.shape[1] in attn_resolutions:
h = attn_block(h, name='attn_{}'.format(i_block), temb=temb)
# Upsample
if i_level != 0:
h = upsample(h, name='upsample', with_conv=resamp_with_conv)
assert not hs
# End
# 最后再经过一个卷积层输出
h = nonlinearity(normalize(h, temb=temb, name='norm_out'))
h = nn.conv2d(h, name='conv_out', num_units=out_ch, init_scale=0.)
assert h.shape == x.shape[:3] + [out_ch]
return h
其中,残差卷积块的代码如下(增加了部分注释):
python
def resnet_block(x, *, temb, name, out_ch=None, conv_shortcut=False, dropout):
B, H, W, C = x.shape
if out_ch is None:
out_ch = C
with tf.variable_scope(name):
h = x
# 对图片进行归一化和非线性转化
h = nonlinearity(normalize(h, temb=temb, name='norm1'))
# 对图片进行卷积
h = nn.conv2d(h, name='conv1', num_units=out_ch)
# add in timestep embedding
# 对步数t的embedding向量进行非线性转化,并合并至图片
h += nn.dense(nonlinearity(temb), name='temb_proj', num_units=out_ch)[:, None, None, :]
# 对合并图片和步数后的输入再进行归一化和非线性转化,并再进行卷积
h = nonlinearity(normalize(h, temb=temb, name='norm2'))
h = tf.nn.dropout(h, rate=dropout)
h = nn.conv2d(h, name='conv2', num_units=out_ch, init_scale=0.)
# 对两次卷积后的输出和原始输入进行残差连接
if C != out_ch:
if conv_shortcut:
x = nn.conv2d(x, name='conv_shortcut', num_units=out_ch)
else:
x = nn.nin(x, name='nin_shortcut', num_units=out_ch)
assert x.shape == h.shape
print('{}: x={} temb={}'.format(tf.get_default_graph().get_name_scope(), x.shape, temb.shape))
return x + h
自注意力层的代码如下(即Transformer中的缩放点积注意力层):
python
def attn_block(x, *, name, temb):
B, H, W, C = x.shape
with tf.variable_scope(name):
h = normalize(x, temb=temb, name='norm')
q = nn.nin(h, name='q', num_units=C)
k = nn.nin(h, name='k', num_units=C)
v = nn.nin(h, name='v', num_units=C)
w = tf.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5))
w = tf.reshape(w, [B, H, W, H * W])
w = tf.nn.softmax(w, -1)
w = tf.reshape(w, [B, H, W, H, W])
h = tf.einsum('bhwHW,bHWc->bhwc', w, v)
h = nn.nin(h, name='proj_out', num_units=C, init_scale=0.)
assert h.shape == x.shape
print(tf.get_default_graph().get_name_scope(), x.shape)
return x + h
训练采样
训练
通过模型预测误差 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t),损失函数采用均方误差(MSE,Mean-Squared Error):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L simple : = E t ∼ [ 1 , T ] , x 0 ∼ q ( x 0 ) , ϵ ∼ N ( 0 , I ) [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] L_{\text{simple}}:=E_{t\sim[1,T],x_0\sim q(x_0),\epsilon\sim\mathcal{N}(0,\mathbf{I})}[\parallel\epsilon-\epsilon_\theta(x_t,t)\parallel^2] </math>Lsimple:=Et∼[1,T],x0∼q(x0),ϵ∼N(0,I)[∥ϵ−ϵθ(xt,t)∥2]
模型训练的目标即最小化上述损失函数,即使模型预测出的噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t)和真实噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ尽可能接近。

训练算法如图4所示,采用梯度下降算法,循环下述过程直至模型收敛:
- 对于样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> { 1 , . . . , T } \{1,...,T\} </math>{1,...,T}中随机采样步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t;
- 从高斯分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , I ) \mathcal{N}(0,\mathbf{I}) </math>N(0,I)中采样真实噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ;
- 根据样本 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0和真实噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ,使用前面推导出的公式 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon </math>xt=αˉt x0+1−αˉt ϵ计算第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t步正向扩散后带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt;
- 根据带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,使用模型预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t),即 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon,t) </math>ϵθ(αˉt x0+1−αˉt ϵ,t);
- 根据真实噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ \epsilon </math>ϵ和预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \epsilon_\theta(\sqrt{\bar{\alpha}t}x_0+\sqrt{1-\bar{\alpha}t}\epsilon,t) </math>ϵθ(αˉt x0+1−αˉt ϵ,t)计算损失函数的梯度,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∇ θ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ ) ∥ 2 \nabla\theta\parallel\epsilon-\epsilon\theta(\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon)\parallel^2 </math>∇θ∥ϵ−ϵθ(αˉt x0+1−αˉt ϵ)∥2;
- 根据梯度和学习率超参更新模型参数。
采样

采样算法如图5所示,过程如下:
- 从高斯分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( 0 , I ) \mathcal{N}(0,\mathbf{I}) </math>N(0,I)中采样完全噪声图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T x_T </math>xT;
- 循环 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T步,步数从 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T到1,直至计算得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0,生成最终的图片,对于其中的某一步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t:
- 根据带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,使用模型预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t);
- 前面已推导出概率密度函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) </math>pθ(xt−1∣xt)满足高斯分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> N ( x t − 1 ; 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) , Σ θ ( x t , t ) ) \mathcal{N}(x_{t-1};\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t,t)\right),\Sigma_\theta(x_t,t)) </math>N(xt−1;αt 1(xt−1−αˉt 1−αtϵθ(xt,t)),Σθ(xt,t)),使用公式 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t,t)\right) </math>μθ(xt,t)=αt 1(xt−1−αˉt 1−αtϵθ(xt,t))由噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) </math>ϵθ(xt,t)计算均值 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) \mu_\theta(x_t,t) </math>μθ(xt,t);
- 对于上述概率密度函数,指定方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)为常量,根据该分布进行采样,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t,t)\right) +\sigma_tz </math>xt−1=αt 1(xt−1−αˉt 1−αtϵθ(xt,t))+σtz。