旋转框目标检测自定义数据集训练测试流程

文章目录


前言

旋转框目标检测(Rotated bounding box object detection)是计算机视觉领域的一项技术,它用于检测图像中具有任意方向的目标。与传统的水平矩形框目标检测相比,旋转框目标检测能够更准确地描述物体的形状和位置,尤其是对于那些长宽比差异较大或者方向各异的物体,如遥感图像中的建筑物、文本行、车辆等,本文将详细介绍YOLOV11-OBB自定义数据集训练测试流程,帮您实现旋转框目标检测。

一、数据集制作

本次标注软件采用的是X-AnyLabeling,github地址:windows直接下载。

点击X-AnyLabeling-CPU.exe下载,下载好之后打开,界面如下所示:

使用方法如下所示:

标注好数据集之后标签是json格式,运行下面的代码将数据集转化为yolov11所需要的txt格式:

python 复制代码
def order_points(points):
    # 1. 计算中心点
    center_x = sum([p[0] for p in points]) / 4
    center_y = sum([p[1] for p in points]) / 4

    # 2. 计算每个点相对于中心点的角度,并排序
    def angle_from_center(point):
        return math.atan2(point[1] - center_y, point[0] - center_x)

    # 按角度逆时针排序
    points = sorted(points, key=angle_from_center, reverse=True)

    # 3. 按"右上、右下、左下、左上"的顺序排列
    ordered_points = [points[0], points[1], points[2], points[3]]

    return ordered_points


import math, os, json

# 定义类别映射字典,键为类别名称,值为类别索引
category_mapping = {"box_top": 0}


def order_points(points):
    # 计算四个顶点的中心点坐标
    center_x = sum([p[0] for p in points]) / 4
    center_y = sum([p[1] for p in points]) / 4
    # 计算每个点相对于中心点的角度,按逆时针方向排序,确保点的顺序一致
    points = sorted(points, key=lambda p: math.atan2(p[1] - center_y, p[0] - center_x), reverse=True)
    # 返回按顺序排列的四个点,顺序为"右上、右下、左下、左上"
    return [points[0], points[1], points[2], points[3]]


def convert_json_to_yolo11(json_folder, output_folder, category_mapping):
    os.makedirs(output_folder, exist_ok=True)  # 确保输出文件夹存在,如果不存在则创建
    for filename in os.listdir(json_folder):  # 遍历JSON文件夹中的每个文件
        if filename.endswith('.json'):  # 只处理.json结尾的文件
            json_path = os.path.join(json_folder, filename)
            try:
                with open(json_path, 'r') as f:
                    data = json.load(f)  # 读取JSON文件内容
            except (FileNotFoundError, json.JSONDecodeError) as e:
                print(f"文件读取错误或格式无效:{filename},错误信息:{e}")  # 如果读取出错,输出错误信息并跳过
                continue
            # 获取图像的宽度和高度
            image_width = data.get("imageWidth")
            image_height = data.get("imageHeight")
            if not image_width or not image_height:  # 如果图像尺寸信息缺失,输出警告并跳过该文件
                print(f"图像尺寸信息缺失:{filename}")
                continue
            yolo_lines = []  # 初始化YOLO格式的标注行
            for shape in data.get("shapes", []):  # 遍历JSON中的每个标注对象
                label = shape.get("label")  # 获取标注的类别名称
                if label in category_mapping:  # 检查类别名称是否在类别映射字典中
                    class_index = category_mapping[label]  # 获取对应的类别索引
                else:
                    print(f"未识别的类别标签:{label},跳过该标注")  # 如果类别标签未定义,输出警告并跳过该标注
                    continue
                points = shape.get("points")  # 获取标注的四个顶点坐标
                if len(points) == 4:  # 确保标注包含四个顶点,符合OBB要求
                    ordered_points = order_points(points)  # 使用order_points函数对顶点进行顺序排列
                    # 将顶点坐标归一化到0-1之间,并保留六位有效数字
                    normalized_points = [[round(x / image_width, 6), round(y / image_height, 6)] for x, y in
                                         ordered_points]
                    # 构造YOLO格式的标注行,包含类别索引和四个归一化顶点坐标
                    yolo_line = [class_index] + [coord for point in normalized_points for coord in point]
                    yolo_lines.append(" ".join(map(str, yolo_line)))  # 将标注行添加到YOLO行列表中
            if yolo_lines:  # 如果存在标注数据,则写入到对应的TXT文件中
                txt_filename = os.path.splitext(filename)[0] + ".txt"  # 生成输出TXT文件名
                output_path = os.path.join(output_folder, txt_filename)
                with open(output_path, 'w') as out_file:
                    out_file.write("\n".join(yolo_lines))  # 将所有标注行写入TXT文件
                print(f"转换完成: {output_path}")  # 输出转换完成信息

