Diffusion-Planner:一、扩散Transformer(Diffusion Transformer)主干网络

Diffusion Transformer (DiT) 主干网络原理讲解与代码注释

1. 概述

Diffusion Transformer(DiT)是 Diffusion Planner 的解码器核心主干网络 ,负责在扩散过程中逐步去噪,生成自车与周围车辆的未来轨迹。它将传统的 U-Net 扩散主干替换为 Transformer 架构 ,并引入 adaLN-Zero(自适应层归一化)条件注入机制,使网络能够高效地融合时间步、场景上下文与路由信息。

DiT 的整体设计理念来源于图像生成领域的 DiT(Diffusion Transformer),其核心思想是:

  • 将轨迹数据视为 token 序列
  • 利用 Transformer 的自注意力机制建模多智能体间的交互
  • 通过自适应归一化将扩散时间步与条件信息注入网络

2. 核心原理

2.1 扩散过程回顾

扩散模型包含两个过程:

  1. 前向过程(Forward Process):逐步向数据添加高斯噪声,直到数据变为纯噪声
  2. 反向过程(Reverse Process):学习从噪声中逐步恢复原始数据

在 Diffusion Planner 中,使用 VP-SDE(Variance Preserving Stochastic Differential Equation) 描述扩散过程:

dx=−β(t)2x dt+β(t) dWt d x = -\frac{\beta(t)}{2} x \, dt + \sqrt{\beta(t)} \, dW_t dx=−2β(t)xdt+β(t) dWt

其中:

  • β(t)=(βmax−βmin)⋅t+βmin\beta(t) = (\beta_{max} - \beta_{min}) \cdot t + \beta_{min}β(t)=(βmax−βmin)⋅t+βmin,t∈0,1t \in 0, 1t∈0,1
  • βmin=0.1\beta_{min} = 0.1βmin=0.1,βmax=20.0\beta_{max} = 20.0βmax=20.0
  • WtW_tWt 是标准维纳过程(布朗运动)

2.1.1 加噪过程详解

DiT 的输入 xxx 是通过 VP-SDE 的边缘分布采样得到的加噪轨迹。具体流程如下:

python 复制代码
# 1. 准备数据:当前状态 + 未来轨迹
current_states = torch.cat([ego_current[:, None], neighbors_current], dim=1)  # [B, P, 4]
gt_future = torch.cat([ego_future[:, None, :, :], neighbors_future[..., :]], dim=1)  # [B, P, T, 4]
all_gt = torch.cat([current_states[:, :, None, :], norm(gt_future)], dim=2)  # [B, P, 1+T, 4]

# 2. 随机采样时间步 t ∈ [eps, 1-eps]
t = torch.rand(B, device=gt_future.device) * (1 - eps) + eps  # [B]

# 3. 生成标准高斯噪声 z
z = torch.randn_like(gt_future, device=gt_future.device)  # [B, P, T, 4]

# 4. 使用VP-SDE边缘分布计算加噪轨迹
mean, std = marginal_prob(all_gt[..., 1:, :], t)  # 仅对未来轨迹加噪,保持当前状态不变
xT = mean + std * z  # 核心加噪公式
xT = torch.cat([all_gt[:, :, :1, :], xT], dim=2)  # 拼接当前状态
数学原理

VP-SDE 的边缘分布满足以下公式:

pt(xt∣x0)=N(xt;μ(t)⋅x0,σ(t)2⋅I) p_t(x_t | x_0) = \mathcal{N}(x_t; \mu(t) \cdot x_0, \sigma(t)^2 \cdot I) pt(xt∣x0)=N(xt;μ(t)⋅x0,σ(t)2⋅I)

其中:

  • 均值 : μ(t)=exp⁡(−14(βmax−βmin)t2−12βmint)\mu(t) = \exp\left(-\frac{1}{4}(\beta_{\text{max}} - \beta_{\text{min}})t^2 - \frac{1}{2}\beta_{\text{min}}t\right)μ(t)=exp(−41(βmax−βmin)t2−21βmint)
  • 标准差 : σ(t)=1−exp⁡(−12(βmax−βmin)t2−βmint)\sigma(t) = \sqrt{1 - \exp\left(-\frac{1}{2}(\beta_{\text{max}} - \beta_{\text{min}})t^2 - \beta_{\text{min}}t\right)}σ(t)=1−exp(−21(βmax−βmin)t2−βmint)
  • β(t)=(βmax−βmin)⋅t+βmin\beta(t) = (\beta_{\text{max}} - \beta_{\text{min}}) \cdot t + \beta_{\text{min}}β(t)=(βmax−βmin)⋅t+βmin,线性噪声调度
关键实现细节
步骤 代码位置 作用
时间步采样 t = torch.rand(B) * (1 - eps) + eps 避免t=0(无噪声)和t=1(完全噪声)的边界情况
噪声生成 z = torch.randn_like(gt_future) 生成与轨迹同形状的标准高斯噪声
边缘分布计算 mean, std = marginal_prob(all_gt[..., 1:, :], t) 根据时间步计算均值和标准差
加噪公式 xT = mean + std * z 从边缘分布采样
当前状态保持 torch.cat([all_gt[:, :, :1, :], xT], dim=2) 当前状态(t=0)不参与加噪
加噪特性
  1. 时间步控制噪声强度

    • t→0:噪声很小,xT≈x0x_T \approx x_0xT≈x0(接近原始轨迹)
    • t→1:噪声很大,xTx_TxT 接近纯随机噪声
  2. 当前状态约束

    • 代码中仅对未来轨迹 all_gt[..., 1:, :] 加噪
    • 当前状态 all_gt[:, :, :1, :] 直接保留,不添加噪声
    • 这保证了扩散过程从已知的当前状态开始
  3. VP-SDE 参数

    • βmin=0.1\beta_{\text{min}} = 0.1βmin=0.1,βmax=20.0\beta_{\text{max}} = 20.0βmax=20.0
    • 线性噪声调度,噪声强度随时间步线性增长
