Isaac GR00T N1.7在Piper真机上的服务端/客户端部署实践

1、背景

Isaac GR00T N1.7是一个面向通用机器人操作任务的VLA模型。模型推理依赖GPU环境,而Piper真机侧则依赖CAN总线、RealSense相机、机器人SDK和本地安全控制逻辑。

为了让模型推理和真机控制解耦,本文采用了服务端/客户端架构,

复制代码
GPU推理机
  启动GR00T模型服务
  接收观测observation
  输出动作action chunk

Piper真机控制机
  采集相机图像和机械臂状态
  通过WebSocket请求远端模型
  执行动作并记录数据

好处是GR00T模型可以运行在H100、A100、RTX等GPU机器上,而Piper真机端只需要负责硬件控制、相机采集、动作队列和安全停止。

2、项目结构

本次部署相关文件主要在,

复制代码
scripts/piper/
  start_gr00t_piper_server.sh      # 启动GR00T WebSocket推理服务
  start_gr00t_piper_client.sh      # 启动Piper真机客户端
  gr00t_piper_ws_server.py         # GR00T到Piper协议适配服务端
  piper_gr00t_remote_rollout.py    # LeRobot Piper remote rollout包装器
  test_remote_handshake.py         # 最小握手测试
  README.md

third_party/lerobot/examples/piper/
  piper_remote_rollout.py          # Piper真机rollout入口
  remote_inference_protocol.md     # websocket-msgpack协议说明

3、环境配置

3.1、基础依赖

项目使用uv管理Python环境。先在项目根目录安装依赖,

复制代码
uv sync --all-extras

当前仓库的默认pyproject.toml面向x86 GPU环境,包含如下关键依赖,

复制代码
torch==2.7.1
torchvision==0.22.1
transformers==4.57.3
websockets>=13.1,<17.0
msgpack==1.1.0
numpy==1.26.4
pyzmq==27.0.1

如果是Jetson Orin、Thor或DGX Spark,需要使用scripts/deployment/下对应平台的安装脚本。

3.2、Piper客户端依赖

Piper真机侧依赖LeRobot的硬件组件,主要包括,

复制代码
remote-inference: websockets, msgpack
piper: python-can, piper_sdk
intelrealsense: pyrealsense2
dataset/viz: 数据记录与可视化

客户端最终会启动类似下面的LeRobot命令,

复制代码
uv run \
  --extra dataset \
  --extra remote-inference \
  --extra piper \
  --extra intelrealsense \
  --extra viz \
  lerobot-rollout ...

如果使用Rerun显示数据,可以打开,

复制代码
--display true

如果使用RealSense,需要确认前置相机和腕部相机序列号正确。

4、实现原理

4.1、通信协议

服务端和客户端之间使用WebSocket binary frame + MessagePack

协议版本为`1`,请求类型主要有三个,

复制代码
hello  - 建立连接并同步action keys、observation schema、task等信息
infer  - 客户端发送当前观测,服务端返回动作
reset  - 清空一次 rollout 的服务端状态

4.2、服务端职责

服务端入口是,

复制代码
scripts/piper/gr00t_piper_ws_server.py

主要完成以下工作,

  1. 加载GR00T checkpoint

  2. 初始化Gr00tPolicy

  3. 接收Piper客户端的observation

  4. 将远端observation转换成GR00T需要的输入格式

  5. 调用policy.get_action()

  6. 将GR00T输出动作重新映射成Piper action keys

  7. 返回给客户端执行

服务端默认的动作key是,

复制代码
joint1.pos
joint2.pos
joint3.pos
joint4.pos
joint5.pos
joint6.pos
gripper.pos

默认图像映射为,

复制代码
{
    "front": "observation.images.front",
    "wrist": "observation.images.wrist",
}

默认状态映射为,

复制代码
{
    "single_arm": [
        "joint1.pos",
        "joint2.pos",
        "joint3.pos",
        "joint4.pos",
        "joint5.pos",
        "joint6.pos",
    ],
    "gripper": ["gripper.pos"],
}

Piper客户端采集到的前视相机、腕部相机、机械臂关节和夹爪状态,会在服务端被整理成GR00T模型训练时对应的video、state、language输入。

4.3、客户端职责

客户端入口是,

复制代码
scripts/piper/piper_gr00t_remote_rollout.py

本质上是对LeRobot 0.5.2中examples/piper/piper_remote_rollout.py的包装,负责,

  1. 连接Piper真机CAN总线

  2. 打开RealSense相机

  3. 采集当前observation

  4. 通过WebSocket请求远端GR00T服务

  5. 接收动作chunk

  6. 按队列逐步发送到Piper

  7. 记录rollout数据

  8. 支持Rerun可视化和 timing debug

客户端本地保留了安全逻辑。如果服务端超时、返回格式错误、缺少action key或action shape不正确,客户端会fail closed,避免继续向机械臂发送错误动作。

5、启动服务端

在GPU推理机上运行,

复制代码
MODEL_PATH=/path/to/gr00t/checkpoint \
EMBODIMENT_TAG=new_embodiment \
TASK="put the bottle into the box." \
HOST=0.0.0.0 \
PORT=8000 \
DEVICE=cuda \
bash scripts/piper/start_gr00t_piper_server.sh

本项目脚本中的默认参数包括,

复制代码
HOST=0.0.0.0
PORT=8000
DEVICE=cuda
EMBODIMENT_TAG=new_embodiment

如果模型的modality key和默认映射不一致,可以通过JSON参数显式指定,

复制代码
bash scripts/piper/start_gr00t_piper_server.sh \
  --video-map '{"front":"observation.images.front","wrist":"observation.images.wrist"}' \
  --state-map '{"single_arm":["joint1.pos","joint2.pos","joint3.pos","joint4.pos","joint5.pos","joint6.pos"],"gripper":["gripper.pos"]}' \
  --action-map '{"single_arm":["joint1.pos","joint2.pos","joint3.pos","joint4.pos","joint5.pos","joint6.pos"],"gripper":["gripper.pos"]}'

