OPENPi模型源码AI解读

一、关键代码:

gemma.py 实现的是 Pi 模型里的"视觉语言模型"部分 ------负责:

  • 处理图像 patch + 文本 token(来自 Paligemma)

  • 和"动作专家"(action expert,在 pi0.py 里)做 Mixture-of-Experts 融合

  • 输出给 Flow Matching 头(在 model.py 里)

关键设计:双专家 MoE

```

输入: 图像patch tokens \| 文本prompt tokens | 动作 tokens

↓ ↓

PaliGemma expert (i=0) Action expert (i=1)

↓ ↓

Gemma-2B weights 从头/LoRA 训练

↓ ↓

输出 B, T, D 输出 B, T, D

↓ ↓

融合给 Flow Matching 头

```

这就是代码里 configs: SequenceConfig 是 列表 的原因------每个专家有独立的 config、独立的权重。

二、核心模块分层

1. Config(L44-109)------ 5 种变体

variant width depth heads kv_heads mlp 用途 dummy 64 4 8 1 128 单元测试/调试 gemma_300m 1024 18 8 1 4096 轻量版 gemma_300m_lora 1024 18 8 1 4096 300M + LoRA(rank=32) gemma_2b 2048 18 8 1 16384 Pi 默认(PaliGemma 2B) gemma_2b_lora 2048 18 8 1 16384 2B + LoRA(rank=16)

LoRA 配置 (L96-107):同时对 attn 和 ffn 各挂一个 LoRA adapter,rank=16, alpha=16。

  1. RMSNorm

这是 Pi 能做 条件动作生成 的关键------通过 AdaLN 把状态/指令/图像 embedding 注入到每层归一化里。

  1. Embedder

vocab_size = 257_152(PaliGemma 的词表大小)。

  1. Attention------ 多专家融合的核心

1. 每个专家独立做 QKV 投影

for i, (x, config) in enumerate(zip(xs, self.configs)):

qkv_einsum(x) # ← 每个专家自己的 Wq Wk Wv

2. concat 后一起做 self-attention

q, k, v = jnp.concatenate(y, axis=1) for y in zip(*qkvs)

所有专家的 token 在 token 维度拼接成一个长序列

一次 attention 就能跨专家通信!

3. 共享的 RoPE + MHA

q = _apply_rope(q, positions)

logits = einsum("BTKGH,BSKH->BKGTS", q, k, ...)

4. attention 后每个专家独立 output 投影

for i, (x, config) in enumerate(...):

out_einsum(encodedstart:end) # ← 每个专家自己的 Wo

这就是 Mixture-of-Experts Transformer 的精髓 :

  • 专家前 :每个 token 独立线性投影

  • 专家间 :一次 self-attention 跨所有专家 token 做信息交换

  • 专家后 :每个 token 独立线性投影回原维度

另外还支持 GQA(grouped-query attention)------当 num_kv_heads != num_heads 时 Q 和 KV 分开投影。

5. FeedForward

Gemma 用 SwiGLU,不是标准 ReLU FFN

ff_gate = x @ W_gate0 # gate 分支

ff1 = x @ W_gate1 # value 分支

activations = gelu(ff_gate) * ff1 # gate 调制 value

output = activations @ W_linear

  1. Block

结构:

x → RMSNorm(adarms_cond) → Attention → gated residual

→ RMSNorm(adarms_cond) → SwiGLU FFN → gated residual

每步都有 sharding constraint(JAX 自动并行)

每步都有 adarms_cond 注入(AdaLN 条件生成)

gated_residual是标准/条件两种模式:

gate is None → x + y # 普通残差

gate 不为 None → x + y * gate # 条件门控残差

  1. Module

Gemma 结构:

1. Embedder: token id → embedding (共享 vocab table)

2. 18 层 Block (Gemma-2B depth=18)

用 nn.scan 把所有层 scan 成一次 jit call(JAX 优化)

用 nn.remat 做 gradient checkpointing(省显存)

3. final RMSNorm(每专家独立)

4. AdaARM conditioning 注入到每层 RMSNorm

adarms_cond (条件注入) 在 init() 能看清楚:

use_adarms 是每个专家是否启用 AdaARM 的开关

传入 B, width 的 conditioning tensor(来自状态/指令)

每个专家独立:use_adarmsi=True 才传条件

三、关键特性总结

特性 代码位置 作用 双专家 MoE L172-201, L233-248 PaliGemma + Action expert 一次 attention 跨专家通信 多专家权重命名 L443-450 _name("attn", 0) → "attn" , _name("attn", 1) → "attn_1" , 第一个专家权重名和官方 PaliGemma checkpoint 对齐 ,直接加载 AdaARM 条件注入 L112-131, L402-403, L413-421 状态/指令 embedding 通过 AdaLN 注入每层,实现条件动作生成 RoPE 位置编码 L424-440 标准旋转位置编码,支持变长序列 KV Cache L211-214 推理时缓存 KV,支持自回归生成 LoRA 适配器 L52, L96-107 可选 LoRA 挂到 attn 和 ffn,只训小部分参数 GQA 支持 L176-199 支持 grouped-query attention 加速推理 JAX sharding L294, L307, L310, L312 每层都有 activation sharding,支持多 GPU 并行

核心主干

一、一层 Block 里有几个 RMSNorm

看 Block. call :

每层 Block 有 2 个 RMSNorm :

  • pre_attention_norm (Attention 之前)

  • pre_ffw_norm (FeedForward 之前)

二、Pi 里有几层 Block

PaliGemma 2B 的 config( gemma.py:L79-87 ):

18 层 Block × 2 个 RMSNorm = 36 个 RMSNorm 注入点 。

加上每个专家独立的 RMSNorm,实际数量:

三、AdaLN 在每个注入点做什么

看 RMSNorm. call 的 AdaLN 路径:

每个注入点做 3 件事:

参数 作用 含义 scale 乘以 (1 + scale) 缩放 特征的每个维度 shift 加上 shift 偏移 特征的每个维度 gate 给后面的 residual 用 门控 残差连接的强度

四、完整数据流(你的水瓶任务)

```

输入:

机器人当前双臂状态 B, 16

→ 状态编码器 (StateEncoder)

→ 条件向量 cond B, 2048 (和 PaliGemma hidden dim 对齐)

→ 复制 18 × 2 = 36 份(每层 Block 的每个 RMSNorm 一份)

→ 送进每个 RMSNorm 的 AdaLN

第 1 层 pre_attention_norm:

normed_image = rms_norm(image_tokens) * (1 + scale_1) +

shift_1

normed_text = rms_norm(text_tokens) * (1 + scale_2) +

shift_2

第 1 层 pre_ffw_norm:

normed_hidden = rms_norm(hidden) * (1 + scale_3) + shift_3

...

第 18 层 pre_ffw_norm:

同上结构,独立的 scale/shift/gate 参数

最终 Flow Matching 头:

收到已经被状态条件调制过 36 次的 hidden states

→ 输出条件化的动作 B, 36

```

五、这样设计的好处

对比 做法 效果 普通条件注入 把状态 concat 到输入 只能影响第 1 层,后面层状态信息丢失 AdaLN(Pi 的做法) 每层的每个 RMSNorm 都注入条件 条件信息贯穿 18 层,每层都能"看到"当前状态 好处 状态条件 × 36 次注入 模型在任何深度都能"记住"机器人姿态,生成更准确的动作

一句话 :每个 Block 里的 pre_attention_norm 和 pre_ffw_norm 各注入一次,18 层共 36 个 AdaLN 注入点 ,让机器人状态条件贯穿整个 Transformer。

QA:

  1. 这 18 层 Block 是 Transformer 的核心"特征提炼管道" ------每一层都在做"从原始信号 → 抽象语义"的一步提炼。给你分阶段讲清楚:

一、每一层做什么

每一层 Block 内部的流水线 :

二、18 层分 4 个阶段

阶段 1:L0 - L3(底层特征提取)

层 做什么 你的水瓶任务里 L0-L1 从原始像素/词元提取 底层特征 识别出"圆的东西"、"黄色的东西"、"直线边缘" L2-L3 组合底层特征成 局部模式 识别出"圆柱形"、"瓶盖纹理"、"桌面纹理"

