PVN3D 训练与评估代码流程详解

这篇文章面向仓库中的实际代码,不是论文复述。我会把 PVN3D 的训练和评估流程,从最外层的入口一路拆解到最内层的网络结构和姿态恢复逻辑,最后再说说这个项目当前实现上的一些问题。

1. 先说说整体思路

PVN3D 在这个仓库里的核心流程可以概括成这样:

  1. 从 RGB-D 数据中采样一批 3D 点。
  2. RGB 分支给每个采样点提取图像上的语义特征,点云分支提取几何特征。
  3. 两路特征通过 DenseFusion 融合到一起。
  4. 网络对每个点同时预测三样东西:
    • 点级别的语义分割类别 pred_rgbd_seg
    • 指向多个关键点的偏移 pred_kp_of
    • 指向目标中心的偏移 pred_ctr_of
  5. 训练时直接监督这三个输出。
  6. 评估时,用"点坐标减去预测偏移"得到每个点对关键点和中心的投票,然后用 MeanShift 聚类得到最终的关键点位置,最后通过刚体配准算出位姿 RT

从代码的角度看,有三个层次比较重要:

  1. 外层驱动:训练脚本、参数配置、DataLoader、Trainer。
  2. 中间层:数据集的张量构造、model_fn_decorator、损失和评估的分发逻辑。
  3. 内层:PVN3D 网络结构、PSPNetPointNet2MSGDenseFusion、MeanShift 投票和 best_fit_transform

2. 入口在哪里:train 和 eval 怎么启动

2.1 脚本入口

仓库最外层的 shell 脚本很薄,只是把命令转发给 Python 训练脚本:

  • pvn3d/train_linemod.sh
  • pvn3d/eval_linemod.sh
  • pvn3d/train_ycb.sh
  • pvn3d/eval_ycb.sh

有个地方需要注意:这个项目没有单独的 eval.py,训练和评估用的是同一个主程序,通过 -eval_net --test 来切换模式。

具体来说:

  • LineMOD 训练:python3 -m train.train_linemod_pvn3d --cls <obj>
  • LineMOD 评估:python3 -m train.train_linemod_pvn3d -checkpoint <pth> -eval_net --test --cls <obj>
  • YCB 也是类似的模式。

2.2 训练脚本在做什么

train/train_linemod_pvn3d.pytrain/train_ycb_pvn3d.py 做的事情基本一样:

  1. 解析命令行参数。
  2. 构建 Config 配置对象。
  3. 创建 train、val、test 的数据集和 DataLoader。
  4. 构建 PVN3D 网络。
  5. 包装 SyncBN、DataParallel,设置优化器和 scheduler。
  6. model_fn_decorator 封装前向、损失和评估的逻辑。
  7. 交给 Trainer.train()Trainer.eval_epoch() 执行。

有个工程细节值得注意:Trainer 本身不知道数据的具体业务含义,所有的业务逻辑都被塞进了 model_fn_decorator。也就是说,这个项目真正的"训练语义"不在 Trainer 里,而在那个装饰器里。

3. 配置层:训练依赖哪些全局设置

配置都在 common.pyConfig 类里。

3.1 Config 负责什么

它定义了:

  • 数据集路径
  • 日志路径
  • checkpoint 保存位置
  • 采样点数 n_sample_points
  • 关键点数量 n_keypoints
  • batch size
  • 测试集预处理的缓存位置
  • 类别映射和对称物体类别列表
  • 相机内参

3.2 几个关键默认值

有几个值对理解流程很重要:

  • n_sample_points = 8192 + 4096 = 12288
  • n_keypoints = 8
  • YCB 有 22 个类别(包括背景)
  • LineMOD 在单类训练模式下只有 2 类(背景 + 当前物体)

这意味着每个 batch 里,网络要在 12288 个采样点上做预测,每个点要输出:

  • 一个分割类别
  • 8 个关键点的 3D 偏移
  • 1 个中心点的 3D 偏移

4. 数据层:数据集到底输出什么

4.1 训练样本的返回结构

不管是 LM_Dataset 还是 YCB_Dataset__getitem__ 返回的都是同一套东西:

python 复制代码
rgb,
pcld,
cld_rgb_nrm,
choose,
kp_targ_ofst,
ctr_targ_ofst,
cls_ids,
RTs,
labels,
kp_3ds,
ctr_3ds

理解这组张量是搞懂整个项目的前提。

4.2 每个张量的含义

