PVN3D 训练与评估代码流程详解

这篇文章面向仓库中的实际代码,不是论文复述。我会把 PVN3D 的训练和评估流程,从最外层的入口一路拆解到最内层的网络结构和姿态恢复逻辑,最后再说说这个项目当前实现上的一些问题。

1. 先说说整体思路

PVN3D 在这个仓库里的核心流程可以概括成这样:

  1. 从 RGB-D 数据中采样一批 3D 点。
  2. RGB 分支给每个采样点提取图像上的语义特征,点云分支提取几何特征。
  3. 两路特征通过 DenseFusion 融合到一起。
  4. 网络对每个点同时预测三样东西:
    • 点级别的语义分割类别 pred_rgbd_seg
    • 指向多个关键点的偏移 pred_kp_of
    • 指向目标中心的偏移 pred_ctr_of
  5. 训练时直接监督这三个输出。
  6. 评估时,用"点坐标减去预测偏移"得到每个点对关键点和中心的投票,然后用 MeanShift 聚类得到最终的关键点位置,最后通过刚体配准算出位姿 RT

从代码的角度看,有三个层次比较重要:

  1. 外层驱动:训练脚本、参数配置、DataLoader、Trainer。
  2. 中间层:数据集的张量构造、model_fn_decorator、损失和评估的分发逻辑。
  3. 内层:PVN3D 网络结构、PSPNetPointNet2MSGDenseFusion、MeanShift 投票和 best_fit_transform

2. 入口在哪里:train 和 eval 怎么启动

2.1 脚本入口

仓库最外层的 shell 脚本很薄,只是把命令转发给 Python 训练脚本:

  • pvn3d/train_linemod.sh
  • pvn3d/eval_linemod.sh
  • pvn3d/train_ycb.sh
  • pvn3d/eval_ycb.sh

有个地方需要注意:这个项目没有单独的 eval.py,训练和评估用的是同一个主程序,通过 -eval_net --test 来切换模式。

具体来说:

  • LineMOD 训练:python3 -m train.train_linemod_pvn3d --cls <obj>
  • LineMOD 评估:python3 -m train.train_linemod_pvn3d -checkpoint <pth> -eval_net --test --cls <obj>
  • YCB 也是类似的模式。

2.2 训练脚本在做什么

train/train_linemod_pvn3d.pytrain/train_ycb_pvn3d.py 做的事情基本一样:

  1. 解析命令行参数。
  2. 构建 Config 配置对象。
  3. 创建 train、val、test 的数据集和 DataLoader。
  4. 构建 PVN3D 网络。
  5. 包装 SyncBN、DataParallel,设置优化器和 scheduler。
  6. model_fn_decorator 封装前向、损失和评估的逻辑。
  7. 交给 Trainer.train()Trainer.eval_epoch() 执行。

有个工程细节值得注意:Trainer 本身不知道数据的具体业务含义,所有的业务逻辑都被塞进了 model_fn_decorator。也就是说,这个项目真正的"训练语义"不在 Trainer 里,而在那个装饰器里。

3. 配置层:训练依赖哪些全局设置

配置都在 common.pyConfig 类里。

3.1 Config 负责什么

它定义了:

  • 数据集路径
  • 日志路径
  • checkpoint 保存位置
  • 采样点数 n_sample_points
  • 关键点数量 n_keypoints
  • batch size
  • 测试集预处理的缓存位置
  • 类别映射和对称物体类别列表
  • 相机内参

3.2 几个关键默认值

有几个值对理解流程很重要:

  • n_sample_points = 8192 + 4096 = 12288
  • n_keypoints = 8
  • YCB 有 22 个类别(包括背景)
  • LineMOD 在单类训练模式下只有 2 类(背景 + 当前物体)

这意味着每个 batch 里,网络要在 12288 个采样点上做预测,每个点要输出:

  • 一个分割类别
  • 8 个关键点的 3D 偏移
  • 1 个中心点的 3D 偏移

4. 数据层:数据集到底输出什么

4.1 训练样本的返回结构

不管是 LM_Dataset 还是 YCB_Dataset__getitem__ 返回的都是同一套东西:

python 复制代码
rgb,
pcld,
cld_rgb_nrm,
choose,
kp_targ_ofst,
ctr_targ_ofst,
cls_ids,
RTs,
labels,
kp_3ds,
ctr_3ds

理解这组张量是搞懂整个项目的前提。

4.2 每个张量的含义

