PVN3D 中 SA 模块与 FP 模块详解

1. 文档目标

本文单独说明 PVN3D 点云主干 Pointnet2MSG 中:

  • SA 模块在全局网络中的位置和作用
  • FP 模块在全局网络中的位置和作用
  • 每个核心算子的职责、输入输出和实现细节
  • 用 Python 伪代码还原关键逻辑
  • 用 Graphviz 描述数据流

核心代码位置:

  • Pointnet2MSG 定义:pvn3d/lib/pvn3d.py
  • PointnetSAModuleMSG / PointnetFPModulepvn3d/lib/pointnet2_utils/pointnet2_modules.py
  • furthest_point_sample / ball_query / grouping_operation / three_nn / three_interpolatepvn3d/lib/pointnet2_utils/pointnet2_utils.py
  • SharedMLPpvn3d/lib/utils/etw_pytorch_utils/pytorch_utils.py

2. Pointnet2MSG 在 PVN3D 中的全局位置

在 PVN3D 里,点云分支不是单独工作的,它和 RGB 分支共同组成最终的融合特征。

对应关系如下:

  1. RGB 图像先进入 ModifiedResnet,得到逐像素 RGB embedding。
  2. 点云输入进入 Pointnet2MSG,得到逐点 point embedding。
  3. 两路特征进入 DenseFusion 做融合。
  4. 融合结果再送入分割头、关键点偏移头、中心点偏移头。

Pointnet2MSG 的定义和调用位于 pvn3d/lib/pvn3d.py:46-153pvn3d/lib/pvn3d.py:237 附近。它本身是一个典型的 PointNet++ segmentation backbone:

  • 前半段:4 层 SA,负责下采样和局部区域编码
  • 后半段:4 层 FP,负责把深层语义传播回高分辨率点集

这意味着:

  • SA 负责把原始点云逐步压缩成更少、更强的语义点
  • FP 负责把压缩后的高级语义重新对齐并传播到原始点级分辨率

最终 Pointnet2MSG.forward() 返回的是 l_features[0],即恢复到原始点数分辨率后的逐点特征。


3. Pointnet2MSG 的整体结构

3.1 输入与拆分

输入张量格式:

python 复制代码
pointcloud.shape == (B, N, 3 + input_channels)

_break_up_pc() 中被拆成:

python 复制代码
xyz = pc[..., 0:3]                         # (B, N, 3)
features = pc[..., 3:].transpose(1, 2)    # (B, C, N)

这里有一个很重要的约定:

  • xyz 始终保持 (B, N, 3)
  • 点特征 features 始终保持 (B, C, N)

这是后面所有 PointNet2 算子的数据布局基础。

3.2 四层 SA 配置

Pointnet2MSGpvn3d/lib/pvn3d.py:64-111 中定义了四层 SA:

层级 npoint radii nsamples 两个分支 MLP 输出通道
SA1 2048 0.0175, 0.025 16, 32 C,16,16,32 / C,32,32,64 96
SA2 1024 0.025, 0.05 16, 32 96,64,64,128 / 96,64,96,128 256
SA3 512 0.05, 0.1 16, 32 256,128,196,256 / 256,128,196,256 512
SA4 128 0.1, 0.2 16, 32 512,256,256,512 / 512,256,384,512 1024

说明:

  • 每一层都是 MSG,即 Multi-Scale Grouping
  • 同一层中会对同一个中心点做两个半径尺度的局部邻域聚合
  • 两个尺度分支各自提特征后,在通道维拼接

3.3 四层 FP 配置

Pointnet2MSGpvn3d/lib/pvn3d.py:113-117 中定义了四层 FP:

层级 传播方向 输入拼接通道 MLP 输出通道
FP4 SA4 -> SA3 1024 + 512 1536, 512, 512 512
FP3 SA3 -> SA2 512 + 256 768, 512, 512 512
FP2 SA2 -> SA1 512 + 96 608, 256, 256 256
FP1 SA1 -> Input 256 + input_channels 256 + C, 128, 128 128

注意实际代码里 FP_modules 是按从浅到深存储,但在 forward() 里通过负索引倒序调用,所以真实传播顺序是:

text 复制代码
SA4 -> SA3 -> SA2 -> SA1 -> 原始点集

4. SA 模块在全局 PointNet2MSG 中的位置和作用

4.1 SA 的位置

