记录一下自己的分割代码。
注意:
- 这是在windows环境,请Linux的同学们注意。
- 标签为txt,图像为jpg,其他的我没试过喔。
训练集、验证集、测试集(7:2:1)
import os
import shutil
import random
from tqdm import tqdm
"""
标注文件是yolo格式(txt文件)
训练集:验证集:测试集 (7:2:1)
"""
def split_img(img_path, label_path, split_list):
try:
Data = './ImageSets'
# 创建文件夹结构
train_img_dir = os.path.join(Data, 'images/train')
val_img_dir = os.path.join(Data, 'images/val')
test_img_dir = os.path.join(Data, 'images/test')
train_label_dir = os.path.join(Data, 'labels/train')
val_label_dir = os.path.join(Data, 'labels/val')
test_label_dir = os.path.join(Data, 'labels/test')
# 创建所有目录
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)
os.makedirs(test_img_dir, exist_ok=True)
os.makedirs(train_label_dir, exist_ok=True)
os.makedirs(val_label_dir, exist_ok=True)
os.makedirs(test_label_dir, exist_ok=True)
except Exception as e:
print(f'创建目录时出错: {e}')
train, val, test = split_list
# 仅处理存在对应标签文件的图像
all_img_path = []
for img in os.listdir(img_path):
img_full = os.path.join(img_path, img)
label_full = toLabelPath(img_full, label_path)
if os.path.isfile(label_full):
all_img_path.append(img_full)
else:
print(f"跳过无标签的图像: {img_full}")
# 随机打乱所有图像路径
random.shuffle(all_img_path)
total = len(all_img_path)
train_num = int(total * train)
val_num = int(total * val)
test_num = total - train_num - val_num
# 划分数据集
train_img = all_img_path[:train_num]
val_img = all_img_path[train_num:train_num + val_num]
test_img = all_img_path[train_num + val_num:]
# 复制训练集
for img in tqdm(train_img, desc='训练集', unit='img'):
label = toLabelPath(img, label_path)
_copy(img, train_img_dir)
_copy(label, train_label_dir)
# 复制验证集
for img in tqdm(val_img, desc='验证集', unit='img'):
label = toLabelPath(img, label_path)
_copy(img, val_img_dir)
_copy(label, val_label_dir)
# 复制测试集
for img in tqdm(test_img, desc='测试集', unit='img'):
label = toLabelPath(img, label_path)
_copy(img, test_img_dir)
_copy(label, test_label_dir)
def _copy(from_path, to_dir):
"""复制文件到目标目录,并确保目标目录存在"""
try:
os.makedirs(to_dir, exist_ok=True)
shutil.copy(from_path, to_dir)
except Exception as e:
print(f"复制 {from_path} 到 {to_dir} 失败: {e}")
def toLabelPath(img_path, label_path):
"""根据图片路径生成对应的标签路径"""
img_filename = os.path.basename(img_path)
base = os.path.splitext(img_filename)[0] # 去除扩展名
label_filename = base + '.txt'
return os.path.join(label_path, label_filename)
if __name__ == '__main__':
# 使用原始字符串避免转义问题
img_path = r'D:\文件\PlantVillage_for_object_detection\Dataset\images' #自己的图片文件地址
label_path = r'D:\文件\PlantVillage_for_object_detection\Dataset\labels' #自己的标签文件地址
split_list = [0.7, 0.2, 0.1] # 训练集:验证集:测试集
split_img(img_path, label_path, split_list)

over
