DALL·E模型及其论文详解

详细介绍DALL·E的模型架构和训练过程,详细解读其论文Zero-Shot Text-to-Image Generation。

🌺DALL·E系列文章列表🌺

万字长文解读深度学习------dVAE(DALL·E的核心部件)

文章目录

建议阅读

架构说明

训练架构

由于论文中并没有给出详细的DALL·E架构,这几天,在网上找DALL·E 的架构图,始终没有找到合适的,最终找到了下面几个架构图,但是多多少少都会有点问题:

  1. 架构1
    这个图像(来源)的问题(红框)在于,最终在生成时,并没有将最后的潜空间特征传入Image Decoder,用来生成图像。

通过 Transformer 学习文本和图像 latent tokens 的联合分布 会传递给 Image Decoder(图像解码器),解码器会基于这些 latent tokens 生成最终的图像。

  1. 架构2

这个图像(来源)的问题(红框)在于,由CNN Encoder 生成的离散 latent token,再由Gumbel-Softmax采样即可获得 新的概率分布,然后将其作为权重,对相应的 codebook 向量进行累积就可以获得 latent vector(近似连续的 latent token)。最终的近似连续的 latent token 是需要传到 Transformer 中与 Text token进行融合,这里并没有箭头进行传递,是初始化了一系列的 latent token 进行融合。

个人认为结构应该如下:

图像数据经过 dVAE 的 Image Encoder 生成的latent token与文本通过BPE Encoder得到的Text token进行拼接,再经过 Transformer 注意力机制后,输入 dVAE 的 Image Decoder 重建图像。

推理架构

推理过程从文本输入开始,经过 BPE 编码器、Transformer 和图像解码器,最终生成图像。

概述

DALL·E 是 OpenAI 于 2021 年推出的一种革命性模型,通过无缝连接文本描述和图像合成,革新了生成式 AI 领域。在此,我们将深入探讨 DALL·E 的工作原理、训练细节以及支持其创造能力的数据集。

DALL·E 使用 Transformer 将文本和图像的 Token 建模为一个数据流。与这种方法相关的两个问题是:

  1. 像素作为图像 Token 会占用过多内存。
  2. 概率目标会优先考虑像素之间的短程依赖。

为了解决这些问题,DALL·E 采用了两阶段的训练过程:

  1. 学习视觉 Codebook。
  2. 学习先验分布。

训练过程

阶段 1

我们训练离散变分自动编码器(dVAE) ,将每个 256 × 256 的 RGB 图像压缩为一个 32 × 32 的图像 token 网格,其中每个元素可以取 8192 个可能值。这种方式将 Transformer 的上下文大小减少了 192 倍( 192 = ( 256 ÷ 32 ) × ( 256 ÷ 32 ) × 3 192=(256\div 32) \times (256\div 32) \times 3 192=(256÷32)×(256÷32)×3),同时不会显著降低视觉质量。

We train a discrete variational autoencoder (dVAE) to compress each 256 × 256 RGB image into a 32 × 32 grid of image tokens, each element of which can assume 8192 possible values. This reduces the context size of the transformer by a factor of 192 without a large degradation in visual quality.


阶段 2

我们将最多 256 个(当文本 Token 不足 256 个时会进行 Padding,下文会解释。)经过 BPE 编码的文本 token 与 32 × 32 = 1024 的图像 token 拼接起来,并训练一个自回归 Transformer 来建模文本和图像 token 的联合分布。

We concatenate up to 256 BPE-encoded text tokens with the 32 × 32 = 1024 image tokens, and train an autoregressive transformer to model the joint distribution over the text and image tokens.


整体过程

这一整体过程可以被视为在图像 x x x、文本 y y y 和 RGB 编码图像的 token z z z 的联合似然分布上,最大化证据下界(ELB,Evidence Lower Bound)的优化问题。我们使用以下分解方式来建模这种分布:

The overall procedure can be viewed as maximizing the evidence lower bound (ELB) on the joint likelihood of the model distribution over images x x x, captions y y y, and the tokens z z z for the encoded RGB image. We model this distribution using the factorization p θ , ψ ( x , y , z ) = p θ ( x ∣ y , z ) p ψ ( y , z ) p_{\theta,\psi}(x, y, z) = p_\theta(x | y, z)p_\psi(y, z) pθ,ψ(x,y,z)=pθ(x∣y,z)pψ(y,z), which yields the lower bound:

