CenterPoint:不用锚框也能做3D检测?无锚框方法的革命

CenterPoint:不用锚框也能做3D检测?无锚框方法的革命

作者:小探 | 首发:探物 AI | 序列:3D感知网络的第3篇文章 | 转载请注明出处
上一篇我们学习了PointPillars,理解了如何用柱子化将3D点云转为伪图像实现实时检测。但PointPillars需要设计大量锚框,正负样本严重不平衡。这一篇我们来看CenterPoint------一个不用锚框、直接检测物体中心的3D检测方法,既快又准,成为当前主流!


一、从锚框检测到无锚框检测

1.1 锚框检测(PointPillars)的贡献?

复制代码
PointPillars的贡献:
✅ 实时3D检测(62 FPS)
✅ 端到端训练
✅ 工业可用

PointPillars的问题:
❌ 需要设计锚框(先验知识)
❌ 锚框数量巨大(每个位置多个锚框)
❌ 正负样本不平衡(大部分锚框是负样本)
❌ 锚框尺寸需要针对数据集调整

1.2 锚框检测的具体问题

复制代码
问题1:锚框设计需要先验知识

  不同数据集需要不同的锚框:
  - KITTI:车辆 (4.0, 1.6, 1.5)
  - nuScenes:车辆 (4.5, 1.8, 1.5)
  - Waymo:车辆 (4.7, 1.9, 1.5)

  问题:
  - 需要统计每个数据集的目标尺寸
  - 不同类别需要不同锚框
  - 迁移性差,但是话说过来工业恒定场景非常合适


问题2:正负样本不平衡

  每个位置生成2-6个锚框
  100×100网格 × 6锚框 = 60,000个锚框
  实际目标可能只有几个!

  正样本:~100个
  负样本:~59,900个
  比例:1:600

  问题:
  - 训练困难(大量简单负样本)
  - 需要复杂的采样策略


问题3:锚框编码复杂

  锚框 → 预测偏移量
  需要编码:Δx, Δy, Δz, Δw, Δl, Δh, Δyaw

  问题:
  - 编码/解码复杂
  - 需要NMS后处理
  - 容易产生重复检测

1.3 无锚框检测的核心思想

复制代码
无锚框检测(CenterPoint):
  不预测锚框的偏移,直接预测物体的中心点!

  传统方法:
    锚框 → 预测偏移 → 解码 → NMS → 结果

  CenterPoint:
    直接预测中心点 → 回归尺寸 → 结果

优势:
  ✅ 不需要设计锚框
  ✅ 不需要NMS(每个中心点只预测一次)
  ✅ 正负样本更平衡
  ✅ 训练更简单

1.4 发展路线

复制代码
PointPillars (2019)
    │  锚框检测,62 FPS
    ↓
CenterPoint (2021)
    │  无锚框检测,中心点检测
    ↓
CenterPoint + 速度估计
    │  加入时序信息
    ↓
TransFusion (2022)
    │  Transformer融合
    ↓
StreamPETR (2023)
    │  流式Transformer
    ↓
自动驾驶实际部署

二、CenterPoint的核心思想

2.1 一句话概括

将3D检测转化为中心点热力图预测,直接检测物体中心,然后回归3D尺寸和朝向,无需锚框和NMS。

2.2 与PointPillars的本质区别

复制代码
PointPillars(锚框方法):
  1. 生成大量锚框(60,000+)
  2. 预测每个锚框的偏移
  3. 解码得到边界框
  4. NMS去除重复

CenterPoint(无锚框方法):
  1. 预测中心点热力图
  2. 在热力图上找峰值(物体中心)
  3. 直接回归每个中心的3D属性
  4. 无需NMS!

对比:
┌──────────────┬──────────────┬──────────────┐
│              │ PointPillars │ CenterPoint  │
├──────────────┼──────────────┼──────────────┤
│ 检测方式     │ 锚框偏移     │ 中心点热力图 │
│ 锚框数量     │ ~60,000      │ 0            │
│ 正负样本比   │ 1:600        │ 更平衡       │
│ NMS          │ 需要         │ 不需要       │
│ 训练难度     │ 较难         │ 较易         │
│ 速度         │ 62 FPS       │ 15-30 FPS    │
│ 精度         │ 较高         │ 更高         │
└──────────────┴──────────────┴──────────────┘

2.3 CenterPoint的三大步骤

复制代码
步骤1:中心点热力图预测
  预测每个类别的物体中心位置

步骤2:3D属性回归
  对每个中心点,回归3D尺寸、朝向、高度等

步骤3:后处理(可选)
  简单的阈值过滤,无需NMS

图解整体流程

复制代码
输入点云 [N×4]
        │
        ↓