完整加噪示意图
复制代码
原始轨迹 x_0:      [当前状态] + [未来轨迹T1, T2, ..., Tn]
                          |              |
                          |              v
                          |     + 高斯噪声 z ~ N(0,1)
                          |              |
                          |              v
                          |     x_t = mean + std * z
                          |              |
                          +---------> [当前状态] + [加噪未来轨迹]
                                      = x_T (DiT输入)
与DiT前向传播的关联

在训练时,xT 作为 sampled_trajectories 输入到 DiT:

python 复制代码
merged_inputs = {
    **inputs,
    "sampled_trajectories": xT,  # 加噪轨迹
    "diffusion_time": t,          # 对应时间步
}
_, decoder_output = model(merged_inputs)

DiT 学习在给定加噪轨迹 xTx_TxT 和时间步 ttt 的情况下,预测去噪后的轨迹或分数。

2.2 DiT 架构设计

DiT 的整体架构包含以下核心组件:

复制代码
输入轨迹 token (x) ──→ Pre-projection ──→ 嵌入层 ──┐
                                                    ├──→ DiT Blocks × N ──→ Final Layer ──→ 输出
时间步 t ──→ Timestep Embedder ──┐                 │
                                 ├──→ 条件融合 (y) ──┘
路由信息 ──→ RouteEncoder ────────┘
                                    ↑
场景上下文 ──→ Cross-Attention ──────┘
2.2.1 时间步嵌入(Timestep Embedding)

扩散模型需要在每个时间步 ttt 进行去噪,因此需要将标量时间步编码为向量表示。DiT 采用 正弦位置编码 结合 MLP:

freqi=exp⁡(−log⁡(max_period)⋅idim/2),i=0,1,...,dim/2−1 \text{freq}_i = \exp\left(-\log(max\_period) \cdot \frac{i}{dim/2}\right), \quad i = 0, 1, ..., dim/2 - 1 freqi=exp(−log(max_period)⋅dim/2i),i=0,1,...,dim/2−1

emb=cos⁡(t⋅freq),sin⁡(t⋅freq) \text{emb} = \\cos(t \\cdot \\text{freq}), \\sin(t \\cdot \\text{freq}) emb=cos(t⋅freq),sin(t⋅freq)

这种编码方式能够捕捉时间步的周期性和连续性特征。

2.2.2 自适应层归一化(adaLN-Zero)

adaLN-Zero 是 DiT 的核心创新,它将条件信息(时间步 + 路由)通过一个小型 MLP 生成六个缩放/偏移/门控参数:

shiftmsa,scalemsa,gatemsa,shiftmlp,scalemlp,gatemlp=adaLN_modulation(y).chunk(6) \text{shift}{msa}, \text{scale}{msa}, \text{gate}{msa}, \text{shift}{mlp}, \text{scale}{mlp}, \text{gate}{mlp} = \text{adaLN\_modulation}(y).\text{chunk}(6) shiftmsa,scalemsa,gatemsa,shiftmlp,scalemlp,gatemlp=adaLN_modulation(y).chunk(6)

在每个 DiT Block 中:

  1. 自注意力前 :对输入进行调制 x=x⋅(1+scale)+shiftx = x \cdot (1 + \text{scale}) + \text{shift}x=x⋅(1+scale)+shift
  2. 自注意力后 :通过门控 x=x+gate⋅Attention(modulated_x)x = x + \text{gate} \cdot \text{Attention}(\text{modulated\_x})x=x+gate⋅Attention(modulated_x)
  3. MLP 前:再次调制
  4. MLP 后:再次门控

这种设计使得条件信息能够精细地控制每一层的特征变换,且初始化时将 adaLN 输出置零,保证训练初期的稳定性。

2.2.3 交叉注意力(Cross-Attention)

DiT Block 在自注意力和 MLP 之后,还包含一个 交叉注意力层,用于融合编码器输出的场景上下文信息:

CrossAttn(Q=norm(x),K=cross_c,V=cross_c) \text{CrossAttn}(Q=\text{norm}(x), K=\text{cross\_c}, V=\text{cross\_c}) CrossAttn(Q=norm(x),K=cross_c,V=cross_c)

其中 cross_c 是编码器对 agents、静态物体和车道线的统一编码,使 DiT 能够感知完整的驾驶场景。

2.2.4 智能体类型嵌入

DiT 通过可学习的嵌入区分自车和周围车辆:

  • 自车(ego):使用 agent_embedding.weight[0]
  • 周围车(neighbor):使用 agent_embedding.weight[1]

这使得网络能够针对自车和周围车辆学习不同的轨迹生成模式。


3. 代码详解与注释

3.1 调制函数(Modulation Functions)

python 复制代码
import math
import torch
import torch.nn as nn
from timm.models.layers import Mlp


def modulate(x, shift, scale, only_first=False):
    """
    自适应层归一化调制函数:对输入 x 进行缩放和平移
    
    公式: x = x * (1 + scale) + shift
    
    Args:
        x: 输入特征 [B, P, D] 或 [B, D]
        shift: 平移参数 [B, D]
        scale: 缩放参数 [B, D]
        only_first: 是否只调制第一个 token(用于特定条件注入)
    
    Returns:
        调制后的特征
    """
    if only_first:
        # 仅对第一个 token(自车)进行调制,保持其他 token 不变
        x_first, x_rest = x[:, :1], x[:, 1:]
        x = torch.cat([x_first * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), x_rest], dim=1)
    else:
        # 对所有 token 进行调制
        x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

    return x