6、服务端握手测试

在接真机前,建议先做最小握手测试,

复制代码
uv run --project third_party/lerobot --extra remote-inference \
  python scripts/piper/test_remote_handshake.py \
  --address ws://127.0.0.1:8000

如果返回`hello_ack`,说明WebSocket、MessagePack和服务端协议基本正常。

7、启动Piper客户端

在Piper真机控制机上运行,

复制代码
REMOTE_ADDRESS=ws://SERVER_IP:8000 \
TASK="put the yellow bottle into the box." \
DURATION=300 \
SPEED_PERCENT=20 \
STRATEGY=sentry \
bash scripts/piper/start_gr00t_piper_client.sh

常用参数示例,

复制代码
bash scripts/piper/start_gr00t_piper_client.sh \
  --can can0 \
  --front-camera-serial 254622078102 \
  --wrist-camera-serial 254622071629 \
  --strategy sentry \
  --display true

8、相关源码

8.1、scripts/piper/start_gr00t_piper_server.sh

复制代码
#!/usr/bin/env bash

HOST="${HOST:-0.0.0.0}"
PORT="${PORT:-8000}"
DEVICE="${DEVICE:-cuda}"
MODEL_PATH="${MODEL_PATH:-}"
EMBODIMENT_TAG="${EMBODIMENT_TAG:-new_embodiment}"
MODALITY_CONFIG_PATH="${MODALITY_CONFIG_PATH:-}"
TASK="${TASK:-}"
export UV_CACHE_DIR="/tmp/uv-cache"

echo "GR00T Piper remote server"
echo "Starting on ws://${HOST}:${PORT}"
echo "Model: ${MODEL_PATH}"
echo "Embodiment tag: ${EMBODIMENT_TAG}"
echo "Task: ${TASK}"
echo "UV cache: ${UV_CACHE_DIR}"

cmd=(
  uv run --no-sync python scripts/piper/gr00t_piper_ws_server.py
  --host "${HOST}"
  --port "${PORT}"
  --device "${DEVICE}"
  --model-path "${MODEL_PATH}"
  --embodiment-tag "${EMBODIMENT_TAG}"
  --task "${TASK}"
)

if [[ -n "${MODALITY_CONFIG_PATH}" ]]; then
  cmd+=(--modality-config-path "${MODALITY_CONFIG_PATH}")
fi

exec "${cmd[@]}" "$@"

8.2、scripts/piper/start_gr00t_piper_client.sh

复制代码
#!/usr/bin/env bash

REMOTE_ADDRESS="${REMOTE_ADDRESS:-ws://127.0.0.1:8000}"
TASK="${TASK:-}"
DURATION="${DURATION:-300}"
SPEED_PERCENT="${SPEED_PERCENT:-20}"
STRATEGY="${STRATEGY:-sentry}"
REMOTE_ACTION_JOINTS_UNIT="${REMOTE_ACTION_JOINTS_UNIT:-deg}"
DATASET_ROOT="${DATASET_ROOT:-}"
DATASET_REPO_ID="${DATASET_REPO_ID:-local}"

echo "GR00T Piper remote client"
echo "Remote: ${REMOTE_ADDRESS}"
echo "Task: ${TASK}"
echo "Strategy: ${STRATEGY}"

exec uv run --no-sync python scripts/piper/piper_gr00t_remote_rollout.py \
  --remote-address "${REMOTE_ADDRESS}" \
  --task "${TASK}" \
  --duration "${DURATION}" \
  --speed-percent "${SPEED_PERCENT}" \
  --strategy "${STRATEGY}" \
  --remote-action-joints-unit "${REMOTE_ACTION_JOINTS_UNIT}" \
  --dataset-root "${DATASET_ROOT}" \
  --dataset-repo-id "${DATASET_REPO_ID}" \
  --display true \
  "$@"

8.3、scripts/piper/gr00t_piper_ws_server.py

复制代码
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import importlib
import inspect
import json
import logging
import os
from pathlib import Path
import sys
import time
import traceback
from typing import Any

import numpy as np


PROTOCOL_VERSION = 1
DEFAULT_ACTION_KEYS = [
    "joint1.pos",
    "joint2.pos",
    "joint3.pos",
    "joint4.pos",
    "joint5.pos",
    "joint6.pos",
    "gripper.pos",
]
DEFAULT_VIDEO_MAP = {
    "front": "observation.images.front",
    "wrist": "observation.images.wrist",
}
DEFAULT_STATE_MAP = {
    "single_arm": DEFAULT_ACTION_KEYS[:6],
    "gripper": ["gripper.pos"],
}
DEFAULT_ACTION_MAP = DEFAULT_STATE_MAP

logger = logging.getLogger("gr00t_piper_ws_server")


def _pack_array(obj: Any) -> Any:
    if isinstance(obj, np.ndarray):
        if obj.dtype.kind in ("V", "O", "c"):
            raise ValueError(f"Unsupported ndarray dtype for msgpack transport: {obj.dtype}")
        return {
            b"__ndarray__": True,
            b"data": obj.tobytes(),
            b"dtype": obj.dtype.str,
            b"shape": obj.shape,
        }
    if isinstance(obj, np.generic):
        return {b"__npgeneric__": True, b"data": obj.item(), b"dtype": obj.dtype.str}
    return obj


def _unpack_array(obj: dict) -> Any:
    if b"__ndarray__" in obj:
        return np.ndarray(
            buffer=obj[b"data"],
            dtype=np.dtype(obj[b"dtype"]),
            shape=obj[b"shape"],
        )
    if b"__npgeneric__" in obj:
        return np.dtype(obj[b"dtype"]).type(obj[b"data"])
    return obj


