【数据集】Yolo人体关键点数据集处理

文章目录

1、介绍

人体关键点检测(Human Keypoints Detection)又称为人体姿态估计2D Pose,是计算机视觉中一个相对基础的任务,是人体动作识别、行为分析、人机交互等的前置任务。一般情况下可以将人体关键点检测细分为单人/多人关键点检测、2D/3D关键点检测,同时有算法在完成关键点检测之后还会进行关键点的跟踪,也被称为人体姿态跟踪。

本次要介绍的数据集是2D关键点检测数据集,数据集主要来自COCO2017,经过对COCO数据集JSON文件进行预处理提取人体关键点信息,一共提取10000 张人体姿态数据集,以及对应的必要信息,已经转化为Yolo格式存储。

2、数据集格式

复制代码
keypoint_dataset
	|
	|_______images
	|		  |_____000000000049.jpg
	|		  |_____ ......
	|_______labels
	|		  |_____000000000049.txt
	|		  |_____ ......		  

3 、COCO人体关键点示意图

下图中,共17个关节点(鼻子x1、眼睛x2、耳朵x2、肩部x2、肘部x2、手腕x2、髋部x2、膝关节x2、脚腕x2):

4、数据集预处理

这里数据集预处理主要包括5个内容,分别是read_jsonshowjson2yolosplit_datasetvalidation

4.1、读取JSON文件

python 复制代码
def read_json(json_path):
    d = defaultdict(list)
    for json_name in tqdm(os.listdir(json_path)):
        json_full_path = json_path + '/' + json_name
        with open(json_full_path, 'r', encoding='utf-8') as file:  
            data = json.load(file)
        image_name = data['image_name']
        # category = data['category']
        # width = data['width']
        # height = data['height']
        # keypoints = data['key_points']
        # bbox = data['bbox']
        image_id = image_name.split('.')[0]
        d[int(image_id)]= data 
    print(f'json 文件读取完成,一共{len(d.keys())}个数据')
    return d

4.2、可视化

python 复制代码
def visualization(ax,keypoints,bbox):
    #随机颜色
    c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
    #关键点之间的连线
    ls = [[15,13],[13,11],[16,14],[14,12],[11,12],[5,11],[6,12],[5,6],[5,7],
          [6,8],[7,9],[8,10],[1,2],[0,1],[0,2],[1,3],[2,4],[3,5],[4,6]]
    sks = np.array(ls)
    #获取关键点坐标x,y,v
    kp = np.array(keypoints)
    x = kp[0::3]
    y = kp[1::3]
    v = kp[2::3]
    for sk in sks:
        if np.all(v[sk]>0):
            # 画点之间的连接线
            plt.plot(x[sk],y[sk], linewidth=1, color=c)
    # 画点
    p = plt.plot(x[v>0], y[v>0],'o',markersize=4, markerfacecolor=c, markeredgecolor='k',markeredgewidth=1)
    p = plt.plot(x[v>1], y[v>1],'o',markersize=4, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)
    # 画矩形边界,多边形填充+矩形边界:
    x, y, w, h = bbox[0],bbox[1],bbox[2],bbox[3]
    ax.add_patch(Polygon(xy=[[x, y], [x, y+h], [x+w, y+h], [x+w, y]], color='k', alpha=0.3))
    ax.add_patch(Rectangle(xy=(x, y), width=w, height=h, fill=False, color=c, alpha=1))
    plt.plot(x + w/2,y+h/2,'*',markersize=5, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)

def show(data,image_root):
    for image_id in data.keys():

        img = io.imread('%s/%s' % (image_root, data[image_id]['image_name']))
        plt.axis('off')
        ax = plt.gca()
        for i in range(len(data[image_id]['key_points'])):
            visualization(ax, data[image_id]['key_points'][i], data[image_id]['bbox'][i])
        plt.imshow(img)
        plt.axis('off')
        plt.show()

4.3、JSON格式转Yolo格式

yolo标注格式为:类别 、标注框中心点(x,y)、长和宽(w,h),关键点坐标以及可见度(kx1,ky1,kv1,kx2,ky2,kv2...),然后并根据图片长宽进行归一化处理