def scale(x, scale, only_first=False):
    """
    仅缩放调制(无平移)
    
    公式: x = x * (1 + scale)
    
    用于 FinalLayer 中的输出调制
    """
    if only_first:
        x_first, x_rest = x[:, :1], x[:, 1:]
        x = torch.cat([x_first * (1 + scale.unsqueeze(1)), x_rest], dim=1)
    else:
        x = x * (1 + scale.unsqueeze(1))

    return x

3.2 时间步嵌入器(TimestepEmbedder)

python 复制代码
class TimestepEmbedder(nn.Module):
    """
    时间步嵌入器:将标量时间步 t ∈ [0, 1] 映射为向量表示
    
    原理:
    1. 使用正弦/余弦位置编码捕捉时间步的周期性特征
    2. 通过 MLP 将位置编码映射到隐藏维度
    
    这种编码方式类似于 Transformer 的位置编码,但用于连续的时间步而非离散的序列位置
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            # 第一层:将频率编码映射到隐藏维度
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),  # Swish 激活函数,平滑且非单调
            # 第二层:进一步变换
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        创建正弦时间步嵌入(Sinusoidal Timestep Embeddings)
        
        数学原理:
        对于维度 i ∈ [0, dim/2):
        - 频率: freq_i = exp(-log(max_period) * i / (dim/2))
        - 编码: [cos(t * freq_i), sin(t * freq_i)]
        
        这种编码的优势:
        1. 能够表示任意连续时间步
        2. 不同维度具有不同频率,形成多尺度表示
        3. 对于相近的时间步,编码也相近(局部平滑性)
        
        Args:
            t: 时间步张量 [B],取值范围通常为 [0, 1]
            dim: 输出编码维度
            max_period: 控制最低频率,越大则频率范围越广
        
        Returns:
            时间步编码 [B, dim]
        """
        half = dim // 2
        # 计算频率:从高频到低频的对数均匀分布
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        
        # 计算角度:t * freq
        args = t[:, None].float() * freqs[None]
        
        # 拼接正弦和余弦编码
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        
        # 如果维度为奇数,补零
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        """
        前向传播:将时间步 t 转换为嵌入向量
        
        Args:
            t: [B] 批次时间步
        
        Returns:
            t_emb: [B, hidden_size] 时间步嵌入
        """
        # 1. 生成正弦位置编码
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        # 2. 通过 MLP 变换到目标维度
        t_emb = self.mlp(t_freq)
        return t_emb

3.3 DiT 基础块(DiTBlock)

python 复制代码
class DiTBlock(nn.Module):
    """
    DiT 基础块:结合自注意力、交叉注意力和 adaLN-Zero 条件注入
    
    结构(按执行顺序):
    1. 自适应层归一化 + 自注意力(Self-Attention)
    2. 自适应层归一化 + MLP
    3. 交叉注意力(Cross-Attention)融合场景上下文
    4. MLP 进一步变换
    
    条件注入机制(adaLN-Zero):
    - 通过条件向量 y 生成 6 个参数:shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
    - shift/scale 用于调制输入特征
    - gate 用于控制残差连接的强度(零初始化保证训练稳定)
    """
    def __init__(self, dim=192, heads=6, dropout=0.1, mlp_ratio=4.0):
        super().__init__()
        # ---- 自注意力分支 ----
        self.norm1 = nn.LayerNorm(dim)  # 预归一化
        self.attn = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
        
        # ---- MLP 分支 ----
        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)  # MLP 隐藏层维度 = dim * 4
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp1 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
                        act_layer=approx_gelu, drop=0)
        
        # ---- adaLN-Zero 调制网络 ----
        # 输入:条件向量 y [B, dim]
        # 输出:6 * dim 参数(shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim, bias=True)
        )
        
        # ---- 交叉注意力分支(融合场景上下文)----
        self.norm3 = nn.LayerNorm(dim)
        self.cross_attn = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
        self.norm4 = nn.LayerNorm(dim)
        self.mlp2 = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
                        act_layer=approx_gelu, drop=0)

    def forward(self, x, cross_c, y, attn_mask):
        """
        DiTBlock 前向传播
        
        Args:
            x: 输入轨迹 token [B, P, D],P = 1 + predicted_neighbor_num
            cross_c: 场景上下文编码 [B, N, D],来自 Encoder
            y: 条件向量 [B, D],融合时间步和路由信息
            attn_mask: 邻居车辆掩码 [B, P],用于屏蔽不存在的车辆
        
        Returns:
            x: 变换后的特征 [B, P, D]
        """
        # ========== Step 1: adaLN-Zero 参数生成 ==========
        # 从条件 y 生成 6 个调制参数,每个参数 [B, D]
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN_modulation(y).chunk(6, dim=1)
        
        # ========== Step 2: 自注意力(Self-Attention)==========
        # 2.1 自适应层归一化调制
        modulated_x = modulate(self.norm1(x), shift_msa, scale_msa)
        # 2.2 自注意力计算 + 门控残差连接
        # attn_mask 用于屏蔽无效邻居车辆,保证自车始终参与计算
        x = x + gate_msa.unsqueeze(1) * self.attn(
            modulated_x, modulated_x, modulated_x, 
            key_padding_mask=attn_mask
        )[0]
        
        # ========== Step 3: MLP 变换 ==========
        # 3.1 再次调制
        modulated_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
        # 3.2 MLP + 门控残差
        x = x + gate_mlp.unsqueeze(1) * self.mlp1(modulated_x)
        
        # ========== Step 4: 交叉注意力(Cross-Attention)==========
        # 将轨迹特征与场景上下文融合,Q 来自轨迹,K/V 来自场景
        x = self.cross_attn(self.norm3(x), cross_c, cross_c)[0]
        
        # ========== Step 5: 最终 MLP ==========
        x = self.mlp2(self.norm4(x))
        
        return x