rgb

  • 形状大概是 (3, H, W)
  • 原始 RGB 图像,给 CNN 用的

pcld

  • 形状 (N, 3)
  • 从深度图反投影得到的采样点云

cld_rgb_nrm

  • 形状 (N, 9)
  • 由三部分拼接:xyz 坐标、rgb 颜色、法向量
  • 这就是点云分支的输入

choose

  • 形状 (1, N)
  • 表示这 N 个采样点在图像上对应哪些像素索引
  • 这个张量很关键,因为 RGB 分支要靠它把 2D 特征图上的像素特征"捞"到每个 3D 点上

labels

  • 形状 (N,)
  • 每个采样点的语义类别标签
  • LineMOD 单类训练时:0 是背景,1 是目标物体
  • YCB 时:0 是背景,1-21 是各个物体

kp_targ_ofst

  • 形状 (N, 8, 3)
  • 每个采样点到各个 3D 关键点的偏移
  • 计算方式是:point_xyz - keypoint_xyz

ctr_targ_ofst

  • 形状 (N, 3)
  • 每个采样点到目标中心点的偏移
  • 计算方式同上

cls_ids / RTs / kp_3ds / ctr_3ds

这几个主要用于评估:

  • cls_ids:当前图里有哪些物体类别
  • RTs:真实的位姿
  • kp_3ds:真实关键点在相机坐标系下的 3D 坐标
  • ctr_3ds:真实中心点在相机坐标系下的 3D 坐标

4.3 数据预处理流程

从代码看,数据集里的处理顺序大致是:

  1. 读取 RGB、深度图、mask、位姿数据。
  2. 训练时做颜色抖动、模糊、加噪、背景混合等数据增强。
  3. 深度图补全。
  4. 用相机内参把深度图转成点云。
  5. 计算每个点的法向量。
  6. 根据 mask 和深度有效区域采样点。
  7. 如果点数不够就 wrap 补齐,太多就随机下采样。
  8. 构造每个点的类别标签。
  9. 构造关键点和中心点的偏移监督信号。

这一步做完之后,训练样本已经从"图像+深度+标注"变成了"固定长度的点级监督任务"。

5. 训练主循环:从 batch 到 loss 的路径

5.1 Trainer.train() 只负责调度

Trainer.train() 的逻辑很常规:

  1. 从 DataLoader 取一个 batch。
  2. 调用 self.model_fn(self.model, batch)
  3. loss.backward()
  4. optimizer.step()
  5. 每隔一定 iteration 跑验证。
  6. 保存 checkpoint。

真正决定损失怎么算、评估怎么触发的是 model_fn_decorator

5.2 model_fn_decorator 是核心胶水层

它做了三件事:

  1. 把 batch 里的所有张量搬到 GPU。
  2. 调用模型前向:
python 复制代码
pred_kp_of, pred_rgbd_seg, pred_ctr_of = model(cld_rgb_nrm, rgb, choose)
  1. 计算三个损失:
    • loss_rgbd_seg
    • loss_kp_of
    • loss_ctr_of

总损失是:

python 复制代码
loss = 2.0 * loss_rgbd_seg + 1.0 * loss_kp_of + 1.0 * loss_ctr_of

5.3 三个监督目标分别约束什么

语义分割损失 loss_rgbd_seg

用的是 FocalLoss,监督每个点的类别预测。目标是让网络知道哪些点属于物体、哪些属于背景,因为后续姿态恢复只会在预测为物体的点上进行。

关键点偏移损失 loss_kp_of

用的是 OFLoss,本质是 masked L1。只在前景点上计算(背景点没有对应的关键点)。这个损失让每个前景点学会回答:"从我这个点走到物体关键点,需要移动多少?"

中心偏移损失 loss_ctr_of

也是 OFLoss,只是目标换成了物体中心。这个分支在评估时有两个作用:一是帮助找到物体中心的投票,二是作为关键点投票的过滤参考。

6. 网络内部:PVN3D 的结构

6.1 整体架构

lib/pvn3d.py 里的 PVN3D 可以这样理解:

复制代码
RGB backbone (ModifiedResnet / PSPNet)
        +
Point backbone (PointNet2MSG)
        +
DenseFusion
        +
3 个预测头

6.2 RGB 分支:ModifiedResnet

这里的 ModifiedResnet 实际包装的是 Modified_PSPNet(resnet34)。计算流程是:

  1. extractors.resnet34 提取特征。
  2. 经过 PSPModule 做金字塔池化。
  3. 连续三次上采样。
  4. 输出一个 128 通道的稠密特征图。

