深度拆解 DiT:扩散模型与 Transformer 的巅峰结合

21-DiT详解:扩散模型遇上Transformer的图像生成革命

引言

DiT(Diffusion Transformer)是Meta AI在2023年提出的突破性工作,它用纯Transformer架构实现扩散模型,在ImageNet 256×256生成任务上达到了FID 2.27的业界最佳水平,并首次在图像生成模型中展现出清晰的scaling law特性。

本文目标:深入理解DiT的四个核心组件(Patchify、向量化、位置编码、AdaLN-Zero)、推理机制和训练过程。

适合人群:了解Transformer基础和扩散模型原理的读者。


第一部分:扩散模型的数学基础

前向过程:从图像到噪声

扩散模型的前向过程是一个马尔可夫链,逐步向图像添加高斯噪声:
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) q(xt∣xt−1)=N(xt;1−βt xt−1,βtI)

关键性质是可以直接从 x 0 x_0 x0 跳到任意 x t x_t xt:
x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \varepsilon xt=αˉt x0+1−αˉt ε

其中 α ˉ t = ∏ i = 1 t ( 1 − β i ) \bar{\alpha}t = \prod{i=1}^{t}(1-\beta_i) αˉt=∏i=1t(1−βi), ε ∼ N ( 0 , I ) \varepsilon \sim \mathcal{N}(0, I) ε∼N(0,I)。

反向过程:学习去噪

训练目标是学习一个神经网络 ε θ ( x t , t ) \varepsilon_\theta(x_t, t) εθ(xt,t) 预测噪声:
L = E t , x 0 , ε ∥ ε − ε θ ( x t , t ) ∥ 2 \mathcal{L} = \mathbb{E}_{t, x_0, \varepsilon}\left\\\|\\varepsilon - \\varepsilon_\\theta(x_t, t)\\\|\^2\\right L=Et,x0,ε∥ε−εθ(xt,t)∥2

DiT就是 ε θ \varepsilon_\theta εθ 的具体实现。


第二部分:DiT的四大核心组件

DiT的核心思想是将图像视为token序列,用Transformer处理。整个架构包含四个关键设计:


组件一:Patchify(切块) - 从2D到1D的转换

Patchify的本质

Patchify是将2D图像转换为1D token序列的过程。这是将Transformer应用于图像的前提。

给定图像 x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} x∈RH×W×C,选择patch大小 p p p(通常16或8),将图像切分成 N = H W p 2 N = \frac{HW}{p^2} N=p2HW 个不重叠的patch。

每个patch是一个 p × p × C p \times p \times C p×p×C 的立方体,flatten后得到 p 2 C p^2C p2C 维向量。所有patch排列成序列:
Patches ∈ R N × ( p 2 C ) \text{Patches} \in \mathbb{R}^{N \times (p^2C)} Patches∈RN×(p2C)

为什么Patchify是合理的?

局部性原理 :自然图像具有强局部相关性。一个 16 × 16 16 \times 16 16×16 的patch(256像素)通常包含一个完整的局部语义单元。

计算效率的权衡

  • 逐像素处理: 256 × 256 256\times256 256×256 图像有65536个token,自注意力复杂度 O ( N 2 ) = O ( 4.3 × 1 0 9 ) O(N^2) = O(4.3 \times 10^9) O(N2)=O(4.3×109)
  • patch大小 p = 16 p=16 p=16:只有256个token,复杂度 O ( 6.5 × 1 0 4 ) O(6.5 \times 10^4) O(6.5×104),降低了6.6万倍

信息无损:切块是可逆操作,不丢失任何像素信息。

Patch排列的顺序

DiT采用光栅扫描顺序(raster-scan order):从左到右、从上到下依次排列。

虽然Transformer的自注意力是位置不变的(打乱patch顺序输出也会相应打乱),但通过位置编码可以让模型理解patch的空间位置关系。

Patchify的深层意义

Patchify不仅是技术手段,更是认知范式的转变

  • CNN的视角:图像是2D网格,通过卷积核滑动提取局部特征
  • Transformer的视角:图像是patch的集合,每个patch通过全局注意力与其他patch交互

这种转变使得模型可以直接建模长距离依赖,而不受卷积感受野的限制。


组件二:Linear Projection(向量化) - 从像素到语义

Embedding的数学定义