3.4 输出层(FinalLayer)

python 复制代码
class FinalLayer(nn.Module):
    """
    DiT 最终输出层:将隐藏特征映射到轨迹输出空间
    
    结构:
    1. adaLN 调制
    2. LayerNorm + Linear + GELU + LayerNorm + Linear
    
    输出维度:output_size = (future_len + 1) * 4
    - future_len: 预测的未来时间步数
    - 4: (x, y, cos_heading, sin_heading)
    - +1: 包含当前状态
    """
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size)
        
        # 投影网络:hidden_size → hidden_size*4 → output_size
        self.proj = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size * 4, bias=True),
            nn.GELU(approximate="tanh"),
            nn.LayerNorm(hidden_size * 4),
            nn.Linear(hidden_size * 4, output_size, bias=True)
        )
        
        # adaLN 调制(只生成 shift 和 scale,无 gate)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, y):
        """
        Args:
            x: 隐藏特征 [B, P, hidden_size]
            y: 条件向量 [B, hidden_size]
        
        Returns:
            轨迹预测 [B, P, output_size]
        """
        B, P, _ = x.shape
        
        # 生成 shift 和 scale 参数
        shift, scale = self.adaLN_modulation(y).chunk(2, dim=1)
        
        # 调制输入特征
        x = modulate(self.norm_final(x), shift, scale)
        
        # 投影到输出空间
        x = self.proj(x)
        return x

3.5 DiT 主网络(DiT)

python 复制代码
class DiT(nn.Module):
    """
    Diffusion Transformer (DiT) 主网络
    
    职责:在扩散过程中,根据当前噪声轨迹、时间步和场景条件,预测去噪后的轨迹
    
    支持两种预测模式:
    1. "x_start": 直接预测原始轨迹 x_0
    2. "score": 预测分数(score function),需要除以 marginal_prob_std
    
    输入:
    - x: 当前噪声轨迹 [B, P, (1+V_future)*4],加噪过程采用 VP-SDE(Variance Preserving SDE) 的边缘分布采样方法。
    - t: 扩散时间步 [B]
    - cross_c: 场景上下文编码 [B, N, D]
    - route_lanes: 路由车道线信息
    - neighbor_current_mask: 邻居车辆存在掩码 [B, P]
    
    输出:
    - 预测轨迹 [B, P, (1+V_future)*4]
    """
    def __init__(self, sde: SDE, route_encoder: nn.Module, depth, output_dim, 
                 hidden_dim=192, heads=6, dropout=0.1, mlp_ratio=4.0, 
                 model_type="x_start"):
        super().__init__()
        
        assert model_type in ["score", "x_start"], f"Unknown model type: {model_type}"
        self._model_type = model_type
        
        # ---- 路由编码器 ----
        self.route_encoder = route_encoder
        
        # ---- 智能体类型嵌入 ----
        # 0: 自车(ego), 1: 周围车(neighbor)
        self.agent_embedding = nn.Embedding(2, hidden_dim)
        
        # ---- 输入投影层 ----
        # 将轨迹维度投影到隐藏维度
        self.preproj = Mlp(in_features=output_dim, hidden_features=512, 
                           out_features=hidden_dim, act_layer=nn.GELU, drop=0.)
        
        # ---- 时间步嵌入器 ----
        self.t_embedder = TimestepEmbedder(hidden_dim)
        
        # ---- DiT Blocks 堆叠 ----
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_dim, heads, dropout, mlp_ratio) 
            for i in range(depth)
        ])
        
        # ---- 最终输出层 ----
        self.final_layer = FinalLayer(hidden_dim, output_dim)
        
        # ---- SDE 相关 ----
        self._sde = sde
        self.marginal_prob_std = self._sde.marginal_prob_std
               
    @property
    def model_type(self):
        return self._model_type

    def forward(self, x, t, cross_c, route_lanes, neighbor_current_mask):
        """
        DiT 前向传播
        
        Args:
            x: (B, P, output_dim) 当前轨迹状态(训练时加噪,推理时从噪声开始)
            t: (B,) 扩散时间步
            cross_c: (B, N, D) 场景上下文编码(来自 Encoder)
            route_lanes: 路由车道线信息
            neighbor_current_mask: (B, P) 邻居掩码,True 表示该位置无车辆
        
        Returns:
            预测结果 [B, P, output_dim]
        """
        B, P, _ = x.shape
        
        # ========== Step 1: 输入投影 ==========
        # 将轨迹从原始维度投影到隐藏维度
        x = self.preproj(x)

        # ========== Step 2: 添加智能体类型嵌入 ==========
        # 自车使用 embedding[0],所有周围车使用 embedding[1]
        x_embedding = torch.cat([
            self.agent_embedding.weight[0][None, :],           # [1, D] 自车
            self.agent_embedding.weight[1][None, :].expand(P - 1, -1)  # [P-1, D] 周围车
        ], dim=0)  # (P, D)
        x_embedding = x_embedding[None, :, :].expand(B, -1, -1)  # (B, P, D)
        x = x + x_embedding  # 添加类型信息

        # ========== Step 3: 条件向量生成 ==========
        # 3.1 编码路由信息
        route_encoding = self.route_encoder(route_lanes)
        y = route_encoding
        # 3.2 融合时间步嵌入
        y = y + self.t_embedder(t)      

        # ========== Step 4: 注意力掩码准备 ==========
        # 确保自车(第0位)始终参与计算,邻居根据 mask 决定
        attn_mask = torch.zeros((B, P), dtype=torch.bool, device=x.device)
        attn_mask[:, 1:] = neighbor_current_mask
        
        # ========== Step 5: DiT Blocks 逐层处理 ==========
        for block in self.blocks:
            x = block(x, cross_c, y, attn_mask)  
            
        # ========== Step 6: 最终输出投影 ==========
        x = self.final_layer(x, y)
        
        # ========== Step 7: 根据模型类型调整输出 ==========
        if self._model_type == "score":
            # Score 模型:输出需要除以 marginal_prob_std(t)
            # 这是因为在 VP-SDE 中,score = ∇_x log p_t(x) = - (x - mean) / std^2
            return x / (self.marginal_prob_std(t)[:, None, None] + 1e-6)
        elif self._model_type == "x_start":
            # x_start 模型:直接预测原始数据
            return x
        else:
            raise ValueError(f"Unknown model type: {self._model_type}")