rgb

  • 形状大概是 (3, H, W)
  • 原始 RGB 图像,给 CNN 用的

pcld

  • 形状 (N, 3)
  • 从深度图反投影得到的采样点云

cld_rgb_nrm

  • 形状 (N, 9)
  • 由三部分拼接:xyz 坐标、rgb 颜色、法向量
  • 这就是点云分支的输入

choose

  • 形状 (1, N)
  • 表示这 N 个采样点在图像上对应哪些像素索引
  • 这个张量很关键,因为 RGB 分支要靠它把 2D 特征图上的像素特征"捞"到每个 3D 点上

labels

  • 形状 (N,)
  • 每个采样点的语义类别标签
  • LineMOD 单类训练时:0 是背景,1 是目标物体
  • YCB 时:0 是背景,1-21 是各个物体

kp_targ_ofst

  • 形状 (N, 8, 3)
  • 每个采样点到各个 3D 关键点的偏移
  • 计算方式是:point_xyz - keypoint_xyz

ctr_targ_ofst

  • 形状 (N, 3)
  • 每个采样点到目标中心点的偏移
  • 计算方式同上

cls_ids / RTs / kp_3ds / ctr_3ds

这几个主要用于评估:

  • cls_ids:当前图里有哪些物体类别
  • RTs:真实的位姿
  • kp_3ds:真实关键点在相机坐标系下的 3D 坐标
  • ctr_3ds:真实中心点在相机坐标系下的 3D 坐标

4.3 数据预处理流程

从代码看,数据集里的处理顺序大致是:

  1. 读取 RGB、深度图、mask、位姿数据。
  2. 训练时做颜色抖动、模糊、加噪、背景混合等数据增强。
  3. 深度图补全。
  4. 用相机内参把深度图转成点云。
  5. 计算每个点的法向量。
  6. 根据 mask 和深度有效区域采样点。
  7. 如果点数不够就 wrap 补齐,太多就随机下采样。
  8. 构造每个点的类别标签。
  9. 构造关键点和中心点的偏移监督信号。

这一步做完之后,训练样本已经从"图像+深度+标注"变成了"固定长度的点级监督任务"。

5. 训练主循环:从 batch 到 loss 的路径

5.1 Trainer.train() 只负责调度

Trainer.train() 的逻辑很常规:

  1. 从 DataLoader 取一个 batch。
  2. 调用 self.model_fn(self.model, batch)
  3. loss.backward()
  4. optimizer.step()
  5. 每隔一定 iteration 跑验证。
  6. 保存 checkpoint。

真正决定损失怎么算、评估怎么触发的是 model_fn_decorator

5.2 model_fn_decorator 是核心胶水层

它做了三件事:

  1. 把 batch 里的所有张量搬到 GPU。
  2. 调用模型前向:
python 复制代码
pred_kp_of, pred_rgbd_seg, pred_ctr_of = model(cld_rgb_nrm, rgb, choose)
  1. 计算三个损失:
    • loss_rgbd_seg
    • loss_kp_of
    • loss_ctr_of

总损失是:

python 复制代码
loss = 2.0 * loss_rgbd_seg + 1.0 * loss_kp_of + 1.0 * loss_ctr_of

5.3 三个监督目标分别约束什么

语义分割损失 loss_rgbd_seg

用的是 FocalLoss,监督每个点的类别预测。目标是让网络知道哪些点属于物体、哪些属于背景,因为后续姿态恢复只会在预测为物体的点上进行。

关键点偏移损失 loss_kp_of

用的是 OFLoss,本质是 masked L1。只在前景点上计算(背景点没有对应的关键点)。这个损失让每个前景点学会回答:"从我这个点走到物体关键点,需要移动多少?"

中心偏移损失 loss_ctr_of

也是 OFLoss,只是目标换成了物体中心。这个分支在评估时有两个作用:一是帮助找到物体中心的投票,二是作为关键点投票的过滤参考。

6. 网络内部:PVN3D 的结构

6.1 整体架构

lib/pvn3d.py 里的 PVN3D 可以这样理解:

复制代码
RGB backbone (ModifiedResnet / PSPNet)
        +
Point backbone (PointNet2MSG)
        +
DenseFusion
        +
3 个预测头

6.2 RGB 分支:ModifiedResnet

这里的 ModifiedResnet 实际包装的是 Modified_PSPNet(resnet34)。计算流程是:

  1. extractors.resnet34 提取特征。
  2. 经过 PSPModule 做金字塔池化。
  3. 连续三次上采样。
  4. 输出一个 128 通道的稠密特征图。

