51-36 DiT,视频生成模型Sora背后的核心技术

今天跟进的是UC Berkeley Wiliam Peebles和New York University Saining Xie的联合工作Scalable Diffusion Models with Transformers。2022年12月发布,Wiliam Peebles在Meta AI实习期间完成该论文,现就职于OpenAI。

扩散模型是一种深度学习生成模型,可生成各种各样的高分辨率图像或视频,主要办法是模拟数据逐步去噪过程来生成新样本(采样)。Diffusion Transformers,DiT是一种新型扩散模型,结合去噪扩散概率模型DDPMs和Transformer架构,研究了对于模型复杂度 (GFLOPs) 和样本质量 (FID) 的缩放性。DiT核心思想是使用Transforme取代U-Net主干作为扩散模型骨干网络,以处理图像潜在表示。

近期伴随OpenAl视频生成模型Sora的大热,DiT被视为Sora背后的核心技术之一而广受关注。

图 1,DiT扩散模型实现了最先进的图像质量,以上为512×512和256×256 ImageNet上训练DiT-XL/2模型生成示例。

Abstract

我们基于transformer架构探索了一类新的扩散模型。训练潜在扩散模型,用对潜在patch进行操作的transformer替换常用的U-Net主干。我们通过Gflops测量前向传递复杂度,分析DiT可扩展性。我们发现具有较高Gflop的DiTs------通过增加transformer深度/宽度或输入token数量------始终具有较低的 FID。除了具有良好的可扩展性属性外,我们最大的DiT-XL/2模型在类条件512×512和256×256 ImageNet基准测试中优于所有先前的扩散模型,实现了 2.27 的最新FID。

1. Introduction

机器学习正在经历由transformer驱动的复兴。在过去的五年中,自然语言处理BERT/GPT、视觉ViT和其他几个领域的神经架构在很大程度上被transformer所包含。然而,许多类别的图像生成模型仍然坚持趋势------虽然transformer在自回归模型中得到了广泛的应用[GPT3,i-GPT,GPT2,DALLE]------它们在其他生成建模框架中较少采用。例如,扩散模型一直是图像生成模型最新前沿,然而,ADM/DALLE2等都采用卷积U-Net架构作为骨干网络。

DDPM的开创性工作首次引入了扩散模型的U-Net主干。最初在像素级自回归模型和条件GANs中取得了成功,U-Net继承自PixelCNN++,但有一些更改。该模型是卷积的,主要由 ResNet 块组成。ADM论文做了 UNet 几种架构选择的消融实验,例如使用自适应归一化层为卷积层注入条件信息和通道数量。然而,DDPM的UNet高级设计在很大程度上保持不变。

在这项工作中,我们旨在揭开扩散模型中架构选择的重要性,并为未来生成建模研究提供经验基线。我们表明 U-Net 归纳偏差对扩散模型的性能并不重要,并且它们可以很容易地替换为transformers标准设计。因此,扩散模型可以很好地从最近的架构统一趋势中受益------例如,通过继承其他领域的最佳实践和训练方法,并保留可扩展性、鲁棒性和效率等有利属性。标准化架构还将为跨域研究开辟了新的可能性。

在本文中,我们专注于基于transformer新型扩散模型。我们称它们为diffusion transformer,DiTs。DiTs遵循ViTs最佳实践,已被证明比传统卷积网络如ResNet更有效地扩展到视觉识别。

更具体地,我们研究了transformer相对于网络复杂度与样本质量的缩放行为。我们表明,通过在潜在扩散模型 (LDM) 框架下构建和基准化 DiT 设计空间,其中扩散模型在 VAE 的潜在空间内进行训练,我们可以成功地将 U-Net 主干替换为transformer。我们进一步表明,DiTs 是扩散模型的可扩展架构:网络复杂度(由 Gflops 测量)与样本质量(由 FID 测量)之间存在很强的相关性。通过简单地缩放DiT并使用高容量主干(118.6 Gflops)训练LDM,我们能够在类条件256 × 256 ImageNet基准上获得2.27 FID的最佳结果。

