1. 文档目标
本文面向 full_onnx 全量模型导出链,统一说明三部分内容:
full_onnx的基本导出流程- 为了让导出的自定义节点可执行,构建算子所必需的成员函数和实现方法
- 本仓库 PointNet2 六个自定义算子的具体实现细节
对应代码入口:
- 全量导出脚本:
deploy/scripts/export_full_onnx.py - 全量导出 wrapper:
deploy/model_wrappers.py - PointNet2 symbolic 注册:
deploy/scripts/pointnet2_custom_ops.py - ORT CPU custom ops:
deploy/ort_custom_ops/src/pvn3d_pointnet2_ops.cc - ORT GPU custom ops:
deploy/ort_custom_ops_gpu/src/pvn3d_pointnet2_ops_cuda.cc
本文关注的是下面这条完整链:
text
PyTorch PVN3D
-> FullModelExportWrapper
-> torch.onnx.export
-> full_onnx(custom nodes)
-> ORT custom op library
-> ORT 执行完整图
2. full_onnx 的基本流程
2.1 什么是 full_onnx
这里的 full_onnx 不是只导出 RGB backbone 或 fusion head 的子图,而是导出单个完整 PVN3D 图:
- 输入:
pointcloud,rgb,choose - 输出:
pred_kp_of,pred_rgbd_seg,pred_ctr_of
这条图里既包含标准 ONNX 节点,也包含 PointNet2 的自定义节点。
在仓库中,对应产物通常是:
deploy/models/onnx_ape/pvn3d_full.onnx
2.2 full_onnx 导出时解决的核心问题
全量导出与拆分导出的区别在于:PointNet2 这一路不能直接变成标准 ONNX 算子。
所以 full_onnx 的关键不是"把整网直接导出",而是:
- 用 wrapper 固定完整模型入口
- 给 PointNet2 扩展算子注册 symbolic
- 让导出器把这些算子写成
ai.onnx.contrib域下的 custom nodes - 导出后再用 ORT custom ops 去补执行能力
如果没有第 2 步和第 3 步,导出器会卡在 PyTorch 扩展算子上。
2.3 full_onnx 导出链总览
下面这份 dot 可以直接渲染:
dot
digraph FullOnnxExportFlow {
rankdir=LR;
node [shape=box, style="rounded,filled", fillcolor="#f5f5f5", color="#666666"];
model [label="PVN3D PyTorch Model"];
wrapper [label="FullModelExportWrapper\n(pointcloud, rgb, choose)"];
symbolic [label="Register PointNet2 symbolic\nai.onnx.contrib::*"];
export [label="torch.onnx.export"];
onnx [label="pvn3d_full.onnx\nstandard nodes + custom nodes"];
check [label="onnx.checker\nshape_inference"];
runtime [label="ORT + custom ops library"];
outputs [label="pred_kp_of\npred_rgbd_seg\npred_ctr_of"];
model -> wrapper -> export -> onnx -> check -> runtime -> outputs;
symbolic -> export;
}
2.4 代码级 full_onnx 导出流程
deploy/scripts/export_full_onnx.py 中的主流程可以概括为:
- 解析导出参数
- 校验 opset 是否受当前 PyTorch 支持
- 生成 manifest
- 注册 PointNet2 symbolic
- 构建完整模型 wrapper
- 构造示例输入
- 调用
torch.onnx.export - 对导出结果执行
onnx.checker和shape_inference - 输出 report / node list / inferred model
2.4.1 参数解析
脚本接受以下关键参数:
--checkpoint--output--num-classes--num-points--height--width--batch-size--opset--device--manifest--report--node-list--inferred-output
这些参数共同决定:
- 模型权重
- 导出输入规格
- ONNX opset
- 导出产物路径
2.4.2 构建完整模型 wrapper
deploy/model_wrappers.py 中的 FullModelExportWrapper 非常简单:
python
class FullModelExportWrapper(nn.Module):
def __init__(self, full_model):
super().__init__()
self.full_model = full_model
def forward(self, pointcloud, rgb, choose):
return self.full_model(pointcloud, rgb, choose)
它的意义不是改模型,而是为导出器固定一个清晰、稳定的导出入口。
2.4.3 构造示例输入
create_example_inputs() 里会生成:
python
rgb: (B, 3, H, W)
pointcloud: (B, N, 9)
choose: (B, 1, N)
其中:
pointcloud的9表示xyz + featurechoose用于从 RGB feature map 中按点选取像素位置对应特征
2.4.4 注册 PointNet2 symbolic
这是 full_onnx 导出的核心步骤。
deploy/scripts/pointnet2_custom_ops.py 中,脚本会把以下 PyTorch 扩展算子绑定到自定义 symbolic:
FurthestPointSamplingGatherOperationBallQueryGroupingOperationThreeNNThreeInterpolate
导出时不再尝试把它们翻译成标准 ONNX,而是直接写成:
text
ai.onnx.contrib::PVN3D_FurthestPointSample
ai.onnx.contrib::PVN3D_GatherPoints
ai.onnx.contrib::PVN3D_BallQuery
ai.onnx.contrib::PVN3D_GroupPoints
ai.onnx.contrib::PVN3D_ThreeNN
ai.onnx.contrib::PVN3D_ThreeInterpolate
2.4.5 调用 torch.onnx.export
导出调用形态是:
python
torch.onnx.export(
wrapper,
(inputs["pointcloud"], inputs["rgb"], inputs["choose"]),
str(output_path),
input_names=["pointcloud", "rgb", "choose"],
output_names=["pred_kp_of", "pred_rgbd_seg", "pred_ctr_of"],
opset_version=args.opset,
do_constant_folding=True,
custom_opsets={CUSTOM_DOMAIN: CUSTOM_OPSET_VERSION},
)
这里有两个关键点:
custom_opsets={CUSTOM_DOMAIN: CUSTOM_OPSET_VERSION}- 输入输出名字被显式固定
第一点保证 custom domain 能写进 ONNX 模型,第二点便于后续调试和 ORT 侧联调。
2.4.6 导出后检查
导出成功后,脚本会继续执行:
onnx.checker.check_model(model)shape_inference.infer_shapes(model)
并输出:
*.inferred.onnx*.nodes.json*.report.json
这些文件解决的是"图是否能被解析"和"图里到底有哪些 custom nodes"。
3. full_onnx 导出与 custom op 的关系
3.1 symbolic 只负责"落图",不负责"执行"
这是 full_onnx 流程里最容易混淆的地方。
在导出阶段:
symbolic负责把 PyTorch 扩展算子写成 ONNX 节点
在运行阶段:
- ORT custom op
.so负责真正执行这些节点
所以这两部分职责分别是:
- 导出端:定义节点长什么样
- 运行端:定义节点怎么算
3.2 Graphviz:导出端和运行端的衔接
dot
digraph ExportVsRuntime {
rankdir=LR;
node [shape=box, style="rounded,filled", fillcolor="#f5f5f5", color="#666666"];
pytorch [label="PyTorch ext op"];
symbolic [label="symbolic()\nmap to custom ONNX node"];
onnx [label="full ONNX graph"];
session [label="ORT Session"];
reg [label="RegisterCustomOps()"];
kernel [label="Kernel::Compute()"];
pytorch -> symbolic -> onnx -> session -> kernel;
reg -> session;
}
4. 构建 ORT 算子必要的成员函数和实现方法
full_onnx 导出完成后,要让 ORT 真正运行这些节点,就需要在 deploy/ort_custom_ops 或 deploy/ort_custom_ops_gpu 中实现 custom ops。
本仓库采用的结构是:
Kernel:负责计算Op:负责向 ORT 描述接口
4.1 CreateKernel
作用:
- 为图中的某个节点实例创建一个 kernel 对象
典型写法:
cpp
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new SomeKernel(api, info);
}
如果某个算子带属性,例如:
npointradiusnsample
那么通常在 SomeKernel 构造函数中读取。
4.2 GetName
作用:
- 返回 ONNX 节点的
op_type
例如:
cpp
const char* GetName() const { return "PVN3D_ThreeNN"; }
要求:
- 必须和导出阶段 symbolic 写入的节点名字完全一致
4.3 GetInputTypeCount
作用:
- 告诉 ORT 该算子有多少个输入
例如:
GatherPoints是 2 个输入ThreeInterpolate是 3 个输入
4.4 GetInputType
作用:
- 指定每个输入张量的 dtype
例如:
cpp
ONNXTensorElementDataType GetInputType(size_t index) const {
return index == 1 ? ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
: ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
4.5 GetOutputTypeCount
作用:
- 告诉 ORT 该算子有多少个输出
例如:
ThreeNN有 2 个输出- 其他大多数只有 1 个输出
4.6 GetOutputType
作用:
- 指定每个输出的 dtype
例如 ThreeNN:
- 输出 0:
float - 输出 1:
int64
4.7 GetExecutionProviderType
作用:
- 绑定执行 provider
在本仓库:
- CPU 版通常不实现该函数
- GPU 版返回
"CUDAExecutionProvider"
这样 ORT 才会把节点分派给 CUDA custom op。
4.8 Kernel 构造函数
作用:
- 从 ONNX 节点属性中读取参数
例如:
cpp
explicit BallQueryKernel(const OrtApi& api, const OrtKernelInfo* info) {
Ort::ConstKernelInfo kernel_info(info);
radius_ = kernel_info.GetAttribute<float>("radius");
nsample_ = kernel_info.GetAttribute<int64_t>("nsample");
}
full_onnx 导出阶段写入的属性,就是在这里被读回来的。
4.9 Compute
这是单个 custom op 的真正执行入口。
典型步骤:
Ort::KernelContext ctx(context)GetInput()- 解析 shape
- 检查 rank / last dim
GetOutput()分配输出- 获取输入输出数据指针
- 执行 CPU 循环或调用 CUDA kernel
4.10 RegisterCustomOps
这是整个 .so 暴露给 ORT 的注册入口。
本仓库逻辑是:
Ort::InitApi- 构造
Ort::CustomOpDomain("ai.onnx.contrib") - 将六个 PointNet2 op 加入 domain
session_options.Add(domain)
如果没有这一步,full_onnx 图里的 custom nodes 仍然无法执行。
5. full_onnx 导出阶段的 symbolic 实现方法
在这条链里,除了 ORT custom op 成员函数,导出端的 symbolic 也很关键。
5.1 symbolic 的职责
symbolic 做的是:
- 把 PyTorch 运算映射成 ONNX 节点
- 指定 ONNX 节点的名字、输入顺序、属性、输出数量
它不做实际数值计算。
5.2 symbolic 的典型写法
例如 furthest_point_sample:
python
def fps_symbolic(g, xyz, npoint):
return g.op(
f"{CUSTOM_DOMAIN}::PVN3D_FurthestPointSample",
xyz,
npoint_i=parse_int(npoint),
)
例如 three_nn:
python
def three_nn_symbolic(g, unknown, known):
dist2, idx = g.op(
f"{CUSTOM_DOMAIN}::PVN3D_ThreeNN",
unknown,
known,
outputs=2,
)
return g.op("Sqrt", dist2), idx
这里有个重要细节:
- ORT custom op
PVN3D_ThreeNN输出的是dist2 - 但 PyTorch 原始接口返回的是
dist - 所以 symbolic 里补了一个
Sqrt
这说明:
- symbolic 负责让导出的 ONNX 语义与原始 PyTorch 前向保持一致
6. 每个算子的实现细节
下面以 full_onnx 图中实际出现的六个 custom ops 为主线说明。
6.1 PVN3D_FurthestPointSample
6.1.1 在 full_onnx 中的位置
它出现在 SA 路径中,用于从输入点集或中间点集里选择中心点。
导出端 symbolic:
python
g.op("ai.onnx.contrib::PVN3D_FurthestPointSample", xyz, npoint_i=...)
运行端输入输出:
- 输入:
xyz (B, N, 3) - 输出:
idx (B, npoint)
6.1.2 核心逻辑
最远点采样的思想是:
- 每次选择"距离已选集合最远"的点
CPU 版实现中:
- 维护每个点到已选点集合的最小距离
- 第一个点固定为
0 - 逐轮更新最优候选
6.1.3 工程细节
- 属性
npoint在 kernel 构造函数中读取 - 输出类型是
int64 - 近零点会被跳过,避免无效点参与采样
6.2 PVN3D_GatherPoints
6.2.1 在 full_onnx 中的位置
它通常紧跟在 FPS 之后,用于按采样索引取出中心点坐标或特征。
输入输出:
- 输入:
points (B, C, N)idx (B, M)
- 输出:
(B, C, M)
6.2.2 核心逻辑
本质上是按索引做列采样:
python
out[b, c, j] = points[b, c, idx[b, j]]
6.2.3 工程细节
idx类型是int64- 会做越界检查
- 它是很多 PointNet2 高层操作的基础子算子
6.3 PVN3D_BallQuery
6.3.1 在 full_onnx 中的位置
它出现在 SA 模块里,在选出中心点后,为每个中心点建立局部球邻域。
输入输出:
- 输入:
new_xyz (B, npoint, 3)xyz (B, N, 3)
- 属性:
radiusnsample
- 输出:
idx (B, npoint, nsample)
6.3.2 核心逻辑
逻辑是:
- 对每个中心点扫描全部原始点
- 找出
distance^2 < radius^2的点 - 最多保留
nsample个
6.3.3 工程细节
- 邻域不足时,用首个命中索引补齐
- 若完全没有邻域,保持初始化
0 - 它只输出邻居索引,不输出特征
6.4 PVN3D_GroupPoints
6.4.1 在 full_onnx 中的位置
它一般接在 BallQuery 后面,把邻居索引转换成真正的局部块特征张量。
输入输出:
- 输入:
points (B, C, N)idx (B, npoint, nsample)
- 输出:
(B, C, npoint, nsample)
6.4.2 核心逻辑
本质上是高维 gather:
python
out[b, c, i, j] = points[b, c, idx[b, i, j]]
6.4.3 工程细节
- 用于生成局部邻域张量
- 后续会与相对坐标、MLP、pooling 结合形成 SA 编码
6.5 PVN3D_ThreeNN
6.5.1 在 full_onnx 中的位置
它出现在 FP 路径中,用于高分辨率点向低分辨率点查找 3 个最近邻。
输入输出:
- 输入:
unknown (B, n, 3)known (B, m, 3)
- 输出:
dist2 (B, n, 3)idx (B, n, 3)
6.5.2 核心逻辑
对每个 unknown 点,扫描 known 全集,维护最小的三个平方距离和对应索引。
6.5.3 full_onnx 里的一个关键语义
这里要特别强调:
- ORT custom op 输出的是
dist2 - symbolic 在导出时把它接了一个
Sqrt
所以 full_onnx 中的外部可见语义仍然是:
distidx
而底层 custom node 本身仍然只负责 dist2
6.6 PVN3D_ThreeInterpolate
6.6.1 在 full_onnx 中的位置
它是 FP 路径中的插值核心,用于将低分辨率特征传播回高分辨率点集。
输入输出:
- 输入:
points (B, C, M)idx (B, N, 3)weight (B, N, 3)
- 输出:
(B, C, N)
6.6.2 核心逻辑
本质上是三项加权和:
python
out[b, c, j] =
points[b, c, idx[b, j, 0]] * weight[b, j, 0] +
points[b, c, idx[b, j, 1]] * weight[b, j, 1] +
points[b, c, idx[b, j, 2]] * weight[b, j, 2]
6.6.3 工程细节
- 它不负责生成
weight weight一般来自ThreeNN后的距离倒数归一化链- 这一步是 FP 恢复逐点分辨率的关键
7. 从 full_onnx 到 ORT 执行的闭环
把前面的内容串起来,完整闭环是:
FullModelExportWrapper固定完整模型导出入口register_pointnet2_symbolics()把 PyTorch 扩展算子映射成ai.onnx.contribcustom nodestorch.onnx.export生成pvn3d_full.onnxonnx.checker/shape_inference验证导出结果- ORT 加载 custom ops
.so RegisterCustomOps()注册同名算子- ORT 根据节点名称、domain、provider 调用相应 kernel
Compute()执行每个 PointNet2 基础算子
下面这份 dot 是完整闭环图:
dot
digraph FullClosure {
rankdir=LR;
node [shape=box, style="rounded,filled", fillcolor="#f5f5f5", color="#666666"];
wrap [label="FullModelExportWrapper"];
sym [label="PointNet2 symbolic"];
export [label="torch.onnx.export"];
fullonnx [label="pvn3d_full.onnx"];
lib [label="ORT custom ops .so"];
reg [label="RegisterCustomOps()"];
dispatch [label="ORT node dispatch"];
compute [label="Kernel::Compute()"];
pred [label="pred_kp_of / pred_rgbd_seg / pred_ctr_of"];
wrap -> export -> fullonnx -> dispatch -> compute -> pred;
sym -> export;
lib -> reg -> dispatch;
}
8. 结论
full_onnx 导出流程的本质不是"把 PVN3D 直接另存成 ONNX",而是:
- 用 wrapper 固定整网导出入口
- 用 symbolic 把 PointNet2 扩展算子保留成 custom nodes
- 用 ORT custom ops 在运行时补上这些节点的执行能力
因此这条链实际上分成两半:
- 导出半链:
wrapper + symbolic + torch.onnx.export - 运行半链:
RegisterCustomOps + Kernel::Compute
六个 PointNet2 基础算子正是这条链的桥梁:
- 在导出阶段,它们以
ai.onnx.contrib::PVN3D_*节点落入 full_onnx - 在运行阶段,它们由 ORT CPU/GPU custom ops 实现真正计算
这也是为什么 full_onnx、symbolic、ORT custom ops 必须放在同一条工程链上理解,而不能分开看。