(Arxiv-2024)SnapGen:通过高效的架构和训练,为移动设备打造高分辨率文本转图像模型

SnapGen:通过高效的架构和训练,为移动设备打造高分辨率文本转图像模型

Paper是Snap Inc.发表在Arxiv 2024的工作
Paper Title:SnapGen: Taming High-Resolution Text-to-Image Models for Mobile Devices with Efficient Architectures and Training
Code:地址

图 1. 各种文本转图像模型在模型大小、移动设备兼容性和视觉输出质量方面的比较。我们的模型只有 379M 个参数,在兼容移动设备的同时,还展现出具有竞争力的视觉质量。输入文本提示显示在每个图像网格上方;所有图像均以 10242 分辨率生成-放大以查看详细信息。

Abstract

现有的文本到图像 (T2I) 扩散模型面临一些限制,包括模型尺寸大、运行时间慢以及在移动设备上的生成质量低。本文旨在通过开发一个非常小而快速的 T2I 模型来解决所有这些挑战,该模型可以在移动平台上生成高分辨率和高质量的图像。我们提出了几种技术来实现这一目标。首先,我们系统地检查网络架构的设计选择,以减少模型参数和延迟,同时确保高质量的生成。其次,为了进一步提高生成质量,我们从更大的模型中采用跨架构知识蒸馏,使用多层次方法从头开始指导我们的模型训练。第三,我们通过将对抗性指导与知识蒸馏相结合实现了几步生成。我们的模型 SnapGen 首次展示了在移动设备上大约 1.4 秒内生成 1024 2 ^2 2 像素图像。在 ImageNet-1K 上,我们的模型只有 372M 个参数,在 2562 px 生成中实现了 2.06 的 FID。在 T2I 基准测试(即 GenEval 和 DPG-Bench)上,我们的模型只有 379M 个参数,以明显较小的尺寸超越了具有数十亿个参数的大规模模型(例如,比 SDXL 小 7 倍,比 IF-XL 小 14 倍)。

1. Introduction

大规模文本到图像 (T2I) 传播模型 [13、14、16、54、56、60--62] 在内容生成方面取得了显著成功,为图像编辑 [51、66、74、87] 和视频创建 [53、57、80] 等众多应用提供了支持。 然而,T2I 模型通常具有较大的模型大小和较慢的运行时间,并且将它们部署在云端会引发与数据安全和高成本相关的担忧 [67]。

为了应对这些挑战,人们越来越有兴趣通过模型压缩(例如,修剪和量化)[43, 69, 88]、通过蒸馏减少步骤 [79, 82] 和减轻二次复杂度的有效注意机制 [49, 76] 等技术来开发更小、更快的 T2I 模型。然而,当前的工作仍然遇到限制,例如移动设备上的低分辨率生成,这限制了它们的广泛应用。

最重要的是,一个关键问题仍未得到探索:我们如何从头开始训练 T2I 模型以在移动设备上生成高质量、高分辨率的图像?这样的模型将在速度、紧凑性、成本效益和安全部署方面提供巨大优势。为了构建此模型,我们引入了多项创新:

  • 高效的网络架构:我们对网络架构进行了深入研究,包括去噪 UNet 和自动编码器 (AE),以在资源使用和性能之间取得最佳平衡。与优化和压缩预训练扩散模型的先前研究 [10, 35, 85] 不同,我们直接关注宏观和微观层面的设计选择,以实现一种新颖的架构,大大减少模型大小和计算复杂度,同时保持高质量的生成。

  • 改进的训练技术:我们引入了几项改进,从头开始训练紧凑的 T2I 模型。 我们利用流匹配 [47, 50] 作为目标,与 SD3 [19] 和 SD3.5 [3] 等较大的模型保持一致。这种设计可以实现有效的知识和步骤蒸馏,将丰富的表示从大规模扩散模型转移到我们小得多的模型中。此外,我们提出了一种具有时间步感知缩放的多级知识蒸馏,结合了多个训练目标。 我们不是像以前的研究 [35, 49] 那样通过线性组合来加权目标,而是考虑流匹配中各个时间步长的目标预测难度(即学生与老师之间的差异)。

  • 高级步骤蒸馏:我们结合对抗性训练和知识蒸馏,使用少步教师模型(即 SD3.5-Large-Turbo [5])对我们的模型进行步骤蒸馏,仅需 4 或 8 个步骤即可实现超快速高质量生成。