2.1 Transformers

Transformer已经取代了跨语言、视觉ViT、强化学习和元学习特定域的架构。在语言域,作为通用自回归模型和ViTs,它们在增加模型大小、训练计算和数据等方面表现出显著的缩放特性。除了语言之外,transformer经过训练以自回归预测像素。它们也在离散码本上作为自回归模型和掩码生成模型进行了训练;前者显示了高达20B参数的出色缩放性能。最后,探讨了transformer在DDPM合成非空间数据,例如,在DALLE2中生成CLIP图像嵌入。在本文中,我们研究了transfromer作为图像扩散模型主干网络的缩放特性。

2.2 Denoising diffusion probabilistic models (DDPMs)

扩散模型和基于分数的生成模型尤其成功,因为图像生成模型 [GLIDE,DALLE,LDM,Imagen]在许多情况下优于以前最先进的生成对抗网络 (GAN) 。在过去两年里,DDPM 的进化主要是由改进的采样技术推动的,如最著名的是无分类器指导、重新制定扩散模型以预测噪声而不是像素[DDPM]、使用级联DDPM管道[CDM]、低分辨率基础扩散模型与上采样器并行训练[ADM]等等。对于上面列出的所有扩散模型,卷积 U-Nets 是主干架构的事实选择。同期工作 [Scalable Adaptive Computation for lterative Generation] 引入了一种基于 DDPM 注意力的新颖、高效架构;我们探索纯transfromer。

2.3 Architecture complexity

在评估文献中图像生成架构复杂性时,使用parameter counts是相当常见的做法。一般来说,这可能是图像模型复杂性的不良指标,因为它们不考虑显著影响性能的图像分辨率。相反,本文中的大部分模型复杂性分析都是通过理论Gflops视角进行的。这使我们与架构设计文献保持一致,其中Gflops被广泛用于衡量复杂性。在实践中,复杂性度量仍然存在争议,因为它经常取决于特定的应用场景。ADM,IDDPM改进扩散模型的开创性工作与我们最相关------在那里,他们分析了 U-Net 架构类的可扩展性和 Gflop 属性。在本文中,我们专注于 Transformer 类。

3. Diffusion Transformers

3.1 Preliminaries

3.1.1 Diffusion formulation

在介绍我们的体系结构之前,我们简要回顾了理解扩散模型DDPM所需的一些基本概念。

3.1.2 Classifier-free guidance

众所周知,与一般抽样技术相比,无分类器指导可以显著提高样本质量,这种趋势适用于我们的 DiT 模型。

3.1.3 Latent diffusion models

直接在高分辨率像素空间中训练扩散模型,在计算上可能是令人望而却步的。潜在扩散模型LDMs通过两阶段方法解决了这个问题:(1)感知压缩阶段,训练一种自动编码器(LDM实际上用VQ-GAN),将图像输入 x 压缩成更小的潜在空间表示 z;(2) 扩散模型阶段,在潜空间中学习预测在 z 上添加的噪声,而不是图像输入 x 上。在潜在扩散模型采样得到图像特征z'。

