本文介绍了一种名为「嫁接」的技术,用于在小计算预算下通过编辑预训练 Diffusion Transformers(简称 DiTs)来探索新的模型架构设计。这种方法允许研究者在不从头开始训练模型的情况下,通过替换模型中的某些算子(如 MLP)来创建新的混合架构,从而在保持模型质量的同时减少计算量。
模型架构设计在机器学习中扮演着核心角色,与数据、算法、算力和基准测试一样重要。它定义了模型函数、算子选择(如注意力机制、卷积)和配置设定(如模型深度、宽度)等等模型要素。
尽管如此,由于从头训练模型的成本过高 ------ 尤其人们难以获得关于架构设计的深刻洞见(即哪些方案有效、哪些无效)。因此,研究新架构仍是一项挑战,对生成模型而言尤为如此。
在本文中,来自斯坦福大学、 Liquid AI 等机构的研究者探索了这一问题,即对预训练模型进行架构编辑来研究新架构。

-
论文标题: Exploring Diffusion Transformer Designs via Grafting
具体而言,该研究提出了一种编辑预训练扩散 transformer(DiT)的简单方法,即 Grafting(嫁接),该方法可以在较小的计算预算下实现新的架构。
嫁接过程如下:
(i)激活蒸馏:此阶段通过回归目标(regression objective)蒸馏原始算子的激活特征,将其功能迁移至新算子。该阶段核心在于实现算子间的功能传递。
(ii)轻量级调优:此阶段通过使用有限的数据进行调优,减轻了由于集成多个新算子而导致的误差传播。
此外,架构编辑还涵盖多种策略,如添加、删除和替换算子。

本文还基于 DiT-XL/2 构建了一个测试平台,以研究嫁接对模型质量的影响。
利用该测试平台,本文通过嫁接技术开发了一系列混合设计:用门控卷积、局部注意力和线性注意力取代 Softmax 注意力,用可变扩展率和卷积变体取代 MLP。
值得注意的是,许多混合设计使用不到 2% 的预训练计算资源就实现了良好的质量(FID:2.38--2.64,而 DiT-XL/2 为 2.27)。然后,本文嫁接了一个文本转图像模型 (PixArt-Σ),实现了 1.43 倍的加速,而 GenEval 分数下降不到 2%。
最后,本文展示了一个案例研究,该研究通过嫁接技术将每对序列 Transformer 模块转换为并行模块,从而重构了 DiT-XL/2。这将模型深度减少到原来一半,并获得了比其他同等深度模型更高的质量(FID:2.77)。
总而言之,该研究展示了可以通过预训练 DiT 来探索新的扩散模型设计,其修改范围涵盖从算子替换到架构重构。
嫁接扩散 Transformer
两阶段嫁接方法
嫁接旨在通过编辑预训练模型的计算图来实现新架构。由于该研究专注于用替代方案替换现有算子,这引出了两个问题:
问题 1:在将新算子集成到计算图之前,应该如何初始化?
对应第一阶段:通过激活蒸馏进行初始化。由于 DiT 的激活是连续且平滑的,这可以被视为一个回归问题:

问题 2:当多个算子集成到计算图时,如何减轻误差传播?
对应第二阶段:轻量级调优。随着更多算子被替换,初始化误差会不断传播,导致与预训练模型的行为出现偏差。
本文采用端到端微调来缓解阶段 1 的累积误差。微调目标函数如公式 1 所示。
实践中,本文发现,即使替换 DiT-XL/2 中的所有 MHA 或 MLP 层,仅使用 10% 的训练数据也能恢复竞争性能。