Linear Projection将每个patch从原始像素空间映射到高维语义空间。
z i = E ⋅ vec ( patch i ) + b \mathbf{z}_i = \mathbf{E} \cdot \text{vec}(\text{patch}_i) + \mathbf{b} zi=E⋅vec(patchi)+b

其中:

  • E ∈ R d × ( p 2 C ) \mathbf{E} \in \mathbb{R}^{d \times (p^2C)} E∈Rd×(p2C) 是投影矩阵(可学习)
  • d d d 是Transformer的隐藏维度(如768、1024)
  • vec \text{vec} vec 表示将patch展平成向量

所有patch embedding组成序列:
Z = z 1 , z 2 , ... , z N ∈ R N × d \mathbf{Z} = \\mathbf{z}_1, \\mathbf{z}_2, \\ldots, \\mathbf{z}_N \in \mathbb{R}^{N \times d} Z=z1,z2,...,zN∈RN×d

为什么需要Projection?

1. 维度标准化

不同的patch大小导致不同的输入维度:

  • p = 8 p=8 p=8: 8 2 × 3 = 192 8^2 \times 3 = 192 82×3=192 维
  • p = 16 p=16 p=16: 1 6 2 × 3 = 768 16^2 \times 3 = 768 162×3=768 维

投影到统一的 d d d 维,使得模型架构与patch大小解耦,提供了架构的灵活性

2. 语义提升

原始像素值(如RGB=125, 200, 89)是低层次的信号,投影矩阵 E \mathbf{E} E 学习将其映射到高层次语义空间。

类比:Word Embedding将离散的词ID映射到连续的语义向量空间,Patch Embedding做的是类似的事情。

3. 计算效率

实践中,Linear Projection通常用卷积层实现:
Conv2D ( k = p , s = p , in = C , out = d ) \text{Conv2D}(k=p, s=p, \text{in}=C, \text{out}=d) Conv2D(k=p,s=p,in=C,out=d)

这等价于对每个patch做矩阵乘法,但利用了卷积的并行计算优势。

Projection的初始化

投影矩阵的初始化对训练至关重要。DiT使用Xavier初始化
E ∼ U ( − 6 p 2 C + d , 6 p 2 C + d ) \mathbf{E} \sim \mathcal{U}\left(-\sqrt{\frac{6}{p^2C + d}}, \sqrt{\frac{6}{p^2C + d}}\right) E∼U(−p2C+d6 ,p2C+d6 )

这保证了初始时每层的激活值方差相近,避免梯度消失/爆炸。


组件三:Positional Encoding(位置编码) - 告诉模型"哪里"

为什么Transformer必须有位置编码?

Transformer的自注意力机制是置换等变的(permutation equivariant):
Attention ( shuffle ( X ) ) = shuffle ( Attention ( X ) ) \text{Attention}(\text{shuffle}(X)) = \text{shuffle}(\text{Attention}(X)) Attention(shuffle(X))=shuffle(Attention(X))

这意味着如果打乱输入顺序,输出也会相应打乱。Transformer本身无法区分patch的位置

但图像任务中,位置信息极其关键:

  • 天空通常在上方,草地在下方
  • 物体的空间关系("猫在沙发上")依赖于位置理解

因此必须显式注入位置信息。

DiT的2D正弦位置编码

DiT采用固定的2D正弦位置编码(inherited from ViT)。

对于位置 ( i , j ) (i, j) (i,j) 的patch(第 i i i行,第 j j j列),其位置编码是:
PE ( i , j ) = PE x ( i ) , PE y ( j ) \text{PE}(i, j) = \\text{PE}_x(i), \\text{PE}_y(j) PE(i,j)=PEx(i),PEy(j)

其中x和y坐标分别编码为:
PE x ( i , 2 k ) = sin ⁡ ( i 1000 0 2 k / d ) \text{PE}_x(i, 2k) = \sin\left(\frac{i}{10000^{2k/d}}\right) PEx(i,2k)=sin(100002k/di)
PE x ( i , 2 k + 1 ) = cos ⁡ ( i 1000 0 2 k / d ) \text{PE}_x(i, 2k+1) = \cos\left(\frac{i}{10000^{2k/d}}\right) PEx(i,2k+1)=cos(100002k/di)

最终的2D位置编码是x和y编码的拼接:
PE 2 D ( i , j ) ∈ R d \text{PE}_{2D}(i,j) \in \mathbb{R}^d PE2D(i,j)∈Rd