最后用解码器x=D(z')将其解码,生成新的图像。

如图 2 所示,LDM在使用一小部分Gflops时获得了良好的性能。

备注,作者在这里描述得让人容易混淆,故增删改了一点内容,LDM论文解读参见博文。

51-33 LDM 潜在扩散模型论文精读 + DDPM 扩散模型代码实现_ldm论文复现-CSDN博客

图 2。ImageNet generation with Diffusion Transformers (DiTs)。泡状面积表示扩散模型的Flops。左:在 400K 训练迭代中, DiT 模型 FID-50K(越低越好)。随着模型Flops增加,FID 的性能稳步提高。右图:DiT-XL/2 是计算效率最高的模型,并且优于所有先前基于 U-Net 扩散模型,如 ADM 和 LDM。

由于我们关注计算效率,这使得它们成为架构探索一个有吸引力的起点。在本文中,我们在DiT模型中应用潜在空间表示,尽管它们也可以不加修改地应用像素空间。这使得我们的图像生成管道是一种基于混合的方法;使用现成的卷积VAE和基于transformer的DDPM。

3.2 Diffusion Transformer Design Space

我们介绍了Diffusion transformer (DiTs),一种新的扩散模型结构。我们的目标是尽可能忠实于标准transformer架构,以保持其缩放特性。由于我们的重点是训练图像的去噪扩散概率模型(特别是图像的空间表示),因此DiT基于视觉转换Vision Transformer(ViT)架构,该架构对patch序列进行操作。DiT保留了许多ViT的最佳实践。图3显示了完整DiT体系结构的概述。在本节中,我们将描述DiT 管道过程以及DiT组件。

图 3。The Diffusion Transformer (DiT) architecture。左:我们训练条件潜在 DiT 模型。输入潜在空间表示被分解为patch并由几个 DiT 块处理。右图:DiT 块的详细信息。我们对包含调节的标准transformer块的变体进行了实验,这些变体包含adaptive layer norm, cross-attention and extra input tokens,自适应层范数效果最好。

3.2.1 Patchify

参考ViT,DiT的输入是一个空间表示z(对于256 × 256 × 3的图像,z的形状为32 × 32 × 4)。DiT的第一层是"patchify",它通过线性嵌入每个patch,将空间输入转换为T个tokens的序列,维度为d。在patchify化之后,我们将标准的基于 ViT frequency-based 位置嵌入(正弦余弦版本)应用于所有输入token。patchify创建的token数量T由patch超参数p决定。如图4所示,将p减半将使T翻四倍,因此至少使总transformer Gflops翻四倍。尽管它对Gflops有显著的影响,但注意,改变p对下游参数总数没有显著影响。我们将p = 2,4,8添加到DiT设计空间中。

图 4。DiT的输入规范。给定patch大小p × p, 转变成 I × I × C 空间表示(来自stable diffusion VAE的噪声潜在空间),被"patchified"成一个长度为 T = (I/p)^2 的序列,隐藏维数为d。patch p越小,序列长度越长,Gflops需求越大。

3.2.2 DiT block design

原始ViT中,在patchify之后,输入tokens会直接由一系列Transformer块处理。但DiT的输入除了带noise图像外,有时还会处理额外的条件信息,如noise时间步 t、类标签 c、自然语言等。为了处理条件输入,我们探索了transformer块的四种变体。这些设计对标准ViT块设计进行了小而重要的修改。各模块的设计如图 3 所示。

  • In-context conditioning:我们简单地将 t 和 c 向量embedding作为两个额外token添加到输入序列中,和图像token一样对待。这类似于ViT中的CLS token,它允许我们使用标准的ViT块而无需修改。在最后一个块之后,我们从序列中删除条件token。这种方法给模型引入的Gflops可以忽略不计。
  • Cross-attention block:我们将 t 和 c 嵌入拼接到一个长度为2的序列中,与图像token序列分开。通过修改transformer块,在多头自注意块之后增加了一个多头交叉注意层,类似于transformer原始设计,也类似于LDM用于调节类别标签的设计。交叉注意会给模型增加最多的Gflops,大约是15%的开销。(备注,多头自注意力输出作为query,条件embeddings作为key和value来引入条件)
  • Adaptive layer norm (adaLN) block:随着自适应归一化层在GAN和UNet骨干扩散模型中的广泛使用,我们探索用自适应归一化层(adaLN)取代transformer块中的标准归一化层。我们不是直接学习维度尺度和位移参数 γ 和 β,而是从 t 和 c 嵌入向量和中回归它们。在我们探索的三种块设计中,adaLN增加的Gflops最少,因此计算效率最高。它也是唯一一种受限于将同一函数应用于所有token的条件调节机制。
  • adaLN-Zero block:在ResNets上先前工作发现,将每个残差块初始化为恒等函数是有益的。例如,Goyal等人发现,在监督学习环境下,对每个块中的最终批量归一化因子 γ 进行零初始化可以加速大规模训练[Accurate, Large Minibatch SGD: Training lmageNet in 1 Hour]。扩散U-Net模型使用类似的初始化策略,在任何残差连接之前对每个块中的最终卷积层进行零初始化。我们探索了adaLN DiT块的修改,它具有相同的功能。除了回归 γ 和 β 外,我们还回归了在DiT块内任何残差连接之前立即应用的维度缩放参数 α。我们初始化MLP以输出所有 α 的零向量;这将把整个DiT块初始化为恒等函数。与普通的adaLN块一样,adaLNZero为模型Gflops可以忽略不计。

DiT设计空间包括in-context, cross-attention, adaptive layer norm and adaLN-Zero blocks四种变体。

3.2.3 Model size

我们应用一个N个DiT块的序列,每个块隐藏维度大小为d。在ViT之后,我们使用标准transformer配置,包括 N、d、attention heads 参数。具体来说,我们使用四种配置:DiT-S、DiT-B、DiT-L和DiT-XL。它们涵盖了广泛的模型大小和Flop分配,从0.3到118.6 Gflops,允许我们衡量缩放性能。表1给出了配置的详细信息。我们将B, S, L和XL配置添加到DiT设计空间中。

3.2.4 Transformer decoder

在最后的DiT块之后,我们需要将图像token序列解码输出为噪声预测和对角协方差预测。这两个输出的形状都等于原始空间输入。我们使用标准的线性解码器来做到这一点;我们在最后一层应用线性归一化(如果使用adaLN则自适应),并将每个token线性解码为p×p×2C张量,其中C是DiT空间输入中的通道数。最后,我们将解码后的token重新排列到它们原来的空间布局中,得到预测的噪声和协方差。

我们探索完整的DiT设计空间是patch大小、transformer块架构和模型大小。

4. Experimental Setup

我们探索了DiT设计空间,并研究了模型类的缩放特性。我们的模型根据它们的配置和潜在patch大小p来命名;例如,DiT-XL/2表示XLarge配置,p = 2。

4.1 Training

我们在256 × 256和512 × 512图像分辨率的ImageNet数据集上训练 class-conditional DiT 模型,ImageNet数据集是一个高度竞争的生成建模基准。除了最后一层的线性层用零初始化,其余使用ViT的初始化技术。我们用AdamW优化器。

我们使用恒定学习率1 × 10−4,没有权重衰减,批大小为256。我们使用的唯一数据增强是水平翻转。与之前ViT工作不同,我们没有发现学习率预热和正则化,对于训练DiT达到高性能是必要的。即使没有这些技术,训练在所有模型配置中都是高度稳定的,并且我们没有观察到在训练transformer时常见的任何损失峰值。遵循生成建模文献中的常见做法,我们在训练过程中保持DiT权重的指数移动平均(EMA),衰减为0.9999。报告的所有结果均使用EMA模型。我们几乎完全借用了ADM,在所有DiT模型大小和patch大小上使用相同的训练超参数。我们没有调整learning rates, decay/warm-up schedules, Adam β1/β2 or weight decays。

4.2 Diffusion

我们使用Stable Diffusion中现成的预训练变分自编码器(VAE)模型。VAE编码器的下采样因子为8,给定RGB图像 x 形状为256 × 256 × 3,则z = E(x)的形状为32 × 32 × 4。

在本节的所有实验中,我们的扩散模型都在这个z空间中运行。在从DiT中采样一个新的潜信号后,我们使用VAE解码器x = D(z)将其解码为像素。我们借用了来自ADM中的扩散超参数;具体来说,我们使用Tmax = 1000线性方差时间表(范围从1×10−4到2 ×10−2)、协方差Σθ参数化、输入时间步长嵌入方法、标签方法。

4.3 Evaluation metrics

我们使用FID来衡量缩放性能,这是评估图像生成模型的标准度量。

与先前工作进行比较,我们遵循惯例,并使用250 DDPM采样步骤报告FID-50K。众所周知,FID对小的实现细节很敏感;为了保证比较的准确性,本文报告的所有值都是通过导出样本并使用ADM的TensorFlow评估套件获得的。除非另有说明,本节中报告的FID编号不使用无分类器指导。我们还报告了Inception Score、sFID和Precision/Recall次要指标。

4.4 Compute

我们在JAX中实现所有模型,并使用TPU-v3 pod对它们进行训练。DiT-XL/2是Gflops最大的模型,以256全局批量大小在TPU v3-256 pod上,可以达到大约5.7次/秒训练速度。

5. Experiments

5.1 Ablation Study

5.1.1 DiT block design

训练了四个最高的 Gflop DiT-XL/2 模型,每个模型都使用不同的块设计,包括

in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops) or adaLN-zero (118.6 Gflops)

