PVN3D 中 SA 模块与 FP 模块详解

1. 文档目标

本文单独说明 PVN3D 点云主干 Pointnet2MSG 中:

  • SA 模块在全局网络中的位置和作用
  • FP 模块在全局网络中的位置和作用
  • 每个核心算子的职责、输入输出和实现细节
  • 用 Python 伪代码还原关键逻辑
  • 用 Graphviz 描述数据流

核心代码位置:

  • Pointnet2MSG 定义:pvn3d/lib/pvn3d.py
  • PointnetSAModuleMSG / PointnetFPModulepvn3d/lib/pointnet2_utils/pointnet2_modules.py
  • furthest_point_sample / ball_query / grouping_operation / three_nn / three_interpolatepvn3d/lib/pointnet2_utils/pointnet2_utils.py
  • SharedMLPpvn3d/lib/utils/etw_pytorch_utils/pytorch_utils.py

2. Pointnet2MSG 在 PVN3D 中的全局位置

在 PVN3D 里,点云分支不是单独工作的,它和 RGB 分支共同组成最终的融合特征。

对应关系如下:

  1. RGB 图像先进入 ModifiedResnet,得到逐像素 RGB embedding。
  2. 点云输入进入 Pointnet2MSG,得到逐点 point embedding。
  3. 两路特征进入 DenseFusion 做融合。
  4. 融合结果再送入分割头、关键点偏移头、中心点偏移头。

Pointnet2MSG 的定义和调用位于 pvn3d/lib/pvn3d.py:46-153pvn3d/lib/pvn3d.py:237 附近。它本身是一个典型的 PointNet++ segmentation backbone:

  • 前半段:4 层 SA,负责下采样和局部区域编码
  • 后半段:4 层 FP,负责把深层语义传播回高分辨率点集

这意味着:

  • SA 负责把原始点云逐步压缩成更少、更强的语义点
  • FP 负责把压缩后的高级语义重新对齐并传播到原始点级分辨率

最终 Pointnet2MSG.forward() 返回的是 l_features[0],即恢复到原始点数分辨率后的逐点特征。


3. Pointnet2MSG 的整体结构

3.1 输入与拆分

输入张量格式:

python 复制代码
pointcloud.shape == (B, N, 3 + input_channels)

_break_up_pc() 中被拆成:

python 复制代码
xyz = pc[..., 0:3]                         # (B, N, 3)
features = pc[..., 3:].transpose(1, 2)    # (B, C, N)

这里有一个很重要的约定:

  • xyz 始终保持 (B, N, 3)
  • 点特征 features 始终保持 (B, C, N)

这是后面所有 PointNet2 算子的数据布局基础。

3.2 四层 SA 配置

Pointnet2MSGpvn3d/lib/pvn3d.py:64-111 中定义了四层 SA:

层级 npoint radii nsamples 两个分支 MLP 输出通道
SA1 2048 [0.0175, 0.025] [16, 32] [C,16,16,32] / [C,32,32,64] 96
SA2 1024 [0.025, 0.05] [16, 32] [96,64,64,128] / [96,64,96,128] 256
SA3 512 [0.05, 0.1] [16, 32] [256,128,196,256] / [256,128,196,256] 512
SA4 128 [0.1, 0.2] [16, 32] [512,256,256,512] / [512,256,384,512] 1024

说明:

  • 每一层都是 MSG,即 Multi-Scale Grouping
  • 同一层中会对同一个中心点做两个半径尺度的局部邻域聚合
  • 两个尺度分支各自提特征后,在通道维拼接

3.3 四层 FP 配置

Pointnet2MSGpvn3d/lib/pvn3d.py:113-117 中定义了四层 FP:

层级 传播方向 输入拼接通道 MLP 输出通道
FP4 SA4 -> SA3 1024 + 512 [1536, 512, 512] 512
FP3 SA3 -> SA2 512 + 256 [768, 512, 512] 512
FP2 SA2 -> SA1 512 + 96 [608, 256, 256] 256
FP1 SA1 -> Input 256 + input_channels [256 + C, 128, 128] 128

注意实际代码里 FP_modules 是按从浅到深存储,但在 forward() 里通过负索引倒序调用,所以真实传播顺序是:

text 复制代码
SA4 -> SA3 -> SA2 -> SA1 -> 原始点集

4. SA 模块在全局 PointNet2MSG 中的位置和作用

4.1 SA 的位置

