Diffusion Transformer (DiT) 主干网络原理讲解与代码注释
1. 概述
Diffusion Transformer(DiT)是 Diffusion Planner 的解码器核心主干网络 ,负责在扩散过程中逐步去噪,生成自车与周围车辆的未来轨迹。它将传统的 U-Net 扩散主干替换为 Transformer 架构 ,并引入 adaLN-Zero(自适应层归一化)条件注入机制,使网络能够高效地融合时间步、场景上下文与路由信息。
DiT 的整体设计理念来源于图像生成领域的 DiT(Diffusion Transformer),其核心思想是:
- 将轨迹数据视为 token 序列
- 利用 Transformer 的自注意力机制建模多智能体间的交互
- 通过自适应归一化将扩散时间步与条件信息注入网络
2. 核心原理
2.1 扩散过程回顾
扩散模型包含两个过程:
- 前向过程(Forward Process):逐步向数据添加高斯噪声,直到数据变为纯噪声
- 反向过程(Reverse Process):学习从噪声中逐步恢复原始数据
在 Diffusion Planner 中,使用 VP-SDE(Variance Preserving Stochastic Differential Equation) 描述扩散过程:
dx=−β(t)2x dt+β(t) dWt d x = -\frac{\beta(t)}{2} x \, dt + \sqrt{\beta(t)} \, dW_t dx=−2β(t)xdt+β(t) dWt
其中:
- β(t)=(βmax−βmin)⋅t+βmin\beta(t) = (\beta_{max} - \beta_{min}) \cdot t + \beta_{min}β(t)=(βmax−βmin)⋅t+βmin,t∈0,1t \in 0, 1t∈0,1
- βmin=0.1\beta_{min} = 0.1βmin=0.1,βmax=20.0\beta_{max} = 20.0βmax=20.0
- WtW_tWt 是标准维纳过程(布朗运动)
2.1.1 加噪过程详解
DiT 的输入 xxx 是通过 VP-SDE 的边缘分布采样得到的加噪轨迹。具体流程如下:
python
# 1. 准备数据:当前状态 + 未来轨迹
current_states = torch.cat([ego_current[:, None], neighbors_current], dim=1) # [B, P, 4]
gt_future = torch.cat([ego_future[:, None, :, :], neighbors_future[..., :]], dim=1) # [B, P, T, 4]
all_gt = torch.cat([current_states[:, :, None, :], norm(gt_future)], dim=2) # [B, P, 1+T, 4]
# 2. 随机采样时间步 t ∈ [eps, 1-eps]
t = torch.rand(B, device=gt_future.device) * (1 - eps) + eps # [B]
# 3. 生成标准高斯噪声 z
z = torch.randn_like(gt_future, device=gt_future.device) # [B, P, T, 4]
# 4. 使用VP-SDE边缘分布计算加噪轨迹
mean, std = marginal_prob(all_gt[..., 1:, :], t) # 仅对未来轨迹加噪,保持当前状态不变
xT = mean + std * z # 核心加噪公式
xT = torch.cat([all_gt[:, :, :1, :], xT], dim=2) # 拼接当前状态
数学原理
VP-SDE 的边缘分布满足以下公式:
pt(xt∣x0)=N(xt;μ(t)⋅x0,σ(t)2⋅I) p_t(x_t | x_0) = \mathcal{N}(x_t; \mu(t) \cdot x_0, \sigma(t)^2 \cdot I) pt(xt∣x0)=N(xt;μ(t)⋅x0,σ(t)2⋅I)
其中:
- 均值 : μ(t)=exp(−14(βmax−βmin)t2−12βmint)\mu(t) = \exp\left(-\frac{1}{4}(\beta_{\text{max}} - \beta_{\text{min}})t^2 - \frac{1}{2}\beta_{\text{min}}t\right)μ(t)=exp(−41(βmax−βmin)t2−21βmint)
- 标准差 : σ(t)=1−exp(−12(βmax−βmin)t2−βmint)\sigma(t) = \sqrt{1 - \exp\left(-\frac{1}{2}(\beta_{\text{max}} - \beta_{\text{min}})t^2 - \beta_{\text{min}}t\right)}σ(t)=1−exp(−21(βmax−βmin)t2−βmint)
- β(t)=(βmax−βmin)⋅t+βmin\beta(t) = (\beta_{\text{max}} - \beta_{\text{min}}) \cdot t + \beta_{\text{min}}β(t)=(βmax−βmin)⋅t+βmin,线性噪声调度
关键实现细节
| 步骤 | 代码位置 | 作用 |
|---|---|---|
| 时间步采样 | t = torch.rand(B) * (1 - eps) + eps |
避免t=0(无噪声)和t=1(完全噪声)的边界情况 |
| 噪声生成 | z = torch.randn_like(gt_future) |
生成与轨迹同形状的标准高斯噪声 |
| 边缘分布计算 | mean, std = marginal_prob(all_gt[..., 1:, :], t) |
根据时间步计算均值和标准差 |
| 加噪公式 | xT = mean + std * z |
从边缘分布采样 |
| 当前状态保持 | torch.cat([all_gt[:, :, :1, :], xT], dim=2) |
当前状态(t=0)不参与加噪 |
加噪特性
-
时间步控制噪声强度:
- t→0:噪声很小,xT≈x0x_T \approx x_0xT≈x0(接近原始轨迹)
- t→1:噪声很大,xTx_TxT 接近纯随机噪声
-
当前状态约束:
- 代码中仅对未来轨迹
all_gt[..., 1:, :]加噪 - 当前状态
all_gt[:, :, :1, :]直接保留,不添加噪声 - 这保证了扩散过程从已知的当前状态开始
- 代码中仅对未来轨迹
-
VP-SDE 参数:
- βmin=0.1\beta_{\text{min}} = 0.1βmin=0.1,βmax=20.0\beta_{\text{max}} = 20.0βmax=20.0
- 线性噪声调度,噪声强度随时间步线性增长
完整加噪示意图
原始轨迹 x_0: [当前状态] + [未来轨迹T1, T2, ..., Tn]
| |
| v
| + 高斯噪声 z ~ N(0,1)
| |
| v
| x_t = mean + std * z
| |
+---------> [当前状态] + [加噪未来轨迹]
= x_T (DiT输入)
与DiT前向传播的关联
在训练时,xT 作为 sampled_trajectories 输入到 DiT:
python
merged_inputs = {
**inputs,
"sampled_trajectories": xT, # 加噪轨迹
"diffusion_time": t, # 对应时间步
}
_, decoder_output = model(merged_inputs)
DiT 学习在给定加噪轨迹 xTx_TxT 和时间步 ttt 的情况下,预测去噪后的轨迹或分数。
2.2 DiT 架构设计
DiT 的整体架构包含以下核心组件:
输入轨迹 token (x) ──→ Pre-projection ──→ 嵌入层 ──┐
├──→ DiT Blocks × N ──→ Final Layer ──→ 输出
时间步 t ──→ Timestep Embedder ──┐ │
├──→ 条件融合 (y) ──┘
路由信息 ──→ RouteEncoder ────────┘
↑
场景上下文 ──→ Cross-Attention ──────┘
2.2.1 时间步嵌入(Timestep Embedding)
扩散模型需要在每个时间步 ttt 进行去噪,因此需要将标量时间步编码为向量表示。DiT 采用 正弦位置编码 结合 MLP:
freqi=exp(−log(max_period)⋅idim/2),i=0,1,...,dim/2−1 \text{freq}_i = \exp\left(-\log(max\_period) \cdot \frac{i}{dim/2}\right), \quad i = 0, 1, ..., dim/2 - 1 freqi=exp(−log(max_period)⋅dim/2i),i=0,1,...,dim/2−1
emb=cos(t⋅freq),sin(t⋅freq) \text{emb} = \\cos(t \\cdot \\text{freq}), \\sin(t \\cdot \\text{freq}) emb=cos(t⋅freq),sin(t⋅freq)
这种编码方式能够捕捉时间步的周期性和连续性特征。
2.2.2 自适应层归一化(adaLN-Zero)
adaLN-Zero 是 DiT 的核心创新,它将条件信息(时间步 + 路由)通过一个小型 MLP 生成六个缩放/偏移/门控参数:
shiftmsa,scalemsa,gatemsa,shiftmlp,scalemlp,gatemlp=adaLN_modulation(y).chunk(6) \text{shift}{msa}, \text{scale}{msa}, \text{gate}{msa}, \text{shift}{mlp}, \text{scale}{mlp}, \text{gate}{mlp} = \text{adaLN\_modulation}(y).\text{chunk}(6) shiftmsa,scalemsa,gatemsa,shiftmlp,scalemlp,gatemlp=adaLN_modulation(y).chunk(6)
在每个 DiT Block 中:
- 自注意力前 :对输入进行调制 x=x⋅(1+scale)+shiftx = x \cdot (1 + \text{scale}) + \text{shift}x=x⋅(1+scale)+shift
- 自注意力后 :通过门控 x=x+gate⋅Attention(modulated_x)x = x + \text{gate} \cdot \text{Attention}(\text{modulated\_x})x=x+gate⋅Attention(modulated_x)
- MLP 前:再次调制
- MLP 后:再次门控
这种设计使得条件信息能够精细地控制每一层的特征变换,且初始化时将 adaLN 输出置零,保证训练初期的稳定性。
2.2.3 交叉注意力(Cross-Attention)
DiT Block 在自注意力和 MLP 之后,还包含一个 交叉注意力层,用于融合编码器输出的场景上下文信息:
CrossAttn(Q=norm(x),K=cross_c,V=cross_c) \text{CrossAttn}(Q=\text{norm}(x), K=\text{cross\_c}, V=\text{cross\_c}) CrossAttn(Q=norm(x),K=cross_c,V=cross_c)
其中 cross_c 是编码器对 agents、静态物体和车道线的统一编码,使 DiT 能够感知完整的驾驶场景。
2.2.4 智能体类型嵌入
DiT 通过可学习的嵌入区分自车和周围车辆:
- 自车(ego):使用
agent_embedding.weight[0] - 周围车(neighbor):使用
agent_embedding.weight[1]
这使得网络能够针对自车和周围车辆学习不同的轨迹生成模式。
3. 代码详解与注释
3.1 调制函数(Modulation Functions)
python
import math
import torch
import torch.nn as nn
from timm.models.layers import Mlp
def modulate(x, shift, scale, only_first=False):
"""
自适应层归一化调制函数:对输入 x 进行缩放和平移
公式: x = x * (1 + scale) + shift
Args:
x: 输入特征 [B, P, D] 或 [B, D]
shift: 平移参数 [B, D]
scale: 缩放参数 [B, D]
only_first: 是否只调制第一个 token(用于特定条件注入)
Returns:
调制后的特征
"""
if only_first:
# 仅对第一个 token(自车)进行调制,保持其他 token 不变
x_first, x_rest = x[:, :1], x[:, 1:]
x = torch.cat([x_first * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), x_rest], dim=1)
else:
# 对所有 token 进行调制
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x
def scale(x, scale, only_first=False):
"""
仅缩放调制(无平移)
公式: x = x * (1 + scale)
用于 FinalLayer 中的输出调制
"""
if only_first:
x_first, x_rest = x[:, :1], x[:, 1:]
x = torch.cat([x_first * (1 + scale.unsqueeze(1)), x_rest], dim=1)
else:
x = x * (1 + scale.unsqueeze(1))
return x
3.2 时间步嵌入器(TimestepEmbedder)
python
class TimestepEmbedder(nn.Module):
"""
时间步嵌入器:将标量时间步 t ∈ [0, 1] 映射为向量表示
原理:
1. 使用正弦/余弦位置编码捕捉时间步的周期性特征
2. 通过 MLP 将位置编码映射到隐藏维度
这种编码方式类似于 Transformer 的位置编码,但用于连续的时间步而非离散的序列位置
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
# 第一层:将频率编码映射到隐藏维度
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(), # Swish 激活函数,平滑且非单调
# 第二层:进一步变换
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
创建正弦时间步嵌入(Sinusoidal Timestep Embeddings)
数学原理:
对于维度 i ∈ [0, dim/2):
- 频率: freq_i = exp(-log(max_period) * i / (dim/2))
- 编码: [cos(t * freq_i), sin(t * freq_i)]
这种编码的优势:
1. 能够表示任意连续时间步
2. 不同维度具有不同频率,形成多尺度表示
3. 对于相近的时间步,编码也相近(局部平滑性)
Args:
t: 时间步张量 [B],取值范围通常为 [0, 1]
dim: 输出编码维度
max_period: 控制最低频率,越大则频率范围越广
Returns:
时间步编码 [B, dim]
"""
half = dim // 2
# 计算频率:从高频到低频的对数均匀分布
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
# 计算角度:t * freq
args = t[:, None].float() * freqs[None]
# 拼接正弦和余弦编码
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
# 如果维度为奇数,补零
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
"""
前向传播:将时间步 t 转换为嵌入向量
Args:
t: [B] 批次时间步
Returns:
t_emb: [B, hidden_size] 时间步嵌入
"""
# 1. 生成正弦位置编码
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
# 2. 通过 MLP 变换到目标维度
t_emb = self.mlp(t_freq)
return t_emb
3.3 DiT 基础块(DiTBlock)
python
class DiTBlock(nn.Module):
"""
DiT 基础块:结合自注意力、交叉注意力和 adaLN-Zero 条件注入
结构(按执行顺序):
1. 自适应层归一化 + 自注意力(Self-Attention)
2. 自适应层归一化 + MLP
3. 交叉注意力(Cross-Attention)融合场景上下文
4. MLP 进一步变换
条件注入机制(adaLN-Zero):
- 通过条件向量 y 生成 6 个参数:shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
- shift/scale 用于调制输入特征
- gate 用于控制残差连接的强度(零初始化保证训练稳定)
"""
def __init__(self, dim=192, heads=6, dropout=0.1, mlp_ratio=4.0):
super().__init__()
# ---- 自注意力分支 ----
self.norm1 = nn.LayerNorm(dim) # 预归一化
self.attn = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
# ---- MLP 分支 ----
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio) # MLP 隐藏层维度 = dim * 4
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=approx_gelu, drop=0)
# ---- adaLN-Zero 调制网络 ----
# 输入:条件向量 y [B, dim]
# 输出:6 * dim 参数(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True)
)
# ---- 交叉注意力分支(融合场景上下文)----
self.norm3 = nn.LayerNorm(dim)
self.cross_attn = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
self.norm4 = nn.LayerNorm(dim)
self.mlp2 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=approx_gelu, drop=0)
def forward(self, x, cross_c, y, attn_mask):
"""
DiTBlock 前向传播
Args:
x: 输入轨迹 token [B, P, D],P = 1 + predicted_neighbor_num
cross_c: 场景上下文编码 [B, N, D],来自 Encoder
y: 条件向量 [B, D],融合时间步和路由信息
attn_mask: 邻居车辆掩码 [B, P],用于屏蔽不存在的车辆
Returns:
x: 变换后的特征 [B, P, D]
"""
# ========== Step 1: adaLN-Zero 参数生成 ==========
# 从条件 y 生成 6 个调制参数,每个参数 [B, D]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(y).chunk(6, dim=1)
# ========== Step 2: 自注意力(Self-Attention)==========
# 2.1 自适应层归一化调制
modulated_x = modulate(self.norm1(x), shift_msa, scale_msa)
# 2.2 自注意力计算 + 门控残差连接
# attn_mask 用于屏蔽无效邻居车辆,保证自车始终参与计算
x = x + gate_msa.unsqueeze(1) * self.attn(
modulated_x, modulated_x, modulated_x,
key_padding_mask=attn_mask
)[0]
# ========== Step 3: MLP 变换 ==========
# 3.1 再次调制
modulated_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
# 3.2 MLP + 门控残差
x = x + gate_mlp.unsqueeze(1) * self.mlp1(modulated_x)
# ========== Step 4: 交叉注意力(Cross-Attention)==========
# 将轨迹特征与场景上下文融合,Q 来自轨迹,K/V 来自场景
x = self.cross_attn(self.norm3(x), cross_c, cross_c)[0]
# ========== Step 5: 最终 MLP ==========
x = self.mlp2(self.norm4(x))
return x
3.4 输出层(FinalLayer)
python
class FinalLayer(nn.Module):
"""
DiT 最终输出层:将隐藏特征映射到轨迹输出空间
结构:
1. adaLN 调制
2. LayerNorm + Linear + GELU + LayerNorm + Linear
输出维度:output_size = (future_len + 1) * 4
- future_len: 预测的未来时间步数
- 4: (x, y, cos_heading, sin_heading)
- +1: 包含当前状态
"""
def __init__(self, hidden_size, output_size):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size)
# 投影网络:hidden_size → hidden_size*4 → output_size
self.proj = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, hidden_size * 4, bias=True),
nn.GELU(approximate="tanh"),
nn.LayerNorm(hidden_size * 4),
nn.Linear(hidden_size * 4, output_size, bias=True)
)
# adaLN 调制(只生成 shift 和 scale,无 gate)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, y):
"""
Args:
x: 隐藏特征 [B, P, hidden_size]
y: 条件向量 [B, hidden_size]
Returns:
轨迹预测 [B, P, output_size]
"""
B, P, _ = x.shape
# 生成 shift 和 scale 参数
shift, scale = self.adaLN_modulation(y).chunk(2, dim=1)
# 调制输入特征
x = modulate(self.norm_final(x), shift, scale)
# 投影到输出空间
x = self.proj(x)
return x
3.5 DiT 主网络(DiT)
python
class DiT(nn.Module):
"""
Diffusion Transformer (DiT) 主网络
职责:在扩散过程中,根据当前噪声轨迹、时间步和场景条件,预测去噪后的轨迹
支持两种预测模式:
1. "x_start": 直接预测原始轨迹 x_0
2. "score": 预测分数(score function),需要除以 marginal_prob_std
输入:
- x: 当前噪声轨迹 [B, P, (1+V_future)*4],加噪过程采用 VP-SDE(Variance Preserving SDE) 的边缘分布采样方法。
- t: 扩散时间步 [B]
- cross_c: 场景上下文编码 [B, N, D]
- route_lanes: 路由车道线信息
- neighbor_current_mask: 邻居车辆存在掩码 [B, P]
输出:
- 预测轨迹 [B, P, (1+V_future)*4]
"""
def __init__(self, sde: SDE, route_encoder: nn.Module, depth, output_dim,
hidden_dim=192, heads=6, dropout=0.1, mlp_ratio=4.0,
model_type="x_start"):
super().__init__()
assert model_type in ["score", "x_start"], f"Unknown model type: {model_type}"
self._model_type = model_type
# ---- 路由编码器 ----
self.route_encoder = route_encoder
# ---- 智能体类型嵌入 ----
# 0: 自车(ego), 1: 周围车(neighbor)
self.agent_embedding = nn.Embedding(2, hidden_dim)
# ---- 输入投影层 ----
# 将轨迹维度投影到隐藏维度
self.preproj = Mlp(in_features=output_dim, hidden_features=512,
out_features=hidden_dim, act_layer=nn.GELU, drop=0.)
# ---- 时间步嵌入器 ----
self.t_embedder = TimestepEmbedder(hidden_dim)
# ---- DiT Blocks 堆叠 ----
self.blocks = nn.ModuleList([
DiTBlock(hidden_dim, heads, dropout, mlp_ratio)
for i in range(depth)
])
# ---- 最终输出层 ----
self.final_layer = FinalLayer(hidden_dim, output_dim)
# ---- SDE 相关 ----
self._sde = sde
self.marginal_prob_std = self._sde.marginal_prob_std
@property
def model_type(self):
return self._model_type
def forward(self, x, t, cross_c, route_lanes, neighbor_current_mask):
"""
DiT 前向传播
Args:
x: (B, P, output_dim) 当前轨迹状态(训练时加噪,推理时从噪声开始)
t: (B,) 扩散时间步
cross_c: (B, N, D) 场景上下文编码(来自 Encoder)
route_lanes: 路由车道线信息
neighbor_current_mask: (B, P) 邻居掩码,True 表示该位置无车辆
Returns:
预测结果 [B, P, output_dim]
"""
B, P, _ = x.shape
# ========== Step 1: 输入投影 ==========
# 将轨迹从原始维度投影到隐藏维度
x = self.preproj(x)
# ========== Step 2: 添加智能体类型嵌入 ==========
# 自车使用 embedding[0],所有周围车使用 embedding[1]
x_embedding = torch.cat([
self.agent_embedding.weight[0][None, :], # [1, D] 自车
self.agent_embedding.weight[1][None, :].expand(P - 1, -1) # [P-1, D] 周围车
], dim=0) # (P, D)
x_embedding = x_embedding[None, :, :].expand(B, -1, -1) # (B, P, D)
x = x + x_embedding # 添加类型信息
# ========== Step 3: 条件向量生成 ==========
# 3.1 编码路由信息
route_encoding = self.route_encoder(route_lanes)
y = route_encoding
# 3.2 融合时间步嵌入
y = y + self.t_embedder(t)
# ========== Step 4: 注意力掩码准备 ==========
# 确保自车(第0位)始终参与计算,邻居根据 mask 决定
attn_mask = torch.zeros((B, P), dtype=torch.bool, device=x.device)
attn_mask[:, 1:] = neighbor_current_mask
# ========== Step 5: DiT Blocks 逐层处理 ==========
for block in self.blocks:
x = block(x, cross_c, y, attn_mask)
# ========== Step 6: 最终输出投影 ==========
x = self.final_layer(x, y)
# ========== Step 7: 根据模型类型调整输出 ==========
if self._model_type == "score":
# Score 模型:输出需要除以 marginal_prob_std(t)
# 这是因为在 VP-SDE 中,score = ∇_x log p_t(x) = - (x - mean) / std^2
return x / (self.marginal_prob_std(t)[:, None, None] + 1e-6)
elif self._model_type == "x_start":
# x_start 模型:直接预测原始数据
return x
else:
raise ValueError(f"Unknown model type: {self._model_type}")
3.6 路由编码器(RouteEncoder)
python
class RouteEncoder(nn.Module):
"""
路由编码器:将规划路径(route lanes)编码为条件向量
使用 MLP-Mixer 架构处理车道线序列:
1. Channel-wise MLP:处理每个车道点的特征维度
2. Token-wise MLP:处理车道点之间的交互
3. Mixer Block:交替进行通道混合和 token 混合
这种设计能够捕捉车道线的空间结构和拓扑关系
"""
def __init__(self, route_num, lane_len, drop_path_rate=0.3, hidden_dim=192,
tokens_mlp_dim=32, channels_mlp_dim=64):
super().__init__()
self._channel = channels_mlp_dim
# 预处理投影层
self.channel_pre_project = Mlp(in_features=4, hidden_features=channels_mlp_dim,
out_features=channels_mlp_dim, act_layer=nn.GELU, drop=0.)
self.token_pre_project = Mlp(in_features=route_num * lane_len,
hidden_features=tokens_mlp_dim, out_features=tokens_mlp_dim,
act_layer=nn.GELU, drop=0.)
# MLP-Mixer 块
self.Mixer = MixerBlock(tokens_mlp_dim, channels_mlp_dim, drop_path_rate)
# 后处理
self.norm = nn.LayerNorm(channels_mlp_dim)
self.emb_project = Mlp(in_features=channels_mlp_dim, hidden_features=hidden_dim,
out_features=hidden_dim, act_layer=nn.GELU, drop=drop_path_rate)
def forward(self, x):
'''
Args:
x: 路由车道线 [B, P, V, D]
其中 D 包含:x, y, x'-x, y'-y(只取前4维)
Returns:
路由编码 [B, hidden_dim]
'''
# 只取前4维(位置信息)
x = x[..., :4]
B, P, V, _ = x.shape
# 创建掩码:标记无效车道点
mask_v = torch.sum(torch.ne(x[..., :4], 0), dim=-1).to(x.device) == 0
mask_p = torch.sum(~mask_v, dim=-1) == 0 # 无效路径
mask_b = torch.sum(~mask_p, dim=-1) == 0 # 无效批次
x = x.view(B, P * V, -1)
# 过滤无效批次
valid_indices = ~mask_b.view(-1)
x = x[valid_indices]
# MLP-Mixer 处理
x = self.channel_pre_project(x)
x = x.permute(0, 2, 1) # [B, C, T]
x = self.token_pre_project(x)
x = x.permute(0, 2, 1) # [B, T, C]
x = self.Mixer(x)
# 全局平均池化
x = torch.mean(x, dim=1)
# 投影到隐藏维度
x = self.emb_project(self.norm(x))
# 恢复批次维度
x_result = torch.zeros((B, x.shape[-1]), device=x.device)
x_result[valid_indices] = x
return x_result.view(B, -1)
3.6.1 路由编码器输入详解
路由编码器的输入 x 是 route_lanes,它在数据预处理阶段从地图中提取并处理而来。
输入数据结构
| 参数 | 维度 | 说明 | 默认值 |
|---|---|---|---|
| Batch size | B |
批次大小 | - |
| Route lanes num | P |
路由车道数量(冗余维度,实际不使用) | 25 |
| Route points num | V |
每条路由车道的点数量 | 20 |
| Feature dim | D |
每个点的特征维度(原始12维,只使用前4维) | 12 |
原始输入特征(12维)
python
# _lane_polyline_process 函数中,每个点的特征组成:
[
# 前4维(RouteEncoder实际使用)
x, y, # 点坐标(局部坐标系,以自车当前位置为原点)
dx, dy, # 当前点到下一个点的向量(切线方向)
# 后8维(RouteEncoder中被忽略)
x_left-x, y_left-y, # 点到左边界的向量
x_right-x, y_right-y, # 点到右边界的向量
traffic_light[4], # 红绿灯状态(one-hot编码,4维)
]
数据来源和处理流程
路由车道数据的处理流程如下:
1. 从地图API提取所有车道和边界
↓
2. 根据路由ID(route_roadblock_ids)筛选属于规划路径的车道
↓
3. 转换为自车局部坐标系
↓
4. 插值或裁剪到固定点数(默认20个点)
↓
5. 计算每个点的方向向量、边界向量、红绿灯状态
↓
6. 最终生成 route_lanes: [25, 20, 12]
RouteEncoder 内部处理维度变化
python
# 输入: route_lanes [B, P, V, 12]
# ↓ 取前4维
x = x[..., :4] # [B, P*V, 4]
# ↓ Channel-wise 投影
x = channel_pre_project(x) # [B, P*V, 64]
# ↓ 转置为 [B, C, T] 格式
x = x.permute(0, 2, 1) # [B, 64, P*V]
# ↓ Token-wise 投影
x = token_pre_project(x) # [B, 64, 32]
# ↓ 转置回 [B, T, C] 格式
x = x.permute(0, 2, 1) # [B, 32, 64]
# ↓ MixerBlock
x = Mixer(x) # [B, 32, 64]
# ↓ 全局平均池化
x = torch.mean(x, dim=1) # [B, 64]
# ↓ 投影到隐藏维度
x = emb_project(norm(x)) # [B, 192]
# 输出: 路由编码 [B, 192]
关键点说明
| 关键点 | 说明 |
|---|---|
| 冗余维度P | 输入中 P=25 是路由车道数量,但 RouteEncoder 会将其展平为 P*V=500,实际不区分不同车道 |
| 只使用前4维 | 原始输入有12维,但 RouteEncoder 只使用前4维的位置和方向信息,忽略边界和红绿灯 |
| 过滤无效批次 | 会检测掩码并过滤掉完全无效的批次,这些批次返回零向量 |
| 全局平均池化 | 对所有车道点取平均,生成一个全局的路由编码向量 |
| 与时间步融合 | 路由编码最终与时间步编码相加,作为 DiT 中自适应层归一化的条件向量 |
配置参数(默认值)
python
# train_predictor.py 中定义的默认参数
route_len = 20 # 每条路由车道的点数
route_num = 25 # 路由车道数量
lane_len = 20 # 普通车道点数
lane_num = 70 # 普通车道数量
输入输出示意图
输入: route_lanes
[Batch, 25, 20, 12]
↓ 只取前4维
[Batch, 25*20, 4] = [Batch, 500, 4]
↓ 投影 → Mixer → 池化
↓
输出: route_encoding
[Batch, 192]
3.7 解码器(Decoder)中的 DiT 使用
python
class Decoder(nn.Module):
"""
Diffusion Planner 解码器
职责:
- 训练时:对加噪轨迹进行去噪预测
- 推理时:使用 DPM-Solver 从纯噪声迭代生成轨迹
包含:
- DiT 主干网络
- RouteEncoder 路由编码
- 状态归一化/反归一化
- 可选的 guidance 函数(碰撞避免等)
"""
def __init__(self, config):
super().__init__()
dpr = config.decoder_drop_path_rate
self._predicted_neighbor_num = config.predicted_neighbor_num
self._future_len = config.future_len
self._sde = VPSDE_linear() # 使用 VP-SDE
# 初始化 DiT 网络
self.dit = DiT(
sde=self._sde,
route_encoder=RouteEncoder(
config.route_num, config.lane_len,
drop_path_rate=config.encoder_drop_path_rate,
hidden_dim=config.hidden_dim
),
depth=config.decoder_depth,
output_dim=(config.future_len + 1) * 4, # x, y, cos, sin
hidden_dim=config.hidden_dim,
heads=config.num_heads,
dropout=dpr,
model_type=config.diffusion_model_type
)
# 归一化器
self._state_normalizer = config.state_normalizer
self._observation_normalizer = config.observation_normalizer
# Guidance 函数(用于推理时的约束)
self._guidance_fn = config.guidance_fn
@property
def sde(self):
return self._sde
def forward(self, encoder_outputs, inputs):
"""
扩散解码器前向传播
训练流程:
1. 提取当前状态
2. 获取加噪轨迹和时间步
3. DiT 预测去噪结果
推理流程:
1. 从纯噪声初始化
2. 使用 DPM-Solver 快速采样
3. 施加初始状态约束
4. 可选:使用 guidance 进行碰撞避免
"""
# 提取自车和周围车当前状态 [B, P, 4]
ego_current = inputs['ego_current_state'][:, None, :4]
neighbors_current = inputs["neighbor_agents_past"][:, :self._predicted_neighbor_num, -1, :4]
neighbor_current_mask = torch.sum(torch.ne(neighbors_current[..., :4], 0), dim=-1) == 0
inputs["neighbor_current_mask"] = neighbor_current_mask
current_states = torch.cat([ego_current, neighbors_current], dim=1)
B, P, _ = current_states.shape
# 提取场景编码和路由
ego_neighbor_encoding = encoder_outputs['encoding']
route_lanes = inputs['route_lanes']
if self.training:
# ========== 训练模式 ==========
# 获取加噪轨迹 [B, P, (1+V_future)*4]
sampled_trajectories = inputs['sampled_trajectories'].reshape(B, P, -1)
diffusion_time = inputs['diffusion_time'] # [B]
# DiT 预测去噪
return {
"score": self.dit(
sampled_trajectories,
diffusion_time,
ego_neighbor_encoding,
route_lanes,
neighbor_current_mask
).reshape(B, P, -1, 4)
}
else:
# ========== 推理模式 ==========
# 从噪声初始化:当前状态 + 随机噪声
xT = torch.cat([
current_states[:, :, None], # 当前状态
torch.randn(B, P, self._future_len, 4).to(current_states.device) * 0.5 # 噪声
], dim=2).reshape(B, P, -1)
# 初始状态约束:强制第一个时间步为当前状态
def initial_state_constraint(xt, t, step):
xt = xt.reshape(B, P, -1, 4)
xt[:, :, 0, :] = current_states # 约束当前状态
return xt.reshape(B, P, -1)
# 使用 DPM-Solver 进行快速采样
x0 = dpm_sampler(
self.dit,
xT,
other_model_params={
"cross_c": ego_neighbor_encoding,
"route_lanes": route_lanes,
"neighbor_current_mask": neighbor_current_mask
},
dpm_solver_params={
"correcting_xt_fn": initial_state_constraint,
},
model_wrapper_params={
"classifier_fn": self._guidance_fn, # 可选的 guidance
"classifier_kwargs": {...},
"guidance_scale": 0.5,
"guidance_type": "classifier" if self._guidance_fn is not None else "uncond"
},
)
# 反归一化并去除当前状态,只保留未来轨迹
x0 = self._state_normalizer.inverse(x0.reshape(B, P, -1, 4))[:, :, 1:]
return {"prediction": x0}
4. 关键设计亮点
4.1 零初始化策略(Zero Initialization)
在 Diffusion_Planner_Decoder.initialize_weights() 中:
python
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.decoder.dit.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.decoder.dit.final_layer.proj[-1].weight, 0)
nn.init.constant_(self.decoder.dit.final_layer.proj[-1].bias, 0)
原理:训练初期,adaLN 的输出为零,gate 参数也为零,DiT Block 相当于恒等映射(Identity)。这保证了:
- 网络在训练初期行为稳定
- 残差连接能够有效传递梯度
- 条件信息随着训练逐渐"注入"网络
4.2 多智能体联合建模
DiT 同时预测自车和周围车辆的轨迹:
- 输入维度 P = 1(自车)+ predicted_neighbor_num(周围车)
- 自注意力机制天然建模多车交互
- 交叉注意力融合统一场景上下文
4.3 条件信息融合
DiT 融合三类条件信息:
- 时间步 t:通过 TimestepEmbedder 编码,控制去噪进度
- 路由信息:通过 RouteEncoder 编码,提供导航指引
- 场景上下文:通过 Cross-Attention 注入,提供环境感知
4.4 DPM-Solver 快速采样
推理时使用 DPM-Solver(扩散概率模型求解器),相比传统的 DDPM 采样:
- 步数大幅减少(从 1000 步降至 10-20 步)
- 保持生成质量的同时满足实时性要求
- 支持初始状态约束和 classifier guidance
5. 总结
Diffusion Planner 中的 DiT 主干网络将扩散模型与 Transformer 架构深度结合,通过以下创新实现了高效的多智能体轨迹生成:
- adaLN-Zero 条件注入:精细控制每层特征变换,零初始化保证训练稳定
- Cross-Attention 场景融合:将编码器提取的丰富场景上下文注入解码器
- 多智能体联合建模:同时预测自车和周围车辆,通过自注意力建模交互
- 时间步与路由条件:融合扩散进度和导航信息,生成符合规划的轨迹
- DPM-Solver 快速推理:满足自动驾驶实时性要求
这种设计使得 DiT 能够在复杂的城市场景中生成多模态、交互感知且符合动力学约束的未来轨迹。
6. 核心内容总结
6.1 架构总览
┌─────────────────────────────────────────────────────────────────────────┐
│ Diffusion Transformer (DiT) │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 输入层 │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ x: [B, P, T×4] 加噪轨迹 (自车+周围车) │ │
│ │ t: [B] 扩散时间步 │ │
│ │ cross_c: [B, N, D] 场景上下文编码 │ │
│ │ route_lanes: [B, 25, 20, 12] 路由车道信息 │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 预处理层 │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ preproj: [B,P,T×4]→[B,P,D] 输入投影 │ │
│ │ agent_embedding: 添加智能体类型嵌入 │ │
│ │ TimestepEmbedder(t): [B]→[B,D] 时间步编码 │ │
│ │ RouteEncoder(route): [B,25,20,12]→[B,D] 路由编码 │ │
│ │ y = route + t_embed 条件向量 │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ DiT Blocks × N │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ for each block: │ │
│ │ adaLN_modulation(y) → γ, β, gate │ │
│ │ x = γ * LN(x) + β │ │
│ │ x = MHSA(x, x, x, attn_mask) + x # 自注意力 │ │
│ │ x = γ * LN(x) + β │ │
│ │ x = MLP(x) + x # 前馈网络 │ │
│ │ x = gate * x # 自适应门控 │ │
│ │ x = MHCA(x, cross_c, cross_c) + x # 交叉注意力 │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 输出层 │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ FinalLayer: adaLN + proj → [B,P,D]→[B,P,T×4] │ │
│ │ if model_type=="score": output /= marginal_prob_std(t) │ │
│ │ if model_type=="x_start": output = x │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 输出: [B, P, T, 4] 去噪轨迹预测 │
│ │
└─────────────────────────────────────────────────────────────────────────┘
6.2 核心组件总结
| 组件 | 类型 | 作用 | 关键特点 |
|---|---|---|---|
| preproj | Linear | 输入投影 | 将轨迹特征映射到隐藏维度 |
| agent_embedding | Embedding | 智能体类型编码 | 区分自车和周围车 |
| TimestepEmbedder | Sinusoidal | 时间步编码 | 周期性位置编码,支持外推 |
| RouteEncoder | MLP-Mixer | 路由信息编码 | 将车道线序列压缩为条件向量 |
| DiT Block | Transformer | 核心编码单元 | adaLN + MHSA + MHCA + MLP |
| adaLN_modulation | MLP | 自适应层归一化 | 根据条件向量动态调整 γ, β, gate |
| MHSA | Multi-Head | 自注意力 | 建模多智能体交互关系 |
| MHCA | Multi-Head | 交叉注意力 | 融合场景上下文信息 |
| FinalLayer | Linear | 输出投影 | 零初始化保证训练稳定 |
6.3 关键数学公式
前向扩散过程(VP-SDE)
dx = -0.5 * beta_t * x dt + sqrt(beta_t) dW
mean = exp(-0.25 * (beta_max-beta_min) * t^2 - 0.5 * beta_min * t) * x
std = sqrt(1 - exp(2 * mean_log_coeff))
adaLN-Zero 调制
h = linear(y) # 条件向量映射
gamma, beta, gate = split(h) # 分割为三个参数
x = gamma * LN(x) + beta # 自适应归一化
x = gate * x # 自适应门控
Score Matching 损失
L = E_{t,x,ε} [||ε - σ_t * score(x_t, t, y)||^2]
其中 x_t = mean * x + std * ε
x_start Prediction 损失
L = E_{t,x,ε} [||x - model(x_t, t, y)||^2]
6.4 训练与推理流程
训练流程
1. 准备数据: ego_future, neighbors_future, route_lanes, scene_encoding
2. 采样时间步 t ~ Uniform(0, 1)
3. 前向扩散: x_t = mean * x_0 + std * ε
4. DiT 预测: pred = dit(x_t, t, cross_c, route_lanes)
5. 计算损失: MSE(pred, x_0) 或 MSE(pred * std, ε)
6. 反向传播优化
推理流程
1. 初始化: x_T = current_state + noise
2. DPM-Solver 迭代 (10-20步):
x_{t-1} = x_t + dt * (drift + guidance)
3. 约束当前状态: x[:, :, 0, :] = current_state
4. 输出未来轨迹: x_0[:, :, 1:]
6.5 关键设计决策
| 设计决策 | 技术实现 | 目的 |
|---|---|---|
| 零初始化 | adaLN和输出层初始化为0 | 训练初期网络近似恒等映射,保证稳定 |
| 条件注入 | adaLN调制每个Block | 精细控制条件信息的注入位置和强度 |
| 双注意力机制 | MHSA + MHCA | 自注意力建模交互,交叉注意力融合上下文 |
| 两种训练模式 | score / x_start | score模式理论基础强,x_start模式更直观 |
| DPM-Solver | 快速采样算法 | 10-20步完成采样,满足实时性 |
| 初始状态约束 | correcting_xt_fn | 强制第一个时间步为当前观测状态 |
6.6 维度变化速览
输入:
x: [B, P, T×4] # 轨迹序列
t: [B] # 时间步
cross_c: [B, N, D] # 场景编码
route_lanes: [B, 25, 20, 12]
预处理:
x → preproj → [B, P, D]
y = RouteEncoder(route) + TimestepEmbedder(t) → [B, D]
DiT Blocks:
[B, P, D] → Block → [B, P, D] (重复 N 次)
输出:
FinalLayer → [B, P, T×4] → reshape → [B, P, T, 4]
6.7 与其他模块的关系
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Encoder │ │ DiT │ │ Loss Module │
│ (场景编码) │────▶│ (主干网络) │────▶│ (损失计算) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ SDE Module │◀─────────────┘
│ │ (扩散过程) │
│ └─────────────────┘
│ │
│ ▼
│ ┌─────────────────┐
│ │ Guidance Module │
│ │ (碰撞避免) │
└─────────────▶│ │
└─────────────────┘
6.8 性能与优化
| 优化策略 | 实现方式 | 效果 |
|---|---|---|
| 掩码计算 | 只处理有效token | 减少无效计算,支持动态场景 |
| DPM-Solver | 快速采样算法 | 采样步数从1000降至10-20 |
| 零初始化 | 稳定训练初期 | 加速收敛,避免梯度爆炸 |
| 残差连接 | 每个Block都有残差 | 缓解梯度消失,支持深层网络 |
| LayerNorm | adaLN自适应归一化 | 加速训练,提高稳定性 |
精炼核心技术总结
一句话定位
DiT 是基于 Transformer 的扩散主干网络,通过 adaLN-Zero 实现条件注入和多智能体轨迹预测。
核心技术点
- adaLN-Zero:自适应层归一化实现精细条件调制
- 双注意力机制:MHSA+MHCA 分别建模自交互和上下文融合
- 零初始化策略:训练初期网络近似恒等映射,保证稳定
- DPM-Solver 集成:10-20步快速采样,满足实时性
- x_start/Score 双模式:灵活选择训练目标
关键公式速查
adaLN-Zero 调制:
y=s⋅LayerNorm(x)+b y = s \cdot \text{LayerNorm}(x) + b y=s⋅LayerNorm(x)+b
其中 s,bs, bs,b 由条件向量通过线性层生成。
DiT Block 输出:
h′=MHSA(LN(h))+MHCA(LN(h),y)+h \text{h}' = \text{MHSA}(\text{LN}(h)) + \text{MHCA}(\text{LN}(h), y) + h h′=MHSA(LN(h))+MHCA(LN(h),y)+h
模块间一句话关联
DiT 接收 Encoder 的场景编码和 SDE 的时间步,输出预测轨迹到 Loss 模块计算损失,推理时与 Guidance 模块协同生成安全轨迹。