python 复制代码
def json2yolo(data,txt_path):
    if not os.path.exists(txt_path):
        os.mkdir(txt_path)
    for image_id in tqdm(data.keys()):
        txt_name = data[image_id]['image_name'].split('.')[0]
        category = data[image_id]['category']
        width = data[image_id]['width']
        height = data[image_id]['height']
        keypoints = data[image_id]['key_points']
        bbox = data[image_id]['bbox']
        '''
        yolo标注格式
        类别 、标注框中心点、长和宽,关键点坐标以及可见度
        根据图片长高进行归一化处理
        '''
        for i in range(len(keypoints)):
            #对标注框进行预处理
            box = bbox[i]
            keypoint = keypoints[i]
            # print(keypoint)
            #x,y为左上角坐标,w,h为宽高
            x,y,w,h = box[0],box[1],box[2],box[3]
            xx = x + w/2
            yy = y + h/2
            #归一化
            xx,yy,ww,hh = xx/width,yy/height,w/width,h/height
            #关键点归一化
            proc_keypoint = []
            for p in range(0,len(keypoint),3):
                x,y,v = keypoint[p],keypoint[p+1],keypoint[p+2]
                kx,ky = x/width, y/height
                proc_keypoint.extend([kx,ky,v])
            #写入txt文件
            yolo_str = f'{category} {xx:.6f} {yy:.6f} {ww:.6f} {hh:.6f} '
            yolo_str = yolo_str + ' '.join([f'{i:.6f}' for i in proc_keypoint])
            
            with open(f'{txt_path}/{txt_name}.txt','a+') as f:
                f.write(yolo_str + '\n')
    print('转化为yolo txt格式完成!')

4.4、划分数据集

划分数据集格式为yolo需要的格式

python 复制代码
def split_dataset(image_root,txt_root,img_targe_file,label_target_file,split_ratio=0.6):
    imgs = os.listdir(image_root)
    import random
    random.seed(2024)
    random.shuffle(imgs)
    #这里仅仅取了1万张图片进行测试,由于显存限制,
    ims = imgs[:10000]
    random.shuffle(ims)
    train_num = int(len(ims)*split_ratio)
    val_num = int(len(ims)*(1-split_ratio)/2)

    train_set = ims[:train_num]
    val_set = ims[train_num:train_num+val_num]
    test_set = ims[-val_num:]

    move_file(train_set,image_root,txt_root,img_targe_file,label_target_file,mode='train')
    move_file(val_set,image_root,txt_root,img_targe_file,label_target_file,mode='val')
    move_file(test_set,image_root,txt_root,img_targe_file,label_target_file,mode='test')

def move_file(dataset,image_root,txt_root,img_targe_file,label_target_file,mode='train'):
    for img in tqdm(dataset):
        img_path = image_root + '/' + img
        txt_path = txt_root + '/' + img.replace('jpg','txt')
        img_targe_file_full = img_targe_file + '/' + mode
        label_target_file_full = label_target_file + '/' + mode


        shutil.copy(img_path, img_targe_file_full)
        shutil.copy(txt_path, label_target_file_full)

4.5、验证

python 复制代码
def validation():
    images = r'E:\datasets\keypoint_dataset\datasets\images'
    labels = r'E:\datasets\keypoint_dataset\datasets\labels'
    for mode in ['train','val','test']:
        img_path = images + '/' + mode
        label_path = labels + '/' + mode
        assert len(os.listdir(img_path)) == len(os.listdir(label_path)),f'{mode}数据划分验证失败'
        for img,lb in zip(os.listdir(img_path),os.listdir(label_path)):
            img_name = img.split('.')[0]
            lb_name = lb.split('.')[0]
            if img_name != lb_name:
                 assert '数据划分验证失败'
    print('数据划分验证成功!')
    pass

5、完整代码

python 复制代码
import json
from collections import defaultdict
import numpy as np
import os
from tqdm import tqdm
import shutil
from matplotlib import pyplot as plt
import skimage.io as io