我们通过大量实验证明了我们的方法和模型的卓越优势:

  • 在 ImageNet-1K [17] 类条件图像生成任务上,我们的模型实现了与现有工作相当的 FID,同时显著减少了模型大小和计算量,即与 SiT-XL[52] 相比,模型大小减少了一半,计算资源减少了三分之一,如表 1 所示。
  • 对于大规模 T2I 生成,我们的 UNet 模型只有 379M 个参数,与十亿参数模型 [45, 56, 89] 相比,其生成质量更佳,例如,基准数据集(表 3)和人工评估(图 8)上的指标有所改进。
  • 我们从头开始训练的压缩解码器与常用模型 [19, 56] 相比具有竞争力的重建质量,尺寸缩小了 36 倍以上,可实现移动部署。
  • 值得注意的是,我们首次展示了 T2I 模型在移动设备(例如 iPhone 16 Pro-Max)上实现高分辨率生成(例如1024 2 ^2 2像素)的时间约为 1.4 秒。

高分辨率文本到图像模型已经出现,它们具有先进的架构和多阶段方法,旨在增强视觉保真度和用户定制。SDXL [56] 是该领域的一项开创性工作,它采用 UNet 主干的精细级联方法来生成高细节图像,从而产生保持锐度和清晰度的照片级真实感输出。以下研究探索了不同的技术,如更先进的文本编码器、更好的图像细化或改进的数据集准备,以获得更好的文本-图像对齐或更高质量的生成 [6、7、16、21、32、39-41、44、48、51、71]。然而,这些模型中的大多数都包含数十亿个参数,这使得它们非常慢,并且无法在资源丰富的硬件(如移动设备)上运行。在这项工作中,我们的目标是构建一个小型快速的模型,即使在移动平台上也可以执行高分辨率生成。

高效的扩散模型解决了模型体积大和运行时间长的挑战。人们一直在努力探索架构优化,以消除大型模型中的冗余,并展示了在几秒钟内即可在设备上生成的模型 [11, 43, 69, 88]。然而,这些模型仅限于低分辨率输出,即 512 2 ^2 2像素。为了实现高效的高分辨率生成,SANA [76] 和 LinFusion [49] 结合了线性注意力 [9, 15, 34],以在笔记本电脑 GPU 上实现 1K 生成。相比之下,我们的目标是更广泛的平台,支持直接在移动设备上进行高分辨率生成(例如 1K)。

扩散模型中的知识蒸馏。在扩散模型的背景下,以前的研究重点是将大型、高容量的教师模型蒸馏为同质架构中更紧凑、更高效的学生模型 [35, 49]。它们通过删除某些组件(如注意力 [72] 或残差块 [24])来降低模型的复杂性,同时保持架构结构。然而,我们的方法与这一趋势不同,我们利用异构架构进行更积极、更高效但更具挑战性的蒸馏。

对抗性步骤蒸馏使用对抗性训练 [23] 中的技术来减少扩散步骤的数量,同时保持高图像质量 [63, 64]。例如,UFOGen [79] 采用扩散 GAN 公式 [73, 75, 78] 来显着减少推理时间,同时保持竞争性性能。DMD2 [81] 通过使用对抗性损失的分布匹配,在先前的蒸馏方法的基础上建立。与现有工作不同,我们在非常紧凑的模型上进行分步蒸馏,并在知识蒸馏的同时训练模型。

3. Method

在本节中,我们将介绍如何设计和训练用于高分辨率生成的高效 T2I 模型。具体来说,从潜在扩散模型 [61] 中设计的架构开始,我们优化了去噪主干(第 3.1 节,图 2 和图 3)和自动编码器(第 3.2 节和图 4),使它们即使在移动设备上也能紧凑快速。然后,我们提出了改进的训练方法和知识蒸馏(第 3.3 节和图 5),为高性能 T2I 模型提供支持。最后,我们引入了分步蒸馏,以显着减少去噪步骤的数量,从而实现更快的 T2I 模型(第 3.4 节和图 7)。

3.1. Efficient UNet Architecture