def packb(payload: Any) -> bytes:
    import msgpack

    return msgpack.packb(payload, default=_pack_array, use_bin_type=True)


def unpackb(payload: bytes) -> Any:
    import msgpack

    return msgpack.unpackb(payload, object_hook=_unpack_array, raw=False)


def _json_mapping(raw: str | None, default: dict[str, Any]) -> dict[str, Any]:
    if raw is None or raw == "":
        return dict(default)
    parsed = json.loads(raw)
    if not isinstance(parsed, dict):
        raise ValueError(f"Mapping must be a JSON object, got {type(parsed).__name__}")
    return parsed


def _load_modality_config(path: str | None) -> None:
    if not path:
        return
    config_path = Path(path).expanduser().resolve()
    if not config_path.exists() or config_path.suffix != ".py":
        raise FileNotFoundError(f"Modality config must be an existing .py file: {path}")
    sys.path.append(str(config_path.parent))
    importlib.import_module(config_path.stem)
    logger.info("Loaded modality config: %s", config_path)


def _ensure_batched_video(value: Any, *, key: str) -> np.ndarray:
    arr = np.asarray(value)
    if arr.dtype != np.uint8:
        arr = arr.astype(np.uint8, copy=False)
    if arr.ndim == 3:
        arr = arr[None, None, ...]
    elif arr.ndim == 4:
        arr = arr[None, ...]
    elif arr.ndim != 5:
        raise ValueError(f"Video key {key!r} must have shape HWC, THWC, or BTHWC, got {arr.shape}")
    if arr.shape[-1] != 3:
        raise ValueError(f"Video key {key!r} must have 3 RGB channels, got shape {arr.shape}")
    return np.ascontiguousarray(arr)


def _stack_remote_scalars(
    observation: dict[str, Any],
    remote_keys: list[str],
    *,
    key: str,
    state_names: list[str],
) -> np.ndarray:
    missing = [remote_key for remote_key in remote_keys if remote_key not in observation]
    if not missing:
        values = [
            float(np.asarray(observation[remote_key]).reshape(-1)[0]) for remote_key in remote_keys
        ]
        return np.asarray(values, dtype=np.float32)[None, None, :]

    state_key = "observation.state"
    if state_key not in observation:
        available = sorted(str(remote_key) for remote_key in observation)
        raise KeyError(
            f"Missing remote observation keys for {key!r}: {missing}. "
            f"Also did not find {state_key!r}. Available keys={available}"
        )

    state = np.asarray(observation[state_key], dtype=np.float32).reshape(-1)
    missing_state_names = [
        remote_key for remote_key in remote_keys if remote_key not in state_names
    ]
    if missing_state_names:
        raise KeyError(
            f"Remote observation uses {state_key!r}, but its names do not contain "
            f"{missing_state_names} for policy key {key!r}. state_names={state_names}"
        )
    indices = [state_names.index(remote_key) for remote_key in remote_keys]
    if max(indices) >= state.shape[0]:
        raise ValueError(
            f"{state_key!r} has length {state.shape[0]}, but mapping for {key!r} "
            f"requires indices {indices} from state_names={state_names}"
        )
    values = [float(state[index]) for index in indices]
    return np.asarray(values, dtype=np.float32)[None, None, :]


def _infer_policy_key(policy_keys: list[str], requested_key: str, mapping: dict[str, Any]) -> str:
    if requested_key in mapping:
        return requested_key
    if len(policy_keys) == 1 and len(mapping) == 1:
        return policy_keys[0]
    raise KeyError(
        f"Mapping does not provide policy key {requested_key!r}. "
        f"Policy keys={policy_keys}, mapping keys={list(mapping)}"
    )


def _flatten_policy_actions(
    policy_actions: dict[str, Any],
    *,
    action_map: dict[str, list[str]],
    requested_action_keys: list[str],
) -> np.ndarray:
    rows_by_key: dict[str, np.ndarray] = {}
    for policy_key, remote_keys in action_map.items():
        if policy_key not in policy_actions:
            raise KeyError(f"Policy did not return action key {policy_key!r}")
        arr = np.asarray(policy_actions[policy_key], dtype=np.float32)
        if arr.ndim == 3:
            if arr.shape[0] != 1:
                raise ValueError(
                    f"Only batch size 1 is supported for action {policy_key!r}, got {arr.shape}"
                )
            arr = arr[0]
        elif arr.ndim == 2:
            pass
        elif arr.ndim == 1:
            arr = arr[:, None]
        else:
            raise ValueError(
                f"Action {policy_key!r} must have shape BTD, TD, or T, got {arr.shape}"
            )
        if arr.shape[1] != len(remote_keys):
            raise ValueError(
                f"Action {policy_key!r} dimension {arr.shape[1]} does not match "
                f"mapped remote keys {remote_keys}"
            )
        for column, remote_key in enumerate(remote_keys):
            rows_by_key[remote_key] = arr[:, column]

    missing = [key for key in requested_action_keys if key not in rows_by_key]
    if missing:
        raise KeyError(f"Policy action mapping did not produce requested keys: {missing}")

    min_rows = min(rows_by_key[key].shape[0] for key in requested_action_keys)
    return np.stack([rows_by_key[key][:min_rows] for key in requested_action_keys], axis=1).astype(
        np.float32, copy=False
    )


