作者:SkyXZ
CSDN:SkyXZ~-CSDN博客
博客园:SkyXZ - 博客园
- PointNet论文Arxiv地址:[1612.00593] PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation
传统的目标检测算法已经非常成熟,例如 YOLO 系列、DETR、Faster R-CNN 等,它们主要处理的是规则的二维图像数据。在图像中,像素按照规则网格排列,不同网格之间排列的不同会导致图像结果完全不同,这种有序性非常适合卷积神经网络进行特征提取。然而,3D 的点云完全不同。它是一组离散、无序且稀疏分布的空间点,没有固定的拓扑结构和排列顺序,也就是说点与点之间的邻居关系不是固定的。想象一下,你有14个乒乓球,他们随机地散落在桌子上,但共同组成了一个雨伞的形状。
这些小球就像点云中的点:它们位置无序,没有行列坐标;即使你把小球拿起来打乱顺序再放回去,雨伞的形状依然不变,而理想情况下,一个点云处理模型也应该具有这样的**"顺序不变性"** ------输入顺序变了,输出识别结果不变。然而传统基于卷积的网络并不具备这种能力,卷积依赖规则栅格来共享权重和提取局部特征,直接用在点云上不仅效率低下,还会让提取到的特征对点输入顺序非常敏感,并且难以捕捉物体的整体形状和局部几何关系。为了解决这些问题,斯坦福大学提出了一种开创性的思路,他们通过共享多层感知机(MLP)对每个点独立提取特征,再使用全局对称函数将这些特征汇总,从而将无序的点云信息转化为顺序不敏感的全局特征,实现了对点云的端到端学习与识别,这也是本文将要介绍的核心算法PointNet。
PS:💻 项目完整代码已上传至Github:,这篇文章是我的学习总结,如果你在阅读中有任何问题、建议或错误指出,也欢迎在评论区与我讨论,我们共同进步!
一、白话解析PointNet架构设计及各组件原理
在传统的检测算法中,模型通常依赖规则栅格和卷积运算来捕捉局部邻域特征,需要在输入端人为地建立点与点之间的空间关系;而 PointNet 则直接在点集上进行特征学习 ,平等的对待桌面上的每一个"小球",认为他们同等重要,在这个理论下,PointNet网络通过三个步骤来解决点云检测的问题:先把点云"摆正",再给每个点"贴标签",最后把信息"汇总起来":
- 先把点云"摆正"(T-Net)
前面我们提到过,点云数据具有无序性 和空间姿态不确定性 ,对于同一辆车来说,雷达从不同角度扫描得到的点云可能完全不一样,甚至整片点云都会发生旋转或平移,而PointNet的第一步便是让这些点先**"坐好"** ,其用一个叫 T-Net 的小网络来预测一个\(3×3\)或更高维的变换矩阵 \(T\),并将输入点云 {x_i \\in \\mathbb{R}\^3} \\(映射到一个更"规范"的坐标系中\\) x\^′_i=T⋅x_i.,这样,不s管车是转了个方向还是稍微偏移,后面处理点云的步骤都在一个"标准姿态"下进行,后续特征提取过程就能对空间扰动更稳健。
- 给每个点"贴标签"(共享 MLP)
点云里每个点都可能包含一些形状信息,但这些点是无序的,就像一堆散落在桌上的小球 ,传统的卷积神经网络习惯让像素按行按列排好队来逐个卷积得到特征顺序,可点云不讲规矩,所以PointNet不强行整理"队伍",而是平等对待点云信息中的每一个点 ,其将点云信息中的每个点都单独通过同一组多层感知机(MLP)来进行特征映射,我们设 MLP 为函数 h(\\cdot)\\(,其对每个点独立计算:\\)f_i=h(x_i\^′),i=1,...,n.,由于 MLP 权重是共享的,网络会平等对待每一个点,而不关心他们的输入顺序,因此不管点的顺序怎么换,其得到的结果都会是一样的 。
- 把信息"汇总起来"(最大值池化MaxPooling)
在上一步中我们将每个点都经过了MLP进行处理得到了每个点单独的特征维度,但由于每个点的顺序都是乱的,如果只是简单的对特征进行相加或者拼接得到的结果便会依赖于点云的数据,这时候斯坦福的研究人员便想到了使用最大值池化MaxPooling,也就是对每个特征维度\(j\),我们从所有点的特征里挑一个最大值出来\(g_j=max(f_{ij}),j=1,...,k\),组成一个固定维度的最大值特征向量,而由于最大值运算与输入顺序无关,因此这一步保证了整个网络具有数学意义上的"顺序不变性"。
我们可以用下面这段动画来辅助理解这三个步骤,假设我们有按任意顺序排列的五个小球且每个小球都有\((x,y,z)\)三维数据,我们将其放进一个多层感知机(MLP)中,将原来每个点的三维特征升维为八维特征,并将这个五个点对应的八个维度中不同通道的最大值保存下来得到一个最终的最大值向量,而由于取每个通道最大值的操作和点云的处理顺序是无关的,所以不管怎么改变点云的排列顺序我们最终得到的特征向量也是不变的,因此我们如果使用最终得到的这个无关的输入的特征向量进行分类、检测、分割便可以得到一个无论输入如何改变结果也不会发生任何变化的模型啦
而在PointNet网络中"小球"的数量不止五个,而MLP处理后的维度也不是八维,而是先从3维提升至64维再提升至1024维接着便使用最大池化得到一个1024维的特征向量后便可以直接将这个向量接入一个全连接层预测整个点云属于哪个类别;那如果我们需要接入分割任务的话,由于需要对每个点的类别概率进行预测,因此我们需要将前面MLP提取的各点各自的局部特征向量与全局的整体特征向量进行拼接,这样每个点既保留了整体点云的全局特征也保留了其自身的局部差异性特征,至此我们便可以对每个点的类别进行预测完成分割任务啦!