在这里,我们描述了去噪UNet的设计选择。基准架构。我们选择了SDXL中的UNet [56] 作为基准(图2(a)),因为它比纯Transformer模型 [13, 14] 更高效且收敛速度更快 [42]。我们对UNet进行了调整,成为一个更薄且更短的模型(即,将三个阶段的Transformer块数从 [ 0 , 2 , 10 ] [0,2,10] [0,2,10]减少到 [ 0 , 2 , 4 ] [0,2,4] [0,2,4],并将通道维度从[320, 640, 1280]减少到 [ 256 , 512 , 896 ] [256,512,896] [256,512,896]),并在此基础上进行了设计迭代。评估指标。我们在ImageNet1K [17] 数据集上进行类条件生成训练,训练120个epochs,除非另有说明,并报告FID分数 [26],生成图像大小为 25 6 2 256^2 2562 px。与现有工作 [32] 类似,我们通过文本模板 "a photo of " 注入类条件,并使用轻量级文本编码器进行编码,以使管道与T2I生成对齐。我们还计算了不同模型的参数数量、浮点运算量(FLOPs)(在 128 × 128 128 \times 128 128×128 的潜在大小上测量,相当于解码后为 1024 × 1024 1024 \times 1024 1024×1024的图像),以及在移动设备上的运行时间(在iPhone 15 Pro上测试)。详细的训练和评估设置可以在补充材料中找到。接下来,我们介绍了改进模型的关键架构变化。

去除高分辨率阶段的自注意力。自注意力(SA)层受到其二次计算复杂度的限制,在处理高分辨率输入时会产生较大的计算成本和内存消耗。因此,我们只在最低分辨率阶段保留SA层,而将其他高分辨率阶段的SA层去除,即图2(b)。这导致了 17 % 17 \% 17% 的FLOPs减少和 24 % 24 \% 24% 的延迟减少,如图3所示。有趣的是,我们甚至观察到性能有所提升,即FID从3.76降至3.12。我们假设,在高分辨率阶段使用SA的模型收敛速度较慢。

用扩展可分离卷积替换卷积。常规卷积(Conv)在参数和计算上都是多余的。为了解决这个问题,我们用可分离卷积(SepConv)[30]替换所有卷积层,可分离卷积由深度卷积(DW)和点卷积(PW)组成,如图 2(c)所示。这种替换将参数减少了 24%,延迟减少了 62%,但也导致性能下降(FID 从 3.12 增加到 3.38)。为了解决这个问题,我们扩展了中间通道。具体来说,第一个 PW 层之后的通道数以扩展率增加,第二个 PW 层之后的通道数减少回原始数量。扩展率设置为 2,以平衡性能、延迟和模型参数之间的权衡。这样的设计使我们的残差块与最近提出的通用倒置瓶颈(UIB)块 [58] 保持一致。结果,我们的模型实现了参数减少 15%、计算量减少 27% 和加速 2.4 倍,同时获得更低的 FID。

图 2. 高效的 UNet。从 SDXL 的 UNet 的更薄更短版本(如 (a) 所示)开始,我们探索了一系列架构变化,即 (b)--(f),以开发更小更快的模型,同时保留高质量的生成性能,如图 3 所示。

图 3. 高效 UNet 各种设计选择的性能和效率比较。使用在 ImageNet-1K 上计算的 256 2 ^2 2像素生成的 FID 来评估生成质量。效率指标包括模型参数、延迟和 FLOP。FLOP 和延迟(在 iPhone 15 Pro 上)是使用 128 × 128 128 \times 128 128×128延迟(相当于 1024 × 1024 1024 \times 1024 1024×1024解码图像)进行一次前向传递来测量的。我们展示了架构增强功能,它可以改善任何指标而不会损害其他指标。

修剪 FFN 层。对于前馈网络 (FFN) 中的层,隐藏通道扩展率默认设置为 4,并使用门控单元进一步加倍。 这大大增加了模型参数、计算和内存使用量。按照 MobileDiffusion [88],我们检查了简单降低扩展率的效果,如图 2 (d) 所示。我们表明,将扩展率降低到 3 可以保持相当的 FID 性能,同时将参数和 FLOPs 都降低 12%。

用 MQA 替换 MHSA。多头自注意力 (MHSA) 要求每个注意力头有多组键和值。相比之下,多查询注意力 (MQA) [65] 通过在所有头之间共享一组键和值而效率更高。用 MQA 替换 MHSA 可将参数减少 16%,延迟减少 9%,对性能的影响极小。有趣的是,9% 的延迟节省超过了 6% 的 FLOP 减少,因为减少的内存访问可以实现更高的计算强度。