class Gr00tPiperRemoteAdapter:
    def __init__(self, args: argparse.Namespace) -> None:
        _load_modality_config(args.modality_config_path)

        from gr00t.data.embodiment_tags import EmbodimentTag
        from gr00t.policy.gr00t_policy import Gr00tPolicy

        self.video_map = _json_mapping(args.video_map, DEFAULT_VIDEO_MAP)
        self.state_map = _json_mapping(args.state_map, DEFAULT_STATE_MAP)
        self.action_map = _json_mapping(args.action_map, DEFAULT_ACTION_MAP)
        self.action_keys = list(DEFAULT_ACTION_KEYS)
        self.observation_state_names = list(DEFAULT_ACTION_KEYS)
        self.last_task = args.task

        logger.info("Loading GR00T policy from %s", args.model_path)
        self.policy = Gr00tPolicy(
            embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
            model_path=args.model_path,
            device=args.device,
            strict=args.strict,
        )
        self.policy.reset()
        self.modality_configs = self.policy.get_modality_config()
        self.language_key = self.modality_configs["language"].modality_keys[0]
        self.use_length = len(self.modality_configs["action"].delta_indices)
        logger.info("Policy modality keys: %s", self._modality_key_summary())

    def hello(self, request: dict[str, Any]) -> dict[str, Any]:
        self._check_protocol(request)
        self.action_keys = list(request.get("action_keys") or DEFAULT_ACTION_KEYS)
        self._update_observation_state_names(request.get("observation_features"))
        self.last_task = str(request.get("task") or self.last_task or "")
        missing = [key for key in self.action_keys if key not in self._mapped_remote_action_keys()]
        if missing:
            raise KeyError(f"Client requested action keys not covered by --action-map: {missing}")
        return {
            "type": "hello_ack",
            "accepted_protocols": ["websocket-msgpack"],
            "protocol_version": PROTOCOL_VERSION,
            "use_length": self.use_length,
            "action_features": {key: {"dtype": "float32", "shape": []} for key in self.action_keys},
            "policy_modality_keys": self._modality_key_summary(),
        }

    def infer(self, request: dict[str, Any]) -> dict[str, Any]:
        self._check_protocol(request)
        request_id = request.get("request_id")
        task = str(request.get("task") or self.last_task or "")
        observation = request.get("observation")
        if not isinstance(observation, dict):
            raise ValueError("infer request missing dict field 'observation'")

        start = time.perf_counter()
        gr00t_obs = self._to_gr00t_observation(observation, task=task)
        actions, _info = self.policy.get_action(gr00t_obs)
        remote_actions = _flatten_policy_actions(
            actions,
            action_map=self.action_map,
            requested_action_keys=self.action_keys,
        )
        infer_ms = (time.perf_counter() - start) * 1e3
        return {
            "type": "action",
            "request_id": request_id,
            "action_keys": list(self.action_keys),
            "actions": remote_actions,
            "server_timing": {"infer_ms": infer_ms},
        }

    def reset(self, request: dict[str, Any]) -> dict[str, Any]:
        self._check_protocol(request)
        self.policy.reset()
        return {"type": "reset_ack", "protocol_version": PROTOCOL_VERSION}

    def _to_gr00t_observation(self, observation: dict[str, Any], *, task: str) -> dict[str, Any]:
        video: dict[str, np.ndarray] = {}
        for policy_key in self.modality_configs["video"].modality_keys:
            remote_key = self.video_map.get(policy_key)
            if remote_key is None and len(self.modality_configs["video"].modality_keys) == len(
                self.video_map
            ):
                remote_key = self.video_map[list(self.video_map.keys())[len(video)]]
            if remote_key is None:
                raise KeyError(f"No --video-map entry for policy video key {policy_key!r}")
            if remote_key not in observation:
                raise KeyError(
                    f"Missing remote video key {remote_key!r} for policy key {policy_key!r}"
                )
            video[policy_key] = _ensure_batched_video(observation[remote_key], key=remote_key)

        state: dict[str, np.ndarray] = {}
        for policy_key in self.modality_configs["state"].modality_keys:
            mapped_key = _infer_policy_key(
                list(self.modality_configs["state"].modality_keys), policy_key, self.state_map
            )
            remote_keys = list(self.state_map[mapped_key])
            state[policy_key] = _stack_remote_scalars(
                observation,
                remote_keys,
                key=policy_key,
                state_names=self.observation_state_names,
            )

        language = {self.language_key: [[task]]}
        return {"video": video, "state": state, "language": language}

    def _mapped_remote_action_keys(self) -> set[str]:
        return {
            remote_key for remote_keys in self.action_map.values() for remote_key in remote_keys
        }

    def _modality_key_summary(self) -> dict[str, list[str]]:
        return {name: list(config.modality_keys) for name, config in self.modality_configs.items()}

    def _update_observation_state_names(self, observation_features: Any) -> None:
        if not isinstance(observation_features, dict):
            return
        state_feature = observation_features.get("observation.state")
        if not isinstance(state_feature, dict):
            return
        names = state_feature.get("names")
        if isinstance(names, list) and all(isinstance(name, str) for name in names):
            self.observation_state_names = list(names)
            logger.info("Remote observation.state names: %s", self.observation_state_names)

    @staticmethod
    def _check_protocol(request: dict[str, Any]) -> None:
        if request.get("protocol_version") != PROTOCOL_VERSION:
            raise ValueError(
                f"Unsupported protocol_version={request.get('protocol_version')}; "
                f"expected {PROTOCOL_VERSION}"
            )


