【通识】Unitree RL Lab -模型格式与转换

从训练到部署:模型格式与转换原理

本文档介绍 Isaac Lab / RSL-RL 训练完成后,模型从 PyTorch 训练产物转换为可部署格式的完整原理与流程。理解这些概念是完成 Sim-to-Real 部署的必备基础。


目录

  1. 三种核心格式概览
  2. [PyTorch Checkpoint --- 训练产物](#PyTorch Checkpoint — 训练产物)
  3. [TorchScript --- PyTorch 的编译产物](#TorchScript — PyTorch 的编译产物)
  4. [ONNX --- 通用中间表示](#ONNX — 通用中间表示)
  5. [PyTorch → ONNX 转换原理](#PyTorch → ONNX 转换原理)
  6. 三种格式对比总结
  7. 在本项目中的实际流转
  8. 直观类比

三种核心格式概览

在强化学习 Sim-to-Real 部署中,你会频繁接触以下三种文件格式:

格式 典型后缀 本质 用途
PyTorch Checkpoint .pt, .pth Python 字典(权重 + 训练状态) 训练保存、恢复、调试
TorchScript .pt 静态计算图 + 权重 C++ 部署(LibTorch)
ONNX .onnx 通用中间表示(跨框架计算图) 多硬件加速、跨平台推理

注意.pt 后缀既可能是 Checkpoint,也可能是 TorchScript,需要根据内容和加载方式区分。


1. PyTorch Checkpoint --- 训练产物

是什么?

训练过程中 RSL-RL 保存的 model_1000.pt(或 policy.pt)本质上是一个 Python 字典(dict) ,通过 torch.save() 序列化到磁盘:

python 复制代码
checkpoint = {
    "model_state_dict": {               # 神经网络每一层的权重张量
        "actor.0.weight": tensor([...]),
        "actor.0.bias": tensor([...]),
        "critic.0.weight": tensor([...]),
        ...
    },
    "optimizer_state_dict": { ... },    # 优化器状态(Adam 的动量等)
    "iter": 1000,                       # 当前迭代轮数
    "best_reward": 42.0,                # 最佳奖励记录
}
torch.save(checkpoint, "model_1000.pt")

怎么加载?

Checkpoint 只是数据 ,不是可执行程序。加载时必须先定义网络结构:

python 复制代码
# 1. 定义模型结构(Python 类)
model = ActorCriticNet(obs_dim=47, action_dim=12)

# 2. 加载权重字典
checkpoint = torch.load("model_1000.pt")
model.load_state_dict(checkpoint["model_state_dict"])

为什么不适合直接部署?

  • 依赖 Python 源码 :没有 ActorCriticNet 的类定义,权重无法使用
  • 包含冗余信息:优化器状态、日志等占用空间且推理不需要
  • 无法在 C++ 运行:嵌入式设备、机器人控制器通常没有 Python 环境
  • 动态图开销:每次前向传播都需要 Python 解释器参与

2. TorchScript --- PyTorch 的编译产物

TorchScript 是 PyTorch 提供的序列化方案 ,目标是把 Python 写的动态模型变成静态、可移植的计算图

两种生成方式

方式 API 原理 适用场景
Tracing torch.jit.trace(model, example_input) 喂一个示例输入,记录实际执行的所有算子,生成静态图 结构固定、无数据依赖控制流的网络
Scripting torch.jit.script(model) 直接解析 Python 代码的 AST,支持控制流 包含 if / for 等动态逻辑的模型

play.py 中的实际做法

isaaclab_rl.rsl_rl 提供的 export_policy_as_jit() 内部大致逻辑:

python 复制代码
# 构造一个带假输入的示例观测
example_obs = torch.zeros(1, obs_dim, device=device)

# 使用 Tracing 生成 TorchScript
traced_model = torch.jit.trace(model, example_obs)

# 保存为独立的可执行文件
traced_model.save("exported/policy.pt")

Tracing 过程示意

你的 Python 模型(动态执行):

python 复制代码
def forward(self, x):
    x = self.linear1(x)     # [batch, 47] @ [47, 256] → [batch, 256]
    x = self.activation(x)  # ReLU
    x = self.linear2(x)     # [batch, 256] @ [256, 12] → [batch, 12]
    return x

Trace 后变成静态计算图(一份与 Python 代码无关的执行蓝图):

复制代码
graph(%x : Float(1, 47)):
  %1 = aten::linear(%x, %weight1, %bias1)
  %2 = aten::relu(%1)
  %3 = aten::linear(%2, %weight2, %bias2)
  return (%3)

C++ 端如何加载?

cpp 复制代码
#include <torch/script.h>

// 加载 TorchScript 模型(不需要 Python 源码)
torch::jit::script::Module policy = torch::jit::load("policy.pt");

// 构造输入张量
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::zeros({1, 47}));

// 前向推理
auto output = policy.forward(inputs).toTensor();

优点与局限

优点 局限
不依赖 Python 源码,C++ 直接运行 Tracing 遇到数据依赖的控制流(如 if x.sum() > 0)会出错
权重和结构打包在一个文件 仅支持 PyTorch 生态,TensorFlow 无法加载
可用 LibTorch 在 ARM/x86 嵌入式设备部署 动态 shape 支持不如 ONNX 灵活

在本项目中,loco_client_sim 的 C++ 端正是通过 LibTorch 加载 TorchScript 来执行策略推理的。


3. ONNX --- 通用中间表示

ONNX(Open Neural Network Exchange)不是某个公司的私有格式,而是一个开放的跨框架中间标准。它的定位类似编程语言里的 LLVM IR,或文档界的 PDF。

核心思想

无论你用 PyTorchTensorFlowMXNet 还是 PaddlePaddle 训练模型,最终都可以转成 ONNX,然后:

  • NVIDIA GPU → TensorRT 加速推理
  • Intel CPU → OpenVINO 加速
  • 移动端 → ONNX RuntimeMNN
  • 甚至 FPGA / ASIC 专用芯片

ONNX 的本质

.onnx 文件是一个 Protobuf(Protocol Buffers) 序列化的文件,内部描述了一张计算图(Graph)

protobuf 复制代码
graph {
  name: "policy"
  input { name: "obs", type: { tensor_type { elem_type: FLOAT, shape: { dim { dim_value: 1 }, dim { dim_value: 47 } } } } }
  output { name: "actions", type: { ... } }

  node {
    op_type: "MatMul"           # 算子类型:矩阵乘
    input: ["obs", "weight1"]
    output: ["gemm1_out"]
  }
  node {
    op_type: "Relu"
    input: ["gemm1_out"]
    output: ["relu1_out"]
  }
  node {
    op_type: "MatMul"
    input: ["relu1_out", "weight2"]
    output: ["actions"]
  }
}

注意 :ONNX 只定义了算子标准数据流,不包含训练相关的信息(如优化器、损失函数)。它是一份"说明书":输入什么、经过哪些算子、输出什么。


PyTorch → ONNX 转换原理

调用链路

复制代码
PyTorch nn.Module
      ↓
torch.onnx.export()
      ↓
ONNX 计算图(Graph)
      ↓
保存为 .onnx 文件

torch.onnx.export() 内部做了什么?

play.pyexport_policy_as_onnx() 为例,其底层调用大致如下:

python 复制代码
torch.onnx.export(
    model,                          # 要转换的 PyTorch 模型
    example_input,                  # 示例输入(用于 trace 和推断 shape)
    "exported/policy.onnx",         # 输出文件路径
    input_names=["obs"],            # 输入节点名称
    output_names=["actions"],       # 输出节点名称
    opset_version=17,               # ONNX 算子集版本(版本越高支持算子越多)
    dynamic_axes={                  # 声明哪些维度是动态的(如 batch)
        "obs": {0: "batch_size"},
        "actions": {0: "batch_size"}
    },
)

详细转换过程

Step 1: Tracing

example_input 跑一次前向传播,PyTorch 会记录所有执行的 ATen 算子,生成一个内部的计算图:

python 复制代码
# 你的模型执行
obs → linear → relu → linear → actions

# 内部记录的 ATen 算子
aten::linear → aten::relu → aten::linear
Step 2: 算子映射(ATen → ONNX)

PyTorch 维护了一张巨大的映射表,将 ATen 算子翻译为 ONNX 标准算子:

PyTorch ATen 算子 ONNX 标准算子 说明
aten::relu Relu 激活函数
aten::linear Gemm 通用矩阵乘(General Matrix Multiply)
aten::conv2d Conv 卷积
aten::batch_norm BatchNormalization 批归一化
aten::lstm LSTM 长短期记忆网络
aten::cat Concat 张量拼接
aten::transpose Transpose 维度转置

如果遇到 ONNX 不支持的算子,导出会报错,需要自定义算子或改写模型。

Step 3: 图优化

导出器会进行一系列优化:

  • 算子融合 :相邻的 Conv + ReLU 合并为一个节点
  • 死节点消除:删除对输出无贡献的中间计算
  • 常量折叠:将编译期可确定的常量表达式直接计算结果
Step 4: Protobuf 序列化

最终,用 Google Protobuf 将图结构、权重张量、元数据写入 .onnx 文件。

归一化器(Normalizer)的处理 ------ 关键细节

训练时观测通常做了动态归一化:

python 复制代码
obs_normalized = (obs - mean) / sqrt(var + eps)

如果导出的模型只包含神经网络本身,C++ 端还需要额外手动实现归一化,容易因参数不一致导致部署行为异常。

Isaac Lab 的 clever 做法 :在导出前,把 normalizer 也包装进模型内部

python 复制代码
class PolicyWithNormalizer(nn.Module):
    def forward(self, obs):
        # 第一步:归一化(使用训练时统计的 mean/var)
        obs = (obs - self.mean) / torch.sqrt(self.var + 1e-8)
        # 第二步:送网络推理
        return self.policy_net(obs)

这样导出的 ONNX / TorchScript 是端到端的:

复制代码
原始 obs  →  [Normalizer]  →  [Neural Net]  →  actions

C++ 端只需直接喂原始观测,无需关心归一化逻辑。


三种格式对比总结

特性 PyTorch Checkpoint (.pth) TorchScript (.pt) ONNX (.onnx)
本质 权重字典 + 训练状态 静态计算图 + 权重 通用计算图 + 权重
运行时依赖 Python + PyTorch 源码 LibTorch(C++) ONNX Runtime / TensorRT / OpenVINO
跨框架 ❌ 仅限 PyTorch ❌ 仅限 PyTorch ✅ PyTorch/TensorFlow/MXNet 互通
部署场景 训练、调试、断点续训 C++ 嵌入式、机器人控制器 多硬件加速、云端推理、异构计算
控制流支持 Python 原生 Tracing 有限制 有限支持(Loop/If 算子)
动态 Shape 完全支持 部分支持 良好(需声明 dynamic_axes
本项目用途 训练保存的中间产物 loco_client_sim C++ 端加载 备选方案 / 其他推理引擎

在本项目中的实际流转

以 G1 强化学习策略从训练到部署为例:

复制代码
┌─────────────────────────────────────────────────────────────┐
│  1. Isaac Lab 训练(GPU 服务器)                               │
│     RSL-RL OnPolicyRunner 训练策略                           │
│                      │                                      │
│                      ▼                                      │
│  logs/rsl_rl/g1_flat/.../checkpoints/model_1000.pt         │
│     (PyTorch Checkpoint:含权重 + 训练状态)                  │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│  2. 运行 play.py 导出(推理 + 转换)                          │
│                                                              │
│     python scripts/rsl_rl/play.py --task=...               │
│                                                              │
│     ├──► exported/policy.pt   (TorchScript)                │
│     │           │                                            │
│     │           ▼                                            │
│     │    control_simulator / loco_client_sim (C++ 端)       │
│     │           │                                            │
│     │           ▼                                            │
│     │    torch::jit::load("policy.pt")                     │
│     │    MuJoCo 仿真 或 真机部署 (NVIDIA Orin)               │
│     │                                                        │
│     └──► exported/policy.onnx (ONNX)                       │
│                 │                                            │
│                 ▼                                            │
│         TensorRT / ONNX Runtime(可选加速方案)                │
└─────────────────────────────────────────────────────────────┘

关键节点说明

  1. Checkpoint 是起点:训练完成后,checkpoint 保存了策略的所有权重。
  2. play.py 是转换器:运行推理时,脚本自动调用导出函数,生成 TorchScript 和 ONNX。
  3. TorchScript 是主线loco_client_sim 的 C++ 代码通过 LibTorch 加载 policy.pt,在 MuJoCo 仿真或真机上实时推理。
  4. ONNX 是备选:如果需要使用 TensorRT 做 GPU 推理加速,或在非 PyTorch 生态中运行,ONNX 是桥梁。

直观类比

概念 类比 说明
PyTorch Checkpoint (.pth) 建筑设计草图 + 材料清单 只有原设计师拿着草图和清单,才能复原整栋建筑
TorchScript (.pt) 完整的建筑施工图 任何按图施工的工人(LibTorch)都能建造,不需要原设计师在场
ONNX (.onnx) 国际通用的建筑标准图纸 中国的施工队、德国的工程师、日本的机器人都能看懂并施工

另一个角度:

概念 类比
Checkpoint Python 源代码 + 数据文件 ------ 需要解释器
TorchScript 编译后的二进制 ------ 直接运行,但只能在 PyTorch 虚拟机里跑
ONNX JVM 字节码或 WebAssembly ------ 只要有对应的运行时,哪里都能跑

延伸阅读

相关推荐
AIGS0011 小时前
生产运营三大瓶颈,工业AI怎么破局?
java·人工智能·人工智能ai大模型应用
TMT星球1 小时前
星源智剑指物理AI的“智谱”,智源研究院持续助力世界模型
人工智能
运维行者_1 小时前
通过Applications Manager的TCP监控确保无缝网络连接
运维·服务器·网络·数据库·人工智能
小鹿软件办公1 小时前
微软发布旗下首款具备思考能力的 MAI-Thinking-1 AI 模型
人工智能·microsoft·语音识别
段一凡-华北理工大学1 小时前
工业领域的Hadoop架构学习~系列文章12:Hadoop集群监控与运维
大数据·人工智能·hadoop·学习·架构·高炉炼铁·高炉炼铁智能化
澜舟孟子开源社区1 小时前
澜舟观点:OpenClaw引爆OPC浪潮,亟需理性引导与科技赋能
人工智能·科技
无心水1 小时前
【Harness:落地实战】23、从CI/CD到AI原生底座:Harness平台全景深度解析——现代软件交付的最终答案?
人工智能·ci/cd·ai-native·openclaw·harness·hermes·honcho
sali-tec1 小时前
C# 基于OpenCv的视觉工作流-章82-毛刺检测
图像处理·人工智能·opencv·算法·计算机视觉
我爱cope1 小时前
【Agent智能体18 | 构建AI工作流的技巧-评估】
人工智能·语言模型·职场和发展