d / 2 d/2 d/2 维编码x坐标,后 d / 2 d/2 d/2 维编码y坐标。

正弦位置编码的数学优势

1. 周期性与连续性

正弦函数是连续平滑的,相邻位置的编码向量相近,这符合图像的空间连续性假设

2. 相对位置的可表达性

通过三角恒等式:
sin ⁡ ( α + β ) = sin ⁡ α cos ⁡ β + cos ⁡ α sin ⁡ β \sin(\alpha + \beta) = \sin\alpha\cos\beta + \cos\alpha\sin\beta sin(α+β)=sinαcosβ+cosαsinβ

模型可以从绝对位置编码中推导出相对位置关系。例如,位置 ( i + 1 , j ) (i+1, j) (i+1,j) 的编码可以通过位置 ( i , j ) (i, j) (i,j) 的编码线性表示。

3. 外推能力

理论上,正弦编码可以泛化到训练时未见过的更大图像尺寸。虽然实践中效果有限,但这是可学习位置编码不具备的特性。

4. 参数效率

位置编码是固定的(不参与训练),节省了 N × d N \times d N×d 个参数。

位置编码的注入:加法 vs 拼接

DiT使用加法注入:
Z with_pos = Z + P E \mathbf{Z}_{\text{with\_pos}} = \mathbf{Z} + \mathbf{PE} Zwith_pos=Z+PE

为什么不用拼接?

  • 加法 R N × d + R N × d = R N × d \mathbb{R}^{N \times d} + \mathbb{R}^{N \times d} = \mathbb{R}^{N \times d} RN×d+RN×d=RN×d,维度不变
  • 拼接 R N × d ; R N × d = R N × 2 d \\mathbb{R}\^{N \\times d}; \\mathbb{R}\^{N \\times d} = \mathbb{R}^{N \times 2d} RN×d;RN×d=RN×2d,计算量翻倍

理论上,如果 d d d 足够大,加法空间就足以让模型将"内容"和"位置"信息解耦。

实际上,这是一个线性子空间分解的假设:
Z + P E = Z content + Z position \mathbf{Z} + \mathbf{PE} = \mathbf{Z}{\text{content}} + \mathbf{Z}{\text{position}} Z+PE=Zcontent+Zposition

模型通过学习将混合的信息分离到不同的子空间。


组件四:AdaLN-Zero - 条件注入的核心创新

AdaLN-Zero是DiT最重要的创新,解决了"如何将时间步 t t t和类别 c c c注入Transformer"这一核心问题。

扩散模型的条件注入难题

扩散模型需要接收两类信息:

  1. 内容信息 :噪声图像 x t x_t xt
  2. 条件信息
    • 时间步 t t t:当前处于扩散过程的哪个阶段(关键!)
    • 类别标签 c c c:生成什么类别的图像

传统方法:

  • 加法注入 x + f ( t , c ) \mathbf{x} + f(t, c) x+f(t,c) ------ 太简单,条件易被覆盖
  • 拼接注入 x ; f ( t , c ) \\mathbf{x}; f(t, c) x;f(t,c) ------ 增加序列长度,计算量增大
  • Cross-Attention :将条件作为Key/Value ------ 复杂度高 O ( N × M ) O(N \times M) O(N×M)

DiT提出了AdaLN(Adaptive Layer Normalization),一种高效且表达力强的方案。

Adaptive Layer Normalization的数学原理

标准的Layer Normalization:
LN ( x ) = γ ⊙ x − μ σ + β \text{LN}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta LN(x)=γ⊙σx−μ+β

其中 γ , β \gamma, \beta γ,β 是固定的可学习参数。

AdaLN的核心思想 :让 γ , β \gamma, \beta γ,β 依赖于条件信息
γ ( c ) , β ( c ) = MLP ( c ) \gamma(\mathbf{c}), \beta(\mathbf{c}) = \text{MLP}(\mathbf{c}) γ(c),β(c)=MLP(c)
AdaLN ( x , c ) = γ ( c ) ⊙ x − μ σ + β ( c ) \text{AdaLN}(\mathbf{x}, \mathbf{c}) = \gamma(\mathbf{c}) \odot \frac{\mathbf{x} - \mu}{\sigma} + \beta(\mathbf{c}) AdaLN(x,c)=γ(c)⊙σx−μ+β(c)

其中 c = f ( t , c ) \mathbf{c} = f(t, c) c=f(t,c) 是时间步和类别的嵌入向量。