将条件注入第一阶段。交叉注意力 (CA) 将条件信息(例如,纹理描述)与空间特征融合在一起,以生成符合条件的图像。然而,SDXL 的 UNet 仅从第二阶段开始在 Transformer 块中应用 CA,导致第一阶段缺少条件指导。作为回应,我们建议从第一阶段开始引入条件嵌入,如图 2(e) 所示。具体来说,我们将残差块替换为包含 CA 和 FFN 但没有 SA 层的 Transformer 块。这种调整使模型更小、更快、更高效,同时改善了 FID。

采用 QK RMSNorm 和 RoPE 位置嵌入。我们扩展了两种最初为语言模型开发的先进技术,即使用 RMSNorm [84] 的查询键 (QK) 规范化 [25] 和旋转位置嵌入 (RoPE) [68],以增强模型 (图 2 (f))。在注意机制中,在查询键投影之后应用的 RMSNorm 降低了 softmax 饱和的风险,同时又不牺牲模型表达能力,并稳定训练以加快收敛速度​​。此外,我们将 RoPE 从一维调整为二维,以更好地支持更高的分辨率,因为它可以显著减轻重复对象等伪影。RMSNorm 和 RoPE 结合起来,引入了可忽略不计的计算和参数开销,同时在 FID 性能方面提供了可衡量的提升。

讨论。上述优化产生了一个高效而强大的扩散主干,能够在移动设备上生成高分辨率图像。在进行大规模 T2I 训练之前,我们将模型的容量与 ImageNet1K 上的现有工作进行了比较。我们按照先前工作 [61] 的设置对模型进行了 1,000 个时期的训练。我们在不同的推理时间步中使用不同的 CFG [27, 37] 来评估模型。 如表 1 所示,我们高效的 UNet 实现了与 SiT-XL [52] 相当的 FID,同时体积却小了近 45%。

表 1. 使用 CFG 在 ImageNet 256 × 256 256 \times 256 256×256 上进行类条件图像生成。计算一次前向传递的 FLOPs。

3.2. Tiny and Fast Decoder


除了去噪模型之外,解码器也占了总运行时间的很大一部分,尤其是在设备上部署时 [43, 88]。这里我们引入了一种新的解码器架构(图 4),以实现高效的高分辨率生成。

基准解码器。我们使用来自SD3 [19] 的自编码器(AE)作为我们的基准模型(即,使用SD3 AE中的相同编码器),因为它具有优越的重建质量。AE将图像 X ∈ R H × W × 3 X \in \mathbb{R}^{H \times W \times 3} X∈RH×W×3映射到一个低维潜在变量 x ∈ R H f × W f × c x \in \mathbb{R}^{\frac{H}{f} \times \frac{W}{f} \times c} x∈RfH×fW×c(其中 f , c f, c f,c在SD3中分别为8和16)。编码后的潜在变量 x x x随后通过解码器解码回图像。对于高分辨率生成,我们观察到SD3中的解码器在移动设备上非常慢。具体来说,当在iPhone 15 Pro的ANE处理器和移动GPU上生成 102 4 2 1024^2 10242 px图像时,解码器会遇到内存溢出(OOM)错误(见表2)。为了克服延迟问题,我们提出了一种更小且更快的解码器。

高效解码器。我们进行了一系列实验来确定高效解码器,与基线架构相比,其主要变化如下:

  1. 我们移除了注意力层,从而大大减少了峰值内存,同时对解码质量没有明显影响。 2. 我们保留最少量的 GroupNorm (GN),以在延迟和性能之间找到平衡(即减轻色移)。 3. 我们使解码器更薄(即通道更少或宽度更窄),并用 SepConvs 替换 Conv。 4. 我们在高分辨率阶段使用更少的残差块。 5. 我们移除了残差块中的 Conv 跳连方式,并使用上采样层进行通道转换。

图 4. (a) SDXL/SD3 解码器与 (b) 我们的微型解码器之间的解码器架构比较。

解码器的训练。我们使用均方误差(MSE)损失、LPIPS损失 [86]、对抗损失 [23] 来训练解码器,并且丢弃了KL项 [36],因为编码器是固定的。解码器在 25 6 2 256^2 2562图像块上训练,批大小为256,训练1M次迭代。正如表2所示,我们的小型解码器在重建任务中达到了具有竞争力的PSNR分数,同时在高分辨率生成方面,相比于传统解码器(例如SDXL和SD3中的解码器),其体积小了 35.9 × 35.9 \times 35.9×,速度提高了 54.4 × 54.4 \times 54.4×,特别是在移动设备上。

