《VLA 系列》复现 realtime-vla | 加速推理 | Triton后端

本文介绍了实时realtime-vla加速实现方案,通过Triton后端优化在RTX 4090/5090显卡上达到20-55ms的推理速度。

详细说明了环境搭建、权重转换(支持Pi0/Pi0.5/DM0模型)和推理测试流程,包括:

  1. 代码获取与依赖安装
  2. 模型权重下载与格式转换
  3. 推理引擎初始化与CUDA Graph优化
  4. 性能测试(端到端延迟34.3Hz,Python开销仅0.004ms)
  5. 与JAX后端的输出一致性验证(平均误差评估) 测试表明该方案在保持精度的同时显著提升推理速度,适用于实时机器人控制等场景。完整代码已开源。

论文地址Running VLAs at Real-time Speed

开源代码https://github.com/dexmal/realtime-vla

官方的加速效果

Model / Backend RTX 4090 (1 view) RTX 4090 (2 views) RTX 4090 (3 views) RTX 5090 (1 view) RTX 5090 (2 views) RTX 5090 (3 views)
Pi0 Triton 20.0ms 27.3ms 36.8ms 17.6ms 24.0ms 31.9ms
Pi05 Triton 22.1ms 29.2ms 38.9ms 20.1ms 26.6ms 34.2ms
DM0 Triton 55.8ms

1、拉取代码

执行下面代码,拉取 realtime-vla 的代码

bash 复制代码
git clone https://github.com/dexmal/realtime-vla.git

然后进入 realtime-vla 目录中

bash 复制代码
cd realtime-vla/

2、构建开发环境

这里主要对pi0、pi0.5进行加速,使用 openpi 的环境就好啦

开发环境代建参考:《VLA 系列》复现 π0.5、π0-FAST、π0 | 环境搭建 | 模型推理_pi0.5复现-CSDN博客

然后进入 openpi 代码目录中,激活openpi的环境,再来到realtime-vla 目录

3、拉取基础权重

由于pi0、pi0.5使用到 google/paligemma-3b-pt-224 权重,需要先进行下载

推荐使用国内的 modelscope,进行下载,速度比较快

先使用uv进行安装

bash 复制代码
uv pip install modelscope

然后执行下面命令进行下载

bash 复制代码
python -c "
from modelscope import snapshot_download
snapshot_download(
    'google/paligemma-3b-pt-224',
    cache_dir='./paligemma-3b-pt-224',
    local_files_only=False
)
"

运行效果:

其中,paligemma-3b-pt-224大概11G左右;

由于下载的paligemma-3b-pt-224 嵌套了多层目录,需要改为:

realtime-vla/paligemma-3b-pt-224/google/paligemma-3b-pt-224/ ---> realtime-vla/paligemma-3b-pt-224/

4、模型权重转换

pi0、pi0.5、DM0的转换示例:(因为需要将模型转为"静态图"进行加速推理,所以输入的prompt是固定的)

bash 复制代码
# pi0的转换示例
pytho3 convert_from_jax.py \
   --jax_path /path/to/checkpoint/folder\
   --output converted_checkpoint.pkl\
   --prompt "your task prompt"\
   --tokenizer_path /path/to/paligemma-3b-pt-224

# pi0.5的转换示例
python3 convert_from_jax_pi05.py \
   --jax_path /path/to/checkpoint/folder\
   --output converted_checkpoint.pkl\
   --prompt "your task prompt"\
   --tokenizer_path /path/to/paligemma-3b-pt-224

# DM0的转换示例
python3 convert_dm0_weight.py \
   --model_path /path/to/checkpoint/folder\
   --output converted_checkpoint.pt

5、pi05_base 实践示例

1)pi05_base 的模型转换

bash 复制代码
python3 convert_from_jax_pi05.py \
  --jax_path /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_base \
  --output pi05_base_converted.pkl \
  --prompt "Pick up the bottle." \
  --tokenizer_path ./paligemma-3b-pt-224

运行信息:

(openpi) (base) liguopu@untu-System-Product-Name:~/lgp_dev/project/realtime-vla $

(openpi) (base) liguopu@untu-System-Product-Name:~/lgp_dev/project/realtime-vla $ python3 convert_from_jax_pi05.py --jax_path /home/liguopu/lgp_dev/project/openpi/check points/pi05_base --output pi05_base_converted.pkl --prompt "Pick up the bottle." --tokenizer_path ./paligemma-3b-pt-224

/home/liguopu/lgp_dev/project/openpi/.venv/lib/python3.11/site-packages/torch/cuda/init.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia- ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.

import pynvml # type: ignoreimport

Loading jax weights from /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_base/params

WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be act ively worked on.

Loaded JAX params keys: dict_keys('PaliGemma', 'action_in_proj', 'action_out_proj', 'time_mlp_in', 'time_mlp_out')

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Default ing to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a sp ecific strategy to `truncation`.

Successfully converted Pi0.5 weights to pi05_base_converted.pkl

(openpi) (base) liguopu@untu-System-Product-Name:~/lgp_dev/project/realtime-vla$

2)编写一个代码,测试是否正常推理

运行指令:

bash 复制代码
python pi05_base_infer.py

代码如下所示:

python 复制代码
import pickle
import numpy as np
import torch
from pi05_infer import Pi05Inference

# 1. 加载转换后的权重
with open('pi05_base_converted.pkl', 'rb') as f:
    checkpoint = pickle.load(f)

# 2. 初始化推理引擎
infer = Pi05Inference(
    checkpoint=checkpoint,
    num_views=2,
    chunk_size=50,
    tokenizer_path="./paligemma-3b-pt-224",
    discrete_state_input=True,
    max_tokenize_len=200,
)

# 3. 准备输入
images = torch.randn(2, 224, 224, 3, dtype=torch.bfloat16, device="cuda")
noise = torch.randn(50, 32, dtype=torch.bfloat16, device="cuda")
task_prompt = "pick up the red cube"
state_tokens = np.random.randint(0, 256, size=32, dtype=np.int32)

# 4. 执行推理 ------ 参数名必须是 diffusion_noise,不是 diffusion_input_noise
actions = infer.forward(
    observation_images_normalized=images,
    diffusion_noise=noise,           # ← 修正:去掉 input_ 前缀
    task_prompt=task_prompt,
    state_tokens=state_tokens,
)

print("Actions shape:", actions.shape)
print("Actions dtype:", actions.dtype)
print("First action:", actions[0])

运行信息:

max_prompt_len: 200, max_tokenize_len: 200

Actions shape: torch.Size(50, 32)

Actions dtype: torch.bfloat16

First action: tensor([ 0.4258, -0.2637, 0.7461, 0.0029, 0.0054, 0.0515, 0.7695, 0.7812,

-0.0554, 0.4531, -0.0679, -0.8086, 0.3184, 0.4219, 0.1377, 0.8008,

-0.5391, 0.2090, 0.2793, -0.5859, 0.5625, -0.1885, -0.2754, 0.3691,

-0.0035, 0.0022, 0.0033, -0.0265, 0.0038, -0.0195, -0.0192, 0.0381],

device='cuda:0', dtype=torch.bfloat16)

3)编写一个代码,测试推理速度

用于量化 CUDA Graph 优化后的端到端推理延迟,并分析 Python 层开销占比

阶段 目的
加载转换权重 从 pickle 文件加载已转换的 Pi0.5 权重
初始化推理引擎 创建 Pi05Inference 实例,配置多视角输入、动作块长度等
构造伪输入 生成随机图像/噪声/state tokens,模拟真实推理输入
CUDA 预热 执行 3 次前向传播,稳定 GPU 状态
测试 A:端到端计时 1000 次完整 forward()(含 prompt 编码 + CUDA Graph replay)
测试 B:纯 Graph replay 1000 次仅 infer_graph.replay()(跳过 Python 前置处理)
对比分析 计算 Python 开销占比,判断瓶颈在 CPU 还是 GPU 侧

运行指令:python pi05_benchmark_1000runs.py

代码如下:

python 复制代码
import time
import pickle
import numpy as np
import torch
from pi05_infer import Pi05Inference

# ===================== 配置参数 =====================
MODEL_NAME = "Pi05"
NUM_VIEWS = 1
IMAGE_SIZE = 224
CHUNK_SIZE = 50
CHECKPOINT_PATH = "pi05_base_converted.pkl"
TOKENIZER_PATH = "./paligemma-3b-pt-224"

# ===================== 1. 加载转换后的权重 =====================
print("[1/4] 加载转换权重...")
with open(CHECKPOINT_PATH, "rb") as f:
    checkpoint = pickle.load(f)

# ===================== 2. 初始化推理引擎 =====================
print(f"[2/4] 初始化 {MODEL_NAME} Inference(views={NUM_VIEWS}, chunk={CHUNK_SIZE})...")
infer = Pi05Inference(
    checkpoint=checkpoint,
    num_views=NUM_VIEWS,
    chunk_size=CHUNK_SIZE,
    tokenizer_path=TOKENIZER_PATH,
    discrete_state_input=True,
    max_tokenize_len=200,
)

# ===================== 3. 准备输入数据 =====================
print(f"[3/4] 准备测试输入: images=({NUM_VIEWS}, {IMAGE_SIZE}, {IMAGE_SIZE}, 3), noise=({CHUNK_SIZE}, 32)")
images = torch.randn(NUM_VIEWS, IMAGE_SIZE, IMAGE_SIZE, 3, dtype=torch.bfloat16, device="cuda")
noise = torch.randn(CHUNK_SIZE, 32, dtype=torch.bfloat16, device="cuda")
task_prompt = "pick up the red cube"
state_tokens = np.random.randint(0, 256, size=32, dtype=np.int32)

# ===================== 4. 预热 =====================
print("[4/4] 预热 CUDA Graph(3 轮)...")
for i in range(3):
    _ = infer.forward(images, noise, task_prompt, state_tokens)
    torch.cuda.synchronize()
print("预热完成,开始正式测试\n")

# ===================== 5. 测试 A:1000 次端到端推理计时 =====================
print("=" * 60)
print("【测试 A】端到端推理(含 prompt 编码 + CUDA Graph replay)")
print("=" * 60)

end2end_times = []
for i in range(1000):
    torch.cuda.synchronize()
    t0 = time.perf_counter()

    actions = infer.forward(images, noise, task_prompt, state_tokens)

    torch.cuda.synchronize()
    t1 = time.perf_counter()
    elapsed_ms = (t1 - t0) * 1000
    end2end_times.append(elapsed_ms)
    if (i + 1) % 100 == 0:
        print(f"  进度: {i+1}/1000")

# 测试 A 汇总
print()
print("=" * 60)
print("【测试 A 汇总】端到端推理(含 prompt 编码 + CUDA Graph replay)")
print("=" * 60)
print(f"  模型名称:     {MODEL_NAME}")
print(f"  输入视角数:   {NUM_VIEWS}")
print(f"  图像分辨率:   {IMAGE_SIZE} x {IMAGE_SIZE}")
print(f"  动作块长度:   {CHUNK_SIZE}")
print(f"  测试次数:     {len(end2end_times)}")