直观理解:调制(Modulation)

AdaLN本质上是用条件信息调制特征的分布

  • γ ( c ) \gamma(\mathbf{c}) γ(c):控制特征的尺度(scale)
  • β ( c ) \beta(\mathbf{c}) β(c):控制特征的偏移(shift)

不同的条件 c \mathbf{c} c 产生不同的 γ , β \gamma, \beta γ,β,从而引导网络产生不同的输出。

类比 :想象一个收音机,条件信息是调频旋钮, γ , β \gamma, \beta γ,β 是调制信号,特征 x \mathbf{x} x 是被调制的载波。

AdaLN-Zero:Zero Initialization的关键改进

DiT在AdaLN基础上加入了Zero Initialization,这是训练稳定性的核心。

标准的DiT Block结构:
h 1 = x + α 1 ( c ) ⊙ Attention ( AdaLN ( x , c ) ) \mathbf{h}_1 = \mathbf{x} + \alpha_1(\mathbf{c}) \odot \text{Attention}(\text{AdaLN}(\mathbf{x}, \mathbf{c})) h1=x+α1(c)⊙Attention(AdaLN(x,c))
h 2 = h 1 + α 2 ( c ) ⊙ MLP ( AdaLN ( h 1 , c ) ) \mathbf{h}_2 = \mathbf{h}_1 + \alpha_2(\mathbf{c}) \odot \text{MLP}(\text{AdaLN}(\mathbf{h}_1, \mathbf{c})) h2=h1+α2(c)⊙MLP(AdaLN(h1,c))

其中 α 1 , α 2 \alpha_1, \alpha_2 α1,α2 是门控参数,也由条件生成:
γ 1 , β 1 , α 1 , γ 2 , β 2 , α 2 = MLP modulation ( c ) \\gamma_1, \\beta_1, \\alpha_1, \\gamma_2, \\beta_2, \\alpha_2 = \text{MLP}_{\text{modulation}}(\mathbf{c}) γ1,β1,α1,γ2,β2,α2=MLPmodulation(c)

Zero Initialization的定义
MLP modulation = W 2 ⋅ SiLU ( W 1 c + b 1 ) + b 2 \text{MLP}_{\text{modulation}} = W_2 \cdot \text{SiLU}(W_1 \mathbf{c} + b_1) + b_2 MLPmodulation=W2⋅SiLU(W1c+b1)+b2

初始化时:
W 2 = 0 , b 2 = 0 W_2 = \mathbf{0}, \quad b_2 = \mathbf{0} W2=0,b2=0

这保证了训练初始时:
γ 1 = γ 2 = 1 , β 1 = β 2 = 0 , α 1 = α 2 = 0 \gamma_1 = \gamma_2 = 1, \quad \beta_1 = \beta_2 = 0, \quad \alpha_1 = \alpha_2 = 0 γ1=γ2=1,β1=β2=0,α1=α2=0

因此:
h 1 = x + 0 ⋅ Attention ( ⋯   ) = x \mathbf{h}_1 = \mathbf{x} + 0 \cdot \text{Attention}(\cdots) = \mathbf{x} h1=x+0⋅Attention(⋯)=x
h 2 = x + 0 ⋅ MLP ( ⋯   ) = x \mathbf{h}_2 = \mathbf{x} + 0 \cdot \text{MLP}(\cdots) = \mathbf{x} h2=x+0⋅MLP(⋯)=x

整个网络初始时是恒等映射 f ( x ) = x f(\mathbf{x}) = \mathbf{x} f(x)=x。

为什么Zero Initialization如此重要?

1. 梯度流动的畅通性

深度网络训练的核心挑战是梯度消失/爆炸

在恒等映射下,梯度可以无损地反向传播:
∂ L ∂ x = ∂ L ∂ h 2 ⋅ ∂ h 2 ∂ x = ∂ L ∂ h 2 ⋅ I \frac{\partial \mathcal{L}}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot \frac{\partial \mathbf{h}_2}{\partial \mathbf{x}} = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_2} \cdot I ∂x∂L=∂h2∂L⋅∂x∂h2=∂h2∂L⋅I

其中 I I I 是单位矩阵,梯度直接传递,不会衰减。

2. 从简单到复杂的学习路径

随着训练进行,门控参数 α 1 , α 2 \alpha_1, \alpha_2 α1,α2 从0逐渐增大,模型逐步学习利用注意力和MLP的输出。

