1 介绍
gym-fetch 是一个扩展了 OpenAI Gym Fetch 机器人环境的 Python 包,为机器人操作任务提供了丰富的环境集合。该项目扩展了原有的 Fetch 环境,新增了 7 个操作任务,相比 metaworld 使用的 sawyer 环境,gym.Fetch 环境具有更好的工程实现,初始化更快,并且具有较小的最大回合长度(50 步),使得这些环境训练速度更快。 gym-fetch-v2是本人为了兼容新版本的gymnasium等等包重新优化后的。

2 完整安装脚本
#!/bin/bash
# install_gym_fetch.sh
echo "=== 开始安装 Gym-Fetch-v2 环境 ==="
# 1. 安装系统依赖
echo "正在安装系统依赖..."
sudo apt update
sudo apt install -y \
build-essential wget curl git cmake \
python3-dev python3-pip python3-venv \
libosmesa6-dev libgl1-mesa-dev libglu1-mesa-dev \
mesa-utils freeglut3-dev libglew-dev
# 2. 安装 MuJoCo
echo "正在安装 MuJoCo 210..."
mkdir -p ~/.mujoco
cd ~/.mujoco
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz
tar -xzf mujoco210-linux-x86_64.tar.gz
mv mujoco210_linux mujoco210
# 3. 设置环境变量
echo "设置环境变量..."
echo 'export MUJOCO_PY_MUJOCO_PATH=~/.mujoco/mujoco210' >> ~/.bashrc
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin' >> ~/.bashrc
source ~/.bashrc
# 4. 创建虚拟环境
echo "创建 Python 虚拟环境..."
python3 -m venv ~/envs/gym-fetch
source ~/envs/gym-fetch/bin/activate
# 5. 安装 Python 包
echo "安装 Python 包..."
pip install --upgrade pip
pip install numpy scipy cython==0.29.36
pip install mujoco-py==2.1.2.14
pip install gymnasium gymnasium-robotics
# 6. 安装 gym-fetch-v2
echo "安装 gym-fetch-v2..."
cd ~
git clone https://github.com/cityu-lm/gym-fetch-v2.git
cd gym-fetch-v2
pip install -e .
echo "=== 安装完成! ==="
echo "激活环境: source ~/envs/gym-fetch/bin/activate"
echo "测试安装: cd ~/gym-fetch-v2 && python examples/demo.py"
测试安装
python
import os
import os
import gymnasium as gym
import fetch # regpisters gym-fetch environments
def main():
# Primitive Single Task envs (requested)
env_names = [
'Bin-pick-v2',
'Bin-place-v2',
'Box-open-v2',
'Box-close-v2',
'Drawer-open-v2',
'Drawer-close-v2',
]
for env_name in env_names:
print(f"Testing environment: {env_name}")
try:
env = gym.make(env_name)
obs, info = env.reset()
print(f"Environment {env_name} reset successfully.")
render_mode = 'human' if os.environ.get('DISPLAY') else 'rgb_array'
# Run a few steps with random actions
for step in range(10):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
# Render offscreen to avoid GUI requirements
try:
env.render(mode=render_mode)
except TypeError:
env.render()
if done:
print(f"Episode done at step {step}")
break
env.close()
print(f"Environment {env_name} closed successfully.\n")
except Exception as e:
print(f"Error with environment {env_name}: {e}\n")
if __name__ == "__main__":
main()
3 数据采集程序
执行
bash
/home/iqr/miniconda3/envs/gym-robot/bin/python3 collect_fetch_demo.py \
--env Bin-pick-v2 \
--out demos/bin_pick \
--max-steps 6000 \
--control-hz 20 \
--stop-on-success \
--ignore-env-done \
--verbose
# 只在成功时保存演示并开始新的样本

