PVN3D Full ONNX 导出与自定义算子说明

1. 文档目标

本文面向 full_onnx 全量模型导出链,统一说明三部分内容:

  1. full_onnx 的基本导出流程
  2. 为了让导出的自定义节点可执行,构建算子所必需的成员函数和实现方法
  3. 本仓库 PointNet2 六个自定义算子的具体实现细节

对应代码入口:

  • 全量导出脚本:deploy/scripts/export_full_onnx.py
  • 全量导出 wrapper:deploy/model_wrappers.py
  • PointNet2 symbolic 注册:deploy/scripts/pointnet2_custom_ops.py
  • ORT CPU custom ops:deploy/ort_custom_ops/src/pvn3d_pointnet2_ops.cc
  • ORT GPU custom ops:deploy/ort_custom_ops_gpu/src/pvn3d_pointnet2_ops_cuda.cc

本文关注的是下面这条完整链:

text 复制代码
PyTorch PVN3D
-> FullModelExportWrapper
-> torch.onnx.export
-> full_onnx(custom nodes)
-> ORT custom op library
-> ORT 执行完整图

2. full_onnx 的基本流程

2.1 什么是 full_onnx

这里的 full_onnx 不是只导出 RGB backbone 或 fusion head 的子图,而是导出单个完整 PVN3D 图:

  • 输入:pointcloud, rgb, choose
  • 输出:pred_kp_of, pred_rgbd_seg, pred_ctr_of

这条图里既包含标准 ONNX 节点,也包含 PointNet2 的自定义节点。

在仓库中,对应产物通常是:

  • deploy/models/onnx_ape/pvn3d_full.onnx

2.2 full_onnx 导出时解决的核心问题

全量导出与拆分导出的区别在于:PointNet2 这一路不能直接变成标准 ONNX 算子。

所以 full_onnx 的关键不是"把整网直接导出",而是:

  1. 用 wrapper 固定完整模型入口
  2. 给 PointNet2 扩展算子注册 symbolic
  3. 让导出器把这些算子写成 ai.onnx.contrib 域下的 custom nodes
  4. 导出后再用 ORT custom ops 去补执行能力

如果没有第 2 步和第 3 步,导出器会卡在 PyTorch 扩展算子上。

2.3 full_onnx 导出链总览

下面这份 dot 可以直接渲染:

dot 复制代码
digraph FullOnnxExportFlow {
  rankdir=LR;
  node [shape=box, style="rounded,filled", fillcolor="#f5f5f5", color="#666666"];

  model [label="PVN3D PyTorch Model"];
  wrapper [label="FullModelExportWrapper\n(pointcloud, rgb, choose)"];
  symbolic [label="Register PointNet2 symbolic\nai.onnx.contrib::*"];
  export [label="torch.onnx.export"];
  onnx [label="pvn3d_full.onnx\nstandard nodes + custom nodes"];
  check [label="onnx.checker\nshape_inference"];
  runtime [label="ORT + custom ops library"];
  outputs [label="pred_kp_of\npred_rgbd_seg\npred_ctr_of"];

  model -> wrapper -> export -> onnx -> check -> runtime -> outputs;
  symbolic -> export;
}

2.4 代码级 full_onnx 导出流程

deploy/scripts/export_full_onnx.py 中的主流程可以概括为:

  1. 解析导出参数
  2. 校验 opset 是否受当前 PyTorch 支持
  3. 生成 manifest
  4. 注册 PointNet2 symbolic
  5. 构建完整模型 wrapper
  6. 构造示例输入
  7. 调用 torch.onnx.export
  8. 对导出结果执行 onnx.checkershape_inference
  9. 输出 report / node list / inferred model

2.4.1 参数解析

脚本接受以下关键参数:

  • --checkpoint
  • --output
  • --num-classes
  • --num-points
  • --height
  • --width
  • --batch-size
  • --opset
  • --device
  • --manifest
  • --report
  • --node-list
  • --inferred-output

这些参数共同决定:

  • 模型权重
  • 导出输入规格
  • ONNX opset
  • 导出产物路径

2.4.2 构建完整模型 wrapper

deploy/model_wrappers.py 中的 FullModelExportWrapper 非常简单:

python 复制代码
class FullModelExportWrapper(nn.Module):
    def __init__(self, full_model):
        super().__init__()
        self.full_model = full_model

    def forward(self, pointcloud, rgb, choose):
        return self.full_model(pointcloud, rgb, choose)

它的意义不是改模型,而是为导出器固定一个清晰、稳定的导出入口。

2.4.3 构造示例输入

create_example_inputs() 里会生成:

python 复制代码
rgb:        (B, 3, H, W)
pointcloud: (B, N, 9)
choose:     (B, 1, N)

其中:

  • pointcloud9 表示 xyz + feature
  • choose 用于从 RGB feature map 中按点选取像素位置对应特征

2.4.4 注册 PointNet2 symbolic

这是 full_onnx 导出的核心步骤。

deploy/scripts/pointnet2_custom_ops.py 中,脚本会把以下 PyTorch 扩展算子绑定到自定义 symbolic:

  • FurthestPointSampling
  • GatherOperation
  • BallQuery
  • GroupingOperation
  • ThreeNN
  • ThreeInterpolate

导出时不再尝试把它们翻译成标准 ONNX,而是直接写成:

text 复制代码
ai.onnx.contrib::PVN3D_FurthestPointSample
ai.onnx.contrib::PVN3D_GatherPoints
ai.onnx.contrib::PVN3D_BallQuery
ai.onnx.contrib::PVN3D_GroupPoints
ai.onnx.contrib::PVN3D_ThreeNN
ai.onnx.contrib::PVN3D_ThreeInterpolate

2.4.5 调用 torch.onnx.export

导出调用形态是:

python 复制代码
torch.onnx.export(
    wrapper,
    (inputs["pointcloud"], inputs["rgb"], inputs["choose"]),
    str(output_path),
    input_names=["pointcloud", "rgb", "choose"],
    output_names=["pred_kp_of", "pred_rgbd_seg", "pred_ctr_of"],
    opset_version=args.opset,
    do_constant_folding=True,
    custom_opsets={CUSTOM_DOMAIN: CUSTOM_OPSET_VERSION},
)

这里有两个关键点:

  1. custom_opsets={CUSTOM_DOMAIN: CUSTOM_OPSET_VERSION}
  2. 输入输出名字被显式固定

第一点保证 custom domain 能写进 ONNX 模型,第二点便于后续调试和 ORT 侧联调。

2.4.6 导出后检查

导出成功后,脚本会继续执行:

  • onnx.checker.check_model(model)
  • shape_inference.infer_shapes(model)

并输出:

  • *.inferred.onnx
  • *.nodes.json
  • *.report.json

这些文件解决的是"图是否能被解析"和"图里到底有哪些 custom nodes"。


3. full_onnx 导出与 custom op 的关系

3.1 symbolic 只负责"落图",不负责"执行"

这是 full_onnx 流程里最容易混淆的地方。

在导出阶段:

  • symbolic 负责把 PyTorch 扩展算子写成 ONNX 节点

在运行阶段:

  • ORT custom op .so 负责真正执行这些节点

所以这两部分职责分别是:

  • 导出端:定义节点长什么样
  • 运行端:定义节点怎么算

3.2 Graphviz:导出端和运行端的衔接

dot 复制代码
digraph ExportVsRuntime {
  rankdir=LR;
  node [shape=box, style="rounded,filled", fillcolor="#f5f5f5", color="#666666"];

  pytorch [label="PyTorch ext op"];
  symbolic [label="symbolic()\nmap to custom ONNX node"];
  onnx [label="full ONNX graph"];
  session [label="ORT Session"];
  reg [label="RegisterCustomOps()"];
  kernel [label="Kernel::Compute()"];

  pytorch -> symbolic -> onnx -> session -> kernel;
  reg -> session;
}

4. 构建 ORT 算子必要的成员函数和实现方法

full_onnx 导出完成后,要让 ORT 真正运行这些节点,就需要在 deploy/ort_custom_opsdeploy/ort_custom_ops_gpu 中实现 custom ops。

本仓库采用的结构是:

  • Kernel:负责计算
  • Op:负责向 ORT 描述接口

