写在前面: 写脚本不易,写博客不易,请多点赞关注,谢谢。10多年来,我一直免费给大家毫无保留的分享技术等,不但从来没被打赏过,而且在分享有些模型转化处理的高级脚本中,有些同胞由于自身的问题,没有转化成功就出言伤人,各位同胞们,你们要知道,随着时间推移,各个第三方库的版本一致在更新,比如onnx,torch等,这就导致这些库的接口有所变化,写法也有所变化,包括你自己的环境等等各种原因导致你没能成功,而我是无私免费分享者,却遭到不礼貌的言语伤害,我伤心,但这并没有影响我的分享热情,我会继续分享给大家。
背景: 介于在用yolov系列模型训练模型时,数据集划分对训练结果的影响较大,所以写了一个数据集重新划分的脚本,支持两类检测的数据,标签分别为0和1,针对数据集中的空标签都有所考虑。
脚本功能: 按照不同标签,按比例进行划分,比如按照1:8的比例你设置训练集占80%,即train_ratio设置为0.8,那么你标签为0的数据数量的80%将划分为训练集,标签为1的80%将划分为训练集,针对空标签,你也可以选择不划分,或者将空标签仅放在训练集里面,或者正常按照训练集和验证集来划分,若一个标签中有多个不同类别的标签,也会单独把这一类数据拿出来,按照80%比例进行划分。
1、待划分数据集格式如下:
主目录yolodatas名称可以不一样,但是其下面的子目录文件夹的名字必须保持一致,若yolodatas名称不一样则需要在脚本当中修改,详细见3.
2、划分完成之后数据集格式如下:
3、划分脚本参数介绍:
--base_train_images_path 待划分数据的训练集图像路径
--base_val_images_path 待划分数据的验证集图像路径
--save_root_path 新的数据集保存路径 ,主目录名
--train_ratio 训练集所占的比例
--emptys 空的标签正常分训练集和验证集,否则不用空标签
--emptys_trian_val 空标签放训练集和验证集,否组只放训练集,emptys为True时有效,emptys为False时此参数无效。
说明: 如果原始待划分的数据集格式和1中所示一致,则不需要该脚本中的路径base_train_images_path和base_val_images_path,如果不一致,则需要修改。其他参数参考脚本参数修改即可。
4、转化脚本如下,可拷贝使用:
python
import os
import argparse
import shutil
import random
def makdirs(opt):
images_train = os.path.join(opt.save_root_path, 'images/train')
images_val = os.path.join(opt.save_root_path ,'images/val')
label_train = os.path.join(opt.save_root_path, 'labels/train')
label_val = os.path.join(opt.save_root_path,'labels/val')
os.makedirs(images_train,exist_ok=True)
os.makedirs(images_val, exist_ok=True)
os.makedirs(label_train, exist_ok=True)
os.makedirs(label_val, exist_ok=True)
return images_train,images_val,label_train,label_val
def readdir(opt):
#所有图像路径列表
images_path_list = [os.path.join(opt.base_train_images_path, train_label_name) for train_label_name in
os.listdir(opt.base_train_images_path)] \
+ [os.path.join(opt.base_val_images_path, base_label_val_patch) for base_label_val_patch in
os.listdir(opt.base_val_images_path)]
#打乱列表
random.shuffle(images_path_list)
# 所有标签路径列表
txt_path_list = [one_images_path_list.replace('images','labels')[:-3]+'txt' for one_images_path_list in images_path_list]
zero_class_images = []
zero_class_labels = []
one_class_images = []
one_class_labels = []
one_two_class_images = []
one_two_class_labels = []
empty_class_images = []
empty_class_labels = []
#所有数据集数量
file_number = len(images_path_list)
for idx in range(file_number):
with open(txt_path_list[idx],'r',encoding='utf-8') as fo:
lines = fo.readlines()
if len(lines)>1:
#标签文件中有多个目标时,将多个目标标签存入列表并统计不同类型标签数量
class_index_list = []
#每一行去除\n
relines = [lls.strip() for lls in lines]
#遍历每一行,并统计每个标签的数量
for line in relines:
#每一行第一个标签名放入列表
# print(line,line[0])
class_index_list.append(int(line[0]))
# print(class_index_list)
# print("class_index_list:", class_index_list)
#标签类
class_label_list = list(set(class_index_list))
# print("class_label_list:",class_label_list)
if len(class_label_list)==1:
if class_label_list[0]==0:
zero_class_images.append(images_path_list[idx])
zero_class_labels.append(txt_path_list[idx])
elif class_label_list[0]==1:
one_class_images.append(images_path_list[idx])
one_class_labels.append(txt_path_list[idx])
else:
print("注意还有其他标签,请修改此处代码,确保所有标签都添加都了标签列表里面!!!!")
else:
print("***********************************")
one_two_class_images.append(images_path_list[idx])
one_two_class_labels.append(txt_path_list[idx])
elif len(lines)==1:
#签文件中有一个目标时
relines = [lls.strip() for lls in lines]
class_ids = int(relines[0][0])
if class_ids == 0:
zero_class_images.append(images_path_list[idx])
zero_class_labels.append(txt_path_list[idx])
elif class_ids == 1:
one_class_images.append(images_path_list[idx])
one_class_labels.append(txt_path_list[idx])
else:
print("注意还有其他标签,请修改此处代码,确保所有标签都添加都了标签列表里面!!!!")
elif len(lines)==0:
#标签为空文件时
empty_class_images.append(images_path_list[idx])
empty_class_labels.append(txt_path_list[idx])
else:
print("error")
return zero_class_images,zero_class_labels,one_class_images,one_class_labels,one_two_class_images,one_two_class_labels,empty_class_images,empty_class_labels
def copytofile(opt):
images_train, images_val, label_train, label_val= makdirs(opt)
zero_class_images, zero_class_labels, one_class_images, one_class_labels, one_two_class_images, one_two_class_labels, empty_class_images, empty_class_labels =readdir(opt)
if len(zero_class_images)!=len(zero_class_labels) or len(one_class_images)!=len(one_class_labels):
print("数据集列表不匹配1!!!!")
return 0
if len(one_two_class_images)!=len(one_two_class_labels) or len(empty_class_images)!=len(empty_class_labels):
print("数据集列表不匹配2!!!!")
return 0
number_train_zero = int(opt.train_ratio*len(zero_class_images))
number_train_one = int(opt.train_ratio*len(one_class_images))
number_train_one_two = int(opt.train_ratio*len(one_two_class_images))
number_train_empty = int(opt.train_ratio*len(empty_class_images))
len_zero,len_one,len_one_two,len_empty = len(zero_class_images),len(one_class_images),len(one_two_class_images),len(empty_class_images)
for z in range(len_zero):
#标签为0的数据
if z<number_train_zero:
shutil.copy2(zero_class_images[z], os.path.join(images_train,zero_class_images[z].split('\\')[-1]))
shutil.copy2(zero_class_labels[z], os.path.join(label_train, zero_class_labels[z].split('\\')[-1]))
else:
shutil.copy2(zero_class_images[z], os.path.join(images_val, zero_class_images[z].split('\\')[-1]))
shutil.copy2(zero_class_labels[z], os.path.join(label_val, zero_class_labels[z].split('\\')[-1]))
for o in range(len_one):
#标签为1的数据
print(one_class_images[o],os.path.join(images_train,one_class_images[o].split('\\')[-1]))
print(one_class_labels[o],os.path.join(label_train, one_class_labels[o].split('\\')[-1]))
if o<number_train_one:
shutil.copy2(one_class_images[o], os.path.join(images_train,one_class_images[o].split('\\')[-1]))
shutil.copy2(one_class_labels[o], os.path.join(label_train, one_class_labels[o].split('\\')[-1]))
else:
shutil.copy2(one_class_images[o], os.path.join(images_val, one_class_images[o].split('\\')[-1]))
shutil.copy2(one_class_labels[o], os.path.join(label_val, one_class_labels[o].split('\\')[-1]))
for ot in range(len_one_two):
#标签为0和1的数据
if ot<number_train_one_two:
shutil.copy2(one_two_class_images[ot], os.path.join(images_train,one_two_class_images[ot].split('\\')[-1]))
shutil.copy2(one_two_class_labels[ot], os.path.join(label_train, one_two_class_labels[ot].split('\\')[-1]))
else:
shutil.copy2(one_two_class_images[ot], os.path.join(images_val, one_two_class_images[ot].split('\\')[-1]))
shutil.copy2(one_two_class_labels[ot], os.path.join(label_val, one_two_class_labels[ot].split('\\')[-1]))
if opt.emptys:
if opt.emptys_trian_val:
for en in range(len_empty):
#标签为空的数据
if en<number_train_empty:
shutil.copy2(empty_class_images[en], os.path.join(images_train,empty_class_images[en].split('\\')[-1]))
shutil.copy2(empty_class_labels[en], os.path.join(label_train, empty_class_labels[en].split('\\')[-1]))
else:
shutil.copy2(empty_class_images[en], os.path.join(images_val, empty_class_images[en].split('\\')[-1]))
shutil.copy2(empty_class_labels[en], os.path.join(label_val, empty_class_labels[en].split('\\')[-1]))
else:
for en in range(len_empty):
#标签为空的数据
shutil.copy2(empty_class_images[en], os.path.join(images_train,empty_class_images[en].split('\\')[-1]))
shutil.copy2(empty_class_labels[en], os.path.join(label_train, empty_class_labels[en].split('\\')[-1]))
print('#------------------------------------------------------总数据集数量为-----------------------------------------------------------------------#')
print("划分前总数据集:",len_zero+len_one+len_one_two+len_empty)
print('#-----------------------------------------------------所有数据集排布情况---------------------------------------------------------------------#')
print("zero:%s\none:%s\nzero_two:%s\nempty:%s" % (len_zero,len_one,len_one_two,len_empty))
print('#---------------------------------------------------数据集划分后训练集分布--------------------------------------------------------------------#')
print("zero:%s\none:%s\nzero_two:%s\nempty:%s\n训练集总数:%s"%(number_train_zero,number_train_one,number_train_one_two,number_train_empty,number_train_zero+number_train_one+number_train_one_two+number_train_empty))
print('#---------------------------------------------------数据集划分后验证集分布--------------------------------------------------------------------#')
print("zero:%s\none:%s\nzero_two:%s\nempty:%s\n验证集总数:%s" % (len_zero-number_train_zero, len_one-number_train_one, len_one_two-number_train_one_two, len_empty-number_train_empty,len_zero+len_one+len_one_two+len_empty-(number_train_zero+number_train_one+number_train_one_two+number_train_empty)))
print('#---------------------------------------------------数据集划分后数据集总数--------------------------------------------------------------------#')
print("划分后总数据集:",number_train_zero+number_train_one+number_train_one_two+number_train_empty+len_zero+len_one+len_one_two+len_empty-(number_train_zero+number_train_one+number_train_one_two+number_train_empty))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base_train_images_path', type=str, default='yolodatas/images/train',help='原始图像的训练集路径')
parser.add_argument('--base_val_images_path', type=str, default='yolodatas/images/val', help='原始图像的验证集路径')
parser.add_argument('--save_root_path', nargs='+', type=str, default='./mynewyolodata', help='新的数据集保存路径')
parser.add_argument('--train_ratio', type=float, default=0.8, help='训练集所占的比例')
parser.add_argument('--emptys', type=bool, default=True, help='空的标签正常分训练集和验证集,否则不用空标签')
parser.add_argument('--emptys_trian_val', type=bool, default=True, help='空标签放训练集和验证集,否组只放训练集')
opt = parser.parse_args()
copytofile(opt)