返回值有两个:x(用于融合的 dense feature map)和 x_seg(分割图的输出)。但在 PVN3D 主流程里,真正用的是前者,rgb_seg 实际上没参与训练。

6.3 choose 的作用:把 2D 特征对齐到 3D 点

RGB 特征图是规则的网格,点云采样点是离散的像素反投影结果。两者之间的桥梁就是 choose

代码逻辑是:

  1. 把 CNN 的输出展平成 (B, C, H*W)
  2. choose 在第三维上做 gather
  3. 得到 (B, C, N) 的、每个点对应的 RGB embedding。

这一步很关键,实现了"每个 3D 点都能拿到对应像素的 2D 语义特征"。

6.4 点云分支:Pointnet2MSG

点云分支的输入是 cld_rgb_nrm,每个点 9 维信息:

  • xyz 坐标 3 维
  • rgb 颜色 3 维
  • 法向量 3 维

构造时 input_channels=6,因为 PointNet2 内部约定把 xyz 单独拿出来,其余 6 维作为特征描述子。

6.4.1 Set Abstraction 层

Pointnet2MSG 用了四层 SA 逐步下采样和局部聚合:

  1. 2048 个点
  2. 1024 个点
  3. 512 个点
  4. 128 个点

每层都用了 MSG(多尺度分组),不同半径和采样数,本质上在做:

  • FPS 采样中心点
  • ball query 邻域分组
  • 每个尺度独立 MLP
  • max pooling 得到局部几何特征
6.4.2 Feature Propagation 层

后面四层 FP 把粗层的特征逐步插值回原始点集,最终输出和输入采样点一一对应的点特征,形状大概是 (B, 128, N)

6.5 融合层:DenseFusion

DenseFusion 是这个项目最关键的部分。输入是:

  • rgb_emb: (B, 128, N)
  • cld_emb: (B, 128, N)

输出拼接了三类信息:

  1. 局部拼接 feat_1 = [rgb_emb, cld_emb],256 通道
  2. 分别卷积后增强的局部拼接 feat_2,512 通道
  3. 通过 1D conv 后做全局平均池化得到的全局特征,1024 通道,再 broadcast 回每个点

最终输出 256 + 512 + 1024 = 1792 通道。这个设计的直觉是:点级预测既需要局部外观,也需要局部几何,还需要全局上下文来避免歧义。

6.6 三个预测头

DenseFusion 后面接了三个并行的 head:

SEG_layer

输出 (B, N, num_classes),用于点级分割。

KpOF_layer

输出 (B, num_kps, N, 3),每个点指向每个关键点的偏移向量。

CtrOf_layer

输出 (B, 1, N, 3),每个点指向物体中心的偏移向量。

7. 评估流程:为什么偏移能恢复姿态

7.1 评估不是直接回归 RT

PVN3D 不直接输出旋转矩阵和平移向量,而是走一条几何链路:

  1. 分割出属于物体的点。
  2. 每个点投票出物体中心和关键点的位置。
  3. 聚类得到最终的关键点 3D 坐标。
  4. 用已知的模型关键点和预测的关键点做刚体配准,算出位姿。

这也是 PVN3D 比直接回归 6D 位姿更稳定的原因之一。

7.2 中心投票与关键点投票

pvn3d_eval_utils.py 里,先把网络输出的偏移转回绝对坐标:

复制代码
pred_ctr = pcld - ctr_of
pred_kp  = pcld - kp_of

注意,数据集的监督定义本来就是 point - target,所以预测出来后直接用 point - offset 就能得到 target 的绝对位置。这样每个前景点都能对物体中心和每个关键点投出一个 3D 位置猜测。

7.3 MeanShift 聚类

同一个物体上的很多点会给同一个关键点投票,理论上这些投票应该聚在一起。代码用 MeanShiftTorch 做聚类:

  1. 对中心投票聚类,得到中心估计。
  2. 必要时根据中心聚类结果过滤点。
  3. 对每个关键点的投票再做一次聚类,得到稳定的关键点位置。

这一步把"稠密的点级预测"压缩成了"少量结构化的关键点"。

7.4 刚体拟合恢复位姿

