21-DiT详解:扩散模型遇上Transformer的图像生成革命
引言
DiT(Diffusion Transformer)是Meta AI在2023年提出的突破性工作,它用纯Transformer架构实现扩散模型,在ImageNet 256×256生成任务上达到了FID 2.27的业界最佳水平,并首次在图像生成模型中展现出清晰的scaling law特性。
本文目标:深入理解DiT的四个核心组件(Patchify、向量化、位置编码、AdaLN-Zero)、推理机制和训练过程。
适合人群:了解Transformer基础和扩散模型原理的读者。
第一部分:扩散模型的数学基础
前向过程:从图像到噪声
扩散模型的前向过程是一个马尔可夫链,逐步向图像添加高斯噪声:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) </math>q(xt∣xt−1)=N(xt;1−βt xt−1,βtI)
关键性质是可以直接从 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0 跳到任意 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \varepsilon </math>xt=αˉt x0+1−αˉt ε
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ˉ t = ∏ i = 1 t ( 1 − β i ) \bar{\alpha}t = \prod{i=1}^{t}(1-\beta_i) </math>αˉt=∏i=1t(1−βi), <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ∼ N ( 0 , I ) \varepsilon \sim \mathcal{N}(0, I) </math>ε∼N(0,I)。
反向过程:学习去噪
训练目标是学习一个神经网络 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t ) \varepsilon_\theta(x_t, t) </math>εθ(xt,t) 预测噪声:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = E t , x 0 , ε [ ∥ ε − ε θ ( x t , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}{t, x_0, \varepsilon}\left[\|\varepsilon - \varepsilon\theta(x_t, t)\|^2\right] </math>L=Et,x0,ε[∥ε−εθ(xt,t)∥2]
DiT就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ \varepsilon_\theta </math>εθ 的具体实现。
第二部分:DiT的四大核心组件
DiT的核心思想是将图像视为token序列,用Transformer处理。整个架构包含四个关键设计:
组件一:Patchify(切块) - 从2D到1D的转换
Patchify的本质
Patchify是将2D图像转换为1D token序列的过程。这是将Transformer应用于图像的前提。
给定图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} </math>x∈RH×W×C,选择patch大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> p p </math>p(通常16或8),将图像切分成 <math xmlns="http://www.w3.org/1998/Math/MathML"> N = H W p 2 N = \frac{HW}{p^2} </math>N=p2HW 个不重叠的patch。
每个patch是一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> p × p × C p \times p \times C </math>p×p×C 的立方体,flatten后得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> p 2 C p^2C </math>p2C 维向量。所有patch排列成序列:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Patches ∈ R N × ( p 2 C ) \text{Patches} \in \mathbb{R}^{N \times (p^2C)} </math>Patches∈RN×(p2C)
为什么Patchify是合理的?
局部性原理 :自然图像具有强局部相关性。一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> 16 × 16 16 \times 16 </math>16×16 的patch(256像素)通常包含一个完整的局部语义单元。
计算效率的权衡:
- 逐像素处理: <math xmlns="http://www.w3.org/1998/Math/MathML"> 256 × 256 256\times256 </math>256×256 图像有65536个token,自注意力复杂度 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) = O ( 4.3 × 1 0 9 ) O(N^2) = O(4.3 \times 10^9) </math>O(N2)=O(4.3×109)
- patch大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 16 p=16 </math>p=16:只有256个token,复杂度 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 6.5 × 1 0 4 ) O(6.5 \times 10^4) </math>O(6.5×104),降低了6.6万倍
信息无损:切块是可逆操作,不丢失任何像素信息。
Patch排列的顺序
DiT采用光栅扫描顺序(raster-scan order):从左到右、从上到下依次排列。
虽然Transformer的自注意力是位置不变的(打乱patch顺序输出也会相应打乱),但通过位置编码可以让模型理解patch的空间位置关系。
Patchify的深层意义
Patchify不仅是技术手段,更是认知范式的转变:
- CNN的视角:图像是2D网格,通过卷积核滑动提取局部特征
- Transformer的视角:图像是patch的集合,每个patch通过全局注意力与其他patch交互
这种转变使得模型可以直接建模长距离依赖,而不受卷积感受野的限制。
组件二:Linear Projection(向量化) - 从像素到语义
Embedding的数学定义
Linear Projection将每个patch从原始像素空间映射到高维语义空间。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> z i = E ⋅ vec ( patch i ) + b \mathbf{z}_i = \mathbf{E} \cdot \text{vec}(\text{patch}_i) + \mathbf{b} </math>zi=E⋅vec(patchi)+b
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> E ∈ R d × ( p 2 C ) \mathbf{E} \in \mathbb{R}^{d \times (p^2C)} </math>E∈Rd×(p2C) 是投影矩阵(可学习)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是Transformer的隐藏维度(如768、1024)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> vec \text{vec} </math>vec 表示将patch展平成向量
所有patch embedding组成序列:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z = [ z 1 , z 2 , ... , z N ] ∈ R N × d \mathbf{Z} = [\mathbf{z}_1, \mathbf{z}_2, \ldots, \mathbf{z}_N] \in \mathbb{R}^{N \times d} </math>Z=[z1,z2,...,zN]∈RN×d
为什么需要Projection?
1. 维度标准化
不同的patch大小导致不同的输入维度:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 8 p=8 </math>p=8: <math xmlns="http://www.w3.org/1998/Math/MathML"> 8 2 × 3 = 192 8^2 \times 3 = 192 </math>82×3=192 维
- <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 16 p=16 </math>p=16: <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 6 2 × 3 = 768 16^2 \times 3 = 768 </math>162×3=768 维
投影到统一的 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 维,使得模型架构与patch大小解耦,提供了架构的灵活性。
2. 语义提升
原始像素值(如RGB=[125, 200, 89])是低层次的信号,投影矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> E \mathbf{E} </math>E 学习将其映射到高层次语义空间。
类比:Word Embedding将离散的词ID映射到连续的语义向量空间,Patch Embedding做的是类似的事情。
3. 计算效率
实践中,Linear Projection通常用卷积层实现:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Conv2D ( k = p , s = p , in = C , out = d ) \text{Conv2D}(k=p, s=p, \text{in}=C, \text{out}=d) </math>Conv2D(k=p,s=p,in=C,out=d)
这等价于对每个patch做矩阵乘法,但利用了卷积的并行计算优势。
Projection的初始化
投影矩阵的初始化对训练至关重要。DiT使用Xavier初始化:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> E ∼ U ( − 6 p 2 C + d , 6 p 2 C + d ) \mathbf{E} \sim \mathcal{U}\left(-\sqrt{\frac{6}{p^2C + d}}, \sqrt{\frac{6}{p^2C + d}}\right) </math>E∼U(−p2C+d6 ,p2C+d6 )
这保证了初始时每层的激活值方差相近,避免梯度消失/爆炸。
组件三:Positional Encoding(位置编码) - 告诉模型"哪里"
为什么Transformer必须有位置编码?
Transformer的自注意力机制是置换等变的(permutation equivariant):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( shuffle ( X ) ) = shuffle ( Attention ( X ) ) \text{Attention}(\text{shuffle}(X)) = \text{shuffle}(\text{Attention}(X)) </math>Attention(shuffle(X))=shuffle(Attention(X))
这意味着如果打乱输入顺序,输出也会相应打乱。Transformer本身无法区分patch的位置。
但图像任务中,位置信息极其关键:
- 天空通常在上方,草地在下方
- 物体的空间关系("猫在沙发上")依赖于位置理解
因此必须显式注入位置信息。
DiT的2D正弦位置编码
DiT采用固定的2D正弦位置编码(inherited from ViT)。
对于位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , j ) (i, j) </math>(i,j) 的patch(第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i行,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> j j </math>j列),其位置编码是:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE ( i , j ) = [ PE x ( i ) , PE y ( j ) ] \text{PE}(i, j) = [\text{PE}_x(i), \text{PE}_y(j)] </math>PE(i,j)=[PEx(i),PEy(j)]
其中x和y坐标分别编码为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE x ( i , 2 k ) = sin ( i 1000 0 2 k / d ) \text{PE}_x(i, 2k) = \sin\left(\frac{i}{10000^{2k/d}}\right) </math>PEx(i,2k)=sin(100002k/di)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE x ( i , 2 k + 1 ) = cos ( i 1000 0 2 k / d ) \text{PE}_x(i, 2k+1) = \cos\left(\frac{i}{10000^{2k/d}}\right) </math>PEx(i,2k+1)=cos(100002k/di)
最终的2D位置编码是x和y编码的拼接:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> PE 2 D ( i , j ) ∈ R d \text{PE}_{2D}(i,j) \in \mathbb{R}^d </math>PE2D(i,j)∈Rd
前 <math xmlns="http://www.w3.org/1998/Math/MathML"> d / 2 d/2 </math>d/2 维编码x坐标,后 <math xmlns="http://www.w3.org/1998/Math/MathML"> d / 2 d/2 </math>d/2 维编码y坐标。
正弦位置编码的数学优势
1. 周期性与连续性
正弦函数是连续平滑的,相邻位置的编码向量相近,这符合图像的空间连续性假设。
2. 相对位置的可表达性
通过三角恒等式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> sin ( α + β ) = sin α cos β + cos α sin β \sin(\alpha + \beta) = \sin\alpha\cos\beta + \cos\alpha\sin\beta </math>sin(α+β)=sinαcosβ+cosαsinβ
模型可以从绝对位置编码中推导出相对位置关系。例如,位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i + 1 , j ) (i+1, j) </math>(i+1,j) 的编码可以通过位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( i , j ) (i, j) </math>(i,j) 的编码线性表示。
3. 外推能力
理论上,正弦编码可以泛化到训练时未见过的更大图像尺寸。虽然实践中效果有限,但这是可学习位置编码不具备的特性。
4. 参数效率
位置编码是固定的(不参与训练),节省了 <math xmlns="http://www.w3.org/1998/Math/MathML"> N × d N \times d </math>N×d 个参数。
位置编码的注入:加法 vs 拼接
DiT使用加法注入:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z with_pos = Z + P E \mathbf{Z}_{\text{with\_pos}} = \mathbf{Z} + \mathbf{PE} </math>Zwith_pos=Z+PE
为什么不用拼接?
- 加法 : <math xmlns="http://www.w3.org/1998/Math/MathML"> R N × d + R N × d = R N × d \mathbb{R}^{N \times d} + \mathbb{R}^{N \times d} = \mathbb{R}^{N \times d} </math>RN×d+RN×d=RN×d,维度不变
- 拼接 : <math xmlns="http://www.w3.org/1998/Math/MathML"> [ R N × d ; R N × d ] = R N × 2 d [\mathbb{R}^{N \times d}; \mathbb{R}^{N \times d}] = \mathbb{R}^{N \times 2d} </math>[RN×d;RN×d]=RN×2d,计算量翻倍
理论上,如果 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 足够大,加法空间就足以让模型将"内容"和"位置"信息解耦。
实际上,这是一个线性子空间分解的假设:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Z + P E = Z content + Z position \mathbf{Z} + \mathbf{PE} = \mathbf{Z}{\text{content}} + \mathbf{Z}{\text{position}} </math>Z+PE=Zcontent+Zposition
模型通过学习将混合的信息分离到不同的子空间。
组件四:AdaLN-Zero - 条件注入的核心创新
AdaLN-Zero是DiT最重要的创新,解决了"如何将时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t和类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c注入Transformer"这一核心问题。
扩散模型的条件注入难题
扩散模型需要接收两类信息:
- 内容信息 :噪声图像 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt
- 条件信息 :
- 时间步 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t:当前处于扩散过程的哪个阶段(关键!)
- 类别标签 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c:生成什么类别的图像
传统方法:
- 加法注入 : <math xmlns="http://www.w3.org/1998/Math/MathML"> x + f ( t , c ) \mathbf{x} + f(t, c) </math>x+f(t,c) ------ 太简单,条件易被覆盖
- 拼接注入 : <math xmlns="http://www.w3.org/1998/Math/MathML"> [ x ; f ( t , c ) ] [\mathbf{x}; f(t, c)] </math>[x;f(t,c)] ------ 增加序列长度,计算量增大
- Cross-Attention :将条件作为Key/Value ------ 复杂度高 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N × M ) O(N \times M) </math>O(N×M)
DiT提出了AdaLN(Adaptive Layer Normalization),一种高效且表达力强的方案。
Adaptive Layer Normalization的数学原理
标准的Layer Normalization:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> LN ( x ) = γ ⊙ x − μ σ + β \text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta </math>LN(x)=γ⊙σx−μ+β
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β 是固定的可学习参数。
AdaLN的核心思想 :让 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β 依赖于条件信息:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> γ ( c ) , β ( c ) = MLP ( c ) \gamma(\mathbf{c}), \beta(\mathbf{c}) = \text{MLP}(\mathbf{c}) </math>γ(c),β(c)=MLP(c)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> AdaLN ( x , c ) = γ ( c ) ⊙ x − μ σ + β ( c ) \text{AdaLN}(\mathbf{x}, \mathbf{c}) = \gamma(\mathbf{c}) \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta(\mathbf{c}) </math>AdaLN(x,c)=γ(c)⊙σx−μ+β(c)
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> c = f ( t , c ) \mathbf{c} = f(t, c) </math>c=f(t,c) 是时间步和类别的嵌入向量。
直观理解:调制(Modulation)
AdaLN本质上是用条件信息调制特征的分布。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> γ ( c ) \gamma(\mathbf{c}) </math>γ(c):控制特征的尺度(scale)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> β ( c ) \beta(\mathbf{c}) </math>β(c):控制特征的偏移(shift)
不同的条件 <math xmlns="http://www.w3.org/1998/Math/MathML"> c \mathbf{c} </math>c 产生不同的 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β,从而引导网络产生不同的输出。
类比 :想象一个收音机,条件信息是调频旋钮, <math xmlns="http://www.w3.org/1998/Math/MathML"> γ , β \gamma, \beta </math>γ,β 是调制信号,特征 <math xmlns="http://www.w3.org/1998/Math/MathML"> x \mathbf{x} </math>x 是被调制的载波。
AdaLN-Zero:Zero Initialization的关键改进
DiT在AdaLN基础上加入了Zero Initialization,这是训练稳定性的核心。
标准的DiT Block结构:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 1 = x + α 1 ( c ) ⊙ Attention ( AdaLN ( x , c ) ) \mathbf{h}_1 = \mathbf{x} + \alpha_1(\mathbf{c}) \odot \text{Attention}(\text{AdaLN}(\mathbf{x}, \mathbf{c})) </math>h1=x+α1(c)⊙Attention(AdaLN(x,c))
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 2 = h 1 + α 2 ( c ) ⊙ MLP ( AdaLN ( h 1 , c ) ) \mathbf{h}_2 = \mathbf{h}_1 + \alpha_2(\mathbf{c}) \odot \text{MLP}(\text{AdaLN}(\mathbf{h}_1, \mathbf{c})) </math>h2=h1+α2(c)⊙MLP(AdaLN(h1,c))
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> α 1 , α 2 \alpha_1, \alpha_2 </math>α1,α2 是门控参数,也由条件生成:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> [ γ 1 , β 1 , α 1 , γ 2 , β 2 , α 2 ] = MLP modulation ( c ) [\gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2] = \text{MLP}_{\text{modulation}}(\mathbf{c}) </math>[γ1,β1,α1,γ2,β2,α2]=MLPmodulation(c)
Zero Initialization的定义:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> MLP modulation = W 2 ⋅ SiLU ( W 1 c + b 1 ) + b 2 \text{MLP}_{\text{modulation}} = W_2 \cdot \text{SiLU}(W_1 \mathbf{c} + b_1) + b_2 </math>MLPmodulation=W2⋅SiLU(W1c+b1)+b2
初始化时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> W 2 = 0 , b 2 = 0 W_2 = \mathbf{0}, \quad b_2 = \mathbf{0} </math>W2=0,b2=0
这保证了训练初始时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> γ 1 = γ 2 = 1 , β 1 = β 2 = 0 , α 1 = α 2 = 0 \gamma_1 = \gamma_2 = 1, \quad \beta_1 = \beta_2 = 0, \quad \alpha_1 = \alpha_2 = 0 </math>γ1=γ2=1,β1=β2=0,α1=α2=0
因此:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 1 = x + 0 ⋅ Attention ( ⋯ ) = x \mathbf{h}_1 = \mathbf{x} + 0 \cdot \text{Attention}(\cdots) = \mathbf{x} </math>h1=x+0⋅Attention(⋯)=x
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 2 = x + 0 ⋅ MLP ( ⋯ ) = x \mathbf{h}_2 = \mathbf{x} + 0 \cdot \text{MLP}(\cdots) = \mathbf{x} </math>h2=x+0⋅MLP(⋯)=x
整个网络初始时是恒等映射 : <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( x ) = x f(\mathbf{x}) = \mathbf{x} </math>f(x)=x。
为什么Zero Initialization如此重要?
1. 梯度流动的畅通性
深度网络训练的核心挑战是梯度消失/爆炸。
在恒等映射下,梯度可以无损地反向传播:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ x = ∂ L ∂ h 2 ⋅ ∂ h 2 ∂ x = ∂ L ∂ h 2 ⋅ I \frac{\partial \mathcal{L}}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot \frac{\partial \mathbf{h}_2}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot I </math>∂x∂L=∂h2∂L⋅∂x∂h2=∂h2∂L⋅I
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> I I </math>I 是单位矩阵,梯度直接传递,不会衰减。
2. 从简单到复杂的学习路径
随着训练进行,门控参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> α 1 , α 2 \alpha_1, \alpha_2 </math>α1,α2 从0逐渐增大,模型逐步学习利用注意力和MLP的输出。
这是一种curriculum learning(课程学习)策略:先学简单的(恒等映射),再学复杂的(注意力模式)。
3. 残差连接的极致体现
残差连接(ResNet)的核心公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h = x + F ( x ) \mathbf{h} = \mathbf{x} + F(\mathbf{x}) </math>h=x+F(x)
当 <math xmlns="http://www.w3.org/1998/Math/MathML"> F ( x ) = 0 F(\mathbf{x}) = 0 </math>F(x)=0 时,网络退化为恒等映射,保证了至少不会比浅层网络差。
AdaLN-Zero通过zero initialization,强制初始时 <math xmlns="http://www.w3.org/1998/Math/MathML"> F ( x ) = 0 F(\mathbf{x}) = 0 </math>F(x)=0,这是残差思想的最彻底实践。
AdaLN vs 其他条件注入方式
| 方法 | 计算复杂度 | 表达能力 | 训练稳定性 |
|---|---|---|---|
| 加法注入 | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1) | 弱 | 中 |
| 拼接注入 | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N ) O(N) </math>O(N) | 中 | 中 |
| Cross-Attention | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N ⋅ M ) O(N \cdot M) </math>O(N⋅M) | 强 | 中 |
| AdaLN-Zero | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) O(1) </math>O(1) | 强 | 优 |
AdaLN-Zero的优势:
- 零额外计算:不增加序列长度,不增加注意力计算
- 强表达力:通过调制归一化参数,影响每一层的特征分布
- 训练稳定:zero initialization保证梯度流畅通
第三部分:DiT的推理过程
推理就是从纯噪声逐步去噪,生成清晰图像。
DDPM采样:严格的概率过程
DDPM(Denoising Diffusion Probabilistic Models)是最原始的采样算法,严格遵循扩散模型的概率推导。
单步去噪公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ε θ ( x t , t , c ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}} \varepsilon\theta(x_t, t, c) \right) + \sigma_t z </math>xt−1=αt 1(xt−1−αˉt 1−αtεθ(xt,t,c))+σtz
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) </math>εθ(xt,t,c) 是DiT预测的噪声
- <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) </math>z∼N(0,I) 是新采样的随机噪声
- <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = β ~ t \sigma_t = \sqrt{\tilde{\beta}_t} </math>σt=β~t 是后验方差
完整流程:
- 初始化 <math xmlns="http://www.w3.org/1998/Math/MathML"> x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) </math>xT∼N(0,I)(纯高斯噪声)
- 对于 <math xmlns="http://www.w3.org/1998/Math/MathML"> t = T , T − 1 , ... , 1 t = T, T-1, \ldots, 1 </math>t=T,T−1,...,1:
- 前向传播DiT: <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ = ε θ ( x t , t , c ) \hat{\varepsilon} = \varepsilon_\theta(x_t, t, c) </math>ε^=εθ(xt,t,c)
- 计算均值: <math xmlns="http://www.w3.org/1998/Math/MathML"> μ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ε ^ ) \mu_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\hat{\varepsilon}\right) </math>μt=αt 1(xt−1−αˉt 1−αtε^)
- 采样噪声: <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) </math>z∼N(0,I)
- 更新: <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 = μ t + σ t z x_{t-1} = \mu_t + \sigma_t z </math>xt−1=μt+σtz
- 返回 <math xmlns="http://www.w3.org/1998/Math/MathML"> x 0 x_0 </math>x0
特点:
- 慢:需要1000步,每步都要前向传播DiT
- 质量高:每步添加适量随机性,生成多样性好
- 理论清晰 :严格遵循后验分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) </math>q(xt−1∣xt,x0)
DDIM采样:确定性加速
DDIM(Denoising Diffusion Implicit Models)通过确定性过程实现加速。
核心思想:不采样新噪声,而是走确定性的"直线路径"。
DDIM公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ε ^ α ˉ t ⏟ predicted x 0 + 1 − α ˉ t − 1 ε ^ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\frac{x_t - \sqrt{1-\bar{\alpha}t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}t}}}{\text{predicted } x_0} + \sqrt{1-\bar{\alpha}{t-1}} \hat{\varepsilon} </math>xt−1=αˉt−1 predicted x0 αˉt xt−1−αˉt ε^+1−αˉt−1 ε^
这个公式的直观理解:
- 用当前 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t x_t </math>xt 和预测噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,估计干净图像:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x ^ 0 = x t − 1 − α ˉ t ε ^ α ˉ t \hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}_t}} </math>x^0=αˉt xt−1−αˉt ε^
- 用 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ 0 \hat{x}0 </math>x^0 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,重新组合成 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t − 1 x{t-1} </math>xt−1:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> x t − 1 = α ˉ t − 1 x ^ 0 + 1 − α ˉ t − 1 ε ^ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \hat{x}0 + \sqrt{1-\bar{\alpha}{t-1}} \hat{\varepsilon} </math>xt−1=αˉt−1 x^0+1−αˉt−1 ε^
关键区别:
- DDPM:每步采样新噪声 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z,引入随机性
- DDIM:重复使用同一个噪声估计 <math xmlns="http://www.w3.org/1998/Math/MathML"> ε ^ \hat{\varepsilon} </math>ε^,是确定性过程
加速原理 :由于是确定性的,可以跳步采样。例如:
- DDPM: <math xmlns="http://www.w3.org/1998/Math/MathML"> 1000 → 999 → 998 → ⋯ → 1 → 0 1000 \to 999 \to 998 \to \cdots \to 1 \to 0 </math>1000→999→998→⋯→1→0(1000步)
- DDIM: <math xmlns="http://www.w3.org/1998/Math/MathML"> 1000 → 950 → 900 → ⋯ → 50 → 0 1000 \to 950 \to 900 \to \cdots \to 50 \to 0 </math>1000→950→900→⋯→50→0(20步)
特点:
- 快:50步达到DDPM 1000步的效果,速度提升20倍
- 确定性:相同初始噪声和条件,生成完全相同的图像
- 质量略降:FID略高于DDPM,但肉眼难以区分
Classifier-Free Guidance:提升条件遵循度
CFG(Classifier-Free Guidance)是提升生成质量的关键技术。
问题:标准条件生成可能"不够听话"。指定生成"猫",模型可能生成模糊的猫,或混合其他动物特征。
解决方案:训练时同时学习条件生成和无条件生成,推理时"放大"条件影响。
CFG训练
训练时,以概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> p = 0.1 p=0.1 </math>p=0.1 将类别标签置空:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c ′ = { ∅ 概率 0.1 c 概率 0.9 c' = \begin{cases} \emptyset & \text{概率 } 0.1 \\ c & \text{概率 } 0.9 \end{cases} </math>c′={∅c概率 0.1概率 0.9
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∅ \emptyset </math>∅ 用特殊token表示(如类别ID=1000)。
这样模型学会了两种模式:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) </math>εθ(xt,t,c):给定类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c的条件生成
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, \emptyset) </math>εθ(xt,t,∅):无条件生成
CFG推理
推理时,将两者线性组合:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ε ~ = ε θ ( x t , t , ∅ ) + w ⋅ ( ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) ) \tilde{\varepsilon} = \varepsilon_\theta(x_t, t, \emptyset) + w \cdot \left(\varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset)\right) </math>ε~=εθ(xt,t,∅)+w⋅(εθ(xt,t,c)−εθ(xt,t,∅))
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w 是guidance scale(通常 <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 7.5 w=7.5 </math>w=7.5)。
数学直观:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset) </math>εθ(xt,t,c)−εθ(xt,t,∅):条件相对于无条件的"差异方向"
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w > 1 w > 1 </math>w>1:沿着这个方向走得更远,放大条件影响
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 1 w = 1 </math>w=1:标准条件生成
- <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 0 w = 0 </math>w=0:无条件生成
效果:
| Guidance scale <math xmlns="http://www.w3.org/1998/Math/MathML"> w w </math>w | 类别一致性 | 图像多样性 | 图像质量 |
|---|---|---|---|
| 1.0 | 低 | 高 | 一般 |
| 3.0-5.0 | 中 | 中 | 好 |
| 7.5 (推荐) | 高 | 中低 | 最好 |
| 15.0+ | 过高 | 低 | 过饱和、失真 |
代价:CFG需要推理两次(条件+无条件),推理时间翻倍。但效果提升显著,是工业标准。
完整推理流程
结合DDIM和CFG:
输入:
- 类别 <math xmlns="http://www.w3.org/1998/Math/MathML"> c c </math>c(如"猫"的ID=281)
- 采样步数 <math xmlns="http://www.w3.org/1998/Math/MathML"> S = 50 S=50 </math>S=50
- Guidance scale <math xmlns="http://www.w3.org/1998/Math/MathML"> w = 7.5 w=7.5 </math>w=7.5
算法:
markdown
1. 确定时间步序列:τ = [1000, 950, 900, ..., 50, 0](均匀采样S步)
2. 初始化:x ← N(0, I)
3. For t in τ[:-1]:
t_next ← τ中t的下一个时间步
# 条件预测
ε_cond ← DiT(x, t, c)
# 无条件预测
ε_uncond ← DiT(x, t, ∅)
# CFG组合
ε̂ ← ε_uncond + w * (ε_cond - ε_uncond)
# 估计x₀
x̂₀ ← (x - √(1-ᾱₜ)·ε̂) / √ᾱₜ
# DDIM更新
x ← √ᾱₜ_ₙₑₓₜ · x̂₀ + √(1-ᾱₜ_ₙₑₓₜ) · ε̂
4. Return x
时间成本(DiT-XL,A100 GPU):
- DDPM 1000步:约60秒/图
- DDIM 50步 + CFG:约6秒/图
第四部分:DiT的训练过程
训练目标
DiT的训练是简单的噪声预测任务:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = E t , x 0 , ε , c [ ∥ ε − ε θ ( x t , t , c ) ∥ 2 ] \mathcal{L} = \mathbb{E}{t, x_0, \varepsilon, c}\left[\|\varepsilon - \varepsilon\theta(x_t, t, c)\|^2\right] </math>L=Et,x0,ε,c[∥ε−εθ(xt,t,c)∥2]
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon </math>xt=αˉt x0+1−αˉt ε。
训练算法
单个训练step:
scss
1. 采样一批数据:(x₀, c) ~ 数据集
2. 采样时间步:t ~ Uniform(1, T)
3. 采样噪声:ε ~ N(0, I)
4. 前向加噪:xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε
5. 预测噪声:ε̂ = DiT(xₜ, t, c)
6. 计算损失:L = ‖ε̂ - ε‖²
7. 反向传播:更新参数θ
关键训练细节
1. 噪声调度(Noise Schedule)
<math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt 的设计影响训练效果。DiT使用线性调度:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> β t = β min + t − 1 T − 1 ( β max − β min ) \beta_t = \beta_{\min} + \frac{t-1}{T-1}(\beta_{\max} - \beta_{\min}) </math>βt=βmin+T−1t−1(βmax−βmin)
典型值: <math xmlns="http://www.w3.org/1998/Math/MathML"> β min = 0.0001 , β max = 0.02 , T = 1000 \beta_{\min} = 0.0001, \beta_{\max} = 0.02, T=1000 </math>βmin=0.0001,βmax=0.02,T=1000。
这意味着:
- 早期( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t小): <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt很小,加噪缓慢,图像几乎不变
- 后期( <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t大): <math xmlns="http://www.w3.org/1998/Math/MathML"> β t \beta_t </math>βt接近0.02,加噪快速,图像迅速变成纯噪声
2. Classifier-Free Guidance训练
如前所述,训练时10%概率drop类别:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c ′ = { null token p = 0.1 c p = 0.9 c' = \begin{cases} \text{null token} & p=0.1 \\ c & p=0.9 \end{cases} </math>c′={null tokencp=0.1p=0.9
这让模型同时学会两种生成模式。
3. 学习率调度
DiT使用warmup + cosine decay:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> η t = η min + 1 2 ( η max − η min ) ( 1 + cos ( π t − t w T − t w ) ) \eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\pi \frac{t - t_w}{T - t_w}\right)\right) </math>ηt=ηmin+21(ηmax−ηmin)(1+cos(πT−twt−tw))
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> t w = 10000 t_w = 10000 </math>tw=10000:warmup步数
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η max = 1 0 − 4 \eta_{\max} = 10^{-4} </math>ηmax=10−4:峰值学习率
- <math xmlns="http://www.w3.org/1998/Math/MathML"> η min = 1 0 − 5 \eta_{\min} = 10^{-5} </math>ηmin=10−5:最小学习率
前10000步线性增长到 <math xmlns="http://www.w3.org/1998/Math/MathML"> η max \eta_{\max} </math>ηmax,之后按余弦函数衰减。
原理:warmup避免初期梯度过大导致发散;cosine decay比step decay更平滑。
4. EMA(Exponential Moving Average)
维护参数的指数移动平均:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> θ EMA ← μ θ EMA + ( 1 − μ ) θ \theta_{\text{EMA}} \leftarrow \mu \theta_{\text{EMA}} + (1-\mu)\theta </math>θEMA←μθEMA+(1−μ)θ
其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> μ = 0.9999 \mu = 0.9999 </math>μ=0.9999。
推理时使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ EMA \theta_{\text{EMA}} </math>θEMA 而非 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ \theta </math>θ。
原理:EMA相当于对训练轨迹上的多个checkpoint做平滑,减少单个模型的抖动,提升生成质量和稳定性。
5. 混合精度训练
使用FP16计算,同时维护FP32主权重:
- 前向传播、梯度计算:FP16
- 参数更新:FP32
收益:
- 训练速度提升1.5-2倍
- 显存占用减半
- 精度损失可忽略
训练规模与成本
DiT-XL的训练配置:
| 项目 | 数值 |
|---|---|
| 参数量 | 675M |
| 数据集 | ImageNet(130万图像,1000类) |
| Batch size | 256(8卡 × 32/卡) |
| 训练步数 | 7M steps |
| 训练时长 | 约1个月(8×A100 80GB) |
| 总计算量 | 约10M GPU-hours |
| FID(256×256) | 2.27 |
Scaling Law:DiT的惊人发现
DiT首次在图像生成模型中展现出清晰的scaling law:
| 模型 | 参数量 | 深度 | 宽度 | FID ↓ |
|---|---|---|---|---|
| DiT-S | 33M | 12层 | 384 | 9.62 |
| DiT-B | 130M | 12层 | 768 | 5.31 |
| DiT-L | 458M | 24层 | 1024 | 3.04 |
| DiT-XL | 675M | 28层 | 1152 | 2.27 |
关键观察:
- 性能持续提升:从DiT-S到DiT-XL,FID持续下降,没有饱和迹象
- 对数线性关系:FID与log(参数量)近似线性关系
- 类似LLM:这与语言模型的scaling law特性一致
意义:
- 更大的模型 → 更好的生成质量(确定性规律)
- 为投资更大模型提供了理论依据
- 预示着10B+参数的扩散模型可能带来质的飞跃
总结:DiT的意义与启示
核心贡献
1. 架构统一
证明了Transformer可以作为扩散模型的通用backbone,图像生成不再需要特定领域的架构设计。
2. AdaLN-Zero
提出了优雅的条件注入机制,在零额外计算成本下实现强大的表达能力和训练稳定性。
3. Scaling Law
首次在图像生成中展现scaling特性,为"训练更大模型"提供了理论支持。
4. 性能突破
FID 2.27(256×256 ImageNet),超越所有基于卷积的方法。
DiT的局限
1. 计算复杂度 :自注意力是 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) O(N^2) </math>O(N2),分辨率越高越慢
2. 推理时间:即使DDIM,仍需50步,比单次前向慢50倍
3. 数据需求:需要大规模数据(百万级)才能充分发挥scaling优势
4. 条件类型:目前主要支持类别标签,对长文本支持有限
未来方向
1. 更高效的注意力:Sparse Attention、Linear Attention、Flash Attention
2. 更快的采样:Consistency Models(一步生成)、Latent Diffusion(低维空间扩散)
3. 更大的模型:DiT-XXL(10B参数)在更大数据集上训练
4. 多模态扩展:文本到图像、视频生成、3D生成
关键启示
- Transformer的通用性:不仅NLP,CV也适用
- Scaling的威力:更大的模型带来更好的效果
- 架构细节的重要性:AdaLN-Zero这样的创新带来质的提升
- 条件注入的本质:如何注入比注入什么更重要
DiT代表了扩散模型从CNN到Transformer的范式转变,这与NLP从RNN到Transformer的转变如出一辙。
Transformer + Diffusion = 图像生成的未来。
参考文献
- Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
- Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
- Ho, J., & Salimans, T. (2022). Classifier-Free Diffusion Guidance. NeurIPS Workshop 2021.
- Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.