SA 位于 Pointnet2MSG 的编码端。对应代码:

python 复制代码
l_xyz, l_features = [xyz], [features]
for i in range(len(self.SA_modules)):
    li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
    l_xyz.append(li_xyz)
    l_features.append(li_features)

执行结束后:

  • l_xyz[0] / l_features[0]:原始点集
  • l_xyz[1] / l_features[1]:SA1 输出
  • l_xyz[2] / l_features[2]:SA2 输出
  • l_xyz[3] / l_features[3]:SA3 输出
  • l_xyz[4] / l_features[4]:SA4 输出

4.2 SA 的核心作用

SA 做三件事:

  1. 选择代表点
  2. 围绕代表点提取局部邻域
  3. 将邻域内的信息压缩为中心点的新特征

从网络功能上看,SA 相当于二维 CNN 中的:

  • 空间降采样
  • 局部感受野提取
  • 通道升维

与 2D CNN 不同的是,点云没有规则网格,所以 SA 必须显式完成:

  • 采样中心点
  • 构建邻域
  • 对邻域点集做对称聚合

4.3 为什么要用 MSG

MSG 是 Multi-Scale Grouping,同一层对每个中心点使用多个半径。

作用是:

  • 小半径分支保留局部几何细节
  • 大半径分支捕获更稳定、更大范围的上下文
  • 通道拼接后,同时保留细粒度和粗粒度信息

这对 PVN3D 很关键,因为 6D pose 估计既需要点级局部几何,也需要物体部件级上下文。


5. FP 模块在全局 PointNet2MSG 中的位置和作用

5.1 FP 的位置

FP 位于 Pointnet2MSG 的解码端。对应代码:

python 复制代码
for i in range(-1, -(len(self.FP_modules) + 1), -1):
    l_features[i - 1] = self.FP_modules[i](
        l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
    )

负索引展开后等价于:

python 复制代码
l_features[3] = FP4(l_xyz[3], l_xyz[4], l_features[3], l_features[4])
l_features[2] = FP3(l_xyz[2], l_xyz[3], l_features[2], l_features[3])
l_features[1] = FP2(l_xyz[1], l_xyz[2], l_features[1], l_features[2])
l_features[0] = FP1(l_xyz[0], l_xyz[1], l_features[0], l_features[1])

5.2 FP 的核心作用

FP 做三件事:

  1. 将低分辨率语义特征插值回高分辨率点集
  2. 与同层 skip feature 拼接
  3. 用共享 MLP 融合为新的逐点特征

如果只有 SA,没有 FP,那么网络最后只能在 128 个点上做预测,无法对原始 N 个点逐点输出。

FP 的意义就是把深层语义重新"铺回去",恢复到密集点级表达。

5.3 FP 为什么用 3NN 插值

点云不是规则栅格,不能直接用双线性上采样。

所以 FP 使用:

  • three_nn:为高分辨率点找到低分辨率点集中的 3 个最近邻
  • three_interpolate:按距离倒数权重做加权插值

这样能把粗粒度语义连续地传播到更细粒度点上。


6. SA 模块内部执行过程

6.1 入口模块:_PointnetSAModuleBase.forward

SA 的真正执行逻辑在 pvn3d/lib/pointnet2_utils/pointnet2_modules.py:27-72

它的流程是:

  1. furthest_point_sample 选中心点索引
  2. gather_operation 取出 new_xyz
  3. 对每个尺度分支执行 QueryAndGroup
  4. 对每个分支的局部张量执行 SharedMLP
  5. 在邻域维做 max pooling
  6. 多尺度结果在通道维拼接

6.2 SA 逻辑伪代码

python 复制代码
def sa_module_forward(xyz, features, npoint, groupers, mlps):
    # xyz: (B, N, 3)
    # features: (B, C, N)

    fps_idx = furthest_point_sample(xyz, npoint)           # (B, npoint)
    new_xyz = gather_operation(xyz.transpose(1, 2), fps_idx)
    new_xyz = new_xyz.transpose(1, 2).contiguous()         # (B, npoint, 3)

    outputs = []
    for grouper, mlp in zip(groupers, mlps):
        grouped = grouper(xyz, new_xyz, features)          # (B, Cg, npoint, nsample)
        local_feat = mlp(grouped)                          # (B, Cout, npoint, nsample)
        pooled = torch.max(local_feat, dim=3)[0]           # (B, Cout, npoint)
        outputs.append(pooled)

    new_features = torch.cat(outputs, dim=1)               # (B, sum(Cout), npoint)
    return new_xyz, new_features

