【深度学习】YOLOv8训练,交通灯目标检测

文章目录

一、数据处理

dart 复制代码
import traceback
import xml.etree.ElementTree as ET
import os
import shutil
import random
import cv2
import numpy as np
from tqdm import tqdm


def convert_annotation_to_list(xml_filepath, size_width, size_height, classes):
    in_file = open(xml_filepath, encoding='UTF-8')
    tree = ET.parse(in_file)
    root = tree.getroot()
    # size = root.find('size')
    # size_width = int(size.find('width').text)
    # size_height = int(size.find('height').text)
    yolo_annotations = []
    # if size_width == 0 or size_height == 0:
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = [float(xmlbox.find('xmin').text),
             float(xmlbox.find('xmax').text),
             float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text)]

        # 标注越界修正
        if b[1] > size_width:
            b[1] = size_width
        if b[3] > size_height:
            b[3] = size_height

        txt_data = [((b[0] + b[1]) / 2.0) / size_width, ((b[2] + b[3]) / 2.0) / size_height,
                    (b[1] - b[0]) / size_width, (b[3] - b[2]) / size_height]
        # 标注越界修正
        if txt_data[0] > 1:
            txt_data[0] = 1
        if txt_data[1] > 1:
            txt_data[1] = 1
        if txt_data[2] > 1:
            txt_data[2] = 1
        if txt_data[3] > 1:
            txt_data[3] = 1
        yolo_annotations.append(f"{cls_id} {' '.join([str(round(a, 6)) for a in txt_data])}")

    in_file.close()
    return yolo_annotations


def main():
    classes = ["red", "green", "yellow", "off"]

    root = r"/ssd/xiedong/lightyolov5"
    img_path_1 = os.path.join(root, "Traffic-Lights-Dataset-Domestic/JPEGImages")
    xml_path_1 = os.path.join(root, "Traffic-Lights-Dataset-Domestic/Annotations")
    img_path_2 = os.path.join(root, "Traffic-Lights-Dataset-Foreign/JPEGImages")
    xml_path_2 = os.path.join(root, "Traffic-Lights-Dataset-Foreign/Annotations")

    dst_yolo_root = os.path.join(root, "Traffic-Lights-Dataset-YOLO")
    dst_yolo_root_img = os.path.join(dst_yolo_root, "images")
    os.makedirs(dst_yolo_root_img, exist_ok=True)
    dst_yolo_root_txt = os.path.join(dst_yolo_root, "labels")
    os.makedirs(dst_yolo_root_txt, exist_ok=True)

    index = 0
    img_path_1_files = os.listdir(img_path_1)
    xml_path_1_files = os.listdir(xml_path_1)
    for img_id in tqdm(img_path_1_files):
        # 右边的.之前的部分
        xml_id = img_id.split(".")[0] + ".xml"
        if xml_id in xml_path_1_files:
            try:
                new_name = f"{index:06d}.jpg"
                img = cv2.imdecode(np.fromfile(os.path.join(img_path_1, img_id), dtype=np.uint8), 1)  # img是矩阵
                cv2.imwrite(os.path.join(dst_yolo_root_img, new_name), img)
                new_txt_name = f"{index:06d}.txt"
                yolo_annotations = convert_annotation_to_list(os.path.join(xml_path_1, img_id[:-4] + ".xml"),
                                                              img.shape[1],
                                                              img.shape[0],
                                                              classes)
                with open(os.path.join(dst_yolo_root_txt, new_txt_name), 'w') as f:
                    f.write('\n'.join(yolo_annotations))
                index += 1
            except:
                traceback.print_exc()

    img_path_1_files = os.listdir(img_path_2)
    xml_path_1_files = os.listdir(xml_path_2)
    for img_id in tqdm(img_path_1_files):
        # 右边的.之前的部分
        xml_id = img_id.split(".")[0] + ".xml"
        if xml_id in xml_path_1_files:
            try:
                new_name = f"{index:06d}.jpg"
                img = cv2.imdecode(np.fromfile(os.path.join(img_path_2, img_id), dtype=np.uint8), 1)  # img是矩阵
                cv2.imwrite(os.path.join(dst_yolo_root_img, new_name), img)
                new_txt_name = f"{index:06d}.txt"
                yolo_annotations = convert_annotation_to_list(os.path.join(xml_path_2, img_id[:-4] + ".xml"),
                                                              img.shape[1],
                                                              img.shape[0],
                                                              classes)
                with open(os.path.join(dst_yolo_root_txt, new_txt_name), 'w') as f:
                    f.write('\n'.join(yolo_annotations))
                index += 1
            except:
                traceback.print_exc()


if __name__ == '__main__':
    main()

二、环境

dart 复制代码
conda create -n py310_yolo8 python=3.10 -y

conda activate py310_yolo8

conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia

pip install ultralytics

data.yaml

yaml 复制代码
path: /ssd/xiedong/lightyolov5/Traffic-Lights-Dataset-YOLO/
train: images
val: images
test: # test images (optional)

# Classes
names:
  0: 'red'
  1: 'green'
  2: 'yellow'
  3: 'off'

三、训练

教程:

https://docs.ultralytics.com/modes/train/#comet

新建训练代码文件train.py

python 复制代码
from ultralytics import YOLO

# Load a model
model = YOLO("yolov8s.pt")  # load a pretrained model (recommended for training)

# Train the model with 2 GPUs
results = model.train(data="data.yaml", epochs=100, imgsz=640, device=[0, 1, 2, 3], batch=128)

开启训练:

dart 复制代码
python -m torch.distributed.run --nproc_per_node 4 train.py

结果会存在这里:

训练截图:

数据分布:

相关推荐
zcg194214 分钟前
如何在CV中使用transformer
人工智能·深度学习·transformer
SuperHeroWu724 分钟前
【MindSpore】MindSpore 开源深度学习框架
人工智能·深度学习·开源·框架·mindspore
weixin_4684668535 分钟前
Airtable 零基础快速上手与实战指南
数据库·人工智能·python·深度学习·ai·大模型
三天不学习44 分钟前
YOLO + .NET 10 快速入门:从零搭建实时目标检测应用
yolo·目标检测·.net
hsg771 小时前
简述:ResNet34/ResNet50及SENet改进模型
人工智能·深度学习
雲明1 小时前
YOLO12目标检测:WebUI界面3步操作指南
目标检测·计算机视觉·webui·yolo12
weixin_468466851 小时前
图像分割新手入门:从环境搭建到实战应用
图像处理·人工智能·深度学习·计算机视觉·ai
深度知识积累AI1 小时前
RK3588 部署 YOLO26 目标检测:从训练到 NPU 推理
yolo
Ai缝合怪 博士1 小时前
【CVPR 2025即插即用】卷积模块篇 | EBlock有效编码器模块,适合低光图像增强、图像分类、实例分割、语义分割、图像去噪、边缘检测、医学图像分割、遥感目标检测等CV任务通用,涨点起飞
目标检测·低光增强·2026顶会顶刊即插即用模块·eblock有效编码器模块·图像分类、实例分割、语义分割·图像去噪、图像去模糊·darkir低光增强模型
阿_旭1 小时前
一文吃透 Grounding DINO:从原理到实战,文本驱动目标检测入门教程【附源码】
人工智能·目标检测·计算机视觉·groundingdino