有了预测的关键点坐标,就能和 CAD 模型上的模板关键点做配准。具体由 Basic_Utils.best_fit_transform() 完成:

  1. 取模型坐标系下的关键点 mesh_kps
  2. 取相机坐标系下的预测关键点 cls_kps
  3. 用 SVD 求解最小二乘刚体变换
  4. 输出 3x4RT

这一层是纯几何计算,不再依赖神经网络。

7.5 评估指标

YCB 和 LineMOD 都用 ADD / ADD-S,但输出组织略有不同:

  • YCB 最终统计每类的 AUC 和总体 AUC
  • LineMOD 额外统计 ADD < 0.1 * diameter

对称物体用 ADD-S,非对称物体用 ADD。

8. 流程图怎么读

pvn3d_train_eval_flow.dot 的时候,建议按这个顺序:

  1. Entry Layer:看命令落到哪个 Python 主程序
  2. Config + Dataset:看训练样本怎么组织和采样
  3. Batch Tensors:明确每个张量的形状和含义
  4. PVN3D Forward:理解 RGB / point / fusion / heads 四段式结构
  5. Training Loop:理解损失怎么汇总和反传
  6. Evaluation / Pose Recovery:理解 offset 投票、聚类和刚体配准

如果后面要单步调试,这也是最合理的断点顺序。

9. 推荐的代码阅读顺序

如果想深入源码,建议按这个顺序读:

  1. train/train_linemod_pvn3d.py
  2. datasets/linemod/linemod_dataset.py
  3. lib/pvn3d.py
  4. lib/pspnet.py
  5. lib/extractors.py
  6. lib/pointnet2_utils/pointnet2_modules.py
  7. lib/loss.py
  8. lib/utils/pvn3d_eval_utils.py
  9. lib/utils/basic_utils.py

这样读的好处是:先建立全局调用链,再深入每层的张量变换,最后看几何后处理和指标。

10. 这个项目目前的不足

这部分不讨论论文本身,只从"当前仓库代码实现"的角度说。

10.1 训练与评估耦合太重

好处是代码少,坏处是职责不清。表现为:

  • 没有独立的评估程序
  • 训练脚本承担了参数解析、数据构建、模型构建、训练、验证、测试、指标汇总全部职责

后果是调试路径长、复用困难,想做独立的 benchmark、profiling、ablation 都不太方便。

10.2 model_fn_decorator 塞了太多业务逻辑

Trainer 很薄,model_fn_decorator 把好几件事绑在一起:

  • batch 解包
  • 前向调用
  • 损失计算
  • 精度统计
  • 测试时的位姿评估

这导致单元测试不好写,前向调试和评估调试耦合在一起,想替换 loss 或者只跑推理也比较别扭。

10.3 数据集层用裸 except,错误会被吞掉

LM_Dataset.get_item()YCB_Dataset.get_item() 最后都有:

python 复制代码
except:
    return None

这挺危险的。文件路径错误、标注格式错误、深度转点云失败、法向量计算失败这些本应暴露的问题,都会被静默吞掉,只表现为"偶尔 sample 是 None"。排查训练不稳定、数据损坏、偶发崩溃的时候,会非常头疼。

10.4 验证/测试集对 None 样本不够稳健

训练集在拿到 None 时会继续随机重采,但验证/测试集可能直接把 None 返回给 DataLoader,默认的 collate_fnNone 并不友好,坏样本可能直接导致评估崩溃。

10.5 工程参数和路径硬编码比较多

比如 DataLoader worker 数、eval_frequency、训练总 epoch、各种路径和目录结构、LineMOD/YCB 的专有路径布局,都写死在代码里。这带来两个问题:一是迁移到新环境容易踩路径坑,二是做实验对比时不够声明式和可追踪。

10.6 评估的线程模型有风险

TorchEval.eval_pose_parallel()ThreadPoolExecutor,在每个线程里继续操作 CUDA tensor 和 torch 逻辑。这类设计不是一定错,但可预测性差、调试困难,对 CUDA 上下文和线程调度不够透明。在吞吐上未必总有收益,但出错时排查成本会很高。

10.7 Point normal 依赖较重,YCB 没有退化路径

当前实现高度依赖 python-pcl

  • LineMOD 数据集做了 fallback,没有 PCL 时返回零法向量
  • YCB 数据集没有对应的 fallback

这会导致环境部署不统一、数据集间行为不一致,某些环境下 YCB 更容易直接失败。

10.8 一些实现细节不够干净