def _serve_forever(args: argparse.Namespace) -> None:
    try:
        from websockets.sync.server import serve
    except ImportError as exc:
        raise ImportError(
            "Install websockets to run this server. In this repo, "
            "`uv run --project third_party/lerobot --extra remote-inference ...` "
            "or adding websockets to the active environment both work."
        ) from exc

    adapter = Gr00tPiperRemoteAdapter(args)
    api_key = os.environ.get(args.api_key_env) if args.api_key_env else None

    def process_request(connection, request):  # noqa: ANN001 - websockets compatibility
        if api_key is None:
            return None
        auth = request.headers.get("Authorization")
        if auth == f"Bearer {api_key}":
            return None
        return connection.respond(401, "Unauthorized\n")

    def handler(ws) -> None:  # noqa: ANN001 - websockets compatibility
        logger.info("Client connected")
        request: Any = None
        for frame in ws:
            if isinstance(frame, str):
                ws.send(packb({"type": "error", "error": "text frames are not supported"}))
                continue
            try:
                request = unpackb(frame)
                if not isinstance(request, dict):
                    raise ValueError(f"Request must be a dict, got {type(request).__name__}")
                request_type = request.get("type")
                if request_type == "hello":
                    response = adapter.hello(request)
                elif request_type == "infer":
                    response = adapter.infer(request)
                elif request_type == "reset":
                    response = adapter.reset(request)
                else:
                    raise ValueError(f"Unknown request type: {request_type!r}")
            except Exception as exc:
                logger.exception("Remote request failed")
                response = {
                    "type": "error",
                    "request_id": request.get("request_id") if isinstance(request, dict) else None,
                    "error": f"{type(exc).__name__}: {exc}",
                    "traceback": traceback.format_exc(),
                }
            ws.send(packb(response))
        logger.info("Client disconnected")

    kwargs: dict[str, Any] = {
        "host": args.host,
        "port": args.port,
        "compression": None,
        "max_size": None,
    }
    if api_key is not None:
        kwargs["process_request"] = process_request
    if "open_timeout" in inspect.signature(serve).parameters:
        kwargs["open_timeout"] = args.open_timeout_s

    logger.info("Serving GR00T Piper remote inference on ws://%s:%s", args.host, args.port)
    with serve(handler, **kwargs) as server:
        server.serve_forever()


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model-path", required=True, help="GR00T N1.7 checkpoint directory")
    parser.add_argument("--embodiment-tag", default="new_embodiment")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--strict", action=argparse.BooleanOptionalAction, default=True)
    parser.add_argument("--modality-config-path", default=None)
    parser.add_argument("--task", default="")

    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--open-timeout-s", type=float, default=10.0)
    parser.add_argument("--api-key-env", default=None)

    parser.add_argument(
        "--video-map", default=None, help="JSON policy video key -> remote observation key"
    )
    parser.add_argument(
        "--state-map", default=None, help="JSON policy state key -> remote scalar keys"
    )
    parser.add_argument(
        "--action-map", default=None, help="JSON policy action key -> remote action keys"
    )
    parser.add_argument("--log-level", default="INFO")
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    logging.basicConfig(
        level=getattr(logging, args.log_level.upper()),
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    )
    _serve_forever(args)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

8.4、scripts/piper/piper_gr00t_remote_rollout.py

复制代码
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import os
from pathlib import Path
import subprocess
from urllib.parse import urlparse


REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_LEROBOT_REPO = str(REPO_ROOT / "third_party/lerobot")
DEFAULT_REMOTE_ADDRESS = ""
DEFAULT_TASK = ""
DEFAULT_DATASET_ROOT = ""
DEFAULT_DATASET_REPO_ID = "local"


def build_command(args: argparse.Namespace, passthrough: list[str]) -> list[str]:
    script = Path(args.lerobot_repo).expanduser() / "examples/piper/piper_remote_rollout.py"
    cmd = [
        "uv",
        "run",
        "--no-sync",
        "python",
        str(script),
        f"--remote-address={args.remote_address}",
        f"--timeout-s={args.timeout_s}",
        f"--max-queued-actions={args.max_queued_actions}",
        f"--task={args.task}",
        f"--duration={args.duration}",
        f"--speed-percent={args.speed_percent}",
        f"--strategy={args.strategy}",
        f"--remote-action-joints-unit={args.remote_action_joints_unit}",
        f"--dataset-root={args.dataset_root}",
        f"--dataset-repo-id={args.dataset_repo_id}",
        f"--display={args.display}",
    ]
    if args.api_key_env:
        cmd.append(f"--api-key-env={args.api_key_env}")
    if args.dry_run:
        cmd.append("--dry-run")
    cmd.extend(passthrough)
    return cmd


def parse_args(argv: list[str] | None = None) -> tuple[argparse.Namespace, list[str]]:
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--lerobot-repo", default=DEFAULT_LEROBOT_REPO)
    parser.add_argument("--remote-address", default=DEFAULT_REMOTE_ADDRESS)
    parser.add_argument("--timeout-s", type=float, default=30.0)
    parser.add_argument("--api-key-env", default=None)
    parser.add_argument("--max-queued-actions", type=int, default=64)
    parser.add_argument("--task", default=DEFAULT_TASK)
    parser.add_argument("--duration", type=float, default=300)
    parser.add_argument("--speed-percent", type=int, default=20)
    parser.add_argument(
        "--strategy", choices=["base", "sentry", "highlight", "dagger"], default="base"
    )
    parser.add_argument("--remote-action-joints-unit", choices=["deg", "rad"], default="deg")
    parser.add_argument("--dataset-root", default=DEFAULT_DATASET_ROOT)
    parser.add_argument("--dataset-repo-id", default=DEFAULT_DATASET_REPO_ID)
    parser.add_argument("--display", choices=["true", "false"], default="false")
    parser.add_argument("--dry-run", action="store_true")
    return parser.parse_known_args(argv)


def main(argv: list[str] | None = None) -> int:
    args, passthrough = parse_args(argv)
    cmd = build_command(args, passthrough)
    print("Command:", " ".join(cmd))
    return subprocess.call(cmd, cwd=args.lerobot_repo)


if __name__ == "__main__":
    raise SystemExit(main())

8.5、scripts/piper/test_remote_handshake.py

复制代码
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import inspect
import sys
import time