3.6 路由编码器(RouteEncoder)

python 复制代码
class RouteEncoder(nn.Module):
    """
    路由编码器:将规划路径(route lanes)编码为条件向量
    
    使用 MLP-Mixer 架构处理车道线序列:
    1. Channel-wise MLP:处理每个车道点的特征维度
    2. Token-wise MLP:处理车道点之间的交互
    3. Mixer Block:交替进行通道混合和 token 混合
    
    这种设计能够捕捉车道线的空间结构和拓扑关系
    """
    def __init__(self, route_num, lane_len, drop_path_rate=0.3, hidden_dim=192, 
                 tokens_mlp_dim=32, channels_mlp_dim=64):
        super().__init__()

        self._channel = channels_mlp_dim

        # 预处理投影层
        self.channel_pre_project = Mlp(in_features=4, hidden_features=channels_mlp_dim, 
                                       out_features=channels_mlp_dim, act_layer=nn.GELU, drop=0.)
        self.token_pre_project = Mlp(in_features=route_num * lane_len, 
                                     hidden_features=tokens_mlp_dim, out_features=tokens_mlp_dim, 
                                     act_layer=nn.GELU, drop=0.)

        # MLP-Mixer 块
        self.Mixer = MixerBlock(tokens_mlp_dim, channels_mlp_dim, drop_path_rate)

        # 后处理
        self.norm = nn.LayerNorm(channels_mlp_dim)
        self.emb_project = Mlp(in_features=channels_mlp_dim, hidden_features=hidden_dim, 
                               out_features=hidden_dim, act_layer=nn.GELU, drop=drop_path_rate)

    def forward(self, x):
        '''
        Args:
            x: 路由车道线 [B, P, V, D]
               其中 D 包含:x, y, x'-x, y'-y(只取前4维)
        Returns:
            路由编码 [B, hidden_dim]
        '''
        # 只取前4维(位置信息)
        x = x[..., :4]

        B, P, V, _ = x.shape
        
        # 创建掩码:标记无效车道点
        mask_v = torch.sum(torch.ne(x[..., :4], 0), dim=-1).to(x.device) == 0
        mask_p = torch.sum(~mask_v, dim=-1) == 0  # 无效路径
        mask_b = torch.sum(~mask_p, dim=-1) == 0  # 无效批次
        
        x = x.view(B, P * V, -1)

        # 过滤无效批次
        valid_indices = ~mask_b.view(-1) 
        x = x[valid_indices] 

        # MLP-Mixer 处理
        x = self.channel_pre_project(x)
        x = x.permute(0, 2, 1)  # [B, C, T]
        x = self.token_pre_project(x)
        x = x.permute(0, 2, 1)  # [B, T, C]
        x = self.Mixer(x)

        # 全局平均池化
        x = torch.mean(x, dim=1)

        # 投影到隐藏维度
        x = self.emb_project(self.norm(x))

        # 恢复批次维度
        x_result = torch.zeros((B, x.shape[-1]), device=x.device)
        x_result[valid_indices] = x
        
        return x_result.view(B, -1)

3.6.1 路由编码器输入详解

路由编码器的输入 xroute_lanes,它在数据预处理阶段从地图中提取并处理而来。

输入数据结构
参数 维度 说明 默认值
Batch size B 批次大小 -
Route lanes num P 路由车道数量(冗余维度,实际不使用) 25
Route points num V 每条路由车道的点数量 20
Feature dim D 每个点的特征维度(原始12维,只使用前4维) 12
原始输入特征(12维)
python 复制代码
# _lane_polyline_process 函数中,每个点的特征组成:
[
    # 前4维(RouteEncoder实际使用)
    x, y,  # 点坐标(局部坐标系,以自车当前位置为原点)
    dx, dy,  # 当前点到下一个点的向量(切线方向)
    
    # 后8维(RouteEncoder中被忽略)
    x_left-x, y_left-y,  # 点到左边界的向量
    x_right-x, y_right-y,  # 点到右边界的向量
    traffic_light[4],  # 红绿灯状态(one-hot编码,4维)
]
数据来源和处理流程

路由车道数据的处理流程如下:

复制代码
1. 从地图API提取所有车道和边界
   ↓
2. 根据路由ID(route_roadblock_ids)筛选属于规划路径的车道
   ↓
3. 转换为自车局部坐标系
   ↓
4. 插值或裁剪到固定点数(默认20个点)
   ↓
5. 计算每个点的方向向量、边界向量、红绿灯状态
   ↓
