python
复制代码
import os
import random
import time
from pathlib import Path
import shutil
import tkinter as tk
from tkinter import filedialog
from loguru import logger
import xml.etree.ElementTree as ET
class AnalysisXML(object):
'''清洗xml'''
def __init__(self):
root = tk.Tk()
root.withdraw()
root.attributes('-topmost', 1)
self.directory = filedialog.askdirectory() # 打开目录选择器
root.destroy()
logger.warning(f'路径选择:【{self.directory}】')
def xml_img_split(self):
'''分割图片和xml'''
logger.info(f'---------------------------分割图片和xml------------------------')
self.images_path = Path(self.directory).parent.joinpath('images')
self.xml_labels_path = Path(self.directory).parent.joinpath('xml_labels')
self.images_path.mkdir(parents=True, exist_ok=True)
self.xml_labels_path.mkdir(parents=True, exist_ok=True)
for i in Path(self.directory).iterdir():
if i.suffix == '.xml':
new_path = self.xml_labels_path.joinpath(i.name)
logger.debug(f'移动:【{i}】 -> 【{new_path}】')
shutil.copy(str(i), str(new_path))
if i.suffix in ('.jpg', '.png'):
new_path = self.images_path.joinpath(i.name)
logger.debug(f'移动:【{i}】 -> 【{new_path}】')
shutil.copy(str(i), str(new_path))
def xml_to_txt(self):
'''xml转txt'''
logger.info(f'----------------------------正在将xml转为txt-----------------------')
self.txt_labels = self.xml_labels_path.joinpath('labels') # 替换为实际的输出TXT文件夹路径
os.makedirs(self.txt_labels, exist_ok=True)
names_set = set()
for filename in os.listdir(self.xml_labels_path):
if filename.endswith('.xml'):
tree = ET.parse(os.path.join(self.xml_labels_path, filename))
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
names_set.add(name)
# 输出所有的name
categories = []
for name in names_set:
categories.append(name)
logger.success(f'标注的内容names:【{categories}】')
category_to_index = {category: index for index, category in enumerate(categories)}
# 遍历输入文件夹中的所有XML文件
for filename in os.listdir(self.xml_labels_path):
if filename.endswith('.xml'):
xml_path = os.path.join(self.xml_labels_path, filename)
logger.warning(f'正在处理:【{xml_path}】')
# 解析XML文件
tree = ET.parse(xml_path)
root = tree.getroot()
# 提取图像的尺寸
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
# 存储name和对应的归一化坐标
objects = []
# 遍历XML中的object标签
for obj in root.findall('object'):
name = obj.find('name').text
if name in category_to_index:
category_index = category_to_index[name]
else:
continue # 如果name不在指定类别中,跳过该object
bndbox = obj.find('bndbox')
xmin = int(bndbox.find('xmin').text)
ymin = int(bndbox.find('ymin').text)
xmax = int(bndbox.find('xmax').text)
ymax = int(bndbox.find('ymax').text)
# 转换为中心点坐标和宽高
x_center = (xmin + xmax) / 2.0
y_center = (ymin + ymax) / 2.0
w = xmax - xmin
h = ymax - ymin
# 归一化
x = x_center / width
y = y_center / height
w = w / width
h = h / height
objects.append(f"{category_index} {x:.6f} {y:.6f} {w:.6f} {h:.6f}")
# 输出结果到对应的TXT文件
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_path = os.path.join(self.txt_labels, txt_filename)
with open(txt_path, 'w') as f:
for obj in objects:
f.write(obj + '\n')
def to_dataset(self, test_ratio):
'''整理为dataset'''
output_folder = os.path.join(os.path.dirname(self.directory), 'datasets')
input_image_folder = self.images_path
input_label_folder = self.txt_labels
train_images_folder = os.path.join(output_folder, 'train', 'images')
train_labels_folder = os.path.join(output_folder, 'train', 'labels')
val_images_folder = os.path.join(output_folder, 'val', 'images')
val_labels_folder = os.path.join(output_folder, 'val', 'labels')
os.makedirs(train_images_folder, exist_ok=True)
os.makedirs(train_labels_folder, exist_ok=True)
os.makedirs(val_images_folder, exist_ok=True)
os.makedirs(val_labels_folder, exist_ok=True)
# 获取所有图像文件列表
images = [f for f in os.listdir(input_image_folder) if f.endswith('.jpg') or f.endswith('.png')]
# 随机打乱图像文件列表
random.shuffle(images)
# 计算验证集的数量
val_size = int(len(images) * test_ratio)
# 划分验证集和训练集
val_images = images[:val_size]
train_images = images[val_size:]
# 复制验证集图像和标签
for image in val_images:
label = os.path.splitext(image)[0] + '.txt'
if os.path.exists(os.path.join(input_label_folder, label)):
shutil.copy(os.path.join(input_image_folder, image), os.path.join(val_images_folder, image))
shutil.copy(os.path.join(input_label_folder, label), os.path.join(val_labels_folder, label))
logger.debug(
f'【{os.path.join(input_image_folder, image)}】 --> 【{os.path.join(val_images_folder, image)}】')
logger.success(
f'【{os.path.join(input_label_folder, label)}】 --> 【{os.path.join(val_labels_folder, label)}】')
else:
logger.error(f"Warning: Label file {label} not found for image {image}")
# 复制训练集图像和标签
for image in train_images:
label = os.path.splitext(image)[0] + '.txt'
if os.path.exists(os.path.join(input_label_folder, label)):
shutil.copy(os.path.join(input_image_folder, image), os.path.join(train_images_folder, image))
shutil.copy(os.path.join(input_label_folder, label), os.path.join(train_labels_folder, label))
logger.debug(
f'【{os.path.join(input_image_folder, image)}】 --> 【{os.path.join(train_images_folder, image)}】')
logger.success(
f'【{os.path.join(input_label_folder, label)}】 --> 【{os.path.join(train_labels_folder, label)}】')
else:
logger.error(f"Warning: Label file {label} not found for image {image}")
def start(self):
'''启动'''
time.sleep(1)
self.xml_img_split()
time.sleep(1)
self.xml_to_txt()
time.sleep(1)
self.to_dataset(0.2)
if __name__ == '__main__':
base_dir = os.path.dirname(__file__)
log_path = os.path.join(base_dir, 'log.log')
if os.path.exists(log_path):
os.unlink(log_path)
logger.add(log_path)
print('...第一层文件夹')
print(' -->第二层文件夹↓')
print(' -->[xml和img混合文件夹]')
print('\n')
status = input('请确认xml和图片在同一个文件夹(99:确认)(任意值:取消):')
if status in (99, '99'):
a = AnalysisXML()
a.start()
logger.success('系统完成')
for i in (3, 2, 1):
time.sleep(1)
logger.success(f'{i}/秒')
else:
logger.error('系统退出!')
for i in (3, 2, 1):
time.sleep(1)
logger.error(f'{i}/秒')