Ubuntu 系统安装 Gym-Fetch-v2以及通过游戏Xbox手柄控制机械臂采集演示数据

1 介绍

gym-fetch 是一个扩展了 OpenAI Gym Fetch 机器人环境的 Python 包,为机器人操作任务提供了丰富的环境集合。该项目扩展了原有的 Fetch 环境,新增了 7 个操作任务,相比 metaworld 使用的 sawyer 环境,gym.Fetch 环境具有更好的工程实现,初始化更快,并且具有较小的最大回合长度(50 步),使得这些环境训练速度更快。 gym-fetch-v2是本人为了兼容新版本的gymnasium等等包重新优化后的。

2 完整安装脚本

复制代码
#!/bin/bash
# install_gym_fetch.sh

echo "=== 开始安装 Gym-Fetch-v2 环境 ==="

# 1. 安装系统依赖
echo "正在安装系统依赖..."
sudo apt update
sudo apt install -y \
    build-essential wget curl git cmake \
    python3-dev python3-pip python3-venv \
    libosmesa6-dev libgl1-mesa-dev libglu1-mesa-dev \
    mesa-utils freeglut3-dev libglew-dev

# 2. 安装 MuJoCo
echo "正在安装 MuJoCo 210..."
mkdir -p ~/.mujoco
cd ~/.mujoco
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz
tar -xzf mujoco210-linux-x86_64.tar.gz
mv mujoco210_linux mujoco210

# 3. 设置环境变量
echo "设置环境变量..."
echo 'export MUJOCO_PY_MUJOCO_PATH=~/.mujoco/mujoco210' >> ~/.bashrc
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin' >> ~/.bashrc
source ~/.bashrc

# 4. 创建虚拟环境
echo "创建 Python 虚拟环境..."
python3 -m venv ~/envs/gym-fetch
source ~/envs/gym-fetch/bin/activate

# 5. 安装 Python 包
echo "安装 Python 包..."
pip install --upgrade pip
pip install numpy scipy cython==0.29.36
pip install mujoco-py==2.1.2.14
pip install gymnasium gymnasium-robotics

# 6. 安装 gym-fetch-v2
echo "安装 gym-fetch-v2..."
cd ~
git clone https://github.com/cityu-lm/gym-fetch-v2.git
cd gym-fetch-v2
pip install -e .

echo "=== 安装完成! ==="
echo "激活环境: source ~/envs/gym-fetch/bin/activate"
echo "测试安装: cd ~/gym-fetch-v2 && python examples/demo.py"

测试安装

python 复制代码
import os

import os

import gymnasium as gym
import fetch  # regpisters gym-fetch environments


def main():
    # Primitive Single Task envs (requested)
    env_names = [
        'Bin-pick-v2',
        'Bin-place-v2',
        'Box-open-v2',
        'Box-close-v2',
        'Drawer-open-v2',
        'Drawer-close-v2',
    ]

    for env_name in env_names:
        print(f"Testing environment: {env_name}")
        try:
            env = gym.make(env_name)
            obs, info = env.reset()
            print(f"Environment {env_name} reset successfully.")

            render_mode = 'human' if os.environ.get('DISPLAY') else 'rgb_array'

            # Run a few steps with random actions
            for step in range(10):
                action = env.action_space.sample()
                obs, reward, terminated, truncated, info = env.step(action)
                done = terminated or truncated

                # Render offscreen to avoid GUI requirements
                try:
                    env.render(mode=render_mode)
                except TypeError:
                    env.render()

                if done:
                    print(f"Episode done at step {step}")
                    break

            env.close()
            print(f"Environment {env_name} closed successfully.\n")
        except Exception as e:
            print(f"Error with environment {env_name}: {e}\n")


if __name__ == "__main__":
    main()

3 数据采集程序

执行

bash 复制代码
/home/iqr/miniconda3/envs/gym-robot/bin/python3 collect_fetch_demo.py \
  --env Bin-pick-v2 \
  --out demos/bin_pick \
  --max-steps 6000 \
  --control-hz 20 \
  --stop-on-success \
  --ignore-env-done \
  --verbose

# 只在成功时保存演示并开始新的样本

collect_fetch_demo.py