阶段 2:L4 - L8(语义组合)

层 做什么 你的水瓶任务里 L4-L5 形成 局部语义 "这个物体是水瓶"、"水瓶在桌子左边" L6-L7 跨模态融合 图像和文本"grasp the water bottle"对齐 L8 状态条件注入 让模型"看到"当前双臂姿态

阶段 3:L9 - L14(推理与规划)

层 做什么 你的水瓶任务里 L9-L10 形成 动作意图 "我应该去抓这个水瓶" L11-L12 规划 动作序列 "先张开夹爪 → 向前移动 → 夹紧" L13-L14 细化 连续轨迹 "夹爪应该以 5cm/s 速度接近"

阶段 4:L15 - L17(条件化输出)

层 做什么 你的水瓶任务里 L15-L16 精调最终特征 让特征和 Flow Matching 头对齐 L17 最后整合 输出给 Flow Matching 头的 embedding

三、每层之间的"信息传递"

每层的输入输出 维度相同 (2048),通过残差连接保证梯度不消失。

四、AdaLN 条件注入在每层的作用

每个 Block 的 2 个 RMSNorm 都会注入条件,但每层"看到"的条件做的事不同:

层 AdaLN 条件作用 L0-L3 调制底层特征提取:"当前状态是双臂张开 → 视觉注意力集中在桌面中央" L4-L8 调制语义融合:"当前状态是已定位 → 文本 prompt 和图像更紧密对齐" L9-L14 调制动作规划:"当前位置在 (x,y,z) → 规划路径从这里开始" L15-L17 调制最终输出:"当前力传感器读数为 0 → 输出抓握前的动作"

五、一句话总结

18 层 Block = 18 次"特征提炼 + 跨专家通信 + 条件调制" :

  • 从 L0 的"识别边缘"到 L17 的"输出可执行动作"

  • 每层都有跨专家 self-attention(PaliGemma 和 Action 专家交互)

  • 每层都有 AdaLN 条件注入(机器人状态贯穿始终)

  • 最终给 Flow Matching 头提供 完全条件化的语义特征 来生成 36 维连续动作

2 。 为什么用mse loss,能用其他loss吗

SOLO Agent

思考过程

直接回答:

一、为什么 Flow Matching 用 MSE Loss

1. Flow Matching 的数学本质

Flow Matching 学习的是一个 速度场 o bj ec tO bj ec t v ( x , t ) ,满足 ODE:

o bj ec tO bj ec t d t d x ​ = v θ ​ ( x , t )

训练目标是让 o bj ec tO bj ec t v θ ​ 匹配 真实的速度场 :

o bj ec tO bj ec t u t ​ = noise − action

所以 loss 是:

o bj ec tO bj ec t \] L = E t , x t ​ ​ \[ ∥ v θ ​ ( x t ​ , t ) − u t ​ ∥ 2

这就是 MSE (均方误差)。

2. MSE 是 Flow Matching 的标准选择

方面 原因 数学一致性 Flow Matching 推导时就是用 o bj ec tO bj ec t L 2 ​ 损失,MSE 是理论要求 优化稳定性 MSE 对连续值回归最稳定,梯度光滑 物理意义 速度场的 o bj ec tO bj ec t L 2 ​ 距离就是 能量 ,有物理意义 简单高效 不用调超参,直接用

3. 代码里的体现(pi0.py:L229)

就是对每个 36 维的 o bj ec tO bj ec t v t ​ 和 o bj ec tO bj ec t u t ​ 计算 squared L2 距离,再取均值。

二、能用其他 Loss 吗?

可以,但要理解每一种的影响:

Loss 公式 效果 适合场景 MSE(当前) o bj ec tO bj ec t ∥ v − u ∥ 2 对大误差惩罚重 默认推荐 MAE / L1 o bj ec tO bj ec t ∥ v − u ∥ 对异常值鲁棒 动作有少量异常值时 Huber 混合 MSE 和 MAE 可调鲁棒性 平衡 MSE 和 MAE Smooth L1 Huber 的变体 对大误差线性惩罚 动作平滑性要求高 Cosine Similarity o bj ec tO bj ec t 1 − cos ( v , u ) 关注方向而非幅值 只关心动作方向

三、不同 Loss 对水瓶抓放任务的影响

MSE(当前默认)

MAE / L1

Huber Loss(推荐试试)

Cosine Similarity

四、如果要改,怎么改

pi0.py:L229 改一行:

五、我的建议

先保持 MSE 不动 ,原因:

  1. MSE 是 Flow Matching 的理论标准 ------pi0-base 预训练时就用 MSE,你 finetune 时换 loss 会导致模型权重和新 loss 不匹配

  2. 你的数据量不大 ------几百条 episode,MSE 已经足够稳定

  3. 先训起来看效果 ------如果 MSE 训完效果不好(比如某些瓶子抓不稳),再考虑换 Huber

如果要试 , 优先试 Huber Loss ------比 MSE 更鲁棒,适合机器人动作这种有少量异常值的场景,而且不会丢失 MSE 的平滑梯度特性。

Q3:

Flow Matching 头提供 完全条件化的语义特征 来生成 36 维连续动作 这一部分的代码在哪

找到了。Flow Matching 头 就在 pi0.py 里,从"18 层 Gemma 输出"到"36 维连续动作"的完整链路:

一、代码位置分布

组件 文件位置 作用 初始化(所有层定义) pi0.py:L66-103 定义 action_in_proj, time_mlp, action_out_proj embed_prefix(图像+文本) pi0.py:L106-152 编码图像+文本 tokens embed_suffix(状态+动作+时间) pi0.py:L154-201 编码状态+噪声动作+时间步 主干 Gemma forward pi0.py:L224-226 18 层 Block 处理 Flow Matching 输出头 pi0.py:L227 action_out_proj 线性层 损失计算(Flow Matching loss) pi0.py:L229 MSE loss 推理采样(Flow Matching ODE) pi0.py:L231+ 迭代去噪生成动作

二、完整数据流(训练时)

Step 1:初始化(L66-103)

Step 2:embed_prefix --- 编码图像+文本(L106-152)

Step 3:embed_suffix --- 编码状态+噪声动作+时间(L154-201)

Step 4:Gemma 18 层 forward(L224-226)

Step 5:Flow Matching 头输出(L227)⭐

这就是 Flow Matching 头 ------一个从 2048 维投影到 36 维的线性层。

Step 6:Flow Matching 损失(L229)

三、推理时(sample_actions)

四、总结

组件 代码位置 维度变换 图像→SigLIP pi0.py:L118 B,3,224,224B,196,2048 文本→Embedding pi0.py:L133 B,5B,5,2048 动作→投影 pi0.py:L174 B,15,36B,15,2048 时间步→MLP pi0.py:L179-182 B,1B,2048 18 层 Gemma pi0.py:L224-226 B,221,2048B,221,2048 Flow Matching 头 pi0.py:L227 B,15,2048B,15,36

Flow Matching 头就是 action_out_proj = Linear(2048 → 36) ------把 Gemma 输出的 2048 维语义特征,投影成 36 维连续动作(双臂 14 关节 + 夹爪 2 × 15 步)。

Q4: 关于激活函数

我用 你当前训练的水瓶抓放任务 (batch=2,3相机,文本prompt,action_horizon=15,action_dim=36)的 具体维度 ,结合代码,把前向传播求预测值的 每一步计算、激活函数位置、维度变化 完全具象化:

一、先明确:前向传播的2种场景

场景 求的预测值是什么 代码入口 训练时 带噪声动作的 速度场预测值 o bj ec tO bj ec t v t ​ pi0.py:L204-229 推理时 最终干净的 36维连续动作 pi0.py:L231+

二、训练时前向传播(求 objectObjectvt​ 预测值)

以 batch=2,3相机,10个文本token,15步动作horizon,36维动作 为例,每一步对应代码和维度:

Step 1:embed_prefix(编码图像+文本)

代码位置: pi0.py:L106-152

Step 2:embed_suffix(编码噪声动作+时间步)

代码位置: pi0.py:L154-201

Step 3:Gemma 18层前向(核心计算)

代码位置: pi0.py:L224-226

Step 4:Flow Matching头输出 objectObjectvt​ 预测值

代码位置: pi0.py:L227

Step 5:计算Loss

代码位置: pi0.py:L229

三、推理时前向传播(求最终干净动作)

推理时是 多次迭代 的前向(10步默认),每一步都是上面的前向流程,只是噪声动作的初始值不同:

代码入口: pi0.py:L231+

四、核心总结

前向传播的计算本质 每一步都是「线性变换 + 激活函数」的组合 ,最终通过 输出层线性投影 得到预测值。

激活函数在前向的位置

激活函数 位置 作用 Softmax Attention层 把注意力权重归一化(非线性) GELU SwiGLU FFN层 引入非线性,拟合复杂模式 Swish 时间步MLP 处理连续时间步的非线性变换 无激活 输出层(action_out_proj) 直接投影到36维连续动作(线性输出)

维度变化全链路

指明 这几个激活函数的代码 和逻辑处理,为什么要这样用

思考过程

我将精确对应 Pi 模型中用到的 3 种核心激活函数 ,明确其 代码位置、逻辑处理、针对机器人任务的设计原因 :

一、激活函数总览

激活函数 代码位置 核心作用 针对机器人任务的设计原因 Softmax gemma.py:L198-205 注意力权重归一化 动态聚焦关键空间信息(抓取点/目标位置) GELU (SwiGLU门控) gemma.py:L260-263 FFN层非线性门控 拟合连续时序动作的平滑特征,自适应筛选动作模式 Swish (SiLU) pi0.py:L176-183 时间步编码MLP非线性 精准捕捉Flow Matching去噪过程的 连续时间特性

二、逐个激活函数详解(代码+逻辑+设计原因)

1. Softmax(注意力权重激活) 代码位置: gemma.py:L198-205 逻辑处理:

  • 把 任意实数的注意力得分 映射到 0,1 区间,且 所有得分加和为1

  • 转换公式: o bj ec tO bj ec t S o f t ma x ( x i ​ ) = ∑ j ​ e x j ​ e x i ​ ​

  • 作用:让注意力权重具有 概率意义 ,明确告诉模型"每个位置的重要性占比" 设计原因(针对机器人任务):

  1. 多相机/多模态融合的精准聚焦 :机器人任务需要同时处理3个相机图像、文本指令、 proprioception(本体状态),Softmax能让模型 动态聚焦关键信息 (比如抓水瓶时聚焦瓶身,放箱子时聚焦箱子内部),而非平均分配注意力。

  2. 避免无关信息干扰 :机器人场景中存在大量噪声(比如背景杂物、机械臂的无关移动),Softmax的归一化特性能抑制噪声区域的权重,让模型专注于 任务核心区域 。

  3. 时序动作的关联建模 :18层Transformer的注意力能捕捉 跨时间步的动作关联 (比如"当前抓取力度"和"后续放置位置"的关联),Softmax让这种关联更清晰。

2. GELU(SwiGLU FFN的门控激活) 代码位置: gemma.py:L260-263 逻辑处理:

  • GELU激活公式: o bj ec tO bj ec t GE LU ( x ) = x ⋅ Φ ( x ) ( o bj ec tO bj ec t Φ 是标准高斯CDF,比ReLU平滑,负值区域有非零输出)

  • SwiGLU的门控逻辑 :

  1. 两路并行线性变换,一路过GELU当**"门控信号" (控制信息通过量),另一路当 "信息载体"**(传递特征)

  2. 门控信号(GELU输出)× 信息载体 = 最终有效特征, 自适应筛选有用特征 设计原因(针对机器人任务):

  3. 连续动作的平滑拟合 :机器人动作是 连续时序信号 (比如手臂移动的速度/轨迹),ReLU的硬截断(负值变0)会丢失连续信息,GELU的平滑特性保证 动作的连续性 ,避免机器人移动出现突兀卡顿。

  4. 自适应动作特征筛选 :

  • 抓水瓶时,门控自动筛选 夹爪闭合的力/位置 特征,忽略机械臂的无关移动;

  • 放箱子时,门控自动筛选 手臂高度/水平位置 特征,聚焦放置动作。

  1. Flow Matching的速度场预测 :Flow Matching需要预测 连续的速度场 (从噪声到动作的插值速度),GELU的平滑特性能更精准地拟合速度场的非线性变化,提升去噪精度。

3. Swish(SiLU,时间步编码的激活) 代码位置: pi0.py:L176-183 逻辑处理:

  • Swish激活公式: o bj ec tO bj ec t Sw i s h ( x ) = x ⋅ σ ( x ) ( o bj ec tO bj ec t σ 是Sigmoid,自门控特性:负值有非零输出,正值接近线性)

  • 作用:把**Flow Matching的连续时间步t(0→1)**编码成高维向量,注入AdaLN层, 动态调制每一层的归一化参数 设计原因(针对Flow Matching任务):

  1. 捕捉去噪过程的非线性特性 :Flow Matching的时间步t从1(纯噪声)→0(干净动作)是 非线性的去噪过程 :
  • t接近1时:需要强去噪(大幅调整动作);

  • t接近0时:需要微调动作(小幅修正误差);

Swish的自门控特性能精准编码这种 非线性的时间特性 ,让AdaLN能自适应调整每层的去噪强度。

  1. 适合连续标量输入 :时间步t是 单个连续标量 (0~1),Swish的自门控特性( o bj ec tO bj ec t x ⋅ σ ( x ) )比GELU更灵活,能更好地捕捉标量输入的细微变化(比如t从0.9到0.8的去噪强度变化)。

  2. 提升AdaLN的条件注入精度 :AdaLN是Pi实现"条件动作生成"的核心(把时间步/状态/指令注入每一层),Swish的非线性编码能让注入的条件更 精准匹配当前去噪阶段的需求 。

三、关键补充:激活函数的"学习"本质

所有这些激活函数(Softmax/GELU/Swish)都是 固定的数学函数 ,不参与训练。真正学习的是 激活函数前后的权重矩阵 :

  • 比如GELU前后的 w_gating 、 w_linear (SwiGLU的权重)

  • 比如Swish前后的 Dense 层权重(时间步编码的权重)

  • 比如Softmax前后的 QKV 投影权重(注意力的权重)

激活函数只是 引入非线性的"工具" ,让模型能学习 连续动作的复杂模式 (比如抓水瓶的力控制、放箱子的位置控制),避免模型退化为线性变换。

相关推荐
果丁智能44 分钟前
物联网智能锁赋能网约房与民宿:身份核验、远程授权的技术落地与成本优化方案
人工智能·物联网·智能家居
科技每日热闻1 小时前
618 AI显示器选购指南!爱攻AGON AI定制芯片电竞显示器AG277UX,适合哪些玩家?
人工智能·科技·游戏·计算机外设
初圣魔门首席弟子1 小时前
AI Agent 核心原理:工具调用(Function Calling)完整工作流程详解
前端·数据库·人工智能
MacroZheng1 小时前
横空出世!Claude Code画图神器来了,比Visio快10倍!
java·人工智能·后端
FII工业富联科技服务1 小时前
“可持续灯塔工厂”技术解密:AI+IoT如何落地端到端碳管理闭环
大数据·人工智能·物联网·ai·数据分析·自动化·制造
AniShort1 小时前
AniShort携3D世界+3D导演台王炸组合AI短剧协作平台亮相2026横店AI短剧大会 近亿元融资赋能短剧工业化
人工智能·microsoft·3d
大山佬1 小时前
ARM 汇编优化:NEON 指令与内存访问的实战技巧
人工智能
Bright16681 小时前
从零打造 Cursor 平替:基于 VS Code 二开的 AI 编程编辑器 CodexaX
人工智能·开源·编辑器
AI客栈1 小时前
AI大模型微服务网关架构下的动态限频与负载均衡设计:生产环境突发故障排查与优化
人工智能