深度拆解 DiT:扩散模型与 Transformer 的巅峰结合

21-DiT详解:扩散模型遇上Transformer的图像生成革命

引言

DiT(Diffusion Transformer)是Meta AI在2023年提出的突破性工作,它用纯Transformer架构实现扩散模型,在ImageNet 256×256生成任务上达到了FID 2.27的业界最佳水平,并首次在图像生成模型中展现出清晰的scaling law特性。

本文目标:深入理解DiT的四个核心组件(Patchify、向量化、位置编码、AdaLN-Zero)、推理机制和训练过程。

适合人群:了解Transformer基础和扩散模型原理的读者。


第一部分:扩散模型的数学基础

前向过程:从图像到噪声

扩散模型的前向过程是一个马尔可夫链,逐步向图像添加高斯噪声:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) </math>q(xt∣xt−1)=N(xt;1−βt xt−1,βtI)

关键性质是可以直接从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 跳到任意 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \varepsilon </math>xt=αˉt x0+1−αˉt ε

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t = ∏ i = 1 t ( 1 − β i ) \bar{\alpha}t = \prod{i=1}^{t}(1-\beta_i) </math>αˉt=∏i=1t(1−βi), <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ∼ N ( 0 , I ) \varepsilon \sim \mathcal{N}(0, I) </math>ε∼N(0,I)。

反向过程:学习去噪

训练目标是学习一个神经网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t ) \varepsilon_\theta(x_t, t) </math>εθ(xt,t) 预测噪声:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = E t , x 0 , ε [ ∥ ε − ε θ ( x t , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}{t, x_0, \varepsilon}\left[\|\varepsilon - \varepsilon\theta(x_t, t)\|^2\right] </math>L=Et,x0,ε[∥ε−εθ(xt,t)∥2]

DiT就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ \varepsilon_\theta </math>εθ 的具体实现。


第二部分:DiT的四大核心组件

DiT的核心思想是将图像视为token序列,用Transformer处理。整个架构包含四个关键设计:


组件一:Patchify(切块) - 从2D到1D的转换

Patchify的本质

Patchify是将2D图像转换为1D token序列的过程。这是将Transformer应用于图像的前提。

给定图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} </math>x∈RH×W×C,选择patch大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p(通常16或8),将图像切分成 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = H W p 2 N = \frac{HW}{p^2} </math>N=p2HW 个不重叠的patch。

每个patch是一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> p × p × C p \times p \times C </math>p×p×C 的立方体,flatten后得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> p 2 C p^2C </math>p2C 维向量。所有patch排列成序列:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Patches ∈ R N × ( p 2 C ) \text{Patches} \in \mathbb{R}^{N \times (p^2C)} </math>Patches∈RN×(p2C)

为什么Patchify是合理的?

局部性原理 :自然图像具有强局部相关性。一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 16 × 16 16 \times 16 </math>16×16 的patch(256像素)通常包含一个完整的局部语义单元。

计算效率的权衡

  • 逐像素处理: <math xmlns="http://www.w3.org/1998/Math/MathML"> 256 × 256 256\times256 </math>256×256 图像有65536个token,自注意力复杂度 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) = O ( 4.3 × 1 0 9 ) O(N^2) = O(4.3 \times 10^9) </math>O(N2)=O(4.3×109)
  • patch大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 16 p=16 </math>p=16:只有256个token,复杂度 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 6.5 × 1 0 4 ) O(6.5 \times 10^4) </math>O(6.5×104),降低了6.6万倍

信息无损:切块是可逆操作,不丢失任何像素信息。

Patch排列的顺序

DiT采用光栅扫描顺序(raster-scan order):从左到右、从上到下依次排列。

虽然Transformer的自注意力是位置不变的(打乱patch顺序输出也会相应打乱),但通过位置编码可以让模型理解patch的空间位置关系。

Patchify的深层意义

Patchify不仅是技术手段,更是认知范式的转变

  • CNN的视角:图像是2D网格,通过卷积核滑动提取局部特征
  • Transformer的视角:图像是patch的集合,每个patch通过全局注意力与其他patch交互

这种转变使得模型可以直接建模长距离依赖,而不受卷积感受野的限制。


组件二:Linear Projection(向量化) - 从像素到语义

Embedding的数学定义

Linear Projection将每个patch从原始像素空间映射到高维语义空间。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> z i = E ⋅ vec ( patch i ) + b \mathbf{z}_i = \mathbf{E} \cdot \text{vec}(\text{patch}_i) + \mathbf{b} </math>zi=E⋅vec(patchi)+b

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> E ∈ R d × ( p 2 C ) \mathbf{E} \in \mathbb{R}^{d \times (p^2C)} </math>E∈Rd×(p2C) 是投影矩阵(可学习)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是Transformer的隐藏维度(如768、1024)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> vec \text{vec} </math>vec 表示将patch展平成向量

所有patch embedding组成序列:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z = [ z 1 , z 2 , ... , z N ] ∈ R N × d \mathbf{Z} = [\mathbf{z}_1, \mathbf{z}_2, \ldots, \mathbf{z}_N] \in \mathbb{R}^{N \times d} </math>Z=[z1,z2,...,zN]∈RN×d

为什么需要Projection?

1. 维度标准化

不同的patch大小导致不同的输入维度:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 8 p=8 </math>p=8: <math xmlns="http://www.w3.org/1998/Math/MathML"> 8 2 × 3 = 192 8^2 \times 3 = 192 </math>82×3=192 维
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 16 p=16 </math>p=16: <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 6 2 × 3 = 768 16^2 \times 3 = 768 </math>162×3=768 维

投影到统一的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 维,使得模型架构与patch大小解耦,提供了架构的灵活性

2. 语义提升

原始像素值(如RGB=[125, 200, 89])是低层次的信号,投影矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> E \mathbf{E} </math>E 学习将其映射到高层次语义空间。

类比:Word Embedding将离散的词ID映射到连续的语义向量空间,Patch Embedding做的是类似的事情。

3. 计算效率

实践中,Linear Projection通常用卷积层实现:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Conv2D ( k = p , s = p , in = C , out = d ) \text{Conv2D}(k=p, s=p, \text{in}=C, \text{out}=d) </math>Conv2D(k=p,s=p,in=C,out=d)

这等价于对每个patch做矩阵乘法,但利用了卷积的并行计算优势。

Projection的初始化

投影矩阵的初始化对训练至关重要。DiT使用Xavier初始化
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E ∼ U ( − 6 p 2 C + d , 6 p 2 C + d ) \mathbf{E} \sim \mathcal{U}\left(-\sqrt{\frac{6}{p^2C + d}}, \sqrt{\frac{6}{p^2C + d}}\right) </math>E∼U(−p2C+d6 ,p2C+d6 )

这保证了初始时每层的激活值方差相近,避免梯度消失/爆炸。


组件三:Positional Encoding(位置编码) - 告诉模型"哪里"

为什么Transformer必须有位置编码?

Transformer的自注意力机制是置换等变的(permutation equivariant):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( shuffle ( X ) ) = shuffle ( Attention ( X ) ) \text{Attention}(\text{shuffle}(X)) = \text{shuffle}(\text{Attention}(X)) </math>Attention(shuffle(X))=shuffle(Attention(X))

这意味着如果打乱输入顺序,输出也会相应打乱。Transformer本身无法区分patch的位置

但图像任务中,位置信息极其关键:

  • 天空通常在上方,草地在下方
  • 物体的空间关系("猫在沙发上")依赖于位置理解

因此必须显式注入位置信息。

DiT的2D正弦位置编码

DiT采用固定的2D正弦位置编码(inherited from ViT)。

对于位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , j ) (i, j) </math>(i,j) 的patch(第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i行,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j列),其位置编码是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE ( i , j ) = [ PE x ( i ) , PE y ( j ) ] \text{PE}(i, j) = [\text{PE}_x(i), \text{PE}_y(j)] </math>PE(i,j)=[PEx(i),PEy(j)]

其中x和y坐标分别编码为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE x ( i , 2 k ) = sin ⁡ ( i 1000 0 2 k / d ) \text{PE}_x(i, 2k) = \sin\left(\frac{i}{10000^{2k/d}}\right) </math>PEx(i,2k)=sin(100002k/di)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE x ( i , 2 k + 1 ) = cos ⁡ ( i 1000 0 2 k / d ) \text{PE}_x(i, 2k+1) = \cos\left(\frac{i}{10000^{2k/d}}\right) </math>PEx(i,2k+1)=cos(100002k/di)

最终的2D位置编码是x和y编码的拼接:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE 2 D ( i , j ) ∈ R d \text{PE}_{2D}(i,j) \in \mathbb{R}^d </math>PE2D(i,j)∈Rd

前 <math xmlns="http://www.w3.org/1998/Math/MathML"> d / 2 d/2 </math>d/2 维编码x坐标,后 <math xmlns="http://www.w3.org/1998/Math/MathML"> d / 2 d/2 </math>d/2 维编码y坐标。

正弦位置编码的数学优势

1. 周期性与连续性

正弦函数是连续平滑的,相邻位置的编码向量相近,这符合图像的空间连续性假设

2. 相对位置的可表达性

通过三角恒等式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> sin ⁡ ( α + β ) = sin ⁡ α cos ⁡ β + cos ⁡ α sin ⁡ β \sin(\alpha + \beta) = \sin\alpha\cos\beta + \cos\alpha\sin\beta </math>sin(α+β)=sinαcosβ+cosαsinβ

模型可以从绝对位置编码中推导出相对位置关系。例如,位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i + 1 , j ) (i+1, j) </math>(i+1,j) 的编码可以通过位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , j ) (i, j) </math>(i,j) 的编码线性表示。

3. 外推能力

理论上,正弦编码可以泛化到训练时未见过的更大图像尺寸。虽然实践中效果有限,但这是可学习位置编码不具备的特性。

4. 参数效率

位置编码是固定的(不参与训练),节省了 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × d N \times d </math>N×d 个参数。

位置编码的注入:加法 vs 拼接

DiT使用加法注入:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z with_pos = Z + P E \mathbf{Z}_{\text{with\_pos}} = \mathbf{Z} + \mathbf{PE} </math>Zwith_pos=Z+PE

为什么不用拼接?

  • 加法 : <math xmlns="http://www.w3.org/1998/Math/MathML"> R N × d + R N × d = R N × d \mathbb{R}^{N \times d} + \mathbb{R}^{N \times d} = \mathbb{R}^{N \times d} </math>RN×d+RN×d=RN×d,维度不变
  • 拼接 : <math xmlns="http://www.w3.org/1998/Math/MathML"> [ R N × d ; R N × d ] = R N × 2 d [\mathbb{R}^{N \times d}; \mathbb{R}^{N \times d}] = \mathbb{R}^{N \times 2d} </math>[RN×d;RN×d]=RN×2d,计算量翻倍

理论上,如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 足够大,加法空间就足以让模型将"内容"和"位置"信息解耦。

实际上,这是一个线性子空间分解的假设:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z + P E = Z content + Z position \mathbf{Z} + \mathbf{PE} = \mathbf{Z}{\text{content}} + \mathbf{Z}{\text{position}} </math>Z+PE=Zcontent+Zposition

模型通过学习将混合的信息分离到不同的子空间。


组件四:AdaLN-Zero - 条件注入的核心创新

AdaLN-Zero是DiT最重要的创新,解决了"如何将时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t和类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c注入Transformer"这一核心问题。

扩散模型的条件注入难题

扩散模型需要接收两类信息:

  1. 内容信息 :噪声图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt
  2. 条件信息
    • 时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t:当前处于扩散过程的哪个阶段(关键!)
    • 类别标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c:生成什么类别的图像

传统方法:

  • 加法注入 : <math xmlns="http://www.w3.org/1998/Math/MathML"> x + f ( t , c ) \mathbf{x} + f(t, c) </math>x+f(t,c) ------ 太简单,条件易被覆盖
  • 拼接注入 : <math xmlns="http://www.w3.org/1998/Math/MathML"> [ x ; f ( t , c ) ] [\mathbf{x}; f(t, c)] </math>[x;f(t,c)] ------ 增加序列长度,计算量增大
  • Cross-Attention :将条件作为Key/Value ------ 复杂度高 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N × M ) O(N \times M) </math>O(N×M)

DiT提出了AdaLN(Adaptive Layer Normalization),一种高效且表达力强的方案。

Adaptive Layer Normalization的数学原理

标准的Layer Normalization:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> LN ( x ) = γ ⊙ x − μ σ + β \text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta </math>LN(x)=γ⊙σx−μ+β

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β 是固定的可学习参数。

AdaLN的核心思想 :让 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β 依赖于条件信息
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> γ ( c ) , β ( c ) = MLP ( c ) \gamma(\mathbf{c}), \beta(\mathbf{c}) = \text{MLP}(\mathbf{c}) </math>γ(c),β(c)=MLP(c)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> AdaLN ( x , c ) = γ ( c ) ⊙ x − μ σ + β ( c ) \text{AdaLN}(\mathbf{x}, \mathbf{c}) = \gamma(\mathbf{c}) \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta(\mathbf{c}) </math>AdaLN(x,c)=γ(c)⊙σx−μ+β(c)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> c = f ( t , c ) \mathbf{c} = f(t, c) </math>c=f(t,c) 是时间步和类别的嵌入向量。

直观理解:调制(Modulation)

AdaLN本质上是用条件信息调制特征的分布

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> γ ( c ) \gamma(\mathbf{c}) </math>γ(c):控制特征的尺度(scale)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> β ( c ) \beta(\mathbf{c}) </math>β(c):控制特征的偏移(shift)

不同的条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> c \mathbf{c} </math>c 产生不同的 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β,从而引导网络产生不同的输出。

类比 :想象一个收音机,条件信息是调频旋钮, <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β 是调制信号,特征 <math xmlns="http://www.w3.org/1998/Math/MathML"> x \mathbf{x} </math>x 是被调制的载波。

AdaLN-Zero:Zero Initialization的关键改进

DiT在AdaLN基础上加入了Zero Initialization,这是训练稳定性的核心。

标准的DiT Block结构:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 1 = x + α 1 ( c ) ⊙ Attention ( AdaLN ( x , c ) ) \mathbf{h}_1 = \mathbf{x} + \alpha_1(\mathbf{c}) \odot \text{Attention}(\text{AdaLN}(\mathbf{x}, \mathbf{c})) </math>h1=x+α1(c)⊙Attention(AdaLN(x,c))
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 2 = h 1 + α 2 ( c ) ⊙ MLP ( AdaLN ( h 1 , c ) ) \mathbf{h}_2 = \mathbf{h}_1 + \alpha_2(\mathbf{c}) \odot \text{MLP}(\text{AdaLN}(\mathbf{h}_1, \mathbf{c})) </math>h2=h1+α2(c)⊙MLP(AdaLN(h1,c))

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> α 1 , α 2 \alpha_1, \alpha_2 </math>α1,α2 是门控参数,也由条件生成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> [ γ 1 , β 1 , α 1 , γ 2 , β 2 , α 2 ] = MLP modulation ( c ) [\gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2] = \text{MLP}_{\text{modulation}}(\mathbf{c}) </math>[γ1,β1,α1,γ2,β2,α2]=MLPmodulation(c)

Zero Initialization的定义
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MLP modulation = W 2 ⋅ SiLU ( W 1 c + b 1 ) + b 2 \text{MLP}_{\text{modulation}} = W_2 \cdot \text{SiLU}(W_1 \mathbf{c} + b_1) + b_2 </math>MLPmodulation=W2⋅SiLU(W1c+b1)+b2

初始化时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W 2 = 0 , b 2 = 0 W_2 = \mathbf{0}, \quad b_2 = \mathbf{0} </math>W2=0,b2=0

这保证了训练初始时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> γ 1 = γ 2 = 1 , β 1 = β 2 = 0 , α 1 = α 2 = 0 \gamma_1 = \gamma_2 = 1, \quad \beta_1 = \beta_2 = 0, \quad \alpha_1 = \alpha_2 = 0 </math>γ1=γ2=1,β1=β2=0,α1=α2=0

因此:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 1 = x + 0 ⋅ Attention ( ⋯   ) = x \mathbf{h}_1 = \mathbf{x} + 0 \cdot \text{Attention}(\cdots) = \mathbf{x} </math>h1=x+0⋅Attention(⋯)=x
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 2 = x + 0 ⋅ MLP ( ⋯   ) = x \mathbf{h}_2 = \mathbf{x} + 0 \cdot \text{MLP}(\cdots) = \mathbf{x} </math>h2=x+0⋅MLP(⋯)=x

整个网络初始时是恒等映射 : <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = x f(\mathbf{x}) = \mathbf{x} </math>f(x)=x。

为什么Zero Initialization如此重要?

1. 梯度流动的畅通性

深度网络训练的核心挑战是梯度消失/爆炸

在恒等映射下,梯度可以无损地反向传播:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ x = ∂ L ∂ h 2 ⋅ ∂ h 2 ∂ x = ∂ L ∂ h 2 ⋅ I \frac{\partial \mathcal{L}}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot \frac{\partial \mathbf{h}_2}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot I </math>∂x∂L=∂h2∂L⋅∂x∂h2=∂h2∂L⋅I

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I 是单位矩阵,梯度直接传递,不会衰减。

2. 从简单到复杂的学习路径

随着训练进行,门控参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> α 1 , α 2 \alpha_1, \alpha_2 </math>α1,α2 从0逐渐增大,模型逐步学习利用注意力和MLP的输出。

这是一种curriculum learning(课程学习)策略:先学简单的(恒等映射),再学复杂的(注意力模式)。

3. 残差连接的极致体现

残差连接(ResNet)的核心公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h = x + F ( x ) \mathbf{h} = \mathbf{x} + F(\mathbf{x}) </math>h=x+F(x)

当 <math xmlns="http://www.w3.org/1998/Math/MathML"> F ( x ) = 0 F(\mathbf{x}) = 0 </math>F(x)=0 时,网络退化为恒等映射,保证了至少不会比浅层网络差。

AdaLN-Zero通过zero initialization,强制初始时 <math xmlns="http://www.w3.org/1998/Math/MathML"> F ( x ) = 0 F(\mathbf{x}) = 0 </math>F(x)=0,这是残差思想的最彻底实践。

AdaLN vs 其他条件注入方式

方法 计算复杂度 表达能力 训练稳定性
加法注入 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1)
拼接注入 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N ) O(N) </math>O(N)
Cross-Attention <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N ⋅ M ) O(N \cdot M) </math>O(N⋅M)
AdaLN-Zero <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1)

AdaLN-Zero的优势:

  • 零额外计算:不增加序列长度,不增加注意力计算
  • 强表达力:通过调制归一化参数,影响每一层的特征分布
  • 训练稳定:zero initialization保证梯度流畅通

第三部分:DiT的推理过程

推理就是从纯噪声逐步去噪,生成清晰图像

DDPM采样:严格的概率过程

DDPM(Denoising Diffusion Probabilistic Models)是最原始的采样算法,严格遵循扩散模型的概率推导。

单步去噪公式
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ε θ ( x t , t , c ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}} \varepsilon\theta(x_t, t, c) \right) + \sigma_t z </math>xt−1=αt 1(xt−1−αˉt 1−αtεθ(xt,t,c))+σtz

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) </math>εθ(xt,t,c) 是DiT预测的噪声
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) </math>z∼N(0,I) 是新采样的随机噪声
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = β ~ t \sigma_t = \sqrt{\tilde{\beta}_t} </math>σt=β~t 是后验方差

完整流程

  1. 初始化 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) </math>xT∼N(0,I)(纯高斯噪声)
  2. 对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = T , T − 1 , ... , 1 t = T, T-1, \ldots, 1 </math>t=T,T−1,...,1:
    • 前向传播DiT: <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ = ε θ ( x t , t , c ) \hat{\varepsilon} = \varepsilon_\theta(x_t, t, c) </math>ε^=εθ(xt,t,c)
    • 计算均值: <math xmlns="http://www.w3.org/1998/Math/MathML"> μ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ε ^ ) \mu_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\hat{\varepsilon}\right) </math>μt=αt 1(xt−1−αˉt 1−αtε^)
    • 采样噪声: <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) </math>z∼N(0,I)
    • 更新: <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 = μ t + σ t z x_{t-1} = \mu_t + \sigma_t z </math>xt−1=μt+σtz
  3. 返回 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0

特点

  • :需要1000步,每步都要前向传播DiT
  • 质量高:每步添加适量随机性,生成多样性好
  • 理论清晰 :严格遵循后验分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) </math>q(xt−1∣xt,x0)

DDIM采样:确定性加速

DDIM(Denoising Diffusion Implicit Models)通过确定性过程实现加速。

核心思想:不采样新噪声,而是走确定性的"直线路径"。

DDIM公式
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ε ^ α ˉ t ⏟ predicted x 0 + 1 − α ˉ t − 1 ε ^ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\frac{x_t - \sqrt{1-\bar{\alpha}t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}t}}}{\text{predicted } x_0} + \sqrt{1-\bar{\alpha}{t-1}} \hat{\varepsilon} </math>xt−1=αˉt−1 predicted x0 αˉt xt−1−αˉt ε^+1−αˉt−1 ε^

