用自己的数据集训练YOLO-NAS目标检测器

YOLO-NAS 是 Deci 开发的一种新的最先进的目标检测模型。 在本指南中,我们将讨论什么是 YOLO-NAS 以及如何在自定义数据集上训练 YOLO-NAS 模型。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器

为了训练我们的自定义模型,我们将:

  • 加载预训练的YOLO-NAS模型;
  • 从 Roboflow 加载自定义数据集,或者使用UnrealSynth制作合成数据集
  • 设置超参数值;
  • 使用超级梯度 Python 包根据我们的数据训练模型;
  • 评估模型以了解结果。

话不多说,让我们开始吧!

1、什么是 YOLO-NAS?

You Only Look Once  神经架构搜索(YOLO-NAS)是最新最先进的(SOTA)实时目标检测模型。 在 COCO 数据集上进行评估并与其前身 YOLOv6 和 YOLOv8  相比,YOLO-NAS 以更低的延迟实现了更高的 mAP 值。

YOLO-NAS 作为 Deci 维护的 super-gradient包的一部分提供。

下图展示了Deci在YOLO-NAS上的基准测试结果:

YOLO-NAS 与其他顶级实时检测器在 COCO 数据集上的性能对比图

YOLO-NAS 在 Roboflow 100 数据集基准测试中也是最好的,这表明它可以轻松地在自定义数据集上进行微调。

YOLO-NAS 和其他顶级实时检测器在 RF100 数据集上的性能对比图

2、Python环境设置

在开始训练之前,我们需要准备好Python环境。 让我们从安装三个 pip 包开始。 YOLO-NAS 模型本身是使用 super-gradient 包进行分发的。 请记住,该模型仍在积极开发中。 为了保持环境的稳定性,最好固定特定版本的包。 此外,我们将安装 roboflow 和监督,这将使我们能够从 Roboflow Universe 下载数据集并分别可视化我们的训练结果。

复制代码
pip install super-gradients==3.1.1
pip install roboflow
pip install supervision

如果你在 Jupyter Notebook 中运行 YOLO-NAS,请不要忘记在安装完成后重新启动环境。

3、使用预训练模型进行推理

在开始培训之前,最好确保安装按计划进行。 最简单的方法是使用预先训练的模型之一进行测试推理。 同时,这也能让我们熟悉YOLO-NAS API。

3.1 加载YOLO-NAS模型

为了使用预训练的 COCO 模型进行推理,我们首先需要选择模型的大小。 YOLO-NAS提供三种不同的模型大小:yolo_nas_s、yolo_nas_m和yolo_nas_l。

yolo_nas_s 模型是最小且最快的,但它可能不会像较大的模型那么准确。 相反,yolo_nas_l 模型最大、最准确、最慢。 yolo_nas_m 模型提供了两者之间的中间立场。

复制代码
import torch
from super_gradients.training import models

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ARCH = 'yolo_nas_l'
#            'yolo_nas_m'
#            'yolo_nas_s'

model = models.get(MODEL_ARCH, pretrained_weights="coco").to(DEVICE)

3.2 YOLO-NAS模型推理

推理过程包括设置置信度阈值和调用预测方法。 预测方法将返回预测列表,其中每个预测对应于图像中检测到的对象。

复制代码
CONFIDENCE_TRESHOLD = 0.35

result = list(model.predict(image, conf=CONFIDENCE_TRESHOLD))[0]

YOLO-NAS推理结果图示

3.3 YOLO-NAS 推理输出格式

YOLO-NAS 推理的输出是一个 ImageDetectionPrediction 对象,它封装了图像中检测到的对象的详细信息。 该对象包含三个字段:

  • image - 表示用于推理的图像的 NumPy 数组。
  • class_names - 模型训练期间使用的类别名称的 Python 列表。
  • Prediction -DetectionPrediction 类的实例,其中包含有关模型检测的详细信息。

DetectionPrediction对象具有三个字段:

  • bboxes_xyxy - 形状 (N, 4) 的 NumPy 数组,以 xyxy 格式表示检测到的对象的边界框。
  • confidence - 形状 (N,) 的 NumPy 数组,表示检测的置信度值。 每个值都在 0 和 1 之间。
  • labels - 形状 (N,) 的 NumPy 数组,表示检测到的对象的类 ID。 每个类 ID 对应于 class_names 列表中的一个索引。

4、使用开源数据集微调 YOLO-NAS

为了微调模型,我们需要数据。 我们将使用足球运动员检测图像数据集。

如果你已经有 YOLO 格式的数据集,请随意使用它。 如果没有,请看看 Roboflow Universe,那里拥有超过 200,000 个开源项目,并且所有项目都可以以任何格式导出。

另外一种获取数据集的方法是使用UnrealSynth,一个基于虚幻引擎开发的YOLO合成数据生成器,可以自动生成包括标注的训练数据集,非常方便:

https://tools.nsdt.cloud/UnrealSynth

复制代码
import roboflow
from roboflow import Roboflow

roboflow.login()

rf = Roboflow()
project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(PROJECT_VERSION).download("yolov5")

要训练 YOLO-NAS 模型,你需要设置几个关键参数。

首先,你需要选择模型尺寸。 有三个选项可供选择:小型、中型和大型。 请记住,较大的模型可能需要更长的时间来训练并需要更多的内存,因此如果使用的资源有限,你可能需要考虑使用较小的模型。

接下来,你需要设置批量大小。 该参数指示在训练过程的每次迭代期间将有多少图像通过神经网络。 较大的批量大小将加快训练过程,但也需要更多的内存。

复制代码
MODEL_ARCH = 'yolo_nas_l'
BATCH_SIZE = 8
MAX_EPOCHS = 25
CHECKPOINT_DIR = f'{HOME}/checkpoints'
EXPERIMENT_NAME = project.name.lower().replace(" ", "_")
LOCATION = dataset.location
CLASSES = sorted(project.classes.keys())

dataset_params = {
    'data_dir': LOCATION,
    'train_images_dir':'train/images',
    'train_labels_dir':'train/labels',
    'val_images_dir':'valid/images',
    'val_labels_dir':'valid/labels',
    'test_images_dir':'test/images',
    'test_labels_dir':'test/labels',
    'classes': CLASSES
}

from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val)

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

最后,你需要设置训练过程的纪元数。 这本质上是整个数据集通过神经网络的次数。

5、训练自定义 YOLO-NAS 模型

你可能已经注意到,训练模型的过程比 YOLOv8 更加冗长。 Ultralytics 模型中的许多功能需要在 CLI 中传递参数,而对于 YOLO-NAS,则需要编写自定义逻辑。

最后,我们准备开始训练。 在调用 train 方法之前,值得运行 TensorBoard。 这将使我们能够实时跟踪培训的关键指标。 值得一提的是,YOLO-NAS还支持W&B等最流行的实验记录仪。

YOLO-NAS 训练期间获得的指标图

复制代码
trainer.train(
    model=model, 
    training_params=train_params, 
    train_loader=train_data, 
    valid_loader=val_data
)

6、评估自定义 YOLO-NAS 模型

训练结束后,你可以使用Trainer提供的测试方法评估模型的性能。 你需要传入测试集数据加载器,训练器将返回一个指标列表,包括通常用于评估对象检测模型的平均精度(mAP)。

复制代码
trainer.test(
    model=best_model,
    test_loader=test_data,
    test_metrics_list=DetectionMetrics_050(
        score_thres=0.1, 
        top_k_predictions=300, 
        num_cls=len(dataset_params['classes']), 
        normalize_targets=True, 
        post_prediction_callback=PPYoloEPostPredictionCallback(
            score_threshold=0.01, 
            nms_top_k=1000, 
            max_predictions=300,                                                                              
            nms_threshold=0.7
        )
    )
)

模型评估期间获得的预测与手动标注的比较

此外,你可以对测试集图像进行推理并可视化结果,以更好地了解模型在各个示例上的表现。 你还可以计算混淆矩阵,以更详细地了解每个类别的模型性能:

模型评估过程中创建的混淆矩阵

7、结束语

一夜之间,YOLO-NAS 成为实时物体检测器的新选择。 在为你的项目微调模型时,请记住要考虑所有方面------从模型准确性到推理速度,再到易于训练和许可限制。


原文链接:训练自己的YOLO-NAS --- BimAnt

相关推荐
kisshuan123962 小时前
Ghost卷积瓶颈轻量化改进YOLOv26双阶段压缩与残差学习协同突破
学习·yolo
WJSKad12353 小时前
DPCF双路径交叉融合改进YOLOv26多尺度特征融合精度与效率
yolo
房开民4 小时前
使用cuda核函数加速 yolov5后处理
yolo·macos·cocoa
WJSKad12354 小时前
Ghost瓶颈轻量化改进YOLOv26双路径特征生成与残差学习协同突破
学习·yolo
AAD555888994 小时前
AAttn区域注意力机制改进YOLOv26特征感知与表达能力提升
人工智能·yolo·目标跟踪
kisshuan123964 小时前
Focus空间通道转换改进YOLOv26特征浓缩与深度可分离卷积双重突破
yolo·机器学习·目标跟踪
kisshuan123966 小时前
ERM增强残差融合模块改进YOLOv26多尺度特征融合精度与边缘检测能力
人工智能·深度学习·yolo
童话名剑13 小时前
YOLO v3(学习笔记)
人工智能·深度学习·yolo·目标检测
南极星100516 小时前
视觉项目(k230+opencv+yolo)--云台实时追踪项目
人工智能·opencv·yolo
子夜江寒1 天前
YOLO目标检测算法简介
算法·yolo·目标检测