yolo数据集格式按照每一个类别的比例划分数据集

写在前面: 写脚本不易,写博客不易,请多点赞关注,谢谢。10多年来,我一直免费给大家毫无保留的分享技术等,不但从来没被打赏过,而且在分享有些模型转化处理的高级脚本中,有些同胞由于自身的问题,没有转化成功就出言伤人,各位同胞们,你们要知道,随着时间推移,各个第三方库的版本一致在更新,比如onnx,torch等,这就导致这些库的接口有所变化,写法也有所变化,包括你自己的环境等等各种原因导致你没能成功,而我是无私免费分享者,却遭到不礼貌的言语伤害,我伤心,但这并没有影响我的分享热情,我会继续分享给大家。
背景: 介于在用yolov系列模型训练模型时,数据集划分对训练结果的影响较大,所以写了一个数据集重新划分的脚本,支持两类检测的数据,标签分别为0和1,针对数据集中的空标签都有所考虑。
脚本功能: 按照不同标签,按比例进行划分,比如按照1:8的比例你设置训练集占80%,即train_ratio设置为0.8,那么你标签为0的数据数量的80%将划分为训练集,标签为1的80%将划分为训练集,针对空标签,你也可以选择不划分,或者将空标签仅放在训练集里面,或者正常按照训练集和验证集来划分,若一个标签中有多个不同类别的标签,也会单独把这一类数据拿出来,按照80%比例进行划分。

1、待划分数据集格式如下:

主目录yolodatas名称可以不一样,但是其下面的子目录文件夹的名字必须保持一致,若yolodatas名称不一样则需要在脚本当中修改,详细见3.

2、划分完成之后数据集格式如下:

3、划分脚本参数介绍:

--base_train_images_path 待划分数据的训练集图像路径

--base_val_images_path 待划分数据的验证集图像路径

--save_root_path 新的数据集保存路径 ,主目录名

--train_ratio 训练集所占的比例

--emptys 空的标签正常分训练集和验证集,否则不用空标签

--emptys_trian_val 空标签放训练集和验证集,否组只放训练集,emptys为True时有效,emptys为False时此参数无效。
说明: 如果原始待划分的数据集格式和1中所示一致,则不需要该脚本中的路径base_train_images_path和base_val_images_path,如果不一致,则需要修改。其他参数参考脚本参数修改即可。

4、转化脚本如下,可拷贝使用:

python 复制代码
import os
import argparse
import shutil
import random
def makdirs(opt):
    images_train = os.path.join(opt.save_root_path, 'images/train')
    images_val = os.path.join(opt.save_root_path ,'images/val')

    label_train = os.path.join(opt.save_root_path, 'labels/train')
    label_val = os.path.join(opt.save_root_path,'labels/val')
    os.makedirs(images_train,exist_ok=True)
    os.makedirs(images_val, exist_ok=True)
    os.makedirs(label_train, exist_ok=True)
    os.makedirs(label_val, exist_ok=True)

    return images_train,images_val,label_train,label_val
def readdir(opt):
    #所有图像路径列表
    images_path_list = [os.path.join(opt.base_train_images_path, train_label_name) for train_label_name in
                               os.listdir(opt.base_train_images_path)] \
                              + [os.path.join(opt.base_val_images_path, base_label_val_patch) for base_label_val_patch in
                                 os.listdir(opt.base_val_images_path)]
    #打乱列表
    random.shuffle(images_path_list)
    # 所有标签路径列表
    txt_path_list = [one_images_path_list.replace('images','labels')[:-3]+'txt' for one_images_path_list in images_path_list]

    zero_class_images = []
    zero_class_labels = []


    one_class_images = []
    one_class_labels = []

    one_two_class_images = []
    one_two_class_labels = []

    empty_class_images = []
    empty_class_labels = []


    #所有数据集数量
    file_number = len(images_path_list)

    for idx in range(file_number):
        with open(txt_path_list[idx],'r',encoding='utf-8') as fo:
            lines = fo.readlines()

        if len(lines)>1:
            #标签文件中有多个目标时,将多个目标标签存入列表并统计不同类型标签数量
            class_index_list = []
            #每一行去除\n
            relines = [lls.strip() for lls in lines]

            #遍历每一行,并统计每个标签的数量
            for line in relines:
                #每一行第一个标签名放入列表
                # print(line,line[0])
                class_index_list.append(int(line[0]))
            # print(class_index_list)

            # print("class_index_list:", class_index_list)

            #标签类
            class_label_list = list(set(class_index_list))
            # print("class_label_list:",class_label_list)

            if len(class_label_list)==1:
                if class_label_list[0]==0:
                    zero_class_images.append(images_path_list[idx])
                    zero_class_labels.append(txt_path_list[idx])
                elif class_label_list[0]==1:
                    one_class_images.append(images_path_list[idx])
                    one_class_labels.append(txt_path_list[idx])
                else:
                    print("注意还有其他标签,请修改此处代码,确保所有标签都添加都了标签列表里面!!!!")

            else:
                print("***********************************")
                one_two_class_images.append(images_path_list[idx])
                one_two_class_labels.append(txt_path_list[idx])


        elif len(lines)==1:
            #签文件中有一个目标时
            relines = [lls.strip() for lls in lines]
            class_ids = int(relines[0][0])
            if class_ids == 0:
                zero_class_images.append(images_path_list[idx])
                zero_class_labels.append(txt_path_list[idx])
            elif class_ids == 1:
                one_class_images.append(images_path_list[idx])
                one_class_labels.append(txt_path_list[idx])
            else:
                print("注意还有其他标签,请修改此处代码,确保所有标签都添加都了标签列表里面!!!!")
        elif len(lines)==0:
            #标签为空文件时
            empty_class_images.append(images_path_list[idx])
            empty_class_labels.append(txt_path_list[idx])
        else:
            print("error")
    return zero_class_images,zero_class_labels,one_class_images,one_class_labels,one_two_class_images,one_two_class_labels,empty_class_images,empty_class_labels

def copytofile(opt):
    images_train, images_val, label_train, label_val= makdirs(opt)
    zero_class_images, zero_class_labels, one_class_images, one_class_labels, one_two_class_images, one_two_class_labels, empty_class_images, empty_class_labels =readdir(opt)
    if len(zero_class_images)!=len(zero_class_labels) or len(one_class_images)!=len(one_class_labels):
        print("数据集列表不匹配1!!!!")
        return 0
    if len(one_two_class_images)!=len(one_two_class_labels) or len(empty_class_images)!=len(empty_class_labels):
        print("数据集列表不匹配2!!!!")
        return 0
    number_train_zero = int(opt.train_ratio*len(zero_class_images))
    number_train_one = int(opt.train_ratio*len(one_class_images))
    number_train_one_two = int(opt.train_ratio*len(one_two_class_images))
    number_train_empty = int(opt.train_ratio*len(empty_class_images))
    len_zero,len_one,len_one_two,len_empty = len(zero_class_images),len(one_class_images),len(one_two_class_images),len(empty_class_images)
    for z in range(len_zero):
        #标签为0的数据
        if z<number_train_zero:
            shutil.copy2(zero_class_images[z], os.path.join(images_train,zero_class_images[z].split('\\')[-1]))
            shutil.copy2(zero_class_labels[z], os.path.join(label_train, zero_class_labels[z].split('\\')[-1]))
        else:
            shutil.copy2(zero_class_images[z], os.path.join(images_val, zero_class_images[z].split('\\')[-1]))
            shutil.copy2(zero_class_labels[z], os.path.join(label_val, zero_class_labels[z].split('\\')[-1]))

    for o in range(len_one):
        #标签为1的数据
        print(one_class_images[o],os.path.join(images_train,one_class_images[o].split('\\')[-1]))
        print(one_class_labels[o],os.path.join(label_train, one_class_labels[o].split('\\')[-1]))
        if o<number_train_one:
            shutil.copy2(one_class_images[o], os.path.join(images_train,one_class_images[o].split('\\')[-1]))
            shutil.copy2(one_class_labels[o], os.path.join(label_train, one_class_labels[o].split('\\')[-1]))
        else:
            shutil.copy2(one_class_images[o], os.path.join(images_val, one_class_images[o].split('\\')[-1]))
            shutil.copy2(one_class_labels[o], os.path.join(label_val, one_class_labels[o].split('\\')[-1]))


    for ot in range(len_one_two):
        #标签为0和1的数据
        if ot<number_train_one_two:
            shutil.copy2(one_two_class_images[ot], os.path.join(images_train,one_two_class_images[ot].split('\\')[-1]))
            shutil.copy2(one_two_class_labels[ot], os.path.join(label_train, one_two_class_labels[ot].split('\\')[-1]))
        else:
            shutil.copy2(one_two_class_images[ot], os.path.join(images_val, one_two_class_images[ot].split('\\')[-1]))
            shutil.copy2(one_two_class_labels[ot], os.path.join(label_val, one_two_class_labels[ot].split('\\')[-1]))

    if opt.emptys:
        if opt.emptys_trian_val:
            for en in range(len_empty):
                #标签为空的数据
                if en<number_train_empty:
                    shutil.copy2(empty_class_images[en], os.path.join(images_train,empty_class_images[en].split('\\')[-1]))
                    shutil.copy2(empty_class_labels[en], os.path.join(label_train, empty_class_labels[en].split('\\')[-1]))
                else:
                    shutil.copy2(empty_class_images[en], os.path.join(images_val, empty_class_images[en].split('\\')[-1]))
                    shutil.copy2(empty_class_labels[en], os.path.join(label_val, empty_class_labels[en].split('\\')[-1]))
        else:
            for en in range(len_empty):
                #标签为空的数据
                shutil.copy2(empty_class_images[en], os.path.join(images_train,empty_class_images[en].split('\\')[-1]))
                shutil.copy2(empty_class_labels[en], os.path.join(label_train, empty_class_labels[en].split('\\')[-1]))


    print('#------------------------------------------------------总数据集数量为-----------------------------------------------------------------------#')
    print("划分前总数据集:",len_zero+len_one+len_one_two+len_empty)
    print('#-----------------------------------------------------所有数据集排布情况---------------------------------------------------------------------#')
    print("zero:%s\none:%s\nzero_two:%s\nempty:%s" % (len_zero,len_one,len_one_two,len_empty))
    print('#---------------------------------------------------数据集划分后训练集分布--------------------------------------------------------------------#')
    print("zero:%s\none:%s\nzero_two:%s\nempty:%s\n训练集总数:%s"%(number_train_zero,number_train_one,number_train_one_two,number_train_empty,number_train_zero+number_train_one+number_train_one_two+number_train_empty))
    print('#---------------------------------------------------数据集划分后验证集分布--------------------------------------------------------------------#')
    print("zero:%s\none:%s\nzero_two:%s\nempty:%s\n验证集总数:%s" % (len_zero-number_train_zero, len_one-number_train_one, len_one_two-number_train_one_two, len_empty-number_train_empty,len_zero+len_one+len_one_two+len_empty-(number_train_zero+number_train_one+number_train_one_two+number_train_empty)))
    print('#---------------------------------------------------数据集划分后数据集总数--------------------------------------------------------------------#')
    print("划分后总数据集:",number_train_zero+number_train_one+number_train_one_two+number_train_empty+len_zero+len_one+len_one_two+len_empty-(number_train_zero+number_train_one+number_train_one_two+number_train_empty))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_train_images_path', type=str, default='yolodatas/images/train',help='原始图像的训练集路径')
    parser.add_argument('--base_val_images_path', type=str, default='yolodatas/images/val', help='原始图像的验证集路径')
    parser.add_argument('--save_root_path', nargs='+', type=str, default='./mynewyolodata', help='新的数据集保存路径')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='训练集所占的比例')
    parser.add_argument('--emptys', type=bool, default=True, help='空的标签正常分训练集和验证集,否则不用空标签')
    parser.add_argument('--emptys_trian_val', type=bool, default=True, help='空标签放训练集和验证集,否组只放训练集')
    opt = parser.parse_args()

    copytofile(opt)