collect_fetch_demo.py
python
#!/usr/bin/env python3
"""Collect human demonstrations in gym-fetch environments using an Xbox controller.
Constraints (per request):
- This script is independent of robosuite (does not import robosuite).
- Works with gymnasium + fetch (gym-fetch env registration).
- Saves demonstrations under an output directory (npz by default, hdf5 if h5py exists).
Usage examples:
- Xbox (human):
/home/ubuntu2004/miniconda3/envs/gym-robot/bin/python collect_fetch_demo.py --env Bin-pick-v2 --out demos/bin_pick
- Headless scripted (for testing):
/home/ubuntu2004/miniconda3/envs/gym-robot/bin/python collect_fetch_demo.py --env Bin-pick-v2 --device scripted --episodes 1 --max-steps 5 --render none
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from typing import Any, Dict, Optional
import gymnasium as gym
import numpy as np
import fetch # noqa: F401 (register gym-fetch envs)
from utils import (
FETCH_ENV_IDS,
ScriptedController,
XboxController,
now_utc_iso,
rollout_episode,
try_save_demo,
)
def _force_time_limit_max_steps(env: Any, max_episode_steps: Optional[int]) -> Any:
"""Best-effort override of TimeLimit max steps.
Why: some env registrations set a default TimeLimit (e.g., 50 steps ~= 2s @ 20Hz).
If the override is not applied, the episode will truncate quickly and appear to
"auto reset".
Gymnasium supports passing max_episode_steps into gym.make(), but we also patch
the wrapper directly to be robust across versions/wrappers.
"""
if max_episode_steps is None:
return env
target = int(max_episode_steps)
if target <= 0:
return env
# Walk wrapper chain to find TimeLimit.
cur = env
for _ in range(64):
if cur.__class__.__name__ == "TimeLimit":
try:
setattr(cur, "_max_episode_steps", target)
except Exception:
pass
# Keep spec consistent if present.
try:
spec = getattr(cur, "spec", None)
if spec is not None:
spec.max_episode_steps = target
except Exception:
pass
return env
if hasattr(cur, "env"):
cur = getattr(cur, "env")
else:
break
# If no TimeLimit wrapper exists, wrap it (gymnasium path).
try:
from gymnasium.wrappers import TimeLimit
return TimeLimit(env, max_episode_steps=target)
except Exception:
return env
def _make_env(env_id: str, render_mode: Optional[str], max_episode_steps: Optional[int]) -> Any:
# gym-fetch envs often DO NOT accept render_mode in their constructors.
# If we pass render_mode and it errors, we must still preserve max_episode_steps;
# otherwise we fall back to the default TimeLimit=50 steps (~2s @ 20Hz).
kwargs: Dict[str, Any] = {}
if max_episode_steps is not None:
kwargs["max_episode_steps"] = int(max_episode_steps)
# Prefer applying max_episode_steps; handle rendering via env.render() calls.
try:
return gym.make(env_id, **kwargs)
except TypeError:
# Older versions may not accept max_episode_steps; degrade gracefully.
return gym.make(env_id)
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--env", type=str, default=FETCH_ENV_IDS[0], choices=list(FETCH_ENV_IDS))
p.add_argument("--out", type=str, default="demos")
p.add_argument("--episodes", type=int, default=0, help="0 means run until user quits")
p.add_argument("--max-steps", type=int, default=500)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--control-hz", type=float, default=20.0)
p.add_argument(
"--success-hold-steps",
type=int,
default=10,
help="End episode after is_success==1 for N consecutive steps.",
)
p.add_argument(
"--stop-on-success",
action="store_true",
default=False,
help="If set, end the current demo when success is held for N steps (default off to avoid premature endings).",
)
p.add_argument(
"--ignore-env-done",
action="store_true",
default=True,
help="If True, ignore env terminated/truncated (e.g., TimeLimit) and keep collecting in the same demo.",
)
p.add_argument(
"--respect-env-done",
action="store_false",
dest="ignore_env_done",
help="If set, stop the demo when env signals done (terminated/truncated).",
)
p.add_argument(
"--render",
type=str,
default="auto",
choices=["auto", "human", "rgb_array", "none"],
help="Rendering mode. auto uses human if DISPLAY else rgb_array.",
)
p.add_argument("--device", type=str, default="xbox", choices=["xbox", "scripted"])
p.add_argument(
"--verbose",
action="store_true",
default=False,
help="Print why an attempt ended (TimeLimit/user reset/quit/success).",
)
# Xbox mapping/scales
p.add_argument("--deadzone", type=float, default=0.10)
p.add_argument("--pos-scale", type=float, default=1.0)
p.add_argument("--invert-y", action="store_true", default=True)
p.add_argument("--no-invert-y", action="store_false", dest="invert_y")
p.add_argument("--gripper-open", type=float, default=1.0)
p.add_argument("--gripper-close", type=float, default=-1.0)
# Axis/button indices (OS-dependent)
p.add_argument("--axis-lx", type=int, default=0)
p.add_argument("--axis-ly", type=int, default=1)
p.add_argument("--axis-lt", type=int, default=2)
p.add_argument("--axis-rt", type=int, default=5)
p.add_argument("--button-x", type=int, default=2)
p.add_argument("--button-y", type=int, default=3)
p.add_argument("--button-start", type=int, default=7)
# Scripted controller support (for tests / automation)
p.add_argument(
"--scripted-json",
type=str,
default=None,
help=(
"Path to JSON containing a list of events: "
"[{action:[4 floats], reset:false, quit:false}, ...]."
),
)
return p.parse_args()
def _resolve_render_mode(arg: str) -> Optional[str]:
if arg == "none":
return None
if arg == "auto":
return "human" if os.environ.get("DISPLAY") else "rgb_array"
return arg
def _build_metadata(args: argparse.Namespace, env: Any) -> Dict[str, Any]:
spec_id = None
try:
spec_id = env.spec.id # type: ignore[attr-defined]
except Exception:
pass
return {
"created_utc": now_utc_iso(),
"env_id": spec_id or args.env,
"seed": args.seed,
"max_steps": args.max_steps,
"control_hz": args.control_hz,
"device": args.device,
"action_space": str(env.action_space),
"observation_space": str(env.observation_space),
"argv": sys.argv,
}
def _make_controller(args: argparse.Namespace) -> Any:
if args.device == "xbox":
return XboxController(
deadzone=args.deadzone,
pos_scale=args.pos_scale,
invert_y=args.invert_y,
gripper_open_value=args.gripper_open,
gripper_close_value=args.gripper_close,
axis_lx=args.axis_lx,
axis_ly=args.axis_ly,
axis_lt=args.axis_lt,
axis_rt=args.axis_rt,
button_x=args.button_x,
button_y=args.button_y,
button_start=args.button_start,
)
# scripted
events = []
if args.scripted_json:
with open(args.scripted_json, "r", encoding="utf-8") as f:
raw = json.load(f)
for item in raw:
action = np.asarray(item.get("action", [0, 0, 0, 0]), dtype=np.float32)
reset = bool(item.get("reset", False))
quit_flag = bool(item.get("quit", False))
from utils import ControllerEvent
events.append(ControllerEvent(action=action, reset_episode=reset, quit=quit_flag))
else:
# default scripted pattern: 10 neutral steps then reset
from utils import ControllerEvent
for _ in range(10):
events.append(ControllerEvent(action=np.zeros(4, dtype=np.float32)))
events.append(ControllerEvent(action=np.zeros(4, dtype=np.float32), reset_episode=True))
return ScriptedController(events=events)
def main() -> int:
args = _parse_args()
render_mode = _resolve_render_mode(args.render)
# Override TimeLimit max_episode_steps to prevent early truncation.
env = _make_env(args.env, render_mode, max_episode_steps=args.max_steps)
env = _force_time_limit_max_steps(env, args.max_steps)
if args.verbose:
# Print TimeLimit config to diagnose 2s auto-truncation (default often 50 steps).
tl_steps = None
cur = env
for _ in range(64):
if cur.__class__.__name__ == "TimeLimit":
tl_steps = getattr(cur, "_max_episode_steps", None)
break
if hasattr(cur, "env"):
cur = getattr(cur, "env")
else:
break
try:
spec_steps = getattr(getattr(env, "spec", None), "max_episode_steps", None)
except Exception:
spec_steps = None
print(f"TimeLimit: _max_episode_steps={tl_steps} spec.max_episode_steps={spec_steps}")
# Seed if supported
try:
env.reset(seed=args.seed)
except Exception:
pass
controller = _make_controller(args)
controller.start()
os.makedirs(args.out, exist_ok=True)
episode_idx = 0
quit_all = False
try:
while True:
if args.episodes > 0 and episode_idx >= args.episodes:
break
demo = rollout_episode(
env,
controller,
max_steps=args.max_steps,
control_hz=args.control_hz,
render=(render_mode == "human"),
stop_on_success=args.stop_on_success,
success_hold_steps=args.success_hold_steps,
continue_on_done_until_success=args.ignore_env_done,
)
if args.verbose:
ended_by = "unknown"
if demo.get("quit", False):
ended_by = "quit"
elif demo.get("reset_by_user", False):
ended_by = "reset_by_user"
elif demo.get("successful", False):
ended_by = "success"
elif demo.get("done_before_success", False):
ended_by = "env_done_before_success"
print(
"Attempt summary:",
f"T={len(demo.get('action', []))}",
f"done_resets={demo.get('done_resets', 0)}",
f"ended_by={ended_by}",
)
meta = _build_metadata(args, env)
meta["episode_index"] = episode_idx
meta["saved_utc"] = now_utc_iso()
meta["quit"] = bool(demo.get("quit", False))
meta["reset_by_user"] = bool(demo.get("reset_by_user", False))
meta["successful"] = bool(demo.get("successful", False))
meta["done_before_success"] = bool(demo.get("done_before_success", False))
if demo.get("successful"):
base = os.path.join(args.out, f"{args.env}_ep{episode_idx:04d}_{int(time.time())}")
saved_path = try_save_demo(base, demo, meta)
print(f"Saved demo: {saved_path}")
episode_idx += 1
else:
print("Demo not successful, skipping save and continuing to next attempt.")
if demo.get("quit", False):
quit_all = True
break
finally:
try:
controller.close()
except Exception:
pass
try:
env.close()
except Exception:
pass
if quit_all:
return 0
return 0
if __name__ == "__main__":
raise SystemExit(main())
utils.py
python
"""Utilities for collecting gym-fetch demonstrations.
This module is intentionally independent of robosuite.
It provides:
- Controller interfaces (Xbox via pygame, plus scripted controller for testing)
- Rollout / recording helpers
- Demo serialization (npz by default; hdf5 if h5py is available)
"""
from __future__ import annotations
import json
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
FETCH_ENV_IDS: Tuple[str, ...] = (
"Bin-pick-v2",
"Bin-place-v2",
"Box-open-v2",
"Box-close-v2",
"Drawer-open-v2",
"Drawer-close-v2",
)
def now_utc_iso() -> str:
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
def ensure_dir(path: str) -> str:
os.makedirs(path, exist_ok=True)
return path
def clamp(x: float, lo: float, hi: float) -> float:
if x < lo:
return lo
if x > hi:
return hi
return x
def apply_deadzone(x: float, deadzone: float) -> float:
if abs(x) < deadzone:
return 0.0
return x
def rate_limit_sleep(start_time_s: float, target_hz: Optional[float]) -> None:
if not target_hz or target_hz <= 0:
return
elapsed = time.time() - start_time_s
dt = 1.0 / float(target_hz)
remaining = dt - elapsed
if remaining > 0:
time.sleep(remaining)
@dataclass
class ControllerEvent:
action: np.ndarray
reset_episode: bool = False
quit: bool = False
class BaseController:
"""Controller interface.
read() must return a ControllerEvent with action in env.action_space shape.
"""
def start(self) -> None:
return
def close(self) -> None:
return
def read(self) -> ControllerEvent:
raise NotImplementedError
class ScriptedController(BaseController):
"""Deterministic controller for tests.
Provide an iterable of (action, reset, quit). If exhausted, will request quit.
"""
def __init__(
self,
events: Iterable[ControllerEvent],
default_action: Optional[np.ndarray] = None,
) -> None:
self._events = list(events)
self._index = 0
self._default_action = default_action
def read(self) -> ControllerEvent:
if self._index < len(self._events):
ev = self._events[self._index]
self._index += 1
return ev
if self._default_action is None:
return ControllerEvent(action=np.zeros(4, dtype=np.float32), quit=True)
return ControllerEvent(action=self._default_action.astype(np.float32), quit=True)
class XboxController(BaseController):
"""Xbox controller mapping for gym-fetch (action dim 4).
Mapping (default):
- Left stick: x/y translation (dx, dy)
- RT/LT triggers: z translation (dz = RT - LT)
- X button: toggle gripper open/close
- Y button: end current episode (reset flag)
- START button: quit
Notes:
- pygame is imported lazily so tests can run without pygame installed.
- Axis / button indices vary across OS / controller; expose via args in caller.
"""
def __init__(
self,
*,
deadzone: float = 0.10,
pos_scale: float = 1.0,
invert_y: bool = True,
gripper_open_value: float = 1.0,
gripper_close_value: float = -1.0,
axis_lx: int = 0,
axis_ly: int = 1,
axis_lt: int = 2,
axis_rt: int = 5,
button_x: int = 2,
button_y: int = 3,
button_start: int = 7,
) -> None:
self.deadzone = float(deadzone)
self.pos_scale = float(pos_scale)
self.invert_y = bool(invert_y)
self.gripper_open_value = float(gripper_open_value)
self.gripper_close_value = float(gripper_close_value)
self.axis_lx = int(axis_lx)
self.axis_ly = int(axis_ly)
self.axis_lt = int(axis_lt)
self.axis_rt = int(axis_rt)
self.button_x = int(button_x)
self.button_y = int(button_y)
self.button_start = int(button_start)
self._pygame = None
self._joystick = None
self._gripper_closed = False
self._prev_x = 0
self._prev_y = 0
self._prev_start = 0
def start(self) -> None:
try:
import pygame # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError(
"pygame is required for XboxController; install it or use --device scripted"
) from e
self._pygame = pygame
pygame.init()
pygame.joystick.init()
if pygame.joystick.get_count() <= 0:
raise RuntimeError("No joystick detected. Connect an Xbox controller.")
self._joystick = pygame.joystick.Joystick(0)
self._joystick.init()
def close(self) -> None:
if self._pygame is not None: # pragma: no cover
try:
self._pygame.joystick.quit()
self._pygame.quit()
except Exception:
pass
def _get_axis(self, idx: int) -> float:
assert self._joystick is not None
x = float(self._joystick.get_axis(idx))
return apply_deadzone(x, self.deadzone)
def _get_button(self, idx: int) -> int:
assert self._joystick is not None
return int(self._joystick.get_button(idx))
def read(self) -> ControllerEvent:
if self._pygame is None or self._joystick is None:
raise RuntimeError("XboxController.start() must be called before read().")
self._pygame.event.pump()
lx = self._get_axis(self.axis_lx)
ly = self._get_axis(self.axis_ly)
if self.invert_y:
ly = -ly
lt_raw = float(self._joystick.get_axis(self.axis_lt))
rt_raw = float(self._joystick.get_axis(self.axis_rt))
# Triggers may be in [-1, 1] or [0, 1]; normalize to [0, 1]
lt = (lt_raw + 1.0) / 2.0 if lt_raw < 0.0 else lt_raw
rt = (rt_raw + 1.0) / 2.0 if rt_raw < 0.0 else rt_raw
dz = clamp(rt - lt, -1.0, 1.0)
# Edge-detect buttons
curr_x = self._get_button(self.button_x)
if curr_x == 1 and self._prev_x == 0:
self._gripper_closed = not self._gripper_closed
self._prev_x = curr_x
curr_y = self._get_button(self.button_y)
reset_episode = bool(curr_y == 1 and self._prev_y == 0)
self._prev_y = curr_y
curr_start = self._get_button(self.button_start)
quit_flag = bool(curr_start == 1 and self._prev_start == 0)
self._prev_start = curr_start
gripper = self.gripper_close_value if self._gripper_closed else self.gripper_open_value
action = np.array(
[
clamp(lx * self.pos_scale, -1.0, 1.0),
clamp(ly * self.pos_scale, -1.0, 1.0),
clamp(dz * self.pos_scale, -1.0, 1.0),
clamp(gripper, -1.0, 1.0),
],
dtype=np.float32,
)
return ControllerEvent(action=action, reset_episode=reset_episode, quit=quit_flag)
def _obs_to_numpy(obs: Any) -> Dict[str, np.ndarray]:
if not isinstance(obs, dict):
return {"observation": np.asarray(obs)}
out: Dict[str, np.ndarray] = {}
for k, v in obs.items():
out[str(k)] = np.asarray(v)
return out
def extract_is_success(info: Any) -> Optional[bool]:
"""Extract success signal from a Gymnasium info dict.
For goal-based robotics envs, Gymnasium commonly uses info["is_success"]
as a float/bool (1/0). Returns None if not present.
"""
if not isinstance(info, dict):
return None
if "is_success" not in info:
return None
v = info.get("is_success")
try:
# handle numpy scalars / arrays
if isinstance(v, np.ndarray):
if v.size == 0:
return None
v = v.reshape(-1)[0]
return bool(float(v) >= 1.0)
except Exception:
try:
return bool(v)
except Exception:
return None
def rollout_episode(
env: Any,
controller: BaseController,
*,
max_steps: int,
control_hz: Optional[float],
render: bool,
stop_on_success: bool = False,
success_hold_steps: int = 10,
continue_on_done_until_success: bool = True,
) -> Dict[str, Any]:
"""Roll out a single episode.
Returns a dict containing arrays and metadata.
"""
obs, info = env.reset()
obs_np = _obs_to_numpy(obs)
traj: Dict[str, Any] = {
"obs": {k: [v] for k, v in obs_np.items()},
"action": [],
"reward": [],
"terminated": [],
"truncated": [],
"info": [info],
"segment_starts": [0],
}
# success latch logic: require N consecutive success timesteps
hold = int(success_hold_steps)
if hold < 1:
hold = 1
remaining_success = hold
for _step in range(int(max_steps)):
step_start = time.time()
ev = controller.read()
if ev.quit:
traj["quit"] = True
break
action = np.asarray(ev.action, dtype=np.float32)
traj["action"].append(action)
obs, reward, terminated, truncated, info = env.step(action)
obs_np = _obs_to_numpy(obs)
for k, v in obs_np.items():
traj["obs"].setdefault(k, []).append(v)
traj["reward"].append(float(reward))
traj["terminated"].append(bool(terminated))
traj["truncated"].append(bool(truncated))
traj["info"].append(info)
if render:
try:
env.render()
except TypeError:
env.render(mode="human")
if ev.reset_episode:
traj["reset_by_user"] = True
break
successful_now = False
if stop_on_success:
is_success = extract_is_success(info)
if is_success is True:
remaining_success -= 1
if remaining_success <= 0:
traj["successful"] = True
successful_now = True
elif is_success is False:
remaining_success = hold
if successful_now:
break
# If the environment signals done (often TimeLimit), optionally reset and continue.
if bool(terminated) or bool(truncated):
if continue_on_done_until_success:
traj["done_resets"] = int(traj.get("done_resets", 0)) + 1
obs, info = env.reset()
obs_np = _obs_to_numpy(obs)
for k, v in obs_np.items():
traj["obs"].setdefault(k, []).append(v)
traj["info"].append(info)
traj["segment_starts"].append(len(traj["action"]))
remaining_success = hold
continue
traj["done_before_success"] = True
break
rate_limit_sleep(step_start, control_hz)
# Convert lists to arrays where appropriate
traj["action"] = np.asarray(traj["action"], dtype=np.float32)
traj["reward"] = np.asarray(traj["reward"], dtype=np.float32)
traj["terminated"] = np.asarray(traj["terminated"], dtype=bool)
traj["truncated"] = np.asarray(traj["truncated"], dtype=bool)
for k in list(traj["obs"].keys()):
traj["obs"][k] = np.asarray(traj["obs"][k])
return traj
def save_demo_npz(path: str, demo: Dict[str, Any], metadata: Dict[str, Any]) -> str:
ensure_dir(os.path.dirname(path) or ".")
# info dicts are not np arrays; store json separately in the npz
packed = {
"metadata_json": json.dumps(metadata, ensure_ascii=False),
"info_json": json.dumps(demo.get("info", []), ensure_ascii=False, default=str),
"action": demo["action"],
"reward": demo["reward"],
"terminated": demo["terminated"],
"truncated": demo["truncated"],
}
for k, v in demo["obs"].items():
packed[f"obs__{k}"] = v
np.savez_compressed(path, **packed)
return path
def save_demo_hdf5(path: str, demo: Dict[str, Any], metadata: Dict[str, Any]) -> str:
"""Save to hdf5 if h5py is installed."""
try:
import h5py # type: ignore
except Exception as e:
raise RuntimeError("h5py is not installed; use save_demo_npz instead") from e
ensure_dir(os.path.dirname(path) or ".")
with h5py.File(path, "w") as f:
f.attrs["metadata_json"] = json.dumps(metadata, ensure_ascii=False)
f.create_dataset("action", data=demo["action"], compression="gzip")
f.create_dataset("reward", data=demo["reward"], compression="gzip")
f.create_dataset("terminated", data=demo["terminated"], compression="gzip")
f.create_dataset("truncated", data=demo["truncated"], compression="gzip")
obs_grp = f.create_group("obs")
for k, v in demo["obs"].items():
obs_grp.create_dataset(k, data=v, compression="gzip")
f.attrs["info_json"] = json.dumps(demo.get("info", []), ensure_ascii=False, default=str)
return path
def try_save_demo(path_base: str, demo: Dict[str, Any], metadata: Dict[str, Any]) -> str:
"""Prefer hdf5 if available, otherwise npz."""
try:
return save_demo_hdf5(path_base + ".hdf5", demo, metadata)
except Exception:
return save_demo_npz(path_base + ".npz", demo, metadata)