p θ , ψ ( x , y , z ) = p θ ( x ∣ y , z ) p ψ ( y , z ) , p_{\theta,\psi}(x, y, z) = p_\theta(x | y, z)p_\psi(y, z), pθ,ψ(x,y,z)=pθ(x∣y,z)pψ(y,z),

由此可以得出以下下界:

ln ⁡ p θ , ψ ( x , y ) ≥ E z ∼ q ϕ ( z ∣ x ) [ ln ⁡ p θ ( x ∣ y , z ) ] − β D K L ( q ϕ ( y , z ∣ x ) , p ψ ( y , z ) ) , \ln p_{\theta,\psi}(x, y) \geq \mathbb{E}{z \sim q\phi(z | x)} \left[ \ln p_\theta(x | y, z) \right] - \beta D_{KL}(q_\phi(y, z | x), p_\psi(y, z)), lnpθ,ψ(x,y)≥Ez∼qϕ(z∣x)[lnpθ(x∣y,z)]−βDKL(qϕ(y,z∣x),pψ(y,z)),

其中:

  • q ϕ q_\phi qϕ 表示由 dVAE 编码器生成的 32 × 32 图像 token 在给定 RGB 图像 x x x 的分布;
  • p θ p_\theta pθ 表示由 dVAE 解码器在给定图像 token 的条件下生成 RGB 图像的分布;
  • p ψ p_\psi pψ 表示由 Transformer 模型建模的文本和图像 token 的联合分布。

需要注意的是,上述下界仅在 β = 1 \beta = 1 β=1 的情况下成立。然而,在实践中,我们发现使用更大的 β \beta β 值会更有帮助(参考 Higgins 等,2016 年)。

论文中提高 KL 权重 β \beta β 至 6.6 6.6 6.6(大于 1)有助于更好地利用 Codebook,从而最终提升训练效果。

接下来的小节将更详细地描述两个阶段的过程。

训练细节

第一阶段:Learning the Visual Codebook

这一阶段的目标是通过 dVAE 学习一个高效的视觉 Codebook 表示,同时克服离散分布的优化挑战。通过使用 Gumbel-Softmax 技术、温度退火机制、卷积调整等方法,确保模型稳定训练和泛化性能的提升。

1. 目标:通过证据下界(ELBO)优化视觉 Codebook

在训练的第一阶段,我们通过优化 ELB(证据下界)来训练 dVAE 的参数 ϕ \phi ϕ 和 θ \theta θ,这相当于仅基于图像进行训练。

由于没有文本,所以把之前的 y y y 去掉。证据下界是一个下界,用于评估我们近似的后验分布( q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x))和真实后验分布( p θ ( z ∣ x ) p_\theta(z|x) pθ(z∣x))之间的接近程度。如果我们的近似分布可以很好地拟合真实后验分布,那么证据下界将会比较高。

通常使用 "Evidence Lower Bound" 作为损失函数,用于训练概率生成模型,如变分自编码器(VAE)和概率图模型。通过最大化证据下界,我们可以使生成模型更好地拟合数据,并学习到数据背后的潜在结构。ELBO = 重构损失 + KL 散度

  1. 重构损失:衡量模型如何重建输入图像,通常希望最小化这一项。
  2. KL 散度:衡量近似后验分布和先验分布之间的差异,通常希望最小化 KL散度,以使模型的潜在表示遵循先验分布。

证据下界(Evidence Lower Bound, ELBO)用于训练变分自编码器(dVAE)时优化模型参数。它为一个下界,能够估算数据的对数似然,并且通过优化该下界来训练模型,使得近似的后验分布尽可能接近真实的后验分布。

ln ⁡ p ( x ) ≥ E z ∼ q ϕ ( z ∣ x ) [ ln ⁡ p θ ( x ∣ z ) ] − D KL ( q ϕ ( z ∣ x ) ∥ p ψ ( z ) ) \ln p(x) \geq \mathbb{E}{z \sim q\phi(z|x)} [\ln p_\theta(x|z)] - D_{\text{KL}}(q_\phi(z|x) \| p_\psi(z)) lnp(x)≥Ez∼qϕ(z∣x)[lnpθ(x∣z)]−DKL(qϕ(z∣x)∥pψ(z))

