从DDPM到DALL-E2和Stable Diffusion——扩散模型相关论文阅读(3)

Classifier Guidance

Improved DDPM虽然对DDPM进行了改进,但在一些大数据集上(如ImageNet 256×256)生成图片的实验效果(FID)仍是低于GAN。因此,OpenAI继续对DDPM进行改进,在2021年随后又发表了论文《Diffusion Models Beat Gans on Image Synthesis》,在模型结构上进一步优化,同时引入Classifier Guidance技术,通过图片分类器的梯度调节反向扩散过程,在尽量保持图片生成多样性的前提下,提升准确性,从而在多个数据集的实验效果(FID)上超过了GAN,实现了SOTA。论文将改进后的模型称为ADM(Ablated Diffusion Model)。

改进

网络结构

DDPM和Improved DDPM中的模型均使用U-Net,ADM在其网络结构的基础上,进一步增加以下数项改进:

  • 增加网络结构的宽度和深度;
  • 在注意力机制上,DDPM原先只在16×16这一层增加单头注意力层(缩放点积注意力),而ADM在32×32、16×16、8×8各层均增加了多头注意力层;
  • 在上下采样上,DDPM原先在下采样使用池化或卷积层、在上采样使用插值或卷积层,而ADM使用残差卷积块;

ADM的代码开源,代码地址是:github.com/openai/guid...,其是在Improved DDPM的代码基础上进行修改。网络结构定义的相关代码在guided-diffusion/unet.py的UNetModel类中,例如,ADM使用残差块卷积进行下采样的代码如下:

python 复制代码
            if level != len(channel_mult) - 1:
                # 除编码器最后一层外的其他层,需要进行下采样输出到下一层
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        # 若标记resblock_updown为True,则使用残差卷积块进行下采样
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            # 设置残差卷积块中需进行下采样
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

自适应组归一化

图1 各种归一化方式的示意

ADM还使用了自适应组归一化(Adaptive Group Normalization,AdaGN),组归一化如图1最右侧所示,即对一个图片样本的所有像素,按通道分组进行归一化,而自适应归一化可表示为以下公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> AdaGN ( h , y ) = y s GroupNorm(h) + y b \text{AdaGN}(h,y)=y_s\text{GroupNorm(h)}+y_b </math>AdaGN(h,y)=ysGroupNorm(h)+yb 其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h是残差卷积块中第一个卷积层的输出, <math xmlns="http://www.w3.org/1998/Math/MathML"> y s y_s </math>ys、 <math xmlns="http://www.w3.org/1998/Math/MathML"> y b y_b </math>yb分别是步数和图片分类的Embedding向量经过线性层后的投影。自适应归一化的代码如下所示:

python 复制代码
            # 经过第一个卷积层的输出
            h = self.in_layers(x)
            ......
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            # 取y_s和y_b
            scale, shift = th.chunk(emb_out, 2, dim=1)
            # 按y_s * GroupNorm(h) + y_b进行自适应组归一化
            h = out_norm(h) * (1 + scale) + shift
            # 经过第二个卷积层输出
            h = out_rest(h)

论文通过实验发现,使用自适应组归一化能够进一步优化FID。

Classifier Guidance

除了在网络结构上精心设计和优化外,GAN还在有条件(已知图片类别)的图片生成中大量使用了图片类别信息。基于此,ADM一方面在自适应组归一化中引入图片类别的Embedding向量作为模型输入,另一方面设计了Classifier Guidance机制,通过引入一个分类器指导反向扩散过程:预先使用带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt训练分类器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ϕ ( y ∣ x t ) p_{\phi}(y|x_t) </math>pϕ(y∣xt)实现对类别的预测;在逐步反向扩散生成图片时,DDPM在每一步基于扩散模型预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt)和方差 <math xmlns="http://www.w3.org/1998/Math/MathML"> Σ θ ( x t ) \Sigma_\theta(x_t) </math>Σθ(xt),并由公式 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t ) ) \mu_\theta(x_t)=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}}\epsilon\theta(x_t)\right) </math>μθ(xt)=αt 1(xt−1−αˉt 1−αtϵθ(xt))计算得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t ) \mu_\theta(x_t) </math>μθ(xt),即得到了 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1高斯分布的均值和方差,在此基础上,ADM使用分类器 <math xmlns="http://www.w3.org/1998/Math/MathML"> p ϕ ( y ∣ x t ) p_{\phi}(y|x_t) </math>pϕ(y∣xt)输出对数的梯度对均值进行调整,并使用调整均值后的高斯分布进行采样得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x_{t-1} </math>xt−1,均值调整公式如下所示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ^ θ ( x t ∣ y ) = μ θ ( x t ∣ y ) + s ⋅ Σ θ ( x t ∣ y ) ∇ x t log ⁡ p ϕ ( y ∣ x t ) \hat{\mu}\theta(x_t|y)=\mu\theta(x_t|y)+s\cdot\Sigma_\theta(x_t|y)\nabla_{x_{t}}\log{p_{\phi}{(y|x_t)}} </math>μ^θ(xt∣y)=μθ(xt∣y)+s⋅Σθ(xt∣y)∇xtlogpϕ(y∣xt)

其中,系数 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s被称为Guidance Scale,论文通过实验发现随着 <math xmlns="http://www.w3.org/1998/Math/MathML"> s s </math>s的增加,生成图片的质量会提升,但多样性会减少。引入Classifier Guidance后的采样算法步骤如图2所示。

图2 引入Classifier Guidance后的采样算法

ADM中使用的分类器网络结构和扩散模型网络结构近似,均采用U-Net,但只使用编码器层和中间层,而没有解码器层,另外,由于分类器的目标是预测类别,因此类别没有作为输入。分类器网络结构定义的相关代码在guided-diffusion/unet.py的UNetModel类中。 在分类器和扩散模型训练完成后,便可使用其进行图片采样。采样时使用分类器输出梯度对均值进行调整的代码在guided-diffusion/gaussian_diffusion.py的condition_mean方法中,如下所示:

python 复制代码
    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        # 使用分类器输出分类器梯度
        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
        # 根据分类器梯度、均值和方差计算新均值
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

而其中使用分类器输出分类器梯度的cond_fn方法代码如下所示:

python 复制代码
    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            # 分类器输出分类结果
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            # 计算梯度并返回
            return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

效果

图3 对比实验结果

论文在多个数据集上对ADM(仅做网络结构优化)、ADM-G(同时引入Classifier Guidance机制)和其他模型进行了对比实验,从FID指标上,ADM、特别是ADM-G超过了GAN,实现了SOTA。

Classifier-Free Guidance

在Classifier Guidance机制被提出后,紧接着Google于2021年发表了论文《Classifier-Free Diffusion Guidance》。这篇论文指出Classifier Guidance仍存在以下不足:一是Classifier Guidance需要额外训练分类器,二是Classifier Guidance会导致基于梯度的对抗攻击,欺骗FID、IS这类基于分类器的评估指标。因此,这篇论文提出了一种不需要训练分类器、但仍可以基于类别信息指导反向扩散过程的机制------Classifier-Free Guidance。 之前ADM等扩散模型可使用类别信息进行有条件的图片生成,由模型基于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y),或是不使用类别信息进行无条件的图片生成,由模型仅基于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt),但这两种情况需要分别训练模型,而Classifier-Free Guidance的思想是在模型训练时,按一定比例丢弃类别信息,使得模型能够同时学习有条件的图片生成和无条件的图片生成,这样在采样生成图片时,由同一个模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y)和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt),并使用两者的差值等价替换分类器输出的梯度对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y)进行调整,调整公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ϵ ^ θ ( x t , y ) = ϵ θ ( x t ) + s ⋅ ( ϵ θ ( x t , y ) − ϵ θ ( x t ) ) \hat{\epsilon}\theta(x_t,y)=\epsilon\theta(x_t)+s\cdot(\epsilon_\theta(x_t,y)-\epsilon_\theta(x_t)) </math>ϵ^θ(xt,y)=ϵθ(xt)+s⋅(ϵθ(xt,y)−ϵθ(xt))

再基于调整后的 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ^ θ ( x t , y ) \hat{\epsilon}\theta(x_t,y) </math>ϵ^θ(xt,y)计算均值,并从高斯分布中采样得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x{t-1} </math>xt−1。

CLIP Guidance

Classifier Guidance使用类别信息指导反向扩散过程,那是否可以使用除类别外的其他信息指导反向扩散过程?2022年发表的论文《More Control for Free! Image Synthesis with Semantic Diffusion Guidance》就尝试使用了其他信息,其中包括在多模态领域应用比较广泛的CLIP模型。 CLIP模型包括两部分,图片编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) f(x) </math>f(x)和文本编码器 <math xmlns="http://www.w3.org/1998/Math/MathML"> g ( c ) g(c) </math>g(c),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x为图片, <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c为文本。训练阶段,采用对比学习,使得正确图片、文本对 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( x , c ) (x,c) </math>(x,c)的点积 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) ⋅ g ( c ) f(x)\cdot g(c) </math>f(x)⋅g(c)尽可能大,错误图片、文本对的点积尽可能小。因此,在推理阶段,可以进行文本和图片相关性的比较。关于CLIP模型的详细介绍,可以阅读原论文《Learning Transferable Visual Models From Natural Language Supervision》《AIGC系列-CLIP论文阅读笔记》。 在Classifier Guidance中可以使用CLIP模型替换分类器,对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt,使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x t ) ⋅ g ( c ) f(x_t)\cdot g(c) </math>f(xt)⋅g(c)的梯度调整 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ θ ( x t ∣ c ) \mu_\theta(x_t|c) </math>μθ(xt∣c),公式如下所示:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> μ ^ θ ( x t ∣ c ) = μ θ ( x t ∣ c ) + s ⋅ Σ θ ( x t ∣ c ) ∇ x t ( f ( x t ) ⋅ g ( c ) ) \hat{\mu}\theta(x_t|c)=\mu\theta(x_t|c)+s\cdot\Sigma_\theta(x_t|c)\nabla_{x_t}(f(x_t)\cdot g(c)) </math>μ^θ(xt∣c)=μθ(xt∣c)+s⋅Σθ(xt∣c)∇xt(f(xt)⋅g(c))