从图 5 可以看出,adaLN-Zero 块产生的 FID 低于交叉注意力和上下文条件,同时是最有效的。在 400K 训练迭代中,adaLN-Zero 模型实现的 FID 几乎是上下文模型的一半,这表明条件机制严重影响了模型质量。

初始化也很重要------adaLNZero,它将每个 DiT 块初始化为恒等函数,显着优于 vanilla adaLN。对于本文的其余部分,所有模型都使用 adaLN-Zero DiT 块。

5.1.2 Scaling model size and patch size

我们训练了 12 个 DiT 模型,model configs (S, B, L, XL) 和patch sizes (8, 4, 2)。

Figure 2(左)给出了每个模型的Gflops及其在400K训练迭代时的FID。观察到,在所有情况下,增加模型大小和减少patch size都会显著改进DiT的性能。

请注意,与其他配置相比,DiT-L 和 DiT-XL 在相对 Gflop 方面彼此明显更接近。图 2(左)概述了每个模型 Gflops 及其在 400K 训练迭代中的 FID。在所有情况下,我们发现增加模型大小和减少patch大小会产生显著改善的扩散模型。

图 6(顶部)展示了在patch size不变的情况下,FID如何随着模型大小的增加而变化?在所有四种配置中,更深更广的Transformer都会显著改善所有训练阶段的FID。

类似地,图 6(底部)显示了在模型大小保持不变的情况下,FID如何随着patch size的减小而改变,通过简单地扩展DiT处理的token数量,观察到在整个训练过程中FID有了相当大的改进。

5.1.3 DiT Gflops are critical to improving performance

这些结果表明,缩放模型 Gflopsis 实际上是提高性能的关键。

5.1.4 Larger DiT models are more compute-efficient

实验结果表明,更大training steps的小DiT模型比更小 training steps的较大的DiT模型计算效率低。

其次,Gflops相同的时候,不同patch size相同配置模型会产生不同性能,如XL/4在大约10^10Gflops后的表现优于XL/2。

5.1.5 Visualizing scaling

上图可视化了扩展缩放对样本质量的影响。在 400K 训练步骤中,我们使用identical starting noise xtmax , sampling noise and class labels,从12个DiT模型中采样结果直观地看到,配置如何影响DiT样本质量?

事实上,增加模型大小和token数量(减少patch size)可以显著提高视觉质量。

5.2 State-of-the-Art Diffusion Models

实验效果如下:

5.3 Scaling Model vs. Sampling Compute

我们在 400K 训练步骤后计算所有 12 个 DiT 模型的 FID,每张图像使用 [16, 32, 64, 128, 256, 1000] 个采样步骤。主要结果如上图所示。考虑使用1000个采样步骤的DiT-L/2和使用128步的DiT-XL/2。L/2消耗了80.7Tflops,XL/2消耗了15.2Tflops。结果表明XL/2有更好的FID-10K(23.7 vs 25.9),证明扩大采样计算量不能弥补模型的不足。

5.4 Appendix

5.4.1 A. Additional Implementation Details

Details of all DiT models

Training loss curves for all DiT models

Gflop counts for baseline diffusion models that use UNet backbones

5.4.2 DiT model details