解释各项符号

  • x x x:输入图像或数据。
  • z z z :潜在变量,由编码器 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x) 生成。
  • p θ ( x ∣ z ) p_\theta(x|z) pθ(x∣z) :解码器生成的图像的重建概率。通过这个概率,模型能够从潜在变量 z z z 中重建输入数据 x x x,从而实现图像的生成。
  • q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x) :编码器的近似后验分布。它通过神经网络(编码器)将输入图像 x x x 映射到潜在空间 z z z,并生成潜在变量的概率分布。
  • p ψ ( z ) p_\psi(z) pψ(z) :潜在变量 z z z 的先验分布。通常是一个简单的均匀分布,用于初始化潜在空间的分布。

作用

  • 重建误差 :优化解码器的重建概率 p θ ( x ∣ z ) p_\theta(x|z) pθ(x∣z),即通过最大化解码器的输出概率来提高从潜在变量 z z z 重建图像 x x x 的能力。优化目标是使得解码器生成的图像尽可能与输入图像一致。
  • KL 散度 :KL 散度项 D KL ( q ϕ ( z ∣ x ) ∥ p ψ ( z ) ) D_{\text{KL}}(q_\phi(z|x) \| p_\psi(z)) DKL(qϕ(z∣x)∥pψ(z)) 通过最小化编码器生成的近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x) 与先验分布 p ψ ( z ) p_\psi(z) pψ(z) 之间的差异,来约束潜在空间的分布。这个过程帮助模型使近似后验分布尽可能接近真实的后验分布,从而提高模型的生成能力。

1. 初始先验分布 p ψ ( z ) p_\psi(z) pψ(z)

  • 定义 :初始的先验分布 p ψ ( z ) p_\psi(z) pψ(z) 通常设定为均匀分布,表示潜在空间的初始分布没有偏向任何特定的区域。
  • Codebook 的作用 :Codebook 用于离散化潜在空间,将潜在变量 z z z 映射为离散的嵌入向量。Codebook 中的向量数量 K = 8192 K=8192 K=8192 确定了潜在空间的离散级别。

2. 近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x)

  • 定义 :近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x) 是编码器根据输入图像 x x x 生成的潜在变量 z z z 的概率分布。它将连续的输入图像 x x x 映射到离散的潜在空间 z z z。
  • 具体过程 :编码器生成的输出是一个 32 × 32 32 \times 32 32×32 的网格,每个位置表示一个离散的概率分布。每个位置的分类分布由 8192 8192 8192 个 logits 参数化,这些 logits 被通过 Softmax 或 Gumbel-Softmax 归一化成离散分布。

3. 真实后验分布 p θ ( z ∣ x ) p_\theta(z|x) pθ(z∣x)

  • 定义 :真实后验分布 p θ ( z ∣ x ) p_\theta(z|x) pθ(z∣x) 表示在给定输入 x x x 后,潜在变量 z z z 的真实条件概率分布。
  • 挑战 :由于计算真实后验分布是不可行的,因此我们使用 变分推断 来近似它。
  • 优化目标 :通过最大化 ELBO,模型能够优化其参数,使得生成的近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x) 更接近真实后验分布 p θ ( z ∣ x ) p_\theta(z|x) pθ(z∣x)。

总结

在 dVAE 的训练过程中,目标是通过优化证据下界(ELBO)来训练编码器和解码器,使得模型能够生成高质量的图像。在这一过程中:

  1. 初始先验 :先验分布 p ψ ( z ) p_\psi(z) pψ(z) 是 Codebook 中的均匀分布,用于初始化潜在空间。
  2. 近似后验分布 :编码器 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(z∣x) 将输入图像映射到潜在空间,并输出一个离散分布,用于生成潜在变量。
  3. 重建误差和 KL 散度:通过最小化 KL 散度项,并优化重建误差,模型能够学习如何从潜在空间中重建图像,并使近似后验分布尽可能接近真实的后验分布。

这种方法通过优化 ELBO,使得模型能够更好地捕捉数据的潜在结构,并生成真实、精确的图像或数据。

2. 优化挑战与解决方法

ELB 的优化难点:

  • 由于 q ϕ q_\phi qϕ 是离散分布,因此不能使用重参数化梯度来优化。
  • 重新参数化梯度是一种常用于训练变分自编码器(VAE)等生成模型的技术。

优化:

针对离散分布的潜在变量优化这个难点,VQ-VAE使用停止梯度(Stop Gradient)解决,而dVAE使用Gumbel-Softmax解决(更优解),这也是VQ-VAE和dVAE的主要区别。详情请参考:万字长文解读深度学习------dVAE(DALL·E的核心部件),在DALL·E中使用的是dVAE。

Gumbel-Softmax Relaxation:

为了解决上述难点,我们使用了 Gumbel-Softmax 技术。实现方式是将 q ϕ q_\phi qϕ 替换为一个经过 Gumbel-Softmax 松弛(relaxation)的版本,其中温度参数 τ → 0 \tau \to 0 τ→0 时松弛效果收紧,逐渐逼近真实离散分布。

DALL-E中的Gumbel-Softmax:向Softmax中引入超参数 τ \tau τ使得ArgMax可导。(设置 τ = 1 16 \tau=\frac{1}{16} τ=161)

超参数 τ \tau τ在深度学习中有一个专业术语叫温度,可以通过调整Softmax曲线的平滑程度来实现不同的功能:

  • 当 τ \tau τ的值大于1时,我们可以得到更加平滑的Softmax曲线,这种方式可以得到更加平滑的置信度分布;
  • 当 τ \tau τ的值小于1时,得到的Softmax曲线更加陡峭;
  • 当 τ \tau τ的值趋近于0时,可以得到近似ArgMax的效果,但是这时Softmax还是可导的。

训练起始的温度系数 τ \tau τ 很高,在训练的过程中,逐渐降低 τ \tau τ,以便其逐渐逼近 ArgMax。在推理阶段就不再需要 Gumbel Softmax,直接使用 ArgMax 即可。


3. 其他优化

在论文的附录A:dVAE的细节中还详细的描述了dVAE的其他优化点。

DALL-E的dVAE的编码器和解码器都是基于残差网络构建的,DALL-E保持了残差网络的基础结构,但也有其针对性的调整,核心修改如下:

  1. 编码器的输入层的卷积核的大小是7x7;
  2. 编码器的最后一个卷积的卷积核的大小是1x1,用于产生大小是32x32x8192的Feature Map;
  3. 编码器使用最大池化而非原来的平均池化进行降采样;
  4. 解码器的的第一个卷积核最后一个卷积的卷积核的大小均为1x1;
  5. 解码器使用了最近邻的方式进行上采样;

第二阶段:Learning Prior Distribution

在这一阶段,模型将文本表示与图像 latent 表示结合,通过 Transformer的 Encoder 捕捉两者之间的联合分布,之后通过Transformer的 Decoder自回归的方式生成图像的潜在表示(latent tokens),这些潜在表示被传入 dVAE 解码器进行图像重建。

在 DALL·E 的 Transformer 模型中,使用了一个 12B 参数量的稀疏 Transformer,包含 64 层,每层有 62 个注意力头,每个头的隐藏层大小为 64。文本输入通过 BPE(Byte Pair Encoding)编码为 Token,最大长度限制为 256 个 Token,词表大小为 16,384。图像的词表对应 Codebook,大小为 8,192。图像的 Token 是通过 dVAE 的编码器加 ArgMax 采样获得,未添加 Gumbel 噪声。需要注意的是,训练和推理的代码尚未开源。

作者将文本输入固定为 256 个 Token,因此当文本 Token 不足 256 个时会进行 Padding,如下图所示,同时也会给 Image Token 添加行索引 embedding 和列索引 embedding。还会有一个特殊的 Token 来标识无文本输入的情况。此外,因为输入中既包含文本 Token,又包含图像 Token,而学习的主要目标是生成图像 Token,因此训练中文本相关的交叉熵损失权重为 1/8,而图像相关的交叉熵损失权重为 7/8。

文本表示与图像token的嵌入方案