相关推荐
AI街潜水的八角10 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
金色旭光15 小时前
目标检测高频评价指标的计算过程
算法·yolo
AI街潜水的八角1 天前
PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统
pytorch·深度学习·yolo
Hugh&1 天前
(开源)基于Django+Yolov8+Tensorflow的智能鸟类识别平台
python·yolo·django·tensorflow
天天代码码天天2 天前
C# OpenCvSharp 部署读光-票证检测矫正模型(cv_resnet18_card_correction)
人工智能·深度学习·yolo·目标检测·计算机视觉·c#·票证检测矫正
前网易架构师-高司机2 天前
行人识别检测数据集,yolo格式,PASICAL VOC XML,COCO JSON,darknet等格式的标注都支持,准确识别率可达99.5%
xml·yolo·行人检测数据集
abments3 天前
C# OpenCvSharp Yolov8 Face Landmarks 人脸特征检测
开发语言·yolo·c#
Coovally AI模型快速验证3 天前
目标检测新视野 | YOLO、SSD与Faster R-CNN三大目标检测模型深度对比分析
人工智能·yolo·目标检测·计算机视觉·目标跟踪·r语言·cnn
那年一路北3 天前
深入探究 YOLOv5:从优势到模型导出全方位解析
人工智能·yolo·目标跟踪
明月下4 天前
【数据分析】coco格式数据生成yolo数据可视化
yolo·信息可视化·数据分析