这个公式的直观理解:

  1. 用当前 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 和预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,估计干净图像:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ 0 = x t − 1 − α ˉ t ε ^ α ˉ t \hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}_t}} </math>x^0=αˉt xt−1−αˉt ε^

  1. 用 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ 0 \hat{x}0 </math>x^0 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,重新组合成 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x{t-1} </math>xt−1:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = α ˉ t − 1 x ^ 0 + 1 − α ˉ t − 1 ε ^ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \hat{x}0 + \sqrt{1-\bar{\alpha}{t-1}} \hat{\varepsilon} </math>xt−1=αˉt−1 x^0+1−αˉt−1 ε^

关键区别

  • DDPM:每步采样新噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,引入随机性
  • DDIM:重复使用同一个噪声估计 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,是确定性过程

加速原理 :由于是确定性的,可以跳步采样。例如:

  • DDPM: <math xmlns="http://www.w3.org/1998/Math/MathML"> 1000 → 999 → 998 → ⋯ → 1 → 0 1000 \to 999 \to 998 \to \cdots \to 1 \to 0 </math>1000→999→998→⋯→1→0(1000步)
  • DDIM: <math xmlns="http://www.w3.org/1998/Math/MathML"> 1000 → 950 → 900 → ⋯ → 50 → 0 1000 \to 950 \to 900 \to \cdots \to 50 \to 0 </math>1000→950→900→⋯→50→0(20步)

特点

  • :50步达到DDPM 1000步的效果,速度提升20倍
  • 确定性:相同初始噪声和条件,生成完全相同的图像
  • 质量略降:FID略高于DDPM,但肉眼难以区分

Classifier-Free Guidance:提升条件遵循度

CFG(Classifier-Free Guidance)是提升生成质量的关键技术。

问题:标准条件生成可能"不够听话"。指定生成"猫",模型可能生成模糊的猫,或混合其他动物特征。

解决方案:训练时同时学习条件生成和无条件生成,推理时"放大"条件影响。

CFG训练

训练时,以概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 0.1 p=0.1 </math>p=0.1 将类别标签置空:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c ′ = { ∅ 概率 0.1 c 概率 0.9 c' = \begin{cases} \emptyset & \text{概率 } 0.1 \\ c & \text{概率 } 0.9 \end{cases} </math>c′={∅c概率 0.1概率 0.9

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∅ \emptyset </math>∅ 用特殊token表示(如类别ID=1000)。

这样模型学会了两种模式:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) </math>εθ(xt,t,c):给定类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c的条件生成
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, \emptyset) </math>εθ(xt,t,∅):无条件生成
CFG推理

推理时,将两者线性组合:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε ~ = ε θ ( x t , t , ∅ ) + w ⋅ ( ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) ) \tilde{\varepsilon} = \varepsilon_\theta(x_t, t, \emptyset) + w \cdot \left(\varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset)\right) </math>ε~=εθ(xt,t,∅)+w⋅(εθ(xt,t,c)−εθ(xt,t,∅))

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w 是guidance scale(通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 7.5 w=7.5 </math>w=7.5)。

数学直观

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset) </math>εθ(xt,t,c)−εθ(xt,t,∅):条件相对于无条件的"差异方向"
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> w > 1 w > 1 </math>w>1:沿着这个方向走得更远,放大条件影响
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 1 w = 1 </math>w=1:标准条件生成
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 0 w = 0 </math>w=0:无条件生成

效果

Guidance scale <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w 类别一致性 图像多样性 图像质量
1.0 一般
3.0-5.0
7.5 (推荐) 中低 最好
15.0+ 过高 过饱和、失真

代价:CFG需要推理两次(条件+无条件),推理时间翻倍。但效果提升显著,是工业标准。

完整推理流程

结合DDIM和CFG:

输入

  • 类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c(如"猫"的ID=281)
  • 采样步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> S = 50 S=50 </math>S=50
  • Guidance scale <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 7.5 w=7.5 </math>w=7.5

算法