6.3 SA 每个子步骤的本质

6.3.1 furthest_point_sample

作用:

  • 从原始 N 个点中挑出 npoint 个中心点
  • 目标不是随机采样,而是尽量让采样点在空间上覆盖均匀

核心思想:

  • 第一个点任选或固定
  • 之后每次选"到已选点集合的最小距离最大"的那个点

伪代码:

python 复制代码
def fps(xyz, npoint):
    selected = [0]
    min_dist = [inf] * len(xyz)
    for _ in range(1, npoint):
        last = xyz[selected[-1]]
        for i, p in enumerate(xyz):
            d = squared_distance(p, last)
            min_dist[i] = min(min_dist[i], d)
        selected.append(argmax(min_dist))
    return selected

实现细节:

  • Python 层只是包装器,真正计算在 _ext.furthest_point_sampling
  • 这是 CUDA 扩展算子,输出 (B, npoint) 索引
  • 该算子本身不需要梯度,反向直接返回 None

6.3.2 gather_operation

作用:

  • 根据索引从 (B, C, N) 中抽取对应列
  • 常用于把 FPS 采样出的中心点坐标从原始点集中取出来

输入输出:

  • 输入:features (B, C, N)idx (B, npoint)
  • 输出:(B, C, npoint)

实现细节:

  • 前向调用 _ext.gather_points
  • 反向调用 _ext.gather_points_grad
  • 这是一个"索引采样 + scatter-add 反传"的典型算子

6.3.3 ball_query

作用:

  • 对每个中心点,在原始点集里找半径 radius 内的邻居
  • 最多保留 nsample 个邻居

输入输出:

  • 输入:xyz (B, N, 3)new_xyz (B, npoint, 3)
  • 输出:idx (B, npoint, nsample)

实现细节:

  • 前向调用 _ext.ball_query
  • 它输出的是邻居索引,不直接输出特征
  • 这是离散邻域选择操作,因此没有定义梯度

语义上可以理解成:

python 复制代码
def ball_query(xyz, centers, radius, nsample):
    result = []
    for c in centers:
        nbrs = [i for i, p in enumerate(xyz) if distance(p, c) <= radius]
        nbrs = pad_or_truncate(nbrs, nsample)
        result.append(nbrs)
    return result

6.3.4 grouping_operation

作用:

  • idx 将点特征收集成局部块
  • 把平铺点集重排成邻域张量

输入输出:

  • 输入:features (B, C, N)idx (B, npoint, nsample)
  • 输出:(B, C, npoint, nsample)

实现细节:

  • 前向调用 _ext.group_points
  • 反向调用 _ext.group_points_grad
  • 可理解为高维 gather

6.3.5 QueryAndGroup

作用:

  • 先做 ball_query
  • 再做 grouping_operation
  • 再把邻域坐标转换成以中心点为原点的相对坐标
  • 最后将相对坐标与原始特征拼接

关键代码位于 pvn3d/lib/pointnet2_utils/pointnet2_utils.py:300-337

python 复制代码
idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
grouped_xyz = grouping_operation(xyz.transpose(1, 2), idx)
grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)

这一步非常关键,因为:

  • grouped_xyz 是局部几何结构
  • grouped_features 是点的属性描述
  • 拼接后,MLP 同时看到几何和语义

6.3.6 SharedMLP

作用:

  • 对局部张量 (B, C, npoint, nsample) 做逐点共享卷积
  • 本质上是多个 1x1 Conv2d (+ BN + ReLU)

pvn3d/lib/utils/etw_pytorch_utils/pytorch_utils.py:25-50 中,SharedMLP 会根据 args 依次堆叠 Conv2d

例如:

python 复制代码
SharedMLP([99, 64, 64, 128])

等价于:

python 复制代码
Conv2d(99, 64, kernel_size=1)
Conv2d(64, 64, kernel_size=1)
Conv2d(64, 128, kernel_size=1)

为什么这里用 Conv2d 而不是 Linear

  • 数据仍然保留 npointnsample 两个空间维
  • 1x1 Conv2d 等价于对每个局部点独立做共享线性映射

6.3.7 邻域 max pooling

作用:

  • 沿邻居维 nsample 做对称聚合
  • 把一个局部点集压缩为一个中心点描述子