6. 最终生成 route_lanes: [25, 20, 12]
RouteEncoder 内部处理维度变化
python 复制代码
# 输入: route_lanes [B, P, V, 12]
# ↓ 取前4维
x = x[..., :4]  # [B, P*V, 4]
# ↓ Channel-wise 投影
x = channel_pre_project(x)  # [B, P*V, 64]
# ↓ 转置为 [B, C, T] 格式
x = x.permute(0, 2, 1)  # [B, 64, P*V]
# ↓ Token-wise 投影
x = token_pre_project(x)  # [B, 64, 32]
# ↓ 转置回 [B, T, C] 格式
x = x.permute(0, 2, 1)  # [B, 32, 64]
# ↓ MixerBlock
x = Mixer(x)  # [B, 32, 64]
# ↓ 全局平均池化
x = torch.mean(x, dim=1)  # [B, 64]
# ↓ 投影到隐藏维度
x = emb_project(norm(x))  # [B, 192]
# 输出: 路由编码 [B, 192]
关键点说明
关键点 说明
冗余维度P 输入中 P=25 是路由车道数量,但 RouteEncoder 会将其展平为 P*V=500,实际不区分不同车道
只使用前4维 原始输入有12维,但 RouteEncoder 只使用前4维的位置和方向信息,忽略边界和红绿灯
过滤无效批次 会检测掩码并过滤掉完全无效的批次,这些批次返回零向量
全局平均池化 对所有车道点取平均,生成一个全局的路由编码向量
与时间步融合 路由编码最终与时间步编码相加,作为 DiT 中自适应层归一化的条件向量
配置参数(默认值)
python 复制代码
# train_predictor.py 中定义的默认参数
route_len = 20          # 每条路由车道的点数
route_num = 25          # 路由车道数量
lane_len = 20           # 普通车道点数
lane_num = 70           # 普通车道数量
输入输出示意图
复制代码
输入: route_lanes
[Batch, 25, 20, 12]
  ↓ 只取前4维
[Batch, 25*20, 4] = [Batch, 500, 4]
  ↓ 投影 → Mixer → 池化
  ↓
输出: route_encoding
[Batch, 192]

3.7 解码器(Decoder)中的 DiT 使用

python 复制代码
class Decoder(nn.Module):
    """
    Diffusion Planner 解码器
    
    职责:
    - 训练时:对加噪轨迹进行去噪预测
    - 推理时:使用 DPM-Solver 从纯噪声迭代生成轨迹
    
    包含:
    - DiT 主干网络
    - RouteEncoder 路由编码
    - 状态归一化/反归一化
    - 可选的 guidance 函数(碰撞避免等)
    """
    def __init__(self, config):
        super().__init__()

        dpr = config.decoder_drop_path_rate
        self._predicted_neighbor_num = config.predicted_neighbor_num
        self._future_len = config.future_len
        self._sde = VPSDE_linear()  # 使用 VP-SDE

        # 初始化 DiT 网络
        self.dit = DiT(
            sde=self._sde, 
            route_encoder=RouteEncoder(
                config.route_num, config.lane_len, 
                drop_path_rate=config.encoder_drop_path_rate, 
                hidden_dim=config.hidden_dim
            ),
            depth=config.decoder_depth, 
            output_dim=(config.future_len + 1) * 4,  # x, y, cos, sin
            hidden_dim=config.hidden_dim, 
            heads=config.num_heads, 
            dropout=dpr,
            model_type=config.diffusion_model_type
        )
        
        # 归一化器
        self._state_normalizer = config.state_normalizer
        self._observation_normalizer = config.observation_normalizer
        
        # Guidance 函数(用于推理时的约束)
        self._guidance_fn = config.guidance_fn
    
    @property
    def sde(self):
        return self._sde
    
    def forward(self, encoder_outputs, inputs):
        """
        扩散解码器前向传播
        
        训练流程:
        1. 提取当前状态
        2. 获取加噪轨迹和时间步
        3. DiT 预测去噪结果
        
        推理流程:
        1. 从纯噪声初始化
        2. 使用 DPM-Solver 快速采样
        3. 施加初始状态约束
        4. 可选:使用 guidance 进行碰撞避免
        """
        # 提取自车和周围车当前状态 [B, P, 4]
        ego_current = inputs['ego_current_state'][:, None, :4]
        neighbors_current = inputs["neighbor_agents_past"][:, :self._predicted_neighbor_num, -1, :4]
        neighbor_current_mask = torch.sum(torch.ne(neighbors_current[..., :4], 0), dim=-1) == 0
        inputs["neighbor_current_mask"] = neighbor_current_mask

        current_states = torch.cat([ego_current, neighbors_current], dim=1)
        B, P, _ = current_states.shape

        # 提取场景编码和路由
        ego_neighbor_encoding = encoder_outputs['encoding']
        route_lanes = inputs['route_lanes']

        if self.training:
            # ========== 训练模式 ==========
            # 获取加噪轨迹 [B, P, (1+V_future)*4]
            sampled_trajectories = inputs['sampled_trajectories'].reshape(B, P, -1)
            diffusion_time = inputs['diffusion_time']  # [B]

            # DiT 预测去噪
            return {
                "score": self.dit(
                    sampled_trajectories, 
                    diffusion_time,
                    ego_neighbor_encoding,
                    route_lanes,
                    neighbor_current_mask
                ).reshape(B, P, -1, 4)
            }
        else:
            # ========== 推理模式 ==========
            # 从噪声初始化:当前状态 + 随机噪声
            xT = torch.cat([
                current_states[:, :, None],  # 当前状态
                torch.randn(B, P, self._future_len, 4).to(current_states.device) * 0.5  # 噪声
            ], dim=2).reshape(B, P, -1)

            # 初始状态约束:强制第一个时间步为当前状态
            def initial_state_constraint(xt, t, step):
                xt = xt.reshape(B, P, -1, 4)
                xt[:, :, 0, :] = current_states  # 约束当前状态
                return xt.reshape(B, P, -1)
            
            # 使用 DPM-Solver 进行快速采样
            x0 = dpm_sampler(
                self.dit,
                xT,
                other_model_params={
                    "cross_c": ego_neighbor_encoding, 
                    "route_lanes": route_lanes,
                    "neighbor_current_mask": neighbor_current_mask                            
                },
                dpm_solver_params={
                    "correcting_xt_fn": initial_state_constraint,
                },
                model_wrapper_params={
                    "classifier_fn": self._guidance_fn,  # 可选的 guidance
                    "classifier_kwargs": {...},
                    "guidance_scale": 0.5,
                    "guidance_type": "classifier" if self._guidance_fn is not None else "uncond"
                },
            )
            
            # 反归一化并去除当前状态,只保留未来轨迹
            x0 = self._state_normalizer.inverse(x0.reshape(B, P, -1, 4))[:, :, 1:]

            return {"prediction": x0}

