联合轨迹生成建模模块原理讲解与代码注释
1. 概述
联合轨迹生成建模模块是 Diffusion Planner 的核心架构,负责同时生成自车和周围车辆的未来轨迹 。该模块采用 编码器-解码器架构,通过 Transformer 实现多智能体轨迹的联合预测。
核心特点:
- 多模态输入融合:统一处理邻居车辆、静态物体、车道线等场景信息
- 多智能体联合建模:通过自注意力机制建模车辆间的交互关系
- 碰撞避免机制:推理时引入 classifier guidance 实现安全约束
2. 整体架构
┌─────────────────────────────────────────────────────────────────────────┐
│ 联合轨迹生成建模 │
├─────────────────────────────────────────────────────────────────────────┤
│ 输入层 │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌─────────────┐ │
│ │邻居车辆历史 │ │静态物体 │ │车道线 │ │路由信息 │ │
│ │(B,P,T,D) │ │(B,P,D) │ │(B,P,V,D) │ │(B,P,V,D) │ │
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ └──────┬──────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ │ │
│ ┌───────────────────────────────────────────┐ │ │
│ │ Encoder(编码器) │ │ │
│ │ AgentFusionEncoder → StaticFusionEncoder │ │ │
│ │ → LaneFusionEncoder → FusionEncoder │ │ │
│ └──────────────────────────┬────────────────┘ │ │
│ ▼ │ │
│ 场景上下文编码 [B,N,D] │ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐│
│ │ Decoder(解码器) ││
│ │ DiT + RouteEncoder + TimestepEmbedder + Cross-Attention ││
│ │ ││
│ │ 输入: 加噪轨迹 [B,P,T,D] + 时间步 t + 路由编码 + 场景编码 ││
│ │ 输出: 去噪轨迹 [B,P,T,D] ││
│ └─────────────────────────────────────────────────────────────────────┘│
│ │ │
│ ▼ │
│ 预测轨迹 [B,1+Pn,T,4](自车 + 邻居) │
└─────────────────────────────────────────────────────────────────────────┘
3. 编码器模块(Encoder)
3.1 编码器主类
python
class Encoder(nn.Module):
"""
场景编码器:将多模态输入统一编码为场景上下文
输入类型:
1. 邻居车辆历史轨迹
2. 静态物体(障碍物、红绿灯等)
3. 车道线信息
输出:统一的场景上下文编码 [B, N, D]
N = agent_num + static_objects_num + lane_num
"""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
# 计算总 token 数:代理数 + 静态物体数 + 车道数
self.token_num = config.agent_num + config.static_objects_num + config.lane_num
# 三类编码器
self.neighbor_encoder = AgentFusionEncoder(
config.time_len,
drop_path_rate=config.encoder_drop_path_rate,
hidden_dim=config.hidden_dim,
depth=config.encoder_depth
)
self.static_encoder = StaticFusionEncoder(
config.static_objects_state_dim,
drop_path_rate=config.encoder_drop_path_rate,
hidden_dim=config.hidden_dim
)
self.lane_encoder = LaneFusionEncoder(
config.lane_len,
drop_path_rate=config.encoder_drop_path_rate,
hidden_dim=config.hidden_dim,
depth=config.encoder_depth
)
# 融合编码器:将三类编码融合为统一表示
self.fusion = FusionEncoder(
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
drop_path_rate=config.encoder_drop_path_rate,
depth=config.encoder_depth,
device=config.device
)
# 位置嵌入:编码 x, y, cos, sin, type
self.pos_emb = nn.Linear(7, config.hidden_dim)
def forward(self, inputs):
"""
Args:
inputs: 包含以下键的字典:
- neighbor_agents_past: [B, P, T, D] 邻居车辆历史轨迹
- static_objects: [B, P, D] 静态物体
- lanes: [B, P, V, D] 车道线
- lanes_speed_limit: [B, P, 1] 限速信息
- lanes_has_speed_limit: [B, P, 1] 是否有限速
Returns:
encoder_outputs: 包含 'encoding' 键,值为 [B, token_num, hidden_dim]
"""
encoder_outputs = {}
# 提取输入
neighbors = inputs['neighbor_agents_past']
static = inputs['static_objects']
lanes = inputs['lanes']
lanes_speed_limit = inputs['lanes_speed_limit']
lanes_has_speed_limit = inputs['lanes_has_speed_limit']
B = neighbors.shape[0]
# 分别编码三类输入
encoding_neighbors, neighbors_mask, neighbor_pos = self.neighbor_encoder(neighbors)
encoding_static, static_mask, static_pos = self.static_encoder(static)
encoding_lanes, lanes_mask, lane_pos = self.lane_encoder(lanes, lanes_speed_limit, lanes_has_speed_limit)
# 拼接编码结果
encoding_input = torch.cat([encoding_neighbors, encoding_static, encoding_lanes], dim=1)
# 处理位置嵌入
encoding_pos = torch.cat([neighbor_pos, static_pos, lane_pos], dim=1).view(B * self.token_num, -1)
# 只对有效token计算位置嵌入 ,避免浪费计算资源在无效输入上
# encoding_mask 是一个 联合有效性掩码 ,用于标识所有输入实体(邻居车辆、静态物体、车道线)的有效性状态,
# 在位置嵌入计算和自注意力机制中起到 过滤无效输入、优化计算效率 的关键作用。
encoding_mask = torch.cat([neighbors_mask, static_mask, lanes_mask], dim=1).view(-1)
encoding_pos = self.pos_emb(encoding_pos[~encoding_mask])
# 填充无效位置
encoding_pos_result = torch.zeros((B * self.token_num, self.hidden_dim), device=encoding_pos.device)
encoding_pos_result[~encoding_mask] = encoding_pos
# 添加位置嵌入
encoding_input = encoding_input + encoding_pos_result.view(B, self.token_num, -1)
# 融合编码
encoder_outputs['encoding'] = self.fusion(encoding_input, encoding_mask.view(B, self.token_num))
return encoder_outputs
3.2 邻居车辆编码器(AgentFusionEncoder)
python
class AgentFusionEncoder(nn.Module):
"""
邻居车辆编码器:将车辆历史轨迹编码为特征向量
使用 MLP-Mixer 架构处理时序数据,捕捉车辆运动模式
输入:[B, P, V, D],其中 D = 9(x, y, cos, sin, vx, vy, w, l, type(3))
输出:[B, P, hidden_dim]
"""
def __init__(self, time_len, drop_path_rate=0.3, hidden_dim=192, depth=3,
tokens_mlp_dim=64, channels_mlp_dim=128):
super().__init__()
self._hidden_dim = hidden_dim
self._channel = channels_mlp_dim
# 类型嵌入:区分不同类型的车辆
self.type_emb = nn.Linear(3, channels_mlp_dim)
# Channel-wise MLP:处理每个时间步的特征
self.channel_pre_project = Mlp(
in_features=8+1, # 8维状态 + 1维掩码
hidden_features=channels_mlp_dim,
out_features=channels_mlp_dim,
act_layer=nn.GELU,
drop=0.
)
# Token-wise MLP:处理时间步之间的关系
self.token_pre_project = Mlp(
in_features=time_len,
hidden_features=tokens_mlp_dim,
out_features=tokens_mlp_dim,
act_layer=nn.GELU,
drop=0.
)
# MLP-Mixer 块堆叠
self.blocks = nn.ModuleList([
MixerBlock(tokens_mlp_dim, channels_mlp_dim, drop_path_rate)
for i in range(depth)
])
self.norm = nn.LayerNorm(channels_mlp_dim)
self.emb_project = Mlp(
in_features=channels_mlp_dim,
hidden_features=hidden_dim,
out_features=hidden_dim,
act_layer=nn.GELU,
drop=drop_path_rate
)
def forward(self, x):
'''
Args:
x: [B, P, V, D],其中 D = 9(x, y, cos, sin, vx, vy, w, l, type(3))
Returns:
encoding: [B, P, hidden_dim] 编码结果
mask_p: [B, P] 有效掩码
pos: [B, P, 7] 位置信息
'''
# 提取类型信息(最后3维)
neighbor_type = x[:, :, -1, 8:]
# 提取状态信息(前8维)
x = x[..., :8]
# 提取位置信息(最后一个时间步的 x, y, cos, sin)
pos = x[:, :, -1, :7].clone()
# 标记为邻居类型 [1, 0, 0]
pos[..., -3:] = 0.0
pos[..., -3] = 1.0
B, P, V, _ = x.shape
# 创建掩码:标记无效的时间步和无效的车辆
mask_v = torch.sum(torch.ne(x[..., :8], 0), dim=-1).to(x.device) == 0
mask_p = torch.sum(~mask_v, dim=-1) == 0
# 添加掩码通道
x = torch.cat([x, (~mask_v).float().unsqueeze(-1)], dim=-1)
x = x.view(B * P, V, -1)
# 过滤无效车辆
valid_indices = ~mask_p.view(-1)
x = x[valid_indices]
# MLP-Mixer 处理
x = self.channel_pre_project(x)
x = x.permute(0, 2, 1)
x = self.token_pre_project(x)
x = x.permute(0, 2, 1)
for block in self.blocks:
x = block(x)
# 全局平均池化:将时间序列压缩为单个向量
x = torch.mean(x, dim=1)
# 添加类型嵌入
neighbor_type = neighbor_type.view(B * P, -1)
neighbor_type = neighbor_type[valid_indices]
type_embedding = self.type_emb(neighbor_type)
x = x + type_embedding
# 投影到隐藏维度
x = self.emb_project(self.norm(x))
# 填充结果
x_result = torch.zeros((B * P, x.shape[-1]), device=x.device)
x_result[valid_indices] = x
return x_result.view(B, P, -1), mask_p.reshape(B, -1), pos.view(B, P, -1)
3.2.1 输入输出示例详解
假设输入参数:
- B = 2(批次大小)
- P = 3(每场景最多3辆车)
- V = 5(时间步)
- D = 9(特征维度)
输入数据示例:
python
# x: [B=2, P=3, V=5, D=9]
# 特征顺序: [x, y, cos, sin, vx, vy, w, l, type(3维独热)]
x = torch.tensor([
# 场景1: 2辆车有效
[
# 车辆1: 小汽车 [1,0,0],沿x轴正方向行驶
[[10.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
[10.5, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
[11.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
[11.5, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0],
[12.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, 4.5, 1,0,0]],
# 车辆2: 卡车 [0,1,0],沿y轴正方向行驶
[[15.0, 25.0, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
[15.0, 25.3, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
[15.0, 25.6, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
[15.0, 25.9, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0],
[15.0, 26.2, 0.0, 1.0, 0.0, 3.0, 0.0, 6.0, 0,1,0]],
# 车辆3: 无效(全零)
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0,0,0], ...]
],
# 场景2: 1辆车有效
[
# 车辆1: 小汽车 [1,0,0],沿45度方向行驶
[[5.0, 10.0, 0.707, 0.707, 7.0, 7.0, 0.0, 4.0, 1,0,0], ...],
# 车辆2-3: 无效
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0,0,0], ...],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0,0,0], ...]
]
])
处理流程:
| 阶段 | 操作 | 形状变化 | 说明 |
| 1 | 特征分离 | [2,3,5,9]→[2,3,5,8] | 提取前8维状态,分离类型信息 |
提取类型信息(最后3维)
neighbor_type = x:, :, -1, 8:
shape: 2, 3, 3
场景1: \[1,0,0, 0,1,0, 0,0,0]
场景2: \[1,0,0, 0,0,0, 0,0,0]
提取状态信息(前8维)
x = x..., :8
shape: 2, 3, 5, 8
| 2 | 位置提取 | -→[2,3,7] | 取最后时间步的x,y,cos,sin,vx,vy,w |
提取最后一个时间步的位置信息
pos = x:, :, -1, :7.clone()
shape: 2, 3, 7
场景1: \[12.0, 20.0, 1.0, 0.0, 5.0, 0.0, 0.0, # 车辆1最后时刻
16.2, 25.0, 0.0, 1.0, 3.0, 0.0, 0.0, # 车辆2最后时刻
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # 车辆3无效
标记为邻居类型 1, 0, 0
pos..., -3: = 0.0
pos..., -3 = 1.0
现在 pos 最后3维都是 1, 0, 0,表示这是邻居车辆
| 3 | 掩码生成 | -→[2,3] | mask_p=\[F,F,T,F,T,T] |
B, P, V, _ = x.shape # B=2, P=3, V=5
mask_v: 标记无效时间步 B, P, V
mask_v = torch.sum(torch.ne(x..., :8, 0), dim=-1) == 0
场景1车辆1: False, False, False, False, False # 所有时间步有效
场景1车辆3: True, True, True, True, True # 所有时间步无效
mask_p: 标记无效车辆 B, P
mask_p = torch.sum(~mask_v, dim=-1) == 0
#场景1: False, False, True # 车辆3无效
场景2: False, True, True # 车辆2、3无效
| 4 | 展平过滤 | [2,3,5,9]→[3,5,9] | 保留3辆有效车 |
掩码通道添加与展平
添加掩码通道(第9维)
x = torch.cat(x, (\~mask_v).float().unsqueeze(-1), dim=-1)
shape: 2, 3, 5, 9
最后一维新增掩码:1=有效时间步,0=无效时间步
展平批次和车辆维度
x = x.view(B * P, V, -1)
shape: 6, 5, 9 (2*3=6辆车)
过滤无效车辆
valid_indices = ~mask_p.view(-1)
False, False, True, False, True, True
对应6辆车中:第0、1、3辆有效(场景1车辆1、2;场景2车辆1)
x = xvalid_indices
shape: 3, 5, 9 (只保留3辆有效车)
| 5 | MLP-Mixer | [3,5,9]→[3,64] | 时间序列编码+全局池化 |
假设 hidden_dim=64
x = self.channel_pre_project(x) # 3, 5, 9 → 3, 5, 64
x = x.permute(0, 2, 1) # 3, 5, 64 → 3, 64, 5
x = self.token_pre_project(x) # 3, 64, 5 → 3, 64, 5
x = x.permute(0, 2, 1) # 3, 64, 5 → 3, 5, 64
for block in self.blocks:
x = block(x) # MLP-Mixer 块处理
全局平均池化:压缩时间维度
x = torch.mean(x, dim=1)
shape: 3, 64 (时间序列 → 单向量)
| 6 | 类型嵌入 | [3,64]→[3,64] | 添加车辆类型信息 |
类型嵌入
neighbor_type = neighbor_type.view(B * P, -1) # 6, 3
neighbor_type = neighbor_typevalid_indices # 3, 3
type_embedding = self.type_emb(neighbor_type) # 3, 64
x = x + type_embedding # 添加类型信息
最终投影
x = self.emb_project(self.norm(x)) # 3, 64 → 3, hidden_dim
| 7 | 结果填充 | [3,64]→[2,3,64] | 恢复原始形状,无效位置填零 |
x_result = torch.zeros((B * P, x.shape-1), device=x.device)
shape: 6, hidden_dim
x_resultvalid_indices = x
将有效车辆的编码放入对应位置
无效车辆位置保持为零向量
**恢复原始形状
encoding = x_result.view(B, P, -1)
shape: 2, 3, hidden_dim
输出结果:
python
# encoding: [2, 3, 64]
# 场景1车辆3、场景2车辆2-3 位置为零向量
# mask_p: [2, 3]
# [[False, False, True], # 场景1:车辆3无效
# [False, True, True]] # 场景2:车辆2-3无效
# pos: [2, 3, 7]
# 包含最后时刻位置+速度+朝向,最后3维为[1,0,0]表示邻居类型
数据流向图:
输入: [2,3,5,9]
│
▼
┌──────────────┐
│ 特征分离 │ → neighbor_type [2,3,3]
└──────┬───────┘
│ x [2,3,5,8]
▼
┌──────────────┐
│ 位置提取 │ → pos [2,3,7]
└──────┬───────┘
│
▼
┌──────────────┐
│ 掩码生成 │ → mask_p [2,3]
└──────┬───────┘
│
▼
┌──────────────┐
│ 过滤无效车 │ → [3,5,9]
└──────┬───────┘
│
▼
┌──────────────┐
│ MLP-Mixer │ → [3,64]
└──────┬───────┘
│
▼
┌──────────────┐
│ 类型嵌入 │ → [3,64]
└──────┬───────┘
│
▼
┌──────────────┐
│ 结果填充 │ → [2,3,64]
└──────────────┘
3.3 静态物体编码器(StaticFusionEncoder)
python
class StaticFusionEncoder(nn.Module):
"""
静态物体编码器:将静态物体(障碍物、红绿灯等)编码为特征向量
输入:[B, P, D],其中 D = 11(x, y, cos, sin, w, l, type(4))
输出:[B, P, hidden_dim]
"""
def __init__(self, dim, drop_path_rate=0.3, hidden_dim=192, device='cuda'):
super().__init__()
self._hidden_dim = hidden_dim
self.projection = Mlp(
in_features=dim,
hidden_features=hidden_dim,
out_features=hidden_dim,
act_layer=nn.GELU,
drop=drop_path_rate
)
def forward(self, x):
'''
Args:
x: [B, P, D],其中 D = 11(x, y, cos, sin, w, l, type(4))
Returns:
encoding: [B, P, hidden_dim]
mask_p: [B, P]
pos: [B, P, 7]
'''
B, P, _ = x.shape
# 提取位置信息
pos = x[:, :, :7].clone()
# 标记为静态物体类型 [0, 1, 0]
pos[..., -3:] = 0.0
pos[..., -2] = 1.0
# 初始化结果
x_result = torch.zeros((B * P, self._hidden_dim), device=x.device)
# 创建掩码
mask_p = torch.sum(torch.ne(x[..., :10], 0), dim=-1).to(x.device) == 0
valid_indices = ~mask_p.view(-1)
# 处理有效静态物体
if valid_indices.sum() > 0:
x = x.view(B * P, -1)
x = x[valid_indices]
x = self.projection(x)
x_result[valid_indices] = x
return x_result.view(B, P, -1), mask_p.view(B, P), pos.view(B, P, -1)
3.4 车道线编码器(LaneFusionEncoder)
python
class LaneFusionEncoder(nn.Module):
"""
车道线编码器:将车道线信息编码为特征向量
输入:车道线点序列 + 限速信息 + 红绿灯状态
输出:车道线特征表示
使用 MLP-Mixer 处理车道线的空间结构
"""
def __init__(self, lane_len, drop_path_rate=0.3, hidden_dim=192, depth=3,
tokens_mlp_dim=64, channels_mlp_dim=128):
super().__init__()
self._lane_len = lane_len
self._channel = channels_mlp_dim
# 限速嵌入
self.speed_limit_emb = nn.Linear(1, channels_mlp_dim)
self.unknown_speed_emb = nn.Embedding(1, channels_mlp_dim)
# 红绿灯嵌入
self.traffic_emb = nn.Linear(4, channels_mlp_dim)
self.channel_pre_project = Mlp(
in_features=8,
hidden_features=channels_mlp_dim,
out_features=channels_mlp_dim,
act_layer=nn.GELU,
drop=0.
)
self.token_pre_project = Mlp(
in_features=lane_len,
hidden_features=tokens_mlp_dim,
out_features=tokens_mlp_dim,
act_layer=nn.GELU,
drop=0.
)
self.blocks = nn.ModuleList([
MixerBlock(tokens_mlp_dim, channels_mlp_dim, drop_path_rate)
for i in range(depth)
])
self.norm = nn.LayerNorm(channels_mlp_dim)
self.emb_project = Mlp(
in_features=channels_mlp_dim,
hidden_features=hidden_dim,
out_features=hidden_dim,
act_layer=nn.GELU,
drop=drop_path_rate
)
def forward(self, x, speed_limit, has_speed_limit):
'''
Args:
x: [B, P, V, D] 车道线点,D = 12(x, y, x'-x, y'-y, x_left-x, y_left-y, x_right-x, y_right-y, traffic(4))
speed_limit: [B, P, 1] 限速值
has_speed_limit: [B, P, 1] 是否有限速
Returns:
encoding: [B, P, hidden_dim]
mask_p: [B, P]
pos: [B, P, 7]
'''
# 提取红绿灯信息
traffic = x[:, :, 0, 8:]
x = x[..., :8]
# 提取中间点的位置作为车道线位置
pos = x[:, :, int(self._lane_len / 2), :7].clone()
# 计算朝向角
heading = torch.atan2(pos[..., 3], pos[..., 2])
pos[..., 2] = torch.cos(heading)
pos[..., 3] = torch.sin(heading)
# 标记为车道类型 [0, 0, 1]
pos[..., -3:] = 0.0
pos[..., -1] = 1.0
B, P, V, _ = x.shape
mask_v = torch.sum(torch.ne(x[..., :8], 0), dim=-1).to(x.device) == 0
mask_p = torch.sum(~mask_v, dim=-1) == 0
x = x.view(B * P, V, -1)
valid_indices = ~mask_p.view(-1)
x = x[valid_indices]
# MLP-Mixer 处理
x = self.channel_pre_project(x)
x = x.permute(0, 2, 1)
x = self.token_pre_project(x)
x = x.permute(0, 2, 1)
for block in self.blocks:
x = block(x)
# 全局平均池化
x = torch.mean(x, dim=1)
# 处理限速信息
speed_limit = speed_limit.view(B * P, 1)
has_speed_limit = has_speed_limit.view(B * P, 1)
traffic = traffic.view(B * P, -1)
has_speed_limit = has_speed_limit[valid_indices].squeeze(-1)
speed_limit = speed_limit[valid_indices].squeeze(-1)
speed_limit_embedding = torch.zeros((speed_limit.shape[0], self._channel), device=x.device)
# 有速限制的车道
if has_speed_limit.sum() > 0:
speed_limit_with_limit = self.speed_limit_emb(speed_limit[has_speed_limit].unsqueeze(-1))
speed_limit_embedding[has_speed_limit] = speed_limit_with_limit
# 无速限制的车道
if (~has_speed_limit).sum() > 0:
speed_limit_no_limit = self.unknown_speed_emb.weight.expand(
(~has_speed_limit).sum().item(), -1
)
speed_limit_embedding[~has_speed_limit] = speed_limit_no_limit
# 处理红绿灯信息
traffic = traffic[valid_indices]
traffic_light_embedding = self.traffic_emb(traffic)
# 融合所有信息
x = x + speed_limit_embedding + traffic_light_embedding
x = self.emb_project(self.norm(x))
x_result = torch.zeros((B * P, x.shape[-1]), device=x.device)
x_result[valid_indices] = x
return x_result.view(B, P, -1), mask_p.reshape(B, -1), pos.view(B, P, -1)
3.5 融合编码器(FusionEncoder)
python
class FusionEncoder(nn.Module):
"""
融合编码器:将不同类型的编码融合为统一的场景表示
使用自注意力机制让不同类型的实体之间进行交互
"""
def __init__(self, hidden_dim=192, num_heads=6, drop_path_rate=0.3, depth=3, device='cuda'):
super().__init__()
dpr = drop_path_rate
# 堆叠自注意力块
self.blocks = nn.ModuleList(
[SelfAttentionBlock(hidden_dim, num_heads, dropout=dpr) for i in range(depth)]
)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x, mask):
"""
Args:
x: [B, N, D] 拼接后的编码
mask: [B, N] 有效掩码
Returns:
[B, N, D] 融合后的场景编码
"""
# 确保第一个 token(自车)始终有效
mask[:, 0] = False
for b in self.blocks:
x = b(x, mask)
return self.norm(x)
4. Diffusion Planner 主模型
python
class Diffusion_Planner(nn.Module):
"""
Diffusion Planner 主模型:整合编码器和解码器
职责:
- 训练时:处理输入,计算损失
- 推理时:生成多智能体轨迹
"""
def __init__(self, config):
super().__init__()
self.encoder = Diffusion_Planner_Encoder(config)
self.decoder = Diffusion_Planner_Decoder(config)
@property
def sde(self):
"""返回 SDE 对象,用于损失计算和采样"""
return self.decoder.decoder.sde
def forward(self, inputs):
"""
Args:
inputs: 包含所有输入数据的字典
Returns:
encoder_outputs: 编码器输出
decoder_outputs: 解码器输出(预测轨迹)
"""
encoder_outputs = self.encoder(inputs)
decoder_outputs = self.decoder(encoder_outputs, inputs)
return encoder_outputs, decoder_outputs
class Diffusion_Planner_Encoder(nn.Module):
"""
编码器包装类:包含权重初始化逻辑
"""
def __init__(self, config):
super().__init__()
self.encoder = Encoder(config)
self.initialize_weights()
def initialize_weights(self):
"""
权重初始化策略:
- Linear: Xavier uniform
- LayerNorm: 权重=1,偏置=0
- Embedding: 正态分布 N(0, 0.02)
"""
def _basic_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
self.apply(_basic_init)
# 特殊初始化嵌入层
nn.init.normal_(self.encoder.pos_emb.weight, std=0.02)
nn.init.normal_(self.encoder.neighbor_encoder.type_emb.weight, std=0.02)
nn.init.normal_(self.encoder.lane_encoder.speed_limit_emb.weight, std=0.02)
nn.init.normal_(self.encoder.lane_encoder.traffic_emb.weight, std=0.02)
def forward(self, inputs):
return self.encoder(inputs)
class Diffusion_Planner_Decoder(nn.Module):
"""
解码器包装类:包含权重初始化逻辑
"""
def __init__(self, config):
super().__init__()
self.decoder = Decoder(config)
self.initialize_weights()
def initialize_weights(self):
"""
解码器权重初始化:
- 基础初始化同编码器
- adaLN 调制层零初始化(保证训练稳定性)
- 输出层零初始化(训练初期接近恒等映射)
"""
def _basic_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
self.apply(_basic_init)
# 时间步嵌入 MLP 初始化
nn.init.normal_(self.decoder.dit.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.decoder.dit.t_embedder.mlp[2].weight, std=0.02)
# 关键:adaLN 调制层零初始化
for block in self.decoder.dit.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# 关键:输出层零初始化
nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.decoder.dit.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.decoder.dit.final_layer.proj[-1].weight, 0)
nn.init.constant_(self.decoder.dit.final_layer.proj[-1].bias, 0)
def forward(self, encoder_outputs, inputs):
return self.decoder(encoder_outputs, inputs)
5. 碰撞避免引导机制(Collision Guidance)
python
def collision_guidance_fn(x, t, cond, inputs, *args, **kwargs) -> torch.Tensor:
"""
碰撞避免引导函数:在推理时引导轨迹生成远离碰撞
原理:Classifier Guidance
通过计算碰撞奖励的梯度,将其注入扩散过程,使生成的轨迹更安全
Args:
x: [B, P, T+1, 4] 当前扩散状态(轨迹)
t: [B, 1] 当前时间步
cond: 条件信息
inputs: 输入数据字典
Returns:
reward: [B] 引导奖励(用于调整采样方向)
"""
B, P, T, _ = x.shape
neighbor_current_mask = inputs["neighbor_current_mask"] # [B, Pn]
x = x.reshape(B, P, -1, 4)
# 只在特定时间步范围内应用引导(t < 0.1 且 t > 0.005)
# 这是因为在扩散后期(t接近0),轨迹已基本确定
mask_diffusion_time = (t < 0.1 and t > 0.005)
x = torch.where(mask_diffusion_time, x, x.detach())
# 归一化朝向向量
x = torch.cat([
x[:, :, :, :2],
x[:, :, :, 2:].detach() / torch.norm(x[:, :, :, 2:].detach(), dim=-1, keepdim=True)
], dim=-1)
# 提取自车预测(跳过当前时刻)
ego_pred = x[:, :1, 1:, :]
# 调整自车位置到后轴中心(COG_TO_REAR = 1.67m)
cos_h, sin_h = ego_pred[..., 2:3], ego_pred[..., 3:4]
ego_pred = torch.cat([
ego_pred[..., 0:1] + cos_h * COG_TO_REAR,
ego_pred[..., 1:2] + sin_h * COG_TO_REAR,
ego_pred[..., 2:]
], dim=-1)
# 提取邻居预测
neighbors_pred = x[:, 1:, 1:, :]
B, Pn, T, _ = neighbors_pred.shape
# 拼接自车和邻居预测(邻居使用detach避免梯度回传)
predictions = torch.cat([ego_pred, neighbors_pred.detach()], dim=1)
# 获取车辆尺寸(自车固定尺寸,邻居从输入获取)
lw = torch.cat([
torch.tensor(ego_size, device=predictions.device)[None, None, :].repeat(B, 1, 1),
inputs["neighbor_agents_past"][:, :Pn, -1, [7, 6]]
], dim=1)
# 构建 bounding box(位置 + 尺寸)
bbox = torch.cat([
predictions,
lw.unsqueeze(2).expand(-1, -1, T, -1) + INFLATION # INFLATION = 1.0m 安全膨胀
], dim=-1)
# 将中心表示转换为四个角点
bbox = center_rect_to_points(bbox.reshape(-1, 6)).reshape(B, Pn + 1, T, 4, 2)
# 计算自车与每个邻居的碰撞距离
ego_bbox = bbox[:, :1, :, :, :].expand(-1, Pn, -1, -1, -1)[~neighbor_current_mask].reshape(-1, 4, 2)
neighbor_bbox = bbox[:, 1:, :, :, :][~neighbor_current_mask].reshape(-1, 4, 2)
# 计算有符号距离(SDF)
distances = batch_signed_distance_rect(ego_bbox, neighbor_bbox)
# 裁剪距离(CLIP_DISTANCE = 1.0m)
clip_distances = torch.maximum(1 - distances / CLIP_DISTANCE, torch.tensor(0.0, device=distances.device))
# 计算碰撞惩罚奖励
reward = - (
torch.sum(clip_distances[clip_distances > 1]) / (torch.sum((clip_distances[clip_distances > 1].detach() > 0).float()) + 1e-5) +
torch.sum(clip_distances[clip_distances <= 1]) / (torch.sum((clip_distances[clip_distances <= 1].detach() > 0).float()) + 1e-5)
).exp()
# 计算奖励对轨迹的梯度(仅对自车位置)
x_aux = torch.autograd.grad(
reward.sum(),
x,
retain_graph=True,
allow_unused=True
)[0][:, 0, :, :2]
# 坐标变换矩阵(将局部坐标转换为全局坐标)
T_total = T + 1
x_mat = torch.einsum("btd,nd->btn", x[:, 0, :, 2:],
torch.tensor([[1., 0], [0, 1], [0, -1], [1, 0]], device=x.device)
).reshape(B, T_total, 2, 2)
# 应用坐标变换
x_aux = torch.einsum("btij,btj->bti", x_mat, x_aux)
# 时间平滑:使用高斯核卷积
x_aux = torch.stack([
torch.einsum("bt,it->bi", x_aux[..., 0],
torch.tril((-torch.linspace(0, 1, T_total, device=x.device)).exp().unsqueeze(0).repeat(T_total, 1))) * 0,
F.conv1d(
F.pad(x_aux[:, None, :, 1], (10, 10), mode='replicate'),
torch.ones(1, 1, 21, device=x.device) * \
(- torch.linspace(-2, 2, 21, device=x.device) ** 2 / 4).exp()
)[:, 0] * 1.0
], dim=2)
# 反向坐标变换
x_aux = torch.einsum("btji,btj->bti", x_mat, x_aux)
# 计算最终奖励
reward = torch.sum(x_aux.detach() * x[:, 0, :, :2], dim=(1, 2))
# 乘以引导系数(3.0)
return 3.0 * reward
def batch_signed_distance_rect(rect1, rect2):
"""
计算两个矩形之间的有符号距离(Signed Distance Function)
使用分离轴定理(Separating Axis Theorem)
- 负距离表示重叠(碰撞)
- 正距离表示分离
Args:
rect1: [B, 4, 2] 第一个矩形的四个角点
rect2: [B, 4, 2] 第二个矩形的四个角点
Returns:
[B] 有符号距离
"""
B, _, _ = rect1.shape
# 计算四个分离轴(两个矩形的四条边的法向量)
norm_vec = torch.stack([
rect1[:, 0] - rect1[:, 1],
rect1[:, 1] - rect1[:, 2],
rect2[:, 0] - rect2[:, 1],
rect2[:, 1] - rect2[:, 2]
], dim=1)
norm_vec = norm_vec / torch.norm(norm_vec, dim=2, keepdim=True)
# 投影到分离轴上
proj1 = torch.einsum('bij,bkj->bik', norm_vec, rect1)
proj1_min, proj1_max = proj1.min(dim=2)[0], proj1.max(dim=2)[0]
proj2 = torch.einsum('bij,bkj->bik', norm_vec, rect2)
proj2_min, proj2_max = proj2.min(dim=2)[0], proj2.max(dim=2)[0]
# 计算重叠情况
overlap = torch.cat([proj1_min - proj2_max, proj2_min - proj1_max], dim=1)
# 确定是否碰撞及距离
positive_distance = torch.where(overlap < 0, 1e5, overlap)
is_overlap = (overlap < 0).all(dim=1)
distance = torch.where(
is_overlap,
overlap.max(dim=1).values, # 碰撞时取最大重叠深度(负值)
positive_distance.min(dim=1).values # 分离时取最小距离
)
return distance
def center_rect_to_points(rect):
"""
将中心表示的矩形转换为四个角点
Args:
rect: [B, 6] (x, y, cos_h, sin_h, l, w)
Returns:
[B, 4, 2] 四个角点坐标
"""
B, _ = rect.shape
xy, cos_h, sin_h, lw = rect[:, :2], rect[:, 2], rect[:, 3], rect[:, 4:]
# 旋转矩阵
rot = torch.stack([cos_h, -sin_h, sin_h, cos_h], dim=1).reshape(-1, 2, 2)
# 四个角点相对于中心的偏移
lw = torch.einsum('bj,ij->bij', lw,
torch.tensor([[1., 1], [-1, 1], [-1, -1], [1, -1]], device=lw.device) / 2)
# 应用旋转
lw = torch.einsum('bij,bkj->bik', lw, rot)
# 添加中心偏移
rect = xy[:, None, :] + lw
return rect
6. 关键设计亮点
6.1 多模态输入统一编码
编码器同时处理三类输入:
- 邻居车辆:使用 MLP-Mixer 处理时序轨迹
- 静态物体:简单的 MLP 投影
- 车道线:结合限速和红绿灯信息
通过位置嵌入区分不同类型:[1,0,0] 表示邻居,[0,1,0] 表示静态物体,[0,0,1] 表示车道线。
6.2 多智能体联合建模
DiT 解码器同时处理自车和邻居车辆:
- 输入维度
P = 1 + predicted_neighbor_num - 自注意力机制建模车辆间的交互
- 智能体类型嵌入区分自车和邻居
6.3 零初始化策略
解码器的 adaLN 调制层和输出层采用零初始化:
- 训练初期模型接近恒等映射
- 条件信息逐渐注入
- 保证训练稳定性
6.4 Collision Guidance
推理时使用 Classifier Guidance 实现碰撞避免:
- 基于分离轴定理计算碰撞距离
- 通过梯度引导轨迹远离碰撞
- 只在特定时间步应用(平衡探索与安全)
7. 配置参数总结
| 参数 | 默认值 | 说明 |
|---|---|---|
| agent_num | 25 | 最大邻居车辆数 |
| static_objects_num | 30 | 最大静态物体数 |
| lane_num | 70 | 最大车道线数 |
| time_len | 20 | 历史时间步数 |
| lane_len | 20 | 每条车道的点数 |
| hidden_dim | 192 | 隐藏层维度 |
| num_heads | 6 | 注意力头数 |
| encoder_depth | 3 | 编码器层数 |
| decoder_depth | 6 | 解码器层数 |
| encoder_drop_path_rate | 0.3 | 编码器 drop path 率 |
| decoder_drop_path_rate | 0.2 | 解码器 drop path 率 |
8. 核心内容总结
8.1 架构总览
┌─────────────────────────────────────────────────────────────────────────┐
│ 联合轨迹生成建模模块 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ Encoder │ │
│ │ │ │
│ │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │AgentFusion │ │StaticFusion │ │LaneFusion │ │ │
│ │ │ Encoder │ │ Encoder │ │ Encoder │ │ │
│ │ │ (邻居车辆) │ │ (静态物体) │ │ (车道线) │ │ │
│ │ │ MLP-Mixer │ │ MLP │ │ MLP-Mixer │ │ │
│ │ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │ │
│ │ │ │ │ │ │
│ │ └────────┬────────┴────────┬────────┘ │ │
│ │ ▼ ▼ │ │
│ │ ┌──────────────┐ │ │
│ │ │ FusionEncoder│ │ │
│ │ │ (自注意力融合)│ │ │
│ │ └──────┬───────┘ │ │
│ │ │ │ │
│ │ 输出: encoding [B, N, D] │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────┐ │
│ │ Diffusion_Planner │ │
│ │ │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ │ │
│ │ │ Diffusion_Planner│ │ Diffusion_Planner│ │ │
│ │ │ _Encoder │ │ _Decoder │ │ │
│ │ │ (编码器包装) │ │ (解码器包装) │ │ │
│ │ │ +权重初始化 │ │ +零初始化策略 │ │ │
│ │ └────────┬────────┘ └────────┬────────┘ │ │
│ │ │ │ │ │
│ │ │ encoding │ DiT + DPM-Solver │ │
│ │ └───────────┬───────────┘ │ │
│ │ ▼ │ │
│ │ ┌─────────────────┐ │ │
│ │ │ Collision │ │ │
│ │ │ Guidance │ │ │
│ │ │ (推理时可选) │ │ │
│ │ └─────────────────┘ │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 输出: 联合轨迹 [B, P, T, 4] │
│ │
└─────────────────────────────────────────────────────────────────────────┘
8.2 核心组件总结
| 组件 | 类型 | 职责 | 关键特点 |
|---|---|---|---|
| Encoder | nn.Module | 统一编码多模态输入 | 整合三类子编码器 |
| AgentFusionEncoder | nn.Module | 邻居车辆编码 | MLP-Mixer处理时序 |
| StaticFusionEncoder | nn.Module | 静态物体编码 | MLP投影 |
| LaneFusionEncoder | nn.Module | 车道线编码 | 含限速和红绿灯 |
| FusionEncoder | nn.Module | 多模态特征融合 | 自注意力机制 |
| Diffusion_Planner | nn.Module | 主模型入口 | 编码器+解码器 |
| Diffusion_Planner_Encoder | nn.Module | 编码器包装 | 权重初始化 |
| Diffusion_Planner_Decoder | nn.Module | 解码器包装 | 零初始化策略 |
| collision_guidance_fn | function | 碰撞避免引导 | Classifier Guidance |
| batch_signed_distance_rect | function | 碰撞距离计算 | 分离轴定理 |
8.3 关键数据流
编码器数据流
输入:
neighbors: [B, P, V, 9] # 邻居车辆历史轨迹
static: [B, S, D] # 静态物体
lanes: [B, L, V, D] # 车道线
处理:
neighbors → AgentFusionEncoder → [B, P, D], mask_p, pos
static → StaticFusionEncoder → [B, S, D], mask_p, pos
lanes → LaneFusionEncoder → [B, L, D], mask_p, pos
融合:
concat([neighbors, static, lanes]) → [B, N, D]
FusionEncoder(encoding, mask) → [B, N, D]
输出:
encoding: [B, N, D]
mask: [B, N]
解码器数据流
输入:
x_t: [B, P, T×4] # 加噪轨迹
t: [B] # 时间步
cross_c: [B, N, D] # 场景编码
route_lanes: [B, 25, 20, 12]
处理:
preproj → [B, P, D]
+ agent_embedding
+ TimestepEmbedder(t)
+ RouteEncoder(route_lanes)
DiT Blocks × N:
MHSA → MLP → MHCA → ...
输出:
x_start / score: [B, P, T, 4]
8.4 碰撞避免引导机制
推理时:
x_t → compute_bboxes() → ego_bbox, neighbor_bbox
│
▼
batch_signed_distance_rect() → distances
│
▼
clip_distances = max(1 - distances/CLIP_DISTANCE, 0)
│
▼
reward = -exp(mean(clip_distances))
│
▼
x_aux = grad(reward, x) # 计算梯度
│
▼
x_aux → coordinate_transform → time_smoothing → coordinate_transform_inv
│
▼
guidance = 3.0 * sum(x_aux * x[:, 0, :, :2])
8.5 关键设计决策
| 设计决策 | 技术实现 | 目的 |
|---|---|---|
| 多模态统一编码 | 三类子编码器+FusionEncoder | 整合邻居、静态物体、车道线信息 |
| MLP-Mixer架构 | AgentFusionEncoder、LaneFusionEncoder | 高效处理时序数据,无需注意力机制 |
| 动态掩码 | mask_p、encoding_mask | 处理可变数量输入实体 |
| 位置嵌入 | pos_emb + type标记 | 区分不同类型实体 |
| 零初始化 | adaLN和输出层初始化为0 | 训练稳定,条件信息逐渐注入 |
| Classifier Guidance | collision_guidance_fn | 推理时引导轨迹避免碰撞 |
| 分离轴定理 | batch_signed_distance_rect | 高效精确的碰撞检测 |
| 时间自适应引导 | guidance_scale随t变化 | 平衡探索与安全性 |
8.6 训练与推理流程
训练流程
1. 数据准备: neighbors, static, lanes, ego_future, neighbors_future
2. 编码器前向: encoder_outputs = encoder(neighbors, static, lanes)
3. 时间步采样: t ~ Uniform(eps, 1-eps)
4. 前向扩散: x_t = mean * x_0 + std * noise
5. 解码器前向: pred = decoder(x_t, t, encoder_outputs, route_lanes)
6. 损失计算: loss = diffusion_loss_func(pred, x_0, noise, t)
7. 反向传播: loss.backward()
推理流程
1. 编码器前向: encoder_outputs = encoder(neighbors, static, lanes)
2. 初始化: x_T = current_state + noise
3. DPM-Solver迭代 (10步):
for step in range(10):
t = T - step * dt
pred = decoder(x_t, t, encoder_outputs, route_lanes)
if guidance_fn:
guidance = guidance_fn(x_t, t, ...)
x_t = dpm_solver_step(x_t, t, pred, guidance)
x_t = correcting_xt_fn(x_t) # 约束当前状态
4. 输出: prediction = x_0[:, :, 1:]
8.7 模块协作关系
┌─────────────────────────────────────────────────────────────────────┐
│ 模块依赖关系 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ │
│ │ Dataset │───→ 原始数据 │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Encoder │ │
│ │ AgentFusion + StaticFusion + LaneFusion + FusionEncoder │ │
│ └──────┬──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Decoder │ │
│ │ DiT + RouteEncoder + StateNormalizer │ │
│ └──────┬──────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Guidance (推理时) │ │
│ │ collision_guidance_fn + SDF + Gradient │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
8.8 关键数学公式
分离轴定理(碰撞检测)
对于两个矩形 rect1 和 rect2:
1. 计算4个分离轴(两个矩形的四条边法向量)
2. 将两个矩形投影到每个轴上
3. 检查投影是否重叠
4. 如果所有轴都重叠 → 碰撞(返回负距离)
5. 如果至少一个轴不重叠 → 分离(返回最小正距离)
Classifier Guidance
x_{t-1} = x_t + dt * (drift + scale * grad(reward, x))
其中 reward = -exp(mean(clip_distances))
clip_distances = max(1 - distances/CLIP_DISTANCE, 0)
MLP-Mixer 结构
Block:
x = ChannelMLP(x) + x # 通道维度变换
x = TokenMLP(x) + x # Token维度变换
ChannelMLP: Linear → GELU → Linear
TokenMLP: Linear → GELU → Linear
8.9 性能与优化
| 优化策略 | 实现方式 | 效果 |
|---|---|---|
| 动态掩码 | 只处理有效实体 | 减少无效计算,支持可变输入 |
| MLP-Mixer | 替代Transformer | 降低复杂度,适合时序数据 |
| 零初始化 | 稳定训练初期 | 加速收敛,避免梯度爆炸 |
| DPM-Solver | 快速采样 | 10-20步完成推理 |
| 时间自适应引导 | 引导强度随t变化 | 平衡探索与安全 |
| 高斯卷积平滑 | 时间维度平滑 | 生成更平滑的轨迹 |
8.10 与其他模块的关系
┌─────────────────────────────────────────────────────────────────────┐
│ 联合轨迹模块与其他模块的交互 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Encoder │────▶│ DiT │────▶│ Loss │ │
│ │ (场景编码) │ │ (主干网络) │ │ (损失计算) │ │
│ └─────────────┘ └──────┬──────┘ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ SDE │ │
│ │ (扩散过程) │ │
│ └─────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Guidance │ │
│ │ (碰撞避免) │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
精炼核心技术总结
一句话定位
联合轨迹生成建模模块通过多模态编码器融合场景信息,驱动 DiT 生成自车与邻居车辆的联合轨迹。
核心技术点
- 多模态编码器:AgentFusion+StaticFusion+LaneFusion 分别处理不同输入类型
- MLP-Mixer:替代 Transformer,高效处理时序轨迹数据
- FusionEncoder:自注意力机制实现多模态特征交互
- 联合轨迹建模:同时预测自车和邻居车辆轨迹
- 掩码机制:支持可变数量的输入实体
关键公式速查
MLP-Mixer Block:
h′=ChannelMLP(h)+h,h′′=TokenMLP(h′)+h′ \text{h}' = \text{ChannelMLP}(\text{h}) + \text{h}, \quad \text{h}'' = \text{TokenMLP}(\text{h}') + \text{h}' h′=ChannelMLP(h)+h,h′′=TokenMLP(h′)+h′
FusionEncoder 融合:
encoding=SelfAttention(concat(agent,static,lane)) \text{encoding} = \text{SelfAttention}(\text{concat}(\text{agent}, \text{static}, \text{lane})) encoding=SelfAttention(concat(agent,static,lane))
模块间一句话关联
编码器将原始感知数据转换为统一场景表示,输入 DiT 进行扩散生成,推理时结合 Guidance 模块实现碰撞避免。