从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型
从零开始的目标检测和关键点检测(一):用labelme标注数据集
从零开始的目标检测和关键点检测(二):训练一个Glue的RTMDet模型
一、重写config文件
1、数据集类型即coco格式的数据集,在dataset_info声明classes、keypoint_info(关键点)、skeleton_info(骨架信息)。
python
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'E:\\pythonproject\\mmdetection\\data\\glue_134_Keypoint\\'
# glue关键点检测数据集-元数据
dataset_info = {
'dataset_name':'glue_134_Keypoint',
'classes':'glue',
'keypoint_info':{
0:{'name':'head','id':0,'color':[255,0,0],'type': '','swap': ''},
1:{'name':'tail','id':1,'color':[0,255,0],'type': '','swap': ''},
},
'skeleton_info': {
0: {'link':('head','tail'),'id': 0,'color': [100,150,200]},
}
}
2、训练参数
python
# 训练超参数
max_epochs = 200 # 训练 epoch 总数
val_interval = 10 # 每隔多少个 epoch 保存一次权重文件
train_cfg = {'max_epochs': max_epochs, 'val_interval': val_interval}
train_batch_size = 32
val_batch_size = 8
stage2_num_epochs = 20
base_lr = 4e-3
randomness = dict(seed=21)
# 优化器
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
# 学习率
param_scheduler = [
dict(
type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=20),
dict(
# use cosine lr from 210 to 420 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]
# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=1024)
3、模型定义、数据预处理、数据加载
详细见源码。
二、开始训练
1、开始训练
bash
python tools/train.py data/glue_134_Keypoint/rtmpose-t-glue.py
2、训练结果
bash
07/27 14:34:07 - mmengine - INFO - Epoch(val) [200][6/6] \
coco/AP: 0.851412 coco/AP .5: 1.000000 coco/AP .75: 1.000000 coco/AP (M): -1.000000 \
coco/AP (L): 0.857120 coco/AR: 0.892683 coco/AR .5: 1.000000 coco/AR .75: 1.000000 \
coco/AR (M): -1.000000 coco/AR (L): 0.892683 \
PCK: 1.000000 AUC: 0.789634 NME: 0.013435 data_time: 0.044700 time: 0.070389
测试一下训练结果
topdown测试 RTMDet + RTMPose
bash
python demo/topdown_demo_with_mmdet.py \
E:\\pythonproject\\mmdetection\\data\\glue_134_Keypoint\\rtmdet_tiny_glue.py \
E:\\pythonproject\\mmdetection\\work_dirs\\rtmdet_tiny_glue\\best_coco_bbox_mAP_epoch_180.pth \
data/glue_134_Keypoint/rtmpose-t-glue.py \
work_dirs/rtmpose-t-glue/best_PCK_epoch_90.pth \
--input data/glue_134_Keypoint/test_image/img.png \
--output-root data/glue_134_Keypoint/test_image/result/ \
--device cpu \
--bbox-thr 0.5 \
--kpt-thr 0.5 \
--nms-thr 0.3 \
--radius 5 \
--thickness 5 \
--draw-bbox \
--draw-heatmap \
--show-kpt-idx
Pose测试 RTMPose
,即手动把glue截出来再丢到网络里
bash
python demo/image_demo.py data/glue_134_Keypoint/test_image/img_2.png \
data/glue_134_Keypoint/rtmpose-t-glue.py \
work_dirs/rtmpose-t-glue/best_PCK_epoch_90.pth \
--out-file data/glue_134_Keypoint/test_image/result_2.png \
--draw-heatmap
3、训练过程可视化
训练集损失函数
训练集准确率
测试集评估指标
测试集评估指标
三、ncnn部署
在线模型转换:Deploee
上传文件完成在线转换