代码:

python 复制代码
new_features = torch.max(new_features, dim=3, keepdim=True)[0]
new_features = new_features[..., 0]

为什么必须是对称聚合:

  • 点集内部没有稳定顺序
  • 聚合函数必须对输入排列不敏感
  • max 是 PointNet/PointNet++ 最经典的 permutation-invariant 聚合

7. FP 模块内部执行过程

7.1 入口模块:PointnetFPModule.forward

FP 的执行逻辑位于 pvn3d/lib/pointnet2_utils/pointnet2_modules.py:163-207

流程是:

  1. 在低分辨率点集 known 中为高分辨率点集 unknown 找 3NN
  2. 用距离倒数权重做插值
  3. 与 skip feature 拼接
  4. SharedMLP 融合

7.2 FP 逻辑伪代码

python 复制代码
def fp_module_forward(unknown_xyz, known_xyz, unknown_feats, known_feats, mlp):
    # unknown_xyz: (B, n, 3)  高分辨率
    # known_xyz:   (B, m, 3)  低分辨率

    dist, idx = three_nn(unknown_xyz, known_xyz)               # (B, n, 3)
    inv = 1.0 / (dist + 1e-8)
    weight = inv / inv.sum(dim=2, keepdim=True)                # (B, n, 3)

    interp = three_interpolate(known_feats, idx, weight)       # (B, C2, n)

    if unknown_feats is not None:
        fused = torch.cat([interp, unknown_feats], dim=1)      # (B, C1+C2, n)
    else:
        fused = interp

    fused = fused.unsqueeze(-1)                                # (B, C, n, 1)
    fused = mlp(fused)                                         # (B, Cout, n, 1)
    return fused[..., 0]                                       # (B, Cout, n)

7.3 FP 每个子步骤的本质

7.3.1 three_nn

作用:

  • 对高分辨率点 unknown,在低分辨率点 known 中找最近的 3 个点

输入输出:

  • 输入:unknown (B, n, 3)known (B, m, 3)
  • 输出:dist (B, n, 3)idx (B, n, 3)

实现细节:

  • 前向调用 _ext.three_nn
  • 扩展层通常先返回平方距离 dist2,Python 包装里再 sqrt
  • 该算子本身是基于坐标的离散邻居搜索,没有反向梯度

7.3.2 距离倒数归一化

作用:

  • 距离越近,权重越大
  • 三个邻居权重和为 1

代码:

python 复制代码
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm

这里加 1e-8 是为了避免距离为 0 时除零。

7.3.3 three_interpolate

作用:

  • 根据 idxweight 从低分辨率特征中加权采样

输入输出:

  • 输入:features (B, C, m)idx (B, n, 3)weight (B, n, 3)
  • 输出:(B, C, n)

数学形式:

python 复制代码
out[b, c, i] =
    weight[b, i, 0] * features[b, c, idx[b, i, 0]] +
    weight[b, i, 1] * features[b, c, idx[b, i, 1]] +
    weight[b, i, 2] * features[b, c, idx[b, i, 2]]

实现细节:

  • 前向调用 _ext.three_interpolate
  • 反向调用 _ext.three_interpolate_grad
  • 反传会把梯度按 weight 分配回 3 个源点

7.3.4 skip feature 拼接

作用:

  • 插值得到的是深层语义
  • unknow_feats 保留的是当前层较浅的几何细节
  • 拼接后,网络既能利用高级语义,也不会丢掉局部细节

这就是 PointNet++ 解码路径中的跳连机制。

7.3.5 SharedMLP 融合

这里的 SharedMLP 输入是 (B, C, n, 1),等价于对每个高分辨率点做共享的逐点非线性投影。

它的作用不是插值,而是"插值后的特征融合与重编码"。


8. Pointnet2MSG 的层级张量流

假设输入为:

python 复制代码
pointcloud: (B, N, 3 + C)
xyz:        (B, N, 3)
features:   (B, C, N)

则编码和解码的大致形状如下。

8.1 SA 编码路径

层级 xyz 形状 feature 形状
L0 (B, N, 3) (B, C, N)
L1 = SA1 (B, 2048, 3) (B, 96, 2048)
L2 = SA2 (B, 1024, 3) (B, 256, 1024)
L3 = SA3 (B, 512, 3) (B, 512, 512)
L4 = SA4 (B, 128, 3) (B, 1024, 128)