python 复制代码
#!/usr/bin/env python3
"""Collect human demonstrations in gym-fetch environments using an Xbox controller.

Constraints (per request):
- This script is independent of robosuite (does not import robosuite).
- Works with gymnasium + fetch (gym-fetch env registration).
- Saves demonstrations under an output directory (npz by default, hdf5 if h5py exists).

Usage examples:
- Xbox (human):
  /home/ubuntu2004/miniconda3/envs/gym-robot/bin/python collect_fetch_demo.py --env Bin-pick-v2 --out demos/bin_pick

- Headless scripted (for testing):
  /home/ubuntu2004/miniconda3/envs/gym-robot/bin/python collect_fetch_demo.py --env Bin-pick-v2 --device scripted --episodes 1 --max-steps 5 --render none
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import time
from typing import Any, Dict, Optional

import gymnasium as gym
import numpy as np

import fetch  # noqa: F401  (register gym-fetch envs)

from utils import (
    FETCH_ENV_IDS,
    ScriptedController,
    XboxController,
    now_utc_iso,
    rollout_episode,
    try_save_demo,
)


def _force_time_limit_max_steps(env: Any, max_episode_steps: Optional[int]) -> Any:
    """Best-effort override of TimeLimit max steps.

    Why: some env registrations set a default TimeLimit (e.g., 50 steps ~= 2s @ 20Hz).
    If the override is not applied, the episode will truncate quickly and appear to
    "auto reset".

    Gymnasium supports passing max_episode_steps into gym.make(), but we also patch
    the wrapper directly to be robust across versions/wrappers.
    """

    if max_episode_steps is None:
        return env
    target = int(max_episode_steps)
    if target <= 0:
        return env

    # Walk wrapper chain to find TimeLimit.
    cur = env
    for _ in range(64):
        if cur.__class__.__name__ == "TimeLimit":
            try:
                setattr(cur, "_max_episode_steps", target)
            except Exception:
                pass
            # Keep spec consistent if present.
            try:
                spec = getattr(cur, "spec", None)
                if spec is not None:
                    spec.max_episode_steps = target
            except Exception:
                pass
            return env

        if hasattr(cur, "env"):
            cur = getattr(cur, "env")
        else:
            break

    # If no TimeLimit wrapper exists, wrap it (gymnasium path).
    try:
        from gymnasium.wrappers import TimeLimit

        return TimeLimit(env, max_episode_steps=target)
    except Exception:
        return env


def _make_env(env_id: str, render_mode: Optional[str], max_episode_steps: Optional[int]) -> Any:
    # gym-fetch envs often DO NOT accept render_mode in their constructors.
    # If we pass render_mode and it errors, we must still preserve max_episode_steps;
    # otherwise we fall back to the default TimeLimit=50 steps (~2s @ 20Hz).
    kwargs: Dict[str, Any] = {}
    if max_episode_steps is not None:
        kwargs["max_episode_steps"] = int(max_episode_steps)

    # Prefer applying max_episode_steps; handle rendering via env.render() calls.
    try:
        return gym.make(env_id, **kwargs)
    except TypeError:
        # Older versions may not accept max_episode_steps; degrade gracefully.
        return gym.make(env_id)


def _parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()

    p.add_argument("--env", type=str, default=FETCH_ENV_IDS[0], choices=list(FETCH_ENV_IDS))
    p.add_argument("--out", type=str, default="demos")
    p.add_argument("--episodes", type=int, default=0, help="0 means run until user quits")
    p.add_argument("--max-steps", type=int, default=500)
    p.add_argument("--seed", type=int, default=0)

    p.add_argument("--control-hz", type=float, default=20.0)

    p.add_argument(
        "--success-hold-steps",
        type=int,
        default=10,
        help="End episode after is_success==1 for N consecutive steps.",
    )

    p.add_argument(
        "--stop-on-success",
        action="store_true",
        default=False,
        help="If set, end the current demo when success is held for N steps (default off to avoid premature endings).",
    )
    p.add_argument(
        "--ignore-env-done",
        action="store_true",
        default=True,
        help="If True, ignore env terminated/truncated (e.g., TimeLimit) and keep collecting in the same demo.",
    )
    p.add_argument(
        "--respect-env-done",
        action="store_false",
        dest="ignore_env_done",
        help="If set, stop the demo when env signals done (terminated/truncated).",
    )

    p.add_argument(
        "--render",
        type=str,
        default="auto",
        choices=["auto", "human", "rgb_array", "none"],
        help="Rendering mode. auto uses human if DISPLAY else rgb_array.",
    )

    p.add_argument("--device", type=str, default="xbox", choices=["xbox", "scripted"])

    p.add_argument(
        "--verbose",
        action="store_true",
        default=False,
        help="Print why an attempt ended (TimeLimit/user reset/quit/success).",
    )

    # Xbox mapping/scales
    p.add_argument("--deadzone", type=float, default=0.10)
    p.add_argument("--pos-scale", type=float, default=1.0)
    p.add_argument("--invert-y", action="store_true", default=True)
    p.add_argument("--no-invert-y", action="store_false", dest="invert_y")
    p.add_argument("--gripper-open", type=float, default=1.0)
    p.add_argument("--gripper-close", type=float, default=-1.0)

    # Axis/button indices (OS-dependent)
    p.add_argument("--axis-lx", type=int, default=0)
    p.add_argument("--axis-ly", type=int, default=1)
    p.add_argument("--axis-lt", type=int, default=2)
    p.add_argument("--axis-rt", type=int, default=5)
    p.add_argument("--button-x", type=int, default=2)
    p.add_argument("--button-y", type=int, default=3)
    p.add_argument("--button-start", type=int, default=7)

    # Scripted controller support (for tests / automation)
    p.add_argument(
        "--scripted-json",
        type=str,
        default=None,
        help=(
            "Path to JSON containing a list of events: "
            "[{action:[4 floats], reset:false, quit:false}, ...]."
        ),
    )

    return p.parse_args()


def _resolve_render_mode(arg: str) -> Optional[str]:
    if arg == "none":
        return None
    if arg == "auto":
        return "human" if os.environ.get("DISPLAY") else "rgb_array"
    return arg


def _build_metadata(args: argparse.Namespace, env: Any) -> Dict[str, Any]:
    spec_id = None
    try:
        spec_id = env.spec.id  # type: ignore[attr-defined]
    except Exception:
        pass

    return {
        "created_utc": now_utc_iso(),
        "env_id": spec_id or args.env,
        "seed": args.seed,
        "max_steps": args.max_steps,
        "control_hz": args.control_hz,
        "device": args.device,
        "action_space": str(env.action_space),
        "observation_space": str(env.observation_space),
        "argv": sys.argv,
    }


def _make_controller(args: argparse.Namespace) -> Any:
    if args.device == "xbox":
        return XboxController(
            deadzone=args.deadzone,
            pos_scale=args.pos_scale,
            invert_y=args.invert_y,
            gripper_open_value=args.gripper_open,
            gripper_close_value=args.gripper_close,
            axis_lx=args.axis_lx,
            axis_ly=args.axis_ly,
            axis_lt=args.axis_lt,
            axis_rt=args.axis_rt,
            button_x=args.button_x,
            button_y=args.button_y,
            button_start=args.button_start,
        )

    # scripted
    events = []
    if args.scripted_json:
        with open(args.scripted_json, "r", encoding="utf-8") as f:
            raw = json.load(f)
        for item in raw:
            action = np.asarray(item.get("action", [0, 0, 0, 0]), dtype=np.float32)
            reset = bool(item.get("reset", False))
            quit_flag = bool(item.get("quit", False))
            from utils import ControllerEvent

            events.append(ControllerEvent(action=action, reset_episode=reset, quit=quit_flag))
    else:
        # default scripted pattern: 10 neutral steps then reset
        from utils import ControllerEvent

        for _ in range(10):
            events.append(ControllerEvent(action=np.zeros(4, dtype=np.float32)))
        events.append(ControllerEvent(action=np.zeros(4, dtype=np.float32), reset_episode=True))

    return ScriptedController(events=events)


def main() -> int:
    args = _parse_args()

    render_mode = _resolve_render_mode(args.render)
    # Override TimeLimit max_episode_steps to prevent early truncation.
    env = _make_env(args.env, render_mode, max_episode_steps=args.max_steps)
    env = _force_time_limit_max_steps(env, args.max_steps)

    if args.verbose:
        # Print TimeLimit config to diagnose 2s auto-truncation (default often 50 steps).
        tl_steps = None
        cur = env
        for _ in range(64):
            if cur.__class__.__name__ == "TimeLimit":
                tl_steps = getattr(cur, "_max_episode_steps", None)
                break
            if hasattr(cur, "env"):
                cur = getattr(cur, "env")
            else:
                break
        try:
            spec_steps = getattr(getattr(env, "spec", None), "max_episode_steps", None)
        except Exception:
            spec_steps = None
        print(f"TimeLimit: _max_episode_steps={tl_steps} spec.max_episode_steps={spec_steps}")

    # Seed if supported
    try:
        env.reset(seed=args.seed)
    except Exception:
        pass

    controller = _make_controller(args)
    controller.start()

    os.makedirs(args.out, exist_ok=True)

    episode_idx = 0
    quit_all = False

    try:
        while True:
            if args.episodes > 0 and episode_idx >= args.episodes:
                break

            demo = rollout_episode(
                env,
                controller,
                max_steps=args.max_steps,
                control_hz=args.control_hz,
                render=(render_mode == "human"),
                stop_on_success=args.stop_on_success,
                success_hold_steps=args.success_hold_steps,
                continue_on_done_until_success=args.ignore_env_done,
            )

            if args.verbose:
                ended_by = "unknown"
                if demo.get("quit", False):
                    ended_by = "quit"
                elif demo.get("reset_by_user", False):
                    ended_by = "reset_by_user"
                elif demo.get("successful", False):
                    ended_by = "success"
                elif demo.get("done_before_success", False):
                    ended_by = "env_done_before_success"
                print(
                    "Attempt summary:",
                    f"T={len(demo.get('action', []))}",
                    f"done_resets={demo.get('done_resets', 0)}",
                    f"ended_by={ended_by}",
                )

            meta = _build_metadata(args, env)
            meta["episode_index"] = episode_idx
            meta["saved_utc"] = now_utc_iso()
            meta["quit"] = bool(demo.get("quit", False))
            meta["reset_by_user"] = bool(demo.get("reset_by_user", False))
            meta["successful"] = bool(demo.get("successful", False))
            meta["done_before_success"] = bool(demo.get("done_before_success", False))

            if demo.get("successful"):
                base = os.path.join(args.out, f"{args.env}_ep{episode_idx:04d}_{int(time.time())}")
                saved_path = try_save_demo(base, demo, meta)
                print(f"Saved demo: {saved_path}")
                episode_idx += 1
            else:
                print("Demo not successful, skipping save and continuing to next attempt.")
            if demo.get("quit", False):
                quit_all = True
                break

    finally:
        try:
            controller.close()
        except Exception:
            pass
        try:
            env.close()
        except Exception:
            pass

    if quit_all:
        return 0
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

utils.py

python 复制代码
"""Utilities for collecting gym-fetch demonstrations.