举几个例子:

  • DenseFusion_1super(DenseFusion, self).__init__() 明显不规范,虽然当前没被使用
  • of_l1_loss()if reduce: torch.mean(in_loss) 没有把结果赋回去,分支逻辑实际上无效
  • 训练脚本里大量用 os.system('mkdir -p ...')os.system('echo ...'),可维护性较差
  • rgb_seg 在主流程里基本没用,说明 RGB backbone 里有部分冗余输出

这些问题不一定马上导致结果错误,但会降低代码可信度。

10.9 模块边界不够清晰,不利于替换 backbone 或 head

从设计上看,PVN3D 很适合做结构替换实验:换 RGB backbone、换 point backbone、换 fusion、换 pose voting 策略。但当前代码把很多行为写死在 lib/pvn3d.py 和训练脚本里,模块接口不够抽象。这会影响做 ablation study、接入新 backbone、做 transformer / sparse convolution 替换实验。

11. 从调试角度,最值得下断点的位置

如果想真正"深入剖析",建议按下面顺序打断点:

数据是否正确

  • LM_Dataset.get_itemYCB_Dataset.get_item
  • 重点检查:choosecld 是否一一对应,labels 是否只有预期类别,kp_targ_ofst / ctr_targ_ofst 是否真的是 point - target

前向 shape 是否匹配

  • PVN3D.forward
  • 重点检查:out_rgb 的空间分辨率,choose.repeat(1, di, 1) 后的 gather 是否越界,pcld_embrgb_emb 的点数维度是否一致

loss 是否只在前景点生效

  • of_l1_loss
  • 重点确认:labels > 1e-8 形成的 mask 是否符合数据定义,LineMOD 单类时前景标签是否确实为 1

评估恢复的几何链条是否自洽

  • cal_frame_posescal_frame_poses_lmbest_fit_transform
  • 重点确认:offset 的正负号、坐标系是否一致、模型关键点和预测关键点的顺序是否一致

12. 最后一层理解:PVN3D 的本质

把实现细节都剥掉,PVN3D 的本质其实是:

复制代码
点级实例分割
+ 点到结构关键点的几何投票
+ 基于 3D 对应点的刚体配准

它不是"直接回归位姿",而是在学习一个更容易监督、更符合几何约束的中间表示:点属于谁,点指向哪。只要这两个问题回答得足够好,最终位姿就能通过传统几何方法稳定恢复出来。

所以理解 PVN3D 的时候,不能只盯着网络结构,必须同时看数据监督的定义、offset 的几何意义、MeanShift 投票、SVD 刚体拟合,这四者缺一不可。

13. 总结

这个仓库里的 PVN3D 代码虽然年代感比较强,但主线很清晰:

  1. 数据集把 RGB-D 样本变成固定长度的点级监督任务。
  2. PVN3D 用 RGB backbone 和 PointNet2 分别提取外观和几何特征。
  3. DenseFusion 把局部和全局信息拼起来。
  4. 网络对每个点同时预测类别、关键点偏移、中心偏移。
  5. 评估时通过点级投票、MeanShift 聚类和刚体拟合恢复位姿。

它的优势不在网络特别深,而在于把神经网络输出和几何恢复链条绑得很自然。

但从工程实现看,这个仓库有明显的"老代码"特征:train/eval 耦合、错误处理粗糙、配置硬编码、模块边界不清晰、一些实现细节不够干净。

相关推荐
前端大波2 小时前
Vue 项目中让 AI 更稳:AGENTS.md + Prompt 模板实践
vue.js·人工智能·prompt
珠海西格电力2 小时前
零碳园区能源互联技术路径适配方案的成本效益分析
大数据·人工智能·架构·智慧城市·能源
Daydream.V2 小时前
OpenCV——DNN模块实现风格迁移
人工智能·opencv·dnn
jinglong.zha2 小时前
OpenClaw核心概念速览
人工智能·ai·大模型·openclaw·养龙虾
摄影图2 小时前
AI神经网络数据可视化图片素材 多格式多场景助力设计高效开展
人工智能·aigc·插画
IT大师兄吖2 小时前
MatAnyone2 视频去除背景 懒人整合包
人工智能·音视频
小超同学你好2 小时前
面向 LLM 的程序设计 1:API 契约设计:从 REST 到「能力端点」
人工智能·语言模型
程序员Shawn2 小时前
【机器学习 | 第八篇】- 朴素贝叶斯
人工智能·机器学习
A 小码农2 小时前
亲测AI智能小助手-IDEA中使用腾讯混元大模型
java·人工智能·intellij-idea