end2end_sorted = sorted(end2end_times)
n = len(end2end_times)
mean_e2e = sum(end2end_times) / n
median_e2e = end2end_sorted[n // 2]
p99_e2e = end2end_sorted[int(n * 0.99)]
min_e2e = min(end2end_times)
max_e2e = max(end2end_times)

print(f"  平均延迟:     {mean_e2e:7.3f} ms")
print(f"  中位数延迟:   {median_e2e:7.3f} ms")
print(f"  P99 延迟:     {p99_e2e:7.3f} ms")
print(f"  最小延迟:     {min_e2e:7.3f} ms")
print(f"  最大延迟:     {max_e2e:7.3f} ms")
print(f"  理论帧率(avg):{1000/mean_e2e:6.1f} Hz")
print(f"  理论帧率(med):{1000/median_e2e:6.1f} Hz")
print("=" * 60)

# ===================== 6. 测试 B:1000 次纯 CUDA Graph replay 计时 =====================
print()
print("=" * 60)
print("【测试 B】纯 CUDA Graph replay(仅 GPU 内核执行)")
print("=" * 60)

# 先手动执行一次 forward,把 prompt embeds 等准备好
_ = infer.forward(images, noise, task_prompt, state_tokens)
torch.cuda.synchronize()

graph_times = []
for i in range(1000):
    torch.cuda.synchronize()
    t0 = time.perf_counter()

    infer.infer_graph.replay()

    torch.cuda.synchronize()
    t1 = time.perf_counter()
    elapsed_ms = (t1 - t0) * 1000
    graph_times.append(elapsed_ms)
    if (i + 1) % 100 == 0:
        print(f"  进度: {i+1}/1000")

# 测试 B 汇总
print()
print("=" * 60)
print("【测试 B 汇总】纯 CUDA Graph replay(仅 GPU 内核执行)")
print("=" * 60)
print(f"  模型名称:     {MODEL_NAME}")
print(f"  输入视角数:   {NUM_VIEWS}")
print(f"  图像分辨率:   {IMAGE_SIZE} x {IMAGE_SIZE}")
print(f"  动作块长度:   {CHUNK_SIZE}")
print(f"  测试次数:     {len(graph_times)}")

graph_sorted = sorted(graph_times)
n = len(graph_times)
mean_graph = sum(graph_times) / n
median_graph = graph_sorted[n // 2]
p99_graph = graph_sorted[int(n * 0.99)]
min_graph = min(graph_times)
max_graph = max(graph_times)

print(f"  平均延迟:     {mean_graph:7.3f} ms")
print(f"  中位数延迟:   {median_graph:7.3f} ms")
print(f"  P99 延迟:     {p99_graph:7.3f} ms")
print(f"  最小延迟:     {min_graph:7.3f} ms")
print(f"  最大延迟:     {max_graph:7.3f} ms")
print(f"  理论帧率(avg):{1000/mean_graph:6.1f} Hz")
print(f"  理论帧率(med):{1000/median_graph:6.1f} Hz")
print("=" * 60)

# ===================== 7. 对比分析 =====================
print()
print("=" * 60)
print("【A vs B 对比分析】")
print("=" * 60)
overhead_ms = mean_e2e - mean_graph
overhead_pct = (overhead_ms / mean_graph) * 100 if mean_graph > 0 else 0
print(f"  Python 前置开销: {overhead_ms:+.3f} ms ({overhead_pct:+.1f}%)")
print(f"  GPU 计算占比:   {(mean_graph/mean_e2e)*100:.1f}%")
print(f"  结论: ", end="")
if overhead_ms < 1.0:
    print("Python 开销极小,瓶颈在 GPU 计算侧")
elif overhead_ms < 5.0:
    print("Python 开销较小,可进一步优化 prompt 编码")
else:
    print("Python 开销显著,建议预编码 prompt 或优化数据拷贝")
print("=" * 60)

运行效果:(用的是NVIDIA GeForce RTX 4090 的显卡,48GB显存的)

1/4 加载转换权重...

2/4 初始化 Pi05 Inference(views=1, chunk=50)...

max_prompt_len: 200, max_tokenize_len: 200

3/4 准备测试输入: images=(1, 224, 224, 3), noise=(50, 32)

4/4 预热 CUDA Graph(3 轮)...

预热完成,开始正式测试

============================================================

【测试 A 汇总】端到端推理(含 prompt 编码 + CUDA Graph replay)

============================================================

模型名称: Pi05

输入视角数: 1

图像分辨率: 224 x 224

动作块长度: 50

测试次数: 1000

平均延迟: 29.114 ms

中位数延迟: 29.284 ms

P99 延迟: 29.522 ms

最小延迟: 28.526 ms

最大延迟: 30.537 ms

理论帧率(avg): 34.3 Hz

理论帧率(med): 34.1 Hz

============================================================

【测试 B 汇总】纯 CUDA Graph replay(仅 GPU 内核执行)

============================================================

模型名称: Pi05

输入视角数: 1

图像分辨率: 224 x 224

动作块长度: 50

测试次数: 1000

平均延迟: 29.118 ms

中位数延迟: 29.314 ms

P99 延迟: 29.500 ms

最小延迟: 28.610 ms

最大延迟: 30.462 ms

理论帧率(avg): 34.3 Hz

理论帧率(med): 34.1 Hz

============================================================

【A vs B 对比分析】

============================================================

Python 前置开销: -0.004 ms (-0.0%)

GPU 计算占比: 100.0%

结论: Python 开销极小,瓶颈在 GPU 计算侧

============================================================

通过修改代码的NUM_VIEWS = 1,可以测试2、3视角的速度

测试的速度,比官方的慢了一些,可能是魔改的4090的问题,因为测试CUDA利用率已经100%,功率也是430W左右,没有其他程序运行下进行的

4)编写一个代码,测试Pi0.5 基础权重****JAX vs Triton 端到端推理速度对比

加载同一组 Pi05 权重在 JAX 和 Triton 两种后端上,用固定输入跑 N 次端到端推理,统计并对比两者的延迟分布与加速比,判断 Triton 加速是否达到实时部署标准

模块 功能说明
输入数据准备 使用 DROID 示例数据生成固定随机种子(seed=42)的输入,确保两组后端输入完全一致
图像预处理 解析图像格式 → 保持宽高比缩放 → 零填充到目标尺寸 → 归一化到 [-1, 1]
状态预处理 拼接关节位置+夹爪位置 → 填充到 32 维 → 可选归一化 → 离散化为 256-bin token ID
JAX 后端计时 调用 jax_policy.infer(),记录含 prompt 编码的完整端到端延迟
Triton 后端计时 调用 triton_model.forward(),含 torch.cuda.synchronize() 精确 GPU 计时
统计汇总 计算平均/中位数/P99/最小/最大延迟、标准差、理论帧率、加速比

代码如下所示:

python 复制代码
import os
import json
import argparse
import pickle
import time
import numpy as np
import torch
from PIL import Image
from pathlib import Path

from pi05_infer import Pi05Inference
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config


# ===================== 默认参数配置 =====================
DEFAULT_TRITON_PATH = "pi05_base_converted.pkl"
DEFAULT_JAX_PATH = "/home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_base"
DEFAULT_NORM_STATS_DIR = ""  # 速度对比不需要,设为空字符串
DEFAULT_CONFIG_NAME = "pi05_droid"
DEFAULT_PROMPT = "Pick up the bottle."
DEFAULT_TOKENIZER_PATH = "./paligemma-3b-pt-224"
DEFAULT_ACTION_DIM = 8
DEFAULT_CHUNK_SIZE = 15
DEFAULT_NUM_VIEWS = 1
DEFAULT_IMAGE_SIZE = 224
DEFAULT_SPEED_RUNS = 100
DEFAULT_WARMUP = 3


class DroidSpeedComparator:
    """仅对比 JAX 与 Triton 端到端推理速度"""

    def __init__(self, triton_path, jax_path, config_name,
                 tokenizer_path, prompt, discrete_state_input=True,
                 action_dim=8, chunk_size=15, num_views=3, image_size=224,
                 norm_stats_dir=""):
        self.prompt = prompt
        self.discrete_state_input = discrete_state_input
        self.action_dim = action_dim
        self.chunk_size = chunk_size
        self.num_views = num_views
        self.image_size = image_size
        self.model_name = "Pi05_base"

        # 加载归一化统计信息(可选,速度对比不需要)
        self.norm_stats = self._load_norm_stats(norm_stats_dir) if norm_stats_dir else None
        # 离散化用的分箱边界(256个bin)
        self._digitize_bins = np.linspace(-1, 1, 256 + 1)[:-1]

        # 加载 Triton 模型(加速后端)
        print("Loading Triton model...")
        with open(triton_path, "rb") as f:
            weights = pickle.load(f)
        self.triton_model = Pi05Inference(
            checkpoint=weights,
            num_views=num_views,
            chunk_size=chunk_size,
            tokenizer_path=tokenizer_path,
            discrete_state_input=True,
            max_tokenize_len=200,
        )

        # 加载 JAX 模型(官方后端)
        print("Loading JAX model...")
        config = _config.get_config(config_name)
        self.jax_policy = _policy_config.create_trained_policy(config, Path(jax_path))

    def _load_norm_stats(self, norm_stats_dir):
        """从指定目录加载归一化统计信息(可选)"""
        if not norm_stats_dir:
            return None
        norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
        if os.path.exists(norm_stats_path):
            with open(norm_stats_path, "r") as f:
                return json.load(f)["norm_stats"]
        return None

    def _pad_to_dim(self, x, target_dim, axis=-1):
        """将输入数组沿指定轴填充到目标维度"""
        current_dim = x.shape[axis] if len(x.shape) > 0 else len(x)
        if current_dim < target_dim:
            pad_width = [(0, 0)] * len(x.shape)
            pad_width[axis] = (0, target_dim - current_dim)
            return np.pad(x, pad_width)
        return x

    def _resize_with_pad(self, image, height=224, width=224):
        """保持宽高比缩放图像,并用0填充到目标尺寸"""
        pil_image = Image.fromarray(image)
        cur_width, cur_height = pil_image.size
        if cur_width == width and cur_height == height:
            return image
        ratio = max(cur_width / width, cur_height / height)
        resized_height = int(cur_height / ratio)
        resized_width = int(cur_width / ratio)
        resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
        zero_image = Image.new(resized_image.mode, (width, height), 0)
        pad_height = max(0, int((height - resized_height) / 2))
        pad_width = max(0, int((width - resized_width) / 2))
        zero_image.paste(resized_image, (pad_width, pad_height))
        return np.array(zero_image)

    def _preprocess_image(self, img_np):
        """图像预处理:解析格式 -> 缩放填充 -> 归一化到[-1, 1]"""
        img = np.asarray(img_np)
        if np.issubdtype(img.dtype, np.floating):
            img = (255 * img).astype(np.uint8)
        if img.shape[0] == 3:
            import einops
            img = einops.rearrange(img, "c h w -> h w c")
        img = self._resize_with_pad(img, self.image_size, self.image_size)
        img = img.astype(np.float32) / 255.0 * 2.0 - 1.0
        return img

    def _preprocess_state_discrete(self, joint_pos, gripper_pos):
        """状态预处理(离散模式):拼接 -> 填充 -> 归一化 -> 离散化,返回32个token ID"""
        if np.isscalar(gripper_pos):
            gripper_pos = np.array([gripper_pos], dtype=np.float32)
        state = np.concatenate([joint_pos, gripper_pos]).astype(np.float32)
        state = self._pad_to_dim(state, 32)

        # 速度对比中归一化统计信息可选,无则跳过归一化
        if self.norm_stats and "state" in self.norm_stats:
            q01 = np.array(self.norm_stats["state"]["q01"], dtype=np.float32)
            q99 = np.array(self.norm_stats["state"]["q99"], dtype=np.float32)
            q01 = self._pad_to_dim(q01, 32)
            q99 = self._pad_to_dim(q99, 32)
            state_normed = (state - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0
        else:
            state_normed = state

        state_normed = np.clip(state_normed, -1.0, 1.0)
        token_ids = np.digitize(state_normed, bins=self._digitize_bins) - 1
        return token_ids.astype(np.int64)

    def _prepare_inputs(self, seed=42):
        """准备一组相同的输入数据,供 JAX 和 Triton 共用"""
        np.random.seed(seed)

        droid_example = droid_policy.make_droid_example()
        noise_np = np.random.randn(self.chunk_size, 32).astype(np.float32)

        exterior = droid_example["observation/exterior_image_1_left"]
        wrist = droid_example["observation/wrist_image_left"]
        joint_pos = np.asarray(droid_example["observation/joint_position"])
        gripper_pos = np.asarray(droid_example["observation/gripper_position"])

        img_base = self._preprocess_image(exterior)
        img_left = self._preprocess_image(wrist)

        # 根据 num_views 严格构建图像列表
        images_list = []
        if self.num_views >= 1:
            images_list.append(img_base)
        if self.num_views >= 2:
            images_list.append(img_left)
        for _ in range(self.num_views - len(images_list)):
            images_list.append(np.zeros_like(img_base))

        state_tokens = self._preprocess_state_discrete(joint_pos, gripper_pos)

        # Triton 输入(Torch Tensor)
        images_torch = torch.from_numpy(
            np.stack(images_list, axis=0)
        ).to(torch.bfloat16).cuda()
        state_torch = torch.from_numpy(state_tokens).to(torch.long).cuda()
        noise_torch = torch.from_numpy(noise_np).to(torch.bfloat16).cuda()

        # JAX 输入(字典)
        jax_input = {
            "observation/exterior_image_1_left": img_base,
            "observation/wrist_image_left": img_left,
            "observation/joint_position": joint_pos,
            "observation/gripper_position": gripper_pos,
            "prompt": self.prompt,
        }

        return {
            "triton": (images_torch, noise_torch, state_torch),
            "jax": (jax_input, noise_np),
        }

    def run(self, num_runs=100, warmup=3):
        """执行速度对比测试"""
        print(f"\n{'='*70}")
        print("【速度对比】JAX 官方后端 vs Triton 加速后端")
        print(f"{'='*70}")
        print(f"  模型名称:     {self.model_name}")
        print(f"  输入视角数:   {self.num_views}")
        print(f"  图像分辨率:   {self.image_size} x {self.image_size}")
        print(f"  动作块长度:   {self.chunk_size}")
        print(f"  测试次数:     {num_runs}")
        print(f"  预热轮数:     {warmup}")
        print(f"  动作维度:     {self.action_dim}")
        print(f"{'='*70}")

        # 准备输入数据
        inputs = self._prepare_inputs(seed=42)
        triton_images, triton_noise, triton_state = inputs["triton"]
        jax_input, jax_noise = inputs["jax"]

        # ===================== JAX 预热 =====================
        print("\n[1/4] JAX 预热...")
        for _ in range(warmup):
            _ = self.jax_policy.infer(jax_input, noise=jax_noise)
        print("JAX 预热完成")

        # ===================== JAX 正式测试 =====================
        print(f"[2/4] JAX 端到端推理计时({num_runs} 次)...")
        jax_times = []
        for i in range(num_runs):
            t0 = time.perf_counter()
            _ = self.jax_policy.infer(jax_input, noise=jax_noise)
            t1 = time.perf_counter()
            jax_times.append((t1 - t0) * 1000)
            if (i + 1) % 20 == 0:
                print(f"  JAX 进度: {i+1}/{num_runs}")

        # ===================== Triton 预热 =====================
        print("\n[3/4] Triton 预热...")
        for _ in range(warmup):
            _ = self.triton_model.forward(
                triton_images, triton_noise, self.prompt, triton_state
            )
        torch.cuda.synchronize()
        print("Triton 预热完成")

        # ===================== Triton 正式测试 =====================
        print(f"[4/4] Triton 端到端推理计时({num_runs} 次)...")
        triton_times = []
        for i in range(num_runs):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            _ = self.triton_model.forward(
                triton_images, triton_noise, self.prompt, triton_state
            )
            torch.cuda.synchronize()
            t1 = time.perf_counter()
            triton_times.append((t1 - t0) * 1000)
            if (i + 1) % 20 == 0:
                print(f"  Triton 进度: {i+1}/{num_runs}")

        # ===================== 速度汇总 =====================
        print(f"\n{'='*70}")
        print("【速度对比汇总】端到端推理(含 prompt 编码 + 推理)")
        print(f"{'='*70}")
        print(f"  模型名称:     {self.model_name}")
        print(f"  输入视角数:   {self.num_views}")
        print(f"  图像分辨率:   {self.image_size} x {self.image_size}")
        print(f"  动作块长度:   {self.chunk_size}")
        print(f"  测试次数:     {num_runs}")

        def _summary(times, name):
            sorted_t = sorted(times)
            n = len(times)
            mean_t = sum(times) / n
            median_t = sorted_t[n // 2]
            p99_t = sorted_t[int(n * 0.99)]
            min_t = min(times)
            max_t = max(times)
            std_t = np.std(times)
            print(f"\n  [{name}]")
            print(f"    平均延迟:     {mean_t:8.3f} ms")
            print(f"    中位数延迟:   {median_t:8.3f} ms")
            print(f"    P99 延迟:     {p99_t:8.3f} ms")
            print(f"    最小延迟:     {min_t:8.3f} ms")
            print(f"    最大延迟:     {max_t:8.3f} ms")
            print(f"    标准差:       {std_t:8.3f} ms")
            print(f"    理论帧率(avg):{1000/mean_t:7.1f} Hz")
            print(f"    理论帧率(med):{1000/median_t:7.1f} Hz")
            return mean_t, median_t, p99_t

        jax_mean, jax_med, jax_p99 = _summary(jax_times, "JAX 官方后端")
        tri_mean, tri_med, tri_p99 = _summary(triton_times, "Triton 加速后端")

        # 加速比
        print(f"\n{'='*70}")
        print("【加速比分析】")
        print(f"{'='*70}")
        speedup_avg = jax_mean / tri_mean if tri_mean > 0 else float('inf')
        speedup_med = jax_med / tri_med if tri_med > 0 else float('inf')
        speedup_p99 = jax_p99 / tri_p99 if tri_p99 > 0 else float('inf')

        print(f"  平均延迟加速比: {speedup_avg:.2f}x  (JAX {jax_mean:.1f}ms / Triton {tri_mean:.1f}ms)")
        print(f"  中位数加速比:   {speedup_med:.2f}x")
        print(f"  P99 加速比:     {speedup_p99:.2f}x")

        latency_reduction = (jax_mean - tri_mean) / jax_mean * 100
        print(f"\n  延迟降低:       {latency_reduction:.1f}%")
        print(f"  JAX 帧率:       {1000/jax_mean:.1f} Hz")
        print(f"  Triton 帧率:    {1000/tri_mean:.1f} Hz")

        if speedup_avg >= 2.0:
            print(f"  结论: ✅ 显著加速({speedup_avg:.1f}x),适合实时部署")
        elif speedup_avg >= 1.2:
            print(f"  结论: ⚠️ 中等加速({speedup_avg:.1f}x),有优化空间")
        else:
            print(f"  结论: ❌ 加速不明显({speedup_avg:.1f}x),需检查瓶颈")
        print(f"{'='*70}")

        return jax_times, triton_times


def main():
    parser = argparse.ArgumentParser(
        description="对比 JAX 官方后端与 Triton 加速后端的端到端推理速度"
    )
    parser.add_argument("--triton_path", type=str, default=DEFAULT_TRITON_PATH,
                        help=f"Triton 转换后权重路径 (默认: {DEFAULT_TRITON_PATH})")
    parser.add_argument("--jax_path", type=str, default=DEFAULT_JAX_PATH,
                        help=f"JAX 官方 checkpoint 路径 (默认: {DEFAULT_JAX_PATH})")
    parser.add_argument("--norm_stats_dir", type=str, default=DEFAULT_NORM_STATS_DIR,
                        help="归一化统计信息目录 (速度对比可选,默认空)")
    parser.add_argument("--config_name", type=str, default=DEFAULT_CONFIG_NAME,
                        help=f"配置名称 (默认: {DEFAULT_CONFIG_NAME})")
    parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT,
                        help=f"任务提示词 (默认: '{DEFAULT_PROMPT}')")
    parser.add_argument("--tokenizer_path", type=str, default=DEFAULT_TOKENIZER_PATH,
                        help=f"tokenizer 路径 (默认: {DEFAULT_TOKENIZER_PATH})")
    parser.add_argument("--discrete_state_input", action="store_true", default=True,
                        help="使用离散状态输入 (默认: True)")
    parser.add_argument("--action_dim", type=int, default=DEFAULT_ACTION_DIM,
                        help=f"动作维度 (默认: {DEFAULT_ACTION_DIM})")
    parser.add_argument("--chunk_size", type=int, default=DEFAULT_CHUNK_SIZE,
                        help=f"动作块长度 (默认: {DEFAULT_CHUNK_SIZE})")
    parser.add_argument("--num_views", type=int, default=DEFAULT_NUM_VIEWS,
                        help=f"输入视角数 (默认: {DEFAULT_NUM_VIEWS})")
    parser.add_argument("--image_size", type=int, default=DEFAULT_IMAGE_SIZE,
                        help=f"图像分辨率 (默认: {DEFAULT_IMAGE_SIZE})")
    parser.add_argument("--speed_runs", type=int, default=DEFAULT_SPEED_RUNS,
                        help=f"速度测试轮数 (默认: {DEFAULT_SPEED_RUNS})")
    parser.add_argument("--warmup", type=int, default=DEFAULT_WARMUP,
                        help=f"预热轮数 (默认: {DEFAULT_WARMUP})")
    args = parser.parse_args()

    comp = DroidSpeedComparator(
        triton_path=args.triton_path,
        jax_path=args.jax_path,
        config_name=args.config_name,
        tokenizer_path=args.tokenizer_path,
        prompt=args.prompt,
        discrete_state_input=args.discrete_state_input,
        action_dim=args.action_dim,
        chunk_size=args.chunk_size,
        num_views=args.num_views,
        image_size=args.image_size,
        norm_stats_dir=args.norm_stats_dir if args.norm_stats_dir else "",
    )

    comp.run(num_runs=args.speed_runs, warmup=args.warmup)


if __name__ == "__main__":
    main()

运行信息:

Loading Triton model...

max_prompt_len: 200, max_tokenize_len: 200

Loading JAX model...

======================================================================

【速度对比】JAX 官方后端 vs Triton 加速后端

======================================================================

模型名称: Pi05_base

输入视角数: 1

图像分辨率: 224 x 224

动作块长度: 15

测试次数: 100

预热轮数: 3

动作维度: 8

======================================================================

1/4 JAX 预热...

JAX 预热完成

2/4 JAX 端到端推理计时(100 次)...

JAX 进度: 20/100

JAX 进度: 40/100

JAX 进度: 60/100

JAX 进度: 80/100

JAX 进度: 100/100

3/4 Triton 预热...

Triton 预热完成

4/4 Triton 端到端推理计时(100 次)...

Triton 进度: 20/100

Triton 进度: 40/100

Triton 进度: 60/100

Triton 进度: 80/100

Triton 进度: 100/100

======================================================================

【速度对比汇总】端到端推理(含 prompt 编码 + 推理)

======================================================================

模型名称: Pi05_base

输入视角数: 1

图像分辨率: 224 x 224

动作块长度: 15

测试次数: 100

JAX 官方后端

平均延迟: 93.975 ms

中位数延迟: 98.326 ms

P99 延迟: 103.069 ms

最小延迟: 78.759 ms

最大延迟: 103.069 ms

标准差: 7.822 ms

理论帧率(avg): 10.6 Hz

理论帧率(med): 10.2 Hz

Triton 加速后端

平均延迟: 35.590 ms

中位数延迟: 30.926 ms

P99 延迟: 52.753 ms

最小延迟: 30.463 ms

最大延迟: 52.753 ms

标准差: 6.897 ms

理论帧率(avg): 28.1 Hz

理论帧率(med): 32.3 Hz

======================================================================

【加速比分析】

======================================================================

平均延迟加速比: 2.64x (JAX 94.0ms / Triton 35.6ms)

中位数加速比: 3.18x

P99 加速比: 1.95x

延迟降低: 62.1%

JAX 帧率: 10.6 Hz

Triton 帧率: 28.1 Hz

结论: ✅ 显著加速(2.6x),适合实时部署

======================================================================

5)编写一个代码,测试Pi0.5 基础权重推理验证脚本

用于确认 Triton 加速后的推理引擎在加载基础(非微调)权重后,输出是否稳定、合理且一致。

阶段 目的
加载基础权重 从 pickle 加载 pi05_base_converted.pkl(原始预训练权重,非 DROID 微调版)
初始化推理引擎 创建 Pi05Inference,配置单视角、50 步动作块、离散状态输入
构造伪输入 模拟 DROID 数据格式(单视角 224×224 图像 + 随机噪声 + 文本指令 + 状态 token)
预热 3 次前向传播稳定 GPU 状态
多次推理验证 10 次端到端推理,记录延迟与输出动作分布
一致性校验 对比 10 次输出,检查相同输入是否产生相同结果(确定性验证)
合理性检查 检测 NaN / Inf,确认输出张量形状与数值范围正常

运行指令:python test_verify_triton_droid.py

代码如下所示:

python 复制代码
import time
import pickle
import numpy as np
import torch
from pi05_infer import Pi05Inference

# 1. 加载转换后的微调权重
print("Loading converted weights...")
with open('pi05_base_converted.pkl', 'rb') as f:
    checkpoint = pickle.load(f)

# 2. 初始化推理引擎
print("Initializing Pi05Inference...")
infer = Pi05Inference(
    checkpoint=checkpoint,
    num_views=1,  # DROID 通常是单视角(外部相机)
    chunk_size=50,
    tokenizer_path="./paligemma-3b-pt-224",
    discrete_state_input=True,
    max_tokenize_len=200,
)

# 3. 准备测试输入(模拟 DROID 格式)
images = torch.randn(1, 224, 224, 3, dtype=torch.bfloat16, device="cuda")
noise = torch.randn(50, 32, dtype=torch.bfloat16, device="cuda")
task_prompt = "Pick up the bottle."
state_tokens = np.random.randint(0, 256, size=32, dtype=np.int32)

# 4. 预热
print("Warming up...")
for _ in range(3):
    _ = infer.forward(images, noise, task_prompt, state_tokens)
    torch.cuda.synchronize()

# 5. 多次推理验证输出稳定性
print("\nRunning 10 inference tests...")
actions_list = []
times = []

for i in range(10):
    torch.cuda.synchronize()
    t0 = time.perf_counter()

    actions = infer.forward(images, noise, task_prompt, state_tokens)

    torch.cuda.synchronize()
    t1 = time.perf_counter()

    elapsed_ms = (t1 - t0) * 1000
    times.append(elapsed_ms)
    actions_list.append(actions.cpu().float().numpy())

    print(f"  Test {i+1}: latency={elapsed_ms:.2f}ms, "
          f"actions_range=[{actions.min():.4f}, {actions.max():.4f}], "
          f"mean={actions.mean():.4f}")

# 6. 验证输出一致性(相同输入应产生相同输出)
print("\nVerifying output consistency...")
for i in range(1, 10):
    diff = np.abs(actions_list[i] - actions_list[0]).max()
    print(f"  Test {i+1} vs Test 1 max diff: {diff:.6f}")

# 7. 统计
print(f"\nLatency stats: mean={np.mean(times):.2f}ms, median={np.median(times):.2f}ms")

# 8. 合理性检查
print(f"\nOutput shape: {actions_list[0].shape}")
print(f"Output dtype: {actions.dtype}")
print(f"Contains NaN: {np.isnan(actions_list[0]).any()}")
print(f"Contains Inf: {np.isinf(actions_list[0]).any()}")

if not np.isnan(actions_list[0]).any() and not np.isinf(actions_list[0]).any():
    print("\n✅ Triton inference output is valid!")
else:
    print("\n❌ Triton inference output contains NaN or Inf!")

运行信息:

Loading converted weights...

Initializing Pi05Inference...

max_prompt_len: 200, max_tokenize_len: 200

Warming up...

Running 10 inference tests...

Test 1: latency=61.08ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 2: latency=57.96ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 3: latency=63.85ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 4: latency=64.14ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 5: latency=63.84ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 6: latency=63.70ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 7: latency=57.46ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 8: latency=34.04ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 9: latency=33.08ms, actions_range=-0.8125, 0.8398, mean=0.0354

Test 10: latency=33.89ms, actions_range=-0.8125, 0.8398, mean=0.0354

Verifying output consistency...

Test 2 vs Test 1 max diff: 0.000000

Test 3 vs Test 1 max diff: 0.000000

Test 4 vs Test 1 max diff: 0.000000

Test 5 vs Test 1 max diff: 0.000000

Test 6 vs Test 1 max diff: 0.000000

Test 7 vs Test 1 max diff: 0.000000

Test 8 vs Test 1 max diff: 0.000000

Test 9 vs Test 1 max diff: 0.000000

Test 10 vs Test 1 max diff: 0.000000

Latency stats: mean=53.30ms, median=59.52ms

Output shape: (50, 32)

Output dtype: torch.bfloat16

Contains NaN: False

Contains Inf: False

✅ Triton inference output is valid!

6、pi05 微调权重 实践示例

1)pi05 微调后 的模型转换

python 复制代码
python3 convert_from_jax_pi05.py \
   --jax_path /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_droid_finetune_low_mem/my_experiment/1999 \
   --output pi05_droid_finetune_low_mem_discrete_false_converted.pkl \
   --prompt "Pick up the bottle." \

运行信息:

Loading jax weights from /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_droid_finetune_low_mem/my_experiment/1999/params

WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.

Loaded JAX params keys: dict_keys('PaliGemma', 'action_in_proj', 'action_out_proj', 'time_mlp_in', 'time_mlp_out')

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.

Successfully converted Pi0.5 weights to pi05_droid_finetune_low_mem_discrete_false_converted.pkl

2)编写一个代码,测试****JAX 官方后端 vs Triton 加速后端的推理一致性对比验证

阶段 目的
双后端加载 同时加载 Triton 转换权重和 JAX 官方策略,建立对照基准
输入对齐 对图像和状态执行完全相同的预处理,消除输入差异
并行推理 同一组输入分别送入 Triton 和 JAX,获取各自输出
误差量化 计算全局 MAE、Max Error、逐维度 MAE
一致性评级 按误差阈值自动判定通过/警告/可部署/失败

运行命令:

bash 复制代码
python test_pi05_compare_jax_triton_droid-v1.py \
  --triton_path pi05_droid_finetune_low_mem_converted.pkl \
  --jax_path /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_droid_finetune_low_mem/my_experiment/1999 \
  --norm_stats_dir /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_droid_finetune_low_mem/my_experiment/1999/assets/droid \
  --config_name pi05_droid \
  --prompt "Pick up the bottle." \
  --tokenizer_path ./paligemma-3b-pt-224 \
  --action_dim 8 \
  --chunk_size 15