markdown 复制代码
1. 确定时间步序列:τ = [1000, 950, 900, ..., 50, 0](均匀采样S步)
2. 初始化:x ← N(0, I)
3. For t in τ[:-1]:
     t_next ← τ中t的下一个时间步

     # 条件预测
     ε_cond ← DiT(x, t, c)

     # 无条件预测
     ε_uncond ← DiT(x, t, ∅)

     # CFG组合
     ε̂ ← ε_uncond + w * (ε_cond - ε_uncond)

     # 估计x₀
     x̂₀ ← (x - √(1-ᾱₜ)·ε̂) / √ᾱₜ

     # DDIM更新
     x ← √ᾱₜ_ₙₑₓₜ · x̂₀ + √(1-ᾱₜ_ₙₑₓₜ) · ε̂

4. Return x

时间成本(DiT-XL,A100 GPU):

  • DDPM 1000步:约60秒/图
  • DDIM 50步 + CFG:约6秒/图

第四部分:DiT的训练过程

训练目标

DiT的训练是简单的噪声预测任务
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = E t , x 0 , ε , c [ ∥ ε − ε θ ( x t , t , c ) ∥ 2 ] \mathcal{L} = \mathbb{E}{t, x_0, \varepsilon, c}\left[\|\varepsilon - \varepsilon\theta(x_t, t, c)\|^2\right] </math>L=Et,x0,ε,c[∥ε−εθ(xt,t,c)∥2]

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon </math>xt=αˉt x0+1−αˉt ε。

训练算法

单个训练step

scss 复制代码
1. 采样一批数据:(x₀, c) ~ 数据集
2. 采样时间步:t ~ Uniform(1, T)
3. 采样噪声:ε ~ N(0, I)
4. 前向加噪:xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε
5. 预测噪声:ε̂ = DiT(xₜ, t, c)
6. 计算损失:L = ‖ε̂ - ε‖²
7. 反向传播:更新参数θ

关键训练细节

1. 噪声调度(Noise Schedule)

<math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 的设计影响训练效果。DiT使用线性调度
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β t = β min ⁡ + t − 1 T − 1 ( β max ⁡ − β min ⁡ ) \beta_t = \beta_{\min} + \frac{t-1}{T-1}(\beta_{\max} - \beta_{\min}) </math>βt=βmin+T−1t−1(βmax−βmin)

典型值: <math xmlns="http://www.w3.org/1998/Math/MathML"> β min ⁡ = 0.0001 , β max ⁡ = 0.02 , T = 1000 \beta_{\min} = 0.0001, \beta_{\max} = 0.02, T=1000 </math>βmin=0.0001,βmax=0.02,T=1000。

这意味着:

  • 早期( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t小): <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt很小,加噪缓慢,图像几乎不变
  • 后期( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t大): <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt接近0.02,加噪快速,图像迅速变成纯噪声
2. Classifier-Free Guidance训练

如前所述,训练时10%概率drop类别:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c ′ = { null token p = 0.1 c p = 0.9 c' = \begin{cases} \text{null token} & p=0.1 \\ c & p=0.9 \end{cases} </math>c′={null tokencp=0.1p=0.9

这让模型同时学会两种生成模式。

3. 学习率调度

DiT使用warmup + cosine decay
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> η t = η min ⁡ + 1 2 ( η max ⁡ − η min ⁡ ) ( 1 + cos ⁡ ( π t − t w T − t w ) ) \eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\pi \frac{t - t_w}{T - t_w}\right)\right) </math>ηt=ηmin+21(ηmax−ηmin)(1+cos(πT−twt−tw))

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> t w = 10000 t_w = 10000 </math>tw=10000:warmup步数
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> η max ⁡ = 1 0 − 4 \eta_{\max} = 10^{-4} </math>ηmax=10−4:峰值学习率
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> η min ⁡ = 1 0 − 5 \eta_{\min} = 10^{-5} </math>ηmin=10−5:最小学习率

前10000步线性增长到 <math xmlns="http://www.w3.org/1998/Math/MathML"> η max ⁡ \eta_{\max} </math>ηmax,之后按余弦函数衰减。

原理:warmup避免初期梯度过大导致发散;cosine decay比step decay更平滑。

4. EMA(Exponential Moving Average)

维护参数的指数移动平均:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ EMA ← μ θ EMA + ( 1 − μ ) θ \theta_{\text{EMA}} \leftarrow \mu \theta_{\text{EMA}} + (1-\mu)\theta </math>θEMA←μθEMA+(1−μ)θ

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ = 0.9999 \mu = 0.9999 </math>μ=0.9999。

