从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型

从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型

从零开始的目标检测和关键点检测(一):用labelme标注数据集

从零开始的目标检测和关键点检测(二):训练一个Glue的RTMDet模型

一、重写config文件

1、数据集类型即coco格式的数据集,在dataset_info声明classes、keypoint_info(关键点)、skeleton_info(骨架信息)。

python 复制代码
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'E:\\pythonproject\\mmdetection\\data\\glue_134_Keypoint\\'

# glue关键点检测数据集-元数据
dataset_info = {
    'dataset_name':'glue_134_Keypoint',
    'classes':'glue',
    'keypoint_info':{
        0:{'name':'head','id':0,'color':[255,0,0],'type': '','swap': ''},
        1:{'name':'tail','id':1,'color':[0,255,0],'type': '','swap': ''},
    },
    'skeleton_info': {
        0: {'link':('head','tail'),'id': 0,'color': [100,150,200]},
    }
}

2、训练参数

python 复制代码
# 训练超参数
max_epochs = 200 # 训练 epoch 总数
val_interval = 10 # 每隔多少个 epoch 保存一次权重文件
train_cfg = {'max_epochs': max_epochs, 'val_interval': val_interval}
train_batch_size = 32
val_batch_size = 8
stage2_num_epochs = 20
base_lr = 4e-3
randomness = dict(seed=21)

# 优化器
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# 学习率
param_scheduler = [
    dict(
        type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=20),
    dict(
        # use cosine lr from 210 to 420 epoch
        type='CosineAnnealingLR',
        eta_min=base_lr * 0.05,
        begin=max_epochs // 2,
        end=max_epochs,
        T_max=max_epochs // 2,
        by_epoch=True,
        convert_to_iter_based=True),
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=1024)

3、模型定义、数据预处理、数据加载

详细见源码。

二、开始训练

1、开始训练

bash 复制代码
python tools/train.py data/glue_134_Keypoint/rtmpose-t-glue.py

2、训练结果

bash 复制代码
07/27 14:34:07 - mmengine - INFO - Epoch(val) [200][6/6]    \
coco/AP: 0.851412  coco/AP .5: 1.000000  coco/AP .75: 1.000000  coco/AP (M): -1.000000 \
coco/AP (L): 0.857120  coco/AR: 0.892683  coco/AR .5: 1.000000  coco/AR .75: 1.000000  \
coco/AR (M): -1.000000  coco/AR (L): 0.892683  \
PCK: 1.000000  AUC: 0.789634  NME: 0.013435  data_time: 0.044700  time: 0.070389

测试一下训练结果

topdown测试 RTMDet + RTMPose

bash 复制代码
python demo/topdown_demo_with_mmdet.py \
    E:\\pythonproject\\mmdetection\\data\\glue_134_Keypoint\\rtmdet_tiny_glue.py \
    E:\\pythonproject\\mmdetection\\work_dirs\\rtmdet_tiny_glue\\best_coco_bbox_mAP_epoch_180.pth \
    data/glue_134_Keypoint/rtmpose-t-glue.py \
    work_dirs/rtmpose-t-glue/best_PCK_epoch_90.pth \
    --input data/glue_134_Keypoint/test_image/img.png \
    --output-root data/glue_134_Keypoint/test_image/result/ \
    --device cpu \
    --bbox-thr 0.5 \
    --kpt-thr 0.5 \
    --nms-thr 0.3 \
    --radius 5 \
    --thickness 5 \
    --draw-bbox  \
    --draw-heatmap \
    --show-kpt-idx


Pose测试 RTMPose,即手动把glue截出来再丢到网络里

bash 复制代码
python demo/image_demo.py data/glue_134_Keypoint/test_image/img_2.png \
    data/glue_134_Keypoint/rtmpose-t-glue.py \
    work_dirs/rtmpose-t-glue/best_PCK_epoch_90.pth \
    --out-file data/glue_134_Keypoint/test_image/result_2.png \
    --draw-heatmap

3、训练过程可视化

训练集损失函数

训练集准确率

测试集评估指标

测试集评估指标

三、ncnn部署

在线模型转换:Deploee

上传文件完成在线转换

相关推荐
胡耀超5 分钟前
知识图谱入门——3:工具分类与对比(知识建模工具:Protégé、 知识抽取工具:DeepDive、知识存储工具:Neo4j)
人工智能·知识图谱
陈苏同学13 分钟前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
吾名招财31 分钟前
yolov5-7.0模型DNN加载函数及参数详解(重要)
c++·人工智能·yolo·dnn
我是哈哈hh1 小时前
专题十_穷举vs暴搜vs深搜vs回溯vs剪枝_二叉树的深度优先搜索_算法专题详细总结
服务器·数据结构·c++·算法·机器学习·深度优先·剪枝
鼠鼠龙年发大财1 小时前
【鼠鼠学AI代码合集#7】概率
人工智能
Tisfy1 小时前
LeetCode 2187.完成旅途的最少时间:二分查找
算法·leetcode·二分查找·题解·二分
龙的爹23331 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt
工业机器视觉设计和实现1 小时前
cnn突破四(生成卷积核与固定核对比)
人工智能·深度学习·cnn
Mephisto.java1 小时前
【力扣 | SQL题 | 每日四题】力扣2082, 2084, 2072, 2112, 180
sql·算法·leetcode
robin_suli1 小时前
滑动窗口->dd爱框框
算法