代码如下所示:

python 复制代码
import os
import json
import argparse
import pickle
import numpy as np
import torch
from PIL import Image
from pathlib import Path

from pi05_infer import Pi05Inference
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config


class DroidComparator:
    def __init__(self, triton_path, jax_path, norm_stats_dir, config_name,
                 tokenizer_path, prompt, discrete_state_input=True,  # 推荐 True
                 action_dim=8, chunk_size=15):
        self.prompt = prompt
        self.discrete_state_input = discrete_state_input
        self.action_dim = action_dim
        self.chunk_size = chunk_size

        # 加载归一化统计信息
        self.norm_stats = self._load_norm_stats(norm_stats_dir)
        # 离散化用的分箱边界(256个bin)
        self._digitize_bins = np.linspace(-1, 1, 256 + 1)[:-1]

        # 加载 Triton 模型(你的优化后端)
        print("Loading Triton model...")
        with open(triton_path, "rb") as f:
            weights = pickle.load(f)
        self.triton_model = Pi05Inference(
            checkpoint=weights,
            num_views=3,
            chunk_size=chunk_size,
            tokenizer_path=tokenizer_path,
            discrete_state_input=True,  # 与测试目标一致
            max_tokenize_len=200,
        )

        # 加载 JAX 模型(官方后端)
        print("Loading JAX model...")
        config = _config.get_config(config_name)
        self.jax_policy = _policy_config.create_trained_policy(config, Path(jax_path))

    def _load_norm_stats(self, norm_stats_dir):
        """从指定目录加载归一化统计信息"""
        norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
        if os.path.exists(norm_stats_path):
            with open(norm_stats_path, "r") as f:
                return json.load(f)["norm_stats"]
        return None

    def _pad_to_dim(self, x, target_dim, axis=-1):
        """将输入数组沿指定轴填充到目标维度"""
        current_dim = x.shape[axis] if len(x.shape) > 0 else len(x)
        if current_dim < target_dim:
            pad_width = [(0, 0)] * len(x.shape)
            pad_width[axis] = (0, target_dim - current_dim)
            return np.pad(x, pad_width)
        return x

    def _resize_with_pad(self, image, height=224, width=224):
        """保持宽高比缩放图像,并用0填充到目标尺寸(完全对齐官方test.py)"""
        pil_image = Image.fromarray(image)
        cur_width, cur_height = pil_image.size
        if cur_width == width and cur_height == height:
            return image
        ratio = max(cur_width / width, cur_height / height)
        resized_height = int(cur_height / ratio)
        resized_width = int(cur_width / ratio)
        resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
        zero_image = Image.new(resized_image.mode, (width, height), 0)
        pad_height = max(0, int((height - resized_height) / 2))
        pad_width = max(0, int((width - resized_width) / 2))
        zero_image.paste(resized_image, (pad_width, pad_height))
        return np.array(zero_image)

    def _preprocess_image(self, img_np):
        """图像预处理:解析格式 -> 缩放填充 -> 归一化到[-1, 1](完全对齐官方test.py)"""
        img = np.asarray(img_np)
        if np.issubdtype(img.dtype, np.floating):
            img = (255 * img).astype(np.uint8)
        if img.shape[0] == 3:
            import einops
            img = einops.rearrange(img, "c h w -> h w c")
        img = self._resize_with_pad(img, 224, 224)
        img = img.astype(np.float32) / 255.0 * 2.0 - 1.0
        return img

    def _preprocess_state_discrete(self, joint_pos, gripper_pos):
        """
        状态预处理(离散模式):拼接 -> 填充 -> 归一化 -> 离散化
        返回32个token ID(int64),完全对齐官方test.py
        """
        if np.isscalar(gripper_pos):
            gripper_pos = np.array([gripper_pos], dtype=np.float32)
        state = np.concatenate([joint_pos, gripper_pos]).astype(np.float32)
        state = self._pad_to_dim(state, 32)

        # 使用分位数归一化到[-1, 1]
        if self.norm_stats and "state" in self.norm_stats:
            q01 = np.array(self.norm_stats["state"]["q01"], dtype=np.float32)
            q99 = np.array(self.norm_stats["state"]["q99"], dtype=np.float32)
            q01 = self._pad_to_dim(q01, 32)
            q99 = self._pad_to_dim(q99, 32)
            state_normed = (state - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0
        else:
            state_normed = state

        # 离散化到256个bin
        state_normed = np.clip(state_normed, -1.0, 1.0)
        token_ids = np.digitize(state_normed, bins=self._digitize_bins) - 1
        return token_ids.astype(np.int64)

    def _unnormalize_actions(self, actions_normed):
        """将归一化后的动作反归一化到原始空间"""
        if self.norm_stats and "actions" in self.norm_stats:
            q01 = np.array(self.norm_stats["actions"]["q01"], dtype=np.float32)
            q99 = np.array(self.norm_stats["actions"]["q99"], dtype=np.float32)
            q01 = self._pad_to_dim(q01, 32)
            q99 = self._pad_to_dim(q99, 32)
            actions = (actions_normed + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
        else:
            actions = actions_normed
        return actions[:, :self.action_dim]

    def compare(self, num_tests=10):
        print(f"\n{'='*60}")
        print("JAX vs Triton DROID 推理对比 (discrete_state_input=True)")
        print(f"{'='*60}")

        for i in range(num_tests):
            # 1. 生成一组随机输入(状态、噪声)
            droid_example = droid_policy.make_droid_example()
            noise_np = np.random.randn(self.chunk_size, 32).astype(np.float32)

            # 2. 提取原始图像和状态
            exterior = droid_example["observation/exterior_image_1_left"]
            wrist = droid_example["observation/wrist_image_left"]
            joint_pos = np.asarray(droid_example["observation/joint_position"])
            gripper_pos = np.asarray(droid_example["observation/gripper_position"])

            # 3. 【核心修复】对图像执行完全相同的预处理,供 JAX 和 Triton 共用
            img_base = self._preprocess_image(exterior)
            img_left = self._preprocess_image(wrist)
            img_right = np.zeros_like(img_base)  # 第三视角用全0填充

            # 4. 预处理状态(离散token)
            state_tokens = self._preprocess_state_discrete(joint_pos, gripper_pos)

            # 5. 【Triton 推理】将预处理后的数据转为 Torch Tensor 并送入 GPU
            images = torch.from_numpy(
                np.stack([img_base, img_left, img_right], axis=0)
            ).to(torch.bfloat16).cuda()
            state_torch = torch.from_numpy(state_tokens).to(torch.long).cuda()
            noise_torch = torch.from_numpy(noise_np).to(torch.bfloat16).cuda()
            triton_output = self.triton_model.forward(images, noise_torch, self.prompt, state_torch)
            triton_actions = self._unnormalize_actions(triton_output.cpu().float().numpy())

            # 6. 【JAX 推理】构造与 Triton 完全相同的输入字典(图像+状态+提示词)
            #    注意:JAX 的 policy.infer 内部会自行完成归一化和离散化,
            #    因此我们传入原始状态(joint_pos, gripper_pos)即可,但图像必须预处理。
            jax_input = {
                "observation/exterior_image_1_left": img_base,   # 已预处理
                "observation/wrist_image_left": img_left,        # 已预处理
                "observation/joint_position": joint_pos,         # 原始状态
                "observation/gripper_position": gripper_pos,     # 原始状态
                "prompt": self.prompt,
            }
            jax_result = self.jax_policy.infer(jax_input, noise=noise_np)
            jax_actions = jax_result["actions"]

            # 7. 对齐动作长度并计算误差
            min_steps = min(jax_actions.shape[0], triton_actions.shape[0])
            jax_aligned = jax_actions[:min_steps]
            triton_aligned = triton_actions[:min_steps]

            print(f"\n--- Test {i+1}/{num_tests} ---")
            print(f"JAX    actions: shape={jax_actions.shape}, range=[{jax_actions.min():.4f}, {jax_actions.max():.4f}]")
            print(f"Triton actions: shape={triton_actions.shape}, range=[{triton_actions.min():.4f}, {triton_actions.max():.4f}]")

            mae = np.mean(np.abs(triton_aligned - jax_aligned))
            max_err = np.max(np.abs(triton_aligned - jax_aligned))
            print(f"Global MAE: {mae:.6f}, Max Error: {max_err:.6f}")

            print("Per-dimension MAE:")
            for dim in range(self.action_dim):
                dim_mae = np.mean(np.abs(triton_aligned[:, dim] - jax_aligned[:, dim]))
                print(f"  dim {dim}: {dim_mae:.6f}")

            # 根据误差级别给出评估
            if mae < 1e-3:
                print("✅ PASS: MAE < 1e-3")
            elif mae < 1e-2:
                print("⚠️  WARNING: MAE < 1e-2 (acceptable)")
            elif mae < 5e-2:
                print("📊 MAE < 5e-2 (design trade-off, deployable)")
            else:
                print("❌ FAIL: MAE too large")

        print(f"\n{'='*60}")
        print("对比完成")
        print(f"{'='*60}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--triton_path", type=str, required=True)
    parser.add_argument("--jax_path", type=str, required=True)
    parser.add_argument("--norm_stats_dir", type=str, required=True)
    parser.add_argument("--config_name", type=str, default="pi05_droid")
    parser.add_argument("--prompt", type=str, default="Pick up the bottle.")
    parser.add_argument("--tokenizer_path", type=str, default="./paligemma-3b-pt-224")
    parser.add_argument("--discrete_state_input", action="store_true", default=True)
    parser.add_argument("--action_dim", type=int, default=8)
    parser.add_argument("--chunk_size", type=int, default=15)
    args = parser.parse_args()

    comp = DroidComparator(
        triton_path=args.triton_path,
        jax_path=args.jax_path,
        norm_stats_dir=args.norm_stats_dir,
        config_name=args.config_name,
        tokenizer_path=args.tokenizer_path,
        prompt=args.prompt,
        discrete_state_input=args.discrete_state_input,
        action_dim=args.action_dim,
        chunk_size=args.chunk_size,
    )
    comp.compare(num_tests=10)


if __name__ == "__main__":
    main()

运行信息:

Loading Triton model...

max_prompt_len: 200, max_tokenize_len: 200

Loading JAX model...

============================================================

JAX vs Triton DROID 推理对比 (discrete_state_input=True)

============================================================

--- Test 1/10 ---

JAX actions: shape=(15, 8), range=-0.2873, 0.8841

Triton actions: shape=(15, 8), range=-0.2408, 0.8475

Global MAE: 0.043570, Max Error: 0.184403

Per-dimension MAE:

dim 0: 0.012646

dim 1: 0.026804

dim 2: 0.074144

dim 3: 0.032435

dim 4: 0.028762

dim 5: 0.091247

dim 6: 0.051733

dim 7: 0.030788

📊 MAE < 5e-2 (design trade-off, deployable)

--- Test 2/10 ---

JAX actions: shape=(15, 8), range=-0.6280, 0.6690

Triton actions: shape=(15, 8), range=-0.6153, 0.5471

Global MAE: 0.077840, Max Error: 0.259895

Per-dimension MAE:

dim 0: 0.114950

dim 1: 0.098268

dim 2: 0.019811

dim 3: 0.115439

dim 4: 0.062824

dim 5: 0.081949

dim 6: 0.096452

dim 7: 0.033029

❌ FAIL: MAE too large

--- Test 3/10 ---

JAX actions: shape=(15, 8), range=-0.5544, 0.3675

Triton actions: shape=(15, 8), range=-0.4773, 0.3750

Global MAE: 0.079263, Max Error: 0.319341

Per-dimension MAE:

dim 0: 0.033593

dim 1: 0.097391

dim 2: 0.037398

dim 3: 0.063690

dim 4: 0.070841

dim 5: 0.174956

dim 6: 0.121281

dim 7: 0.034952

❌ FAIL: MAE too large

--- Test 4/10 ---

JAX actions: shape=(15, 8), range=-0.4042, 0.8830

Triton actions: shape=(15, 8), range=-0.2996, 0.9139

Global MAE: 0.064592, Max Error: 0.211363

Per-dimension MAE:

dim 0: 0.019950

dim 1: 0.046448

dim 2: 0.056883

dim 3: 0.068979

dim 4: 0.078694

dim 5: 0.098469

dim 6: 0.128584

dim 7: 0.018731

❌ FAIL: MAE too large

--- Test 5/10 ---

JAX actions: shape=(15, 8), range=-0.6994, 0.6707

Triton actions: shape=(15, 8), range=-0.6767, 0.6503

Global MAE: 0.088871, Max Error: 0.393171

Per-dimension MAE:

dim 0: 0.051963

dim 1: 0.037229

dim 2: 0.039792

dim 3: 0.050262

dim 4: 0.113197

dim 5: 0.065095

dim 6: 0.186282

dim 7: 0.167150

❌ FAIL: MAE too large

--- Test 6/10 ---

JAX actions: shape=(15, 8), range=-0.6909, 0.3633

Triton actions: shape=(15, 8), range=-0.7374, 0.7616

Global MAE: 0.128749, Max Error: 0.597238

Per-dimension MAE:

dim 0: 0.029933

dim 1: 0.050044

dim 2: 0.056748

dim 3: 0.031456

dim 4: 0.169972

dim 5: 0.073704

dim 6: 0.192077

dim 7: 0.426054

❌ FAIL: MAE too large

--- Test 7/10 ---

JAX actions: shape=(15, 8), range=-0.4557, 0.8646

Triton actions: shape=(15, 8), range=-0.4083, 0.8885

Global MAE: 0.062593, Max Error: 0.266613

Per-dimension MAE:

dim 0: 0.011620

dim 1: 0.023778

dim 2: 0.050466

dim 3: 0.014120

dim 4: 0.148288

dim 5: 0.035732

dim 6: 0.133349

dim 7: 0.083391

❌ FAIL: MAE too large

--- Test 8/10 ---

JAX actions: shape=(15, 8), range=-0.6022, 0.6886

Triton actions: shape=(15, 8), range=-0.5626, 0.6398

Global MAE: 0.082532, Max Error: 0.254548

Per-dimension MAE:

dim 0: 0.073425

dim 1: 0.063760

dim 2: 0.050406

dim 3: 0.057179

dim 4: 0.114202

dim 5: 0.028413

dim 6: 0.135704

dim 7: 0.137168

❌ FAIL: MAE too large

--- Test 9/10 ---

JAX actions: shape=(15, 8), range=-0.1179, 0.1835

Triton actions: shape=(15, 8), range=-0.3947, 0.6019

Global MAE: 0.160285, Max Error: 0.418451

Per-dimension MAE:

dim 0: 0.190313

dim 1: 0.068125

dim 2: 0.181805

dim 3: 0.201560

dim 4: 0.087499

dim 5: 0.130815

dim 6: 0.066352

dim 7: 0.355810

❌ FAIL: MAE too large

--- Test 10/10 ---

JAX actions: shape=(15, 8), range=-0.4795, 0.8724

Triton actions: shape=(15, 8), range=-0.5005, 0.8983

Global MAE: 0.034000, Max Error: 0.150526

Per-dimension MAE:

dim 0: 0.018649

dim 1: 0.019654

dim 2: 0.037001

dim 3: 0.043121

dim 4: 0.044089

dim 5: 0.042278

dim 6: 0.038548

dim 7: 0.028658

📊 MAE < 5e-2 (design trade-off, deployable)

============================================================

对比完成

============================================================

3)编写一个代码,测试****JAX 官方后端 vs Triton 加速后端的推理一致性对比验证,添加可视化

计算 MAE/RMSE/Max Error 等多维误差指标,并自动生成 可视化图片

功能模块 说明
双后端推理对比 同时加载 Triton 优化后端(Pi05Inference)和 JAX 官方后端(openpi policy),对同一组输入进行并行推理
输入对齐修复 核心修复:JAX 和 Triton 使用完全相同的预处理图像img_base/img_left/img_right),消除预处理差异导致的误差
离散状态输入 支持 discrete_state_input=True 模式:关节位置+夹爪位置 → 填充到32维 → 分位数归一化 → 256-bin 离散化 → int64 token IDs
动作反归一化 基于 norm_stats.json 中的 q01/q99 分位数,将模型输出的归一化动作还原到原始空间
多维度误差评估 计算 Global MAE / Max Error / RMSE,以及每个动作维度(Dim 0~7)的独立 MAE
阈值判定 根据 MAE 自动判定:✅ PASS(<1e-3) / ⚠️ WARNING(<1e-2) / 📊 DEPLOYABLE(<5e-2) / ❌ FAIL(≥5e-2)