from matplotlib.patches import Polygon,Rectangle
def visualization(ax,keypoints,bbox):
    #随机颜色
    c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
    #关键点之间的连线
    ls = [[15,13],[13,11],[16,14],[14,12],[11,12],[5,11],[6,12],[5,6],[5,7],
          [6,8],[7,9],[8,10],[1,2],[0,1],[0,2],[1,3],[2,4],[3,5],[4,6]]
    sks = np.array(ls)
    #获取关键点坐标x,y,v
    kp = np.array(keypoints)
    x = kp[0::3]
    y = kp[1::3]
    v = kp[2::3]
    for sk in sks:
        if np.all(v[sk]>0):
            # 画点之间的连接线
            plt.plot(x[sk],y[sk], linewidth=1, color=c)
    # 画点
    p = plt.plot(x[v>0], y[v>0],'o',markersize=4, markerfacecolor=c, markeredgecolor='k',markeredgewidth=1)
    p = plt.plot(x[v>1], y[v>1],'o',markersize=4, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)
    # 画矩形边界,多边形填充+矩形边界:
    x, y, w, h = bbox[0],bbox[1],bbox[2],bbox[3]
    ax.add_patch(Polygon(xy=[[x, y], [x, y+h], [x+w, y+h], [x+w, y]], color='k', alpha=0.3))
    ax.add_patch(Rectangle(xy=(x, y), width=w, height=h, fill=False, color=c, alpha=1))
    plt.plot(x + w/2,y+h/2,'*',markersize=5, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)

def read_json(json_path):
    d = defaultdict(list)
    
    for json_name in tqdm(os.listdir(json_path)):
        json_full_path = json_path + '/' + json_name
        with open(json_full_path, 'r', encoding='utf-8') as file:  
            data = json.load(file)
        image_name = data['image_name']
        # category = data['category']
        # width = data['width']
        # height = data['height']
        # keypoints = data['key_points']
        # bbox = data['bbox']
        image_id = image_name.split('.')[0]
        
        d[int(image_id)]= data 
    
    print(f'json 文件读取完成,一共{len(d.keys())}个数据')
    return d
def show(data,image_root):
    for image_id in data.keys():

        img = io.imread('%s/%s' % (image_root, data[image_id]['image_name']))
        plt.axis('off')
        ax = plt.gca()
        for i in range(len(data[image_id]['key_points'])):
            visualization(ax, data[image_id]['key_points'][i], data[image_id]['bbox'][i])
        plt.imshow(img)
        plt.axis('off')
        plt.show()
def json2yolo(data,txt_path):
    if not os.path.exists(txt_path):
        os.mkdir(txt_path)
    for image_id in tqdm(data.keys()):
        txt_name = data[image_id]['image_name'].split('.')[0]
        category = data[image_id]['category']
        width = data[image_id]['width']
        height = data[image_id]['height']
        keypoints = data[image_id]['key_points']
        bbox = data[image_id]['bbox']
        '''
        yolo标注格式
        类别 、标注框中心点、长和宽,关键点坐标以及可见度
        根据图片长高进行归一化处理
        '''
        for i in range(len(keypoints)):
            #对标注框进行预处理
            box = bbox[i]
            keypoint = keypoints[i]
            # print(keypoint)
            #x,y为左上角坐标,w,h为宽高
            x,y,w,h = box[0],box[1],box[2],box[3]
            xx = x + w/2
            yy = y + h/2
            #归一化
            xx,yy,ww,hh = xx/width,yy/height,w/width,h/height
            #关键点归一化
            proc_keypoint = []
            for p in range(0,len(keypoint),3):
                x,y,v = keypoint[p],keypoint[p+1],keypoint[p+2]
                kx,ky = x/width, y/height
                proc_keypoint.extend([kx,ky,v])
            #写入txt文件
            yolo_str = f'{category} {xx:.6f} {yy:.6f} {ww:.6f} {hh:.6f} '
            yolo_str = yolo_str + ' '.join([f'{i:.6f}' for i in proc_keypoint])
            
            with open(f'{txt_path}/{txt_name}.txt','a+') as f:
                f.write(yolo_str + '\n')
    print('转化为yolo txt格式完成!')

def split_dataset(image_root,txt_root,img_targe_file,label_target_file,split_ratio=0.6):
    imgs = os.listdir(image_root)
    import random
    random.seed(2024)
    random.shuffle(imgs)
    #这里仅仅取了1万张图片进行测试,由于显存限制,
    ims = imgs[:10000]
    random.shuffle(ims)
    train_num = int(len(ims)*split_ratio)
    val_num = int(len(ims)*(1-split_ratio)/2)

    train_set = ims[:train_num]
    val_set = ims[train_num:train_num+val_num]
    test_set = ims[-val_num:]

    move_file(train_set,image_root,txt_root,img_targe_file,label_target_file,mode='train')
    move_file(val_set,image_root,txt_root,img_targe_file,label_target_file,mode='val')
    move_file(test_set,image_root,txt_root,img_targe_file,label_target_file,mode='test')