This module is intentionally independent of robosuite.
It provides:
- Controller interfaces (Xbox via pygame, plus scripted controller for testing)
- Rollout / recording helpers
- Demo serialization (npz by default; hdf5 if h5py is available)

"""

from __future__ import annotations

import json
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np


FETCH_ENV_IDS: Tuple[str, ...] = (
    "Bin-pick-v2",
    "Bin-place-v2",
    "Box-open-v2",
    "Box-close-v2",
    "Drawer-open-v2",
    "Drawer-close-v2",
)


def now_utc_iso() -> str:
    return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())


def ensure_dir(path: str) -> str:
    os.makedirs(path, exist_ok=True)
    return path


def clamp(x: float, lo: float, hi: float) -> float:
    if x < lo:
        return lo
    if x > hi:
        return hi
    return x


def apply_deadzone(x: float, deadzone: float) -> float:
    if abs(x) < deadzone:
        return 0.0
    return x


def rate_limit_sleep(start_time_s: float, target_hz: Optional[float]) -> None:
    if not target_hz or target_hz <= 0:
        return
    elapsed = time.time() - start_time_s
    dt = 1.0 / float(target_hz)
    remaining = dt - elapsed
    if remaining > 0:
        time.sleep(remaining)


@dataclass
class ControllerEvent:
    action: np.ndarray
    reset_episode: bool = False
    quit: bool = False


class BaseController:
    """Controller interface.

    read() must return a ControllerEvent with action in env.action_space shape.
    """

    def start(self) -> None:
        return

    def close(self) -> None:
        return

    def read(self) -> ControllerEvent:
        raise NotImplementedError


class ScriptedController(BaseController):
    """Deterministic controller for tests.

    Provide an iterable of (action, reset, quit). If exhausted, will request quit.
    """

    def __init__(
        self,
        events: Iterable[ControllerEvent],
        default_action: Optional[np.ndarray] = None,
    ) -> None:
        self._events = list(events)
        self._index = 0
        self._default_action = default_action

    def read(self) -> ControllerEvent:
        if self._index < len(self._events):
            ev = self._events[self._index]
            self._index += 1
            return ev
        if self._default_action is None:
            return ControllerEvent(action=np.zeros(4, dtype=np.float32), quit=True)
        return ControllerEvent(action=self._default_action.astype(np.float32), quit=True)


class XboxController(BaseController):
    """Xbox controller mapping for gym-fetch (action dim 4).

    Mapping (default):
    - Left stick: x/y translation (dx, dy)
    - RT/LT triggers: z translation (dz = RT - LT)
    - X button: toggle gripper open/close
    - Y button: end current episode (reset flag)
    - START button: quit

    Notes:
    - pygame is imported lazily so tests can run without pygame installed.
    - Axis / button indices vary across OS / controller; expose via args in caller.
    """

    def __init__(
        self,
        *,
        deadzone: float = 0.10,
        pos_scale: float = 1.0,
        invert_y: bool = True,
        gripper_open_value: float = 1.0,
        gripper_close_value: float = -1.0,
        axis_lx: int = 0,
        axis_ly: int = 1,
        axis_lt: int = 2,
        axis_rt: int = 5,
        button_x: int = 2,
        button_y: int = 3,
        button_start: int = 7,
    ) -> None:
        self.deadzone = float(deadzone)
        self.pos_scale = float(pos_scale)
        self.invert_y = bool(invert_y)
        self.gripper_open_value = float(gripper_open_value)
        self.gripper_close_value = float(gripper_close_value)

        self.axis_lx = int(axis_lx)
        self.axis_ly = int(axis_ly)
        self.axis_lt = int(axis_lt)
        self.axis_rt = int(axis_rt)

        self.button_x = int(button_x)
        self.button_y = int(button_y)
        self.button_start = int(button_start)

        self._pygame = None
        self._joystick = None

        self._gripper_closed = False
        self._prev_x = 0
        self._prev_y = 0
        self._prev_start = 0

    def start(self) -> None:
        try:
            import pygame  # type: ignore
        except Exception as e:  # pragma: no cover
            raise RuntimeError(
                "pygame is required for XboxController; install it or use --device scripted"
            ) from e

        self._pygame = pygame
        pygame.init()
        pygame.joystick.init()
        if pygame.joystick.get_count() <= 0:
            raise RuntimeError("No joystick detected. Connect an Xbox controller.")

        self._joystick = pygame.joystick.Joystick(0)
        self._joystick.init()

    def close(self) -> None:
        if self._pygame is not None:  # pragma: no cover
            try:
                self._pygame.joystick.quit()
                self._pygame.quit()
            except Exception:
                pass

    def _get_axis(self, idx: int) -> float:
        assert self._joystick is not None
        x = float(self._joystick.get_axis(idx))
        return apply_deadzone(x, self.deadzone)

    def _get_button(self, idx: int) -> int:
        assert self._joystick is not None
        return int(self._joystick.get_button(idx))

    def read(self) -> ControllerEvent:
        if self._pygame is None or self._joystick is None:
            raise RuntimeError("XboxController.start() must be called before read().")

        self._pygame.event.pump()

        lx = self._get_axis(self.axis_lx)
        ly = self._get_axis(self.axis_ly)
        if self.invert_y:
            ly = -ly

        lt_raw = float(self._joystick.get_axis(self.axis_lt))
        rt_raw = float(self._joystick.get_axis(self.axis_rt))

        # Triggers may be in [-1, 1] or [0, 1]; normalize to [0, 1]
        lt = (lt_raw + 1.0) / 2.0 if lt_raw < 0.0 else lt_raw
        rt = (rt_raw + 1.0) / 2.0 if rt_raw < 0.0 else rt_raw

        dz = clamp(rt - lt, -1.0, 1.0)

        # Edge-detect buttons
        curr_x = self._get_button(self.button_x)
        if curr_x == 1 and self._prev_x == 0:
            self._gripper_closed = not self._gripper_closed
        self._prev_x = curr_x

        curr_y = self._get_button(self.button_y)
        reset_episode = bool(curr_y == 1 and self._prev_y == 0)
        self._prev_y = curr_y

        curr_start = self._get_button(self.button_start)
        quit_flag = bool(curr_start == 1 and self._prev_start == 0)
        self._prev_start = curr_start

        gripper = self.gripper_close_value if self._gripper_closed else self.gripper_open_value

        action = np.array(
            [
                clamp(lx * self.pos_scale, -1.0, 1.0),
                clamp(ly * self.pos_scale, -1.0, 1.0),
                clamp(dz * self.pos_scale, -1.0, 1.0),
                clamp(gripper, -1.0, 1.0),
            ],
            dtype=np.float32,
        )

        return ControllerEvent(action=action, reset_episode=reset_episode, quit=quit_flag)


def _obs_to_numpy(obs: Any) -> Dict[str, np.ndarray]:
    if not isinstance(obs, dict):
        return {"observation": np.asarray(obs)}
    out: Dict[str, np.ndarray] = {}
    for k, v in obs.items():
        out[str(k)] = np.asarray(v)
    return out


def extract_is_success(info: Any) -> Optional[bool]:
    """Extract success signal from a Gymnasium info dict.

    For goal-based robotics envs, Gymnasium commonly uses info["is_success"]
    as a float/bool (1/0). Returns None if not present.
    """

    if not isinstance(info, dict):
        return None
    if "is_success" not in info:
        return None
    v = info.get("is_success")
    try:
        # handle numpy scalars / arrays
        if isinstance(v, np.ndarray):
            if v.size == 0:
                return None
            v = v.reshape(-1)[0]
        return bool(float(v) >= 1.0)
    except Exception:
        try:
            return bool(v)
        except Exception:
            return None


def rollout_episode(
    env: Any,
    controller: BaseController,
    *,
    max_steps: int,
    control_hz: Optional[float],
    render: bool,
    stop_on_success: bool = False,
    success_hold_steps: int = 10,
    continue_on_done_until_success: bool = True,
) -> Dict[str, Any]:
    """Roll out a single episode.

    Returns a dict containing arrays and metadata.
    """

    obs, info = env.reset()
    obs_np = _obs_to_numpy(obs)

    traj: Dict[str, Any] = {
        "obs": {k: [v] for k, v in obs_np.items()},
        "action": [],
        "reward": [],
        "terminated": [],
        "truncated": [],
        "info": [info],
        "segment_starts": [0],
    }

    # success latch logic: require N consecutive success timesteps
    hold = int(success_hold_steps)
    if hold < 1:
        hold = 1
    remaining_success = hold

    for _step in range(int(max_steps)):
        step_start = time.time()
        ev = controller.read()
        if ev.quit:
            traj["quit"] = True
            break

        action = np.asarray(ev.action, dtype=np.float32)
        traj["action"].append(action)

        obs, reward, terminated, truncated, info = env.step(action)
        obs_np = _obs_to_numpy(obs)
        for k, v in obs_np.items():
            traj["obs"].setdefault(k, []).append(v)

        traj["reward"].append(float(reward))
        traj["terminated"].append(bool(terminated))
        traj["truncated"].append(bool(truncated))
        traj["info"].append(info)

        if render:
            try:
                env.render()
            except TypeError:
                env.render(mode="human")

        if ev.reset_episode:
            traj["reset_by_user"] = True
            break

        successful_now = False
        if stop_on_success:
            is_success = extract_is_success(info)
            if is_success is True:
                remaining_success -= 1
                if remaining_success <= 0:
                    traj["successful"] = True
                    successful_now = True
            elif is_success is False:
                remaining_success = hold

        if successful_now:
            break

        # If the environment signals done (often TimeLimit), optionally reset and continue.
        if bool(terminated) or bool(truncated):
            if continue_on_done_until_success:
                traj["done_resets"] = int(traj.get("done_resets", 0)) + 1
                obs, info = env.reset()
                obs_np = _obs_to_numpy(obs)
                for k, v in obs_np.items():
                    traj["obs"].setdefault(k, []).append(v)
                traj["info"].append(info)
                traj["segment_starts"].append(len(traj["action"]))
                remaining_success = hold
                continue

            traj["done_before_success"] = True
            break

        rate_limit_sleep(step_start, control_hz)

    # Convert lists to arrays where appropriate
    traj["action"] = np.asarray(traj["action"], dtype=np.float32)
    traj["reward"] = np.asarray(traj["reward"], dtype=np.float32)
    traj["terminated"] = np.asarray(traj["terminated"], dtype=bool)
    traj["truncated"] = np.asarray(traj["truncated"], dtype=bool)
    for k in list(traj["obs"].keys()):
        traj["obs"][k] = np.asarray(traj["obs"][k])

    return traj


def save_demo_npz(path: str, demo: Dict[str, Any], metadata: Dict[str, Any]) -> str:
    ensure_dir(os.path.dirname(path) or ".")

    # info dicts are not np arrays; store json separately in the npz
    packed = {
        "metadata_json": json.dumps(metadata, ensure_ascii=False),
        "info_json": json.dumps(demo.get("info", []), ensure_ascii=False, default=str),
        "action": demo["action"],
        "reward": demo["reward"],
        "terminated": demo["terminated"],
        "truncated": demo["truncated"],
    }
    for k, v in demo["obs"].items():
        packed[f"obs__{k}"] = v

    np.savez_compressed(path, **packed)
    return path


def save_demo_hdf5(path: str, demo: Dict[str, Any], metadata: Dict[str, Any]) -> str:
    """Save to hdf5 if h5py is installed."""

    try:
        import h5py  # type: ignore
    except Exception as e:
        raise RuntimeError("h5py is not installed; use save_demo_npz instead") from e

    ensure_dir(os.path.dirname(path) or ".")

    with h5py.File(path, "w") as f:
        f.attrs["metadata_json"] = json.dumps(metadata, ensure_ascii=False)
        f.create_dataset("action", data=demo["action"], compression="gzip")
        f.create_dataset("reward", data=demo["reward"], compression="gzip")
        f.create_dataset("terminated", data=demo["terminated"], compression="gzip")
        f.create_dataset("truncated", data=demo["truncated"], compression="gzip")
        obs_grp = f.create_group("obs")
        for k, v in demo["obs"].items():
            obs_grp.create_dataset(k, data=v, compression="gzip")
        f.attrs["info_json"] = json.dumps(demo.get("info", []), ensure_ascii=False, default=str)

    return path


def try_save_demo(path_base: str, demo: Dict[str, Any], metadata: Dict[str, Any]) -> str:
    """Prefer hdf5 if available, otherwise npz."""

    try:
        return save_demo_hdf5(path_base + ".hdf5", demo, metadata)
    except Exception:
        return save_demo_npz(path_base + ".npz", demo, metadata)
相关推荐
2501_948122632 小时前
React Native for OpenHarmony 实战:Steam 资讯 App 设置页面
javascript·react native·react.js·游戏·ecmascript·harmonyos
斯文by累2 小时前
Ubuntu系统上安装Kafka 8.0
linux·ubuntu·kafka
2501_948122632 小时前
React Native for OpenHarmony 实战:Steam 资讯 App 意见反馈实现
javascript·react native·react.js·游戏·ecmascript·harmonyos
HIT_Weston12 小时前
95、【Ubuntu】【Hugo】搭建私人博客:_default&partials
linux·运维·ubuntu
论迹13 小时前
【Git】-- Git安装 & 卸载(ubuntu)
git·ubuntu·elasticsearch
i橡皮擦14 小时前
TheIsle恐龙岛读取游戏基址做插件(C#语言)
开发语言·游戏·c#·恐龙岛·theisle
论迹14 小时前
【Git】-- Git基本操作
git·ubuntu
相偎15 小时前
Ubuntu搭建svn服务器
服务器·ubuntu·svn
oMcLin18 小时前
如何在Ubuntu 22.04 LTS上配置并优化MySQL 8.0分区表,提高大规模数据集查询的效率与性能?
android·mysql·ubuntu