基于扩散模型生成图片的算法DDPM于2020年被提出。2021年OpenAI发表的论文《Improved Denoising Diffusion Probabilistic Models》,对DDPM算法进行改进。
Improved DDPM
改进
噪声Schedule采用余弦函数
原始DDPM算法,使用公式 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon </math>xt=αˉt x0+1−αˉt ϵ计算第 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t步正向扩散后带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt,公式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t = ∏ i = 1 t α i \bar{\alpha}t=\prod{i=1}^{t}{\alpha_i} </math>αˉt=∏i=1tαi, <math xmlns="http://www.w3.org/1998/Math/MathML"> β t = 1 − α t \beta_t=1-\alpha_t </math>βt=1−αt,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t = 1 − α ˉ t α ˉ t − 1 \beta_t=1-\frac{\bar{\alpha}t}{\bar{\alpha}{t-1}} </math>βt=1−αˉt−1αˉt, <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt表示每步噪声的大小,原始DDPM算法,令 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt随 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t线性增长,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> β 1 = 1 0 − 4 \beta_1=10^{-4} </math>β1=10−4增长到 <math xmlns="http://www.w3.org/1998/Math/MathML"> β T = 0.02 \beta_T=0.02 </math>βT=0.02,改进的DDPM算法,令 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t \bar{\alpha}_t </math>αˉt随 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的变化采用余弦函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> f ( t ) = cos ( t / T + s 1 + s ⋅ π 2 ) 2 α ˉ t = f ( t ) f ( 0 ) β t = clip ( 1 − α ˉ t α ˉ t − 1 , 0.999 ) \begin{aligned} f(t)&=\cos{\left(\frac{t/T+s}{1+s}\cdot\frac{\pi}{2}\right)^2} \\ \bar{\alpha}_t&=\frac{f(t)}{f(0)}\\ \beta_t&=\text{clip}(1-\frac{\bar{\alpha}t}{\bar{\alpha}{t-1}},0.999) \end{aligned} </math>f(t)αˉtβt=cos(1+st/T+s⋅2π)2=f(0)f(t)=clip(1−αˉt−1αˉt,0.999)
两种噪声Schedule下, <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t \bar{\alpha}_t </math>αˉt随 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t的变化曲线如图1所示,相比线性函数,余弦函数的 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t \bar{\alpha}_t </math>αˉt下降相对较平缓,因而 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt相对较小,加噪相对较慢,不会过快地对原始图片加入过多的噪声。
Improved DDPM的代码开源,代码地址是:github.com/openai/impr...,其深度学习框架采用PyTorch。训练和采样分别执行以下脚本:
shell
# 训练脚本
python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
# 采样脚本
python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS
其中,MODEL_FLAGS、DIFFUSION_FLAGS、TRAIN_FLAGS分别表示模型结构(U-Net)、扩散过程和训练的配置,而基线模型(DDPM)的配置如下:
shell
MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"
如果采用余弦函数作为噪声Schedule,可以将DIFFUSION_FLAGS中的noise_schedule设置为cosine,而相应的计算噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt的代码在improved_diffusion/gaussian_diffusion.py中,如下:
python
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
# 入参alpha_bar即公式推导中的f(t)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
对方差进行学习
原始DDPM算法,满足高斯分布的概率密度函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) </math>pθ(xt−1∣xt)中的方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)被固定为常量 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t 2 I \sigma_t^2\mathbf{I} </math>σt2I,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t 2 \sigma_t^2 </math>σt2直接取值 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt。改进的DDPM算法通过模型对 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)也进行了学习,这样可以使用更少的步数、且生成更高质量的图片。
具体如何学习 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)呢?DDPM的论文已推导 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)的取值在 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ~ t \tilde{\beta}_t </math>β~t之间,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ~ t : = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}t:=\frac{1-\bar{\alpha}{t-1}}{1-\bar{\alpha}_t}\beta_t </math>β~t:=1−αˉt1−αˉt−1βt,图2表示了 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ~ t / β t \tilde{\beta}t/\beta_t </math>β~t/βt和 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t之间的关系,从中可见,除了 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = 0 t=0 </math>t=0外,其他 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t取值下, <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β ~ t \tilde{\beta}t </math>β~t近似相等,所以原始DDPM算法将方差直接取值 <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt,而改进的DDPM算法设计了中间向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v,由模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v,并将 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma\theta(x_t,t) </math>Σθ(xt,t)表示为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Σ θ ( x t , t ) = exp ( v log β t + ( 1 − v ) log β ~ t ) \Sigma\theta(x_t,t)=\exp(v\log{\beta_t}+(1-v)\log{\tilde{\beta}_t}) </math>Σθ(xt,t)=exp(vlogβt+(1−v)logβ~t)
原始DDPM算法的损失函数为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L simple : = E t ∼ [ 1 , T ] , x 0 ∼ q ( x 0 ) , ϵ ∼ N ( 0 , I ) [ ∥ ϵ − ϵ θ ( x t , t ) ∥ 2 ] L_{\text{simple}}:=E_{t\sim[1,T],x_0\sim q(x_0),\epsilon\sim\mathcal{N}(0,\mathbf{I})}[\parallel\epsilon-\epsilon_\theta(x_t,t)\parallel^2] </math>Lsimple:=Et∼[1,T],x0∼q(x0),ϵ∼N(0,I)[∥ϵ−ϵθ(xt,t)∥2]
其中并不包含 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t),因此改进的DDPM算法设计了新的损失函数为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L hybrid = L simple + λ L vlb L_\text{hybrid}=L_\text{simple}+\lambda L_\text{vlb} </math>Lhybrid=Lsimple+λLvlb
而 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb的定义如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L vlb : = L 0 + L 1 + ⋯ + L T − 1 + L T where L 0 : = − log p θ ( x 0 ∣ x 1 ) L t − 1 : = D KL ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) for 2 ≤ t ≤ T L T : = D KL ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) \begin{aligned} L_{\text{vlb}}&:=L_0+L_1+\cdots+L_{T-1}+L_T \\ \text{where }L_0&:=-\log{p_\theta(x_0|x_1)} \\ L_{t-1}&:=D_{\text{KL}}(q(x_{t-1}|x_{t},x_0)\parallel p_\theta(x_{t-1}|x_{t}))\text{ for } 2\le t\le T \\ L_T&:=D_{\text{KL}}(q(x_T|x_0)\parallel p(x_T)) \end{aligned} </math>Lvlbwhere L0Lt−1LT:=L0+L1+⋯+LT−1+LT:=−logpθ(x0∣x1):=DKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)) for 2≤t≤T:=DKL(q(xT∣x0)∥p(xT))
损失函数中 <math xmlns="http://www.w3.org/1998/Math/MathML"> λ \lambda </math>λ被设置为0.001,以减少 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb对 <math xmlns="http://www.w3.org/1998/Math/MathML"> L simple L_\text{simple} </math>Lsimple的影响。另外梯度更新时, <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb部分不更新涉及 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t , t ) \mu_\theta(x_t,t) </math>μθ(xt,t)的参数,只更新涉及 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t , t ) \Sigma_\theta(x_t,t) </math>Σθ(xt,t)的参数。 如果需要对方差进行学习,可以将训练脚本参数MODEL_FLAGS中的learn_sigma设置为True,这样,模型结构(U-Net)的输出维度增加,增加的部分作为 <math xmlns="http://www.w3.org/1998/Math/MathML"> v v </math>v,相关代码在improved_diffusion/script_util.py的create_model方法中,如下:
python
return UNetModel(
in_channels=3,
model_channels=num_channels,
# 模型结构(U-Net)的输出维度增加,增加的部分作为v
out_channels=(3 if not learn_sigma else 6),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
)
同时,反向扩散生成图片时,由模型预测方差的代码在improved_diffusion/gaussian_diffusion.py的p_mean_variance方法中,如下:
python
# 模型预测
model_output = model(x, self._scale_timesteps(t), **model_kwargs)
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
# 按列拆分模型输出,后半部分作为v
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
# 按公式exp(v·log(β_t)+(1-v)·log(β_t))计算方差
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
而代码如何调整损失函数的计算,在下一节介绍。
训练时采用Importance Sampling
论文进一步发现,将损失函数替换为 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb或 <math xmlns="http://www.w3.org/1998/Math/MathML"> L hybrid L_\text{hybrid} </math>Lhybrid后,损失函数取值随着训练迭代变化的曲线比较波动,不易收敛,如图3所示。 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb包含多项,每项对应一个步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,且每项的取值量纲差别较大,如图4所示,而原始DDPM算法训练时随机采样步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,因此论文认为每次训练迭代随机采样步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t、并进而计算量纲差别大的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L t L_t </math>Lt导致 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb或 <math xmlns="http://www.w3.org/1998/Math/MathML"> L hybrid L_\text{hybrid} </math>Lhybrid的波动。论文通过训练时采用Importance Sampling来解决上述波动问题。Importance Sampling中, <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb可表示为以下公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L vlb = E t ∼ p t [ L t p t ] , where p t ∝ E [ L t 2 ] and ∑ p t = 1 L_\text{vlb}=E_{t\sim p_t}\left[\frac{L_t}{p_t}\right],\text{where }p_t\propto\sqrt{E\left[L_t^2\right]}\text{ and }\sum{p_t}=1 </math>Lvlb=Et∼pt[ptLt],where pt∝E[Lt2] and ∑pt=1
<math xmlns="http://www.w3.org/1998/Math/MathML"> E [ L t 2 ] E\left[L_t^2\right] </math>E[Lt2]无法提前求解,且在训练过程中会变化,因此,论文对 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb的每一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> L t L_t </math>Lt保留最新的10个取值,并在训练过程中动态更新。训练初期,仍是随机采样步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t,直至所有的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L t L_t </math>Lt均有10个取值,再采用Importance Sampling。从图3可以看出,经过Importance Sampling后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb随着训练迭代变化的曲线比较平滑,且损失最小。 如果需要使用Importance Sampling后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb作为损失函数,可以将训练脚本参数DIFFUSION_FLAGS中的use_kl设置为True、TRAIN_FLAGS中的schedule_sampler设置为loss-second-moment。Importance Sampling的相关代码在improved_diffusion/resample.py的LossSecondMomentResampler类中,如下:
python
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
# 训练初期,仍是随机采样步数
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
# 根据L_t的历史值计算p_t
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()
使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb作为损失函数的相关代码在improved_diffusion/gaussian_diffusion.py的_vb_terms_bpd方法中,如下:
python
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
)
assert decoder_nll.shape == x_start.shape
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
output = th.where((t == 0), decoder_nll, kl)
return {"output": output, "pred_xstart": out["pred_xstart"]}
效果
针对上述改进,论文在ImageNet 64×64和CIFAR-10这两个数据集上分别进行消融实验以验证各改进的有效性,如图5和图6所示。
其中,有效性指标使用了NLL和FID。NLL(Negative Log Likelihood)等价于损失函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb,NLL越小,说明生成图像与真实图像的分布越接近。FID(Fréchet Inception Distance)是另一种用于图像生成质量评估的指标,它可以评估生成图像与真实图像之间的相似度。FID指标的计算方法是使用Inception-v3模型对生成图像和真实图像进行特征提取,并计算两个特征分布之间的Fréchet距离。FID越小,说明生成图像与真实图像越相似。从实验结果上看,噪声Schedule采用余弦函数、对方差进行学习并且训练时损失函数采用Importance Sampling后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> L vlb L_\text{vlb} </math>Lvlb,NLL最低,但FID较高,而噪声Schedule采用余弦函数、对方差进行学习并且训练时损失函数采用 <math xmlns="http://www.w3.org/1998/Math/MathML"> L hybrid L_\text{hybrid} </math>Lhybrid,在NLL、FID上都能取得较小的值。
另外,论文还和其他基于似然预估的模型进行了对比实验,如图7所示。优化后的DDPM虽然在NLL和FID上还不是SOTA,但相对也是较优的效果,仅次于基于Transformer的网络结构。
加速
DDPM在生成图片时需要从完全噪声开始执行多步降噪操作,而每步操作均需要将当前步带噪声的图片作为输入由模型预测噪声,导致生成图片需要较多的步骤和计算量。论文也采用了《Denoising Diffusion Implicit Models》提出的采样方法------DDIM,减少步数。