8.2 FP 解码路径

层级 传播后 feature 形状
L3 <- FP4(L4) (B, 512, 512)
L2 <- FP3(L3) (B, 512, 1024)
L1 <- FP2(L2) (B, 256, 2048)
L0 <- FP1(L1) (B, 128, N)

最终输出:

python 复制代码
Pointnet2MSG(pointcloud) -> (B, 128, N)

这就是后续 DenseFusion 使用的逐点点云 embedding。


9. Graphviz 数据流图

下面这份 dot 可以直接保存为 .dot 文件后用 Graphviz 渲染:

如果要更细致地画 SA 单层数据流,可以用下面这份:


10. 从"算子"角度理解 SA 与 FP

如果只看算子级别,SAFP 可以分别抽象成下面两种模式。

10.1 SA = 采样 + 邻域构图 + 局部编码 + 对称聚合

公式化表达:

python 复制代码
centers = FPS(xyz)
neighbors = BallQuery(xyz, centers)
local = Concat(relative_xyz, point_features)
encoded = SharedMLP(local)
center_feature = MaxPool(encoded, axis=neighbor)

这是一种从"点集合"到"更小点集合"的映射。

10.2 FP = 邻近插值 + 跳连融合 + 逐点重编码

公式化表达:

python 复制代码
idx, dist = ThreeNN(high_res_xyz, low_res_xyz)
interp = WeightedInterpolate(low_res_feat, idx, inverse_distance(dist))
fused = Concat(interp, skip_feat)
high_res_feat = SharedMLP(fused)

这是一种从"稀疏点集合"到"稠密点集合"的映射。


11. SA 与 FP 在 PVN3D 任务中的实际意义

PVN3D 的目标不是只分类点,而是要做:

  • 逐点语义分割
  • 逐点关键点偏移回归
  • 逐点中心点偏移回归

这决定了点云主干必须同时满足两件事:

  1. 有足够强的高层语义表达能力
  2. 能回到原始点级分辨率输出密集特征

因此:

  • SA 负责逐层扩大感受野、提炼高层语义
  • FP 负责把这些语义精确地传播回每个原始点

如果没有 SA:

  • 网络看不到稳定的局部几何层次
  • 高层上下文弱,pose 估计会不稳

如果没有 FP:

  • 只能在稀疏点上有强特征
  • 无法输出高质量逐点分割和偏移

所以 SA 和 FP 在 Pointnet2MSG 中是成对工作的编码器/解码器。


12. 结论

Pointnet2MSG 是 PVN3D 点云分支的核心骨干,其内部结构可以概括为:

  • SA:用 FPS 选中心点,用 Ball Query 建邻域,用 SharedMLP + MaxPool 完成多尺度局部编码
  • FP:用 3NN 找对应关系,用距离倒数权重插值,把深层语义传播回高分辨率点,再与 skip feature 融合

从实现角度看:

  • 几何相关、索引相关、插值相关操作主要由 pointnet2_utils._ext CUDA 扩展完成
  • Python 层负责模块组织、张量拼接、共享 MLP、池化和整体前向流程

从网络功能角度看:

  • SA 决定特征"抽得够不够深"
  • FP 决定特征"回得够不够密"

二者共同保证 PVN3D 既有强几何语义,又能保留逐点预测能力。

相关推荐
米小虾32 分钟前
Loop Engineering —— 循环的设计与自主执行
人工智能·agent
米小虾1 小时前
Harness Engineering —— 系统的安全护栏
人工智能·agent
火山引擎开发者社区1 小时前
积分当钱花,火山引擎开发者激励计划首月消费双倍回馈
人工智能
aqi002 小时前
15天学会AI应用开发(十)把文本嵌入模型换成国产模型
人工智能·python·ai编程
MobotStone2 小时前
为什么在AI时代,“好奇心”成了最值钱的能力?
人工智能
武子康3 小时前
调查研究-200 llama.cpp b9754:一次很小但很关键的 Agent 工具调用修复
人工智能·agent·llama
Ralph_Salar3 小时前
从0到1搭建AI智能支付风控助手Stage1-RAG知识库升级 — 元数据让检索更精准
人工智能
武子康3 小时前
调查研究-199 MCP Zero-Touch OAuth:为什么它是 MCP 进入企业生产的关键门槛?
人工智能·agent·mcp