PROTOCOL_VERSION = 1
DEFAULT_ACTION_KEYS = [
    "joint1.pos",
    "joint2.pos",
    "joint3.pos",
    "joint4.pos",
    "joint5.pos",
    "joint6.pos",
    "gripper.pos",
]


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--address", default="")
    parser.add_argument("--timeout-s", type=float, default=10.0)
    parser.add_argument("--wait-s", type=float, default=0.0)
    args = parser.parse_args()

    import msgpack
    import websockets.sync.client as ws_client

    payload = {
        "type": "hello",
        "protocol_version": PROTOCOL_VERSION,
        "robot_type": "piper",
        "fps": 30.0,
        "task": "handshake test",
        "observation_features": {},
        "action_keys": DEFAULT_ACTION_KEYS,
    }

    kwargs = {"compression": None, "max_size": None, "open_timeout": args.timeout_s}
    if "proxy" in inspect.signature(ws_client.connect).parameters:
        kwargs["proxy"] = None

    deadline = time.monotonic() + max(args.wait_s, 0.0)
    last_exc: BaseException | None = None
    while True:
        try:
            with ws_client.connect(args.address, **kwargs) as ws:
                ws.send(msgpack.packb(payload, use_bin_type=True))
                response = ws.recv(timeout=args.timeout_s)
                if isinstance(response, str):
                    raise RuntimeError(f"Server returned text frame: {response}")
                print(msgpack.unpackb(response, raw=False))
                return 0
        except Exception as exc:
            last_exc = exc
            if time.monotonic() >= deadline:
                break
            time.sleep(1.0)

    print(f"Remote handshake failed: {type(last_exc).__name__}: {last_exc}", file=sys.stderr)
    return 1


if __name__ == "__main__":
    raise SystemExit(main())

8.6、third_party/lerobot/examples/piper/piper_remote_rollout.py

复制代码
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import os
import subprocess
import sys
import tempfile
import threading
from contextlib import suppress
from pathlib import Path

from examples.piper.piper_camera_config import build_realsense_camera_arg

DEFAULT_CAN = "can0"
DEFAULT_ROBOT_ID = "piper"
DEFAULT_SPEED = 20
DEFAULT_FPS = 30
DEFAULT_DURATION_S = 300
DEFAULT_TASK = ""
DEFAULT_REMOTE_ADDRESS = ""
DEFAULT_DATASET_ROOT = ""
DEFAULT_DATASET_REPO_ID = "local"
DEFAULT_FRONT_CAMERA_SERIAL = ""
DEFAULT_WRIST_CAMERA_SERIAL = ""
DEFAULT_CAMERA_WIDTH = 1280
DEFAULT_CAMERA_HEIGHT = 720
DEFAULT_CAMERA_FPS = 30
DEFAULT_RERUN_PORT = 9876
DEFAULT_ROBOT_ACTION_PROCESSOR = "piper_rad_to_deg"
DEFAULT_ROBOT_OBSERVATION_PROCESSOR = "piper_deg_to_rad"

RECORDING_STRATEGIES = {"sentry", "highlight", "dagger"}
UPLOAD_INTERVAL_STRATEGIES = {"sentry", "dagger"}
STREAMING_ENCODING_STRATEGIES = {"sentry", "highlight"}


def _camera_args(args: argparse.Namespace) -> list[str]:
    return build_realsense_camera_arg(args)


def _passthrough_option_value(passthrough: list[str], option: str) -> str | None:
    value: str | None = None
    prefix = f"{option}="
    index = 0
    while index < len(passthrough):
        arg = passthrough[index]
        if arg.startswith(prefix):
            value = arg.split("=", 1)[1]
            index += 1
            continue
        if arg == option:
            if index + 1 >= len(passthrough) or passthrough[index + 1].startswith("--"):
                raise ValueError(f"{option} requires a value")
            value = passthrough[index + 1]
            index += 2
            continue
        index += 1
    return value


def _expand_policy_path(policy_path: str | None) -> str | None:
    if policy_path is None:
        return None
    policy_path = policy_path.strip()
    if policy_path.startswith("~"):
        return str(Path(policy_path).expanduser())
    return policy_path


def _resolve_policy_path(args: argparse.Namespace, passthrough: list[str]) -> tuple[str | None, bool]:
    wrapper_policy_path = _expand_policy_path(args.policy_path)
    passthrough_policy_path = _expand_policy_path(_passthrough_option_value(passthrough, "--policy.path"))
    if wrapper_policy_path and passthrough_policy_path and wrapper_policy_path != passthrough_policy_path:
        raise ValueError(
            "Conflicting policy paths: "
            f"--policy-path={wrapper_policy_path!r} and --policy.path={passthrough_policy_path!r}"
        )
    return wrapper_policy_path or passthrough_policy_path, wrapper_policy_path is not None


def _resolve_inference_type(args: argparse.Namespace, passthrough: list[str]) -> tuple[str, bool]:
    policy_path, _has_wrapper_policy_path = _resolve_policy_path(args, passthrough)
    passthrough_inference_type = _passthrough_option_value(passthrough, "--inference.type")
    if passthrough_inference_type:
        return passthrough_inference_type, False
    if policy_path:
        return "sync", True
    return "remote", True


def _write_task_file(task_file: Path, task: str) -> None:
    task_file.parent.mkdir(parents=True, exist_ok=True)
    tmp_file = task_file.with_name(f".{task_file.name}.tmp")
    tmp_file.write_text(f"{task.strip()}\n", encoding="utf-8")
    tmp_file.replace(task_file)


def prepare_task_file(args: argparse.Namespace) -> tuple[Path | None, bool]:
    if args.tui:
        return None, False

    if args.task_file:
        task_file = Path(args.task_file).expanduser()
        args.task_file = str(task_file)
        _write_task_file(task_file, args.task)
        return task_file, False

    if not args.interactive_task:
        return None, False

    fd, task_file_name = tempfile.mkstemp(
        prefix="lerobot_piper_remote_task_",
        suffix=".txt",
    )
    os.close(fd)
    task_file = Path(task_file_name)
    args.task_file = str(task_file)
    _write_task_file(task_file, args.task)
    return task_file, True