推理时使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ EMA \theta_{\text{EMA}} </math>θEMA 而非 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ。

原理:EMA相当于对训练轨迹上的多个checkpoint做平滑,减少单个模型的抖动,提升生成质量和稳定性。

5. 混合精度训练

使用FP16计算,同时维护FP32主权重:

  • 前向传播、梯度计算:FP16
  • 参数更新:FP32

收益

  • 训练速度提升1.5-2倍
  • 显存占用减半
  • 精度损失可忽略

训练规模与成本

DiT-XL的训练配置:

项目 数值
参数量 675M
数据集 ImageNet(130万图像,1000类)
Batch size 256(8卡 × 32/卡)
训练步数 7M steps
训练时长 约1个月(8×A100 80GB)
总计算量 约10M GPU-hours
FID(256×256) 2.27

Scaling Law:DiT的惊人发现

DiT首次在图像生成模型中展现出清晰的scaling law:

模型 参数量 深度 宽度 FID ↓
DiT-S 33M 12层 384 9.62
DiT-B 130M 12层 768 5.31
DiT-L 458M 24层 1024 3.04
DiT-XL 675M 28层 1152 2.27

关键观察

  1. 性能持续提升:从DiT-S到DiT-XL,FID持续下降,没有饱和迹象
  2. 对数线性关系:FID与log(参数量)近似线性关系
  3. 类似LLM:这与语言模型的scaling law特性一致

意义

  • 更大的模型 → 更好的生成质量(确定性规律)
  • 为投资更大模型提供了理论依据
  • 预示着10B+参数的扩散模型可能带来质的飞跃

总结:DiT的意义与启示

核心贡献

1. 架构统一

证明了Transformer可以作为扩散模型的通用backbone,图像生成不再需要特定领域的架构设计。

2. AdaLN-Zero

提出了优雅的条件注入机制,在零额外计算成本下实现强大的表达能力和训练稳定性。

3. Scaling Law

首次在图像生成中展现scaling特性,为"训练更大模型"提供了理论支持。

4. 性能突破

FID 2.27(256×256 ImageNet),超越所有基于卷积的方法。

DiT的局限

1. 计算复杂度 :自注意力是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) O(N^2) </math>O(N2),分辨率越高越慢

2. 推理时间:即使DDIM,仍需50步,比单次前向慢50倍

3. 数据需求:需要大规模数据(百万级)才能充分发挥scaling优势

4. 条件类型:目前主要支持类别标签,对长文本支持有限

未来方向

1. 更高效的注意力:Sparse Attention、Linear Attention、Flash Attention

2. 更快的采样:Consistency Models(一步生成)、Latent Diffusion(低维空间扩散)

3. 更大的模型:DiT-XXL(10B参数)在更大数据集上训练

4. 多模态扩展:文本到图像、视频生成、3D生成

关键启示

  1. Transformer的通用性:不仅NLP,CV也适用
  2. Scaling的威力:更大的模型带来更好的效果
  3. 架构细节的重要性:AdaLN-Zero这样的创新带来质的提升
  4. 条件注入的本质:如何注入比注入什么更重要

DiT代表了扩散模型从CNN到Transformer的范式转变,这与NLP从RNN到Transformer的转变如出一辙。

Transformer + Diffusion = 图像生成的未来。


参考文献

  1. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
  2. Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
  3. Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
  4. Ho, J., & Salimans, T. (2022). Classifier-Free Diffusion Guidance. NeurIPS Workshop 2021.
  5. Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.
相关推荐
ZhengEnCi2 小时前
08c. 检索算法与策略-混合检索
后端·python·算法
用户7344028193422 小时前
Java 8 Stream 的终极技巧——Collectors 操作
后端
树獭叔叔2 小时前
深度拆解 VAE:生成式 AI 的潜空间大门
后端·aigc·openai
任沫2 小时前
字符串
数据结构·后端
几米哥3 小时前
GPT-5.4 深度解读:为什么说它是 OpenAI 最重要的一次升级
openai
Java编程爱好者4 小时前
2026 大厂 Java 八股文面试题库|附答案(完整版)
后端
Moment5 小时前
腾讯终于对个人开放了,5 分钟在 QQ 里养一只「真能干活」的 AI 😍😍😍
前端·后端·github
可夫小子5 小时前
OpenClaw基础-4-三分钟完成QQ机器人接入
openai·ai编程