运行指令:

复制代码
 python test_pi05_compare_jax_triton_droid-vis.py \
  --triton_path pi05_droid_finetune_low_mem_converted.pkl \
  --jax_path /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_droid_finetune_low_mem/my_experiment/1999 \
  --norm_stats_dir /home/liguopu/lgp_dev/project/openpi/checkpoints/pi05_droid_finetune_low_mem/my_experiment/1999/assets/droid \
  --config_name pi05_droid \
  --prompt "Pick up the bottle." \
  --tokenizer_path ./paligemma-3b-pt-224 \
  --action_dim 8 \
  --chunk_size 15 \
  --output_dir ./comparison_results

参数说明:

参数 说明
--triton_path Triton 转换后的 .pkl 权重路径
--jax_path JAX 官方 checkpoint 目录路径
--norm_stats_dir 归一化统计信息目录(含 norm_stats.json
--config_name 模型配置名(默认 pi05_droid
--prompt 文本提示词
--tokenizer_path PaliGemma tokenizer 路径
--action_dim 动作维度(默认 8)
--chunk_size 动作块长度(默认 15)
--output_dir 新增 :可视化输出目录(默认 ./comparison_results

参考代码:

python 复制代码
# test_pi05_compare_jax_triton_droid.py ------ 已修复:JAX 和 Triton 使用完全相同的预处理输入
# 增强版:添加多维度可视化对比(区分不同组数据,分别保存+汇总大图)
import os
import json
import argparse
import pickle
import numpy as np
import torch
from PIL import Image
from pathlib import Path

from pi05_infer import Pi05Inference
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config

# ========== 新增:可视化依赖 ==========
import matplotlib
matplotlib.use("Agg")  # 无头环境使用Agg后端
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


class DroidComparator:
    def __init__(self, triton_path, jax_path, norm_stats_dir, config_name,
                 tokenizer_path, prompt, discrete_state_input=True,
                 action_dim=8, chunk_size=15, output_dir="./comparison_results"):
        self.prompt = prompt
        self.discrete_state_input = discrete_state_input
        self.action_dim = action_dim
        self.chunk_size = chunk_size
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)  # 创建输出目录

        # 加载归一化统计信息
        self.norm_stats = self._load_norm_stats(norm_stats_dir)
        # 离散化用的分箱边界(256个bin)
        self._digitize_bins = np.linspace(-1, 1, 256 + 1)[:-1]

        # 加载 Triton 模型(你的优化后端)
        print("Loading Triton model...")
        with open(triton_path, "rb") as f:
            weights = pickle.load(f)
        self.triton_model = Pi05Inference(
            checkpoint=weights,
            num_views=3,
            chunk_size=chunk_size,
            tokenizer_path=tokenizer_path,
            discrete_state_input=True,
            max_tokenize_len=200,
        )

        # 加载 JAX 模型(官方后端)
        print("Loading JAX model...")
        config = _config.get_config(config_name)
        self.jax_policy = _policy_config.create_trained_policy(config, Path(jax_path))

        # ========== 新增:存储所有测试结果用于汇总可视化 ==========
        self.all_results = []  # 存储每组测试的完整数据

    def _load_norm_stats(self, norm_stats_dir):
        """从指定目录加载归一化统计信息"""
        norm_stats_path = os.path.join(norm_stats_dir, "norm_stats.json")
        if os.path.exists(norm_stats_path):
            with open(norm_stats_path, "r") as f:
                return json.load(f)["norm_stats"]
        return None

    def _pad_to_dim(self, x, target_dim, axis=-1):
        """将输入数组沿指定轴填充到目标维度"""
        current_dim = x.shape[axis] if len(x.shape) > 0 else len(x)
        if current_dim < target_dim:
            pad_width = [(0, 0)] * len(x.shape)
            pad_width[axis] = (0, target_dim - current_dim)
            return np.pad(x, pad_width)
        return x

    def _resize_with_pad(self, image, height=224, width=224):
        """保持宽高比缩放图像,并用0填充到目标尺寸(完全对齐官方test.py)"""
        pil_image = Image.fromarray(image)
        cur_width, cur_height = pil_image.size
        if cur_width == width and cur_height == height:
            return image
        ratio = max(cur_width / width, cur_height / height)
        resized_height = int(cur_height / ratio)
        resized_width = int(cur_width / ratio)
        resized_image = pil_image.resize((resized_width, resized_height), resample=Image.BILINEAR)
        zero_image = Image.new(resized_image.mode, (width, height), 0)
        pad_height = max(0, int((height - resized_height) / 2))
        pad_width = max(0, int((width - resized_width) / 2))
        zero_image.paste(resized_image, (pad_width, pad_height))
        return np.array(zero_image)

    def _preprocess_image(self, img_np):
        """图像预处理:解析格式 -> 缩放填充 -> 归一化到[-1, 1](完全对齐官方test.py)"""
        img = np.asarray(img_np)
        if np.issubdtype(img.dtype, np.floating):
            img = (255 * img).astype(np.uint8)
        if img.shape[0] == 3:
            import einops
            img = einops.rearrange(img, "c h w -> h w c")
        img = self._resize_with_pad(img, 224, 224)
        img = img.astype(np.float32) / 255.0 * 2.0 - 1.0
        return img

    def _preprocess_state_discrete(self, joint_pos, gripper_pos):
        """
        状态预处理(离散模式):拼接 -> 填充 -> 归一化 -> 离散化
        返回32个token ID(int64),完全对齐官方test.py
        """
        if np.isscalar(gripper_pos):
            gripper_pos = np.array([gripper_pos], dtype=np.float32)
        state = np.concatenate([joint_pos, gripper_pos]).astype(np.float32)
        state = self._pad_to_dim(state, 32)

        # 使用分位数归一化到[-1, 1]
        if self.norm_stats and "state" in self.norm_stats:
            q01 = np.array(self.norm_stats["state"]["q01"], dtype=np.float32)
            q99 = np.array(self.norm_stats["state"]["q99"], dtype=np.float32)
            q01 = self._pad_to_dim(q01, 32)
            q99 = self._pad_to_dim(q99, 32)
            state_normed = (state - q01) / (q99 - q01 + 1e-6) * 2.0 - 1.0
        else:
            state_normed = state

        # 离散化到256个bin
        state_normed = np.clip(state_normed, -1.0, 1.0)
        token_ids = np.digitize(state_normed, bins=self._digitize_bins) - 1
        return token_ids.astype(np.int64)

    def _unnormalize_actions(self, actions_normed):
        """将归一化后的动作反归一化到原始空间"""
        if self.norm_stats and "actions" in self.norm_stats:
            q01 = np.array(self.norm_stats["actions"]["q01"], dtype=np.float32)
            q99 = np.array(self.norm_stats["actions"]["q99"], dtype=np.float32)
            q01 = self._pad_to_dim(q01, 32)
            q99 = self._pad_to_dim(q99, 32)
            actions = (actions_normed + 1.0) / 2.0 * (q99 - q01 + 1e-6) + q01
        else:
            actions = actions_normed
        return actions[:, :self.action_dim]

    # ========== 新增:单组测试可视化(保存独立图片) ==========
    def _plot_single_test(self, test_idx, jax_actions, triton_actions, 
                          global_mae, global_max_err, global_rmse, per_dim_mae):
        """
        为单个测试绘制8维度时序对比图
        区分不同维度的数据组,用不同颜色/线型展示
        """
        fig, axes = plt.subplots(4, 2, figsize=(14, 16))
        axes = axes.flatten()
        t = np.arange(self.chunk_size)

        for d in range(self.action_dim):
            ax = axes[d]
            # JAX 动作:蓝色实线+圆点标记
            ax.plot(t, jax_actions[:, d], 'b-o', label='JAX', 
                    markersize=5, linewidth=1.8, alpha=0.85, zorder=3)
            # Triton 动作:红色虚线+方块标记
            ax.plot(t, triton_actions[:, d], 'r-s', label='Triton', 
                    markersize=5, linewidth=1.8, alpha=0.85, zorder=3)
            # 误差填充区域:紫色半透明
            ax.fill_between(t, jax_actions[:, d], triton_actions[:, d], 
                           alpha=0.25, color='purple', label='Error', zorder=1)

            # 标注维度MAE
            ax.set_title(f'Dim {d}  (MAE={per_dim_mae[d]:.4f})', 
                        fontsize=12, fontweight='bold', color='#2C3E50')
            ax.set_xlabel('Time Step', fontsize=10)
            ax.set_ylabel('Action Value', fontsize=10)
            ax.legend(fontsize=9, loc='best', framealpha=0.9)
            ax.grid(alpha=0.3, linestyle='--')
            ax.set_xticks(t)

        # 全局标题
        status = "PASS" if global_mae < 1e-3 else ("WARNING" if global_mae < 1e-2 else 
                 ("DEPLOYABLE" if global_mae < 5e-2 else "FAIL"))
        status_color = {"PASS": "green", "WARNING": "orange", "DEPLOYABLE": "blue", "FAIL": "red"}[status]

        fig.suptitle(
            f'Test {test_idx+1}: JAX vs Triton Action Sequence Comparison
'
            f'Global MAE={global_mae:.4f} | MaxErr={global_max_err:.4f} | RMSE={global_rmse:.4f} | Status: {status}',
            fontsize=14, fontweight='bold', color=status_color, y=1.01
        )
        plt.tight_layout()

        save_path = os.path.join(self.output_dir, f"test_{test_idx+1:02d}_timeseries.png")
        plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        print(f"  📊 已保存单组可视化: {save_path}")

    # ========== 新增:汇总可视化(所有测试整合到一张大图) ==========
    def _plot_summary_dashboard(self):
        """
        汇总所有测试结果,生成综合对比仪表板大图
        区分不同测试组的数据,用子图网格展示
        """
        if not self.all_results:
            print("⚠️ 无测试结果,跳过汇总可视化")
            return

        num_tests = len(self.all_results)
        # 提取所有测试的指标
        global_maes = [r["global_mae"] for r in self.all_results]
        global_maxs = [r["global_max_err"] for r in self.all_results]
        global_rmses = [r["global_rmse"] for r in self.all_results]
        per_dim_maes = np.array([r["per_dim_mae"] for r in self.all_results])  # (num_tests, action_dim)

        # 收集所有误差数据
        all_errors = []
        for r in self.all_results:
            all_errors.append(r["triton_actions"] - r["jax_actions"])
        all_errors = np.array(all_errors)  # (num_tests, chunk_size, action_dim)

        fig = plt.figure(figsize=(22, 26))

        # ---- 子图A: 全局指标柱状图(区分不同测试组) ----
        ax1 = plt.subplot2grid((5, 2), (0, 0), colspan=2)
        x = np.arange(num_tests)
        width = 0.25
        bars1 = ax1.bar(x - width, global_maes, width, label='MAE', color='#E74C3C', alpha=0.85, edgecolor='black', linewidth=0.5)
        bars2 = ax1.bar(x, global_rmses, width, label='RMSE', color='#3498DB', alpha=0.85, edgecolor='black', linewidth=0.5)
        bars3 = ax1.bar(x + width, global_maxs, width, label='Max Error', color='#F39C12', alpha=0.85, edgecolor='black', linewidth=0.5)

        # 阈值参考线
        ax1.axhline(y=1e-3, color='green', linestyle='--', linewidth=2, alpha=0.8, label='Pass (<1e-3)')
        ax1.axhline(y=1e-2, color='orange', linestyle='--', linewidth=2, alpha=0.8, label='Warning (<1e-2)')
        ax1.axhline(y=5e-2, color='red', linestyle='--', linewidth=2, alpha=0.8, label='Deployable (<5e-2)')

        ax1.set_title('(A) Global Error Metrics per Test Group', fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax1.set_ylabel('Error Value', fontsize=11)
        ax1.set_xticks(x)
        ax1.set_xticklabels([f'Test {i+1}' for i in range(num_tests)], rotation=30, ha='right')
        ax1.legend(loc='upper left', fontsize=9, ncol=2)
        ax1.grid(axis='y', alpha=0.3, linestyle='--')

        # 柱顶标注数值
        for bars in [bars1, bars2, bars3]:
            for bar in bars:
                height = bar.get_height()
                ax1.annotate(f'{height:.3f}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3), textcoords="offset points",
                            ha='center', va='bottom', fontsize=7, rotation=90)

        # ---- 子图B: 各维度MAE热力图(区分维度组) ----
        ax2 = plt.subplot2grid((5, 2), (1, 0))
        im = ax2.imshow(per_dim_maes.T, cmap='YlOrRd', aspect='auto', vmin=0, vmax=0.7)
        ax2.set_xticks(np.arange(num_tests))
        ax2.set_xticklabels([f'T{i+1}' for i in range(num_tests)], fontsize=9)
        ax2.set_yticks(np.arange(self.action_dim))
        ax2.set_yticklabels([f'Dim {d}' for d in range(self.action_dim)], fontsize=9)
        ax2.set_title('(B) Per-Dimension MAE Heatmap', fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        # 热力图数值标注
        for i in range(num_tests):
            for d in range(self.action_dim):
                val = per_dim_maes[i, d]
                text_color = 'white' if val > 0.35 else 'black'
                ax2.text(i, d, f'{val:.3f}', ha='center', va='center', 
                        color=text_color, fontsize=7, fontweight='bold')
        plt.colorbar(im, ax=ax2, label='MAE', fraction=0.046)

        # ---- 子图C: 各维度MAE箱线图(区分维度组) ----
        ax3 = plt.subplot2grid((5, 2), (1, 1))
        dim_data = [per_dim_maes[:, d] for d in range(self.action_dim)]
        bp = ax3.boxplot(dim_data, tick_labels=[f'D{d}' for d in range(self.action_dim)],
                          patch_artist=True, showmeans=True, meanline=True,
                          boxprops=dict(linewidth=1.5),
                          medianprops=dict(color='black', linewidth=2),
                          meanprops=dict(color='green', linewidth=2, linestyle='--'))
        colors = plt.cm.Set3(np.linspace(0, 1, self.action_dim))
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
            patch.set_edgecolor('black')
            patch.set_linewidth(1.5)
        ax3.axhline(y=5e-2, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Deployable threshold')
        ax3.set_title('(C) Per-Dimension MAE Distribution', fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax3.set_ylabel('MAE', fontsize=11)
        ax3.legend(fontsize=9)
        ax3.grid(axis='y', alpha=0.3, linestyle='--')

        # ---- 子图D: 最佳测试时序对比(Dim 0-3,偏移展示区分组) ----
        ax4 = plt.subplot2grid((5, 2), (2, 0))
        best_idx = int(np.argmin(global_maes))
        t = np.arange(self.chunk_size)
        dim_colors = plt.cm.tab10(np.linspace(0, 1, 4))
        for d in range(4):
            offset = d * 2.5
            ax4.plot(t, self.all_results[best_idx]["jax_actions"][:, d] + offset, 
                    '-', color=dim_colors[d], linewidth=2, alpha=0.9, label=f'JAX Dim{d}')
            ax4.plot(t, self.all_results[best_idx]["triton_actions"][:, d] + offset, 
                    '--', color=dim_colors[d], linewidth=2, alpha=0.9, label=f'Triton Dim{d}')
            ax4.fill_between(t, 
                            self.all_results[best_idx]["jax_actions"][:, d] + offset,
                            self.all_results[best_idx]["triton_actions"][:, d] + offset,
                            alpha=0.15, color=dim_colors[d])
        ax4.set_title(f'(D) Best Test (T{best_idx+1}, MAE={global_maes[best_idx]:.3f}): Dims 0-3', 
                     fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax4.set_xlabel('Time Step', fontsize=10)
        ax4.set_ylabel('Action Value (offset by dim)', fontsize=10)
        ax4.grid(alpha=0.3, linestyle='--')
        # 简化图例
        legend_elements = [
            Line2D([0], [0], color='black', lw=2, linestyle='-', label='JAX'),
            Line2D([0], [0], color='black', lw=2, linestyle='--', label='Triton')
        ]
        ax4.legend(handles=legend_elements, loc='upper right', fontsize=9)

        # ---- 子图E: 最差测试时序对比(Dim 0-3) ----
        ax5 = plt.subplot2grid((5, 2), (2, 1))
        worst_idx = int(np.argmax(global_maes))
        for d in range(4):
            offset = d * 2.5
            ax5.plot(t, self.all_results[worst_idx]["jax_actions"][:, d] + offset, 
                    '-', color=dim_colors[d], linewidth=2, alpha=0.9)
            ax5.plot(t, self.all_results[worst_idx]["triton_actions"][:, d] + offset, 
                    '--', color=dim_colors[d], linewidth=2, alpha=0.9)
            ax5.fill_between(t, 
                            self.all_results[worst_idx]["jax_actions"][:, d] + offset,
                            self.all_results[worst_idx]["triton_actions"][:, d] + offset,
                            alpha=0.15, color=dim_colors[d])
        ax5.set_title(f'(E) Worst Test (T{worst_idx+1}, MAE={global_maes[worst_idx]:.3f}): Dims 0-3', 
                     fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax5.set_xlabel('Time Step', fontsize=10)
        ax5.set_ylabel('Action Value (offset by dim)', fontsize=10)
        ax5.grid(alpha=0.3, linestyle='--')
        ax5.legend(handles=legend_elements, loc='upper right', fontsize=9)

        # ---- 子图F: 误差分布直方图(所有测试所有维度) ----
        ax6 = plt.subplot2grid((5, 2), (3, 0))
        all_err_flat = all_errors.flatten()
        ax6.hist(all_err_flat, bins=60, color='steelblue', edgecolor='black', 
                alpha=0.75, linewidth=0.8)
        ax6.axvline(x=0, color='red', linestyle='--', linewidth=2.5, label='Zero Error')
        ax6.axvline(x=np.mean(all_err_flat), color='green', linestyle='--', linewidth=2.5, 
                   label=f'Mean={np.mean(all_err_flat):.4f}')
        ax6.axvline(x=np.median(all_err_flat), color='purple', linestyle='--', linewidth=2.5,
                   label=f'Median={np.median(all_err_flat):.4f}')
        ax6.set_title('(F) Error Distribution (All Tests, All Dims)', fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax6.set_xlabel('Error (Triton - JAX)', fontsize=10)
        ax6.set_ylabel('Frequency', fontsize=10)
        ax6.legend(fontsize=9)
        ax6.grid(alpha=0.3, linestyle='--')

        # ---- 子图G: MAE/RMSE趋势+阈值区域 ----
        ax7 = plt.subplot2grid((5, 2), (3, 1))
        test_nums = np.arange(1, num_tests + 1)
        ax7.plot(test_nums, global_maes, 'o-', color='#E74C3C', linewidth=2.5, 
                markersize=9, markerfacecolor='white', markeredgewidth=2, label='MAE')
        ax7.plot(test_nums, global_rmses, 's-', color='#3498DB', linewidth=2.5, 
                markersize=9, markerfacecolor='white', markeredgewidth=2, label='RMSE')
        # 阈值区域填充
        ax7.axhline(y=1e-3, color='green', linestyle='--', linewidth=1.5, alpha=0.7)
        ax7.axhline(y=1e-2, color='orange', linestyle='--', linewidth=1.5, alpha=0.7)
        ax7.axhline(y=5e-2, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
        ax7.fill_between(test_nums, 0, 1e-3, alpha=0.15, color='green', label='Pass zone')
        ax7.fill_between(test_nums, 1e-3, 1e-2, alpha=0.15, color='orange', label='Warning zone')
        ax7.fill_between(test_nums, 1e-2, 5e-2, alpha=0.15, color='red', label='Deployable zone')
        ax7.set_title('(G) MAE/RMSE Trend with Threshold Zones', fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax7.set_xlabel('Test Index', fontsize=10)
        ax7.set_ylabel('Error', fontsize=10)
        ax7.set_xticks(test_nums)
        ax7.legend(loc='upper right', fontsize=9, ncol=2)
        ax7.grid(alpha=0.3, linestyle='--')

        # ---- 子图H: 全测试误差热力图(时间步×维度,按测试分组) ----
        ax8 = plt.subplot2grid((5, 2), (4, 0), colspan=2)
        # 将所有测试的误差拼接成一个大矩阵展示
        # 每测试之间插入分隔行
        separator = np.full((1, self.action_dim), np.nan)
        display_matrix = []
        for i in range(num_tests):
            display_matrix.append(all_errors[i].T)  # (action_dim, chunk_size)
            if i < num_tests - 1:
                display_matrix.append(separator.T)
        display_matrix = np.concatenate(display_matrix, axis=1)  # (action_dim, total_timesteps)

        im2 = ax8.imshow(display_matrix, cmap='RdBu_r', aspect='auto', 
                        vmin=-0.7, vmax=0.7, interpolation='nearest')
        ax8.set_title('(H) Error Heatmap: Triton - JAX (All Tests × Time Steps × Dims)', 
                     fontsize=14, fontweight='bold', loc='left', color='#2C3E50')
        ax8.set_xlabel('Time Step (grouped by test)', fontsize=10)
        ax8.set_ylabel('Action Dimension', fontsize=10)
        ax8.set_yticks(np.arange(self.action_dim))
        ax8.set_yticklabels([f'Dim {d}' for d in range(self.action_dim)], fontsize=9)
        # 标注测试分界线
        cumsum = 0
        for i in range(num_tests - 1):
            cumsum += self.chunk_size
            ax8.axvline(x=cumsum + i - 0.5, color='black', linewidth=1.5, linestyle='-')
        plt.colorbar(im2, ax=ax8, label='Error', fraction=0.02, pad=0.01)

        # 总标题
        fig.suptitle(
            'JAX vs Triton DROID Inference: Comprehensive Error Analysis Dashboard
'
            f'Total Tests: {num_tests} | Action Dim: {self.action_dim} | Chunk Size: {self.chunk_size}',
            fontsize=18, fontweight='bold', color='#1A252F', y=0.995
        )
        plt.tight_layout(rect=[0, 0, 1, 0.98])

        save_path = os.path.join(self.output_dir, "summary_dashboard.png")
        plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        print(f"
📊 汇总仪表板已保存: {save_path}")

    # ========== 新增:关键维度专项分析图(Dim 7通常是误差最大的维度) ==========
    def _plot_critical_dim_analysis(self):
        """对误差最大的维度(通常是夹爪/末端执行器维度)进行专项分析"""
        if not self.all_results:
            return

        num_tests = len(self.all_results)
        # 找出平均MAE最大的维度
        per_dim_avg = np.mean([r["per_dim_mae"] for r in self.all_results], axis=0)
        critical_dim = int(np.argmax(per_dim_avg))

        fig, axes = plt.subplots(2, 2, figsize=(14, 12))

        # 1. 关键维度在所有测试中的MAE对比
        ax = axes[0, 0]
        dim_maes = [r["per_dim_mae"][critical_dim] for r in self.all_results]
        colors_bar = ['#27AE60' if v < 1e-3 else '#F39C12' if v < 1e-2 else '#E74C3C' for v in dim_maes]
        bars = ax.bar(range(1, num_tests+1), dim_maes, color=colors_bar, edgecolor='black', alpha=0.85)
        ax.axhline(y=5e-2, color='red', linestyle='--', linewidth=2, label='Deployable threshold')
        ax.set_title(f'Critical Dim {critical_dim}: MAE per Test', fontsize=13, fontweight='bold')
        ax.set_xlabel('Test Index')
        ax.set_ylabel('MAE')
        ax.set_xticks(range(1, num_tests+1))
        ax.legend()
        ax.grid(axis='y', alpha=0.3)
        for bar, val in zip(bars, dim_maes):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                   f'{val:.3f}', ha='center', va='bottom', fontsize=8, fontweight='bold')

        # 2. 关键维度时序对比(最佳测试)
        ax = axes[0, 1]
        best_idx = int(np.argmin([r["global_mae"] for r in self.all_results]))
        t = np.arange(self.chunk_size)
        ax.plot(t, self.all_results[best_idx]["jax_actions"][:, critical_dim], 
               'b-o', label='JAX', markersize=5, linewidth=2)
        ax.plot(t, self.all_results[best_idx]["triton_actions"][:, critical_dim], 
               'r-s', label='Triton', markersize=5, linewidth=2)
        ax.fill_between(t, 
                       self.all_results[best_idx]["jax_actions"][:, critical_dim],
                       self.all_results[best_idx]["triton_actions"][:, critical_dim],
                       alpha=0.3, color='purple')
        ax.set_title(f'Best Test (T{best_idx+1}): Dim {critical_dim} Sequence', fontsize=13, fontweight='bold')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Action Value')
        ax.legend()
        ax.grid(alpha=0.3)

        # 3. 关键维度时序对比(最差测试)
        ax = axes[1, 0]
        worst_idx = int(np.argmax([r["global_mae"] for r in self.all_results]))
        ax.plot(t, self.all_results[worst_idx]["jax_actions"][:, critical_dim], 
               'b-o', label='JAX', markersize=5, linewidth=2)
        ax.plot(t, self.all_results[worst_idx]["triton_actions"][:, critical_dim], 
               'r-s', label='Triton', markersize=5, linewidth=2)
        ax.fill_between(t, 
                       self.all_results[worst_idx]["jax_actions"][:, critical_dim],
                       self.all_results[worst_idx]["triton_actions"][:, critical_dim],
                       alpha=0.3, color='purple')
        ax.set_title(f'Worst Test (T{worst_idx+1}): Dim {critical_dim} Sequence', fontsize=13, fontweight='bold')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Action Value')
        ax.legend()
        ax.grid(alpha=0.3)

        # 4. 关键维度误差分布
        ax = axes[1, 1]
        critical_errors = np.array([r["triton_actions"][:, critical_dim] - r["jax_actions"][:, critical_dim] 
                                    for r in self.all_results]).flatten()
        ax.hist(critical_errors, bins=40, color='crimson', edgecolor='black', alpha=0.75)
        ax.axvline(x=0, color='black', linestyle='--', linewidth=2.5, label='Zero')
        ax.axvline(x=np.mean(critical_errors), color='green', linestyle='--', linewidth=2.5, 
                  label=f'Mean={np.mean(critical_errors):.4f}')
        ax.set_title(f'Dim {critical_dim} Error Distribution', fontsize=13, fontweight='bold')
        ax.set_xlabel('Error (Triton - JAX)')
        ax.set_ylabel('Frequency')
        ax.legend()
        ax.grid(alpha=0.3)

        fig.suptitle(f'Critical Dimension Analysis: Dim {critical_dim} (Avg MAE={per_dim_avg[critical_dim]:.4f})',
                    fontsize=15, fontweight='bold', y=1.01)
        plt.tight_layout()

        save_path = os.path.join(self.output_dir, "critical_dim_analysis.png")
        plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        print(f"  📊 关键维度分析图已保存: {save_path}")

    def compare(self, num_tests=10):
        print(f"
{'='*60}")
        print("JAX vs Triton DROID 推理对比 (discrete_state_input=True)")
        print(f"{'='*60}")
        print(f"可视化输出目录: {self.output_dir}")
        print(f"{'='*60}
")

        for i in range(num_tests):
            # 1. 生成一组随机输入(状态、噪声)
            droid_example = droid_policy.make_droid_example()
            noise_np = np.random.randn(self.chunk_size, 32).astype(np.float32)

            # 2. 提取原始图像和状态
            exterior = droid_example["observation/exterior_image_1_left"]
            wrist = droid_example["observation/wrist_image_left"]
            joint_pos = np.asarray(droid_example["observation/joint_position"])
            gripper_pos = np.asarray(droid_example["observation/gripper_position"])

            # 3. 【核心修复】对图像执行完全相同的预处理,供 JAX 和 Triton 共用
            img_base = self._preprocess_image(exterior)
            img_left = self._preprocess_image(wrist)
            img_right = np.zeros_like(img_base)  # 第三视角用全0填充

            # 4. 预处理状态(离散token)
            state_tokens = self._preprocess_state_discrete(joint_pos, gripper_pos)

            # 5. 【Triton 推理】将预处理后的数据转为 Torch Tensor 并送入 GPU
            images = torch.from_numpy(
                np.stack([img_base, img_left, img_right], axis=0)
            ).to(torch.bfloat16).cuda()
            state_torch = torch.from_numpy(state_tokens).to(torch.long).cuda()
            noise_torch = torch.from_numpy(noise_np).to(torch.bfloat16).cuda()
            triton_output = self.triton_model.forward(images, noise_torch, self.prompt, state_torch)
            triton_actions = self._unnormalize_actions(triton_output.cpu().float().numpy())

            # 6. 【JAX 推理】构造与 Triton 完全相同的输入字典
            jax_input = {
                "observation/exterior_image_1_left": img_base,
                "observation/wrist_image_left": img_left,
                "observation/joint_position": joint_pos,
                "observation/gripper_position": gripper_pos,
                "prompt": self.prompt,
            }
            jax_result = self.jax_policy.infer(jax_input, noise=noise_np)
            jax_actions = jax_result["actions"]

            # 7. 对齐动作长度并计算误差
            min_steps = min(jax_actions.shape[0], triton_actions.shape[0])
            jax_aligned = jax_actions[:min_steps]
            triton_aligned = triton_actions[:min_steps]

            # 计算指标
            mae = np.mean(np.abs(triton_aligned - jax_aligned))
            max_err = np.max(np.abs(triton_aligned - jax_aligned))
            rmse = np.sqrt(np.mean((triton_aligned - jax_aligned) ** 2))
            per_dim_mae_list = []
            for dim in range(self.action_dim):
                dim_mae = np.mean(np.abs(triton_aligned[:, dim] - jax_aligned[:, dim]))
                per_dim_mae_list.append(dim_mae)
            per_dim_mae_arr = np.array(per_dim_mae_list)

            # 打印结果(保留原有格式)
            print(f"
--- Test {i+1}/{num_tests} ---")
            print(f"JAX    actions: shape={jax_actions.shape}, range=[{jax_actions.min():.6f}, {jax_actions.max():.6f}]")
            print(f"Triton actions: shape={triton_actions.shape}, range=[{triton_actions.min():.6f}, {triton_actions.max():.6f}]")
            print(f"Global MAE: {mae:.6f}, Max Error: {max_err:.6f}, RMSE: {rmse:.6f}")
            print("Per-dimension MAE:")
            for dim in range(self.action_dim):
                print(f"  dim {dim}: {per_dim_mae_arr[dim]:.6f}")

            if mae < 1e-3:
                print("✅ PASS: MAE < 1e-3")
            elif mae < 1e-2:
                print("⚠️  WARNING: MAE < 1e-2 (acceptable)")
            elif mae < 5e-2:
                print("📊 MAE < 5e-2 (design trade-off, deployable)")
            else:
                print("❌ FAIL: MAE too large")

            # ========== 新增:存储结果并绘制单组可视化 ==========
            result = {
                "test_idx": i,
                "jax_actions": jax_aligned.copy(),
                "triton_actions": triton_aligned.copy(),
                "global_mae": mae,
                "global_max_err": max_err,
                "global_rmse": rmse,
                "per_dim_mae": per_dim_mae_arr.copy(),
            }
            self.all_results.append(result)

            # 绘制并保存单组测试的时序对比图
            self._plot_single_test(i, jax_aligned, triton_aligned, mae, max_err, rmse, per_dim_mae_arr)

        # ========== 新增:所有测试完成后,生成汇总可视化 ==========
        print(f"
{'='*60}")
        print("生成汇总可视化...")
        self._plot_summary_dashboard()
        self._plot_critical_dim_analysis()

        print(f"
{'='*60}")
        print("对比完成")
        print(f"所有可视化文件保存在: {self.output_dir}")
        print(f"{'='*60}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--triton_path", type=str, required=True)
    parser.add_argument("--jax_path", type=str, required=True)
    parser.add_argument("--norm_stats_dir", type=str, required=True)
    parser.add_argument("--config_name", type=str, default="pi05_droid")
    parser.add_argument("--prompt", type=str, default="Pick up the bottle.")
    parser.add_argument("--tokenizer_path", type=str, default="./paligemma-3b-pt-224")
    parser.add_argument("--discrete_state_input", action="store_true", default=True)
    parser.add_argument("--action_dim", type=int, default=8)
    parser.add_argument("--chunk_size", type=int, default=15)
    # ========== 新增:输出目录参数 ==========
    parser.add_argument("--output_dir", type=str, default="./comparison_results",
                       help="可视化结果输出目录")
    args = parser.parse_args()

    comp = DroidComparator(
        triton_path=args.triton_path,
        jax_path=args.jax_path,
        norm_stats_dir=args.norm_stats_dir,
        config_name=args.config_name,
        tokenizer_path=args.tokenizer_path,
        prompt=args.prompt,
        discrete_state_input=args.discrete_state_input,
        action_dim=args.action_dim,
        chunk_size=args.chunk_size,
        output_dir=args.output_dir,
    )
    comp.compare(num_tests=10)


if __name__ == "__main__":
    main()

运行效果,查看汇总分析:

汇总分析2:(整体的偏差比较大)

查看某组数据差异对比1:

查看某组数据差异对比2:

后续的DM0,有待更新中~

相关推荐
feasibility.3 小时前
ROS2+Gazebo+VLM服务:纯仿真环境下的具身智能闭环系统| 大脑-小脑分离控制
人工智能·机器人·ros·仿真·具身智能·vla·vlm
一颗小树x4 小时前
《VLA 系列》realtime-vla | 论文解读 加速推理 30Hz+
加速·vla·推理优化·realtime-vla
传说故事7 天前
【论文阅读】MEM: Multi-Scale Embodied Memory for Vision Language Action Models
论文阅读·人工智能·具身智能·vla
传说故事8 天前
【论文阅读】RLDX-1
论文阅读·人工智能·具身智能·vla
传说故事8 天前
【论文阅读】StereoVLA: Enhancing Vision-Language-Action Models with Stereo Vision
论文阅读·人工智能·具身智能·vla
Mike_66616 天前
推流和推理什么区别
推流·推理·cpu推理·cpu推流·gpu推流·gpu推理
qcx2317 天前
阿里 RynnVLA-002 源码深度拆解:一个 7B 模型如何同时当机器人大脑和世界模拟器
ai·机器人·llm·agent·具身智能·vla
AIDF202622 天前
第六篇:实战出击——深度学习的“减脂”与“提速”
人工智能·深度学习·框架·推理
AIDF20261 个月前
我们看一份报告的时候主要看什么
运维·服务器·推理·vllm