┌─────────────────────────────────────────────┐
│ 步骤1:特征提取(PointPillars骨干)          │
│                                             │
│ - 柱子化                                     │
│ - 2D卷积骨干                                 │
│ - 特征金字塔                                 │
└─────────────────────────────────────────────┘
        │
        ↓
   BEV特征图 [C×H×W]
        │
        ↓
┌─────────────────────────────────────────────┐
│ 步骤2:中心点热力图预测                      │
│                                             │
│ - 对每个类别预测热力图                       │
│ - 热力图峰值 = 物体中心                     │
│ - 无需锚框!                                │
└─────────────────────────────────────────────┘
        │
        ↓
   热力图 [K×H×W]
   (K个类别)
        │
        ↓
┌─────────────────────────────────────────────┐
│ 步骤3:3D属性回归                            │
│                                             │
│ - 对每个中心点回归:                         │
│   - 3D尺寸 (w, l, h)                       │
│   - 高度 z                                  │
│   - 朝向角 yaw                              │
│   - 速度 (vx, vy)                           │
└─────────────────────────────────────────────┘
        │
        ↓
3D检测结果:
  - 中心点位置 (x, y)
  - 高度 z
  - 尺寸 (w, l, h)
  - 朝向 yaw
  - 速度 (vx, vy)
  - 类别 + 置信度

2.4 为什么这样做更好?

复制代码
1. 不需要锚框设计
   - 不需要统计目标尺寸
   - 不同数据集通用
   - 迁移性好

2. 正负样本更平衡
   - 只在物体中心附近是正样本
   - 其他位置是负样本
   - 比例更合理

3. 不需要NMS
   - 每个物体只有一个中心点
   - 热力图峰值天然去重
   - 后处理更简单

4. 训练更简单
   - 损失函数更直观
   - 不需要复杂的采样策略
   - 收敛更快

5. 精度更高
   - 中心点定位更准
   - 减少锚框带来的误差

三、CenterPoint的网络结构详解

3.1 整体架构

复制代码
输入点云
    │
    ↓
┌─────────────────────────────────────────────┐
│ PointPillars特征提取                         │
│                                             │
│  柱子化 → 柱子编码器 → 伪图像               │
│  → 2D骨干网络 → 特征金字塔                   │
└─────────────────────────────────────────────┘
    │
    ↓
BEV特征图 [C×H×W]
    │
    ├──→ 中心点热力图头 ──→ 热力图 [K×H×W]
    │
    ├──→ 尺寸回归头 ──→ 尺寸图 [3×H×W]
    │
    ├──→ 高度回归头 ──→ 高度图 [1×H×W]
    │
    └──→ 朝向回归头 ──→ 朝向图 [2×H×W]
         │
         ↓
    后处理 → 3D检测结果

3.2 中心点热力图预测(核心创新)

核心思想:用热力图表示物体中心位置,峰值处即为中心

python 复制代码
class CenterPointHead(nn.Module):
    def forward(self, bev_feat):
        return {
            "heatmap": self.heatmap_head(bev_feat),       # 中心点概率图
            "size": self.size_head(bev_feat),             # w, l, h
            "height": self.height_head(bev_feat),         # z
            "orientation": self.orientation_head(bev_feat), # sin(yaw), cos(yaw)
            "velocity": self.velocity_head(bev_feat),     # vx, vy(可选)
        }

初始化部分本质上就是几组卷积头:一个负责"哪里是中心",其余负责"这个中心对应的3D属性"。

热力图图解

3.3 高斯热力图生成(标签编码)

python 复制代码
def generate_heatmap(targets):
    heatmap = zeros([num_classes, H, W])

    for x, y, cls_id in targets:
        cx, cy = world_to_grid(x, y)
        draw_gaussian(heatmap[cls_id], center=(cx, cy), sigma=sigma)

    return heatmap

高斯热力图图解

复制代码
单个物体的热力图标签:

真实中心点:(10, 20)
高斯核大小:sigma = 2

热力图值分布:
        18    19    20    21    22
  8  [ 0.0   0.0   0.0   0.0   0.0 ]
  9  [ 0.0   0.1   0.2   0.1   0.0 ]
  10 [ 0.0   0.3   0.8   0.3   0.0 ]  ← 中心行
  11 [ 0.0   0.1   0.2   0.1   0.0 ]
  12 [ 0.0   0.0   0.0   0.0   0.0 ]

中心值最高(接近1),周围逐渐衰减

为什么用高斯而不是二值?
  - 二值:只有中心是1,其他是0
  - 高斯:中心附近都有值
  - 优势:提供更多监督信号,训练更稳定

3.4 3D属性回归