SA 位于 Pointnet2MSG 的编码端。对应代码:

python 复制代码
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
    li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
    l_xyz.append(li_xyz)
    l_features.append(li_features)

执行结束后:

  • l_xyz[0] / l_features[0]:原始点集
  • l_xyz[1] / l_features[1]:SA1 输出
  • l_xyz[2] / l_features[2]:SA2 输出
  • l_xyz[3] / l_features[3]:SA3 输出
  • l_xyz[4] / l_features[4]:SA4 输出

4.2 SA 的核心作用

SA 做三件事:

  1. 选择代表点
  2. 围绕代表点提取局部邻域
  3. 将邻域内的信息压缩为中心点的新特征

从网络功能上看,SA 相当于二维 CNN 中的:

  • 空间降采样
  • 局部感受野提取
  • 通道升维

与 2D CNN 不同的是,点云没有规则网格,所以 SA 必须显式完成:

  • 采样中心点
  • 构建邻域
  • 对邻域点集做对称聚合

4.3 为什么要用 MSG

MSG 是 Multi-Scale Grouping,同一层对每个中心点使用多个半径。

作用是:

  • 小半径分支保留局部几何细节
  • 大半径分支捕获更稳定、更大范围的上下文
  • 通道拼接后,同时保留细粒度和粗粒度信息

这对 PVN3D 很关键,因为 6D pose 估计既需要点级局部几何,也需要物体部件级上下文。


5. FP 模块在全局 PointNet2MSG 中的位置和作用

5.1 FP 的位置

FP 位于 Pointnet2MSG 的解码端。对应代码:

python 复制代码
for i in range(-1, -(len(self.FP_modules) + 1), -1):
    l_features[i - 1] = self.FP_modules[i](
        l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
    )

负索引展开后等价于:

python 复制代码
l_features[3] = FP4(l_xyz[3], l_xyz[4], l_features[3], l_features[4])
l_features[2] = FP3(l_xyz[2], l_xyz[3], l_features[2], l_features[3])
l_features[1] = FP2(l_xyz[1], l_xyz[2], l_features[1], l_features[2])
l_features[0] = FP1(l_xyz[0], l_xyz[1], l_features[0], l_features[1])

5.2 FP 的核心作用

FP 做三件事:

  1. 将低分辨率语义特征插值回高分辨率点集
  2. 与同层 skip feature 拼接
  3. 用共享 MLP 融合为新的逐点特征

如果只有 SA,没有 FP,那么网络最后只能在 128 个点上做预测,无法对原始 N 个点逐点输出。

FP 的意义就是把深层语义重新"铺回去",恢复到密集点级表达。

5.3 FP 为什么用 3NN 插值

点云不是规则栅格,不能直接用双线性上采样。

所以 FP 使用:

  • three_nn:为高分辨率点找到低分辨率点集中的 3 个最近邻
  • three_interpolate:按距离倒数权重做加权插值

这样能把粗粒度语义连续地传播到更细粒度点上。


6. SA 模块内部执行过程

6.1 入口模块:_PointnetSAModuleBase.forward

SA 的真正执行逻辑在 pvn3d/lib/pointnet2_utils/pointnet2_modules.py:27-72

它的流程是:

  1. furthest_point_sample 选中心点索引
  2. gather_operation 取出 new_xyz
  3. 对每个尺度分支执行 QueryAndGroup
  4. 对每个分支的局部张量执行 SharedMLP
  5. 在邻域维做 max pooling
  6. 多尺度结果在通道维拼接

6.2 SA 逻辑伪代码

python 复制代码
def sa_module_forward(xyz, features, npoint, groupers, mlps):
    # xyz: (B, N, 3)
    # features: (B, C, N)

    fps_idx = furthest_point_sample(xyz, npoint)           # (B, npoint)
    new_xyz = gather_operation(xyz.transpose(1, 2), fps_idx)
    new_xyz = new_xyz.transpose(1, 2).contiguous()         # (B, npoint, 3)

    outputs = []
    for grouper, mlp in zip(groupers, mlps):
        grouped = grouper(xyz, new_xyz, features)          # (B, Cg, npoint, nsample)
        local_feat = mlp(grouped)                          # (B, Cout, npoint, nsample)
        pooled = torch.max(local_feat, dim=3)[0]           # (B, Cout, npoint)
        outputs.append(pooled)

    new_features = torch.cat(outputs, dim=1)               # (B, sum(Cout), npoint)
    return new_xyz, new_features