返回值有两个:x(用于融合的 dense feature map)和 x_seg(分割图的输出)。但在 PVN3D 主流程里,真正用的是前者,rgb_seg 实际上没参与训练。

6.3 choose 的作用:把 2D 特征对齐到 3D 点

RGB 特征图是规则的网格,点云采样点是离散的像素反投影结果。两者之间的桥梁就是 choose

代码逻辑是:

  1. 把 CNN 的输出展平成 (B, C, H*W)
  2. choose 在第三维上做 gather
  3. 得到 (B, C, N) 的、每个点对应的 RGB embedding。

这一步很关键,实现了"每个 3D 点都能拿到对应像素的 2D 语义特征"。

6.4 点云分支:Pointnet2MSG

点云分支的输入是 cld_rgb_nrm,每个点 9 维信息:

  • xyz 坐标 3 维
  • rgb 颜色 3 维
  • 法向量 3 维

构造时 input_channels=6,因为 PointNet2 内部约定把 xyz 单独拿出来,其余 6 维作为特征描述子。

6.4.1 Set Abstraction 层

Pointnet2MSG 用了四层 SA 逐步下采样和局部聚合:

  1. 2048 个点
  2. 1024 个点
  3. 512 个点
  4. 128 个点

每层都用了 MSG(多尺度分组),不同半径和采样数,本质上在做:

  • FPS 采样中心点
  • ball query 邻域分组
  • 每个尺度独立 MLP
  • max pooling 得到局部几何特征
6.4.2 Feature Propagation 层

后面四层 FP 把粗层的特征逐步插值回原始点集,最终输出和输入采样点一一对应的点特征,形状大概是 (B, 128, N)

6.5 融合层:DenseFusion

DenseFusion 是这个项目最关键的部分。输入是:

  • rgb_emb: (B, 128, N)
  • cld_emb: (B, 128, N)

输出拼接了三类信息:

  1. 局部拼接 feat_1 = [rgb_emb, cld_emb],256 通道
  2. 分别卷积后增强的局部拼接 feat_2,512 通道
  3. 通过 1D conv 后做全局平均池化得到的全局特征,1024 通道,再 broadcast 回每个点

最终输出 256 + 512 + 1024 = 1792 通道。这个设计的直觉是:点级预测既需要局部外观,也需要局部几何,还需要全局上下文来避免歧义。

6.6 三个预测头

DenseFusion 后面接了三个并行的 head:

SEG_layer

输出 (B, N, num_classes),用于点级分割。

KpOF_layer

输出 (B, num_kps, N, 3),每个点指向每个关键点的偏移向量。

CtrOf_layer

输出 (B, 1, N, 3),每个点指向物体中心的偏移向量。

7. 评估流程:为什么偏移能恢复姿态

7.1 评估不是直接回归 RT

PVN3D 不直接输出旋转矩阵和平移向量,而是走一条几何链路:

  1. 分割出属于物体的点。
  2. 每个点投票出物体中心和关键点的位置。
  3. 聚类得到最终的关键点 3D 坐标。
  4. 用已知的模型关键点和预测的关键点做刚体配准,算出位姿。

这也是 PVN3D 比直接回归 6D 位姿更稳定的原因之一。

7.2 中心投票与关键点投票

pvn3d_eval_utils.py 里,先把网络输出的偏移转回绝对坐标:

复制代码
pred_ctr = pcld - ctr_of
pred_kp  = pcld - kp_of

注意,数据集的监督定义本来就是 point - target,所以预测出来后直接用 point - offset 就能得到 target 的绝对位置。这样每个前景点都能对物体中心和每个关键点投出一个 3D 位置猜测。

7.3 MeanShift 聚类

同一个物体上的很多点会给同一个关键点投票,理论上这些投票应该聚在一起。代码用 MeanShiftTorch 做聚类:

  1. 对中心投票聚类,得到中心估计。
  2. 必要时根据中心聚类结果过滤点。
  3. 对每个关键点的投票再做一次聚类,得到稳定的关键点位置。

这一步把"稠密的点级预测"压缩成了"少量结构化的关键点"。

7.4 刚体拟合恢复位姿

有了预测的关键点坐标,就能和 CAD 模型上的模板关键点做配准。具体由 Basic_Utils.best_fit_transform() 完成:

  1. 取模型坐标系下的关键点 mesh_kps
  2. 取相机坐标系下的预测关键点 cls_kps
  3. 用 SVD 求解最小二乘刚体变换
  4. 输出 3x4RT