为了嵌入输入timesteps,我们使用 256 维frequency embedding,然后是一个维度等于transformer隐藏大小和 SiLU 激活函数的两层 MLP。每个 adaLN 层将timesteps 和class embeddings总和馈送到 SiLU 非线性和MLP输出层,输出神经元等于 4 × (adaLN) 或 6× (adaLN-Zero) transfomrer隐藏大小。我们在核心transformer中使用GELU非线性(用tanh近似)。

5.4.3 Classifier-free guidance on a subset of channels

在我们使用无分类器指导的实验中,我们仅将指导应用于潜在变量的前三个通道,而不是所有四个通道。在调查后,我们发现当简单地调整guidance scale factor时,三通道引导和四通道引导给出了相似的结果。

5.4.4 D. VAE Decoder Ablations

我们使用了现成预训练好的 VAE。VAE 模型(ft-MSE 和 ft-EMA)是原始 LDM"f8"模型(仅微调解码器权重)的微调版本。故本节消融三种不同的VAE解码器,包括LDM原始解码器、Stable Diffusion两个微调解码器。表 5 显示了结果,当使用 LDM 解码器时,XL/2 继续优于所有先前的扩散模型。

5.4.5 B. Model Samples

我们展示了两个 DiT-XL/2 模型在 512 × 512 和 256 × 256 分辨率下的样本,分别训练了 3M 和 7M 步。图 1 和图 11 显示了两个模型中选择的样本。

图 14 到 33 显示了两个模型在一系列无分类器指导尺度和输入类标签(使用 250 个 DDPM 采样步骤和 ft-EMA VAE 解码器生成)中的未整理样本。与之前使用引导的工作一样,我们观察到更大的尺度增加了视觉保真度并降低样本多样性。

6. Conclusion

我们介绍了 diffusion transformer (DiTs),这是一种简单的基于 transformer 的扩散模型骨干网络架构,它优于先前的 U-Net 骨干,并继承了 transformer 模型类的优良缩放特性。鉴于本文中有希望的扩展结果,未来工作应该继续将 DiT 扩展到更大的模型和 token 数量。

也可以将 DiT 作为 DALLE2 和 Stable Diffusion 等文本到图像模型的直接主干进行探索。

备注:本文虽然只用ImageNet数据集,训练生成一张图片还是蛮耗资源的。大模型、大规模数据集、多GPU这种特耗资源的办法,像当年大型机及其应用一样应该会被淘汰,要么新硬件、要么新方法会诞生。

本文由深圳季连科技有限公司AIgraphX自动驾驶大模型团队编辑。如有错误,欢迎在评论区指正。

reference

William, P. , & Saining, X. . (2023). Scalable Diffusion Models with Transformers.

GitHub - DiT: Scalable Diffusion Models with Transformers

LDM-https://arxiv.org/abs/2112.10752

IDDPM-https://arxiv.org/abs/2102.09672

相关推荐
我喜欢就喜欢12 小时前
基于qt vs下的视频播放
开发语言·qt·音视频
安步当歌12 小时前
【WebRTC】视频采集模块中各个类的简单分析
音视频·webrtc·视频编解码·video-codec
EasyGBS13 小时前
国标GB28181公网直播EasyGBS国标GB28181软件管理解决方案
大数据·网络·音视频·媒体·视频监控·gb28181
Johnstons16 小时前
AnaTraf | 网络性能监控系统保障音视频质量的秘籍
网络·音视频·网络流量监控·网络流量分析·npmd
lrlianmengba16 小时前
推荐一款非常好用的视频编辑软件:Movavi Video Editor Plus
音视频
SZ17011023116 小时前
ffplay 实现视频流中音频的延迟
音视频·延迟
LNTON羚通17 小时前
CPU算法分析LiteAIServer视频智能分析平台视频智能分析:抖动、过亮与过暗检测技术
大数据·目标检测·音视频·视频监控
MediaTea19 小时前
Pr 视频过渡:沉浸式视频 - VR 光线
音视频·vr
几何心凉1 天前
视频自动播放被浏览器阻止及其解决方案
音视频
阿龍17871 天前
流媒体传输,降低延时和保证质量的方法(个人总结)
音视频