【深度学习】YOLOv8训练,交通灯目标检测

文章目录

一、数据处理

dart 复制代码
import traceback
import xml.etree.ElementTree as ET
import os
import shutil
import random
import cv2
import numpy as np
from tqdm import tqdm


def convert_annotation_to_list(xml_filepath, size_width, size_height, classes):
    in_file = open(xml_filepath, encoding='UTF-8')
    tree = ET.parse(in_file)
    root = tree.getroot()
    # size = root.find('size')
    # size_width = int(size.find('width').text)
    # size_height = int(size.find('height').text)
    yolo_annotations = []
    # if size_width == 0 or size_height == 0:
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = [float(xmlbox.find('xmin').text),
             float(xmlbox.find('xmax').text),
             float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text)]

        # 标注越界修正
        if b[1] > size_width:
            b[1] = size_width
        if b[3] > size_height:
            b[3] = size_height

        txt_data = [((b[0] + b[1]) / 2.0) / size_width, ((b[2] + b[3]) / 2.0) / size_height,
                    (b[1] - b[0]) / size_width, (b[3] - b[2]) / size_height]
        # 标注越界修正
        if txt_data[0] > 1:
            txt_data[0] = 1
        if txt_data[1] > 1:
            txt_data[1] = 1
        if txt_data[2] > 1:
            txt_data[2] = 1
        if txt_data[3] > 1:
            txt_data[3] = 1
        yolo_annotations.append(f"{cls_id} {' '.join([str(round(a, 6)) for a in txt_data])}")

    in_file.close()
    return yolo_annotations


def main():
    classes = ["red", "green", "yellow", "off"]

    root = r"/ssd/xiedong/lightyolov5"
    img_path_1 = os.path.join(root, "Traffic-Lights-Dataset-Domestic/JPEGImages")
    xml_path_1 = os.path.join(root, "Traffic-Lights-Dataset-Domestic/Annotations")
    img_path_2 = os.path.join(root, "Traffic-Lights-Dataset-Foreign/JPEGImages")
    xml_path_2 = os.path.join(root, "Traffic-Lights-Dataset-Foreign/Annotations")

    dst_yolo_root = os.path.join(root, "Traffic-Lights-Dataset-YOLO")
    dst_yolo_root_img = os.path.join(dst_yolo_root, "images")
    os.makedirs(dst_yolo_root_img, exist_ok=True)
    dst_yolo_root_txt = os.path.join(dst_yolo_root, "labels")
    os.makedirs(dst_yolo_root_txt, exist_ok=True)

    index = 0
    img_path_1_files = os.listdir(img_path_1)
    xml_path_1_files = os.listdir(xml_path_1)
    for img_id in tqdm(img_path_1_files):
        # 右边的.之前的部分
        xml_id = img_id.split(".")[0] + ".xml"
        if xml_id in xml_path_1_files:
            try:
                new_name = f"{index:06d}.jpg"
                img = cv2.imdecode(np.fromfile(os.path.join(img_path_1, img_id), dtype=np.uint8), 1)  # img是矩阵
                cv2.imwrite(os.path.join(dst_yolo_root_img, new_name), img)
                new_txt_name = f"{index:06d}.txt"
                yolo_annotations = convert_annotation_to_list(os.path.join(xml_path_1, img_id[:-4] + ".xml"),
                                                              img.shape[1],
                                                              img.shape[0],
                                                              classes)
                with open(os.path.join(dst_yolo_root_txt, new_txt_name), 'w') as f:
                    f.write('\n'.join(yolo_annotations))
                index += 1
            except:
                traceback.print_exc()

    img_path_1_files = os.listdir(img_path_2)
    xml_path_1_files = os.listdir(xml_path_2)
    for img_id in tqdm(img_path_1_files):
        # 右边的.之前的部分
        xml_id = img_id.split(".")[0] + ".xml"
        if xml_id in xml_path_1_files:
            try:
                new_name = f"{index:06d}.jpg"
                img = cv2.imdecode(np.fromfile(os.path.join(img_path_2, img_id), dtype=np.uint8), 1)  # img是矩阵
                cv2.imwrite(os.path.join(dst_yolo_root_img, new_name), img)
                new_txt_name = f"{index:06d}.txt"
                yolo_annotations = convert_annotation_to_list(os.path.join(xml_path_2, img_id[:-4] + ".xml"),
                                                              img.shape[1],
                                                              img.shape[0],
                                                              classes)
                with open(os.path.join(dst_yolo_root_txt, new_txt_name), 'w') as f:
                    f.write('\n'.join(yolo_annotations))
                index += 1
            except:
                traceback.print_exc()


if __name__ == '__main__':
    main()

二、环境

dart 复制代码
conda create -n py310_yolo8 python=3.10 -y

conda activate py310_yolo8

conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia

pip install ultralytics

data.yaml

yaml 复制代码
path: /ssd/xiedong/lightyolov5/Traffic-Lights-Dataset-YOLO/
train: images
val: images
test: # test images (optional)

# Classes
names:
  0: 'red'
  1: 'green'
  2: 'yellow'
  3: 'off'

三、训练

教程:

https://docs.ultralytics.com/modes/train/#comet

新建训练代码文件train.py

python 复制代码
from ultralytics import YOLO

# Load a model
model = YOLO("yolov8s.pt")  # load a pretrained model (recommended for training)

# Train the model with 2 GPUs
results = model.train(data="data.yaml", epochs=100, imgsz=640, device=[0, 1, 2, 3], batch=128)

开启训练:

dart 复制代码
python -m torch.distributed.run --nproc_per_node 4 train.py

结果会存在这里:

训练截图:

数据分布:

相关推荐
瑶光守护者24 分钟前
【卫星通信】超低码率语音编码ULBC:EnCodec神经音频编解码器架构深度解析
深度学习·音视频·卫星通信·语音编解码·ulbc
Uzuki7 小时前
LLM 指标 | PPL vs. BLEU vs. ROUGE-L vs. METEOR vs. CIDEr
深度学习·机器学习·llm·vlm
2501_9248905212 小时前
商超场景徘徊识别误报率↓79%!陌讯多模态时序融合算法落地优化
java·大数据·人工智能·深度学习·算法·目标检测·计算机视觉
SalvoGao12 小时前
空转学习 | cell-level 与 spot-level的区别
人工智能·深度学习·学习
什么都想学的阿超13 小时前
【大语言模型 15】因果掩码与注意力掩码实现:深度学习中的信息流控制艺术
人工智能·深度学习·语言模型
SHIPKING39314 小时前
【机器学习&深度学习】大模型分布式推理概述:从显存困境到高并发挑战的解决方案
人工智能·深度学习
十八岁牛爷爷18 小时前
通过官方文档详解Ultralytics YOLO 开源工程-熟练使用 YOLO11实现分割、分类、旋转框检测和姿势估计(附测试代码)
人工智能·yolo·目标跟踪
没有梦想的咸鱼185-1037-166320 小时前
AI大模型支持下的:CMIP6数据分析与可视化、降尺度技术与气候变化的区域影响、极端气候分析
人工智能·python·深度学习·机器学习·chatgpt·数据挖掘·数据分析
灵智工坊LingzhiAI1 天前
基于深度学习的中草药识别系统:从零到部署的完整实践
人工智能·深度学习
2501_924731111 天前
智慧矿山误报率↓83%!陌讯多模态融合算法在矿用设备监控的落地优化
人工智能·算法·目标检测·视觉检测