4. 关键设计亮点

4.1 零初始化策略(Zero Initialization)

Diffusion_Planner_Decoder.initialize_weights() 中:

python 复制代码
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.decoder.dit.blocks:
    nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
    nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

# Zero-out output layers:
nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.decoder.dit.final_layer.proj[-1].weight, 0)
nn.init.constant_(self.decoder.dit.final_layer.proj[-1].bias, 0)

原理:训练初期,adaLN 的输出为零,gate 参数也为零,DiT Block 相当于恒等映射(Identity)。这保证了:

  1. 网络在训练初期行为稳定
  2. 残差连接能够有效传递梯度
  3. 条件信息随着训练逐渐"注入"网络

4.2 多智能体联合建模

DiT 同时预测自车和周围车辆的轨迹:

  • 输入维度 P = 1(自车)+ predicted_neighbor_num(周围车)
  • 自注意力机制天然建模多车交互
  • 交叉注意力融合统一场景上下文

4.3 条件信息融合

DiT 融合三类条件信息:

  1. 时间步 t:通过 TimestepEmbedder 编码,控制去噪进度
  2. 路由信息:通过 RouteEncoder 编码,提供导航指引
  3. 场景上下文:通过 Cross-Attention 注入,提供环境感知

4.4 DPM-Solver 快速采样

推理时使用 DPM-Solver(扩散概率模型求解器),相比传统的 DDPM 采样:

  • 步数大幅减少(从 1000 步降至 10-20 步)
  • 保持生成质量的同时满足实时性要求
  • 支持初始状态约束和 classifier guidance

5. 总结

Diffusion Planner 中的 DiT 主干网络将扩散模型与 Transformer 架构深度结合,通过以下创新实现了高效的多智能体轨迹生成:

  1. adaLN-Zero 条件注入:精细控制每层特征变换,零初始化保证训练稳定
  2. Cross-Attention 场景融合:将编码器提取的丰富场景上下文注入解码器
  3. 多智能体联合建模:同时预测自车和周围车辆,通过自注意力建模交互
  4. 时间步与路由条件:融合扩散进度和导航信息,生成符合规划的轨迹
  5. DPM-Solver 快速推理:满足自动驾驶实时性要求

这种设计使得 DiT 能够在复杂的城市场景中生成多模态、交互感知且符合动力学约束的未来轨迹。


6. 核心内容总结

6.1 架构总览

复制代码
┌─────────────────────────────────────────────────────────────────────────┐
│                    Diffusion Transformer (DiT)                         │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  输入层                                                                 │
│  ┌──────────────────────────────────────────────────────────────────┐  │
│  │ x: [B, P, T×4] 加噪轨迹 (自车+周围车)                            │  │
│  │ t: [B] 扩散时间步                                                │  │
│  │ cross_c: [B, N, D] 场景上下文编码                                │  │
│  │ route_lanes: [B, 25, 20, 12] 路由车道信息                        │  │
│  └──────────────────────────────────────────────────────────────────┘  │
│                              │                                          │
│                              ▼                                          │
│  预处理层                                                               │
│  ┌──────────────────────────────────────────────────────────────────┐  │
│  │ preproj: [B,P,T×4]→[B,P,D] 输入投影                              │  │
│  │ agent_embedding: 添加智能体类型嵌入                               │  │
│  │ TimestepEmbedder(t): [B]→[B,D] 时间步编码                        │  │
│  │ RouteEncoder(route): [B,25,20,12]→[B,D] 路由编码                 │  │
│  │ y = route + t_embed 条件向量                                     │  │
│  └──────────────────────────────────────────────────────────────────┘  │
│                              │                                          │
│                              ▼                                          │
│  DiT Blocks × N                                                         │
│  ┌──────────────────────────────────────────────────────────────────┐  │
│  │ for each block:                                                   │  │
│  │   adaLN_modulation(y) → γ, β, gate                               │  │
│  │   x = γ * LN(x) + β                                              │  │
│  │   x = MHSA(x, x, x, attn_mask) + x  # 自注意力                   │  │
│  │   x = γ * LN(x) + β                                              │  │
│  │   x = MLP(x) + x                    # 前馈网络                   │  │
│  │   x = gate * x                      # 自适应门控                 │  │
│  │   x = MHCA(x, cross_c, cross_c) + x # 交叉注意力                 │  │
│  └──────────────────────────────────────────────────────────────────┘  │
│                              │                                          │
│                              ▼                                          │
│  输出层                                                                 │
│  ┌──────────────────────────────────────────────────────────────────┐  │
│  │ FinalLayer: adaLN + proj → [B,P,D]→[B,P,T×4]                    │  │
│  │ if model_type=="score":  output /= marginal_prob_std(t)          │  │
│  │ if model_type=="x_start": output = x                             │  │
│  └──────────────────────────────────────────────────────────────────┘  │
│                              │                                          │
│                              ▼                                          │
│  输出: [B, P, T, 4] 去噪轨迹预测                                        │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

6.2 核心组件总结