这是一种curriculum learning(课程学习)策略:先学简单的(恒等映射),再学复杂的(注意力模式)。

3. 残差连接的极致体现

残差连接(ResNet)的核心公式:
h = x + F ( x ) \mathbf{h} = \mathbf{x} + F(\mathbf{x}) h=x+F(x)

F ( x ) = 0 F(\mathbf{x}) = 0 F(x)=0 时,网络退化为恒等映射,保证了至少不会比浅层网络差。

AdaLN-Zero通过zero initialization,强制初始时 F ( x ) = 0 F(\mathbf{x}) = 0 F(x)=0,这是残差思想的最彻底实践。

AdaLN vs 其他条件注入方式

方法 计算复杂度 表达能力 训练稳定性
加法注入 O ( 1 ) O(1) O(1)
拼接注入 O ( N ) O(N) O(N)
Cross-Attention O ( N ⋅ M ) O(N \cdot M) O(N⋅M)
AdaLN-Zero O ( 1 ) O(1) O(1)

AdaLN-Zero的优势:

  • 零额外计算:不增加序列长度,不增加注意力计算
  • 强表达力:通过调制归一化参数,影响每一层的特征分布
  • 训练稳定:zero initialization保证梯度流畅通

第三部分:DiT的推理过程

推理就是从纯噪声逐步去噪,生成清晰图像

DDPM采样:严格的概率过程

DDPM(Denoising Diffusion Probabilistic Models)是最原始的采样算法,严格遵循扩散模型的概率推导。

单步去噪公式
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ε θ ( x t , t , c ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}t}} \varepsilon\theta(x_t, t, c) \right) + \sigma_t z xt−1=αt 1(xt−1−αˉt 1−αtεθ(xt,t,c))+σtz

其中:

  • ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) εθ(xt,t,c) 是DiT预测的噪声
  • z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) z∼N(0,I) 是新采样的随机噪声
  • σ t = β ~ t \sigma_t = \sqrt{\tilde{\beta}_t} σt=β~t 是后验方差

完整流程

  1. 初始化 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I)(纯高斯噪声)
  2. 对于 t = T , T − 1 , ... , 1 t = T, T-1, \ldots, 1 t=T,T−1,...,1:
    • 前向传播DiT: ε ^ = ε θ ( x t , t , c ) \hat{\varepsilon} = \varepsilon_\theta(x_t, t, c) ε^=εθ(xt,t,c)
    • 计算均值: μ t = 1 α t ( x t − 1 − α t 1 − α ˉ t ε ^ ) \mu_t = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\hat{\varepsilon}\right) μt=αt 1(xt−1−αˉt 1−αtε^)
    • 采样噪声: z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) z∼N(0,I)
    • 更新: x t − 1 = μ t + σ t z x_{t-1} = \mu_t + \sigma_t z xt−1=μt+σtz
  3. 返回 x 0 x_0 x0

特点

  • :需要1000步,每步都要前向传播DiT
  • 质量高:每步添加适量随机性,生成多样性好
  • 理论清晰 :严格遵循后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0)

DDIM采样:确定性加速

DDIM(Denoising Diffusion Implicit Models)通过确定性过程实现加速。

核心思想:不采样新噪声,而是走确定性的"直线路径"。

DDIM公式
x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ε ^ α ˉ t ⏟ predicted x 0 + 1 − α ˉ t − 1 ε ^ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\frac{x_t - \sqrt{1-\bar{\alpha}t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}t}}}{\text{predicted } x_0} + \sqrt{1-\bar{\alpha}{t-1}} \hat{\varepsilon} xt−1=αˉt−1 predicted x0 αˉt xt−1−αˉt ε^+1−αˉt−1 ε^

这个公式的直观理解:

  1. 用当前 x t x_t xt 和预测噪声 ε ^ \hat{\varepsilon} ε^,估计干净图像:

x ^ 0 = x t − 1 − α ˉ t ε ^ α ˉ t \hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\hat{\varepsilon}}{\sqrt{\bar{\alpha}_t}} x^0=αˉt xt−1−αˉt ε^

  1. x ^ 0 \hat{x}0 x^0 和 ε ^ \hat{\varepsilon} ε^,重新组合成 x t − 1 x{t-1} xt−1:

x t − 1 = α ˉ t − 1 x ^ 0 + 1 − α ˉ t − 1 ε ^ x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \hat{x}0 + \sqrt{1-\bar{\alpha}{t-1}} \hat{\varepsilon} xt−1=αˉt−1 x^0+1−αˉt−1 ε^

关键区别

  • DDPM:每步采样新噪声 z z z,引入随机性
  • DDIM:重复使用同一个噪声估计 ε ^ \hat{\varepsilon} ε^,是确定性过程

加速原理 :由于是确定性的,可以跳步采样。例如:

  • DDPM: 1000 → 999 → 998 → ⋯ → 1 → 0 1000 \to 999 \to 998 \to \cdots \to 1 \to 0 1000→999→998→⋯→1→0(1000步)
  • DDIM: 1000 → 950 → 900 → ⋯ → 50 → 0 1000 \to 950 \to 900 \to \cdots \to 50 \to 0 1000→950→900→⋯→50→0(20步)

特点

  • :50步达到DDPM 1000步的效果,速度提升20倍
  • 确定性:相同初始噪声和条件,生成完全相同的图像
  • 质量略降:FID略高于DDPM,但肉眼难以区分

Classifier-Free Guidance:提升条件遵循度

CFG(Classifier-Free Guidance)是提升生成质量的关键技术。

问题:标准条件生成可能"不够听话"。指定生成"猫",模型可能生成模糊的猫,或混合其他动物特征。

解决方案:训练时同时学习条件生成和无条件生成,推理时"放大"条件影响。

CFG训练

训练时,以概率 p = 0.1 p=0.1 p=0.1 将类别标签置空:
c ′ = { ∅ 概率 0.1 c 概率 0.9 c' = \begin{cases} \emptyset & \text{概率 } 0.1 \\ c & \text{概率 } 0.9 \end{cases} c′={∅c概率 0.1概率 0.9

其中 ∅ \emptyset ∅ 用特殊token表示(如类别ID=1000)。

这样模型学会了两种模式:

  • ε θ ( x t , t , c ) \varepsilon_\theta(x_t, t, c) εθ(xt,t,c):给定类别 c c c的条件生成
  • ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, \emptyset) εθ(xt,t,∅):无条件生成
CFG推理

推理时,将两者线性组合:
ε ~ = ε θ ( x t , t , ∅ ) + w ⋅ ( ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) ) \tilde{\varepsilon} = \varepsilon_\theta(x_t, t, \emptyset) + w \cdot \left(\varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset)\right) ε~=εθ(xt,t,∅)+w⋅(εθ(xt,t,c)−εθ(xt,t,∅))

其中 w w w 是guidance scale(通常 w = 7.5 w=7.5 w=7.5)。

数学直观

  • ε θ ( x t , t , c ) − ε θ ( x t , t , ∅ ) \varepsilon_\theta(x_t, t, c) - \varepsilon_\theta(x_t, t, \emptyset) εθ(xt,t,c)−εθ(xt,t,∅):条件相对于无条件的"差异方向"
  • w > 1 w > 1 w>1:沿着这个方向走得更远,放大条件影响
  • w = 1 w = 1 w=1:标准条件生成
  • w = 0 w = 0 w=0:无条件生成

效果

Guidance scale w w w 类别一致性 图像多样性 图像质量
1.0 一般
3.0-5.0
7.5 (推荐) 中低 最好
15.0+ 过高 过饱和、失真

代价:CFG需要推理两次(条件+无条件),推理时间翻倍。但效果提升显著,是工业标准。

完整推理流程

结合DDIM和CFG:

输入

  • 类别 c c c(如"猫"的ID=281)
  • 采样步数 S = 50 S=50 S=50
  • Guidance scale w = 7.5 w=7.5 w=7.5

算法

markdown 复制代码
1. 确定时间步序列:τ = [1000, 950, 900, ..., 50, 0](均匀采样S步)
2. 初始化:x ← N(0, I)
3. For t in τ[:-1]:
     t_next ← τ中t的下一个时间步

     # 条件预测
     ε_cond ← DiT(x, t, c)

     # 无条件预测
     ε_uncond ← DiT(x, t, ∅)

     # CFG组合
     ε̂ ← ε_uncond + w * (ε_cond - ε_uncond)

     # 估计x₀
     x̂₀ ← (x - √(1-ᾱₜ)·ε̂) / √ᾱₜ

     # DDIM更新
     x ← √ᾱₜ_ₙₑₓₜ · x̂₀ + √(1-ᾱₜ_ₙₑₓₜ) · ε̂

