Classifier Guidance
Improved DDPM虽然对DDPM进行了改进,但在一些大数据集上(如ImageNet 256×256)生成图片的实验效果(FID)仍是低于GAN。因此,OpenAI继续对DDPM进行改进,在2021年随后又发表了论文《Diffusion Models Beat Gans on Image Synthesis》,在模型结构上进一步优化,同时引入Classifier Guidance技术,通过图片分类器的梯度调节反向扩散过程,在尽量保持图片生成多样性的前提下,提升准确性,从而在多个数据集的实验效果(FID)上超过了GAN,实现了SOTA。论文将改进后的模型称为ADM(Ablated Diffusion Model)。
改进
网络结构
DDPM和Improved DDPM中的模型均使用U-Net,ADM在其网络结构的基础上,进一步增加以下数项改进:
- 增加网络结构的宽度和深度;
- 在注意力机制上,DDPM原先只在16×16这一层增加单头注意力层(缩放点积注意力),而ADM在32×32、16×16、8×8各层均增加了多头注意力层;
- 在上下采样上,DDPM原先在下采样使用池化或卷积层、在上采样使用插值或卷积层,而ADM使用残差卷积块;
ADM的代码开源,代码地址是:github.com/openai/guid...,其是在Improved DDPM的代码基础上进行修改。网络结构定义的相关代码在guided-diffusion/unet.py的UNetModel类中,例如,ADM使用残差块卷积进行下采样的代码如下:
python
if level != len(channel_mult) - 1:
# 除编码器最后一层外的其他层,需要进行下采样输出到下一层
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
# 若标记resblock_updown为True,则使用残差卷积块进行下采样
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
# 设置残差卷积块中需进行下采样
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
自适应组归一化
ADM还使用了自适应组归一化(Adaptive Group Normalization,AdaGN),组归一化如图1最右侧所示,即对一个图片样本的所有像素,按通道分组进行归一化,而自适应归一化可表示为以下公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> AdaGN ( h , y ) = y s GroupNorm(h) + y b \text{AdaGN}(h,y)=y_s\text{GroupNorm(h)}+y_b </math>AdaGN(h,y)=ysGroupNorm(h)+yb 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h是残差卷积块中第一个卷积层的输出, <math xmlns="http://www.w3.org/1998/Math/MathML"> y s y_s </math>ys、 <math xmlns="http://www.w3.org/1998/Math/MathML"> y b y_b </math>yb分别是步数和图片分类的Embedding向量经过线性层后的投影。自适应归一化的代码如下所示:
python
# 经过第一个卷积层的输出
h = self.in_layers(x)
......
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# 取y_s和y_b
scale, shift = th.chunk(emb_out, 2, dim=1)
# 按y_s * GroupNorm(h) + y_b进行自适应组归一化
h = out_norm(h) * (1 + scale) + shift
# 经过第二个卷积层输出
h = out_rest(h)
论文通过实验发现,使用自适应组归一化能够进一步优化FID。
Classifier Guidance
除了在网络结构上精心设计和优化外,GAN还在有条件(已知图片类别)的图片生成中大量使用了图片类别信息。基于此,ADM一方面在自适应组归一化中引入图片类别的Embedding向量作为模型输入,另一方面设计了Classifier Guidance机制,通过引入一个分类器指导反向扩散过程:预先使用带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt训练分类器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ϕ ( y ∣ x t ) p_{\phi}(y|x_t) </math>pϕ(y∣xt)实现对类别的预测;在逐步反向扩散生成图片时,DDPM在每一步基于扩散模型预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt)和方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t ) \Sigma_\theta(x_t) </math>Σθ(xt),并由公式 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t ) ) \mu_\theta(x_t)=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t)\right) </math>μθ(xt)=αt 1(xt−1−αˉt 1−αtϵθ(xt))计算得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t ) \mu_\theta(x_t) </math>μθ(xt),即得到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1高斯分布的均值和方差,在此基础上,ADM使用分类器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ϕ ( y ∣ x t ) p_{\phi}(y|x_t) </math>pϕ(y∣xt)输出对数的梯度对均值进行调整,并使用调整均值后的高斯分布进行采样得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1,均值调整公式如下所示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ^ θ ( x t ∣ y ) = μ θ ( x t ∣ y ) + s ⋅ Σ θ ( x t ∣ y ) ∇ x t log p ϕ ( y ∣ x t ) \hat{\mu}\theta(x_t|y)=\mu\theta(x_t|y)+s\cdot\Sigma_\theta(x_t|y)\nabla_{x_{t}}\log{p_{\phi}{(y|x_t)}} </math>μ^θ(xt∣y)=μθ(xt∣y)+s⋅Σθ(xt∣y)∇xtlogpϕ(y∣xt)
其中,系数 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s被称为Guidance Scale,论文通过实验发现随着 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s的增加,生成图片的质量会提升,但多样性会减少。引入Classifier Guidance后的采样算法步骤如图2所示。
ADM中使用的分类器网络结构和扩散模型网络结构近似,均采用U-Net,但只使用编码器层和中间层,而没有解码器层,另外,由于分类器的目标是预测类别,因此类别没有作为输入。分类器网络结构定义的相关代码在guided-diffusion/unet.py的UNetModel类中。 在分类器和扩散模型训练完成后,便可使用其进行图片采样。采样时使用分类器输出梯度对均值进行调整的代码在guided-diffusion/gaussian_diffusion.py的condition_mean方法中,如下所示:
python
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
# 使用分类器输出分类器梯度
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
# 根据分类器梯度、均值和方差计算新均值
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean
而其中使用分类器输出分类器梯度的cond_fn方法代码如下所示:
python
def cond_fn(x, t, y=None):
assert y is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
# 分类器输出分类结果
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
# 计算梯度并返回
return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
效果
论文在多个数据集上对ADM(仅做网络结构优化)、ADM-G(同时引入Classifier Guidance机制)和其他模型进行了对比实验,从FID指标上,ADM、特别是ADM-G超过了GAN,实现了SOTA。
Classifier-Free Guidance
在Classifier Guidance机制被提出后,紧接着Google于2021年发表了论文《Classifier-Free Diffusion Guidance》。这篇论文指出Classifier Guidance仍存在以下不足:一是Classifier Guidance需要额外训练分类器,二是Classifier Guidance会导致基于梯度的对抗攻击,欺骗FID、IS这类基于分类器的评估指标。因此,这篇论文提出了一种不需要训练分类器、但仍可以基于类别信息指导反向扩散过程的机制------Classifier-Free Guidance。 之前ADM等扩散模型可使用类别信息进行有条件的图片生成,由模型基于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y),或是不使用类别信息进行无条件的图片生成,由模型仅基于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt),但这两种情况需要分别训练模型,而Classifier-Free Guidance的思想是在模型训练时,按一定比例丢弃类别信息,使得模型能够同时学习有条件的图片生成和无条件的图片生成,这样在采样生成图片时,由同一个模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt),并使用两者的差值等价替换分类器输出的梯度对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y)进行调整,调整公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ϵ ^ θ ( x t , y ) = ϵ θ ( x t ) + s ⋅ ( ϵ θ ( x t , y ) − ϵ θ ( x t ) ) \hat{\epsilon}\theta(x_t,y)=\epsilon\theta(x_t)+s\cdot(\epsilon_\theta(x_t,y)-\epsilon_\theta(x_t)) </math>ϵ^θ(xt,y)=ϵθ(xt)+s⋅(ϵθ(xt,y)−ϵθ(xt))
再基于调整后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ^ θ ( x t , y ) \hat{\epsilon}\theta(x_t,y) </math>ϵ^θ(xt,y)计算均值,并从高斯分布中采样得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x{t-1} </math>xt−1。
CLIP Guidance
Classifier Guidance使用类别信息指导反向扩散过程,那是否可以使用除类别外的其他信息指导反向扩散过程?2022年发表的论文《More Control for Free! Image Synthesis with Semantic Diffusion Guidance》就尝试使用了其他信息,其中包括在多模态领域应用比较广泛的CLIP模型。 CLIP模型包括两部分,图片编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) f(x) </math>f(x)和文本编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> g ( c ) g(c) </math>g(c),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x为图片, <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c为文本。训练阶段,采用对比学习,使得正确图片、文本对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , c ) (x,c) </math>(x,c)的点积 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) ⋅ g ( c ) f(x)\cdot g(c) </math>f(x)⋅g(c)尽可能大,错误图片、文本对的点积尽可能小。因此,在推理阶段,可以进行文本和图片相关性的比较。关于CLIP模型的详细介绍,可以阅读原论文《Learning Transferable Visual Models From Natural Language Supervision》或《AIGC系列-CLIP论文阅读笔记》。 在Classifier Guidance中可以使用CLIP模型替换分类器,对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt,使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x t ) ⋅ g ( c ) f(x_t)\cdot g(c) </math>f(xt)⋅g(c)的梯度调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t ∣ c ) \mu_\theta(x_t|c) </math>μθ(xt∣c),公式如下所示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ^ θ ( x t ∣ c ) = μ θ ( x t ∣ c ) + s ⋅ Σ θ ( x t ∣ c ) ∇ x t ( f ( x t ) ⋅ g ( c ) ) \hat{\mu}\theta(x_t|c)=\mu\theta(x_t|c)+s\cdot\Sigma_\theta(x_t|c)\nabla_{x_t}(f(x_t)\cdot g(c)) </math>μ^θ(xt∣c)=μθ(xt∣c)+s⋅Σθ(xt∣c)∇xt(f(xt)⋅g(c))
和Classifier Guidance中的分类器类似,需使用带噪声的图片和文本对训练CLIP模型以获得正确的梯度。
GLIDE
在上述工作的基础上,OpenAI于2022年发表了论文《GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models》,其中发布了GLIDE(Guided Language to Image Diffusion for Generation and Editing)模型,用于基于文本的图片生成。图4是使用GLIDE模型基于文本生成的图片。
基于文本的图片生成
一般的扩散模型从随机采样的高斯噪声开始,不能生成特定的图片,而GLIDE在已有扩散模型的基础上,使用文本信息指导扩散过程,对于带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和文本 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c,能够通过模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t , c ) p_\theta(x_{t-1}|x_t,c) </math>pθ(xt−1∣xt,c),从而逐步降噪,实现了基于文本的图片生成。 具体实现上,GLIDE基于ADM模型,但模型参数和训练数据规模更大,模型参数达到35亿。GLIDE先将文本 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c转化为长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K的token序列,再通过Transformer输出文本的Embedding向量,最后使用文本Embedding向量替换原ADM模型输入中的类别Embedding向量。另外,文本Embedding向量还会经过投影与注意力层中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V拼接在一起,通过注意力机制指导扩散过程。 GLIDE的代码开源,代码地址是:github.com/openai/guid...。通过Transformer输出文本的Embedding向量作为模型输入的代码在glide_text2im/text2im_model.py中,如下所示:
python
def forward(self, x, timesteps, tokens=None, mask=None):
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.xf_width:
text_outputs = self.get_text_emb(tokens, mask)
xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
emb = emb + xf_proj.to(emb)
else:
xf_out = None
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, xf_out)
hs.append(h)
h = self.middle_block(h, emb, xf_out)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, xf_out)
h = h.type(x.dtype)
h = self.out(h)
return h
注意力层拼接文本Embedding向量的代码在glide_text2im/unet.py中,如下所示:
python
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, encoder_kv=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# 拼接文本Embedding向量到k
k = th.cat([ek, k], dim=-1)
# 拼接文本Embedding向量到v
v = th.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
Classifier-Free Guidance
GLIDE也使用了Classifier-Free Guidance机制,只是将类别替换为文本,因此,对模型所预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y)进行调整的公式如下: <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ^ θ ( x t , c ) = ϵ θ ( x t ) + s ⋅ ( ϵ θ ( x t , c ) − ϵ θ ( x t ) ) \hat{\epsilon}\theta(x_t,c)=\epsilon\theta(x_t)+s\cdot(\epsilon_\theta(x_t,c)-\epsilon_\theta(x_t)) </math>ϵ^θ(xt,c)=ϵθ(xt)+s⋅(ϵθ(xt,c)−ϵθ(xt)) GLIDE在上一节已训练得到基于文本的模型、可预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , c ) \epsilon_\theta(x_t,c) </math>ϵθ(xt,c)的基础上,对模型进行微调,将20%的文本Token序列替换成空序列,从而使得模型在具备预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , c ) \epsilon_\theta(x_t,c) </math>ϵθ(xt,c)的基础上,能够进一步预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt),从而在采样时,能够基于调整后的噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ^ θ ( x t , c ) \hat{\epsilon}_\theta(x_t,c) </math>ϵ^θ(xt,c)降噪生成图片。
CLIP Guidance
GLIDE也使用了CLIP Guidance机制,但从实验结果上,其效果不如Classifier-Free Guidance。