设备上总延迟的讨论。最后,我们在iPhone 16 Pro-Max上测量了T2I模型生成 102 4 2 1024^2 10242 px图像的延迟。解码器的延迟为119毫秒,UNet的每步延迟为274毫秒。这导致4到8步生成的运行时间为 1.2 ∼ 2.3 1.2 \sim 2.3 1.2∼2.3秒。需要注意的是,与其他组件相比,文本编码器的运行时间可以忽略不计 [43]。

3.3. Training Recipe and Multi-Level Distillation


为了提高我们高效扩散模型的生成质量,我们提出了一系列训练技术。

基于流的训练和推理。Rectified Flows (RFs) [47, 50] 将前向过程定义为连接数据分布和标准正态分布的直线路径,即:

x t = ( 1 − σ t ) x 0 + σ t ϵ , x_t=\left(1-\sigma_t\right) x_0+\sigma_t \epsilon, xt=(1−σt)x0+σtϵ,

其中 x 0 x_0 x0是干净的(潜在的)图像, t t t是时间步, σ t \sigma_t σt是与时间步相关的因子, ϵ \epsilon ϵ是从 N ( 0 , I ) \mathcal{N}(0, I) N(0,I)中采样的随机噪声。去噪UNet被设计为预测一个速度场,目标函数为:

L task = E ϵ ∼ N ( 0 , I ) , t [ ∥ ( ϵ − x 0 ) − v θ ( x t , t ) ∥ 2 2 ] , \mathcal{L}{\text {task }}=\mathbb{E}{\epsilon \sim \mathcal{N}(0, I), t}\left[\left\|\left(\epsilon-x_0\right)-v_\theta\left(x_t, t\right)\right\|_2^2\right], Ltask =Eϵ∼N(0,I),t[∥(ϵ−x0)−vθ(xt,t)∥22],

其中 v θ ( x t , t ) v_\theta\left(x_t, t\right) vθ(xt,t)是由UNet预测的速度,参数化为 θ \theta θ。为了进一步增强训练稳定性,我们在训练过程中使用logit-normal采样 [19] 来为时间步分配更多样本,特别是将更多样本分配给中间时间步。在推理阶段,我们使用Flow-Euler采样器 [20],根据速度预测下一个样本,即:

x t − 1 = x t + ( σ t − 1 − σ t ) ⋅ v θ ( x t , t ) . x_{t-1}=x_t+\left(\sigma_{t-1}-\sigma_t\right) \cdot v_\theta\left(x_t, t\right) . xt−1=xt+(σt−1−σt)⋅vθ(xt,t).

为了在高分辨率(即 102 4 2 1024^2 10242 px)图像上实现更低的信噪比,我们应用类似SD3 [19] 的时间步偏移,在训练和推理过程中调整调度因子 σ t \sigma_t σt。

多级知识蒸馏。为了提高紧凑模型的生成质量,以前的一种常见做法是应用知识蒸馏来模仿扩大规模的教师模型的预测 [35]。 得益于对齐的流匹配目标和(AE)潜在空间,强大的 SD3.5-Large 模型 [4] 可用作输出蒸馏的老师。然而,我们仍然面临挑战,因为 1)U-Net 和 DiT 之间的异构架构,2)蒸馏损失和任务损失之间的规模差异,以及 3)不同时间步的预测难度不同。为了解决这些问题,我们提出了一种新颖的多级蒸馏损失,并应用时间步感知缩放来稳定和加速蒸馏的收敛。我们的知识蒸馏方案的概览如图 5 所示,详细技术阐述如下。

图 5.多级知识蒸馏概述,其中我们执行输出蒸馏和特征蒸馏。

除了在公式2中定义的任务损失外,知识蒸馏的主要目标是直接用教师模型 θ T \theta_T θT的输出监督我们的模型 θ \theta θ,这一点可以表示为:

L k d = E [ ∥ v θ T ( x t , t ) − v θ ( x t , t ) ∥ 2 2 ] . \mathcal{L}{\mathrm{kd}}=\mathbb{E}\left[\left\|v{\theta_T}\left(x_t, t\right)-v_\theta\left(x_t, t\right)\right\|_2^2\right] . Lkd=E[∥vθT(xt,t)−vθ(xt,t)∥22].

