本文基于对论文《Latent-WAM: Latent World Action Modeling for End-to-End Autonomous Driving》的深入研读和讨论,系统性地梳理该框架的核心技术要点。文章按照从宏观到微观、从训练到推理的逻辑顺序组织,结合伪代码示例,力求呈现一个完整、清晰的技术全景。
目录
- 整体架构概览
- 空间感知压缩世界编码器 (SCWE)
- 几何蒸馏:让视觉模型理解三维空间
- 动态潜在世界模型 (DLWM)
- 3D-RoPE:时空位置编码的创新
- 轨迹规划模块
- EMA编码器:自监督学习的关键
- 训练与推理
- 核心技术决策总结
一、整体架构概览
1.1 设计哲学
Latent-WAM 的核心设计哲学是:在高度压缩的潜在空间中学习具有空间意识和动态感知能力的世界表征,从而实现高效、可扩展的轨迹规划。
与现有方法相比,它有三个关键创新:
- 空间感知压缩:用少量可学习的场景标记压缩多视角图像
- 几何知识蒸馏:从冻结的几何基础模型中注入3D空间理解能力
- 动态潜在建模:用因果Transformer自回归预测未来世界状态
1.2 三大核心模块
python
class LatentWAM(nn.Module):
def __init__(self):
# 模块1: 空间感知压缩世界编码器
self.scwe = SpatialAwareCompressiveWorldEncoder()
# 模块2: 动态潜在世界模型
self.dlwm = DynamicLatentWorldModel()
# 模块3: 轨迹解码器
self.traj_decoder = TrajectoryDecoder()
def forward(self, images, ego_status, command):
# 1. SCWE编码当前世界状态
scene_tokens = self.scwe(images) # (T, M×N, Dl)
# 2. 聚合世界状态(场景 + 自车)
world_state = self.aggregate(scene_tokens, ego_status)
# 3. 轨迹解码(推理时)
trajectories = self.traj_decoder(world_state[-1], command)
return trajectories
1.3 数据流维度全景
python
# 输入维度
images: (T=4, M=3, H=224, W=448, C=3) # 4帧 × 3视角
ego_status: (T=4, 速度=2, 加速度=2, 命令=4) # 自车状态
# SCWE输出
scene_tokens: (T=4, M=3, N=16, Dl=256) # 48个scene tokens per frame
# 世界状态聚合(每帧)
world_state_per_frame: (M×N + 1, Dl) = (49, 256)
# 完整序列
world_state_seq: (T × 49, 256) = (196, 256)
二、空间感知压缩世界编码器 (SCWE)
2.1 核心思想
SCWE 解决的问题是:如何将多视角、高分辨率的图像序列压缩成少量、高效的标记,同时保留空间几何信息。
解决方案:
- 用可学习的场景查询与图像块标记交互,实现信息压缩
- 通过几何蒸馏向视觉主干注入3D空间理解能力
2.2 场景压缩的数学形式
Q ^ scene , X ^ = E ( Q scene ; X ) \hat{Q}_{\text{scene}}, \hat{X} = \mathcal{E}(Q_{\\text{scene}}; X) Q^scene,X^=E(Qscene;X)
其中:
X ∈ R^(T×M×S×De):图像块标记,S是每张图的块数Q_scene ∈ R^(T×M×N×De):可学习的场景查询,N=16ˆQ_scene:压缩后的场景标记,作为世界状态的核心
2.3 伪代码实现
python
class SpatialAwareCompressiveWorldEncoder(nn.Module):
def __init__(self, T=4, M=3, N=16, S=2500, De=768, Dl=256):
super().__init__()
# DINOv2作为视觉主干
self.vision_encoder = DINOv2_Base() # 86.6M参数
# 可学习的场景查询
self.scene_queries = nn.Parameter(
torch.randn(T, M, N, De)
)
# 投影层
self.projection = nn.Linear(De, Dl)
# 几何蒸馏的投影头
self.geo_projector = nn.Linear(De, 2048)
def forward(self, images):
# images: (T, M, 3, H, W)
T, M, C, H, W = images.shape
# 1. 编码图像 → patch tokens
# 输出: (T*M, S, De)
patch_tokens = self.vision_encoder(images.view(T*M, C, H, W))
# 2. 重塑为时序格式
patch_tokens = patch_tokens.view(T, M, S, -1) # (T, M, S, De)
# 3. 扩展scene queries到batch
queries = self.scene_queries.expand(batch_size, -1, -1, -1, -1)
# 4. 交叉注意力:queries从patch tokens中提取信息
# 简化的DINO编码器交互
scene_tokens = self.cross_attention(queries, patch_tokens)
# scene_tokens: (T, M, N, De)
# 5. 投影到潜在空间
scene_tokens = self.projection(scene_tokens) # (T, M, N, Dl)
return scene_tokens, patch_tokens
三、几何蒸馏:让视觉模型理解三维空间
3.1 为什么需要几何蒸馏?
纯视觉模型(如DINOv2)虽然能识别物体,但缺乏对三维空间结构的理解。例如:
- 不知道一条白线是"车道线"还是"停车位边界"
- 不理解近处的车和远处的车在深度上的差异
几何蒸馏通过让视觉模型模仿一个冻结的3D基础模型(WorldMirror/VGGT)的输出,来注入空间理解能力。
3.2 损失函数
L align = 1 − cos ( LN ( ϕ ( X ^ ) ) , LN ( f g ( I ) ) ) \mathcal{L}_{\text{align}} = 1 - \cos(\text{LN}(\phi(\hat{X})), \text{LN}(f_g(I))) Lalign=1−cos(LN(ϕ(X^)),LN(fg(I)))
其中:
ϕ(ˆX):学生模型(DINO)的图像块特征f_g(I):教师模型(WorldMirror)的几何特征LN:LayerNorm归一化cos:余弦相似度
3.3 几何特征的真值是什么?
VGGT/WorldMirror 对每张图像输出一个密集特征图:
python
# 教师模型输出
teacher_features = worldmirror(images) # (T, M, S, 2048)
# 每个patch位置都有一个2048维的几何特征向量
# 这些特征蕴含了该点的3D位置、深度线索、多视角对应关系
3.4 蒸馏的训练策略
python
class GeometricDistillation:
def __init__(self):
# 教师模型:冻结,特征可预先计算
self.teacher = WorldMirror().eval()
self.student = DINOv2() # 可训练
# 预先计算所有训练数据的几何特征
self.cached_teacher_features = precompute_features()
def forward(self, images):
# 学生特征
student_patches = self.student(images) # (T*M, S, 768)
student_features = self.geo_projector(student_patches) # (T*M, S, 2048)
# 教师特征(从缓存加载,无需前向)
teacher_features = self.cached_teacher_features[image_ids]
# 计算蒸馏损失
loss_align = 1 - cosine_similarity(
layer_norm(student_features),
layer_norm(teacher_features)
).mean()
return loss_align
3.5 关键设计决策
| 设计选择 | 原因 |
|---|---|
| 教师模型冻结 | 几何特征是"真理",不应随训练改变 |
| 特征预先缓存 | 避免每次训练都跑大模型,节省GPU内存 |
| 使用余弦相似度 | 关注方向而非幅度,避免特征数值崩溃 |
| LayerNorm预处理 | 消除量纲影响,稳定训练 |
四、动态潜在世界模型 (DLWM)
4.1 问题定义
DLWM 的任务是:基于历史的世界状态,预测未来的世界状态。
与视频生成方法不同,DLWM 预测的是潜在空间中的特征,而非像素。这大大降低了计算开销,同时保留了规划所需的关键信息。
4.2 世界状态的聚合
python
def aggregate_world_state(scene_tokens, ego_status):
# scene_tokens: (T, M, N, Dl) = (4, 3, 16, 256)
# ego_status: (T, Dl) = (4, 256)
T, M, N, Dl = scene_tokens.shape
# 1. 展平场景tokens
scene_flat = scene_tokens.view(T, M*N, Dl) # (4, 48, 256)
# 2. 扩展自车状态到每帧
ego_expanded = ego_status.unsqueeze(1) # (4, 1, 256)
# 3. 拼接得到完整世界状态
world_state = torch.cat([scene_flat, ego_expanded], dim=1)
# world_state: (4, 49, 256)
return world_state
4.3 因果预测的数学形式
S future = DLWM ( Q future , K V future ) S_{\text{future}} = \text{DLWM}(Q_{\text{future}}, KV_{\text{future}}) Sfuture=DLWM(Qfuture,KVfuture)
其中:
KV_future:历史世界状态(作为key-value缓存)Q_future:可学习的未来状态查询S_future:预测的未来世界状态
4.4 教师强制注意力掩码
python
def build_teacher_forcing_mask(T=4, tokens_per_frame=49):
"""
构建因果注意力掩码
- 帧内:双向注意力(全1)
- 帧间:只能看到过去(下三角)
"""
seq_len = T * tokens_per_frame
mask = torch.zeros(seq_len, seq_len)
for frame_idx in range(T):
start = frame_idx * tokens_per_frame
end = (frame_idx + 1) * tokens_per_frame
# 帧内全连接
mask[start:end, start:end] = 1
# 帧间:只能看到历史帧
if frame_idx > 0:
mask[start:end, :start] = 1
return mask
4.5 自车状态监督
除了预测场景标记,DLWM 还同时预测未来的自车状态:
python
# 从预测的世界状态中提取自车状态
predicted_ego = future_world_state[:, -1, :] # 最后一个token是自车
# 多任务预测头
cmd_pred = self.cmd_head(predicted_ego) # 驾驶命令,4分类
vel_pred = self.vel_head(predicted_ego) # 速度,2维
acc_pred = self.acc_head(predicted_ego) # 加速度,2维
# 损失计算
L_cmd = CrossEntropy(cmd_pred, cmd_gt)
L_vel = MSE(vel_pred, vel_gt)
L_acc = MSE(acc_pred, acc_gt)
L_ego = L_cmd + L_vel + L_acc
五、3D-RoPE:时空位置编码的创新
5.1 问题背景
DLWM 处理的序列中,每个 token 都有三个维度的坐标:
- 时间 (t):0, 1, 2, 3(T=4帧)
- 相机 (m):0, 1, 2(左、前、右)
- 标记内索引 (n):0~15(每个视角16个scene tokens)
标准的 RoPE 只能编码一维位置,无法区分"同一时间不同相机"和"不同时间同一相机"的 token。
5.2 解决方案:拆分注意力头维度
注意力头维度 Dh = 256
├── 第1部分 (d1=85):编码时间坐标 t,基频=50
├── 第2部分 (d2=85):编码相机索引 m,基频=10
└── 第3部分 (d3=86):编码token索引 n,基频=100
5.3 伪代码实现
python
class Rotation3D:
def __init__(self, dim=256, base_t=50, base_m=10, base_n=100):
self.dim = dim
# 分配各维度占比
self.dims = [dim//3, dim//3, dim - 2*(dim//3)] # [85, 85, 86]
self.bases = [base_t, base_m, base_n]
# 预计算各维度的频率
self.freqs = []
for d_dim, base in zip(self.dims, self.bases):
freq = base ** (-2 * torch.arange(d_dim//2) / d_dim)
self.freqs.append(freq)
def get_rotation_matrix(self, t, m, n):
"""为单个token生成旋转矩阵"""
angles = []
for coord, d_dim, freq in zip([t, m, n], self.dims, self.freqs):
# 每个坐标贡献 d_dim/2 个旋转角
angle = coord * freq # (d_dim//2,)
angles.append(angle)
all_angles = torch.cat(angles) # (128,)
# 构建旋转矩阵(块对角)
R = self.build_block_diagonal_rotation(all_angles)
return R
def apply(self, tokens, coords):
"""
tokens: (seq_len, dim)
coords: (seq_len, 3) 每个token的(t, m, n)
"""
for i, (t, m, n) in enumerate(coords):
R = self.get_rotation_matrix(t, m, n)
tokens[i] = R @ tokens[i] # 只旋转Q和K,V不变
return tokens
# 使用示例
rope_3d = Rotation3D(dim=256)
# 对Q和K施加旋转(V不变)
Q_rotated = rope_3d.apply(Q, q_coords)
K_rotated = rope_3d.apply(K, k_coords)
# 计算注意力
attn = softmax(Q_rotated @ K_rotated.T / sqrt(dim))
output = attn @ V # V不旋转
5.4 RoPE 的关键性质
相对位置编码能力:旋转后的点积只依赖于坐标差
R ( t 1 ) R ( t 2 ) T = R ( t 1 − t 2 ) R(t_1)R(t_2)^T = R(t_1 - t_2) R(t1)R(t2)T=R(t1−t2)
这意味着 Q_rotated · K_rotated 的值只取决于 t1-t2、m1-m2、n1-n2,而不是绝对坐标值。
外推能力:训练时只见过 T=4,推理时可处理 T=8,因为旋转函数是连续的。
六、轨迹规划模块
6.1 架构设计
轨迹解码器是一个轻量级 Transformer,输入当前世界状态和驾驶命令,输出多条候选轨迹。
python
class TrajectoryDecoder(nn.Module):
def __init__(self, K=6, np=40, dim=256, num_heads=8, num_layers=4):
super().__init__()
self.K = K # 候选轨迹数
self.np = np # 每个轨迹的点数
# 可学习的轨迹查询
self.traj_queries = nn.Parameter(torch.randn(K, np, dim))
# 驾驶命令编码器
self.cmd_encoder = nn.Linear(4, dim)
# Transformer解码器
decoder_layer = nn.TransformerDecoderLayer(
d_model=dim, nhead=num_heads, dim_feedforward=1024
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
# 输出头:解码为(x, y, θ)
self.output_head = nn.Sequential(
nn.Linear(dim, 128),
nn.ReLU(),
nn.Linear(128, 3)
)
def forward(self, world_state, command_one_hot):
"""
world_state: (49, 256) - 当前帧的世界状态
command: (4,) - one-hot驾驶命令
"""
# 1. 编码命令并注入查询
cmd_embed = self.cmd_encoder(command_one_hot) # (256,)
queries = self.traj_queries + cmd_embed # (K, np, 256)
# 2. 重塑为Transformer输入
queries = queries.view(self.K * self.np, 256) # (K*np, 256)
# 3. 交叉注意力(world_state作为memory)
traj_features = self.decoder(
tgt=queries.unsqueeze(1), # (K*np, 1, 256)
memory=world_state.unsqueeze(0) # (1, 49, 256)
)
# 4. 解码为轨迹点
trajectories = self.output_head(traj_features) # (K*np, 3)
trajectories = trajectories.view(self.K, self.np, 3)
# 5. 根据命令选择最终轨迹
cmd_to_idx = {'left':0, 'right':1, 'straight':2, 'stop':3}
idx = cmd_to_idx[command]
final_trajectory = trajectories[idx]
return final_trajectory
6.2 轨迹的数学表示
每个轨迹点表示为 (x, y, θ):
x:纵向位移(前进方向)y:横向位移(左右)θ:航向角
所有量都在自车局部坐标系中,以当前时刻 t 为原点。
6.3 训练损失
L traj = L1_loss ( τ pred , τ expert ) \mathcal{L}{\text{traj}} = \text{L1\loss}(\tau{\text{pred}}, \tau{\text{expert}}) Ltraj=L1_loss(τpred,τexpert)
使用 L1 损失而非 L2,因为 L1 对异常值更鲁棒。
6.4 推理性能
| 模块 | 参数量 | 延迟 | 内存 |
|---|---|---|---|
| World Encoder | 86.6M | 100ms | - |
| Trajectory Decoder | 8.4M | 6ms | - |
| 总计 | 104M | 107ms | 1.1GB |
七、EMA编码器:自监督学习的关键
7.1 为什么需要EMA?
在自监督学习中,如果让 DLWM 直接预测 Online 编码器的输出,模型可能会"作弊"------学到恒等映射或平凡解。
EMA(指数移动平均)编码器通过提供一个缓慢变化的、平滑的目标来解决这个问题。
7.2 EMA更新机制
python
class EMAEncoder:
def __init__(self, online_encoder, momentum=0.999):
self.online = online_encoder
self.target = copy.deepcopy(online_encoder)
self.momentum = momentum
# 冻结target编码器
for param in self.target.parameters():
param.requires_grad = False
def update(self):
"""每个训练step后调用"""
with torch.no_grad():
for param_t, param_o in zip(self.target.parameters(),
self.online.parameters()):
# EMA更新:缓慢追随online编码器
param_t.data = (self.momentum * param_t.data +
(1 - self.momentum) * param_o.data)
@torch.no_grad()
def encode(self, images):
"""提供预测的真值"""
return self.target(images)
7.3 训练中的数据流
python
# 训练一个样本
def training_step(images_t0, images_t1, images_t2, images_t3, ego_status):
# 1. Online编码器处理所有帧
scene_online = online_scwe(images) # 可训练
# 2. EMA编码器处理所有帧(提供真值)
with torch.no_grad():
scene_target = ema_scwe(images) # 冻结,EMA更新
# 3. DLWM基于前3帧预测第4帧
history = scene_online[:3] # t0, t1, t2
future_pred = dlwm(history) # 预测t3
# 4. 用EMA输出作为真值计算损失
loss_wm = MSE(future_pred, scene_target[3])
# 5. 反向传播更新online编码器和DLWM
loss_wm.backward()
optimizer.step()
# 6. 更新EMA编码器
ema_scwe.update()
7.4 EMA的收敛性质
当 Online 模型收敛到最优解 θ* 时:
θ EMA = ∑ k = 0 ∞ ( 1 − m ) m k θ online ( t − k ) → θ ∗ \theta_{\text{EMA}} = \sum_{k=0}^{\infty} (1-m) m^k \theta_{\text{online}}^{(t-k)} \rightarrow \theta^* θEMA=k=0∑∞(1−m)mkθonline(t−k)→θ∗
即 EMA 模型也会收敛到相同的最优解,且具有更小的方差。
八、训练与推理
8.1 总损失函数
L = L traj + α L align + β L wm + γ L ego \mathcal{L} = \mathcal{L}{\text{traj}} + \alpha \mathcal{L}{\text{align}} + \beta \mathcal{L}{\text{wm}} + \gamma \mathcal{L}{\text{ego}} L=Ltraj+αLalign+βLwm+γLego
| 损失项 | 权重 | 作用 |
|---|---|---|
L_traj |
1.0 | 轨迹模仿学习 |
L_align |
0.1 | 几何知识蒸馏 |
L_wm |
0.2 | 世界模型自监督预测 |
L_ego |
0.1 | 自车状态监督 |
8.2 训练配置
python
# 硬件
gpus = 32 × A100
training_days = 2
# 数据
batch_size = 512
T = 4 # 时序帧数
M = 3 # 摄像头数
N = 16 # 每视角scene token数
Dl = 256 # 潜在维度
# 优化
optimizer = AdamW(lr=2e-4, weight_decay=0.05)
scheduler = CosineAnnealingLR(T_max=100, eta_min=1e-6)
warmup_steps = 0.1 * total_steps
# 精度
use_bf16 = True
8.3 推理流程
python
@torch.no_grad()
def inference(images, ego_status, command):
# 1. SCWE编码当前帧
scene_tokens = scwe(images[-1:]) # 只处理当前帧
# 2. 聚合世界状态
world_state = aggregate(scene_tokens, ego_status[-1:])
# 3. 轨迹解码
trajectory = traj_decoder(world_state, command)
return trajectory # (np, 3)
# 注意:推理时不需要:
# - EMA编码器
# - DLWM(未来预测)
# - 几何教师模型
# - 任何辅助模块
8.4 训练 vs 推理模块对比
| 模块 | 训练时需要 | 推理时需要 | 参数量 |
|---|---|---|---|
| SCWE (online) | ✅ | ✅ | 86.6M |
| SCWE (EMA) | ✅ | ❌ | 86.6M |
| 几何教师模型 | ✅ (预计算) | ❌ | - |
| DLWM | ✅ | ❌ | ~10M |
| 轨迹解码器 | ✅ | ✅ | 8.4M |
| 总计 | 191M | 104M | - |
九、核心技术决策总结
9.1 为什么不用BEV?
| BEV方法的问题 | Latent-WAM的解决方案 |
|---|---|
| 视角转换有畸变和精度损失 | 纯PV操作,保持原始视角 |
| 需要构建BEV网格,计算量大 | 直接在token上操作 |
| 投影过程丢失高度信息 | 3D-RoPE保留完整时空信息 |
| 强依赖精确的外参标定 | 通过位置编码学习视角关系 |
9.2 为什么用蒸馏而非直接拼接几何特征?
| 方法 | EPDMS | 问题 |
|---|---|---|
| 无几何信息 | 88.3 | - |
| 拼接冻结几何特征 | 88.0 | 特征与规划目标不对齐 |
| 几何蒸馏 | 89.3 | 特征内化到主干,推理无开销 |
9.3 为什么需要EMA而不是直接预测Online特征?
| 目标 | 问题 |
|---|---|
| 预测Online特征 | 模型可能学到恒等映射,坍塌到平凡解 |
| 预测固定随机特征 | 目标无意义,无法收敛 |
| 预测EMA特征 | 提供稳定、平滑、有意义的自监督目标 |
9.4 场景查询如何学到有效表征?
没有直接监督,而是通过任务驱动:
- 轨迹规划任务(
L_traj)要求它们包含规划所需信息 - 世界模型任务(
L_wm)要求它们能够预测未来 - 自车状态监督(
L_ego)要求它们与环境动态交互
这种"通过做事来学习"的设计,让模型自动发现对规划最重要的表征。
9.5 关键超参数
| 参数 | 值 | 说明 |
|---|---|---|
| T | 4 | 时序帧数 |
| M | 3 | 摄像头数(左、前、右) |
| N | 16 | 每视角scene token数 |
| Dl | 256 | 潜在特征维度 |
| 3D-RoPE基频 | (50, 10, 100) | (时间, 相机, token索引) |
| EMA动量 | 0.999 | 目标网络更新速率 |
| 损失权重 | (0.1, 0.2, 0.1) | (align, wm, ego) |
十、总结
Latent-WAM 的核心创新可以概括为:
- 空间感知压缩:用16个可学习token表示整个场景,通过几何蒸馏注入3D理解
- 动态潜在建模:用因果Transformer自回归预测未来世界状态
- 3D-RoPE:统一的时空位置编码,让模型理解token的三维坐标
- EMA自监督:提供稳定的预测目标,防止模型坍塌
- 任务驱动学习:场景查询没有直接监督,但通过规划任务学会有效表征
这些设计共同造就了 Latent-WAM 在数据效率、推理速度、规划质量三方面的优势:用更少的训练数据、更小的模型(104M参数),在 NAVSIM v2 和 HUGSIM 上达到了新的最先进水平。
参考文献
- Wang et al. Latent-WAM: Latent World Action Modeling for End-to-End Autonomous Driving. arXiv:2603.24581, 2026.
- Su et al. RoFormer: Enhanced Transformer with Rotary Position Embedding. 2021.
- Caron et al. Emerging Properties in Self-Supervised Vision Transformers. ICCV 2021.
- Grill et al. Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning. NeurIPS 2020.