4.1 CreateKernel

作用:

  • 为图中的某个节点实例创建一个 kernel 对象

典型写法:

cpp 复制代码
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
  return new SomeKernel(api, info);
}

如果某个算子带属性,例如:

  • npoint
  • radius
  • nsample

那么通常在 SomeKernel 构造函数中读取。

4.2 GetName

作用:

  • 返回 ONNX 节点的 op_type

例如:

cpp 复制代码
const char* GetName() const { return "PVN3D_ThreeNN"; }

要求:

  • 必须和导出阶段 symbolic 写入的节点名字完全一致

4.3 GetInputTypeCount

作用:

  • 告诉 ORT 该算子有多少个输入

例如:

  • GatherPoints 是 2 个输入
  • ThreeInterpolate 是 3 个输入

4.4 GetInputType

作用:

  • 指定每个输入张量的 dtype

例如:

cpp 复制代码
ONNXTensorElementDataType GetInputType(size_t index) const {
  return index == 1 ? ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64
                    : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

4.5 GetOutputTypeCount

作用:

  • 告诉 ORT 该算子有多少个输出

例如:

  • ThreeNN 有 2 个输出
  • 其他大多数只有 1 个输出

4.6 GetOutputType

作用:

  • 指定每个输出的 dtype

例如 ThreeNN

  • 输出 0:float
  • 输出 1:int64

4.7 GetExecutionProviderType

作用:

  • 绑定执行 provider

在本仓库:

  • CPU 版通常不实现该函数
  • GPU 版返回 "CUDAExecutionProvider"

这样 ORT 才会把节点分派给 CUDA custom op。

4.8 Kernel 构造函数

作用:

  • 从 ONNX 节点属性中读取参数

例如:

cpp 复制代码
explicit BallQueryKernel(const OrtApi& api, const OrtKernelInfo* info) {
  Ort::ConstKernelInfo kernel_info(info);
  radius_ = kernel_info.GetAttribute<float>("radius");
  nsample_ = kernel_info.GetAttribute<int64_t>("nsample");
}

full_onnx 导出阶段写入的属性,就是在这里被读回来的。

4.9 Compute

这是单个 custom op 的真正执行入口。

典型步骤:

  1. Ort::KernelContext ctx(context)
  2. GetInput()
  3. 解析 shape
  4. 检查 rank / last dim
  5. GetOutput() 分配输出
  6. 获取输入输出数据指针
  7. 执行 CPU 循环或调用 CUDA kernel

4.10 RegisterCustomOps

这是整个 .so 暴露给 ORT 的注册入口。

本仓库逻辑是:

  1. Ort::InitApi
  2. 构造 Ort::CustomOpDomain("ai.onnx.contrib")
  3. 将六个 PointNet2 op 加入 domain
  4. session_options.Add(domain)

如果没有这一步,full_onnx 图里的 custom nodes 仍然无法执行。


5. full_onnx 导出阶段的 symbolic 实现方法

在这条链里,除了 ORT custom op 成员函数,导出端的 symbolic 也很关键。

5.1 symbolic 的职责

symbolic 做的是:

  • 把 PyTorch 运算映射成 ONNX 节点
  • 指定 ONNX 节点的名字、输入顺序、属性、输出数量

它不做实际数值计算。

5.2 symbolic 的典型写法

例如 furthest_point_sample

python 复制代码
def fps_symbolic(g, xyz, npoint):
    return g.op(
        f"{CUSTOM_DOMAIN}::PVN3D_FurthestPointSample",
        xyz,
        npoint_i=parse_int(npoint),
    )

例如 three_nn

python 复制代码
def three_nn_symbolic(g, unknown, known):
    dist2, idx = g.op(
        f"{CUSTOM_DOMAIN}::PVN3D_ThreeNN",
        unknown,
        known,
        outputs=2,
    )
    return g.op("Sqrt", dist2), idx

这里有个重要细节:

  • ORT custom op PVN3D_ThreeNN 输出的是 dist2
  • 但 PyTorch 原始接口返回的是 dist
  • 所以 symbolic 里补了一个 Sqrt

这说明:

  • symbolic 负责让导出的 ONNX 语义与原始 PyTorch 前向保持一致

6. 每个算子的实现细节

下面以 full_onnx 图中实际出现的六个 custom ops 为主线说明。

6.1 PVN3D_FurthestPointSample

6.1.1 在 full_onnx 中的位置

它出现在 SA 路径中,用于从输入点集或中间点集里选择中心点。

导出端 symbolic:

python 复制代码
g.op("ai.onnx.contrib::PVN3D_FurthestPointSample", xyz, npoint_i=...)

运行端输入输出:

  • 输入:xyz (B, N, 3)
  • 输出:idx (B, npoint)

6.1.2 核心逻辑

最远点采样的思想是:

  • 每次选择"距离已选集合最远"的点

CPU 版实现中:

  1. 维护每个点到已选点集合的最小距离
  2. 第一个点固定为 0
  3. 逐轮更新最优候选

6.1.3 工程细节

  • 属性 npoint 在 kernel 构造函数中读取
  • 输出类型是 int64
  • 近零点会被跳过,避免无效点参与采样

6.2 PVN3D_GatherPoints

6.2.1 在 full_onnx 中的位置

它通常紧跟在 FPS 之后,用于按采样索引取出中心点坐标或特征。

输入输出:

  • 输入:
    • points (B, C, N)
    • idx (B, M)
  • 输出:
    • (B, C, M)

6.2.2 核心逻辑

本质上是按索引做列采样:

python 复制代码
out[b, c, j] = points[b, c, idx[b, j]]

6.2.3 工程细节

  • idx 类型是 int64
  • 会做越界检查
  • 它是很多 PointNet2 高层操作的基础子算子

6.3 PVN3D_BallQuery

6.3.1 在 full_onnx 中的位置

它出现在 SA 模块里,在选出中心点后,为每个中心点建立局部球邻域。

输入输出:

  • 输入:
    • new_xyz (B, npoint, 3)
    • xyz (B, N, 3)
  • 属性:
    • radius
    • nsample
  • 输出:
    • idx (B, npoint, nsample)

6.3.2 核心逻辑

逻辑是:

  1. 对每个中心点扫描全部原始点
  2. 找出 distance^2 < radius^2 的点
  3. 最多保留 nsample

6.3.3 工程细节

  • 邻域不足时,用首个命中索引补齐
  • 若完全没有邻域,保持初始化 0
  • 它只输出邻居索引,不输出特征

6.4 PVN3D_GroupPoints

6.4.1 在 full_onnx 中的位置

它一般接在 BallQuery 后面,把邻居索引转换成真正的局部块特征张量。

输入输出:

  • 输入:
    • points (B, C, N)
    • idx (B, npoint, nsample)
  • 输出:
    • (B, C, npoint, nsample)

6.4.2 核心逻辑

本质上是高维 gather:

python 复制代码
out[b, c, i, j] = points[b, c, idx[b, i, j]]

6.4.3 工程细节

  • 用于生成局部邻域张量
  • 后续会与相对坐标、MLP、pooling 结合形成 SA 编码

6.5 PVN3D_ThreeNN

6.5.1 在 full_onnx 中的位置

它出现在 FP 路径中,用于高分辨率点向低分辨率点查找 3 个最近邻。

输入输出:

  • 输入:
    • unknown (B, n, 3)
    • known (B, m, 3)
  • 输出:
    • dist2 (B, n, 3)
    • idx (B, n, 3)

6.5.2 核心逻辑

对每个 unknown 点,扫描 known 全集,维护最小的三个平方距离和对应索引。

6.5.3 full_onnx 里的一个关键语义

这里要特别强调:

  • ORT custom op 输出的是 dist2
  • symbolic 在导出时把它接了一个 Sqrt

所以 full_onnx 中的外部可见语义仍然是:

  • dist
  • idx

而底层 custom node 本身仍然只负责 dist2


6.6 PVN3D_ThreeInterpolate

6.6.1 在 full_onnx 中的位置

它是 FP 路径中的插值核心,用于将低分辨率特征传播回高分辨率点集。

输入输出:

  • 输入:
    • points (B, C, M)
    • idx (B, N, 3)
    • weight (B, N, 3)
  • 输出:
    • (B, C, N)

6.6.2 核心逻辑

本质上是三项加权和:

python 复制代码
out[b, c, j] =
    points[b, c, idx[b, j, 0]] * weight[b, j, 0] +
    points[b, c, idx[b, j, 1]] * weight[b, j, 1] +
    points[b, c, idx[b, j, 2]] * weight[b, j, 2]

6.6.3 工程细节

  • 它不负责生成 weight
  • weight 一般来自 ThreeNN 后的距离倒数归一化链
  • 这一步是 FP 恢复逐点分辨率的关键

7. 从 full_onnx 到 ORT 执行的闭环

把前面的内容串起来,完整闭环是:

  1. FullModelExportWrapper 固定完整模型导出入口
  2. register_pointnet2_symbolics() 把 PyTorch 扩展算子映射成 ai.onnx.contrib custom nodes
  3. torch.onnx.export 生成 pvn3d_full.onnx
  4. onnx.checker / shape_inference 验证导出结果
  5. ORT 加载 custom ops .so
  6. RegisterCustomOps() 注册同名算子
  7. ORT 根据节点名称、domain、provider 调用相应 kernel
  8. Compute() 执行每个 PointNet2 基础算子

下面这份 dot 是完整闭环图:

dot 复制代码
digraph FullClosure {
  rankdir=LR;
  node [shape=box, style="rounded,filled", fillcolor="#f5f5f5", color="#666666"];

  wrap [label="FullModelExportWrapper"];
  sym [label="PointNet2 symbolic"];
  export [label="torch.onnx.export"];
  fullonnx [label="pvn3d_full.onnx"];
  lib [label="ORT custom ops .so"];
  reg [label="RegisterCustomOps()"];
  dispatch [label="ORT node dispatch"];
  compute [label="Kernel::Compute()"];
  pred [label="pred_kp_of / pred_rgbd_seg / pred_ctr_of"];

  wrap -> export -> fullonnx -> dispatch -> compute -> pred;
  sym -> export;
  lib -> reg -> dispatch;
}

8. 结论

full_onnx 导出流程的本质不是"把 PVN3D 直接另存成 ONNX",而是:

  • 用 wrapper 固定整网导出入口
  • 用 symbolic 把 PointNet2 扩展算子保留成 custom nodes
  • 用 ORT custom ops 在运行时补上这些节点的执行能力

因此这条链实际上分成两半:

  1. 导出半链:wrapper + symbolic + torch.onnx.export
  2. 运行半链:RegisterCustomOps + Kernel::Compute

六个 PointNet2 基础算子正是这条链的桥梁:

  • 在导出阶段,它们以 ai.onnx.contrib::PVN3D_* 节点落入 full_onnx
  • 在运行阶段,它们由 ORT CPU/GPU custom ops 实现真正计算

这也是为什么 full_onnx、symbolic、ORT custom ops 必须放在同一条工程链上理解,而不能分开看。

相关推荐
新缸中之脑2 小时前
Magika:文件类型检测小模型
人工智能
渣渣xiong2 小时前
从零开始:前端转型AI agent直到就业第十二天-第十三天
前端·人工智能
齐齐大魔王2 小时前
机器学习(一)
人工智能·机器学习
云和数据.ChenGuang2 小时前
机器学习之方差和标准差计算
人工智能·python·机器学习·django·pygame·deepseek
北京耐用通信2 小时前
破局工业通讯壁垒!耐达讯自动化EtherCAT转RS232网关,老设备焕新核心桥梁
服务器·网络·人工智能·科技·物联网·网络协议·自动化
永霖光电_UVLED2 小时前
AIXTRON(爱思强)于2026年的业务指引实现上调
大数据·人工智能
云起SAAS2 小时前
AI词元理财系统完整源码 | 多级分销返利+虚拟挖矿+复利投资 | Vue3前后端分离
人工智能·广告联盟·看广告变现轻·看广告激励积分兑换系统app·ai词元理财系统完整源码
m0_694845572 小时前
VoxCPM部署教程:构建AI语音交互系统
服务器·人工智能·后端·自动化
eastyuxiao2 小时前
多机 OpenClaw 互联完整方案
人工智能·架构