6.3 SA 每个子步骤的本质

6.3.1 furthest_point_sample

作用:

  • 从原始 N 个点中挑出 npoint 个中心点
  • 目标不是随机采样,而是尽量让采样点在空间上覆盖均匀

核心思想:

  • 第一个点任选或固定
  • 之后每次选"到已选点集合的最小距离最大"的那个点

伪代码:

python 复制代码
def fps(xyz, npoint):
    selected = [0]
    min_dist = [inf] * len(xyz)
    for _ in range(1, npoint):
        last = xyz[selected[-1]]
        for i, p in enumerate(xyz):
            d = squared_distance(p, last)
            min_dist[i] = min(min_dist[i], d)
        selected.append(argmax(min_dist))
    return selected

实现细节:

  • Python 层只是包装器,真正计算在 _ext.furthest_point_sampling
  • 这是 CUDA 扩展算子,输出 (B, npoint) 索引
  • 该算子本身不需要梯度,反向直接返回 None

6.3.2 gather_operation

作用:

  • 根据索引从 (B, C, N) 中抽取对应列
  • 常用于把 FPS 采样出的中心点坐标从原始点集中取出来

输入输出:

  • 输入:features (B, C, N)idx (B, npoint)
  • 输出:(B, C, npoint)

实现细节:

  • 前向调用 _ext.gather_points
  • 反向调用 _ext.gather_points_grad
  • 这是一个"索引采样 + scatter-add 反传"的典型算子

6.3.3 ball_query

作用:

  • 对每个中心点,在原始点集里找半径 radius 内的邻居
  • 最多保留 nsample 个邻居

输入输出:

  • 输入:xyz (B, N, 3)new_xyz (B, npoint, 3)
  • 输出:idx (B, npoint, nsample)

实现细节:

  • 前向调用 _ext.ball_query
  • 它输出的是邻居索引,不直接输出特征
  • 这是离散邻域选择操作,因此没有定义梯度

语义上可以理解成:

python 复制代码
def ball_query(xyz, centers, radius, nsample):
    result = []
    for c in centers:
        nbrs = [i for i, p in enumerate(xyz) if distance(p, c) <= radius]
        nbrs = pad_or_truncate(nbrs, nsample)
        result.append(nbrs)
    return result

6.3.4 grouping_operation

作用:

  • idx 将点特征收集成局部块
  • 把平铺点集重排成邻域张量

输入输出:

  • 输入:features (B, C, N)idx (B, npoint, nsample)
  • 输出:(B, C, npoint, nsample)

实现细节:

  • 前向调用 _ext.group_points
  • 反向调用 _ext.group_points_grad
  • 可理解为高维 gather

6.3.5 QueryAndGroup

作用:

  • 先做 ball_query
  • 再做 grouping_operation
  • 再把邻域坐标转换成以中心点为原点的相对坐标
  • 最后将相对坐标与原始特征拼接

关键代码位于 pvn3d/lib/pointnet2_utils/pointnet2_utils.py:300-337

python 复制代码
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
grouped_xyz = grouping_operation(xyz.transpose(1, 2), idx)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)

这一步非常关键,因为:

  • grouped_xyz 是局部几何结构
  • grouped_features 是点的属性描述
  • 拼接后,MLP 同时看到几何和语义

6.3.6 SharedMLP

作用:

  • 对局部张量 (B, C, npoint, nsample) 做逐点共享卷积
  • 本质上是多个 1x1 Conv2d (+ BN + ReLU)

pvn3d/lib/utils/etw_pytorch_utils/pytorch_utils.py:25-50 中,SharedMLP 会根据 args 依次堆叠 Conv2d

例如:

python 复制代码
SharedMLP([99, 64, 64, 128])

等价于:

python 复制代码
Conv2d(99, 64, kernel_size=1)
Conv2d(64, 64, kernel_size=1)
Conv2d(64, 128, kernel_size=1)

为什么这里用 Conv2d 而不是 Linear

  • 数据仍然保留 npointnsample 两个空间维
  • 1x1 Conv2d 等价于对每个局部点独立做共享线性映射

6.3.7 邻域 max pooling

作用:

  • 沿邻居维 nsample 做对称聚合
  • 把一个局部点集压缩为一个中心点描述子

代码:

python 复制代码
new_features = torch.max(new_features, dim=3, keepdim=True)[0]
new_features = new_features[..., 0]

