openpi 入门教程

系列文章目录

目录

系列文章目录

前言

一、运行要求

二、安装

三、模型检查点

[3.1 基础模型](#3.1 基础模型)

[3.2 微调模型](#3.2 微调模型)

四、运行预训练模型的推理

五、在自己的数据上微调基础模型

[5.1. 将数据转换为 LeRobot 数据集](#5.1. 将数据转换为 LeRobot 数据集)

[5.3. 启动策略服务器并运行推理](#5.3. 启动策略服务器并运行推理)

[5.4 更多示例](#5.4 更多示例)

六、故障排除

[七、远程运行 openpi 模型](#七、远程运行 openpi 模型)

[7.1 启动远程策略服务器](#7.1 启动远程策略服务器)

[7.2 从机器人代码中查询远程策略服务器](#7.2 从机器人代码中查询远程策略服务器)

八、推理教程

[8.1 策略推断](#8.1 策略推断)

[8.2 使用实时模型](#8.2 使用实时模型)

九、策略记录代码


前言

openpi 包含物理智能团队发布的机器人开源模型和软件包。

目前,该 repo 包含两种模型:

  • π₀ 模型,一种基于流的扩散视觉-语言-动作模型 (VLA)
  • π₀-FAST 模型,一种基于 FAST 动作标记器的自回归 VLA。

对于这两种模型,我们都提供了在 10K+ 小时的机器人数据上预先训练过的基本模型检查点,以及用于开箱即用或根据您自己的数据集进行微调的示例。

这是一次实验:π0 是为我们自己的机器人开发的,与 ALOHA 和 DROID 等广泛使用的平台不同,尽管我们乐观地认为,研究人员和从业人员将能够进行创造性的新实验,将π0 适应到他们自己的平台上,但我们并不指望每一次这样的尝试都能成功。综上所述:π0 可能对你有用,也可能对你没用,但我们欢迎你去试试看!


一、运行要求

要运行本资源库中的模型,您需要至少具备以下规格的英伟达™(NVIDIA®)图形处理器。这些估算假设使用的是单 GPU,但您也可以通过在训练配置中配置 fsdp_devices,使用多 GPU 并行模型来减少每个 GPU 的内存需求。还请注意,当前的训练脚本还不支持多节点训练。

Mode Memory Required Example GPU
Inference > 8 GB RTX 4090
Fine-Tuning (LoRA) > 22.5 GB RTX 4090
Fine-Tuning (Full) > 70 GB A100 (80GB) / H100

该软件包已在 Ubuntu 22.04 上进行了测试,目前不支持其他操作系统。

二、安装

克隆此 repo 时,确保更新子模块:

bash 复制代码
git clone --recurse-submodules [email protected]:Physical-Intelligence/openpi.git

# Or if you already cloned the repo:
git submodule update --init --recursive

我们使用 uv 来管理 Python 的依赖关系。请参阅 uv 安装说明进行设置。安装好 uv 后,运行以下命令来设置环境:

bash 复制代码
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .

注意:需要 GIT_LFS_SKIP_SMUDGE=1 才能将 LeRobot 作为依赖项。

Docker 作为 uv 安装的替代方案,我们提供了使用 Docker 安装 openpi 的说明。如果遇到系统设置问题,可以考虑使用 Docker 简化安装。更多详情,请参阅 Docker 安装。

三、模型检查点

3.1 基础模型

我们提供多个基础 VLA 模型检查点。这些检查点已在 10k+ 小时的机器人数据上进行了预训练,可用于微调。

Model Use Case Description Checkpoint Path
π0 Fine-Tuning Base diffusion π₀ model for fine-tuning s3://openpi-assets/checkpoints/pi0_base
π0-FAST Fine-Tuning Base autoregressive π₀-FAST model for fine-tuning s3://openpi-assets/checkpoints/pi0_fast_base

3.2 微调模型

我们还为各种机器人平台和任务提供 "专家 "检查点。这些模型在上述基础模型的基础上进行了微调,旨在直接在目标机器人上运行。这些模型不一定适用于您的特定机器人。由于这些检查点是在使用 ALOHA 和 DROID Franka 等更广泛使用的机器人收集的相对较小的数据集上进行微调的,因此它们可能无法适用于您的特定设置,不过我们发现其中一些检查点,尤其是 DROID 检查点,在实践中具有相当广泛的适用性。

Model Use Case Description Checkpoint Path
π0-FAST-DROID Inference π0-FAST model fine-tuned on the DROID dataset, can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform s3://openpi-assets/checkpoints/pi0_fast_droid
π0-DROID Fine-Tuning π0 model fine-tuned on the DROID dataset, faster inference than π0-FAST-DROID, but may not follow language commands as well s3://openpi-assets/checkpoints/pi0_droid
π0-ALOHA-towel Inference π0 model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on ALOHA robot platforms s3://openpi-assets/checkpoints/pi0_aloha_towel
π0-ALOHA-tupperware Inference π0 model fine-tuned on internal ALOHA data, can unpack food from a tupperware container s3://openpi-assets/checkpoints/pi0_aloha_tupperware
π0-ALOHA-pen-uncap Inference π0 model fine-tuned on public ALOHA data, can uncap a pen s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap

默认情况下,检查点会自动从 s3://openpi-assets 下载,并在需要时缓存到 ~/.cache/openpi 中。你可以通过设置 OPENPI_DATA_HOME 环境变量来覆盖下载路径。

四、运行预训练模型的推理

我们的预训练模型检查点只需几行代码即可运行(此处为我们的 π0-FAST-DROID 模型):

python 复制代码
from openpi.training import config
from openpi.policies import policy_config
from openpi.shared import download

config = config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example.
example = {
    "observation/exterior_image_1_left": ...,
    "observation/wrist_image_left": ...,
    ...
    "prompt": "pick up the fork"
}
action_chunk = policy.infer(example)["actions"]

您也可以在示例笔记本中进行测试。

我们提供了在 DROID 和 ALOHA 机器人上运行预训练检查点推理的详细分步示例。

  • 远程推理: 我们提供了远程运行模型推理的示例和代码:模型可以在不同的服务器上运行,并通过 websocket 连接向机器人发送动作流。这样就可以轻松地在机器人外使用更强大的 GPU,并将机器人和策略环境分开。
  • 在没有机器人的情况下测试推理: 我们提供了一个脚本,用于在没有机器人的情况下测试推理。该脚本将生成随机观测数据,并使用模型运行推理。更多详情,请参阅此处。

五、在自己的数据上微调基础模型

我们将在 Libero 数据集上微调 π0-FAST 模型,作为如何在自己的数据上微调基础模型的运行示例。我们将解释三个步骤:

  1. 将您的数据转换为 LeRobot 数据集(我们使用该数据集进行训练)
  2. 定义训练配置并运行训练
  3. 启动策略服务器并运行推理

5.1. 将数据转换为 LeRobot 数据集

我们在 examples/libero/convert_libero_data_to_lerobot.py 中提供了将 Libero 数据转换为 LeRobot 数据集的最小示例脚本。您可以轻松修改它,转换自己的数据!您可以从这里下载原始的 Libero 数据集,并使用以下命令运行脚本:

python 复制代码
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data

5.2. 定义训练配置和运行训练

要在自己的数据上对基础模型进行微调,您需要定义用于数据处理和训练的配置。下面我们提供了带有详细注释的 Libero 配置示例,您可以根据自己的数据集进行修改:

  • LiberoInputs 和 LiberoOutputs: 定义从 Libero 环境到模型的数据映射,反之亦然。将用于训练和推理。
  • LeRobotLiberoDataConfig: 定义如何处理 LeRobot 数据集中用于训练的 Libero 原始数据。
  • TrainConfig:训练配置: 定义微调超参数、数据配置和权重加载器。

我们提供了π₀和π₀-FAST 在 Libero 数据上的微调配置示例。

在运行训练之前,我们需要计算训练数据的归一化统计量。使用训练配置的名称运行下面的脚本:

bash 复制代码
uv run scripts/compute_norm_stats.py --config-name pi0_fast_libero

现在,我们可以使用以下命令启动训练(如果使用相同配置重新运行微调,则 --overwrite 标志用于覆盖现有检查点):

bash 复制代码
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp-name=my_experiment --overwrite

该命令会将训练进度记录到控制台,并将检查点保存到检查点目录。您还可以在权重与偏差仪表板上监控训练进度。为了最大限度地使用 GPU 内存,请在运行训练之前设置 XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -- 这将使 JAX 能够使用高达 90% 的 GPU 内存(默认值为 75%)。

注:我们提供了从预训练开始重新加载状态/动作归一化统计数据的功能。如果您要对预训练混合物中的机器人新任务进行微调,这将非常有用。有关如何重新加载归一化统计数据的详细信息,请参阅 norm_stats.md 文件。

5.3. 启动策略服务器并运行推理

训练完成后,我们就可以启动策略服务器,然后通过 Libero 评估脚本进行查询,从而运行推理。启动模型服务器非常简单(本例使用迭代 20,000 的检查点,可根据需要修改):

bash 复制代码
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero --policy.dir=checkpoints/pi0_fast_libero/my_experiment/20000

这将启动一个服务器,该服务器监听 8000 端口,并等待向其发送观察结果。然后,我们就可以运行 Libero 评估脚本来查询服务器。有关如何安装 Libero 和运行评估脚本的说明,请参阅 Libero README。

如果你想在自己的机器人运行时中嵌入策略服务器调用,我们在远程推理文档中提供了一个最简单的示例。

5.4 更多示例

我们在以下 READMEs 中提供了更多示例,说明如何在 ALOHA 平台上使用我们的模型进行微调和推理:

  • ALOHA 模拟器
  • ALOHA 真实
  • UR5

六、故障排除

我们将在此收集常见问题及其解决方案。如果遇到问题,请先查看此处。如果找不到解决方案,请在软件仓库中提交问题(参见此处的指导原则)。

Issue Resolution
uv 同步因依赖关系冲突而失败 尝试删除虚拟环境目录(rm -rf .venv)并重新运行 uv 同步。如果问题仍然存在,请检查是否安装了最新版本的 uv(uv self update)。
训练耗尽 GPU 内存 确保在运行训练之前设置 XLA_PYTHON_CLIENT_MEM_FRACTION=0.9,以允许 JAX 使用更多 GPU 内存。您也可以尝试在训练配置中减少批量大小。
策略服务器连接错误 检查服务器是否正在运行,是否在预期端口上监听。验证客户端和服务器之间的网络连接和防火墙设置。
训练时缺失常模统计错误 在开始训练前使用配置名称运行 scripts/compute_norm_stats.py。
数据集下载失败 检查网络连接。如果使用 local_files_only=True,请确认数据集是否存在于本地。对于 HuggingFace 数据集,请确保已登录(huggingface-cli 登录)。
CUDA/GPU 错误 验证英伟达驱动程序和 CUDA 工具包是否安装正确。对于 Docker,确保已安装 nvidia-container-toolkit。检查 GPU 兼容性。
运行示例时出现导入错误 确保使用 uv sync 安装了所有依赖项并激活了虚拟环境。某些示例的 READMEs 中可能列出了其他要求。
动作尺寸不匹配 验证您的数据处理转换是否与机器人的预期输入/输出尺寸相匹配。检查策略类中的动作空间定义。

七、远程运行 openpi 模型

我们提供了远程运行 openpi 模型的实用程序。这对于在机器人外更强大的 GPU 上运行推理非常有用,还有助于将机器人环境和策略环境分开(例如,避免机器人软件的依赖性地狱)。

7.1 启动远程策略服务器

要启动远程策略服务器,只需运行以下命令即可:

bash 复制代码
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]

env 参数指定应加载哪个 π0 检查点。在脚本引擎盖下,该脚本将执行类似下面的命令,你可以用它来启动策略服务器,例如为你自己训练的检查点启动策略服务器(这里以 DROID 环境为例):

bash 复制代码
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid

这将启动一个策略服务器,为 config 和 dir 参数指定的策略提供服务。策略将通过指定端口(默认:8000)提供。

7.2 从机器人代码中查询远程策略服务器

我们提供的客户端实用程序依赖性极低,您可以轻松将其嵌入到任何机器人代码库中。

首先,在机器人环境中安装 openpi-client 软件包:

bash 复制代码
cd $OPENPI_ROOT/packages/openpi-client
pip install -e .

然后,您就可以使用客户端从机器人代码中查询远程策略服务器。下面举例说明如何做到这一点:

python 复制代码
from openpi_client import image_tools
from openpi_client import websocket_client_policy

# Outside of episode loop, initialize the policy client.
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)

for step in range(num_steps):
    # Inside the episode loop, construct the observation.
    # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
    # We provide utilities for resizing images + uint8 conversion so you match the training routines.
    # The typical resize_size for pre-trained pi0 models is 224.
    # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
    observation = {
        "observation/image": image_tools.convert_to_uint8(
            image_tools.resize_with_pad(img, 224, 224)
        ),
        "observation/wrist_image": image_tools.convert_to_uint8(
            image_tools.resize_with_pad(wrist_img, 224, 224)
        ),
        "observation/state": state,
        "prompt": task_instruction,
    }

    # Call the policy server with the current observation.
    # This returns an action chunk of shape (action_horizon, action_dim).
    # Note that you typically only need to call the policy every N steps and execute steps
    # from the predicted action chunk open-loop in the remaining steps.
    action_chunk = client.infer(observation)["actions"]

    # Execute the actions in the environment.
    ...

这里,主机和端口参数指定了远程策略服务器的 IP 地址和端口。您也可以将这些参数指定为机器人代码的命令行参数,或在机器人代码库中硬编码。观察结果是观察结果和提示的字典,与您所服务的策略的策略输入相一致。在简单的客户端示例中,我们提供了如何在不同环境下构建该字典的具体示例。

八、推理教程

python 复制代码
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

8.1 策略推断

下面的示例展示了如何从检查点创建策略,并在虚拟示例上运行推理。

python 复制代码
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

8.2 使用实时模型

下面的示例展示了如何从检查点创建实时模型并计算训练损失。首先,我们将演示如何使用假数据。

python 复制代码
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

现在,我们将创建一个数据加载器,并使用一批真实的训练数据来计算损失。

python 复制代码
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)

九、策略记录代码

python 复制代码
import pathlib

import numpy as np

record_path = pathlib.Path("../policy_records")
num_steps = len(list(record_path.glob("step_*.npy")))

records = []
for i in range(num_steps):
    record = np.load(record_path / f"step_{i}.npy", allow_pickle=True).item()
    records.append(record)
python 复制代码
print("length of records", len(records))
print("keys in records", records[0].keys())

for k in records[0]:
    print(f"{k} shape: {records[0][k].shape}")
python 复制代码
from PIL import Image


def get_image(step: int, idx: int = 0):
    img = (255 * records[step]["inputs/image"]).astype(np.uint8)
    return img[idx].transpose(1, 2, 0)


def show_image(step: int, idx_lst: list[int]):
    imgs = [get_image(step, idx) for idx in idx_lst]
    return Image.fromarray(np.hstack(imgs))


for i in range(2):
    display(show_image(i, [0])
python 复制代码
import pandas as pd


def get_axis(name, axis):
    return np.array([record[name][axis] for record in records])


# qpos is [..., 14] of type float:
# 0-5: left arm joint angles
# 6: left arm gripper
# 7-12: right arm joint angles
# 13: right arm gripper
names = [("left_joint", 6), ("left_gripper", 1), ("right_joint", 6), ("right_gripper", 1)]


def make_data():
    cur_dim = 0
    in_data = {}
    out_data = {}
    for name, dim_size in names:
        for i in range(dim_size):
            in_data[f"{name}_{i}"] = get_axis("inputs/qpos", cur_dim)
            out_data[f"{name}_{i}"] = get_axis("outputs/qpos", cur_dim)
            cur_dim += 1
    return pd.DataFrame(in_data), pd.DataFrame(out_data)


in_data, out_data = make_data()
python 复制代码
for name in in_data.columns:
    data = pd.DataFrame({f"in_{name}": in_data[name], f"out_{name}": out_data[name]})
    data.plot()
相关推荐
每天一个秃顶小技巧3 分钟前
02.Golang 切片(slice)源码分析(一、定义与基础操作实现)
开发语言·后端·python·golang
ai产品老杨4 分钟前
AI赋能安全生产,推进数智化转型的智慧油站开源了。
前端·javascript·vue.js·人工智能·ecmascript
机器人之树小风35 分钟前
KUKA机器人安装包选项KUKA.PLC mxAutomation软件
经验分享·科技·机器人
明月醉窗台36 分钟前
[20250507] AI边缘计算开发板行业调研报告 (2024年最新版)
人工智能·边缘计算
Blossom.1181 小时前
低代码开发:开启软件开发的新篇章
人工智能·深度学习·安全·低代码·机器学习·计算机视觉·数据挖掘
安特尼1 小时前
招行数字金融挑战赛数据赛道赛题一
人工智能·python·机器学习·金融·数据分析
带娃的IT创业者1 小时前
《AI大模型应知应会100篇》第59篇:Flowise:无代码搭建大模型应用
人工智能
serve the people1 小时前
解决osx-arm64平台上conda默认源没有提供 python=3.7 的官方编译版本的问题
开发语言·python·conda
深蓝学院1 小时前
降低60.6%碰撞率!复旦大学&地平线CorDriver:首次引入「走廊」增强端到端自动驾驶安全性
自动驾驶
柒七爱吃麻辣烫2 小时前
在Linux中安装JDK并且搭建Java环境
java·linux·开发语言