这一层是纯几何计算,不再依赖神经网络。

7.5 评估指标

YCB 和 LineMOD 都用 ADD / ADD-S,但输出组织略有不同:

  • YCB 最终统计每类的 AUC 和总体 AUC
  • LineMOD 额外统计 ADD < 0.1 * diameter

对称物体用 ADD-S,非对称物体用 ADD。

8. 流程图怎么读

pvn3d_train_eval_flow.dot 的时候,建议按这个顺序:

  1. Entry Layer:看命令落到哪个 Python 主程序
  2. Config + Dataset:看训练样本怎么组织和采样
  3. Batch Tensors:明确每个张量的形状和含义
  4. PVN3D Forward:理解 RGB / point / fusion / heads 四段式结构
  5. Training Loop:理解损失怎么汇总和反传
  6. Evaluation / Pose Recovery:理解 offset 投票、聚类和刚体配准

如果后面要单步调试,这也是最合理的断点顺序。

9. 推荐的代码阅读顺序

如果想深入源码,建议按这个顺序读:

  1. train/train_linemod_pvn3d.py
  2. datasets/linemod/linemod_dataset.py
  3. lib/pvn3d.py
  4. lib/pspnet.py
  5. lib/extractors.py
  6. lib/pointnet2_utils/pointnet2_modules.py
  7. lib/loss.py
  8. lib/utils/pvn3d_eval_utils.py
  9. lib/utils/basic_utils.py

这样读的好处是:先建立全局调用链,再深入每层的张量变换,最后看几何后处理和指标。

10. 这个项目目前的不足

这部分不讨论论文本身,只从"当前仓库代码实现"的角度说。

10.1 训练与评估耦合太重

好处是代码少,坏处是职责不清。表现为:

  • 没有独立的评估程序
  • 训练脚本承担了参数解析、数据构建、模型构建、训练、验证、测试、指标汇总全部职责

后果是调试路径长、复用困难,想做独立的 benchmark、profiling、ablation 都不太方便。

10.2 model_fn_decorator 塞了太多业务逻辑

Trainer 很薄,model_fn_decorator 把好几件事绑在一起:

  • batch 解包
  • 前向调用
  • 损失计算
  • 精度统计
  • 测试时的位姿评估

这导致单元测试不好写,前向调试和评估调试耦合在一起,想替换 loss 或者只跑推理也比较别扭。

10.3 数据集层用裸 except,错误会被吞掉

LM_Dataset.get_item()YCB_Dataset.get_item() 最后都有:

python 复制代码
except:
    return None

这挺危险的。文件路径错误、标注格式错误、深度转点云失败、法向量计算失败这些本应暴露的问题,都会被静默吞掉,只表现为"偶尔 sample 是 None"。排查训练不稳定、数据损坏、偶发崩溃的时候,会非常头疼。

10.4 验证/测试集对 None 样本不够稳健

训练集在拿到 None 时会继续随机重采,但验证/测试集可能直接把 None 返回给 DataLoader,默认的 collate_fnNone 并不友好,坏样本可能直接导致评估崩溃。

10.5 工程参数和路径硬编码比较多

比如 DataLoader worker 数、eval_frequency、训练总 epoch、各种路径和目录结构、LineMOD/YCB 的专有路径布局,都写死在代码里。这带来两个问题:一是迁移到新环境容易踩路径坑,二是做实验对比时不够声明式和可追踪。

10.6 评估的线程模型有风险

TorchEval.eval_pose_parallel()ThreadPoolExecutor,在每个线程里继续操作 CUDA tensor 和 torch 逻辑。这类设计不是一定错,但可预测性差、调试困难,对 CUDA 上下文和线程调度不够透明。在吞吐上未必总有收益,但出错时排查成本会很高。

10.7 Point normal 依赖较重,YCB 没有退化路径

当前实现高度依赖 python-pcl

  • LineMOD 数据集做了 fallback,没有 PCL 时返回零法向量
  • YCB 数据集没有对应的 fallback

这会导致环境部署不统一、数据集间行为不一致,某些环境下 YCB 更容易直接失败。

10.8 一些实现细节不够干净

举几个例子:

  • DenseFusion_1super(DenseFusion, self).__init__() 明显不规范,虽然当前没被使用
  • of_l1_loss()if reduce: torch.mean(in_loss) 没有把结果赋回去,分支逻辑实际上无效
  • 训练脚本里大量用 os.system('mkdir -p ...')os.system('echo ...'),可维护性较差
  • rgb_seg 在主流程里基本没用,说明 RGB backbone 里有部分冗余输出

