论文:Rethinking Cross-Layer Information Routing in Diffusion Transformers(arXiv:2605.20708,2026-05)。
一、它在解决什么问题
Diffusion Transformer(DiT)已经是现代视觉生成的事实标准骨干。这些年大家把 DiT 的几乎每个设计轴都重新审视了一遍------tokenization、attention、conditioning、训练目标......唯独有一个最基础的东西被原封不动从 Transformer 继承 下来:残差流(residual stream),也就是"信息如何跨层累积"这件事。
标准残差就是逐层把子层输出加进主干:
hl=h0+∑i=0l−1fi(hi;t)h_l = h_0 + \sum_{i=0}^{l-1} f_i(h_i; t)hl=h0+i=0∑l−1fi(hi;t)
作者的核心观点是:扩散是一个随时间步 t 变化的过程,不同噪声水平本应混合不同比例的浅层/深层特征。而朴素的"逐层相加"既不区分时间步、也不区分来源,是一种被忽视的、可以被显式重新设计的轴。
二、三个诊断:传统残差到底有什么毛病
作者在 ImageNet 256×256 上训练 SiT-XL/2(600K 迭代),用 4096 个样本做了系统测量,归纳出三个症状:
1. 前向幅度单调膨胀(Forward Magnitude Inflation)
隐藏态的 RMS 从第 1 块的约 15.5 一路涨到第 28 块的约 1576,膨胀近 100 倍。配合归一化,深层块被迫输出更大的原始幅度,才能对残差流维持影响力------这就是大模型里熟知的 "PreNorm 稀释" 现象。
2. 反向梯度急剧衰减(Backward Gradient Decay)
梯度 RMS 在前 5 个块之后快速下降,早期块拿到约 5×10⁻⁷ 的梯度,深层低一个数量级以上,接近零。说明标准残差路径对梯度流的调控能力其实有限。
3. 块间冗余(Block-wise Redundancy)
相邻块输出之间的余弦相似度全深度维持在 0.9 以上,意味着深层各块之间存在显著的表示冗余。
还有一个很关键的时间维度观察:即使基线模型没有任何路由器,它对"信息源"的隐含偏好本身就随时间步 t 系统性变化------高噪声阶段和低噪声阶段偏好的来源明显不同。这正是把 t 升级为路由一级控制信号的动机。
三、方法:Diffusion-Adaptive Routing(DAR)
DAR 是一个即插即用的残差替代件 ,把"逐层相加"换成对历史子层输出做 softmax 加权聚合:
hl=∑i=0l−1αi→l(t) vi,αi→l(t)=exp(ql(t)⊤ki/d)∑j=0l−1exp(ql(t)⊤kj/d)h_l = \sum_{i=0}^{l-1} \alpha_{i\to l}(t)\, v_i, \quad \alpha_{i\to l}(t) = \frac{\exp(q_l(t)^\top k_i/\sqrt{d})}{\sum_{j=0}^{l-1}\exp(q_l(t)^\top k_j/\sqrt{d})}hl=i=0∑l−1αi→l(t)vi,αi→l(t)=∑j=0l−1exp(ql(t)⊤kj/d )exp(ql(t)⊤ki/d )
其中 vi=fi(hi;t)v_i = f_i(h_i; t)vi=fi(hi;t) 是第 i 个子层输出,ki=RMSNorm(vi)k_i = \mathrm{RMSNorm}(v_i)ki=RMSNorm(vi)。直观理解:每一层不再无差别地把前面所有东西加起来,而是用一个注意力式的权重去挑"此刻该听谁的"。三个设计要点:
① 可学习的 query,且动态优于静态
- 静态:ql=wlq_l = w_lql=wl(每层一个可学向量)
- 动态:ql(t)=Wq(l)vl−1q_l(t) = W_q^{(l)} v_{l-1}ql(t)=Wq(l)vl−1(从前一层输出派生)
和 LLM 的结论相反,这里动态变体显著更好。
② 时间步注入
可以隐式(动态 query 继承 vl−1v_{l-1}vl−1 中已编码的时间信息),也可以显式(ql(t)=wl+e(t)q_l(t) = w_l + e(t)ql(t)=wl+e(t),直接复用 DiT 现成的 time embedding,零额外参数)。两种时间感知变体都远超时间无关基线。
③ 分块聚合(chunk)降显存
把 L 个子层切成 N 个大小为 S 的块,内存从 O(Ld)O(Ld)O(Ld) 降到 O((S+N)d)O((S+N)d)O((S+N)d)。作者还用一个 rate-distortion 模型推导了最优块大小:
S∗=L⋅1−α1+α,α∈(0,1)S^* = \sqrt{L \cdot \frac{1-\alpha}{1+\alpha}}, \quad \alpha \in (0,1)S∗=L⋅1+α1−α ,α∈(0,1)
对 L=56 预测 S∗∈3.7,4.9S^* \in 3.7, 4.9S∗∈3.7,4.9,与实验最优值 S=4 吻合。
四、实验结果
基准设置:ImageNet-1K 256×256,SiT-XL/2(675M),batch 1024,lr 1e-4,bf16,评测 FID/sFID/IS/Precision/Recall(5 万样本,250 NFE)。
| 方法 | 迭代数 | FID(no CFG) | FID(CFG) |
|---|---|---|---|
| SiT 基线 | 1.75M | 9.67 | 2.15 |
| DAR 静态 c4 | 600K | 7.56 | 2.08 |
| DAR 动态 c4 | 500K | 8.07 | 2.05 |
核心结论:
- 训练效率 :达到基线相同质量只要 1/8.75 的迭代;最终 FID 7.56 vs 9.67。
- 对比 U-Net 类:DAR 静态 c4 的 FID 2.23 优于 U-DiT-L 的 3.00,且参数更少,同时保留同构 Transformer 栈的可扩展性(不用手工指定 U 形层配对)。
- 与 REPA 正交可叠加:DAR+REPA 在 100K 迭代即达到 REPA 在 200K 的水平,再获约 2× 早期加速。说明"路由级加速"和"表示级加速"是两条独立可组合的路子。
开销方面:静态版几乎零参数增长(每层一个向量,约 2MB);动态版参数从 675M 涨到 752M。作者还写了融合 Triton kernel,前向 11.5×、反向 8.5× 加速,激活显存省约 75%。
五、消融与验证
- 时间步感知是关键:400K 迭代时,时间无关静态版 FID 11.51,动态版 8.10,显式注入时间的静态版 7.97。
- 线性探针 :从聚合隐藏态线性解码 t,前 5 块 R2>0.95R^2>0.95R2>0.95、深层接近 1.0,证实动态 query 的输入确实保留了充足的时间信息。
- 块大小扫描 :S=1(不分块)FID 10.41,S=4 最优 8.39,S=8 过度压缩 11.14,呈 U 形,与理论 S∗S^*S∗ 一致。
此外作者在 Qwen-Image 上用 Distribution Matching Distillation 做了文生图后训练的初步验证,DAR 在 4 步蒸馏里能更好保留高频细节。
六、个人评价
几点感受:
- 诊断扎实。"前向幅度膨胀 + 反向梯度衰减 + 块间冗余"三件套讲得很清楚,把"残差应不应该照搬"这个问题立住了。
- 方法本质 :用一个 timestep-conditioned 的轻量 attention 去替代残差求和。卖点是训练效率而非刷新 SOTA FID,且和 REPA 等现代目标正交------这点工程上很友好。
- 一个重要前提 :DAR 不是 training-free 的。它带可学习参数、改变了前向拓扑,所有结论都来自从头训练。想直接套到 FLUX 这类预训练大模型上"白嫖"是行不通的------预训练权重基于加性残差学出来,换成路由聚合后特征基本作废,需要重训。
- 局限 :主要在 28 层的 SiT-XL/2 上验证;按 S∗S^*S∗ 与深度成正比的推断,更深的 MM-DiT / 视频 DiT 理论上收益更大,但尚待大规模验证。