python 复制代码
def decode_predictions(pred):
    peaks = find_local_peaks(pred["heatmap"], threshold=0.3)
    boxes = []

    for cls_id, y, x, score in peaks:
        center = grid_to_world(x, y)
        size = read_at(pred["size"], y, x)
        z = read_at(pred["height"], y, x)
        yaw = decode_yaw(read_at(pred["orientation"], y, x))

        boxes.append((center, z, size, yaw, cls_id, score))

    return boxes

解码过程图解

复制代码
热力图:
┌─────────────────────────────────────┐
│                                     │
│              ▓▓▓                    │  ← 峰值位置
│              ▓▓▓                    │
│                                     │
└─────────────────────────────────────┘

找到峰值位置 (x_idx, y_idx)

在该位置读取回归值:
  - 高度图 → z = 0.5m
  - 尺寸图 → w=1.6m, l=4.0m, h=1.5m
  - 朝向图 → sin(yaw)=0.5, cos(yaw)=0.866 → yaw=30°

组装3D检测框:
  (x, y, 0.5, 1.6, 4.0, 1.5, 30°)

3.5 损失函数

python 复制代码
loss = focal_loss(pred["heatmap"], target["heatmap"])

for name in ["size", "height", "orientation", "velocity"]:
    # 只在真实中心点附近计算属性回归损失
    loss += l1_loss(pred[name], target[name], mask=target["center_mask"])

损失函数图解

复制代码
热力图损失(Focal Loss):
  正样本(中心点):loss = -(1-p)^α * log(p)
  负样本(其他):loss = -(1-t)^β * p^α * log(1-p)

  效果:
  - 正样本:预测越准,损失越小
  - 负样本:简单负样本损失小,难负样本损失大
  - 解决正负样本不平衡

属性损失(L1 Loss):
  只在热力图峰值位置计算

  位置 (10, 20):
    预测:w=1.5, l=3.9, h=1.4
    真实:w=1.6, l=4.0, h=1.5
    损失:|1.5-1.6| + |3.9-4.0| + |1.4-1.5| = 0.3

  其他位置:不计算损失

四、CenterPoint代码骨架(精简版)

4.1 网络前向流程

python 复制代码
class CenterPoint(nn.Module):
    def forward(self, pillars, indices, num_points, grid_size):
        pillar_feat = self.pillar_encoder(pillars, indices, num_points)
        pseudo_image = create_pseudo_image(pillar_feat, indices, grid_size)
        bev_feat = self.backbone(pseudo_image)
        bev_feat = self.fpn(bev_feat)

        return self.head(bev_feat)

4.2 训练流程

python 复制代码
for batch in train_loader:
    pred = model(batch["pillars"], batch["indices"], batch["num_points"], grid_size)
    loss = centerpoint_loss(pred, batch["targets"])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

五、CenterPoint的改进版本

5.1 加入时序信息(CenterPoint-Temporal)

复制代码
问题:单帧检测无法估计速度

解决:融合多帧信息

CenterPoint-Temporal:
  当前帧特征 → 检测
       ↓
  上一帧特征 → 检测 → 变换到当前帧
       ↓
  融合 → 更准确的检测 + 速度估计
python 复制代码
current_feat = detector.extract_features(current_points)
previous_feat = detector.extract_features(previous_points)

# 先把上一帧对齐到当前帧坐标系,再融合
previous_feat = warp_to_current_frame(previous_feat, ego_motion)
fused_feat = temporal_fusion(current_feat, previous_feat)

pred = detector.head(fused_feat)

5.2 加入速度估计

python 复制代码
# 单帧:直接从当前BEV特征预测速度
velocity = velocity_head(current_feat)

# 多帧:加入"当前帧 - 上一帧"的特征差,速度更稳定
velocity = velocity_head(concat([current_feat, current_feat - previous_feat]))

5.3 CenterPoint vs PointPillars

复制代码
KITTI数据集对比:

┌─────────────────┬───────────┬───────────┬───────────┐
│ 方法            │ AP Easy   │ AP Mod    │ AP Hard   │
├─────────────────┼───────────┼───────────┼───────────┤
│ PointPillars    │ 82.58     │ 74.31     │ 68.99     │
│ CenterPoint     │ 85.62     │ 77.45     │ 72.13     │
└─────────────────┴───────────┴───────────┴───────────┘

nuScenes数据集对比:

┌─────────────────┬───────────┬───────────┬───────────┐
│ 方法            │ NDS       │ mAP       │ FPS       │
├─────────────────┼───────────┼───────────┼───────────┤
│ PointPillars    │ 45.3      │ 30.5      │ 62        │
│ CenterPoint     │ 67.3      │ 59.6      │ 15        │
└─────────────────┴───────────┴───────────┴───────────┘