def move_file(dataset,image_root,txt_root,img_targe_file,label_target_file,mode='train'):
    for img in tqdm(dataset):
        img_path = image_root + '/' + img
        txt_path = txt_root + '/' + img.replace('jpg','txt')
        img_targe_file_full = img_targe_file + '/' + mode
        label_target_file_full = label_target_file + '/' + mode


        shutil.copy(img_path, img_targe_file_full)
        shutil.copy(txt_path, label_target_file_full)
def validation():
    images = r'E:\datasets\keypoint_dataset\datasets\images'
    labels = r'E:\datasets\keypoint_dataset\datasets\labels'
    for mode in ['train','val','test']:
        img_path = images + '/' + mode
        label_path = labels + '/' + mode
        assert len(os.listdir(img_path)) == len(os.listdir(label_path)),f'{mode}数据划分验证失败'
        for img,lb in zip(os.listdir(img_path),os.listdir(label_path)):
            img_name = img.split('.')[0]
            lb_name = lb.split('.')[0]
            if img_name != lb_name:
                 assert '数据划分验证失败'
    print('数据划分验证成功!')
    pass

if __name__ == '__main__':
    root = 'E:/datasets/keypoint_dataset'
    image_root = 'E:/datasets/keypoint_dataset/images'
    json_root = 'E:/datasets/keypoint_dataset/labels'
    txt_root = root + '/'+ 'txt'
    #读取json文件
    data = read_json(json_root)
    #可视化
    # show(data,image_root)
    #json2yolo
    # json2yolo(data,txt_root)
    #split dataset
    # img_target_file = root + '/'+ 'datasets' + '/images'
    # label_target_file = root + '/'+ 'datasets' + '/labels'
    # split_dataset(image_root,txt_root,img_target_file,label_target_file,split_ratio=0.7)
    #validation验证
    validation()

6、数据集下载链接

数据集下载链接

相关推荐
向哆哆4 天前
高精度织物缺陷检测数据集(适用YOLO系列/1000+标注)(已标注+划分/可直接训练)
yolo·目标检测
前网易架构师-高司机4 天前
带标注的驾驶员安全带识别数据集,识别率99.5%,可识别有无系安全带,支持yolo,coco json,pascal voc xml格式
xml·yolo·数据集·交通·安全带
向哆哆4 天前
粉尘环境分类检测千张图数据集(适用YOLO系列)(已标注+划分/可直接训练)
yolo·分类·数据挖掘
琅琊榜首20205 天前
移动端AI挂机新范式:YOLOv8+NCNN实现无Root视觉自动化
人工智能·yolo·自动化
智驱力人工智能5 天前
地铁隧道轨道障碍物实时检测方案 守护城市地下动脉的工程实践 轨道障碍物检测 高铁站区轨道障碍物AI预警 铁路轨道异物识别系统价格
人工智能·算法·yolo·目标检测·计算机视觉·边缘计算
智驱力人工智能5 天前
机场鸟类活动智能监测 守护航空安全的精准工程实践 飞鸟检测 机场鸟击预防AI预警系统方案 机场停机坪鸟类干扰实时监测机场航站楼鸟击预警
人工智能·opencv·算法·安全·yolo·目标检测·边缘计算
前端摸鱼匠5 天前
YOLOv8使用 Ultralytics 内置功能简化格式转换:介绍如何使用 yolo mode=data 等相关功能或辅助工具来加速和简化数据格式的准备工作
人工智能·yolo·目标检测·机器学习·目标跟踪·视觉检测
hans汉斯5 天前
《数据挖掘》期刊推介&征稿指南
图像处理·人工智能·算法·yolo·数据挖掘·超分辨率重建·汉斯出版社
卓越软件开发5 天前
毕设全栈开发一条龙:Java/SpringBoot/Vue/ 小程序 / Python / 安卓 / AI 图像识别 人脸检测 车牌识别 YOLO
开发语言·spring boot·python·yolo·小程序·毕业设计·课程设计
向哆哆6 天前
单车/共享单车目标检测数据集(适用YOLO系列)(已标注+划分/可直接训练)
人工智能·yolo·目标检测