考虑到教师模型和我们的模型之间的能力差距,仅仅应用输出级监督会导致不稳定和收敛速度慢。因此,我们进一步引入了跨架构的特征级蒸馏损失,表示为:

L featkd = E [ ∑ ( l T , l ) ∥ f θ T l T ( x t , t ) − ψ ( f θ l ( x t , t ) ) ∥ 2 2 ] , \mathcal{L}{\text {featkd }}=\mathbb{E}\left[\sum{\left(l_T, l\right)}\left\|f_{\theta_T}^{l_T}\left(x_t, t\right)-\psi\left(f_\theta^l\left(x_t, t\right)\right)\right\|_2^2\right], Lfeatkd =E (lT,l)∑ fθTlT(xt,t)−ψ(fθl(xt,t)) 22 ,

其中 f θ T l T ( ⋅ ) f_{\theta_T}^{l_T}(\cdot) fθTlT(⋅)和 f l ( ⋅ ) f^l(\cdot) fl(⋅)分别表示教师模型和学生模型中第 l T l_T lT层和第 l l l层的特征输出。与之前的工作 [35, 49] 不同,我们考虑了从DiT到UNet的跨架构蒸馏。由于Transformer中的信息最丰富的部分通常位于最后一层,我们将蒸馏目标设置为这两种模型的最后一层,并使用一个轻量级的可训练投影器 ψ ( ⋅ ) \psi(\cdot) ψ(⋅)(仅包含两个卷积层)将学生特征映射到与教师特征相匹配的维度。提出的特征级蒸馏损失为学生模型提供了额外的监督,从而加速了与教师模型生成质量的对齐。

时间步感知缩放。在知识蒸馏中,特别是在扩散模型中,权衡多个目标一直是一个主要挑战。先前工作的整体训练目标 [35,49,69] 是多个损失项的简单线性组合,即:

L = L task + λ 1 L kd + λ 2 L featkd , \mathcal{L}=\mathcal{L}{\text {task }}+\lambda_1 \mathcal{L}{\text {kd }}+\lambda_2 \mathcal{L}_{\text {featkd }}, L=Ltask +λ1Lkd +λ2Lfeatkd ,

其中加权系数 λ 1 \lambda_1 λ1和 λ 2 \lambda_2 λ2是通过经验设置为常数。然而,这种基线设置没有考虑到不同时间步的预测难度。我们调查了在模型训练过程中, L task \mathcal{L}{\text {task }} Ltask 和 L kd \mathcal{L}{\text {kd }} Lkd 在不同时间步 t t t上的经验风险幅度分布。图6显示,在中间步骤中,预测难度较低,相较于接近0或1的 t t t。

基于这一重要观察,我们提出了一种时间步感知的目标缩放方法,以弥合不同 t t t值之间损失幅度的差距,并考虑每个时间步的预测难度,如下所示:

S ( L task , L k d ) = E t [ λ ( t ) ⋅ L task t + ( 1 − λ ( t ) ) ∣ L task t ∣ ∣ L k d t ∣ ⋅ L k d t ] , \mathcal{S}\left(\mathcal{L}{\text {task }}, \mathcal{L}{\mathrm{kd}}\right)=\mathbb{E}t\left[\lambda(t) \cdot \mathcal{L}{\text {task }}^t+(1-\lambda(t)) \frac{\left|\mathcal{L}{\text {task }}^t\right|}{\left|\mathcal{L}{\mathrm{kd}}^t\right|} \cdot \mathcal{L}_{\mathrm{kd}}^t\right], S(Ltask ,Lkd)=Et[λ(t)⋅Ltask t+(1−λ(t))∣Lkdt∣∣Ltask t∣⋅Lkdt],

其中 λ ( t ) \lambda(t) λ(t)是标准化的logit-norm密度函数(位置为0,尺度为1), ∣ ⋅ ∣ |\cdot| ∣⋅∣表示幅度。在 S \mathcal{S} S中,我们首先确保任务损失和蒸馏损失在不同 t t t之间具有相同的尺度,然后在预测难度较高的地方(即 t t t接近0或1)施加更多的教师监督,在预测难度较低的地方(即中间时间步)施加更多的真实数据监督。该方案考虑了时间步 t t t的变化,并有助于加速蒸馏训练。最终的多级蒸馏目标 L M D \mathcal{L}_{\mathrm{MD}} LMD定义为:

L M D = S ( L task , L k d ) + S ( L task , L featkd ) . \mathcal{L}{\mathrm{MD}}=\mathcal{S}\left(\mathcal{L}{\text {task }}, \mathcal{L}{\mathrm{kd}}\right)+\mathcal{S}\left(\mathcal{L}{\text {task }}, \mathcal{L}_{\text {featkd }}\right) . LMD=S(Ltask ,Lkd)+S(Ltask ,Lfeatkd ).

3.4. Step Distillation


我们进一步采用基于分布匹配的步蒸馏方案来增强我们模型的采样效率。根据潜在对抗扩散蒸馏(LADD)[63],我们使用扩散-GAN混合结构将我们的模型蒸馏到更少的步骤,优化目标为:

min ⁡ D θ T max ⁡ G θ E [ [ log ⁡ ( D θ T ( x t − 1 , t ) ) ] + [ log ⁡ ( 1 − D θ T ( x t − 1 ′ , t ) ) ] − S ( L task , L k d ) ] \begin{aligned} & \min {D{\theta_T}} \max {G\theta} \mathbb{E}\left[\left[\log \left(D_{\theta_T}\left(x_{t-1}, t\right)\right)\right]\right. \\ & \left.+\left[\log \left(1-D_{\theta_T}\left(x_{t-1}^{\prime}, t\right)\right)\right]-\mathcal{S}\left(\mathcal{L}{\text {task }}, \mathcal{L}{\mathrm{kd}}\right)\right] \end{aligned} DθTminGθmaxE[[log(DθT(xt−1,t))]+[log(1−DθT(xt−1′,t))]−S(Ltask ,Lkd)]

其中 D θ T D_{\theta_T} DθT是判别模型,部分初始化为预训练的少步教师模型 θ T \theta_T θT(SD3.5-LargeTurbo [5])。大规模教师模型仅作为特征提取器,在蒸馏过程中被冻结。我们只在特征提取后训练判别器中的少数线性层。我们从 q ( x t − 1 ∣ x 0 ) q\left(x_{t-1} \mid x_0\right) q(xt−1∣x0)中采样 x t − 1 x_{t-1} xt−1,以及从 q ( x t − 1 ∣ x 0 ′ ) q\left(x_{t-1} \mid x_0^{\prime}\right) q(xt−1∣x0′)中采样 x t − 1 ′ x_{t-1}^{\prime} xt−1′,其中 x 0 ′ x_0^{\prime} x0′是我们去噪生成器 1 G θ ( x t , t ) {}^1 G_\theta\left(x_t, t\right) 1Gθ(xt,t)的预测结果,作为我们的学生模型,而 q ( x ) q(x) q(x)是扩散模型中定义的前向过程(见公式1)。该目标包括一个对抗损失,用于匹配时间步 t − 1 t-1 t−1处的噪声样本,以及应用时间步感知缩放后的输出级蒸馏损失 S ( L task , L kd ) \mathcal{S}\left(\mathcal{L}{\text {task }}, \mathcal{L}{\text {kd }}\right) S(Ltask ,Lkd )。提出的步蒸馏(如图7所示)可以理解为通过对抗细化和知识蒸馏训练扩散模型,其中教师引导作为额外的归纳偏置。这个高级步蒸馏使我们的紧凑模型能够以极少的去噪步骤进行高质量的生成。

相关推荐
Java程序之猿1 小时前
微服务分布式(一、项目初始化)
分布式·微服务·架构
小蜗牛慢慢爬行4 小时前
Hibernate、JPA、Spring DATA JPA、Hibernate 代理和架构
java·架构·hibernate
paixiaoxin5 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
AI视觉网奇5 小时前
人脸生成3d模型 Era3D
人工智能·计算机视觉
编码小哥6 小时前
opencv中的色彩空间
opencv·计算机视觉
思忖小下6 小时前
梳理你的思路(从OOP到架构设计)_简介设计模式
设计模式·架构·eit
吃个糖糖6 小时前
34 Opencv 自定义角点检测
人工智能·opencv·计算机视觉
葡萄爱8 小时前
OpenCV图像分割
人工智能·opencv·计算机视觉
深度学习lover11 小时前
<项目代码>YOLO Visdrone航拍目标识别<目标检测>
python·yolo·目标检测·计算机视觉·visdrone航拍目标识别
一个儒雅随和的男子12 小时前
微服务详细教程之nacos和sentinel实战
微服务·架构·sentinel