组件 类型 作用 关键特点
preproj Linear 输入投影 将轨迹特征映射到隐藏维度
agent_embedding Embedding 智能体类型编码 区分自车和周围车
TimestepEmbedder Sinusoidal 时间步编码 周期性位置编码,支持外推
RouteEncoder MLP-Mixer 路由信息编码 将车道线序列压缩为条件向量
DiT Block Transformer 核心编码单元 adaLN + MHSA + MHCA + MLP
adaLN_modulation MLP 自适应层归一化 根据条件向量动态调整 γ, β, gate
MHSA Multi-Head 自注意力 建模多智能体交互关系
MHCA Multi-Head 交叉注意力 融合场景上下文信息
FinalLayer Linear 输出投影 零初始化保证训练稳定

6.3 关键数学公式

前向扩散过程(VP-SDE)
复制代码
dx = -0.5 * beta_t * x dt + sqrt(beta_t) dW
mean = exp(-0.25 * (beta_max-beta_min) * t^2 - 0.5 * beta_min * t) * x
std = sqrt(1 - exp(2 * mean_log_coeff))
adaLN-Zero 调制
复制代码
h = linear(y)                    # 条件向量映射
gamma, beta, gate = split(h)     # 分割为三个参数
x = gamma * LN(x) + beta         # 自适应归一化
x = gate * x                     # 自适应门控
Score Matching 损失
复制代码
L = E_{t,x,ε} [||ε - σ_t * score(x_t, t, y)||^2]
其中 x_t = mean * x + std * ε
x_start Prediction 损失
复制代码
L = E_{t,x,ε} [||x - model(x_t, t, y)||^2]

6.4 训练与推理流程

训练流程
复制代码
1. 准备数据: ego_future, neighbors_future, route_lanes, scene_encoding
2. 采样时间步 t ~ Uniform(0, 1)
3. 前向扩散: x_t = mean * x_0 + std * ε
4. DiT 预测: pred = dit(x_t, t, cross_c, route_lanes)
5. 计算损失: MSE(pred, x_0) 或 MSE(pred * std, ε)
6. 反向传播优化
推理流程
复制代码
1. 初始化: x_T = current_state + noise
2. DPM-Solver 迭代 (10-20步):
   x_{t-1} = x_t + dt * (drift + guidance)
3. 约束当前状态: x[:, :, 0, :] = current_state
4. 输出未来轨迹: x_0[:, :, 1:]

6.5 关键设计决策

设计决策 技术实现 目的
零初始化 adaLN和输出层初始化为0 训练初期网络近似恒等映射,保证稳定
条件注入 adaLN调制每个Block 精细控制条件信息的注入位置和强度
双注意力机制 MHSA + MHCA 自注意力建模交互,交叉注意力融合上下文
两种训练模式 score / x_start score模式理论基础强,x_start模式更直观
DPM-Solver 快速采样算法 10-20步完成采样,满足实时性
初始状态约束 correcting_xt_fn 强制第一个时间步为当前观测状态

6.6 维度变化速览

复制代码
输入:
  x: [B, P, T×4]         # 轨迹序列
  t: [B]                  # 时间步
  cross_c: [B, N, D]     # 场景编码
  route_lanes: [B, 25, 20, 12]

预处理:
  x → preproj → [B, P, D]
  y = RouteEncoder(route) + TimestepEmbedder(t) → [B, D]

DiT Blocks:
  [B, P, D] → Block → [B, P, D] (重复 N 次)

输出:
  FinalLayer → [B, P, T×4] → reshape → [B, P, T, 4]

6.7 与其他模块的关系

复制代码
┌─────────────────┐     ┌─────────────────┐     ┌─────────────────┐
│   Encoder       │     │     DiT         │     │   Loss Module   │
│ (场景编码)      │────▶│ (主干网络)       │────▶│ (损失计算)      │
└─────────────────┘     └─────────────────┘     └─────────────────┘
        │                       │                       │
        │                       ▼                       │
        │              ┌─────────────────┐              │
        │              │  SDE Module     │◀─────────────┘
        │              │ (扩散过程)      │
        │              └─────────────────┘
        │                       │
        │                       ▼
        │              ┌─────────────────┐
        │              │ Guidance Module │
        │              │ (碰撞避免)      │
        └─────────────▶│                 │
                       └─────────────────┘

6.8 性能与优化

优化策略 实现方式 效果
掩码计算 只处理有效token 减少无效计算,支持动态场景
DPM-Solver 快速采样算法 采样步数从1000降至10-20
零初始化 稳定训练初期 加速收敛,避免梯度爆炸
残差连接 每个Block都有残差 缓解梯度消失,支持深层网络
LayerNorm adaLN自适应归一化 加速训练,提高稳定性

精炼核心技术总结

一句话定位

DiT 是基于 Transformer 的扩散主干网络,通过 adaLN-Zero 实现条件注入和多智能体轨迹预测。

核心技术点

  • adaLN-Zero:自适应层归一化实现精细条件调制
  • 双注意力机制:MHSA+MHCA 分别建模自交互和上下文融合
  • 零初始化策略:训练初期网络近似恒等映射,保证稳定
  • DPM-Solver 集成:10-20步快速采样,满足实时性
  • x_start/Score 双模式:灵活选择训练目标

关键公式速查

adaLN-Zero 调制

y=s⋅LayerNorm(x)+b y = s \cdot \text{LayerNorm}(x) + b y=s⋅LayerNorm(x)+b

其中 s,bs, bs,b 由条件向量通过线性层生成。

DiT Block 输出

h′=MHSA(LN(h))+MHCA(LN(h),y)+h \text{h}' = \text{MHSA}(\text{LN}(h)) + \text{MHCA}(\text{LN}(h), y) + h h′=MHSA(LN(h))+MHCA(LN(h),y)+h

模块间一句话关联

DiT 接收 Encoder 的场景编码和 SDE 的时间步,输出预测轨迹到 Loss 模块计算损失,推理时与 Guidance 模块协同生成安全轨迹。