从训练到部署:模型格式与转换原理
本文档介绍 Isaac Lab / RSL-RL 训练完成后,模型从 PyTorch 训练产物转换为可部署格式的完整原理与流程。理解这些概念是完成 Sim-to-Real 部署的必备基础。
目录
- 三种核心格式概览
- [PyTorch Checkpoint --- 训练产物](#PyTorch Checkpoint — 训练产物)
- [TorchScript --- PyTorch 的编译产物](#TorchScript — PyTorch 的编译产物)
- [ONNX --- 通用中间表示](#ONNX — 通用中间表示)
- [PyTorch → ONNX 转换原理](#PyTorch → ONNX 转换原理)
- 三种格式对比总结
- 在本项目中的实际流转
- 直观类比
三种核心格式概览
在强化学习 Sim-to-Real 部署中,你会频繁接触以下三种文件格式:
| 格式 | 典型后缀 | 本质 | 用途 |
|---|---|---|---|
| PyTorch Checkpoint | .pt, .pth |
Python 字典(权重 + 训练状态) | 训练保存、恢复、调试 |
| TorchScript | .pt |
静态计算图 + 权重 | C++ 部署(LibTorch) |
| ONNX | .onnx |
通用中间表示(跨框架计算图) | 多硬件加速、跨平台推理 |
注意 :
.pt后缀既可能是 Checkpoint,也可能是 TorchScript,需要根据内容和加载方式区分。
1. PyTorch Checkpoint --- 训练产物
是什么?
训练过程中 RSL-RL 保存的 model_1000.pt(或 policy.pt)本质上是一个 Python 字典(dict) ,通过 torch.save() 序列化到磁盘:
python
checkpoint = {
"model_state_dict": { # 神经网络每一层的权重张量
"actor.0.weight": tensor([...]),
"actor.0.bias": tensor([...]),
"critic.0.weight": tensor([...]),
...
},
"optimizer_state_dict": { ... }, # 优化器状态(Adam 的动量等)
"iter": 1000, # 当前迭代轮数
"best_reward": 42.0, # 最佳奖励记录
}
torch.save(checkpoint, "model_1000.pt")
怎么加载?
Checkpoint 只是数据 ,不是可执行程序。加载时必须先定义网络结构:
python
# 1. 定义模型结构(Python 类)
model = ActorCriticNet(obs_dim=47, action_dim=12)
# 2. 加载权重字典
checkpoint = torch.load("model_1000.pt")
model.load_state_dict(checkpoint["model_state_dict"])
为什么不适合直接部署?
- 依赖 Python 源码 :没有
ActorCriticNet的类定义,权重无法使用 - 包含冗余信息:优化器状态、日志等占用空间且推理不需要
- 无法在 C++ 运行:嵌入式设备、机器人控制器通常没有 Python 环境
- 动态图开销:每次前向传播都需要 Python 解释器参与
2. TorchScript --- PyTorch 的编译产物
TorchScript 是 PyTorch 提供的序列化方案 ,目标是把 Python 写的动态模型变成静态、可移植的计算图。
两种生成方式
| 方式 | API | 原理 | 适用场景 |
|---|---|---|---|
| Tracing | torch.jit.trace(model, example_input) |
喂一个示例输入,记录实际执行的所有算子,生成静态图 | 结构固定、无数据依赖控制流的网络 |
| Scripting | torch.jit.script(model) |
直接解析 Python 代码的 AST,支持控制流 | 包含 if / for 等动态逻辑的模型 |
play.py 中的实际做法
isaaclab_rl.rsl_rl 提供的 export_policy_as_jit() 内部大致逻辑:
python
# 构造一个带假输入的示例观测
example_obs = torch.zeros(1, obs_dim, device=device)
# 使用 Tracing 生成 TorchScript
traced_model = torch.jit.trace(model, example_obs)
# 保存为独立的可执行文件
traced_model.save("exported/policy.pt")
Tracing 过程示意
你的 Python 模型(动态执行):
python
def forward(self, x):
x = self.linear1(x) # [batch, 47] @ [47, 256] → [batch, 256]
x = self.activation(x) # ReLU
x = self.linear2(x) # [batch, 256] @ [256, 12] → [batch, 12]
return x
Trace 后变成静态计算图(一份与 Python 代码无关的执行蓝图):
graph(%x : Float(1, 47)):
%1 = aten::linear(%x, %weight1, %bias1)
%2 = aten::relu(%1)
%3 = aten::linear(%2, %weight2, %bias2)
return (%3)
C++ 端如何加载?
cpp
#include <torch/script.h>
// 加载 TorchScript 模型(不需要 Python 源码)
torch::jit::script::Module policy = torch::jit::load("policy.pt");
// 构造输入张量
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::zeros({1, 47}));
// 前向推理
auto output = policy.forward(inputs).toTensor();
优点与局限
| 优点 | 局限 |
|---|---|
| 不依赖 Python 源码,C++ 直接运行 | Tracing 遇到数据依赖的控制流(如 if x.sum() > 0)会出错 |
| 权重和结构打包在一个文件 | 仅支持 PyTorch 生态,TensorFlow 无法加载 |
| 可用 LibTorch 在 ARM/x86 嵌入式设备部署 | 动态 shape 支持不如 ONNX 灵活 |
在本项目中,loco_client_sim 的 C++ 端正是通过 LibTorch 加载 TorchScript 来执行策略推理的。
3. ONNX --- 通用中间表示
ONNX(Open Neural Network Exchange)不是某个公司的私有格式,而是一个开放的跨框架中间标准。它的定位类似编程语言里的 LLVM IR,或文档界的 PDF。
核心思想
无论你用 PyTorch 、TensorFlow 、MXNet 还是 PaddlePaddle 训练模型,最终都可以转成 ONNX,然后:
- NVIDIA GPU → TensorRT 加速推理
- Intel CPU → OpenVINO 加速
- 移动端 → ONNX Runtime 或 MNN
- 甚至 FPGA / ASIC 专用芯片
ONNX 的本质
.onnx 文件是一个 Protobuf(Protocol Buffers) 序列化的文件,内部描述了一张计算图(Graph):
protobuf
graph {
name: "policy"
input { name: "obs", type: { tensor_type { elem_type: FLOAT, shape: { dim { dim_value: 1 }, dim { dim_value: 47 } } } } }
output { name: "actions", type: { ... } }
node {
op_type: "MatMul" # 算子类型:矩阵乘
input: ["obs", "weight1"]
output: ["gemm1_out"]
}
node {
op_type: "Relu"
input: ["gemm1_out"]
output: ["relu1_out"]
}
node {
op_type: "MatMul"
input: ["relu1_out", "weight2"]
output: ["actions"]
}
}
注意 :ONNX 只定义了算子标准 和数据流,不包含训练相关的信息(如优化器、损失函数)。它是一份"说明书":输入什么、经过哪些算子、输出什么。
PyTorch → ONNX 转换原理
调用链路
PyTorch nn.Module
↓
torch.onnx.export()
↓
ONNX 计算图(Graph)
↓
保存为 .onnx 文件
torch.onnx.export() 内部做了什么?
以 play.py 中 export_policy_as_onnx() 为例,其底层调用大致如下:
python
torch.onnx.export(
model, # 要转换的 PyTorch 模型
example_input, # 示例输入(用于 trace 和推断 shape)
"exported/policy.onnx", # 输出文件路径
input_names=["obs"], # 输入节点名称
output_names=["actions"], # 输出节点名称
opset_version=17, # ONNX 算子集版本(版本越高支持算子越多)
dynamic_axes={ # 声明哪些维度是动态的(如 batch)
"obs": {0: "batch_size"},
"actions": {0: "batch_size"}
},
)
详细转换过程
Step 1: Tracing
用 example_input 跑一次前向传播,PyTorch 会记录所有执行的 ATen 算子,生成一个内部的计算图:
python
# 你的模型执行
obs → linear → relu → linear → actions
# 内部记录的 ATen 算子
aten::linear → aten::relu → aten::linear
Step 2: 算子映射(ATen → ONNX)
PyTorch 维护了一张巨大的映射表,将 ATen 算子翻译为 ONNX 标准算子:
| PyTorch ATen 算子 | ONNX 标准算子 | 说明 |
|---|---|---|
aten::relu |
Relu |
激活函数 |
aten::linear |
Gemm |
通用矩阵乘(General Matrix Multiply) |
aten::conv2d |
Conv |
卷积 |
aten::batch_norm |
BatchNormalization |
批归一化 |
aten::lstm |
LSTM |
长短期记忆网络 |
aten::cat |
Concat |
张量拼接 |
aten::transpose |
Transpose |
维度转置 |
如果遇到 ONNX 不支持的算子,导出会报错,需要自定义算子或改写模型。
Step 3: 图优化
导出器会进行一系列优化:
- 算子融合 :相邻的
Conv + ReLU合并为一个节点 - 死节点消除:删除对输出无贡献的中间计算
- 常量折叠:将编译期可确定的常量表达式直接计算结果
Step 4: Protobuf 序列化
最终,用 Google Protobuf 将图结构、权重张量、元数据写入 .onnx 文件。
归一化器(Normalizer)的处理 ------ 关键细节
训练时观测通常做了动态归一化:
python
obs_normalized = (obs - mean) / sqrt(var + eps)
如果导出的模型只包含神经网络本身,C++ 端还需要额外手动实现归一化,容易因参数不一致导致部署行为异常。
Isaac Lab 的 clever 做法 :在导出前,把 normalizer 也包装进模型内部:
python
class PolicyWithNormalizer(nn.Module):
def forward(self, obs):
# 第一步:归一化(使用训练时统计的 mean/var)
obs = (obs - self.mean) / torch.sqrt(self.var + 1e-8)
# 第二步:送网络推理
return self.policy_net(obs)
这样导出的 ONNX / TorchScript 是端到端的:
原始 obs → [Normalizer] → [Neural Net] → actions
C++ 端只需直接喂原始观测,无需关心归一化逻辑。
三种格式对比总结
| 特性 | PyTorch Checkpoint (.pth) | TorchScript (.pt) | ONNX (.onnx) |
|---|---|---|---|
| 本质 | 权重字典 + 训练状态 | 静态计算图 + 权重 | 通用计算图 + 权重 |
| 运行时依赖 | Python + PyTorch 源码 | LibTorch(C++) | ONNX Runtime / TensorRT / OpenVINO |
| 跨框架 | ❌ 仅限 PyTorch | ❌ 仅限 PyTorch | ✅ PyTorch/TensorFlow/MXNet 互通 |
| 部署场景 | 训练、调试、断点续训 | C++ 嵌入式、机器人控制器 | 多硬件加速、云端推理、异构计算 |
| 控制流支持 | Python 原生 | Tracing 有限制 | 有限支持(Loop/If 算子) |
| 动态 Shape | 完全支持 | 部分支持 | 良好(需声明 dynamic_axes) |
| 本项目用途 | 训练保存的中间产物 | loco_client_sim C++ 端加载 |
备选方案 / 其他推理引擎 |
在本项目中的实际流转
以 G1 强化学习策略从训练到部署为例:
┌─────────────────────────────────────────────────────────────┐
│ 1. Isaac Lab 训练(GPU 服务器) │
│ RSL-RL OnPolicyRunner 训练策略 │
│ │ │
│ ▼ │
│ logs/rsl_rl/g1_flat/.../checkpoints/model_1000.pt │
│ (PyTorch Checkpoint:含权重 + 训练状态) │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ 2. 运行 play.py 导出(推理 + 转换) │
│ │
│ python scripts/rsl_rl/play.py --task=... │
│ │
│ ├──► exported/policy.pt (TorchScript) │
│ │ │ │
│ │ ▼ │
│ │ control_simulator / loco_client_sim (C++ 端) │
│ │ │ │
│ │ ▼ │
│ │ torch::jit::load("policy.pt") │
│ │ MuJoCo 仿真 或 真机部署 (NVIDIA Orin) │
│ │ │
│ └──► exported/policy.onnx (ONNX) │
│ │ │
│ ▼ │
│ TensorRT / ONNX Runtime(可选加速方案) │
└─────────────────────────────────────────────────────────────┘
关键节点说明
- Checkpoint 是起点:训练完成后,checkpoint 保存了策略的所有权重。
- play.py 是转换器:运行推理时,脚本自动调用导出函数,生成 TorchScript 和 ONNX。
- TorchScript 是主线 :
loco_client_sim的 C++ 代码通过 LibTorch 加载policy.pt,在 MuJoCo 仿真或真机上实时推理。 - ONNX 是备选:如果需要使用 TensorRT 做 GPU 推理加速,或在非 PyTorch 生态中运行,ONNX 是桥梁。
直观类比
| 概念 | 类比 | 说明 |
|---|---|---|
| PyTorch Checkpoint (.pth) | 建筑设计草图 + 材料清单 | 只有原设计师拿着草图和清单,才能复原整栋建筑 |
| TorchScript (.pt) | 完整的建筑施工图 | 任何按图施工的工人(LibTorch)都能建造,不需要原设计师在场 |
| ONNX (.onnx) | 国际通用的建筑标准图纸 | 中国的施工队、德国的工程师、日本的机器人都能看懂并施工 |
另一个角度:
| 概念 | 类比 |
|---|---|
| Checkpoint | Python 源代码 + 数据文件 ------ 需要解释器 |
| TorchScript | 编译后的二进制 ------ 直接运行,但只能在 PyTorch 虚拟机里跑 |
| ONNX | JVM 字节码或 WebAssembly ------ 只要有对应的运行时,哪里都能跑 |
延伸阅读
- PyTorch TorchScript 官方文档
- ONNX 官方规范
- TensorRT 开发者指南
- 本项目相关代码:
isaaclab_rl/rsl_rl/exporter.py------export_policy_as_jit()/export_policy_as_onnx()实现loco_client_sim/src/loco_client_sim.cpp------ C++ 端 LibTorch 加载逻辑extern/unitree_rl_lab/scripts/rsl_rl/play.py------ 推理与导出入口