4. Return x

时间成本(DiT-XL,A100 GPU):

  • DDPM 1000步:约60秒/图
  • DDIM 50步 + CFG:约6秒/图

第四部分:DiT的训练过程

训练目标

DiT的训练是简单的噪声预测任务
L = E t , x 0 , ε , c ∥ ε − ε θ ( x t , t , c ) ∥ 2 \mathcal{L} = \mathbb{E}_{t, x_0, \varepsilon, c}\left\\\|\\varepsilon - \\varepsilon_\\theta(x_t, t, c)\\\|\^2\\right L=Et,x0,ε,c∥ε−εθ(xt,t,c)∥2

其中 x t = α ˉ t x 0 + 1 − α ˉ t ε x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon xt=αˉt x0+1−αˉt ε。

训练算法

单个训练step

scss 复制代码
1. 采样一批数据:(x₀, c) ~ 数据集
2. 采样时间步:t ~ Uniform(1, T)
3. 采样噪声:ε ~ N(0, I)
4. 前向加噪:xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε
5. 预测噪声:ε̂ = DiT(xₜ, t, c)
6. 计算损失:L = ‖ε̂ - ε‖²
7. 反向传播:更新参数θ

关键训练细节

1. 噪声调度(Noise Schedule)

β t \beta_t βt 的设计影响训练效果。DiT使用线性调度
β t = β min ⁡ + t − 1 T − 1 ( β max ⁡ − β min ⁡ ) \beta_t = \beta_{\min} + \frac{t-1}{T-1}(\beta_{\max} - \beta_{\min}) βt=βmin+T−1t−1(βmax−βmin)

典型值: β min ⁡ = 0.0001 , β max ⁡ = 0.02 , T = 1000 \beta_{\min} = 0.0001, \beta_{\max} = 0.02, T=1000 βmin=0.0001,βmax=0.02,T=1000。

这意味着:

  • 早期( t t t小): β t \beta_t βt很小,加噪缓慢,图像几乎不变
  • 后期( t t t大): β t \beta_t βt接近0.02,加噪快速,图像迅速变成纯噪声
2. Classifier-Free Guidance训练

如前所述,训练时10%概率drop类别:
c ′ = { null token p = 0.1 c p = 0.9 c' = \begin{cases} \text{null token} & p=0.1 \\ c & p=0.9 \end{cases} c′={null tokencp=0.1p=0.9

这让模型同时学会两种生成模式。

3. 学习率调度

DiT使用warmup + cosine decay
η t = η min ⁡ + 1 2 ( η max ⁡ − η min ⁡ ) ( 1 + cos ⁡ ( π t − t w T − t w ) ) \eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\pi \frac{t - t_w}{T - t_w}\right)\right) ηt=ηmin+21(ηmax−ηmin)(1+cos(πT−twt−tw))

其中:

  • t w = 10000 t_w = 10000 tw=10000:warmup步数
  • η max ⁡ = 1 0 − 4 \eta_{\max} = 10^{-4} ηmax=10−4:峰值学习率
  • η min ⁡ = 1 0 − 5 \eta_{\min} = 10^{-5} ηmin=10−5:最小学习率

前10000步线性增长到 η max ⁡ \eta_{\max} ηmax,之后按余弦函数衰减。

原理:warmup避免初期梯度过大导致发散;cosine decay比step decay更平滑。

4. EMA(Exponential Moving Average)

维护参数的指数移动平均:
θ EMA ← μ θ EMA + ( 1 − μ ) θ \theta_{\text{EMA}} \leftarrow \mu \theta_{\text{EMA}} + (1-\mu)\theta θEMA←μθEMA+(1−μ)θ

其中 μ = 0.9999 \mu = 0.9999 μ=0.9999。

推理时使用 θ EMA \theta_{\text{EMA}} θEMA 而非 θ \theta θ。

原理:EMA相当于对训练轨迹上的多个checkpoint做平滑,减少单个模型的抖动,提升生成质量和稳定性。

5. 混合精度训练

使用FP16计算,同时维护FP32主权重:

  • 前向传播、梯度计算:FP16
  • 参数更新:FP32

收益

  • 训练速度提升1.5-2倍
  • 显存占用减半
  • 精度损失可忽略

训练规模与成本

DiT-XL的训练配置:

项目 数值
参数量 675M
数据集 ImageNet(130万图像,1000类)
Batch size 256(8卡 × 32/卡)
训练步数 7M steps
训练时长 约1个月(8×A100 80GB)
总计算量 约10M GPU-hours
FID(256×256) 2.27

Scaling Law:DiT的惊人发现

DiT首次在图像生成模型中展现出清晰的scaling law:

模型 参数量 深度 宽度 FID ↓
DiT-S 33M 12层 384 9.62
DiT-B 130M 12层 768 5.31
DiT-L 458M 24层 1024 3.04
DiT-XL 675M 28层 1152 2.27

关键观察

  1. 性能持续提升:从DiT-S到DiT-XL,FID持续下降,没有饱和迹象
  2. 对数线性关系:FID与log(参数量)近似线性关系
  3. 类似LLM:这与语言模型的scaling law特性一致

意义

  • 更大的模型 → 更好的生成质量(确定性规律)
  • 为投资更大模型提供了理论依据
  • 预示着10B+参数的扩散模型可能带来质的飞跃

总结:DiT的意义与启示

核心贡献

1. 架构统一

证明了Transformer可以作为扩散模型的通用backbone,图像生成不再需要特定领域的架构设计。

2. AdaLN-Zero

提出了优雅的条件注入机制,在零额外计算成本下实现强大的表达能力和训练稳定性。

3. Scaling Law

首次在图像生成中展现scaling特性,为"训练更大模型"提供了理论支持。

4. 性能突破

FID 2.27(256×256 ImageNet),超越所有基于卷积的方法。

DiT的局限

1. 计算复杂度 :自注意力是 O ( N 2 ) O(N^2) O(N2),分辨率越高越慢

2. 推理时间:即使DDIM,仍需50步,比单次前向慢50倍

3. 数据需求:需要大规模数据(百万级)才能充分发挥scaling优势

4. 条件类型:目前主要支持类别标签,对长文本支持有限

未来方向

1. 更高效的注意力:Sparse Attention、Linear Attention、Flash Attention

2. 更快的采样:Consistency Models(一步生成)、Latent Diffusion(低维空间扩散)

3. 更大的模型:DiT-XXL(10B参数)在更大数据集上训练

4. 多模态扩展:文本到图像、视频生成、3D生成

关键启示

  1. Transformer的通用性:不仅NLP,CV也适用
  2. Scaling的威力:更大的模型带来更好的效果
  3. 架构细节的重要性:AdaLN-Zero这样的创新带来质的提升
  4. 条件注入的本质:如何注入比注入什么更重要

DiT代表了扩散模型从CNN到Transformer的范式转变,这与NLP从RNN到Transformer的转变如出一辙。

Transformer + Diffusion = 图像生成的未来。


参考文献

  1. Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
  2. Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.
  3. Song, J., Meng, C., & Ermon, S. (2020). Denoising Diffusion Implicit Models. ICLR 2021.
  4. Ho, J., & Salimans, T. (2022). Classifier-Free Diffusion Guidance. NeurIPS Workshop 2021.
  5. Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.
相关推荐
葫芦和十三1 小时前
图解 MongoDB 11|慢查询排查闭环:从 Profile 到 explain 的分层路径
后端·mongodb·agent
葫芦和十三4 小时前
图解 MongoDB 09|explain 再读:从 queryPlanner 到 executionStats
后端·mongodb·agent
葫芦和十三4 小时前
图解 MongoDB 10|覆盖查询:让索引把活干完,根本不用回表
后端·mongodb·agent
大鸡腿同学6 小时前
从 CoT 思维链到 ReAct:智能 Agent 到底是怎么 “思考” 的?
后端
IT_陈寒8 小时前
Vite的静态资源打包让我熬夜到三点,这坑千万别跳
前端·人工智能·后端
SamDeepThinking9 小时前
高并发场景下,CompletableFuture与ForkJoinPool该如何取舍?
java·后端·面试
Asize9 小时前
多模态生图:从 Vite 工程化到前端调用 Qwen Image
javascript·人工智能·后端
java小白小9 小时前
SpringBoot(09):缓存实战——穿透、雪崩、击穿的解决方案
后端
MobotStone9 小时前
AI项目越多,为什么越容易失控
人工智能·aigc
java小白小9 小时前
SpringBoot(08):Redis 集成——5 分钟给你的项目加上缓存
后端