AlpamayoR1 通过巧妙地结合VLM的推理能力和扩散模型的生成能力和KV Cache的优化技术,实现了高效、高质量的轨迹预测。
📊 模型架构总结
1️⃣ ReasoningVLA (基础模型) - base_model.py
核心组件:
- VLM Backbone: 使用 Qwen3-VL-8B-Instruct 作为视觉语言模型
- 轨迹Tokenizer : 将连续轨迹编码为离散token(
<i0>-<i767>) - 轨迹融合Mixin: 将轨迹token融合到输入序列中
特殊Token系统:
- 轨迹相关:
<|traj_history|>,<|traj_future|>,<|traj_history_start|>,<|traj_future_start|> - 推理相关:
<|cot_start|>,<|cot_end|>(Chain of Thought) - 动作相关:
<|meta_action_start|>,<|meta_action_end|> - 路由相关:
<|route_start|>,<|route_pad|>,<|route_end|>
2️⃣ AlpamayoR1 (专家模型) - alpamayo_r1.py
继承自 ReasoningVLA,添加了动作生成能力:
核心模块:
- Expert模型: 基于VLM文本配置,用于动作空间推理
- 动作空间: 支持多种动作表示(如unicycle_accel_curvature)
- 扩散模型: 使用Flow Matching等方法进行动作生成
- 动作投影层 :
action_in_proj(动作→embedding) 和action_out_proj(embedding→动作)
3️⃣ 三阶段推理流程 - sample_trajectories_from_data_with_vlm_rollout()
阶段1: VLM自回归生成
- 输入:图像 + 文本prompt + 历史轨迹token
- 输出:CoT推理文本 + 未来轨迹token
- 使用
ExpertLogitsProcessor屏蔽轨迹token的logits
阶段2: Expert去噪
- 定义
step_fn(x, t): 噪声动作 → token embedding → expert推理 → 预测噪声 - 复用VLM的KV cache,避免重复计算
阶段3: 扩散采样
diffusion.sample()生成多个轨迹样本- 通过
action_to_traj()解码为xyz和rot
4️⃣ 关键技术特性
- 轨迹Token化: 历史轨迹16个token,未来轨迹64个token
- KV Cache优化 : VLM生成后缓存,Expert复用,
prompt_cache.crop()保持长度 - 位置编码处理 :
rope_deltas处理多模态输入的位置偏移 - Logits处理 :
ExpertLogitsProcessor屏蔽轨迹token,StopAfterEOS在<|traj_future_start|>后停止
5️⃣ 数据流
输入数据 (图像+文本+历史轨迹)
↓
[轨迹融合] fuse_traj_tokens()
↓
VLM生成 (CoT文本+未来轨迹token)
↓
[KV Cache] 缓存prompt的KV
↓
[扩散采样] 迭代去噪 (action_in_proj → expert → action_out_proj)
↓
[动作解码] action_to_traj()
↓
输出: pred_xyz, pred_rot
6️⃣ 模型优势
- 模块化设计: VLM和Expert分离,职责清晰
- 高效推理: KV cache复用,避免重复计算
- 多模态融合: 图像、文本、轨迹统一处理
- 灵活采样: 支持多轨迹采样,提高鲁棒性
- 可扩展性: 支持不同的动作空间和扩散方法
该架构巧妙地结合了VLM的推理能力和扩散模型的生成能力,实现了高质量的轨迹预测。
📊 ReasoningVLA 模型深度解析
一、模型架构概览
1.1 继承关系
PreTrainedModel (HuggingFace)
↓
ReasoningVLA (base_model.py:285)
↓
AlpamayoR1 (alpamayo_r1.py:73)
1.2 核心设计理念
- VLM作为推理引擎: 使用Qwen3-VL进行视觉理解和推理
- Expert作为动作生成器: 专门负责动作空间的去噪和生成
- 扩散模型: 通过迭代去噪生成高质量轨迹
- KV Cache优化: 复用VLM的注意力缓存,避免重复计算
二、ReasoningVLA 基础模型详解
2.1 配置类 ReasoningVLAConfig
python
核心参数:
├── vlm_name_or_path: "Qwen/Qwen3-VL-8B-Instruct"
├── vlm_backend: "qwenvl3"
├── traj_vocab_size: 768 # 轨迹离散token数量
├── tokens_per_history_traj: 16 # 历史轨迹token数
├── tokens_per_future_traj: 64 # 未来轨迹token数
├── model_dtype: "bfloat16"
├── attn_implementation: "flash_attention_2"
├── min_pixels/max_pixels: 图像分辨率控制
└── add_special_tokens: 是否添加特殊token
2.2 特殊Token系统
轨迹相关Token (TRAJ_TOKEN)
python
{
"history": "<|traj_history|>",
"future": "<|traj_future|>",
"history_start": "<|traj_history_start|>",
"future_start": "<|traj_future_start|>",
"history_end": "<|traj_history_end|>",
"future_end": "<|traj_future_end|>"
}
完整特殊Token列表 (SPECIAL_TOKENS)
- 推理:
cot_start,cot_end - 动作:
meta_action_start,meta_action_end - 路由:
route_start,route_pad,route_end - 问答:
question_start,question_end,answer_start,answer_end - 向量化水印:
vectorized_wm_start,vectorized_wm_end,vectorized_wm_pre_tkn
2.3 VLM Backbone初始化
Qwen3-VL配置 (_initialize_qwenvl3_vlm)
python
1. 加载Qwen3VLConfig
2. 保存原始vocab_size
3. 扩展vocab_size以容纳轨迹token
4. 创建Qwen3VLForConditionalGeneration实例
关键点:
- 使用Flash Attention 2加速
- 支持多模态输入(图像+文本)
- 动态调整词汇表大小
2.4 轨迹Tokenizer系统
DeltaTrajectoryTokenizer (delta_tokenizer.py:21)
编码过程 (encode()):
python
输入: fut_xyz [B, Tf, 3], fut_rot [B, Tf, 3, 3]
↓
1. 计算delta: fut_xyz[:, 1:] - fut_xyz[:, :-1]
↓
2. 归一化到[0, 1]: (delta - min) / (max - min)
↓
3. 量化到离散bin: round(delta * (num_bins - 1))
↓
4. 展平: [B, Tf, 3] → [B, Tf*3]
↓
输出: tokens [B, num_tokens_per_trajectory]
解码过程 (decode()):
python
输入: tokens [B, num_tokens_per_trajectory]
↓
1. 重塑: [B, num_tokens] → [B, Tf, 3]
↓
2. 反归一化: tokens / (num_bins - 1) * (max - min) + min
↓
3. 累加恢复: cumsum(delta) → fut_xyz
↓
4. 计算旋转矩阵: 使用多项式拟合计算yaw角
↓
输出: fut_xyz [B, Tf, 3], fut_rot [B, Tf, 3, 3]
Yaw角计算 (get_yaw_rotation_matrices()):
python
1. 滑动窗口 (window_size=10)
2. 多项式拟合 x(t) 和 y(t) (poly_order=3)
3. 计算导数 dx/dt, dy/dt
4. yaw = atan2(dy, dx)
5. 构造旋转矩阵:
[[cos(yaw), -sin(yaw), 0],
[sin(yaw), cos(yaw), 0],
[0, 0, 1]]
2.5 轨迹融合Mixin
TrajectoryFusionMixin (base_model.py:125)
核心方法 fuse_traj_tokens():
python
输入: input_ids [B, n_token], traj_data
↓
1. 验证必需属性 (hist_traj_tokenizer, hist_token_start_idx, config)
↓
2. 编码历史轨迹: tokenize_history_trajectory()
↓
3. 替换pad token: replace_pad_token()
↓
输出: input_ids [B, n_token] (轨迹token已融合)
历史轨迹编码 (tokenize_history_trajectory()):
python
输入: traj_data [B, n_traj, T, 3]
↓
1. 展平: [B, n_traj, T, 3] → [B*n_traj, T, 3]
↓
2. 编码: tokenizer.encode(hist_xyz, hist_rot, fut_xyz, fut_rot)
↓
3. 添加偏移: + start_idx
↓
4. 重塑: [B*n_traj, tokens] → [B, n_traj*tokens]
↓
输出: hist_idx [B, n_traj * tokens_per_history_traj]
三、AlpamayoR1 专家模型详解
📦 核心组件详解
1️⃣ 配置类 - AlpamayoR1Config
继承自 ReasoningVLAConfig,添加了以下配置:
python
diffusion_cfg: dict # 扩散模型配置
action_space_cfg: dict # 动作空间配置
action_in_proj_cfg: dict # 动作输入投影配置
action_out_proj_cfg: dict # 动作输出投影配置
expert_cfg: dict # Expert模型配置
keep_same_dtype: bool = True # 保持相同数据类型
expert_non_causal_attention: bool = True # Expert使用非因果注意力
2️⃣ Expert模型 - self.expert
初始化过程:
python
# 1. 深拷贝VLM的文本配置
expert_config = copy.deepcopy(self.vlm.config.text_config)
# 2. 应用expert特定配置
if config.expert_cfg is not None:
for key, value in config.expert_cfg.items():
setattr(expert_config, key, value)
# 3. 从配置创建Expert模型
self.expert = AutoModel.from_config(expert_config)
# 4. 删除embed_tokens(共享VLM的embedding)
del self.expert.embed_tokens
关键特性:
- 共享VLM的embedding层,减少参数量
- 专注于动作空间的推理
- 支持非因果注意力(
expert_non_causal_attention)
3️⃣ 动作空间 - ActionSpace
抽象基类,定义了动作空间的接口:
核心方法:
-
traj_to_action(): 将轨迹转换为动作- 输入:历史轨迹 + 未来轨迹
- 输出:动作表示
-
action_to_traj(): 将动作转换为轨迹- 输入:动作 + 历史轨迹
- 输出:未来轨迹(xyz + rot)
-
get_action_space_dims(): 获取动作空间维度 -
is_within_bounds(): 检查动作是否在有效范围内
实现示例 :unicycle_accel_curvature 动作空间
- 动作维度:
(T, 3)- 加速度、曲率、速度 - 支持车辆运动学约束
4️⃣ 扩散模型 - BaseDiffusion
抽象基类,定义了扩散模型的接口:
核心方法:
sample(): 从扩散模型采样- 参数:
batch_size: 批次大小step_fn: 去噪步骤函数device: 设备return_all_steps: 是否返回所有步骤
- 返回:采样结果
[B, *x_dims]
- 参数:
StepFn协议 (StepFn):
python
def step_fn(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
x: 噪声输入 [B, *action_dims]
t: 时间步
返回: 去噪后的输出或预测的噪声/向量场
"""
实现示例:Flow Matching
- 使用连续时间扩散
- 预测向量场而非噪声
- 更快的采样速度
5️⃣ 动作投影层
5.1 动作输入投影 - PerWaypointActionInProjV2
架构:
动作输入 [B, T, action_dim]
↓
Fourier编码(每个动作维度独立)
↓
时间步Fourier编码
↓
拼接 [B, T, num_fourier_feats * action_dim + num_fourier_feats]
↓
MLP编码器(多层 + RMSNorm + SiLU)
↓
LayerNorm
↓
输出 [B, T, hidden_size]
关键组件:
-
FourierEncoderV2 (
FourierEncoderV2)- 使用对数间隔的频率
- 输出:
[sin(2πxf), cos(2πxf)] * √2 - 更好的高频信息表示
-
MLPEncoder (
MLPEncoder)- 多层MLP:
Linear → SiLU → (RMSNorm → Linear → SiLU)×N → RMSNorm → Linear - 默认4层,隐藏维度1024
- 多层MLP:
-
RMSNorm (
RMSNorm)- 均方根归一化
- 更稳定的训练
参数:
python
in_dims: list[int] # 输入维度,最后一个元素是动作维度
out_dim: int # 输出维度(expert的hidden_size)
num_enc_layers: int = 4 # MLP层数
hidden_size: int = 1024 # MLP隐藏维度
max_freq: float = 100.0 # Fourier最大频率
num_fourier_feats: int = 20 # Fourier特征数
5.2 动作输出投影 - action_out_proj
简单的线性投影:
python
nn.Linear(hidden_size, action_dim)
- 输入:
[B, T, hidden_size] - 输出:
[B, T, action_dim]
6️⃣ Logits处理器 - ExpertLogitsProcessor
作用:在VLM生成时屏蔽轨迹token的logits,专注于CoT生成
实现:
python
def __call__(self, input_ids, scores):
# 将轨迹token位置的logits设为-inf
scores[:, traj_token_offset:traj_token_offset + traj_vocab_size] = float('-inf')
return scores
效果:
- VLM不会生成轨迹token
- 专注于生成推理文本(CoT)
- 轨迹由扩散模型生成
7️⃣ 停止条件 - StopAfterEOS
作用:在遇到EOS token后继续生成一个token,然后停止
实现逻辑:
python
1. 初始化 eos_found = [False, False, ...]
2. 检查最后一个token是否为EOS
3. 如果是,标记该序列为已找到EOS
4. 当所有序列都找到EOS时,返回True停止生成
为什么需要多生成一个token?
- KV cache在生成下一个token后更新
- 需要确保cache包含EOS token的KV
四、推理流程深度解析
主函数 :sample_trajectories_from_data_with_vlm_rollout()
阶段1:数据准备 (第152-163行)
python
# 1. 提取历史轨迹
ego_history_xyz = data["ego_history_xyz"] # [B, n_traj, T, 3]
ego_history_rot = data["ego_history_rot"] # [B, n_traj, T, 3]
# 2. 获取tokenized数据
tokenized_data = data["tokenized_data"]
input_ids = tokenized_data.pop("input_ids")
# 3. 准备轨迹数据用于VLM
traj_data_vlm = {
"ego_history_xyz": ego_history_xyz,
"ego_history_rot": ego_history_rot,
}
# 4. 融合轨迹token到输入序列
input_ids = self.fuse_traj_tokens(input_ids, traj_data_vlm)
轨迹融合过程:
- 使用
hist_traj_tokenizer编码历史轨迹 - 将编码后的token替换输入序列中的
<|traj_history|>占位符 - 输出:包含历史轨迹token的完整输入序列
阶段2:VLM自回归生成 (第165-207行)
python
# 1. 配置生成参数
generation_config.top_p = 0.98
generation_config.temperature = 0.6
generation_config.do_sample = True
generation_config.num_return_sequences = num_traj_samples # 多采样
generation_config.max_new_tokens = max_generation_length
generation_config.output_logits = True
generation_config.return_dict_in_generate = True
# 2. 设置停止条件
eos_token_id = tokenizer.convert_tokens_to_ids("<|traj_future_start|>")
stopping_criteria = StoppingCriteriaList([StopAfterEOS(eos_token_id)])
# 3. 设置logits处理器(屏蔽轨迹token)
logits_processor = LogitsProcessorList([
ExpertLogitsProcessor(
traj_token_offset=config.traj_token_start_idx,
traj_vocab_size=config.traj_vocab_size,
)
])
# 4. 执行生成
vlm_outputs = self.vlm.generate(
input_ids=input_ids,
generation_config=generation_config,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
**tokenized_data,
)
# 5. 保存rope_deltas(位置编码偏移)
vlm_outputs.rope_deltas = self.vlm.model.rope_deltas
# 6. 替换EOS后的padding
vlm_outputs.sequences = replace_padding_after_eos(
token_ids=vlm_outputs.sequences,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
生成内容:
- CoT推理文本(在
<|cot_start|>和<|cot_end|>之间) - 元动作信息(在
<|meta_action_start|>和<|meta_action_end|>之间) <|traj_future_start|>token(停止点)
阶段3:KV Cache处理 (第207-230行)
python
# 1. 获取prompt cache
prompt_cache = vlm_outputs.past_key_values
prefill_seq_len = prompt_cache.get_seq_length()
# 2. 找到每个序列的<|traj_future_start|>位置
traj_future_start_mask = vlm_outputs.sequences == eos_token_id
has_traj_future_start = traj_future_start_mask.any(dim=1)
traj_future_start_positions = traj_future_start_mask.int().argmax(dim=1)
# 3. 如果没有找到,使用最后一个token位置
last_token_positions = torch.full(
(b_star,), vlm_outputs.sequences.shape[1] - 1, device=device
)
valid_token_pos_id = torch.where(
has_traj_future_start, traj_future_start_positions, last_token_positions
)
# 4. 计算偏移量(用于位置编码)
offset = valid_token_pos_id + 1
KV Cache的作用:
- 缓存prompt的key和value
- Expert推理时复用,避免重复计算
- 显著提升推理速度
阶段4:位置编码和注意力掩码 (第232-248行)
python
# 1. 创建位置编码
n_diffusion_tokens = self.action_space.get_action_space_dims()[0]
position_ids = torch.arange(n_diffusion_tokens, device=device)
position_ids = einops.repeat(position_ids, "l -> 3 b l", b=b_star).clone()
# 2. 添加rope_deltas和offset
delta = vlm_outputs.rope_deltas + offset[:, None]
position_ids += delta.to(position_ids.device)
# 3. 创建注意力掩码
attention_mask = torch.zeros(
(b_star, 1, n_diffusion_tokens, prompt_cache.get_seq_length() + n_diffusion_tokens),
dtype=torch.float32,
device=device,
)
# 4. 设置掩码(屏蔽padding部分)
for i in range(b_star):
attention_mask[i, :, :, offset[i] : -n_diffusion_tokens] = torch.finfo(
attention_mask.dtype
).min
关键点:
rope_deltas: 处理多模态输入的位置偏移offset: 确保扩散token的位置正确attention_mask: 控制expert的注意力范围
阶段5:Expert定义去噪步骤函数 (第254-284行)
python
def step_fn(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
x: 噪声动作 [B*, *action_dim]
t: 时间步
返回: 预测的噪声/向量场
"""
b_star = x.shape[0]
# 1. 投影噪声动作到token embedding
future_token_embeds = self.action_in_proj(x, t)
if future_token_embeds.dim() == 2:
future_token_embeds = future_token_embeds.view(b_star, n_diffusion_tokens, -1)
# 2. 使用expert推理(复用KV cache)
expert_out_base = self.expert(
inputs_embeds=future_token_embeds,
position_ids=position_ids,
past_key_values=prompt_cache,
attention_mask=attention_mask,
use_cache=True,
is_causal=False if config.expert_non_causal_attention else True,
)
# 3. 裁剪cache(保持原始长度)
prompt_cache.crop(prefill_seq_len)
# 4. 提取最后hidden state
last_hidden = expert_out_base.last_hidden_state[:, -n_diffusion_tokens:]
# 5. 投影到动作空间
pred = self.action_out_proj(last_hidden).view(
-1, *self.action_space.get_action_space_dims()
)
return pred
关键优化:
prompt_cache.crop(prefill_seq_len): 避免cache无限增长is_causal=False: 允许expert看到所有未来token(非因果注意力)
阶段6:扩散采样 (第286-297行)
python
# 1. 计算总批次大小
total_batch = B * num_traj_samples * num_traj_sets
# 2. 执行扩散采样
sampled_action = self.diffusion.sample(
batch_size=total_batch,
step_fn=step_fn,
device=device,
return_all_steps=False,
**diffusion_kwargs,
)
扩散采样过程(以Flow Matching为例):
-
初始化:
x_0 ~ N(0, I) -
迭代去噪:
for t in [T, T-1, ..., 1]: v_pred = step_fn(x_t, t) x_{t-1} = x_t + dt * v_pred -
输出:
x_0(清洁动作)
阶段7:动作解码 (第299-317行)
python
# 1. 重复历史轨迹以匹配采样数量
hist_xyz_rep = einops.repeat(
ego_history_xyz[:, -1], "b ... -> (b n) ...", n=n_samples_total
)
hist_rot_rep = einops.repeat(
ego_history_rot[:, -1], "b ... -> (b n) ...", n=n_samples_total
)
# 2. 将动作转换为轨迹
pred_xyz, pred_rot = self.action_space.action_to_traj(
sampled_action, hist_xyz_rep, hist_rot_rep
)
# 3. 重塑为 [B, num_traj_sets, num_traj_samples, ...]
pred_xyz = einops.rearrange(
pred_xyz, "(b ns nj) ... -> b ns nj ...",
ns=num_traj_sets, nj=num_traj_samples
)
pred_rot = einops.rearrange(
pred_rot, "(b ns nj) ... -> b ns nj ...",
ns=num_traj_sets, nj=num_traj_samples
)
输出格式:
pred_xyz:[B, num_traj_sets, num_traj_samples, T, 3]pred_rot:[B, num_traj_sets, num_traj_samples, T, 3, 3]
阶段8:提取文本token (第320-327行)
python
if kwargs.get("return_extra", False):
extra = extract_text_tokens(self.tokenizer, vlm_outputs.sequences)
# 重塑文本token以匹配轨迹形状
for text_tokens in extra.keys():
extra[text_tokens] = np.array(extra[text_tokens]).reshape(
[input_ids.shape[0], num_traj_sets, num_traj_samples]
)
return pred_xyz, pred_rot, extra
提取的文本内容:
cot: Chain of Thought推理meta_action: 元动作信息answer: 答案(如果有)
五、关键技术细节
5.1 KV Cache机制
工作原理:
python
# VLM生成阶段
vlm_outputs = self.vlm.generate(...)
prompt_cache = vlm_outputs.past_key_values # [B, n_layers, 2, n_heads, seq_len, head_dim]
# Expert推理阶段
expert_out = self.expert(
inputs_embeds=future_token_embeds,
past_key_values=prompt_cache, # 复用cache
...
)
# 裁剪cache
prompt_cache.crop(prefill_seq_len) # 保持长度不变
优势:
- VLM生成后缓存prompt的KV
- Expert复用cache,避免重复计算
prompt_cache.crop()保持cache长度- 显著减少计算量
- 支持批量推理
5.2 位置编码处理
RoPE Delta机制:
python
# 多模态输入导致位置编码不连续
rope_deltas = vlm.model.rope_deltas # [B, 3]
# 调整扩散token的位置
position_ids = torch.arange(n_diffusion_tokens)
position_ids = einops.repeat(position_ids, "l -> 3 b l", b=b_star)
delta = rope_deltas + offset[:, None]
position_ids += delta
作用:
rope_deltas: 处理图像patch和文本token多模态输入的位置偏移- 动态调整
position_ids以匹配扩散token - 确保expert的位置编码正确
5.3 注意力掩码设计
掩码结构:
python
attention_mask: [B, 1, n_diffusion_tokens, total_seq_len]
↓
[B, 1, n_diffusion_tokens, prefill_len + n_diffusion_tokens]
# 设置padding区域为负无穷
attention_mask[:, :, :, offset:-n_diffusion_tokens] = -inf
作用:
- 屏蔽padding token
- 确保只关注有效token
- 支持变长序列
5.4 非因果注意力
配置:
python
expert_non_causal_attention: True
效果:
python
expert_out = self.expert(
...,
is_causal=False # 双向注意力
)
优势:
expert_non_causal_attention=True- 允许expert看到所有未来token
- 提高轨迹生成的连贯性
5.5 多轨迹采样
num_traj_samples: 每个输入生成多个轨迹num_traj_sets: 多组采样- 提高鲁棒性和多样性
5.6 Logits处理
ExpertLogitsProcessor: 屏蔽轨迹tokenStopAfterEOS: 在<|traj_future_start|>后停止- 专注于CoT生成
六、数据流完整图
输入数据
├── ego_history_xyz [B, n_traj, T, 3]
├── ego_history_rot [B, n_traj, T, 3]
├── 图像
└── 文本prompt
↓
[轨迹融合] fuse_traj_tokens()
↓
VLM输入序列 [B, L]
↓
[VLM生成] → CoT文本 + <|traj_future_start|>
↓
[KV Cache] → 缓存prompt的KV
↓
[扩散采样]
├── 初始化噪声动作 [B*ns*nj, T, action_dim]
├── 迭代去噪 (step_fn)
│ ├── action_in_proj: [B*ns*nj, T, action_dim] → [B*ns*nj, T, hidden_size]
│ ├── expert: 使用KV cache推理
│ │ └── 输出: [B*ns*nj, T, hidden_size]
│ └── action_out_proj: [B*ns*nj, T, hidden_size] → [B*ns*nj, T, action_dim]
└── 输出: 清洁动作 [B*ns*nj, T, action_dim]
↓
[动作解码] action_to_traj()
├── 输入: 动作 [B*ns*nj, T, action_dim]
└── 输出: pred_xyz [B, ns, nj, T, 3], pred_rot [B, ns, nj, T, 3, 3]
┌─────────────────────────────────────────────────────────────┐
│ 输入数据 │
├─────────────────────────────────────────────────────────────┤
│ ego_history_xyz: [B, n_traj, T, 3] │
│ ego_history_rot: [B, n_traj, T, 3, 3] │
│ 图像: [B, C, H, W] │
│ 文本prompt: "预测未来轨迹..." │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 轨迹融合 │
├─────────────────────────────────────────────────────────────┤
│ tokenize_history_trajectory() │
│ ↓ │
│ hist_idx: [B, n_traj * 16] │
│ ↓ │
│ replace_pad_token() │
│ ↓ │
│ input_ids: [B, n_token] (轨迹token已融合) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ VLM自回归生成 │
├─────────────────────────────────────────────────────────────┤
│ vlm.generate( │
│ input_ids, │
│ stopping_criteria=StopAfterEOS("<|traj_future_start|>"), │
│ logits_processor=ExpertLogitsProcessor() │
│ ) │
│ ↓ │
│ sequences: [B*ns, seq_len] │
│ 包含: CoT文本 + <|traj_future_start|> │
│ ↓ │
│ replace_padding_after_eos() │
│ ↓ │
│ prompt_cache: KV cache [B*ns, n_layers, 2, n_heads, ...] │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 位置编码和注意力掩码 │
├─────────────────────────────────────────────────────────────┤
│ 查找<|traj_future_start|>位置 │
│ ↓ │
│ 计算offset = position + 1 │
│ ↓ │
│ position_ids = arange(n_diffusion_tokens) + delta │
│ ↓ │
│ attention_mask: 设置padding区域为-inf │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 扩散采样 │
├─────────────────────────────────────────────────────────────┤
│ 初始化: x_T ~ N(0, I) │
│ ↓ │
│ 迭代去噪: │
│ for t in [T, T-1, ..., 1]: │
│ step_fn(x_t, t): │
│ 1. action_in_proj(x_t, t) → token_embeds │
│ 2. expert(token_embeds, past_key_values=prompt_cache)│
│ 3. action_out_proj(hidden) → pred │
│ x_{t-1} = x_t + dt * pred │
│ ↓ │
│ sampled_action: [B*ns*nj, *action_dims] │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 动作解码 │
├─────────────────────────────────────────────────────────────┤
│ action_to_traj(sampled_action, hist_xyz, hist_rot) │
│ ↓ │
│ pred_xyz: [B*ns*nj, T, 3] │
│ pred_rot: [B*ns*nj, T, 3, 3] │
│ ↓ │
│ rearrange: [B*ns*nj, ...] → [B, ns, nj, ...] │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ 输出 │
├─────────────────────────────────────────────────────────────┤
│ pred_xyz: [B, num_traj_sets, num_traj_samples, T, 3] │
│ pred_rot: [B, num_traj_sets, num_traj_samples, T, 3, 3] │
│ (可选) CoT文本, meta_action, answer │
└─────────────────────────────────────────────────────────────┘
七、模型优势总结
7.1 架构优势
- 模块化设计: VLM和Expert分离,职责清晰
- 参数高效: 共享embedding,减少参数量
- 可扩展性: 支持不同的动作空间和扩散方法
7.2 推理优势
- KV Cache优化: 复用注意力缓存,避免重复计算
- 批量采样: 一次推理生成多个轨迹
- 灵活采样: 支持top-p、top-k、temperature等采样策略
7.3 性能优势
- Flash Attention 2: 加速注意力计算
- 非因果注意力: 更好的全局建模
- Flow Matching: 更少的采样步数
7.4 功能优势
- 多模态融合: 图像、文本、轨迹统一处理
- 推理能力: CoT生成,可解释性强
- 鲁棒性: 多轨迹采样,提高可靠性
八、关键代码位置索引
| 功能 | 文件 | 行号 |
|---|---|---|
| ReasoningVLA类 | base_model.py | 285 |
| ReasoningVLAConfig | base_model.py | 200 |
| AlpamayoR1类 | alpamayo_r1.py | 73 |
| AlpamayoR1Config | config.py | 23 |
| 推理流程 | alpamayo_r1.py | 122 |
| DeltaTrajectoryTokenizer | delta_tokenizer.py | 21 |
| StopAfterEOS | token_utils.py | 172 |
| ExpertLogitsProcessor | alpamayo_r1.py | 41 |
| ActionSpace基类 | action_space.py | 23 |
| BaseDiffusion基类 | diffusion/base.py | 45 |