图10 . 展示了一个假设的 Transformer 版本的嵌入方案,该版本的最大文本长度为6个token。每个方框表示一个大小为 d model = 3968 d_{\text{model}} = 3968 dmodel=3968的向量。在此示例中,标题长度为4个token,因此使用了2个填充token(如第2.2节所述)。每个图像词汇嵌入都会与一个行位置编码以及列位置编码相加。

稀疏/掩码Transformer

DALL-E使用的Transformer是稀疏Transformer,它的特点是只关注Top-k个贡献最大的特征的状态,因此比普通的Transformer更能关注重要的特征。DALL-E的Transformer有64个自注意力层,每个层的头数是62,每个注意力头的维度是64。

DALLE共使用了3个不同形式的稀疏自注意力编码:

( a )是行注意力,每个标志只关注top-5的相关标志;

( b )是列注意力;

( c )是列注意力的转置,它能够更好的利用GPU;

( d )是卷积自注意力。

图像生成/推理过程

图像生成过程:

  1. 将输入文本编码成特征向量,然将特征向量送入到自回归的Transformer中生成图像的token
  2. 将图像的token送入到dVAE的解码器中得到生成图像
  3. 通过CLIP对生成样本进行评估,得到最终的生成结果

混合精度训练 (Mixed Precision Training)

  • 目的:为了节省 GPU 内存并提高计算吞吐量,大多数参数、Adam 优化器的矩阵和激活都使用 16 位精度(半精度浮点数)存储。
  • 优势:这种方法能够减少内存占用,加速计算,同时保持训练的数值稳定性,提升训练效率。

分布式运算 (Distributed Computing)

  • 参数分片 (Parameter Sharding):当模型太大,单机显存不足时,采用参数分片技术将模型的参数拆分到不同的 GPU 上进行存储和计算。
  • 多机多卡训练的挑战:不同机器之间的通信带宽远小于单机多卡之间的带宽,因此在进行多机多卡训练时,需要特别注意通信瓶颈的问题。通常采用高效的通信协议,如梯度压缩和通信优化策略,来减轻这些问题。

CLIP (Contrastive Language-Image Pre-training)

在生成阶段,模型从 Transformer 中采样多个候选图像(例如 N = 512)。随后,CLIP 模型根据图像与文本描述的匹配程度为这些图像打分,并选取得分最高的前 k 张图像。

为了保持生成图像的多样性和准确性,所有样本都使用 无温度缩放(即 t = 1) 的采样方式,除非另有说明。

参考

AI绘画原理解析:从CLIP、BLIP到DALLE、DALLE 2、DALLE 3、Stable Diffusion(含ControlNet详解)
【论文精读】DALLE: Zero-Shot Text-to-Image Generation零样本文本到图像生成
文生图模型演进:AE、VAE、VQ-VAE、VQ-GAN、DALL-E 等 8 模型

参考博文(DALL·E和dVAE的很多国内文章图片都来自下面的博文):
Understanding VQ-VAE (DALL-E Explained Pt. 1)
How is it so good ? (DALL-E Explained Pt. 2)
How OpenAI's DALL-E works?

相关推荐
AIGC大时代29 分钟前
方法建议ChatGPT提示词分享
人工智能·深度学习·chatgpt·aigc·ai写作
糯米导航33 分钟前
ChatGPT Prompt 编写指南
人工智能·chatgpt·prompt
Damon小智35 分钟前
全面评测 DOCA 开发环境下的 DPU:性能表现、机器学习与金融高频交易下的计算能力分析
人工智能·机器学习·金融·边缘计算·nvidia·dpu·doca
赵孝正1 小时前
特征选择(机器学习)
人工智能·机器学习
QQ_7781329741 小时前
Pix2Pix:图像到图像转换的条件生成对抗网络深度解析
人工智能·神经网络
数据馅1 小时前
window系统annaconda中同时安装paddle和pytorch环境
人工智能·pytorch·paddle
高工智能汽车1 小时前
2025年新开局!谁在引领汽车AI风潮?
人工智能·汽车
不爱原创的Yoga1 小时前
自动驾驶汽车目前面临的最大技术挑战是什么?
人工智能·自动驾驶·汽车
罗小罗同学2 小时前
人工智能的出现,给生命科学领域的研究带来全新的视角|行业前沿·25-01-22
人工智能·搜索引擎·生命科学