为什么必须是对称聚合:

  • 点集内部没有稳定顺序
  • 聚合函数必须对输入排列不敏感
  • max 是 PointNet/PointNet++ 最经典的 permutation-invariant 聚合

7. FP 模块内部执行过程

7.1 入口模块:PointnetFPModule.forward

FP 的执行逻辑位于 pvn3d/lib/pointnet2_utils/pointnet2_modules.py:163-207

流程是:

  1. 在低分辨率点集 known 中为高分辨率点集 unknown 找 3NN
  2. 用距离倒数权重做插值
  3. 与 skip feature 拼接
  4. SharedMLP 融合

7.2 FP 逻辑伪代码

python 复制代码
def fp_module_forward(unknown_xyz, known_xyz, unknown_feats, known_feats, mlp):
    # unknown_xyz: (B, n, 3)  高分辨率
    # known_xyz:   (B, m, 3)  低分辨率

    dist, idx = three_nn(unknown_xyz, known_xyz)               # (B, n, 3)
    inv = 1.0 / (dist + 1e-8)
    weight = inv / inv.sum(dim=2, keepdim=True)                # (B, n, 3)

    interp = three_interpolate(known_feats, idx, weight)       # (B, C2, n)

    if unknown_feats is not None:
        fused = torch.cat([interp, unknown_feats], dim=1)      # (B, C1+C2, n)
    else:
        fused = interp

    fused = fused.unsqueeze(-1)                                # (B, C, n, 1)
    fused = mlp(fused)                                         # (B, Cout, n, 1)
    return fused[..., 0]                                       # (B, Cout, n)

7.3 FP 每个子步骤的本质

7.3.1 three_nn

作用:

  • 对高分辨率点 unknown,在低分辨率点 known 中找最近的 3 个点

输入输出:

  • 输入:unknown (B, n, 3)known (B, m, 3)
  • 输出:dist (B, n, 3)idx (B, n, 3)

实现细节:

  • 前向调用 _ext.three_nn
  • 扩展层通常先返回平方距离 dist2,Python 包装里再 sqrt
  • 该算子本身是基于坐标的离散邻居搜索,没有反向梯度

7.3.2 距离倒数归一化

作用:

  • 距离越近,权重越大
  • 三个邻居权重和为 1

代码:

python 复制代码
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm

这里加 1e-8 是为了避免距离为 0 时除零。

7.3.3 three_interpolate

作用:

  • 根据 idxweight 从低分辨率特征中加权采样

输入输出:

  • 输入:features (B, C, m)idx (B, n, 3)weight (B, n, 3)
  • 输出:(B, C, n)

数学形式:

python 复制代码
out[b, c, i] =
    weight[b, i, 0] * features[b, c, idx[b, i, 0]] +
    weight[b, i, 1] * features[b, c, idx[b, i, 1]] +
    weight[b, i, 2] * features[b, c, idx[b, i, 2]]

实现细节:

  • 前向调用 _ext.three_interpolate
  • 反向调用 _ext.three_interpolate_grad
  • 反传会把梯度按 weight 分配回 3 个源点

7.3.4 skip feature 拼接

作用:

  • 插值得到的是深层语义
  • unknow_feats 保留的是当前层较浅的几何细节
  • 拼接后,网络既能利用高级语义,也不会丢掉局部细节

这就是 PointNet++ 解码路径中的跳连机制。

7.3.5 SharedMLP 融合

这里的 SharedMLP 输入是 (B, C, n, 1),等价于对每个高分辨率点做共享的逐点非线性投影。

它的作用不是插值,而是"插值后的特征融合与重编码"。


8. Pointnet2MSG 的层级张量流

假设输入为:

python 复制代码
pointcloud: (B, N, 3 + C)
xyz:        (B, N, 3)
features:   (B, C, N)

则编码和解码的大致形状如下。

8.1 SA 编码路径

层级 xyz 形状 feature 形状
L0 (B, N, 3) (B, C, N)
L1 = SA1 (B, 2048, 3) (B, 96, 2048)
L2 = SA2 (B, 1024, 3) (B, 256, 1024)
L3 = SA3 (B, 512, 3) (B, 512, 512)
L4 = SA4 (B, 128, 3) (B, 1024, 128)

8.2 FP 解码路径

层级 传播后 feature 形状
L3 <- FP4(L4) (B, 512, 512)
L2 <- FP3(L3) (B, 512, 1024)
L1 <- FP2(L2) (B, 256, 2048)
L0 <- FP1(L1) (B, 128, N)

