1、背景
Isaac GR00T N1.7是一个面向通用机器人操作任务的VLA模型。模型推理依赖GPU环境,而Piper真机侧则依赖CAN总线、RealSense相机、机器人SDK和本地安全控制逻辑。
为了让模型推理和真机控制解耦,本文采用了服务端/客户端架构,
GPU推理机
启动GR00T模型服务
接收观测observation
输出动作action chunk
Piper真机控制机
采集相机图像和机械臂状态
通过WebSocket请求远端模型
执行动作并记录数据
好处是GR00T模型可以运行在H100、A100、RTX等GPU机器上,而Piper真机端只需要负责硬件控制、相机采集、动作队列和安全停止。
2、项目结构
本次部署相关文件主要在,
scripts/piper/
start_gr00t_piper_server.sh # 启动GR00T WebSocket推理服务
start_gr00t_piper_client.sh # 启动Piper真机客户端
gr00t_piper_ws_server.py # GR00T到Piper协议适配服务端
piper_gr00t_remote_rollout.py # LeRobot Piper remote rollout包装器
test_remote_handshake.py # 最小握手测试
README.md
third_party/lerobot/examples/piper/
piper_remote_rollout.py # Piper真机rollout入口
remote_inference_protocol.md # websocket-msgpack协议说明
3、环境配置
3.1、基础依赖
项目使用uv管理Python环境。先在项目根目录安装依赖,
uv sync --all-extras
当前仓库的默认pyproject.toml面向x86 GPU环境,包含如下关键依赖,
torch==2.7.1
torchvision==0.22.1
transformers==4.57.3
websockets>=13.1,<17.0
msgpack==1.1.0
numpy==1.26.4
pyzmq==27.0.1
如果是Jetson Orin、Thor或DGX Spark,需要使用scripts/deployment/下对应平台的安装脚本。
3.2、Piper客户端依赖
Piper真机侧依赖LeRobot的硬件组件,主要包括,
remote-inference: websockets, msgpack
piper: python-can, piper_sdk
intelrealsense: pyrealsense2
dataset/viz: 数据记录与可视化
客户端最终会启动类似下面的LeRobot命令,
uv run \
--extra dataset \
--extra remote-inference \
--extra piper \
--extra intelrealsense \
--extra viz \
lerobot-rollout ...
如果使用Rerun显示数据,可以打开,
--display true
如果使用RealSense,需要确认前置相机和腕部相机序列号正确。
4、实现原理
4.1、通信协议
服务端和客户端之间使用WebSocket binary frame + MessagePack
协议版本为`1`,请求类型主要有三个,
hello - 建立连接并同步action keys、observation schema、task等信息
infer - 客户端发送当前观测,服务端返回动作
reset - 清空一次 rollout 的服务端状态
4.2、服务端职责
服务端入口是,
scripts/piper/gr00t_piper_ws_server.py
主要完成以下工作,
-
加载GR00T checkpoint
-
初始化Gr00tPolicy
-
接收Piper客户端的observation
-
将远端observation转换成GR00T需要的输入格式
-
调用policy.get_action()
-
将GR00T输出动作重新映射成Piper action keys
-
返回给客户端执行
服务端默认的动作key是,
joint1.pos
joint2.pos
joint3.pos
joint4.pos
joint5.pos
joint6.pos
gripper.pos
默认图像映射为,
{
"front": "observation.images.front",
"wrist": "observation.images.wrist",
}
默认状态映射为,
{
"single_arm": [
"joint1.pos",
"joint2.pos",
"joint3.pos",
"joint4.pos",
"joint5.pos",
"joint6.pos",
],
"gripper": ["gripper.pos"],
}
Piper客户端采集到的前视相机、腕部相机、机械臂关节和夹爪状态,会在服务端被整理成GR00T模型训练时对应的video、state、language输入。
4.3、客户端职责
客户端入口是,
scripts/piper/piper_gr00t_remote_rollout.py
本质上是对LeRobot 0.5.2中examples/piper/piper_remote_rollout.py的包装,负责,
-
连接Piper真机CAN总线
-
打开RealSense相机
-
采集当前observation
-
通过WebSocket请求远端GR00T服务
-
接收动作chunk
-
按队列逐步发送到Piper
-
记录rollout数据
-
支持Rerun可视化和 timing debug
客户端本地保留了安全逻辑。如果服务端超时、返回格式错误、缺少action key或action shape不正确,客户端会fail closed,避免继续向机械臂发送错误动作。
5、启动服务端
在GPU推理机上运行,
MODEL_PATH=/path/to/gr00t/checkpoint \
EMBODIMENT_TAG=new_embodiment \
TASK="put the bottle into the box." \
HOST=0.0.0.0 \
PORT=8000 \
DEVICE=cuda \
bash scripts/piper/start_gr00t_piper_server.sh
本项目脚本中的默认参数包括,
HOST=0.0.0.0
PORT=8000
DEVICE=cuda
EMBODIMENT_TAG=new_embodiment
如果模型的modality key和默认映射不一致,可以通过JSON参数显式指定,
bash scripts/piper/start_gr00t_piper_server.sh \
--video-map '{"front":"observation.images.front","wrist":"observation.images.wrist"}' \
--state-map '{"single_arm":["joint1.pos","joint2.pos","joint3.pos","joint4.pos","joint5.pos","joint6.pos"],"gripper":["gripper.pos"]}' \
--action-map '{"single_arm":["joint1.pos","joint2.pos","joint3.pos","joint4.pos","joint5.pos","joint6.pos"],"gripper":["gripper.pos"]}'
6、服务端握手测试
在接真机前,建议先做最小握手测试,
uv run --project third_party/lerobot --extra remote-inference \
python scripts/piper/test_remote_handshake.py \
--address ws://127.0.0.1:8000
如果返回`hello_ack`,说明WebSocket、MessagePack和服务端协议基本正常。
7、启动Piper客户端
在Piper真机控制机上运行,
REMOTE_ADDRESS=ws://SERVER_IP:8000 \
TASK="put the yellow bottle into the box." \
DURATION=300 \
SPEED_PERCENT=20 \
STRATEGY=sentry \
bash scripts/piper/start_gr00t_piper_client.sh
常用参数示例,
bash scripts/piper/start_gr00t_piper_client.sh \
--can can0 \
--front-camera-serial 254622078102 \
--wrist-camera-serial 254622071629 \
--strategy sentry \
--display true
8、相关源码
8.1、scripts/piper/start_gr00t_piper_server.sh
#!/usr/bin/env bash
HOST="${HOST:-0.0.0.0}"
PORT="${PORT:-8000}"
DEVICE="${DEVICE:-cuda}"
MODEL_PATH="${MODEL_PATH:-}"
EMBODIMENT_TAG="${EMBODIMENT_TAG:-new_embodiment}"
MODALITY_CONFIG_PATH="${MODALITY_CONFIG_PATH:-}"
TASK="${TASK:-}"
export UV_CACHE_DIR="/tmp/uv-cache"
echo "GR00T Piper remote server"
echo "Starting on ws://${HOST}:${PORT}"
echo "Model: ${MODEL_PATH}"
echo "Embodiment tag: ${EMBODIMENT_TAG}"
echo "Task: ${TASK}"
echo "UV cache: ${UV_CACHE_DIR}"
cmd=(
uv run --no-sync python scripts/piper/gr00t_piper_ws_server.py
--host "${HOST}"
--port "${PORT}"
--device "${DEVICE}"
--model-path "${MODEL_PATH}"
--embodiment-tag "${EMBODIMENT_TAG}"
--task "${TASK}"
)
if [[ -n "${MODALITY_CONFIG_PATH}" ]]; then
cmd+=(--modality-config-path "${MODALITY_CONFIG_PATH}")
fi
exec "${cmd[@]}" "$@"
8.2、scripts/piper/start_gr00t_piper_client.sh
#!/usr/bin/env bash
REMOTE_ADDRESS="${REMOTE_ADDRESS:-ws://127.0.0.1:8000}"
TASK="${TASK:-}"
DURATION="${DURATION:-300}"
SPEED_PERCENT="${SPEED_PERCENT:-20}"
STRATEGY="${STRATEGY:-sentry}"
REMOTE_ACTION_JOINTS_UNIT="${REMOTE_ACTION_JOINTS_UNIT:-deg}"
DATASET_ROOT="${DATASET_ROOT:-}"
DATASET_REPO_ID="${DATASET_REPO_ID:-local}"
echo "GR00T Piper remote client"
echo "Remote: ${REMOTE_ADDRESS}"
echo "Task: ${TASK}"
echo "Strategy: ${STRATEGY}"
exec uv run --no-sync python scripts/piper/piper_gr00t_remote_rollout.py \
--remote-address "${REMOTE_ADDRESS}" \
--task "${TASK}" \
--duration "${DURATION}" \
--speed-percent "${SPEED_PERCENT}" \
--strategy "${STRATEGY}" \
--remote-action-joints-unit "${REMOTE_ACTION_JOINTS_UNIT}" \
--dataset-root "${DATASET_ROOT}" \
--dataset-repo-id "${DATASET_REPO_ID}" \
--display true \
"$@"
8.3、scripts/piper/gr00t_piper_ws_server.py
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import importlib
import inspect
import json
import logging
import os
from pathlib import Path
import sys
import time
import traceback
from typing import Any
import numpy as np
PROTOCOL_VERSION = 1
DEFAULT_ACTION_KEYS = [
"joint1.pos",
"joint2.pos",
"joint3.pos",
"joint4.pos",
"joint5.pos",
"joint6.pos",
"gripper.pos",
]
DEFAULT_VIDEO_MAP = {
"front": "observation.images.front",
"wrist": "observation.images.wrist",
}
DEFAULT_STATE_MAP = {
"single_arm": DEFAULT_ACTION_KEYS[:6],
"gripper": ["gripper.pos"],
}
DEFAULT_ACTION_MAP = DEFAULT_STATE_MAP
logger = logging.getLogger("gr00t_piper_ws_server")
def _pack_array(obj: Any) -> Any:
if isinstance(obj, np.ndarray):
if obj.dtype.kind in ("V", "O", "c"):
raise ValueError(f"Unsupported ndarray dtype for msgpack transport: {obj.dtype}")
return {
b"__ndarray__": True,
b"data": obj.tobytes(),
b"dtype": obj.dtype.str,
b"shape": obj.shape,
}
if isinstance(obj, np.generic):
return {b"__npgeneric__": True, b"data": obj.item(), b"dtype": obj.dtype.str}
return obj
def _unpack_array(obj: dict) -> Any:
if b"__ndarray__" in obj:
return np.ndarray(
buffer=obj[b"data"],
dtype=np.dtype(obj[b"dtype"]),
shape=obj[b"shape"],
)
if b"__npgeneric__" in obj:
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
return obj
def packb(payload: Any) -> bytes:
import msgpack
return msgpack.packb(payload, default=_pack_array, use_bin_type=True)
def unpackb(payload: bytes) -> Any:
import msgpack
return msgpack.unpackb(payload, object_hook=_unpack_array, raw=False)
def _json_mapping(raw: str | None, default: dict[str, Any]) -> dict[str, Any]:
if raw is None or raw == "":
return dict(default)
parsed = json.loads(raw)
if not isinstance(parsed, dict):
raise ValueError(f"Mapping must be a JSON object, got {type(parsed).__name__}")
return parsed
def _load_modality_config(path: str | None) -> None:
if not path:
return
config_path = Path(path).expanduser().resolve()
if not config_path.exists() or config_path.suffix != ".py":
raise FileNotFoundError(f"Modality config must be an existing .py file: {path}")
sys.path.append(str(config_path.parent))
importlib.import_module(config_path.stem)
logger.info("Loaded modality config: %s", config_path)
def _ensure_batched_video(value: Any, *, key: str) -> np.ndarray:
arr = np.asarray(value)
if arr.dtype != np.uint8:
arr = arr.astype(np.uint8, copy=False)
if arr.ndim == 3:
arr = arr[None, None, ...]
elif arr.ndim == 4:
arr = arr[None, ...]
elif arr.ndim != 5:
raise ValueError(f"Video key {key!r} must have shape HWC, THWC, or BTHWC, got {arr.shape}")
if arr.shape[-1] != 3:
raise ValueError(f"Video key {key!r} must have 3 RGB channels, got shape {arr.shape}")
return np.ascontiguousarray(arr)
def _stack_remote_scalars(
observation: dict[str, Any],
remote_keys: list[str],
*,
key: str,
state_names: list[str],
) -> np.ndarray:
missing = [remote_key for remote_key in remote_keys if remote_key not in observation]
if not missing:
values = [
float(np.asarray(observation[remote_key]).reshape(-1)[0]) for remote_key in remote_keys
]
return np.asarray(values, dtype=np.float32)[None, None, :]
state_key = "observation.state"
if state_key not in observation:
available = sorted(str(remote_key) for remote_key in observation)
raise KeyError(
f"Missing remote observation keys for {key!r}: {missing}. "
f"Also did not find {state_key!r}. Available keys={available}"
)
state = np.asarray(observation[state_key], dtype=np.float32).reshape(-1)
missing_state_names = [
remote_key for remote_key in remote_keys if remote_key not in state_names
]
if missing_state_names:
raise KeyError(
f"Remote observation uses {state_key!r}, but its names do not contain "
f"{missing_state_names} for policy key {key!r}. state_names={state_names}"
)
indices = [state_names.index(remote_key) for remote_key in remote_keys]
if max(indices) >= state.shape[0]:
raise ValueError(
f"{state_key!r} has length {state.shape[0]}, but mapping for {key!r} "
f"requires indices {indices} from state_names={state_names}"
)
values = [float(state[index]) for index in indices]
return np.asarray(values, dtype=np.float32)[None, None, :]
def _infer_policy_key(policy_keys: list[str], requested_key: str, mapping: dict[str, Any]) -> str:
if requested_key in mapping:
return requested_key
if len(policy_keys) == 1 and len(mapping) == 1:
return policy_keys[0]
raise KeyError(
f"Mapping does not provide policy key {requested_key!r}. "
f"Policy keys={policy_keys}, mapping keys={list(mapping)}"
)
def _flatten_policy_actions(
policy_actions: dict[str, Any],
*,
action_map: dict[str, list[str]],
requested_action_keys: list[str],
) -> np.ndarray:
rows_by_key: dict[str, np.ndarray] = {}
for policy_key, remote_keys in action_map.items():
if policy_key not in policy_actions:
raise KeyError(f"Policy did not return action key {policy_key!r}")
arr = np.asarray(policy_actions[policy_key], dtype=np.float32)
if arr.ndim == 3:
if arr.shape[0] != 1:
raise ValueError(
f"Only batch size 1 is supported for action {policy_key!r}, got {arr.shape}"
)
arr = arr[0]
elif arr.ndim == 2:
pass
elif arr.ndim == 1:
arr = arr[:, None]
else:
raise ValueError(
f"Action {policy_key!r} must have shape BTD, TD, or T, got {arr.shape}"
)
if arr.shape[1] != len(remote_keys):
raise ValueError(
f"Action {policy_key!r} dimension {arr.shape[1]} does not match "
f"mapped remote keys {remote_keys}"
)
for column, remote_key in enumerate(remote_keys):
rows_by_key[remote_key] = arr[:, column]
missing = [key for key in requested_action_keys if key not in rows_by_key]
if missing:
raise KeyError(f"Policy action mapping did not produce requested keys: {missing}")
min_rows = min(rows_by_key[key].shape[0] for key in requested_action_keys)
return np.stack([rows_by_key[key][:min_rows] for key in requested_action_keys], axis=1).astype(
np.float32, copy=False
)
class Gr00tPiperRemoteAdapter:
def __init__(self, args: argparse.Namespace) -> None:
_load_modality_config(args.modality_config_path)
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.policy.gr00t_policy import Gr00tPolicy
self.video_map = _json_mapping(args.video_map, DEFAULT_VIDEO_MAP)
self.state_map = _json_mapping(args.state_map, DEFAULT_STATE_MAP)
self.action_map = _json_mapping(args.action_map, DEFAULT_ACTION_MAP)
self.action_keys = list(DEFAULT_ACTION_KEYS)
self.observation_state_names = list(DEFAULT_ACTION_KEYS)
self.last_task = args.task
logger.info("Loading GR00T policy from %s", args.model_path)
self.policy = Gr00tPolicy(
embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
model_path=args.model_path,
device=args.device,
strict=args.strict,
)
self.policy.reset()
self.modality_configs = self.policy.get_modality_config()
self.language_key = self.modality_configs["language"].modality_keys[0]
self.use_length = len(self.modality_configs["action"].delta_indices)
logger.info("Policy modality keys: %s", self._modality_key_summary())
def hello(self, request: dict[str, Any]) -> dict[str, Any]:
self._check_protocol(request)
self.action_keys = list(request.get("action_keys") or DEFAULT_ACTION_KEYS)
self._update_observation_state_names(request.get("observation_features"))
self.last_task = str(request.get("task") or self.last_task or "")
missing = [key for key in self.action_keys if key not in self._mapped_remote_action_keys()]
if missing:
raise KeyError(f"Client requested action keys not covered by --action-map: {missing}")
return {
"type": "hello_ack",
"accepted_protocols": ["websocket-msgpack"],
"protocol_version": PROTOCOL_VERSION,
"use_length": self.use_length,
"action_features": {key: {"dtype": "float32", "shape": []} for key in self.action_keys},
"policy_modality_keys": self._modality_key_summary(),
}
def infer(self, request: dict[str, Any]) -> dict[str, Any]:
self._check_protocol(request)
request_id = request.get("request_id")
task = str(request.get("task") or self.last_task or "")
observation = request.get("observation")
if not isinstance(observation, dict):
raise ValueError("infer request missing dict field 'observation'")
start = time.perf_counter()
gr00t_obs = self._to_gr00t_observation(observation, task=task)
actions, _info = self.policy.get_action(gr00t_obs)
remote_actions = _flatten_policy_actions(
actions,
action_map=self.action_map,
requested_action_keys=self.action_keys,
)
infer_ms = (time.perf_counter() - start) * 1e3
return {
"type": "action",
"request_id": request_id,
"action_keys": list(self.action_keys),
"actions": remote_actions,
"server_timing": {"infer_ms": infer_ms},
}
def reset(self, request: dict[str, Any]) -> dict[str, Any]:
self._check_protocol(request)
self.policy.reset()
return {"type": "reset_ack", "protocol_version": PROTOCOL_VERSION}
def _to_gr00t_observation(self, observation: dict[str, Any], *, task: str) -> dict[str, Any]:
video: dict[str, np.ndarray] = {}
for policy_key in self.modality_configs["video"].modality_keys:
remote_key = self.video_map.get(policy_key)
if remote_key is None and len(self.modality_configs["video"].modality_keys) == len(
self.video_map
):
remote_key = self.video_map[list(self.video_map.keys())[len(video)]]
if remote_key is None:
raise KeyError(f"No --video-map entry for policy video key {policy_key!r}")
if remote_key not in observation:
raise KeyError(
f"Missing remote video key {remote_key!r} for policy key {policy_key!r}"
)
video[policy_key] = _ensure_batched_video(observation[remote_key], key=remote_key)
state: dict[str, np.ndarray] = {}
for policy_key in self.modality_configs["state"].modality_keys:
mapped_key = _infer_policy_key(
list(self.modality_configs["state"].modality_keys), policy_key, self.state_map
)
remote_keys = list(self.state_map[mapped_key])
state[policy_key] = _stack_remote_scalars(
observation,
remote_keys,
key=policy_key,
state_names=self.observation_state_names,
)
language = {self.language_key: [[task]]}
return {"video": video, "state": state, "language": language}
def _mapped_remote_action_keys(self) -> set[str]:
return {
remote_key for remote_keys in self.action_map.values() for remote_key in remote_keys
}
def _modality_key_summary(self) -> dict[str, list[str]]:
return {name: list(config.modality_keys) for name, config in self.modality_configs.items()}
def _update_observation_state_names(self, observation_features: Any) -> None:
if not isinstance(observation_features, dict):
return
state_feature = observation_features.get("observation.state")
if not isinstance(state_feature, dict):
return
names = state_feature.get("names")
if isinstance(names, list) and all(isinstance(name, str) for name in names):
self.observation_state_names = list(names)
logger.info("Remote observation.state names: %s", self.observation_state_names)
@staticmethod
def _check_protocol(request: dict[str, Any]) -> None:
if request.get("protocol_version") != PROTOCOL_VERSION:
raise ValueError(
f"Unsupported protocol_version={request.get('protocol_version')}; "
f"expected {PROTOCOL_VERSION}"
)
def _serve_forever(args: argparse.Namespace) -> None:
try:
from websockets.sync.server import serve
except ImportError as exc:
raise ImportError(
"Install websockets to run this server. In this repo, "
"`uv run --project third_party/lerobot --extra remote-inference ...` "
"or adding websockets to the active environment both work."
) from exc
adapter = Gr00tPiperRemoteAdapter(args)
api_key = os.environ.get(args.api_key_env) if args.api_key_env else None
def process_request(connection, request): # noqa: ANN001 - websockets compatibility
if api_key is None:
return None
auth = request.headers.get("Authorization")
if auth == f"Bearer {api_key}":
return None
return connection.respond(401, "Unauthorized\n")
def handler(ws) -> None: # noqa: ANN001 - websockets compatibility
logger.info("Client connected")
request: Any = None
for frame in ws:
if isinstance(frame, str):
ws.send(packb({"type": "error", "error": "text frames are not supported"}))
continue
try:
request = unpackb(frame)
if not isinstance(request, dict):
raise ValueError(f"Request must be a dict, got {type(request).__name__}")
request_type = request.get("type")
if request_type == "hello":
response = adapter.hello(request)
elif request_type == "infer":
response = adapter.infer(request)
elif request_type == "reset":
response = adapter.reset(request)
else:
raise ValueError(f"Unknown request type: {request_type!r}")
except Exception as exc:
logger.exception("Remote request failed")
response = {
"type": "error",
"request_id": request.get("request_id") if isinstance(request, dict) else None,
"error": f"{type(exc).__name__}: {exc}",
"traceback": traceback.format_exc(),
}
ws.send(packb(response))
logger.info("Client disconnected")
kwargs: dict[str, Any] = {
"host": args.host,
"port": args.port,
"compression": None,
"max_size": None,
}
if api_key is not None:
kwargs["process_request"] = process_request
if "open_timeout" in inspect.signature(serve).parameters:
kwargs["open_timeout"] = args.open_timeout_s
logger.info("Serving GR00T Piper remote inference on ws://%s:%s", args.host, args.port)
with serve(handler, **kwargs) as server:
server.serve_forever()
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--model-path", required=True, help="GR00T N1.7 checkpoint directory")
parser.add_argument("--embodiment-tag", default="new_embodiment")
parser.add_argument("--device", default="cuda")
parser.add_argument("--strict", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--modality-config-path", default=None)
parser.add_argument("--task", default="")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--open-timeout-s", type=float, default=10.0)
parser.add_argument("--api-key-env", default=None)
parser.add_argument(
"--video-map", default=None, help="JSON policy video key -> remote observation key"
)
parser.add_argument(
"--state-map", default=None, help="JSON policy state key -> remote scalar keys"
)
parser.add_argument(
"--action-map", default=None, help="JSON policy action key -> remote action keys"
)
parser.add_argument("--log-level", default="INFO")
return parser.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
)
_serve_forever(args)
return 0
if __name__ == "__main__":
raise SystemExit(main())
8.4、scripts/piper/piper_gr00t_remote_rollout.py
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
from pathlib import Path
import subprocess
from urllib.parse import urlparse
REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_LEROBOT_REPO = str(REPO_ROOT / "third_party/lerobot")
DEFAULT_REMOTE_ADDRESS = ""
DEFAULT_TASK = ""
DEFAULT_DATASET_ROOT = ""
DEFAULT_DATASET_REPO_ID = "local"
def build_command(args: argparse.Namespace, passthrough: list[str]) -> list[str]:
script = Path(args.lerobot_repo).expanduser() / "examples/piper/piper_remote_rollout.py"
cmd = [
"uv",
"run",
"--no-sync",
"python",
str(script),
f"--remote-address={args.remote_address}",
f"--timeout-s={args.timeout_s}",
f"--max-queued-actions={args.max_queued_actions}",
f"--task={args.task}",
f"--duration={args.duration}",
f"--speed-percent={args.speed_percent}",
f"--strategy={args.strategy}",
f"--remote-action-joints-unit={args.remote_action_joints_unit}",
f"--dataset-root={args.dataset_root}",
f"--dataset-repo-id={args.dataset_repo_id}",
f"--display={args.display}",
]
if args.api_key_env:
cmd.append(f"--api-key-env={args.api_key_env}")
if args.dry_run:
cmd.append("--dry-run")
cmd.extend(passthrough)
return cmd
def parse_args(argv: list[str] | None = None) -> tuple[argparse.Namespace, list[str]]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--lerobot-repo", default=DEFAULT_LEROBOT_REPO)
parser.add_argument("--remote-address", default=DEFAULT_REMOTE_ADDRESS)
parser.add_argument("--timeout-s", type=float, default=30.0)
parser.add_argument("--api-key-env", default=None)
parser.add_argument("--max-queued-actions", type=int, default=64)
parser.add_argument("--task", default=DEFAULT_TASK)
parser.add_argument("--duration", type=float, default=300)
parser.add_argument("--speed-percent", type=int, default=20)
parser.add_argument(
"--strategy", choices=["base", "sentry", "highlight", "dagger"], default="base"
)
parser.add_argument("--remote-action-joints-unit", choices=["deg", "rad"], default="deg")
parser.add_argument("--dataset-root", default=DEFAULT_DATASET_ROOT)
parser.add_argument("--dataset-repo-id", default=DEFAULT_DATASET_REPO_ID)
parser.add_argument("--display", choices=["true", "false"], default="false")
parser.add_argument("--dry-run", action="store_true")
return parser.parse_known_args(argv)
def main(argv: list[str] | None = None) -> int:
args, passthrough = parse_args(argv)
cmd = build_command(args, passthrough)
print("Command:", " ".join(cmd))
return subprocess.call(cmd, cwd=args.lerobot_repo)
if __name__ == "__main__":
raise SystemExit(main())
8.5、scripts/piper/test_remote_handshake.py
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import inspect
import sys
import time
PROTOCOL_VERSION = 1
DEFAULT_ACTION_KEYS = [
"joint1.pos",
"joint2.pos",
"joint3.pos",
"joint4.pos",
"joint5.pos",
"joint6.pos",
"gripper.pos",
]
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--address", default="")
parser.add_argument("--timeout-s", type=float, default=10.0)
parser.add_argument("--wait-s", type=float, default=0.0)
args = parser.parse_args()
import msgpack
import websockets.sync.client as ws_client
payload = {
"type": "hello",
"protocol_version": PROTOCOL_VERSION,
"robot_type": "piper",
"fps": 30.0,
"task": "handshake test",
"observation_features": {},
"action_keys": DEFAULT_ACTION_KEYS,
}
kwargs = {"compression": None, "max_size": None, "open_timeout": args.timeout_s}
if "proxy" in inspect.signature(ws_client.connect).parameters:
kwargs["proxy"] = None
deadline = time.monotonic() + max(args.wait_s, 0.0)
last_exc: BaseException | None = None
while True:
try:
with ws_client.connect(args.address, **kwargs) as ws:
ws.send(msgpack.packb(payload, use_bin_type=True))
response = ws.recv(timeout=args.timeout_s)
if isinstance(response, str):
raise RuntimeError(f"Server returned text frame: {response}")
print(msgpack.unpackb(response, raw=False))
return 0
except Exception as exc:
last_exc = exc
if time.monotonic() >= deadline:
break
time.sleep(1.0)
print(f"Remote handshake failed: {type(last_exc).__name__}: {last_exc}", file=sys.stderr)
return 1
if __name__ == "__main__":
raise SystemExit(main())
8.6、third_party/lerobot/examples/piper/piper_remote_rollout.py
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
import subprocess
import sys
import tempfile
import threading
from contextlib import suppress
from pathlib import Path
from examples.piper.piper_camera_config import build_realsense_camera_arg
DEFAULT_CAN = "can0"
DEFAULT_ROBOT_ID = "piper"
DEFAULT_SPEED = 20
DEFAULT_FPS = 30
DEFAULT_DURATION_S = 300
DEFAULT_TASK = ""
DEFAULT_REMOTE_ADDRESS = ""
DEFAULT_DATASET_ROOT = ""
DEFAULT_DATASET_REPO_ID = "local"
DEFAULT_FRONT_CAMERA_SERIAL = ""
DEFAULT_WRIST_CAMERA_SERIAL = ""
DEFAULT_CAMERA_WIDTH = 1280
DEFAULT_CAMERA_HEIGHT = 720
DEFAULT_CAMERA_FPS = 30
DEFAULT_RERUN_PORT = 9876
DEFAULT_ROBOT_ACTION_PROCESSOR = "piper_rad_to_deg"
DEFAULT_ROBOT_OBSERVATION_PROCESSOR = "piper_deg_to_rad"
RECORDING_STRATEGIES = {"sentry", "highlight", "dagger"}
UPLOAD_INTERVAL_STRATEGIES = {"sentry", "dagger"}
STREAMING_ENCODING_STRATEGIES = {"sentry", "highlight"}
def _camera_args(args: argparse.Namespace) -> list[str]:
return build_realsense_camera_arg(args)
def _passthrough_option_value(passthrough: list[str], option: str) -> str | None:
value: str | None = None
prefix = f"{option}="
index = 0
while index < len(passthrough):
arg = passthrough[index]
if arg.startswith(prefix):
value = arg.split("=", 1)[1]
index += 1
continue
if arg == option:
if index + 1 >= len(passthrough) or passthrough[index + 1].startswith("--"):
raise ValueError(f"{option} requires a value")
value = passthrough[index + 1]
index += 2
continue
index += 1
return value
def _expand_policy_path(policy_path: str | None) -> str | None:
if policy_path is None:
return None
policy_path = policy_path.strip()
if policy_path.startswith("~"):
return str(Path(policy_path).expanduser())
return policy_path
def _resolve_policy_path(args: argparse.Namespace, passthrough: list[str]) -> tuple[str | None, bool]:
wrapper_policy_path = _expand_policy_path(args.policy_path)
passthrough_policy_path = _expand_policy_path(_passthrough_option_value(passthrough, "--policy.path"))
if wrapper_policy_path and passthrough_policy_path and wrapper_policy_path != passthrough_policy_path:
raise ValueError(
"Conflicting policy paths: "
f"--policy-path={wrapper_policy_path!r} and --policy.path={passthrough_policy_path!r}"
)
return wrapper_policy_path or passthrough_policy_path, wrapper_policy_path is not None
def _resolve_inference_type(args: argparse.Namespace, passthrough: list[str]) -> tuple[str, bool]:
policy_path, _has_wrapper_policy_path = _resolve_policy_path(args, passthrough)
passthrough_inference_type = _passthrough_option_value(passthrough, "--inference.type")
if passthrough_inference_type:
return passthrough_inference_type, False
if policy_path:
return "sync", True
return "remote", True
def _write_task_file(task_file: Path, task: str) -> None:
task_file.parent.mkdir(parents=True, exist_ok=True)
tmp_file = task_file.with_name(f".{task_file.name}.tmp")
tmp_file.write_text(f"{task.strip()}\n", encoding="utf-8")
tmp_file.replace(task_file)
def prepare_task_file(args: argparse.Namespace) -> tuple[Path | None, bool]:
if args.tui:
return None, False
if args.task_file:
task_file = Path(args.task_file).expanduser()
args.task_file = str(task_file)
_write_task_file(task_file, args.task)
return task_file, False
if not args.interactive_task:
return None, False
fd, task_file_name = tempfile.mkstemp(
prefix="lerobot_piper_remote_task_",
suffix=".txt",
)
os.close(fd)
task_file = Path(task_file_name)
args.task_file = str(task_file)
_write_task_file(task_file, args.task)
return task_file, True
def _interactive_task_loop(task_file: Path, initial_task: str, stop_event: threading.Event) -> None:
current_task = initial_task.strip()
print(
"Interactive task input enabled. Type a new task, /show, or /exit.",
flush=True,
)
while not stop_event.is_set():
line = sys.stdin.readline()
if not line:
return
new_task = line.strip()
if not new_task:
continue
if new_task == "/exit":
print("Interactive task input disabled; rollout continues.", flush=True)
return
if new_task == "/show":
with suppress(OSError):
current_task = task_file.read_text(encoding="utf-8").strip() or current_task
print(f"Current task: {current_task}", flush=True)
continue
current_task = new_task
_write_task_file(task_file, current_task)
print(f"Task updated: {current_task}", flush=True)
def build_command(args: argparse.Namespace, passthrough: list[str]) -> list[str]:
dataset_root = Path(args.dataset_root).expanduser()
records_data = args.strategy in RECORDING_STRATEGIES
entrypoint = "lerobot-rollout-tui" if args.tui else "lerobot-rollout"
policy_path, should_append_policy_path = _resolve_policy_path(args, passthrough)
inference_type, should_append_inference_type = _resolve_inference_type(args, passthrough)
is_remote_inference = inference_type == "remote"
if not is_remote_inference and not policy_path:
raise ValueError(
f"--policy-path or --policy.path is required when --inference.type={inference_type}"
)
if not is_remote_inference and args.remote_action_joints_unit == "rad":
raise ValueError("--remote-action-joints-unit=rad is only supported with remote inference")
has_robot_action_processor = any(
arg.startswith("--robot_action_processor.") or arg == "--robot_action_processor"
for arg in passthrough
)
has_robot_observation_processor = any(
arg.startswith("--robot_observation_processor.") or arg == "--robot_observation_processor"
for arg in passthrough
)
cmd = [
"uv",
"run",
"--extra",
"dataset",
"--extra",
"remote-inference" if is_remote_inference else "pi",
"--extra",
"piper",
"--extra",
"intelrealsense",
"--extra",
"viz",
]
if args.tui:
cmd.extend(["--extra", "tui"])
cmd.extend(
[
entrypoint,
f"--strategy.type={args.strategy}",
"--robot.type=piper",
f"--robot.can_name={args.can}",
f"--robot.speed_percent={args.speed_percent}",
f"--robot.id={args.robot_id}",
f"--task={args.task}",
f"--fps={args.fps}",
]
)
if should_append_inference_type:
cmd.append(f"--inference.type={inference_type}")
if is_remote_inference:
cmd.extend(
[
"--inference.protocol=websocket-msgpack",
f"--inference.address={args.remote_address}",
f"--inference.timeout_s={args.timeout_s}",
f"--inference.max_queued_actions={args.max_queued_actions}",
]
)
elif should_append_policy_path:
cmd.append(f"--policy.path={policy_path}")
if not args.tui:
cmd.append(f"--duration={args.duration}")
if args.task_file and not args.tui:
cmd.append(f"--task_file={Path(args.task_file).expanduser()}")
if is_remote_inference and args.remote_action_joints_unit == "rad" and not has_robot_action_processor:
cmd.append(f"--robot_action_processor.type={DEFAULT_ROBOT_ACTION_PROCESSOR}")
if (
is_remote_inference
and args.remote_action_joints_unit == "rad"
and not has_robot_observation_processor
):
cmd.append(f"--robot_observation_processor.type={DEFAULT_ROBOT_OBSERVATION_PROCESSOR}")
if args.strategy in UPLOAD_INTERVAL_STRATEGIES:
cmd.append(f"--strategy.upload_every_n_episodes={args.upload_every_n_episodes}")
if records_data:
cmd.extend(
[
f"--dataset.repo_id={args.dataset_repo_id}",
f"--dataset.root={dataset_root}",
f"--dataset.single_task={args.task}",
"--dataset.push_to_hub=false",
]
)
if args.strategy in STREAMING_ENCODING_STRATEGIES:
cmd.append("--dataset.streaming_encoding=true")
if args.encoder_threads is not None:
cmd.append(f"--dataset.encoder_threads={args.encoder_threads}")
if args.depth_encoder_vcodec != "auto":
cmd.append(f"--dataset.depth_encoder.vcodec={args.depth_encoder_vcodec}")
if args.depth_encoder_pix_fmt is not None:
cmd.append(f"--dataset.depth_encoder.pix_fmt={args.depth_encoder_pix_fmt}")
if args.display == "true":
cmd.extend([f"--display_data={args.display}", f"--display_port={args.rerun_port}"])
if args.timing_debug:
cmd.append("--timing_debug=true")
if args.timing_debug_path:
cmd.append(f"--timing_debug_path={args.timing_debug_path}")
if args.rerun_ip:
cmd.append(f"--display_ip={args.rerun_ip}")
if is_remote_inference and args.api_key_env:
cmd.append(f"--inference.api_key_env={args.api_key_env}")
if args.no_return_to_initial_position:
cmd.append("--return_to_initial_position=false")
cmd.extend(_camera_args(args))
cmd.extend(passthrough)
return cmd
def parse_args(argv: list[str] | None = None) -> tuple[argparse.Namespace, list[str]]:
parser = argparse.ArgumentParser(
description="Wrapper for Piper rollout with recording and Rerun visualization.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--remote-address", default=DEFAULT_REMOTE_ADDRESS)
parser.add_argument("--timeout-s", type=float, default=30.0)
parser.add_argument("--api-key-env", default=None)
parser.add_argument("--max-queued-actions", type=int, default=64)
parser.add_argument(
"--policy-path",
default=None,
help="Local LeRobot policy path. Enables local sync inference unless --inference.type is passed through.",
)
parser.add_argument("--can", default=DEFAULT_CAN)
parser.add_argument("--robot-id", default=DEFAULT_ROBOT_ID)
parser.add_argument("--speed-percent", type=int, default=DEFAULT_SPEED)
parser.add_argument("--no-return-to-initial-position", action="store_true")
parser.add_argument("--dataset-root", default=DEFAULT_DATASET_ROOT)
parser.add_argument("--dataset-repo-id", default=DEFAULT_DATASET_REPO_ID)
parser.add_argument("--upload-every-n-episodes", type=int, default=1000000)
parser.add_argument("--encoder-threads", type=int, default=2)
parser.add_argument(
"--depth-encoder-vcodec",
default="auto",
),
)
parser.add_argument("--depth-encoder-pix-fmt", default=None, help="Depth pixel format override.")
parser.add_argument("--task", default=DEFAULT_TASK)
parser.add_argument(
"--tui",
action=argparse.BooleanOptionalAction,
default=False,
help="Run through the Textual/Rich interactive rollout UI.",
)
parser.add_argument("--task-file", default=None, help="File watched by lerobot-rollout for task updates.")
parser.add_argument(
"--interactive-task",
action=argparse.BooleanOptionalAction,
default=True,
help="Read task updates from stdin while rollout is running.",
)
parser.add_argument("--fps", type=float, default=DEFAULT_FPS)
parser.add_argument("--duration", type=float, default=DEFAULT_DURATION_S)
parser.add_argument(
"--remote-action-joints-unit",
choices=["rad", "deg"],
default="deg",
help=(
"Unit of joint1..joint6 expected/returned by the remote policy. "
"rad inserts piper_deg_to_rad for observations and piper_rad_to_deg for actions."
),
)
parser.add_argument("--rerun-port", type=int, default=DEFAULT_RERUN_PORT)
parser.add_argument(
"--rerun-ip",
default=None,
help="Connect to an existing remote Rerun server. Leave unset to spawn a local viewer.",
)
parser.add_argument("--front-camera-serial", default=DEFAULT_FRONT_CAMERA_SERIAL)
parser.add_argument("--wrist-camera-serial", default=DEFAULT_WRIST_CAMERA_SERIAL)
parser.add_argument("--camera-width", type=int, default=DEFAULT_CAMERA_WIDTH)
parser.add_argument("--camera-height", type=int, default=DEFAULT_CAMERA_HEIGHT)
parser.add_argument("--camera-fps", type=int, default=DEFAULT_CAMERA_FPS)
parser.add_argument("--no-cameras", action="store_true")
parser.add_argument("--no-depth", action="store_true", help="Disable default raw RealSense depth streams")
parser.add_argument(
"--display",
type=str,
default="false",
choices=["true", "false"],
help="Enable Rerun data display (true/false, default: false)",
)
parser.add_argument(
"--timing-debug",
action=argparse.BooleanOptionalAction,
default=True,
help="Write asynchronous rollout timing diagnostics.",
)
parser.add_argument(
"--timing-debug-path",
default="/tmp/lerobot_piper_remote_rollout_timing.parquet",
help="Path for rollout timing diagnostics when --timing-debug is enabled.",
)
parser.add_argument(
"--strategy",
type=str,
default="base",
choices=["base", "sentry", "highlight", "dagger"],
help="Rollout strategy type: base, sentry, highlight, dagger",
)
parser.add_argument("--dry-run", action="store_true", help="Print the assembled command only.")
return parser.parse_known_args(argv)
def main() -> int:
args, passthrough = parse_args()
task_file, remove_task_file = prepare_task_file(args)
cmd = build_command(args, passthrough)
print("Command:", " ".join(cmd))
if args.dry_run:
if remove_task_file and task_file is not None:
task_file.unlink(missing_ok=True)
return 0
stop_event = threading.Event()
if args.interactive_task and not args.tui:
if task_file is None:
raise RuntimeError("interactive task mode requires a task file")
threading.Thread(
target=_interactive_task_loop,
args=(task_file, args.task, stop_event),
daemon=True,
name="PiperTaskInput",
).start()
try:
return subprocess.call(cmd)
finally:
stop_event.set()
if remove_task_file and task_file is not None:
task_file.unlink(missing_ok=True)
if __name__ == "__main__":
raise SystemExit(main())