二、PyTorch完整复现PointNet
在理解了PointNet的原理之后我们便可以开始着手复现PointNet啦,这个网络并不复杂,接下来,我们将一步步用 PyTorch 搭建 PointNet,包括输入处理、T-Net 空间变换模块、共享 MLP 层、全局特征提取以及分类和分割任务的实现
1.数据加载DataLoad部分
ModelNet40 是由普林斯顿大学提出的一个 3D 形状分类数据集,包含 40 个不同类别的三维模型 (例如椅子、桌子、飞机、汽车等),总共有 12,311 个 CAD 模型,其中训练集 9,843 个,测试集 2,468 个,由于模型均由CAD模型转化而来,因此每个点云模均无噪声,且无背景,仅单个物体,我们可以使用如下方式来获取本数据集:
bash
wget http://modelnet.cs.princeton.edu/ModelNet40.zip
下载并解压之后的文件结构如下所示,所有数据按照类别划分并且分为 train
和 test
两个部分,文件格式为 .off(Object File Format),这是一种比较常见的三维模型存储格式,其中包含点的数量、面片数量,每个点的三维坐标,每个面片的顶点索引(描述三角面或多边形):
yaml
# off格式示例
- 第一行:固定为 `OFF`
- 第二行:三个整数,分别是 顶点数、面数、边数
- 接下来顶点坐标:每行一个点的 (x y z)
- 最后是面片数据:每行的第一个数是该面的顶点数 n,后面是 n 个顶点的索引(从 0 开始计数)
# 示例:
OFF #第一行
8 6 12 # 顶点数 面数 边数
0.0 0.0 0.0 # 顶点坐标
1.0 0.0 0.0 # 顶点坐标
1.0 1.0 0.0 # 顶点坐标
0.0 1.0 0.0 # 顶点坐标
0.0 0.0 1.0 # 顶点坐标
1.0 0.0 1.0 # 顶点坐标
1.0 1.0 1.0 # 顶点坐标
0.0 1.0 1.0 # 顶点坐标
4 0 1 2 3 # 面片数据
4 7 6 5 4 # 面片数据
4 0 4 5 1 # 面片数据
4 1 5 6 2 # 面片数据
4 2 6 7 3 # 面片数据
4 3 7 4 0 # 面片数据