def _interactive_task_loop(task_file: Path, initial_task: str, stop_event: threading.Event) -> None:
    current_task = initial_task.strip()
    print(
        "Interactive task input enabled. Type a new task, /show, or /exit.",
        flush=True,
    )
    while not stop_event.is_set():
        line = sys.stdin.readline()
        if not line:
            return
        new_task = line.strip()
        if not new_task:
            continue
        if new_task == "/exit":
            print("Interactive task input disabled; rollout continues.", flush=True)
            return
        if new_task == "/show":
            with suppress(OSError):
                current_task = task_file.read_text(encoding="utf-8").strip() or current_task
            print(f"Current task: {current_task}", flush=True)
            continue

        current_task = new_task
        _write_task_file(task_file, current_task)
        print(f"Task updated: {current_task}", flush=True)


def build_command(args: argparse.Namespace, passthrough: list[str]) -> list[str]:
    dataset_root = Path(args.dataset_root).expanduser()
    records_data = args.strategy in RECORDING_STRATEGIES
    entrypoint = "lerobot-rollout-tui" if args.tui else "lerobot-rollout"
    policy_path, should_append_policy_path = _resolve_policy_path(args, passthrough)
    inference_type, should_append_inference_type = _resolve_inference_type(args, passthrough)
    is_remote_inference = inference_type == "remote"
    if not is_remote_inference and not policy_path:
        raise ValueError(
            f"--policy-path or --policy.path is required when --inference.type={inference_type}"
        )
    if not is_remote_inference and args.remote_action_joints_unit == "rad":
        raise ValueError("--remote-action-joints-unit=rad is only supported with remote inference")
    has_robot_action_processor = any(
        arg.startswith("--robot_action_processor.") or arg == "--robot_action_processor"
        for arg in passthrough
    )
    has_robot_observation_processor = any(
        arg.startswith("--robot_observation_processor.") or arg == "--robot_observation_processor"
        for arg in passthrough
    )

    cmd = [
        "uv",
        "run",
        "--extra",
        "dataset",
        "--extra",
        "remote-inference" if is_remote_inference else "pi",
        "--extra",
        "piper",
        "--extra",
        "intelrealsense",
        "--extra",
        "viz",
    ]
    if args.tui:
        cmd.extend(["--extra", "tui"])
    cmd.extend(
        [
            entrypoint,
            f"--strategy.type={args.strategy}",
            "--robot.type=piper",
            f"--robot.can_name={args.can}",
            f"--robot.speed_percent={args.speed_percent}",
            f"--robot.id={args.robot_id}",
            f"--task={args.task}",
            f"--fps={args.fps}",
        ]
    )
    if should_append_inference_type:
        cmd.append(f"--inference.type={inference_type}")
    if is_remote_inference:
        cmd.extend(
            [
                "--inference.protocol=websocket-msgpack",
                f"--inference.address={args.remote_address}",
                f"--inference.timeout_s={args.timeout_s}",
                f"--inference.max_queued_actions={args.max_queued_actions}",
            ]
        )
    elif should_append_policy_path:
        cmd.append(f"--policy.path={policy_path}")
    if not args.tui:
        cmd.append(f"--duration={args.duration}")

    if args.task_file and not args.tui:
        cmd.append(f"--task_file={Path(args.task_file).expanduser()}")

    if is_remote_inference and args.remote_action_joints_unit == "rad" and not has_robot_action_processor:
        cmd.append(f"--robot_action_processor.type={DEFAULT_ROBOT_ACTION_PROCESSOR}")
    if (
        is_remote_inference
        and args.remote_action_joints_unit == "rad"
        and not has_robot_observation_processor
    ):
        cmd.append(f"--robot_observation_processor.type={DEFAULT_ROBOT_OBSERVATION_PROCESSOR}")

    if args.strategy in UPLOAD_INTERVAL_STRATEGIES:
        cmd.append(f"--strategy.upload_every_n_episodes={args.upload_every_n_episodes}")

    if records_data:
        cmd.extend(
            [
                f"--dataset.repo_id={args.dataset_repo_id}",
                f"--dataset.root={dataset_root}",
                f"--dataset.single_task={args.task}",
                "--dataset.push_to_hub=false",
            ]
        )
        if args.strategy in STREAMING_ENCODING_STRATEGIES:
            cmd.append("--dataset.streaming_encoding=true")
        if args.encoder_threads is not None:
            cmd.append(f"--dataset.encoder_threads={args.encoder_threads}")
        if args.depth_encoder_vcodec != "auto":
            cmd.append(f"--dataset.depth_encoder.vcodec={args.depth_encoder_vcodec}")
        if args.depth_encoder_pix_fmt is not None:
            cmd.append(f"--dataset.depth_encoder.pix_fmt={args.depth_encoder_pix_fmt}")

    if args.display == "true":
        cmd.extend([f"--display_data={args.display}", f"--display_port={args.rerun_port}"])

    if args.timing_debug:
        cmd.append("--timing_debug=true")
        if args.timing_debug_path:
            cmd.append(f"--timing_debug_path={args.timing_debug_path}")

    if args.rerun_ip:
        cmd.append(f"--display_ip={args.rerun_ip}")
    if is_remote_inference and args.api_key_env:
        cmd.append(f"--inference.api_key_env={args.api_key_env}")
    if args.no_return_to_initial_position:
        cmd.append("--return_to_initial_position=false")

    cmd.extend(_camera_args(args))
    cmd.extend(passthrough)
    return cmd


def parse_args(argv: list[str] | None = None) -> tuple[argparse.Namespace, list[str]]:
    parser = argparse.ArgumentParser(
        description="Wrapper for Piper rollout with recording and Rerun visualization.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("--remote-address", default=DEFAULT_REMOTE_ADDRESS)
    parser.add_argument("--timeout-s", type=float, default=30.0)
    parser.add_argument("--api-key-env", default=None)
    parser.add_argument("--max-queued-actions", type=int, default=64)
    parser.add_argument(
        "--policy-path",
        default=None,
        help="Local LeRobot policy path. Enables local sync inference unless --inference.type is passed through.",
    )

    parser.add_argument("--can", default=DEFAULT_CAN)
    parser.add_argument("--robot-id", default=DEFAULT_ROBOT_ID)
    parser.add_argument("--speed-percent", type=int, default=DEFAULT_SPEED)
    parser.add_argument("--no-return-to-initial-position", action="store_true")

    parser.add_argument("--dataset-root", default=DEFAULT_DATASET_ROOT)
    parser.add_argument("--dataset-repo-id", default=DEFAULT_DATASET_REPO_ID)
    parser.add_argument("--upload-every-n-episodes", type=int, default=1000000)
    parser.add_argument("--encoder-threads", type=int, default=2)
    parser.add_argument(
        "--depth-encoder-vcodec",
        default="auto",
      ),
    )
    parser.add_argument("--depth-encoder-pix-fmt", default=None, help="Depth pixel format override.")

    parser.add_argument("--task", default=DEFAULT_TASK)
    parser.add_argument(
        "--tui",
        action=argparse.BooleanOptionalAction,
        default=False,
        help="Run through the Textual/Rich interactive rollout UI.",
    )
    parser.add_argument("--task-file", default=None, help="File watched by lerobot-rollout for task updates.")
    parser.add_argument(
        "--interactive-task",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Read task updates from stdin while rollout is running.",
    )
    parser.add_argument("--fps", type=float, default=DEFAULT_FPS)
    parser.add_argument("--duration", type=float, default=DEFAULT_DURATION_S)
    parser.add_argument(
        "--remote-action-joints-unit",
        choices=["rad", "deg"],
        default="deg",
        help=(
            "Unit of joint1..joint6 expected/returned by the remote policy. "
            "rad inserts piper_deg_to_rad for observations and piper_rad_to_deg for actions."
        ),
    )

    parser.add_argument("--rerun-port", type=int, default=DEFAULT_RERUN_PORT)
    parser.add_argument(
        "--rerun-ip",
        default=None,
        help="Connect to an existing remote Rerun server. Leave unset to spawn a local viewer.",
    )

    parser.add_argument("--front-camera-serial", default=DEFAULT_FRONT_CAMERA_SERIAL)
    parser.add_argument("--wrist-camera-serial", default=DEFAULT_WRIST_CAMERA_SERIAL)
    parser.add_argument("--camera-width", type=int, default=DEFAULT_CAMERA_WIDTH)
    parser.add_argument("--camera-height", type=int, default=DEFAULT_CAMERA_HEIGHT)
    parser.add_argument("--camera-fps", type=int, default=DEFAULT_CAMERA_FPS)
    parser.add_argument("--no-cameras", action="store_true")
    parser.add_argument("--no-depth", action="store_true", help="Disable default raw RealSense depth streams")
    parser.add_argument(
        "--display",
        type=str,
        default="false",
        choices=["true", "false"],
        help="Enable Rerun data display (true/false, default: false)",
    )
    parser.add_argument(
        "--timing-debug",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Write asynchronous rollout timing diagnostics.",
    )
    parser.add_argument(
        "--timing-debug-path",
        default="/tmp/lerobot_piper_remote_rollout_timing.parquet",
        help="Path for rollout timing diagnostics when --timing-debug is enabled.",
    )
    parser.add_argument(
        "--strategy",
        type=str,
        default="base",
        choices=["base", "sentry", "highlight", "dagger"],
        help="Rollout strategy type: base, sentry, highlight, dagger",
    )

    parser.add_argument("--dry-run", action="store_true", help="Print the assembled command only.")
    return parser.parse_known_args(argv)


def main() -> int:
    args, passthrough = parse_args()
    task_file, remove_task_file = prepare_task_file(args)
    cmd = build_command(args, passthrough)
    print("Command:", " ".join(cmd))
    if args.dry_run:
        if remove_task_file and task_file is not None:
            task_file.unlink(missing_ok=True)
        return 0

    stop_event = threading.Event()
    if args.interactive_task and not args.tui:
        if task_file is None:
            raise RuntimeError("interactive task mode requires a task file")
        threading.Thread(
            target=_interactive_task_loop,
            args=(task_file, args.task, stop_event),
            daemon=True,
            name="PiperTaskInput",
        ).start()
    try:
        return subprocess.call(cmd)
    finally:
        stop_event.set()
        if remove_task_file and task_file is not None:
            task_file.unlink(missing_ok=True)


if __name__ == "__main__":
    raise SystemExit(main())
相关推荐
Agilex松灵机器人3 小时前
万小时数据落地!松灵机器人构建具身智能数据新基建
大数据·人工智能·机器人·具身智能·松灵机器人
机器人零零壹3 小时前
访越擎科技机器人离线编程软件iRobotCAM创始人:具身智能爆发前夕,我如何参与?
具身智能·机器人仿真·并联机器人·机器人离线编程·关节设计
深蓝学院18 小时前
李飞飞团队新作:首个闭合感知‑行动回路的具身空间智能基准
具身智能·空间智能
The moon forgets20 小时前
跨本体机器人学习:人类运动解码通用物理交互
学习·机器人·交互·具身智能·vla
chen_zn951 天前
GR00T N1.7源码学习(五):Policy推理、RTC动作衔接与部署流程解析
人工智能·深度学习·具身智能·vla·流匹配
Asimov_Liu1 天前
Diffusion 与 Flow Matching 数学原理及其在 VLA Action 生成中的应用
stable diffusion·自动驾驶·具身智能·vla·flow matching
chen_zn951 天前
GR00T N1.7源码学习(三):动作头内部模块、DiT结构与多机器人条件编码解析
深度学习·具身智能·vla·gr00t
chen_zn951 天前
GR00T N1.7源码学习(二):训练数据、Processor与多机器人动作空间解析
深度学习·具身智能·vla·lerobot·gr00t
WWZZ20252 天前
宇树B2/W开发部署1:入门篇
机器人·具身智能·宇树·b2-w