和Classifier Guidance中的分类器类似,需使用带噪声的图片和文本对训练CLIP模型以获得正确的梯度。

GLIDE

在上述工作的基础上,OpenAI于2022年发表了论文《GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models》,其中发布了GLIDE(Guided Language to Image Diffusion for Generation and Editing)模型,用于基于文本的图片生成。图4是使用GLIDE模型基于文本生成的图片。

图4 使用GLIDE模型基于文本生成的图片

基于文本的图片生成

一般的扩散模型从随机采样的高斯噪声开始,不能生成特定的图片,而GLIDE在已有扩散模型的基础上,使用文本信息指导扩散过程,对于带噪声的图片 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt和文本 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c,能够通过模型预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> p θ ( x t − 1 ∣ x t , c ) p_\theta(x_{t-1}|x_t,c) </math>pθ(xt−1∣xt,c),从而逐步降噪,实现了基于文本的图片生成。 具体实现上,GLIDE基于ADM模型,但模型参数和训练数据规模更大,模型参数达到35亿。GLIDE先将文本 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c转化为长度为 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K的token序列,再通过Transformer输出文本的Embedding向量,最后使用文本Embedding向量替换原ADM模型输入中的类别Embedding向量。另外,文本Embedding向量还会经过投影与注意力层中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K、 <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V拼接在一起,通过注意力机制指导扩散过程。 GLIDE的代码开源,代码地址是:github.com/openai/guid...。通过Transformer输出文本的Embedding向量作为模型输入的代码在glide_text2im/text2im_model.py中,如下所示:

python 复制代码
    def forward(self, x, timesteps, tokens=None, mask=None):
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        if self.xf_width:
            text_outputs = self.get_text_emb(tokens, mask)
            xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
            emb = emb + xf_proj.to(emb)
        else:
            xf_out = None
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, xf_out)
            hs.append(h)
        h = self.middle_block(h, emb, xf_out)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, xf_out)
        h = h.type(x.dtype)
        h = self.out(h)
        return h

注意力层拼接文本Embedding向量的代码在glide_text2im/unet.py中,如下所示:

python 复制代码
class QKVAttention(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv, encoder_kv=None):
        """
        Apply QKV attention.

        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        if encoder_kv is not None:
            assert encoder_kv.shape[1] == self.n_heads * ch * 2
            ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
            # 拼接文本Embedding向量到k
            k = th.cat([ek, k], dim=-1)
            # 拼接文本Embedding向量到v
            v = th.cat([ev, v], dim=-1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = th.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

Classifier-Free Guidance

GLIDE也使用了Classifier-Free Guidance机制,只是将类别替换为文本,因此,对模型所预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , y ) \epsilon_\theta(x_t,y) </math>ϵθ(xt,y)进行调整的公式如下: <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ^ θ ( x t , c ) = ϵ θ ( x t ) + s ⋅ ( ϵ θ ( x t , c ) − ϵ θ ( x t ) ) \hat{\epsilon}\theta(x_t,c)=\epsilon\theta(x_t)+s\cdot(\epsilon_\theta(x_t,c)-\epsilon_\theta(x_t)) </math>ϵ^θ(xt,c)=ϵθ(xt)+s⋅(ϵθ(xt,c)−ϵθ(xt)) GLIDE在上一节已训练得到基于文本的模型、可预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , c ) \epsilon_\theta(x_t,c) </math>ϵθ(xt,c)的基础上,对模型进行微调,将20%的文本Token序列替换成空序列,从而使得模型在具备预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t , c ) \epsilon_\theta(x_t,c) </math>ϵθ(xt,c)的基础上,能够进一步预测 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ θ ( x t ) \epsilon_\theta(x_t) </math>ϵθ(xt),从而在采样时,能够基于调整后的噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ϵ ^ θ ( x t , c ) \hat{\epsilon}_\theta(x_t,c) </math>ϵ^θ(xt,c)降噪生成图片。

CLIP Guidance

GLIDE也使用了CLIP Guidance机制,但从实验结果上,其效果不如Classifier-Free Guidance。

参考文献

相关推荐
学习前端的小z22 分钟前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc
幸运超级加倍~41 分钟前
软件设计师-上午题-16 算法(4-5分)
笔记·算法
yannan201903131 小时前
【算法】(Python)动态规划
python·算法·动态规划
埃菲尔铁塔_CV算法1 小时前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR1 小时前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
linsa_pursuer1 小时前
快乐数算法
算法·leetcode·职场和发展
小芒果_011 小时前
P11229 [CSP-J 2024] 小木棍
c++·算法·信息学奥赛
qq_434085901 小时前
Day 52 || 739. 每日温度 、 496.下一个更大元素 I 、503.下一个更大元素II
算法
Beau_Will1 小时前
ZISUOJ 2024算法基础公选课练习一(2)
算法
打羽毛球吗️1 小时前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习