我们首先对应读取每个OFF文件,ModelNet40数据集中的OFF文件除了标准的OFF num_vertices num_faces num_edges
格式外,还存在一些特殊的格式变体,比如连体格式OFF4528(OFF后面直接跟顶点数)以及简化格式num_vertices num_faces num_edges,因此我们需要识别并处理这些特殊的off格式,同时还要自动跳过注释行和空行,然后读取完数据之后返回标准的点云数据格式Points: [N, 3]
:
python
def read_off_file(self, file_path):
with open(file_path, 'r') as f:
lines = [line.strip() for line in f if line.strip() and not line.startswith('#')]
if not lines[0].upper().startswith('OFF'):
raise ValueError(f"{file_path} is not a valid OFF file.")
# 解析顶点数、面数
if len(lines[0].split()) == 1:
# 标准格式:OFF\nnum_vertices num_faces num_edges
num_vertices, num_faces, *_ = map(int, lines[1].split())
start = 2
else:
# 简化格式:OFF num_vertices num_faces [num_edges]
parts = lines[0].split()
num_vertices, num_faces = map(int, parts[1:3])
start = 1
points = []
for i in range(start, start + num_vertices):
x, y, z = map(float, lines[i].split()[:3])
points.append([x, y, z])
return np.array(points, dtype=np.float32)
在完成了OFF文件的获取之后我们接下来完成对点云的采样,通常一个 OFF 文件包含的点数\(N\)并不固定,但在训练模型时,我们通常希望每个点云有相同数量的点 num_points
。这时就需要对点云进行采样。最简单的方法是随机采样,同时由于一些简单物体的原始点云点数可能会少于我们要求的点数,这时候在采集点云的时候则允许重复采样,我们的实现代码如下:
python
def sample_points(self, points):
n = len(points)
replace = n < self.num_points
indices = np.random.choice(n, self.num_points, replace=replace)
return points[indices]
由于DataLoad比较基础,在本任务中唯一有难度的仅有OFF文件数据的处理,其余的与Torch
正常的数据加载流程一致,因此这里仅介绍如何处理off文件格式,其余的文件获取加载部分不再赘述直接上源代码:
python
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import glob
class ModelNet40Dataset(Dataset):
def __init__(self, root_dir, split='train', num_points=1024):
self.root_dir = root_dir
self.split = split
self.num_points = num_points
self.classes = sorted([d for d in os.listdir(root_dir)
if os.path.isdir(os.path.join(root_dir, d))])
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.file_paths = []
self.labels = []
for class_name in self.classes:
class_dir = os.path.join(root_dir, class_name, split)
if os.path.exists(class_dir):
files = glob.glob(os.path.join(class_dir, "*.off"))
self.file_paths.extend(files)
self.labels.extend([self.class_to_idx[class_name]] * len(files))
print(f"Found {len(self.file_paths)} {split} files in {len(self.classes)} classes")
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
label = self.labels[idx]
points = self.read_off_file(file_path)
points = self.sample_points(points)
# 点云归一化(中心化 + 缩放到单位球)
points = points - np.mean(points, axis=0) # 中心化
scale = np.max(np.linalg.norm(points, axis=1)) # 最大距离
points = points / scale # 缩放
points = torch.FloatTensor(points)
label = torch.LongTensor([label])
return points, label
def read_off_file(self, file_path):
with open(file_path, 'r') as f:
lines = [line.strip() for line in f if line.strip() and not line.startswith('#')]
if not lines[0].upper().startswith('OFF'):
raise ValueError(f"{file_path} is not a valid OFF file.")
# 解析顶点数、面数
if len(lines[0].split()) == 1:
# 标准格式:OFF\nnum_vertices num_faces num_edges
num_vertices, num_faces, *_ = map(int, lines[1].split())
start = 2
else:
# 简化格式:OFF num_vertices num_faces [num_edges]
parts = lines[0].split()
num_vertices, num_faces = map(int, parts[1:3])
start = 1
points = []
for i in range(start, start + num_vertices):
x, y, z = map(float, lines[i].split()[:3])
points.append([x, y, z])
return np.array(points, dtype=np.float32)
def sample_points(self, points):
n = len(points)
replace = n < self.num_points
indices = np.random.choice(n, self.num_points, replace=replace)
return points[indices]
def dataloader(root_dir, batch_size=32, num_points=1024, num_workers=4):
train_dataset = ModelNet40Dataset(
root_dir=root_dir,
split='train',
num_points=num_points
)
test_dataset = ModelNet40Dataset(
root_dir=root_dir,
split='test',
num_points=num_points
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, test_loader, len(train_dataset.classes)
2. PointNet主干网络
完成了数据集加载部分之后,我们开始复现主体的网络部分,按照我们前面所介绍的,PointNet网络主要分为三个组件分别是T-Net、MLP以及完成对应的分类分割任务的任务头,接下来我们一个模块一个模块的来完成:
2.1 T-Net
T-Net 的任务是学出一个 k×k
的仿射变换矩阵 ,用来把输入点云或者中间特征"摆正",让后续的特征提取不再受旋转、平移这些刚体变化的影响,在具体的实现上,我们先用三层MLP(用 Conv1d
+ BatchNorm
+ ReLU
来写)提取局部特征,然后通过全局最大池化得到一个 1024 维的全局向量,再接三层全连接层(512→256→k×k)直接输出变换矩阵,同时在这里我们把最后一层的权重初始化为 0,偏置初始化成单位矩阵,保证一开始输出就是"原封不动"的恒等变换,在Forward
中,输入的点云数据 [B, N, k]
会先转置为 [B, k, N]
,然后经卷积提特征后再压缩成全局描述,最后再reshape
成 [B, k, k]
矩阵用来对齐点云或特征,相当于给 PointNet 上了个"自动配准的前处理":
python
# -----------------------------
# T-Net 变换网络
# -----------------------------
class TNet(nn.Module):
def __init__(self, k=3):
super(TNet, self).__init__()
self.k = k
# MLP 层用 Conv1d 实现共享MLP
self.conv1 = nn.Conv1d(k, 64, 1)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, 1)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 1024, 1)
self.bn3 = nn.BatchNorm1d(1024)
# 全连接层生成变换矩阵
self.fc1 = nn.Linear(1024, 512)
self.bn4 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, 256)
self.bn5 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, k * k)
# 初始化偏置为单位矩阵
nn.init.constant_(self.fc3.weight, 0)
identity = torch.eye(k)
self.fc3.bias.data.copy_(identity.view(-1))
def forward(self, x):
B, N, _ = x.size()
x = x.transpose(1, 2) # [B, k, N]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2)[0] # 全局最大池化 [B, 1024]
x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x) # [B, k*k]
x = x.view(B, self.k, self.k)
return x
2.2 MLP
PointNet 的特征提取核心就是一组共享参数的多层感知机(Shared MLP) ,其本质上就是对每个点独立做相同的非线性映射,然后在通道维上堆叠成新的特征描述,我们在具体实现上用一维卷积 Conv1d
的 kernel_size=1 来实现 MLP模块,配合 BatchNorm1d
做归一化,让训练更加稳定,再接一个 ReLU 激活函数增加非线性,从T-Net
模块处理之后的点云输入数据 [B, C_in, N]
经过一层 MLP j就能变换到 [B, C_out, N]
,这样处理的特点是对点云顺序不敏感,因为每个点的特征提取完全共享权重,不会引入与点顺序相关的结构信息,这样我们就能方便地堆叠多层 MLP 逐步提升特征维度,为后续的全局池化和分类分割任务打好基础:
python
# -----------------------------
# MLP
# -----------------------------
class MLP(nn.Module):
def __init__(self, in_channels, out_channels, use_bn=True):
super(MLP, self).__init__()
self.use_bn = use_bn
self.conv = nn.Conv1d(in_channels, out_channels, 1)
if use_bn:
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
# x: [B, C, N]
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
x = F.relu(x)
return x
2.3 PointNet主体
PointNet主体就和前面我们介绍的网络架构一样,我们首先使用两个T-Net + MLP堆叠实现对输入数据的处理与升维,然后再根据不同的任务接入不同的任务头,分类任务通过全局最大池化提取全局特征向量并经过多层全连接输出类别概率,而分割任务则将每个点的局部特征与重复的全局特征拼接,通过全连接和 1×1 卷积生成每个点的类别预测,实现局部与全局信息的充分融合:
python
# -----------------------------
# PointNet 主体
# -----------------------------
class PointNet(nn.Module):
def __init__(self, num_classes=40, input_channels=3, task='cls'):
super(PointNet, self).__init__()
self.task = task
# T-Net
self.input_transform = TNet(k=input_channels)
# MLP - 64
self.mlp1 = nn.Sequential(
MLP(input_channels, 64),
MLP(64, 64)
)
# T-Net
self.feature_transform = TNet(k=64)
# MLP - 1024
self.mlp2 = nn.Sequential(
MLP(64, 64),
MLP(64, 128),
MLP(128, 1024)
)
# CLS-Head
if task == 'cls':
self.fc_cls = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
# SEG-Head
elif task == 'seg':
self.fc_seg = nn.Sequential(
nn.Linear(1088, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128)
)
self.conv_seg = nn.Conv1d(128, num_classes, 1)
def forward(self, x):
B, N, C = x.size()
trans = self.input_transform(x) # [B, C, C]
x = torch.bmm(x, trans) # [B, N, C]
x = x.transpose(1, 2) # [B, C, N] 用于MLP
x = self.mlp1(x) # [B, 64, N]
x_trans = x.transpose(1, 2) # [B, N, 64]
trans_feat = self.feature_transform(x_trans) # [B, 64, 64]
x = torch.bmm(x_trans, trans_feat) # [B, N, 64]
x = x.transpose(1, 2) # [B, 64, N]
x = self.mlp2(x) # [B, 1024, N]
if self.task == 'cls':
x = torch.max(x, 2)[0] # [B, 1024]
output = self.fc_cls(x)
elif self.task == 'seg':
local_feat = x
global_feat = torch.max(x, 2, keepdim=True)[0]
global_feat = global_feat.repeat(1, 1, N)
x = torch.cat([local_feat, global_feat], 1) # [B, 1088, N]
x = x.transpose(1, 2) # [B, N, 1088]
x = self.fc_seg(x) # [B, N, 128]
x = x.transpose(1, 2) # [B, 128, N]
output = self.conv_seg(x)
output = output.transpose(1, 2) # [B, N, num_classes]
return output, trans, trans_feat
3. Train训练部分
最后我们便可以完成训练部分的代码啦,这里的训练部分和其他模型的搭建方式类似就不过多赘述啦,但需要特别关注的是 PointNet 的损失函数设计,除了常规的交叉熵分类损失,PointNet还额外引入了对齐矩阵的正则化项,用于约束 T-Net 和特征变换矩阵以保持正交性,从而避免学习到奇异变换。PointNet通过\(I - A^T A\)的范数,将其作为额外的惩罚项加入总损失,这样可以确保网络学到的仿射矩阵接近旋转矩阵,既能实现空间对齐,又不至于破坏点云结构。其余的部分按照常规流程进行即可,完整代码如下:
python
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import time
import argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
from model import PointNet
from dataset import dataloader
def feature_transform_regularizer(trans):
batch_size = trans.size(0)
k = trans.size(1)
# 计算 I - A^T*A
identity = torch.eye(k).to(trans.device)
trans_square = torch.bmm(trans.transpose(1, 2), trans)
reg_loss = torch.mean(torch.norm(identity - trans_square, dim=(1, 2)))
return reg_loss
def pointnet_loss(pred, target, trans, trans_feat, weight=0.001):
cls_loss = nn.CrossEntropyLoss()(pred, target)
trans_loss = feature_transform_regularizer(trans)
trans_feat_loss = feature_transform_regularizer(trans_feat)
total_loss = cls_loss + weight * (trans_loss + trans_feat_loss)
return total_loss, cls_loss, trans_loss, trans_feat_loss
def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
model.train()
total_loss = 0
total_cls_loss = 0
total_reg_loss = 0
correct = 0
total = 0
pbar = tqdm(train_loader, desc=f'Epoch {epoch} [Train]')
for batch_idx, (points, labels) in enumerate(pbar):
points = points.to(device)
labels = labels.squeeze().to(device)
pred, trans, trans_feat = model(points)
loss, cls_loss, trans_loss, trans_feat_loss = criterion(pred, labels, trans, trans_feat)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_reg_loss += (trans_loss + trans_feat_loss).item()
_, predicted = pred.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Cls': f'{cls_loss.item():.4f}',
'Reg': f'{(trans_loss + trans_feat_loss).item():.4f}',
'Acc': f'{100.*correct/total:.2f}%'
})
avg_loss = total_loss / len(train_loader)
avg_cls_loss = total_cls_loss / len(train_loader)
avg_reg_loss = total_reg_loss / len(train_loader)
accuracy = 100. * correct / total
return avg_loss, avg_cls_loss, avg_reg_loss, accuracy
def validate(model, test_loader, criterion, device):
model.eval()
total_loss = 0
total_cls_loss = 0
total_reg_loss = 0
correct = 0
total = 0
with torch.no_grad():
pbar = tqdm(test_loader, desc='[Val]')
for points, labels in pbar:
points = points.to(device)
labels = labels.squeeze().to(device)
pred, trans, trans_feat = model(points)
loss, cls_loss, trans_loss, trans_feat_loss = criterion(pred, labels, trans, trans_feat)
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_reg_loss += (trans_loss + trans_feat_loss).item()
_, predicted = pred.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Cls': f'{cls_loss.item():.4f}',
'Reg': f'{(trans_loss + trans_feat_loss).item():.4f}',
'Acc': f'{100.*correct/total:.2f}%'
})
avg_loss = total_loss / len(test_loader)
avg_cls_loss = total_cls_loss / len(test_loader)
avg_reg_loss = total_reg_loss / len(test_loader)
accuracy = 100. * correct / total
return avg_loss, avg_cls_loss, avg_reg_loss, accuracy
def plot_training_curves(train_losses, train_accs, val_losses, val_accs, save_path):
epochs = range(1, len(train_losses) + 1)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, 'b-', label='Train Loss')
plt.plot(epochs, val_losses, 'r-', label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accs, 'b-', label='Train Acc')
plt.plot(epochs, val_accs, 'r-', label='Val Acc')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_path, 'training_curves.png'), dpi=300, bbox_inches='tight')
plt.close()
def main():
parser = argparse.ArgumentParser(description='PointNet Training')
parser.add_argument('--data_path', type=str,
default='/home/qi.xiong/Data_Qi/Dataset/Point/ModelNet40',
help='Path to ModelNet40 dataset')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument("--task", type=str, default="cls", help="Task type: cls or seg")
parser.add_argument('--num_points', type=int, default=1024, help='Number of points')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers')
parser.add_argument('--save_path', type=str, default='./checkpoints', help='Save path')
args = parser.parse_args()
os.makedirs(args.save_path, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print('Loading dataset...')
train_loader, test_loader, num_classes = dataloader(
root_dir=args.data_path,
batch_size=args.batch_size,
num_points=args.num_points,
num_workers=args.num_workers
)
print(f'Dataset loaded: {num_classes} classes')
print(f'Train samples: {len(train_loader.dataset)}')
print(f'Test samples: {len(test_loader.dataset)}')
model = PointNet(num_classes=num_classes, task=args.task, test_mode=False).to(device)
print(f'Model created with {sum(p.numel() for p in model.parameters()):,} parameters')
criterion = pointnet_loss
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
best_acc = 0
train_losses = []
train_cls_losses = []
train_reg_losses = []
train_accs = []
val_losses = []
val_cls_losses = []
val_reg_losses = []
val_accs = []
print('Starting training...')
start_time = time.time()
for epoch in range(1, args.epochs + 1):
train_loss, train_cls_loss, train_reg_loss, train_acc = train_epoch(
model, train_loader, criterion, optimizer, device, epoch
)
val_loss, val_cls_loss, val_reg_loss, val_acc = validate(model, test_loader, criterion, device)
scheduler.step()
train_losses.append(train_loss)
train_cls_losses.append(train_cls_loss)
train_reg_losses.append(train_reg_loss)
train_accs.append(train_acc)
val_losses.append(val_loss)
val_cls_losses.append(val_cls_loss)
val_reg_losses.append(val_reg_loss)
val_accs.append(val_acc)
print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f} (Cls: {train_cls_loss:.4f}, Reg: {train_reg_loss:.4f}), '
f'Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f} (Cls: {val_cls_loss:.4f}, Reg: {val_reg_loss:.4f}), '
f'Val Acc: {val_acc:.2f}%, LR: {scheduler.get_last_lr()[0]:.6f}')
if val_acc > best_acc:
best_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
'train_losses': train_losses,
'train_cls_losses': train_cls_losses,
'train_reg_losses': train_reg_losses,
'train_accs': train_accs,
'val_losses': val_losses,
'val_cls_losses': val_cls_losses,
'val_reg_losses': val_reg_losses,
'val_accs': val_accs,
}, os.path.join(args.save_path, 'best_model.pth'))
print(f'New best model saved with accuracy: {best_acc:.2f}%')
if epoch % 50 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
'train_losses': train_losses,
'train_cls_losses': train_cls_losses,
'train_reg_losses': train_reg_losses,
'train_accs': train_accs,
'val_losses': val_losses,
'val_cls_losses': val_cls_losses,
'val_reg_losses': val_reg_losses,
'val_accs': val_accs,
}, os.path.join(args.save_path, f'checkpoint_epoch_{epoch}.pth'))
total_time = time.time() - start_time
print(f'Training completed in {total_time/3600:.2f} hours! Best validation accuracy: {best_acc:.2f}%')
plot_training_curves(train_losses, train_accs, val_losses, val_accs, args.save_path)
print(f'Training curves saved to {args.save_path}/training_curves.png')
torch.save({
'epoch': args.epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_acc': best_acc,
'train_losses': train_losses,
'train_cls_losses': train_cls_losses,
'train_reg_losses': train_reg_losses,
'train_accs': train_accs,
'val_losses': val_losses,
'val_cls_losses': val_cls_losses,
'val_reg_losses': val_reg_losses,
'val_accs': val_accs,
}, os.path.join(args.save_path, 'final_model.pth'))
print('Final model saved!')
if __name__ == '__main__':
main()
参考的显存占用及训练截图如下:

