Diffusion-Planner:三、联合轨迹生成建模模块

联合轨迹生成建模模块原理讲解与代码注释

1. 概述

联合轨迹生成建模模块是 Diffusion Planner 的核心架构,负责同时生成自车和周围车辆的未来轨迹 。该模块采用 编码器-解码器架构,通过 Transformer 实现多智能体轨迹的联合预测。

核心特点:

  1. 多模态输入融合:统一处理邻居车辆、静态物体、车道线等场景信息
  2. 多智能体联合建模:通过自注意力机制建模车辆间的交互关系
  3. 碰撞避免机制:推理时引入 classifier guidance 实现安全约束

2. 整体架构

复制代码
┌─────────────────────────────────────────────────────────────────────────┐
│                        联合轨迹生成建模                                 │
├─────────────────────────────────────────────────────────────────────────┤
│  输入层                                                                 │
│  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐  ┌─────────────┐ │
│  │邻居车辆历史   │  │静态物体      │  │车道线        │  │路由信息     │ │
│  │(B,P,T,D)     │  │(B,P,D)       │  │(B,P,V,D)     │  │(B,P,V,D)    │ │
│  └──────┬───────┘  └──────┬───────┘  └──────┬───────┘  └──────┬──────┘ │
│         │                 │                 │                 │        │
│         ▼                 ▼                 ▼                 │        │
│  ┌───────────────────────────────────────────┐                │        │
│  │           Encoder(编码器)                │                │        │
│  │  AgentFusionEncoder → StaticFusionEncoder │                │        │
│  │  → LaneFusionEncoder → FusionEncoder      │                │        │
│  └──────────────────────────┬────────────────┘                │        │
│                             ▼                                 │        │
│                   场景上下文编码 [B,N,D]                       │        │
│                             │                                 │        │
│                             ▼                                 ▼        │
│  ┌─────────────────────────────────────────────────────────────────────┐│
│  │                        Decoder(解码器)                            ││
│  │  DiT + RouteEncoder + TimestepEmbedder + Cross-Attention           ││
│  │                                                                     ││
│  │  输入: 加噪轨迹 [B,P,T,D] + 时间步 t + 路由编码 + 场景编码            ││
│  │  输出: 去噪轨迹 [B,P,T,D]                                           ││
│  └─────────────────────────────────────────────────────────────────────┘│
│                                         │                               │
│                                         ▼                               │
│                    预测轨迹 [B,1+Pn,T,4](自车 + 邻居)                   │
└─────────────────────────────────────────────────────────────────────────┘

3. 编码器模块(Encoder)

3.1 编码器主类

python 复制代码
class Encoder(nn.Module):
    """
    场景编码器:将多模态输入统一编码为场景上下文
    
    输入类型:
    1. 邻居车辆历史轨迹
    2. 静态物体(障碍物、红绿灯等)
    3. 车道线信息
    
    输出:统一的场景上下文编码 [B, N, D]
    
    N = agent_num + static_objects_num + lane_num
    """
    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.hidden_dim

        # 计算总 token 数:代理数 + 静态物体数 + 车道数
        self.token_num = config.agent_num + config.static_objects_num + config.lane_num

        # 三类编码器
        self.neighbor_encoder = AgentFusionEncoder(
            config.time_len, 
            drop_path_rate=config.encoder_drop_path_rate, 
            hidden_dim=config.hidden_dim, 
            depth=config.encoder_depth
        )
        self.static_encoder = StaticFusionEncoder(
            config.static_objects_state_dim, 
            drop_path_rate=config.encoder_drop_path_rate, 
            hidden_dim=config.hidden_dim
        )
        self.lane_encoder = LaneFusionEncoder(
            config.lane_len, 
            drop_path_rate=config.encoder_drop_path_rate, 
            hidden_dim=config.hidden_dim, 
            depth=config.encoder_depth
        )
    
        # 融合编码器:将三类编码融合为统一表示
        self.fusion = FusionEncoder(
            hidden_dim=config.hidden_dim, 
            num_heads=config.num_heads, 
            drop_path_rate=config.encoder_drop_path_rate, 
            depth=config.encoder_depth, 
            device=config.device
        )

        # 位置嵌入:编码 x, y, cos, sin, type
        self.pos_emb = nn.Linear(7, config.hidden_dim)

    def forward(self, inputs):
        """
        Args:
            inputs: 包含以下键的字典:
                - neighbor_agents_past: [B, P, T, D] 邻居车辆历史轨迹
                - static_objects: [B, P, D] 静态物体
                - lanes: [B, P, V, D] 车道线
                - lanes_speed_limit: [B, P, 1] 限速信息
                - lanes_has_speed_limit: [B, P, 1] 是否有限速
        
        Returns:
            encoder_outputs: 包含 'encoding' 键,值为 [B, token_num, hidden_dim]
        """
        encoder_outputs = {}

        # 提取输入
        neighbors = inputs['neighbor_agents_past']
        static = inputs['static_objects']
        lanes = inputs['lanes']
        lanes_speed_limit = inputs['lanes_speed_limit']
        lanes_has_speed_limit = inputs['lanes_has_speed_limit']

        B = neighbors.shape[0]

        # 分别编码三类输入
        encoding_neighbors, neighbors_mask, neighbor_pos = self.neighbor_encoder(neighbors)
        encoding_static, static_mask, static_pos = self.static_encoder(static)
        encoding_lanes, lanes_mask, lane_pos = self.lane_encoder(lanes, lanes_speed_limit, lanes_has_speed_limit)

        # 拼接编码结果
        encoding_input = torch.cat([encoding_neighbors, encoding_static, encoding_lanes], dim=1)

        # 处理位置嵌入
        encoding_pos = torch.cat([neighbor_pos, static_pos, lane_pos], dim=1).view(B * self.token_num, -1)
        # 只对有效token计算位置嵌入 ,避免浪费计算资源在无效输入上
        # encoding_mask 是一个 联合有效性掩码 ,用于标识所有输入实体(邻居车辆、静态物体、车道线)的有效性状态,
        # 在位置嵌入计算和自注意力机制中起到 过滤无效输入、优化计算效率 的关键作用。
        encoding_mask = torch.cat([neighbors_mask, static_mask, lanes_mask], dim=1).view(-1)
        encoding_pos = self.pos_emb(encoding_pos[~encoding_mask]) 
        
        # 填充无效位置
        encoding_pos_result = torch.zeros((B * self.token_num, self.hidden_dim), device=encoding_pos.device)
        encoding_pos_result[~encoding_mask] = encoding_pos

        # 添加位置嵌入
        encoding_input = encoding_input + encoding_pos_result.view(B, self.token_num, -1)

        # 融合编码
        encoder_outputs['encoding'] = self.fusion(encoding_input, encoding_mask.view(B, self.token_num))

        return encoder_outputs

3.2 邻居车辆编码器(AgentFusionEncoder)

python 复制代码
class AgentFusionEncoder(nn.Module):
    """
    邻居车辆编码器:将车辆历史轨迹编码为特征向量
    
    使用 MLP-Mixer 架构处理时序数据,捕捉车辆运动模式
    
    输入:[B, P, V, D],其中 D = 9(x, y, cos, sin, vx, vy, w, l, type(3))
    输出:[B, P, hidden_dim]
    """
    def __init__(self, time_len, drop_path_rate=0.3, hidden_dim=192, depth=3, 
                 tokens_mlp_dim=64, channels_mlp_dim=128):
        super().__init__()

        self._hidden_dim = hidden_dim
        self._channel = channels_mlp_dim

        # 类型嵌入:区分不同类型的车辆
        self.type_emb = nn.Linear(3, channels_mlp_dim)

        # Channel-wise MLP:处理每个时间步的特征
        self.channel_pre_project = Mlp(
            in_features=8+1,  # 8维状态 + 1维掩码
            hidden_features=channels_mlp_dim, 
            out_features=channels_mlp_dim, 
            act_layer=nn.GELU, 
            drop=0.
        )
        
        # Token-wise MLP:处理时间步之间的关系
        self.token_pre_project = Mlp(
            in_features=time_len, 
            hidden_features=tokens_mlp_dim, 
            out_features=tokens_mlp_dim, 
            act_layer=nn.GELU, 
            drop=0.
        )

        # MLP-Mixer 块堆叠
        self.blocks = nn.ModuleList([
            MixerBlock(tokens_mlp_dim, channels_mlp_dim, drop_path_rate) 
            for i in range(depth)
        ])

        self.norm = nn.LayerNorm(channels_mlp_dim)
        self.emb_project = Mlp(
            in_features=channels_mlp_dim, 
            hidden_features=hidden_dim, 
            out_features=hidden_dim, 
            act_layer=nn.GELU, 
            drop=drop_path_rate
        )


    def forward(self, x):
        '''
        Args:
            x: [B, P, V, D],其中 D = 9(x, y, cos, sin, vx, vy, w, l, type(3))
        
        Returns:
            encoding: [B, P, hidden_dim] 编码结果
            mask_p: [B, P] 有效掩码
            pos: [B, P, 7] 位置信息
        '''
        # 提取类型信息(最后3维)
        neighbor_type = x[:, :, -1, 8:]
        # 提取状态信息(前8维)
        x = x[..., :8]

        # 提取位置信息(最后一个时间步的 x, y, cos, sin)
        pos = x[:, :, -1, :7].clone()
        # 标记为邻居类型 [1, 0, 0]
        pos[..., -3:] = 0.0
        pos[..., -3] = 1.0
        
        B, P, V, _ = x.shape
        
        # 创建掩码:标记无效的时间步和无效的车辆
        mask_v = torch.sum(torch.ne(x[..., :8], 0), dim=-1).to(x.device) == 0
        mask_p = torch.sum(~mask_v, dim=-1) == 0
        
        # 添加掩码通道
        x = torch.cat([x, (~mask_v).float().unsqueeze(-1)], dim=-1)
        x = x.view(B * P, V, -1)

        # 过滤无效车辆
        valid_indices = ~mask_p.view(-1) 
        x = x[valid_indices] 

        # MLP-Mixer 处理
        x = self.channel_pre_project(x)
        x = x.permute(0, 2, 1)
        x = self.token_pre_project(x)
        x = x.permute(0, 2, 1)
        for block in self.blocks:
            x = block(x)  

        # 全局平均池化:将时间序列压缩为单个向量
        x = torch.mean(x, dim=1)

        # 添加类型嵌入
        neighbor_type = neighbor_type.view(B * P, -1)
        neighbor_type = neighbor_type[valid_indices]
        type_embedding = self.type_emb(neighbor_type)
        x = x + type_embedding

        # 投影到隐藏维度
        x = self.emb_project(self.norm(x))

        # 填充结果
        x_result = torch.zeros((B * P, x.shape[-1]), device=x.device)
        x_result[valid_indices] = x  
        
        return x_result.view(B, P, -1), mask_p.reshape(B, -1), pos.view(B, P, -1)
3.2.1 输入输出示例详解

假设输入参数:

  • B = 2(批次大小)
  • P = 3(每场景最多3辆车)
  • V = 5(时间步)
  • D = 9(特征维度)

输入数据示例:

python 复制代码
# x: [B=2, P=3, V=5, D=9]
# 特征顺序: [x, y, cos, sin, vx, vy, w, l, type(3维独热)]
x = torch.tensor([
    # 场景1: 2辆车有效
    [
        # 车辆1: 小汽车 [1,0,0],沿x轴正方向行驶
        [[10.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
         [10.5, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
         [11.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
         [11.5, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
         [12.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0]],
        # 车辆2: 卡车 [0,1,0],沿y轴正方向行驶
        [[15.0, 25.0, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
         [15.0, 25.3, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
         [15.0, 25.6, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
         [15.0, 25.9, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
         [15.0, 26.2, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0]],
        # 车辆3: 无效(全零)
        [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0,0,0], ...]
    ],
    # 场景2: 1辆车有效
    [
        # 车辆1: 小汽车 [1,0,0],沿45度方向行驶
        [[5.0, 10.0, 0.707, 0.707, 7.0, 7.0, 0.0, 4.0, 1,0,0], ...],
        # 车辆2-3: 无效
        [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0,0,0], ...],
        [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0,0,0], ...]
    ]
])

处理流程:

| 阶段 | 操作 | 形状变化 | 说明 |

| 1 | 特征分离 | [2,3,5,9][2,3,5,8] | 提取前8维状态,分离类型信息 |

提取类型信息(最后3维)

neighbor_type = x:, :, -1, 8:

shape: 2, 3, 3

场景1: \[1,0,0, 0,1,0, 0,0,0]

场景2: \[1,0,0, 0,0,0, 0,0,0]

提取状态信息(前8维)

x = x..., :8

shape: 2, 3, 5, 8

| 2 | 位置提取 | -→[2,3,7] | 取最后时间步的x,y,cos,sin,vx,vy,w |

提取最后一个时间步的位置信息

pos = x:, :, -1, :7.clone()

shape: 2, 3, 7

场景1: \[12.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, # 车辆1最后时刻

16.2, 25.0, 0.0, 1.0, 3.0, 0.0, 0.0, # 车辆2最后时刻

0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # 车辆3无效

标记为邻居类型 1, 0, 0

pos..., -3: = 0.0

pos..., -3 = 1.0

现在 pos 最后3维都是 1, 0, 0,表示这是邻居车辆

| 3 | 掩码生成 | -→[2,3] | mask_p=\[F,F,T,F,T,T] |

B, P, V, _ = x.shape # B=2, P=3, V=5

mask_v: 标记无效时间步 B, P, V

mask_v = torch.sum(torch.ne(x..., :8, 0), dim=-1) == 0

场景1车辆1: False, False, False, False, False # 所有时间步有效

场景1车辆3: True, True, True, True, True # 所有时间步无效

mask_p: 标记无效车辆 B, P

mask_p = torch.sum(~mask_v, dim=-1) == 0

#场景1: False, False, True # 车辆3无效

场景2: False, True, True # 车辆2、3无效

| 4 | 展平过滤 | [2,3,5,9][3,5,9] | 保留3辆有效车 |

掩码通道添加与展平

添加掩码通道(第9维)

x = torch.cat(x, (\~mask_v).float().unsqueeze(-1), dim=-1)

shape: 2, 3, 5, 9

最后一维新增掩码:1=有效时间步,0=无效时间步

展平批次和车辆维度

x = x.view(B * P, V, -1)

shape: 6, 5, 9 (2*3=6辆车)

过滤无效车辆

valid_indices = ~mask_p.view(-1)

False, False, True, False, True, True

对应6辆车中:第0、1、3辆有效(场景1车辆1、2;场景2车辆1)

x = xvalid_indices

shape: 3, 5, 9 (只保留3辆有效车)

| 5 | MLP-Mixer | [3,5,9][3,64] | 时间序列编码+全局池化 |

假设 hidden_dim=64

x = self.channel_pre_project(x) # 3, 5, 93, 5, 64

x = x.permute(0, 2, 1) # 3, 5, 643, 64, 5

x = self.token_pre_project(x) # 3, 64, 53, 64, 5

x = x.permute(0, 2, 1) # 3, 64, 53, 5, 64

for block in self.blocks:

x = block(x) # MLP-Mixer 块处理

全局平均池化:压缩时间维度

x = torch.mean(x, dim=1)

shape: 3, 64 (时间序列 → 单向量)

| 6 | 类型嵌入 | [3,64][3,64] | 添加车辆类型信息 |

类型嵌入

neighbor_type = neighbor_type.view(B * P, -1) # 6, 3

neighbor_type = neighbor_typevalid_indices # 3, 3

type_embedding = self.type_emb(neighbor_type) # 3, 64

x = x + type_embedding # 添加类型信息

最终投影

x = self.emb_project(self.norm(x)) # 3, 643, hidden_dim

| 7 | 结果填充 | [3,64][2,3,64] | 恢复原始形状,无效位置填零 |

x_result = torch.zeros((B * P, x.shape-1), device=x.device)

shape: 6, hidden_dim

x_resultvalid_indices = x

将有效车辆的编码放入对应位置

无效车辆位置保持为零向量

**恢复原始形状

encoding = x_result.view(B, P, -1)

shape: 2, 3, hidden_dim

输出结果:

python 复制代码
# encoding: [2, 3, 64]
# 场景1车辆3、场景2车辆2-3 位置为零向量

# mask_p: [2, 3]
# [[False, False, True],   # 场景1:车辆3无效
#  [False, True, True]]    # 场景2:车辆2-3无效

# pos: [2, 3, 7]
# 包含最后时刻位置+速度+朝向,最后3维为[1,0,0]表示邻居类型

数据流向图:

复制代码
输入: [2,3,5,9]
           │
           ▼
    ┌──────────────┐
    │ 特征分离     │ → neighbor_type [2,3,3]
    └──────┬───────┘
           │ x [2,3,5,8]
           ▼
    ┌──────────────┐
    │ 位置提取     │ → pos [2,3,7]
    └──────┬───────┘
           │
           ▼
    ┌──────────────┐
    │ 掩码生成     │ → mask_p [2,3]
    └──────┬───────┘
           │
           ▼
    ┌──────────────┐
    │ 过滤无效车   │ → [3,5,9]
    └──────┬───────┘
           │
           ▼
    ┌──────────────┐
    │ MLP-Mixer    │ → [3,64]
    └──────┬───────┘
           │
           ▼
    ┌──────────────┐
    │ 类型嵌入     │ → [3,64]
    └──────┬───────┘
           │
           ▼
    ┌──────────────┐
    │ 结果填充     │ → [2,3,64]
    └──────────────┘

3.3 静态物体编码器(StaticFusionEncoder)

python 复制代码
class StaticFusionEncoder(nn.Module):
    """
    静态物体编码器:将静态物体(障碍物、红绿灯等)编码为特征向量
    
    输入:[B, P, D],其中 D = 11(x, y, cos, sin, w, l, type(4))
    输出:[B, P, hidden_dim]
    """
    def __init__(self, dim, drop_path_rate=0.3, hidden_dim=192, device='cuda'):
        super().__init__()
        self._hidden_dim = hidden_dim
        self.projection = Mlp(
            in_features=dim, 
            hidden_features=hidden_dim, 
            out_features=hidden_dim, 
            act_layer=nn.GELU, 
            drop=drop_path_rate
        )

    def forward(self, x):
        '''
        Args:
            x: [B, P, D],其中 D = 11(x, y, cos, sin, w, l, type(4))
        
        Returns:
            encoding: [B, P, hidden_dim]
            mask_p: [B, P]
            pos: [B, P, 7]
        ''' 
        B, P, _ = x.shape

        # 提取位置信息
        pos = x[:, :, :7].clone()
        # 标记为静态物体类型 [0, 1, 0]
        pos[..., -3:] = 0.0
        pos[..., -2] = 1.0

        # 初始化结果
        x_result = torch.zeros((B * P, self._hidden_dim), device=x.device)

        # 创建掩码
        mask_p = torch.sum(torch.ne(x[..., :10], 0), dim=-1).to(x.device) == 0
        valid_indices = ~mask_p.view(-1) 

        # 处理有效静态物体
        if valid_indices.sum() > 0:
            x = x.view(B * P, -1)
            x = x[valid_indices]
            x = self.projection(x)
            x_result[valid_indices] = x

        return x_result.view(B, P, -1), mask_p.view(B, P), pos.view(B, P, -1)

3.4 车道线编码器(LaneFusionEncoder)

python 复制代码
class LaneFusionEncoder(nn.Module):
    """
    车道线编码器:将车道线信息编码为特征向量
    
    输入:车道线点序列 + 限速信息 + 红绿灯状态
    输出:车道线特征表示
    
    使用 MLP-Mixer 处理车道线的空间结构
    """
    def __init__(self, lane_len, drop_path_rate=0.3, hidden_dim=192, depth=3, 
                 tokens_mlp_dim=64, channels_mlp_dim=128):
        super().__init__()

        self._lane_len = lane_len
        self._channel = channels_mlp_dim

        # 限速嵌入
        self.speed_limit_emb = nn.Linear(1, channels_mlp_dim)
        self.unknown_speed_emb = nn.Embedding(1, channels_mlp_dim)
        # 红绿灯嵌入
        self.traffic_emb = nn.Linear(4, channels_mlp_dim)

        self.channel_pre_project = Mlp(
            in_features=8, 
            hidden_features=channels_mlp_dim, 
            out_features=channels_mlp_dim, 
            act_layer=nn.GELU, 
            drop=0.
        )
        self.token_pre_project = Mlp(
            in_features=lane_len, 
            hidden_features=tokens_mlp_dim, 
            out_features=tokens_mlp_dim, 
            act_layer=nn.GELU, 
            drop=0.
        )

        self.blocks = nn.ModuleList([
            MixerBlock(tokens_mlp_dim, channels_mlp_dim, drop_path_rate) 
            for i in range(depth)
        ])

        self.norm = nn.LayerNorm(channels_mlp_dim)
        self.emb_project = Mlp(
            in_features=channels_mlp_dim, 
            hidden_features=hidden_dim, 
            out_features=hidden_dim, 
            act_layer=nn.GELU, 
            drop=drop_path_rate
        )

    def forward(self, x, speed_limit, has_speed_limit):
        '''
        Args:
            x: [B, P, V, D] 车道线点,D = 12(x, y, x'-x, y'-y, x_left-x, y_left-y, x_right-x, y_right-y, traffic(4))
            speed_limit: [B, P, 1] 限速值
            has_speed_limit: [B, P, 1] 是否有限速
        
        Returns:
            encoding: [B, P, hidden_dim]
            mask_p: [B, P]
            pos: [B, P, 7]
        '''
        # 提取红绿灯信息
        traffic = x[:, :, 0, 8:]
        x = x[..., :8]

        # 提取中间点的位置作为车道线位置
        pos = x[:, :, int(self._lane_len / 2), :7].clone()
        # 计算朝向角
        heading = torch.atan2(pos[..., 3], pos[..., 2])
        pos[..., 2] = torch.cos(heading)
        pos[..., 3] = torch.sin(heading)
        # 标记为车道类型 [0, 0, 1]
        pos[..., -3:] = 0.0
        pos[..., -1] = 1.0

        B, P, V, _ = x.shape
        mask_v = torch.sum(torch.ne(x[..., :8], 0), dim=-1).to(x.device) == 0
        mask_p = torch.sum(~mask_v, dim=-1) == 0
        x = x.view(B * P, V, -1)

        valid_indices = ~mask_p.view(-1) 
        x = x[valid_indices] 

        # MLP-Mixer 处理
        x = self.channel_pre_project(x)
        x = x.permute(0, 2, 1)
        x = self.token_pre_project(x)
        x = x.permute(0, 2, 1)
        for block in self.blocks:
            x = block(x)  

        # 全局平均池化
        x = torch.mean(x, dim=1)

        # 处理限速信息
        speed_limit = speed_limit.view(B * P, 1)
        has_speed_limit = has_speed_limit.view(B * P, 1)
        traffic = traffic.view(B * P, -1)

        has_speed_limit = has_speed_limit[valid_indices].squeeze(-1)
        speed_limit = speed_limit[valid_indices].squeeze(-1)
        speed_limit_embedding = torch.zeros((speed_limit.shape[0], self._channel), device=x.device)

        # 有速限制的车道
        if has_speed_limit.sum() > 0:
            speed_limit_with_limit = self.speed_limit_emb(speed_limit[has_speed_limit].unsqueeze(-1))
            speed_limit_embedding[has_speed_limit] = speed_limit_with_limit

        # 无速限制的车道
        if (~has_speed_limit).sum() > 0:
            speed_limit_no_limit = self.unknown_speed_emb.weight.expand(
                (~has_speed_limit).sum().item(), -1
            )
            speed_limit_embedding[~has_speed_limit] = speed_limit_no_limit

        # 处理红绿灯信息
        traffic = traffic[valid_indices]
        traffic_light_embedding = self.traffic_emb(traffic)

        # 融合所有信息
        x = x + speed_limit_embedding + traffic_light_embedding
        x = self.emb_project(self.norm(x))

        x_result = torch.zeros((B * P, x.shape[-1]), device=x.device)
        x_result[valid_indices] = x  
        
        return x_result.view(B, P, -1), mask_p.reshape(B, -1), pos.view(B, P, -1)

3.5 融合编码器(FusionEncoder)

python 复制代码
class FusionEncoder(nn.Module):
    """
    融合编码器:将不同类型的编码融合为统一的场景表示
    
    使用自注意力机制让不同类型的实体之间进行交互
    """
    def __init__(self, hidden_dim=192, num_heads=6, drop_path_rate=0.3, depth=3, device='cuda'):
        super().__init__()

        dpr = drop_path_rate

        # 堆叠自注意力块
        self.blocks = nn.ModuleList(
            [SelfAttentionBlock(hidden_dim, num_heads, dropout=dpr) for i in range(depth)]
        )

        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x, mask):
        """
        Args:
            x: [B, N, D] 拼接后的编码
            mask: [B, N] 有效掩码
        
        Returns:
            [B, N, D] 融合后的场景编码
        """
        # 确保第一个 token(自车)始终有效
        mask[:, 0] = False

        for b in self.blocks:
            x = b(x, mask)

        return self.norm(x)

4. Diffusion Planner 主模型

python 复制代码
class Diffusion_Planner(nn.Module):
    """
    Diffusion Planner 主模型:整合编码器和解码器
    
    职责:
    - 训练时:处理输入,计算损失
    - 推理时:生成多智能体轨迹
    """
    def __init__(self, config):
        super().__init__()

        self.encoder = Diffusion_Planner_Encoder(config)
        self.decoder = Diffusion_Planner_Decoder(config)

    @property
    def sde(self):
        """返回 SDE 对象,用于损失计算和采样"""
        return self.decoder.decoder.sde
    
    def forward(self, inputs):
        """
        Args:
            inputs: 包含所有输入数据的字典
        
        Returns:
            encoder_outputs: 编码器输出
            decoder_outputs: 解码器输出(预测轨迹)
        """
        encoder_outputs = self.encoder(inputs)
        decoder_outputs = self.decoder(encoder_outputs, inputs)

        return encoder_outputs, decoder_outputs


class Diffusion_Planner_Encoder(nn.Module):
    """
    编码器包装类:包含权重初始化逻辑
    """
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(config)
        self.initialize_weights()

    def initialize_weights(self):
        """
        权重初始化策略:
        - Linear: Xavier uniform
        - LayerNorm: 权重=1,偏置=0
        - Embedding: 正态分布 N(0, 0.02)
        """
        def _basic_init(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
        self.apply(_basic_init)

        # 特殊初始化嵌入层
        nn.init.normal_(self.encoder.pos_emb.weight, std=0.02)
        nn.init.normal_(self.encoder.neighbor_encoder.type_emb.weight, std=0.02)
        nn.init.normal_(self.encoder.lane_encoder.speed_limit_emb.weight, std=0.02)
        nn.init.normal_(self.encoder.lane_encoder.traffic_emb.weight, std=0.02)

    def forward(self, inputs):
        return self.encoder(inputs)


class Diffusion_Planner_Decoder(nn.Module):
    """
    解码器包装类:包含权重初始化逻辑
    """
    def __init__(self, config):
        super().__init__()
        self.decoder = Decoder(config)
        self.initialize_weights()

    def initialize_weights(self):
        """
        解码器权重初始化:
        - 基础初始化同编码器
        - adaLN 调制层零初始化(保证训练稳定性)
        - 输出层零初始化(训练初期接近恒等映射)
        """
        def _basic_init(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0.0, std=0.02)
        self.apply(_basic_init)

        # 时间步嵌入 MLP 初始化
        nn.init.normal_(self.decoder.dit.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.decoder.dit.t_embedder.mlp[2].weight, std=0.02)

        # 关键:adaLN 调制层零初始化
        for block in self.decoder.dit.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # 关键:输出层零初始化
        nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.decoder.dit.final_layer.proj[-1].weight, 0)
        nn.init.constant_(self.decoder.dit.final_layer.proj[-1].bias, 0)

    def forward(self, encoder_outputs, inputs):
        return self.decoder(encoder_outputs, inputs)

5. 碰撞避免引导机制(Collision Guidance)

python 复制代码
def collision_guidance_fn(x, t, cond, inputs, *args, **kwargs) -> torch.Tensor:
    """
    碰撞避免引导函数:在推理时引导轨迹生成远离碰撞
    
    原理:Classifier Guidance
    通过计算碰撞奖励的梯度,将其注入扩散过程,使生成的轨迹更安全
    
    Args:
        x: [B, P, T+1, 4] 当前扩散状态(轨迹)
        t: [B, 1] 当前时间步
        cond: 条件信息
        inputs: 输入数据字典
    
    Returns:
        reward: [B] 引导奖励(用于调整采样方向)
    """
    B, P, T, _ = x.shape
    neighbor_current_mask = inputs["neighbor_current_mask"] # [B, Pn]
    
    x = x.reshape(B, P, -1, 4)
    
    # 只在特定时间步范围内应用引导(t < 0.1 且 t > 0.005)
    # 这是因为在扩散后期(t接近0),轨迹已基本确定
    mask_diffusion_time = (t < 0.1 and t > 0.005)
    x = torch.where(mask_diffusion_time, x, x.detach())
    
    # 归一化朝向向量
    x = torch.cat([
        x[:, :, :, :2], 
        x[:, :, :, 2:].detach() / torch.norm(x[:, :, :, 2:].detach(), dim=-1, keepdim=True)
    ], dim=-1)
    
    # 提取自车预测(跳过当前时刻)
    ego_pred = x[:, :1, 1:, :]
    # 调整自车位置到后轴中心(COG_TO_REAR = 1.67m)
    cos_h, sin_h = ego_pred[..., 2:3], ego_pred[..., 3:4]
    ego_pred = torch.cat([
        ego_pred[..., 0:1] + cos_h * COG_TO_REAR, 
        ego_pred[..., 1:2] + sin_h * COG_TO_REAR, 
        ego_pred[..., 2:]
    ], dim=-1)
    
    # 提取邻居预测
    neighbors_pred = x[:, 1:, 1:, :]
    
    B, Pn, T, _ = neighbors_pred.shape

    # 拼接自车和邻居预测(邻居使用detach避免梯度回传)
    predictions = torch.cat([ego_pred, neighbors_pred.detach()], dim=1)
    
    # 获取车辆尺寸(自车固定尺寸,邻居从输入获取)
    lw = torch.cat([
        torch.tensor(ego_size, device=predictions.device)[None, None, :].repeat(B, 1, 1),
        inputs["neighbor_agents_past"][:, :Pn, -1, [7, 6]]
    ], dim=1)
    
    # 构建 bounding box(位置 + 尺寸)
    bbox = torch.cat([
        predictions,
        lw.unsqueeze(2).expand(-1, -1, T, -1) + INFLATION  # INFLATION = 1.0m 安全膨胀
    ], dim=-1)
    
    # 将中心表示转换为四个角点
    bbox = center_rect_to_points(bbox.reshape(-1, 6)).reshape(B, Pn + 1, T, 4, 2)
    
    # 计算自车与每个邻居的碰撞距离
    ego_bbox = bbox[:, :1, :, :, :].expand(-1, Pn, -1, -1, -1)[~neighbor_current_mask].reshape(-1, 4, 2)
    neighbor_bbox = bbox[:, 1:, :, :, :][~neighbor_current_mask].reshape(-1, 4, 2)
    
    # 计算有符号距离(SDF)
    distances = batch_signed_distance_rect(ego_bbox, neighbor_bbox)
    
    # 裁剪距离(CLIP_DISTANCE = 1.0m)
    clip_distances = torch.maximum(1 - distances / CLIP_DISTANCE, torch.tensor(0.0, device=distances.device))
            
    # 计算碰撞惩罚奖励
    reward = - (
        torch.sum(clip_distances[clip_distances > 1]) / (torch.sum((clip_distances[clip_distances > 1].detach() > 0).float()) + 1e-5) +
        torch.sum(clip_distances[clip_distances <= 1]) / (torch.sum((clip_distances[clip_distances <= 1].detach() > 0).float()) + 1e-5)
    ).exp()
    
    # 计算奖励对轨迹的梯度(仅对自车位置)
    x_aux = torch.autograd.grad(
        reward.sum(), 
        x, 
        retain_graph=True, 
        allow_unused=True
    )[0][:, 0, :, :2]
    
    # 坐标变换矩阵(将局部坐标转换为全局坐标)
    T_total = T + 1
    x_mat = torch.einsum("btd,nd->btn", x[:, 0, :, 2:], 
        torch.tensor([[1., 0], [0, 1], [0, -1], [1, 0]], device=x.device)
    ).reshape(B, T_total, 2, 2)
    
    # 应用坐标变换
    x_aux = torch.einsum("btij,btj->bti", x_mat, x_aux)
    
    # 时间平滑:使用高斯核卷积
    x_aux = torch.stack([  
        torch.einsum("bt,it->bi", x_aux[..., 0], 
            torch.tril((-torch.linspace(0, 1, T_total, device=x.device)).exp().unsqueeze(0).repeat(T_total, 1))) * 0,
        F.conv1d(
            F.pad(x_aux[:, None, :, 1], (10, 10), mode='replicate'), 
            torch.ones(1, 1, 21, device=x.device) * \
            (- torch.linspace(-2, 2, 21, device=x.device) ** 2 / 4).exp()
        )[:, 0] * 1.0
    ], dim=2)
    
    # 反向坐标变换
    x_aux = torch.einsum("btji,btj->bti", x_mat, x_aux)
    
    # 计算最终奖励
    reward = torch.sum(x_aux.detach() * x[:, 0, :, :2], dim=(1, 2))
    
    # 乘以引导系数(3.0)
    return 3.0 * reward


def batch_signed_distance_rect(rect1, rect2):
    """
    计算两个矩形之间的有符号距离(Signed Distance Function)
    
    使用分离轴定理(Separating Axis Theorem)
    - 负距离表示重叠(碰撞)
    - 正距离表示分离
    
    Args:
        rect1: [B, 4, 2] 第一个矩形的四个角点
        rect2: [B, 4, 2] 第二个矩形的四个角点
    
    Returns:
        [B] 有符号距离
    """
    B, _, _ = rect1.shape
    # 计算四个分离轴(两个矩形的四条边的法向量)
    norm_vec = torch.stack([
        rect1[:, 0] - rect1[:, 1], 
        rect1[:, 1] - rect1[:, 2], 
        rect2[:, 0] - rect2[:, 1], 
        rect2[:, 1] - rect2[:, 2]
    ], dim=1)
    norm_vec = norm_vec / torch.norm(norm_vec, dim=2, keepdim=True)
    
    # 投影到分离轴上
    proj1 = torch.einsum('bij,bkj->bik', norm_vec, rect1)
    proj1_min, proj1_max = proj1.min(dim=2)[0], proj1.max(dim=2)[0]
    
    proj2 = torch.einsum('bij,bkj->bik', norm_vec, rect2)
    proj2_min, proj2_max = proj2.min(dim=2)[0], proj2.max(dim=2)[0]
    
    # 计算重叠情况
    overlap = torch.cat([proj1_min - proj2_max, proj2_min - proj1_max], dim=1)
    
    # 确定是否碰撞及距离
    positive_distance = torch.where(overlap < 0, 1e5, overlap)
    is_overlap = (overlap < 0).all(dim=1)
    distance = torch.where(
        is_overlap, 
        overlap.max(dim=1).values,  # 碰撞时取最大重叠深度(负值)
        positive_distance.min(dim=1).values  # 分离时取最小距离
    )   
    
    return distance


def center_rect_to_points(rect):
    """
    将中心表示的矩形转换为四个角点
    
    Args:
        rect: [B, 6] (x, y, cos_h, sin_h, l, w)
    
    Returns:
        [B, 4, 2] 四个角点坐标
    """
    B, _ = rect.shape
    xy, cos_h, sin_h, lw = rect[:, :2], rect[:, 2], rect[:, 3], rect[:, 4:]
    
    # 旋转矩阵
    rot = torch.stack([cos_h, -sin_h, sin_h, cos_h], dim=1).reshape(-1, 2, 2)
    # 四个角点相对于中心的偏移
    lw = torch.einsum('bj,ij->bij', lw, 
        torch.tensor([[1., 1], [-1, 1], [-1, -1], [1, -1]], device=lw.device) / 2)
    # 应用旋转
    lw = torch.einsum('bij,bkj->bik', lw, rot)
    
    # 添加中心偏移
    rect = xy[:, None, :] + lw
    
    return rect

6. 关键设计亮点

6.1 多模态输入统一编码

编码器同时处理三类输入:

  • 邻居车辆:使用 MLP-Mixer 处理时序轨迹
  • 静态物体:简单的 MLP 投影
  • 车道线:结合限速和红绿灯信息

通过位置嵌入区分不同类型:[1,0,0] 表示邻居,[0,1,0] 表示静态物体,[0,0,1] 表示车道线。

6.2 多智能体联合建模

DiT 解码器同时处理自车和邻居车辆:

  • 输入维度 P = 1 + predicted_neighbor_num
  • 自注意力机制建模车辆间的交互
  • 智能体类型嵌入区分自车和邻居

6.3 零初始化策略

解码器的 adaLN 调制层和输出层采用零初始化:

  • 训练初期模型接近恒等映射
  • 条件信息逐渐注入
  • 保证训练稳定性

6.4 Collision Guidance

推理时使用 Classifier Guidance 实现碰撞避免:

  • 基于分离轴定理计算碰撞距离
  • 通过梯度引导轨迹远离碰撞
  • 只在特定时间步应用(平衡探索与安全)

7. 配置参数总结

参数 默认值 说明
agent_num 25 最大邻居车辆数
static_objects_num 30 最大静态物体数
lane_num 70 最大车道线数
time_len 20 历史时间步数
lane_len 20 每条车道的点数
hidden_dim 192 隐藏层维度
num_heads 6 注意力头数
encoder_depth 3 编码器层数
decoder_depth 6 解码器层数
encoder_drop_path_rate 0.3 编码器 drop path 率
decoder_drop_path_rate 0.2 解码器 drop path 率

8. 核心内容总结

8.1 架构总览

复制代码
┌─────────────────────────────────────────────────────────────────────────┐
│                    联合轨迹生成建模模块                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌──────────────────────────────────────────────────────────────────┐   │
│  │                        Encoder                                   │   │
│  │                                                                  │   │
│  │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │   │
│  │  │AgentFusion   │  │StaticFusion  │  │LaneFusion    │           │   │
│  │  │ Encoder      │  │ Encoder      │  │ Encoder      │           │   │
│  │  │ (邻居车辆)    │  │ (静态物体)    │  │ (车道线)      │           │   │
│  │  │ MLP-Mixer    │  │ MLP          │  │ MLP-Mixer    │           │   │
│  │  └──────┬───────┘  └──────┬───────┘  └──────┬───────┘           │   │
│  │         │                 │                 │                   │   │
│  │         └────────┬────────┴────────┬────────┘                   │   │
│  │                  ▼                ▼                             │   │
│  │           ┌──────────────┐                                       │   │
│  │           │ FusionEncoder│                                       │   │
│  │           │ (自注意力融合)│                                      │   │
│  │           └──────┬───────┘                                       │   │
│  │                  │                                               │   │
│  │           输出: encoding [B, N, D]                               │   │
│  └──────────────────────────────────────────────────────────────────┘   │
│                                    │                                    │
│                                    ▼                                    │
│  ┌──────────────────────────────────────────────────────────────────┐   │
│  │                    Diffusion_Planner                             │   │
│  │                                                                  │   │
│  │  ┌─────────────────┐     ┌─────────────────┐                     │   │
│  │  │ Diffusion_Planner│     │ Diffusion_Planner│                     │   │
│  │  │ _Encoder        │     │ _Decoder        │                     │   │
│  │  │ (编码器包装)     │     │ (解码器包装)     │                     │   │
│  │  │ +权重初始化      │     │ +零初始化策略    │                     │   │
│  │  └────────┬────────┘     └────────┬────────┘                     │   │
│  │           │                       │                              │   │
│  │           │ encoding              │ DiT + DPM-Solver             │   │
│  │           └───────────┬───────────┘                              │   │
│  │                       ▼                                          │   │
│  │              ┌─────────────────┐                                 │   │
│  │              │   Collision     │                                 │   │
│  │              │   Guidance      │                                 │   │
│  │              │  (推理时可选)    │                                 │   │
│  │              └─────────────────┘                                 │   │
│  │                                                                  │   │
│  └──────────────────────────────────────────────────────────────────┘   │
│                                    │                                    │
│                                    ▼                                    │
│                          输出: 联合轨迹 [B, P, T, 4]                     │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

8.2 核心组件总结

组件 类型 职责 关键特点
Encoder nn.Module 统一编码多模态输入 整合三类子编码器
AgentFusionEncoder nn.Module 邻居车辆编码 MLP-Mixer处理时序
StaticFusionEncoder nn.Module 静态物体编码 MLP投影
LaneFusionEncoder nn.Module 车道线编码 含限速和红绿灯
FusionEncoder nn.Module 多模态特征融合 自注意力机制
Diffusion_Planner nn.Module 主模型入口 编码器+解码器
Diffusion_Planner_Encoder nn.Module 编码器包装 权重初始化
Diffusion_Planner_Decoder nn.Module 解码器包装 零初始化策略
collision_guidance_fn function 碰撞避免引导 Classifier Guidance
batch_signed_distance_rect function 碰撞距离计算 分离轴定理

8.3 关键数据流

编码器数据流
复制代码
输入:
  neighbors: [B, P, V, 9]     # 邻居车辆历史轨迹
  static: [B, S, D]           # 静态物体
  lanes: [B, L, V, D]         # 车道线

处理:
  neighbors → AgentFusionEncoder → [B, P, D], mask_p, pos
  static    → StaticFusionEncoder → [B, S, D], mask_p, pos
  lanes     → LaneFusionEncoder   → [B, L, D], mask_p, pos

融合:
  concat([neighbors, static, lanes]) → [B, N, D]
  FusionEncoder(encoding, mask) → [B, N, D]

输出:
  encoding: [B, N, D]
  mask: [B, N]
解码器数据流
复制代码
输入:
  x_t: [B, P, T×4]            # 加噪轨迹
  t: [B]                       # 时间步
  cross_c: [B, N, D]          # 场景编码
  route_lanes: [B, 25, 20, 12]

处理:
  preproj → [B, P, D]
  + agent_embedding
  + TimestepEmbedder(t)
  + RouteEncoder(route_lanes)

DiT Blocks × N:
  MHSA → MLP → MHCA → ...

输出:
  x_start / score: [B, P, T, 4]

8.4 碰撞避免引导机制

复制代码
推理时:
  x_t → compute_bboxes() → ego_bbox, neighbor_bbox
           │
           ▼
  batch_signed_distance_rect() → distances
           │
           ▼
  clip_distances = max(1 - distances/CLIP_DISTANCE, 0)
           │
           ▼
  reward = -exp(mean(clip_distances))
           │
           ▼
  x_aux = grad(reward, x)  # 计算梯度
           │
           ▼
  x_aux → coordinate_transform → time_smoothing → coordinate_transform_inv
           │
           ▼
  guidance = 3.0 * sum(x_aux * x[:, 0, :, :2])

8.5 关键设计决策

设计决策 技术实现 目的
多模态统一编码 三类子编码器+FusionEncoder 整合邻居、静态物体、车道线信息
MLP-Mixer架构 AgentFusionEncoder、LaneFusionEncoder 高效处理时序数据,无需注意力机制
动态掩码 mask_p、encoding_mask 处理可变数量输入实体
位置嵌入 pos_emb + type标记 区分不同类型实体
零初始化 adaLN和输出层初始化为0 训练稳定,条件信息逐渐注入
Classifier Guidance collision_guidance_fn 推理时引导轨迹避免碰撞
分离轴定理 batch_signed_distance_rect 高效精确的碰撞检测
时间自适应引导 guidance_scale随t变化 平衡探索与安全性

8.6 训练与推理流程

训练流程
复制代码
1. 数据准备: neighbors, static, lanes, ego_future, neighbors_future
2. 编码器前向: encoder_outputs = encoder(neighbors, static, lanes)
3. 时间步采样: t ~ Uniform(eps, 1-eps)
4. 前向扩散: x_t = mean * x_0 + std * noise
5. 解码器前向: pred = decoder(x_t, t, encoder_outputs, route_lanes)
6. 损失计算: loss = diffusion_loss_func(pred, x_0, noise, t)
7. 反向传播: loss.backward()
推理流程
复制代码
1. 编码器前向: encoder_outputs = encoder(neighbors, static, lanes)
2. 初始化: x_T = current_state + noise
3. DPM-Solver迭代 (10步):
   for step in range(10):
       t = T - step * dt
       pred = decoder(x_t, t, encoder_outputs, route_lanes)
       if guidance_fn:
           guidance = guidance_fn(x_t, t, ...)
       x_t = dpm_solver_step(x_t, t, pred, guidance)
       x_t = correcting_xt_fn(x_t)  # 约束当前状态
4. 输出: prediction = x_0[:, :, 1:]

8.7 模块协作关系

复制代码
┌─────────────────────────────────────────────────────────────────────┐
│                        模块依赖关系                                  │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  ┌─────────────┐                                                    │
│  │   Dataset   │───→ 原始数据                                        │
│  └──────┬──────┘                                                    │
│         │                                                           │
│         ▼                                                           │
│  ┌─────────────────────────────────────────────────────────────┐    │
│  │                    Encoder                                  │    │
│  │  AgentFusion + StaticFusion + LaneFusion + FusionEncoder    │    │
│  └──────┬──────────────────────────────────────────────────────┘    │
│         │                                                           │
│         ▼                                                           │
│  ┌─────────────────────────────────────────────────────────────┐    │
│  │                    Decoder                                  │    │
│  │  DiT + RouteEncoder + StateNormalizer                       │    │
│  └──────┬──────────────────────────────────────────────────────┘    │
│         │                                                           │
│         ▼                                                           │
│  ┌─────────────────────────────────────────────────────────────┐    │
│  │              Guidance (推理时)                               │    │
│  │  collision_guidance_fn + SDF + Gradient                     │    │
│  └─────────────────────────────────────────────────────────────┘    │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

8.8 关键数学公式

分离轴定理(碰撞检测)
复制代码
对于两个矩形 rect1 和 rect2:
1. 计算4个分离轴(两个矩形的四条边法向量)
2. 将两个矩形投影到每个轴上
3. 检查投影是否重叠
4. 如果所有轴都重叠 → 碰撞(返回负距离)
5. 如果至少一个轴不重叠 → 分离(返回最小正距离)
Classifier Guidance
复制代码
x_{t-1} = x_t + dt * (drift + scale * grad(reward, x))

其中 reward = -exp(mean(clip_distances))
      clip_distances = max(1 - distances/CLIP_DISTANCE, 0)
MLP-Mixer 结构
复制代码
Block:
  x = ChannelMLP(x) + x       # 通道维度变换
  x = TokenMLP(x) + x         # Token维度变换

ChannelMLP: Linear → GELU → Linear
TokenMLP: Linear → GELU → Linear

8.9 性能与优化

优化策略 实现方式 效果
动态掩码 只处理有效实体 减少无效计算,支持可变输入
MLP-Mixer 替代Transformer 降低复杂度,适合时序数据
零初始化 稳定训练初期 加速收敛,避免梯度爆炸
DPM-Solver 快速采样 10-20步完成推理
时间自适应引导 引导强度随t变化 平衡探索与安全
高斯卷积平滑 时间维度平滑 生成更平滑的轨迹

8.10 与其他模块的关系

复制代码
┌─────────────────────────────────────────────────────────────────────┐
│                    联合轨迹模块与其他模块的交互                        │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐           │
│  │   Encoder   │────▶│   DiT       │────▶│   Loss      │           │
│  │ (场景编码)   │     │ (主干网络)   │     │ (损失计算)   │           │
│  └─────────────┘     └──────┬──────┘     └─────────────┘           │
│                             │                                       │
│                             ▼                                       │
│                    ┌─────────────┐                                  │
│                    │   SDE       │                                  │
│                    │ (扩散过程)   │                                  │
│                    └─────────────┘                                  │
│                             │                                       │
│                             ▼                                       │
│                    ┌─────────────┐                                  │
│                    │ Guidance    │                                  │
│                    │ (碰撞避免)   │                                  │
│                    └─────────────┘                                  │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

精炼核心技术总结

一句话定位

联合轨迹生成建模模块通过多模态编码器融合场景信息,驱动 DiT 生成自车与邻居车辆的联合轨迹。

核心技术点

  • 多模态编码器:AgentFusion+StaticFusion+LaneFusion 分别处理不同输入类型
  • MLP-Mixer:替代 Transformer,高效处理时序轨迹数据
  • FusionEncoder:自注意力机制实现多模态特征交互
  • 联合轨迹建模:同时预测自车和邻居车辆轨迹
  • 掩码机制:支持可变数量的输入实体

关键公式速查

MLP-Mixer Block

h′=ChannelMLP(h)+h,h′′=TokenMLP(h′)+h′ \text{h}' = \text{ChannelMLP}(\text{h}) + \text{h}, \quad \text{h}'' = \text{TokenMLP}(\text{h}') + \text{h}' h′=ChannelMLP(h)+h,h′′=TokenMLP(h′)+h′

FusionEncoder 融合

encoding=SelfAttention(concat(agent,static,lane)) \text{encoding} = \text{SelfAttention}(\text{concat}(\text{agent}, \text{static}, \text{lane})) encoding=SelfAttention(concat(agent,static,lane))

模块间一句话关联

编码器将原始感知数据转换为统一场景表示,输入 DiT 进行扩散生成,推理时结合 Guidance 模块实现碰撞避免。