# 使用示例
json_folder = "/home/build/yhgt/json"  # JSON文件夹路径,需要修改
output_folder = "/home/build/yhgt/txt"  # 输出TXT文件夹路径,需要修改
convert_json_to_yolo11(json_folder, output_folder, category_mapping)

二、模型训练

2.1 划分训练集验证集:


2.2 配置yaml文件:

python 复制代码
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford
# Example usage: python train.py --data VOC.yaml
# parent
# ├── yolov5
# └── datasets
#     └── VOC  ← downloads here


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
#path: ../VOCdevkit_wpeson_Tanker_01-20/VOC2007/ImageSets/Main
train: /home/build/yhgt/images/train/ # train images (relative to 'path')  16551 images

val: /home/build/yhgt/images/val/
 
 # train images (relative to 'path')  16551 images
#val: # val images (relative to 'path')  4952 images
#  - val.txt
#test: # test images (optional)
#  - test.txt

# Classes
nc: 1 # number of classes
names: ['box_top']  # class names


# Download script/URL (optional) ---------------------------------------------------------------------------------------

2.3 训练

python 复制代码
from ultralytics import YOLO

# Load a model
model = YOLO("/home/build/下载/ultralytics-main (1)/yolo11n-obb.pt")  # load a pretrained model (recommended for training)

# Train the model with 2 GPUs
results = model.train(data="/home/build/yhgt/cc.yaml", epochs=200, imgsz=640, device="0,1")
相关推荐
我星期八休息13 分钟前
深入理解跳表(Skip List):原理、实现与应用
开发语言·数据结构·人工智能·python·算法·list
蒋星熠22 分钟前
如何在Anaconda中配置你的CUDA & Pytorch & cuNN环境(2025最新教程)
开发语言·人工智能·pytorch·python·深度学习·机器学习·ai
Hcoco_me36 分钟前
什么是机器学习?
人工智能·机器学习
Code_流苏36 分钟前
AI热点周报(9.7~9.13):阿里Qwen3-Next震撼发布、Claude 增强记忆与服务抖动、OpenAI 聚焦模型规范化...
人工智能·gpt·ai·openai·claude·qwen3-next·架构创新
合作小小程序员小小店37 分钟前
机器学习介绍
人工智能·python·机器学习·scikit-learn·安全威胁分析
这张生成的图像能检测吗40 分钟前
(综述)视觉任务的视觉语言模型
人工智能·计算机视觉·语言模型·自然语言处理·视觉语言模型
聚客AI1 小时前
🚫万能Agent兜底:当规划缺失工具时,AI如何自救
人工智能·llm·agent
Juchecar1 小时前
一文讲清 nn.Module 中 forward 函数被调用时机
人工智能
七牛云行业应用1 小时前
深度解析强化学习(RL):原理、算法与金融应用
人工智能·算法·金融
说私域2 小时前
“开源AI智能名片链动2+1模式S2B2C商城小程序”在直播公屏引流中的应用与效果
人工智能·小程序·开源