自嫁接基准
在研究新的架构设计之前,该研究引入了自嫁接(self-grafting),这是一种简单的对照设置:将现有算子(如 MHA、MLP)替换为相同类型但权重随机初始化的算子。这样可以保持计算图的结构 ------ 包括算子类型和参数数量 ------ 但改变了具体的计算过程。自嫁接有三方面作用:(1)评估在不改变架构的情况下嫁接流程本身的效果;(2)为比较不同的替换方案提供一个性能基准;(3)研究影响性能的因素,如数据规模、回归目标和超参数。
激活行为分析以及自嫁接结果
本文首先分析了 DiT-XL/2 层中的 MHA 和 MLP 算子激活行为。在这两种情况下,本文观察到激活值存在较大差异,尤其是在较深的层中(表 1 (i, ii))。

经过分析,本文得出通过选择特定于算子的回归目标,可以实现高质量的初始化。
如表 1 (iii,iv) 所示,回归目标的选择会影响性能。对于 MHA,L1 实现了最佳 FID(2.51),其次是 Huber(2.55)和 L2(2.58)。对于 MLP,L2 表现最佳(2.33),而 L1 表现不佳(2.83);值得注意的是,MLP 的参数量是 MHA 的 2 倍。
这表明高质量的初始化需要量身定制的、激活感知的策略。
研究还发现,使用 10% 的数据进行完全自嫁接可实现接近基线的性能。表明在适度的数据和计算预算下完全自嫁接是可行的。

实验
实验 I:通过嫁接实现混合架构
本节实验围绕这个问题进行:当现有算子被高效的替代方案取代时,我们能否保持模型质量?
为了探究这个问题,本文研究了以下嫁接过程:
-
待替换算子的类型 ------MHA 或 MLP;
-
替换算子的类型 ------ 例如卷积;
-
层选择策略 ------ 替换所有层中的算子或使用启发式选择;
-
替换率 ------ 全部替换或部分替换。
为了实验,该研究构建了一个测试平台,并提出两种层选择策略:完全替换和交错替换。测试平台详见表 3。

此外,该研究还引入了 Hyena-X 和 Hyena-Y 两种新的高效门控卷积算子,并设计为 MHA 的直接替代品。Figure 3 展示了它们的结构。

MHA 结果。通过嫁接替换 DiT-XL/2 中的 MHA 算子,获得了良好的质量 - 效率权衡。主要发现如下:
在交错嫁接下,较小的感受野表现出惊人的效果。实验发现,在 50% 交错替换比例下,滑动窗口注意力(SWA)、Hyena-X/Y 和 Mamba-2 等替代方案均能保持 FID 分数与基线(2.27)差距在 0.5 以内。尤其值得注意的是,尽管 SWA 和 Hyena 变体的感受野有限(卷积核 K=4 / 窗口 w=4),其 FID 下降幅度却极小。
替换策略:交错替换 vs. 完全替换。将交错替换比例从 50% 提升至 75% 时,性能通常下降,但 SWA 在 75% 交错替换下仍有效(FID=3.09)。100% 替换时,性能急剧恶化(所有 FID > 75),这与局部性分析一致,表明只有部分层是局部且适合嫁接的。
数据规模和层选择的消融实验结果。

MLP 结果显示通过嫁接的方式替换 MLP 算子是有效的。
经过实验,得出要点 1:嫁接对于在较小的计算预算下构建具有良好生成质量的高效混合架构非常有效。交错设计尤其有效。
实验 II:通过嫁接改进文本到图像的扩散 Transformers
结果。嫁接模型在实时计算速度(wall-clock time)上实现了 1.43 倍的提升,同时生成评估分数(GenEval)仅出现小幅下降(47.78 vs. 49.75)。特定属性的指标(Attribute-specific metrics)基本保持可比,并且定性样本也展现出良好的对齐度和质量。在一些纹理区域观察到了局部性的失真(artifacts),这可能是由于 LoRA 的适应能力以及所使用的合成数据质量不高所致(失败案例详见图 D.3,D.4)


要点 2:在文生图 DiTs 中成功应用嫁接技术,构建的混合架构在实现显著加速的同时,生成质量损失极小。
了解更多内容,请参考原论文。