这些问题不一定马上导致结果错误,但会降低代码可信度。

10.9 模块边界不够清晰,不利于替换 backbone 或 head

从设计上看,PVN3D 很适合做结构替换实验:换 RGB backbone、换 point backbone、换 fusion、换 pose voting 策略。但当前代码把很多行为写死在 lib/pvn3d.py 和训练脚本里,模块接口不够抽象。这会影响做 ablation study、接入新 backbone、做 transformer / sparse convolution 替换实验。

11. 从调试角度,最值得下断点的位置

如果想真正"深入剖析",建议按下面顺序打断点:

数据是否正确

  • LM_Dataset.get_itemYCB_Dataset.get_item
  • 重点检查:choosecld 是否一一对应,labels 是否只有预期类别,kp_targ_ofst / ctr_targ_ofst 是否真的是 point - target

前向 shape 是否匹配

  • PVN3D.forward
  • 重点检查:out_rgb 的空间分辨率,choose.repeat(1, di, 1) 后的 gather 是否越界,pcld_embrgb_emb 的点数维度是否一致

loss 是否只在前景点生效

  • of_l1_loss
  • 重点确认:labels > 1e-8 形成的 mask 是否符合数据定义,LineMOD 单类时前景标签是否确实为 1

评估恢复的几何链条是否自洽

  • cal_frame_posescal_frame_poses_lmbest_fit_transform
  • 重点确认:offset 的正负号、坐标系是否一致、模型关键点和预测关键点的顺序是否一致

12. 最后一层理解:PVN3D 的本质

把实现细节都剥掉,PVN3D 的本质其实是:

复制代码
点级实例分割
+ 点到结构关键点的几何投票
+ 基于 3D 对应点的刚体配准

它不是"直接回归位姿",而是在学习一个更容易监督、更符合几何约束的中间表示:点属于谁,点指向哪。只要这两个问题回答得足够好,最终位姿就能通过传统几何方法稳定恢复出来。

所以理解 PVN3D 的时候,不能只盯着网络结构,必须同时看数据监督的定义、offset 的几何意义、MeanShift 投票、SVD 刚体拟合,这四者缺一不可。

13. 总结

这个仓库里的 PVN3D 代码虽然年代感比较强,但主线很清晰:

  1. 数据集把 RGB-D 样本变成固定长度的点级监督任务。
  2. PVN3D 用 RGB backbone 和 PointNet2 分别提取外观和几何特征。
  3. DenseFusion 把局部和全局信息拼起来。
  4. 网络对每个点同时预测类别、关键点偏移、中心偏移。
  5. 评估时通过点级投票、MeanShift 聚类和刚体拟合恢复位姿。

它的优势不在网络特别深,而在于把神经网络输出和几何恢复链条绑得很自然。

但从工程实现看,这个仓库有明显的"老代码"特征:train/eval 耦合、错误处理粗糙、配置硬编码、模块边界不清晰、一些实现细节不够干净。

相关推荐
Techblog of HaoWANG几秒前
目标检测与跟踪(16)-- Ubuntu 20.04 下 ROS1 + Conda 虚拟环境开机自启动方案(兼容 ROS2 共存)
人工智能·目标检测·ubuntu·机器人·视觉检测·conda·控制
TechWayfarer3 分钟前
边缘计算节点的IP管理:如何精准定位全球部署的AI推理节点?
人工智能·tcp/ip·边缘计算
财经资讯数据_灵砚智能18 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年4月20日
人工智能·python·信息可视化·自然语言处理·ai编程
j_xxx404_20 分钟前
【AI大模型入门(二)】提示词工程进阶
人工智能·ai·prompt
程序员cxuan29 分钟前
vibe coding 凉了,wish coding 来了
人工智能·后端·程序员
传说故事34 分钟前
【论文阅读】ViVa: A Video-Generative Value Model for Robot Reinforcement Learning
论文阅读·人工智能·强化学习·具身智能
keineahnung234544 分钟前
PyTorch 張量尺寸為 1 時,步長為何不具語意?
人工智能·pytorch·python·深度学习
小t说说1 小时前
2026年PPT生成工具评测及使用体验
大数据·前端·人工智能
NineData1 小时前
NineData 将亮相 2026 德国汉诺威工业博览会
数据库·人工智能·数据库管理工具·ninedata·数据库迁移工具·玖章算术