结论:
  CenterPoint精度更高,但速度较慢
  适合对精度要求高的场景

六、CenterPoint在自动驾驶中的应用

6.1 检测流水线

python 复制代码
def detect(point_cloud):
    pillars = preprocess(point_cloud)
    pred = model(pillars)

    boxes = decode_predictions(pred)
    boxes = filter_by_score(boxes, threshold=0.3)

    return topk(boxes, k=100)

6.2 与其他传感器融合

python 复制代码
lidar_result = lidar_detector(lidar_points)

if camera_images is not None:
    camera_result = camera_detector(camera_images)
    result = fuse(lidar_result, camera_result)
else:
    result = lidar_result

七、常见问题解答(FAQ)

Q1: CenterPoint为什么不需要NMS?

复制代码
传统方法(PointPillars):
  - 每个位置有多个锚框
  - 一个物体可能被多个锚框检测到
  - 需要NMS去除重复

CenterPoint:
  - 每个物体只有一个中心点
  - 热力图峰值天然去重
  - 每个峰值只预测一个物体

图解:
  PointPillars:              CenterPoint:
  ┌───┬───┬───┐              ┌───┬───┬───┐
  │ A │ A │   │              │   │ ● │   │
  ├───┼───┼───┤              ├───┼───┼───┤
  │ A │ A │   │              │   │   │   │
  └───┴───┴───┘              └───┴───┴───┘
  A = 锚框(多个)            ● = 中心点(一个)
  需要NMS去重                 不需要NMS

Q2: 热力图峰值检测怎么实现?

复制代码
方法1:最大值池化
  - 3x3最大值池化
  - 找局部最大值
  - 简单有效

方法2:阈值过滤
  - 热力图值 > 阈值
  - 可能有多个点
  - 需要后处理

方法3:非极大值抑制
  - 对热力图做NMS
  - 更精确
  - 计算量稍大

常用方法:最大值池化 + 阈值过滤

Q3: CenterPoint的速度为什么比PointPillars慢?

复制代码
原因1:检测头更复杂
  PointPillars:分类 + 回归
  CenterPoint:热力图 + 尺寸 + 高度 + 朝向 + 速度

原因2:后处理不同
  PointPillars:简单NMS
  CenterPoint:峰值检测 + 解码

原因3:精度更高
  更复杂的网络 → 更高的精度

解决方法:
1. 轻量化骨干网络
2. 优化后处理
3. TensorRT加速

Q4: 如何处理遮挡物体?

复制代码
CenterPoint本身不处理遮挡

解决方法:
1. 时序融合
   - 融合多帧信息
   - 利用历史位置

2. 多传感器融合
   - LiDAR + Camera
   - 不同视角互补

3. 3D占用感知
   - 预测遮挡区域
   - 结合CenterPoint使用

实际系统:
  CenterPoint(快速检测)
  + 时序融合(速度估计)
  + 占用感知(遮挡处理)

Q5: CenterPoint能检测多远的物体?

复制代码
取决于激光雷达和设置

典型设置:
- KITTI:0-70m
- nuScenes:0-100m
- Waymo:0-150m

远处物体的问题:
- 点数少(稀疏)
- 热力图峰值弱
- 检测困难

解决方法:
1. 多尺度特征(FPN)
2. 远处用更大的感受野
3. 结合Camera(远处更清晰)

八、总结:CenterPoint的精髓

8.1 核心思想

  1. 中心点热力图:用热力图表示物体中心,峰值处即为中心
  2. 无锚框:不需要设计锚框,直接预测中心点
  3. 属性回归:对每个中心点回归3D属性
  4. 无需NMS:热力图峰值天然去重

8.2 一句话总结

CenterPoint通过将3D检测转化为中心点热力图预测,直接检测物体中心并回归3D属性,无需锚框和NMS,实现了更高的精度和更简洁的设计。

8.3 关键创新

创新 作用
中心点热力图 直接预测物体中心
无锚框设计 不需要先验知识
Focal Loss 处理正负样本不平衡
属性回归 直接预测3D属性

8.4 下一步学习

  1. 3D占用感知:从检测到分割,完整场景理解
  2. BEVFormer:鸟瞰图视角的多传感器融合
  3. TransFusion:Transformer融合

附录:关键术语表

术语 英文 含义
无锚框 Anchor-Free 不使用预定义锚框的检测方法
热力图 Heatmap 表示物体中心概率的2D图
Focal Loss - 处理正负样本不平衡的损失函数
NMS Non-Maximum Suppression 非极大值抑制
BEV Bird's Eye View 鸟瞰图视角
峰值检测 Peak Detection 在热力图上找局部最大值

下期预告:《从检测到感知------3D占用网格如何让自动驾驶"看穿"遮挡?》