最终输出:

python 复制代码
Pointnet2MSG(pointcloud) -> (B, 128, N)

这就是后续 DenseFusion 使用的逐点点云 embedding。


9. Graphviz 数据流图

下面这份 dot 可以直接保存为 .dot 文件后用 Graphviz 渲染:

如果要更细致地画 SA 单层数据流,可以用下面这份:


10. 从"算子"角度理解 SA 与 FP

如果只看算子级别,SAFP 可以分别抽象成下面两种模式。

10.1 SA = 采样 + 邻域构图 + 局部编码 + 对称聚合

公式化表达:

python 复制代码
centers = FPS(xyz)
neighbors = BallQuery(xyz, centers)
local = Concat(relative_xyz, point_features)
encoded = SharedMLP(local)
center_feature = MaxPool(encoded, axis=neighbor)

这是一种从"点集合"到"更小点集合"的映射。

10.2 FP = 邻近插值 + 跳连融合 + 逐点重编码

公式化表达:

python 复制代码
idx, dist = ThreeNN(high_res_xyz, low_res_xyz)
interp = WeightedInterpolate(low_res_feat, idx, inverse_distance(dist))
fused = Concat(interp, skip_feat)
high_res_feat = SharedMLP(fused)

这是一种从"稀疏点集合"到"稠密点集合"的映射。


11. SA 与 FP 在 PVN3D 任务中的实际意义

PVN3D 的目标不是只分类点,而是要做:

  • 逐点语义分割
  • 逐点关键点偏移回归
  • 逐点中心点偏移回归

这决定了点云主干必须同时满足两件事:

  1. 有足够强的高层语义表达能力
  2. 能回到原始点级分辨率输出密集特征

因此:

  • SA 负责逐层扩大感受野、提炼高层语义
  • FP 负责把这些语义精确地传播回每个原始点

如果没有 SA:

  • 网络看不到稳定的局部几何层次
  • 高层上下文弱,pose 估计会不稳

如果没有 FP:

  • 只能在稀疏点上有强特征
  • 无法输出高质量逐点分割和偏移

所以 SA 和 FP 在 Pointnet2MSG 中是成对工作的编码器/解码器。


12. 结论

Pointnet2MSG 是 PVN3D 点云分支的核心骨干,其内部结构可以概括为:

  • SA:用 FPS 选中心点,用 Ball Query 建邻域,用 SharedMLP + MaxPool 完成多尺度局部编码
  • FP:用 3NN 找对应关系,用距离倒数权重插值,把深层语义传播回高分辨率点,再与 skip feature 融合

从实现角度看:

  • 几何相关、索引相关、插值相关操作主要由 pointnet2_utils._ext CUDA 扩展完成
  • Python 层负责模块组织、张量拼接、共享 MLP、池化和整体前向流程

从网络功能角度看:

  • SA 决定特征"抽得够不够深"
  • FP 决定特征"回得够不够密"

二者共同保证 PVN3D 既有强几何语义,又能保留逐点预测能力。

相关推荐
机器学习之心2 小时前
贝叶斯优化+卷积神经网络+多目标优化+多属性决策!BO-CNN+NSGAII+熵权TOPSIS,附实验报告!
人工智能·神经网络·cnn·多目标优化·多属性决策
苯酸氨酰糖化物2 小时前
基于深度学习(U-Net架构下改良GAN与ViT算法)的高效肺部多模态疾病预测模型
人工智能·深度学习·算法·生成对抗网络·视觉检测
kishu_iOS&AI2 小时前
深度学习 —— 浅析&Pytorch入门
人工智能·pytorch·深度学习
清章一2 小时前
HTML头部元信息避坑指南大纲
人工智能
一直会游泳的小猫2 小时前
Pascal Editor:基于 WebGPU 的开源 3D 建筑编辑器技术解析
3d·开源·编辑器
大模型实验室Lab4AI2 小时前
MAG-3D: Multi-Agent Grounded Reasoning for 3D Understanding
人工智能·计算机视觉·3d
沪漂阿龙2 小时前
循环神经网络(RNN)深度解析:从数学原理到智能输入法实战
人工智能·rnn·深度学习
来两个炸鸡腿2 小时前
【Datawhale2604】Hello-agents task01 智能体经典范式构建
人工智能·大模型·智能体
njsgcs2 小时前
我需要ai理解鼠标在工程图里的位置,要能理解标注的任务
人工智能