【视觉算法系列2】在自定义数据集上训练 YOLO NAS(上篇)

提示:免费获取本文涉及的完整代码与数据集,请添加微信peaeci122

YOLO-NAS是目前最新的YOLO目标检测模型,它在准确性方面击败了所有其他 YOLO 模型。与之前的 YOLO 模型相比,预训练的 YOLO-NAS 模型能够以更高的准确度检测更多目标。

如何在自定义数据集上训练 YOLO NAS?这将是我这两篇文章的目标------在自定义数据集上训练不同的 YOLO NAS 模型。

YOLO-NAS 的主要优势是,它能比之前的模型更好地检测到较小的物体,虽然可以运行多个推理实验来分析结果,但在具有挑战性的数据集上进行训练会有更好的了解。

这里将使用三个现有的预训练 YOLO-NAS 模型进行四次训练实验,并且选择了一个无人机热成像检测数据集。

目录:

1、用于训练 YOLO NAS 的物体检测数据集

2、在自定义数据集上训练 YOLO NAS

3、微调 YOLO NAS 模型

4、使用经过训练的 YOLO NAS 模型对测试图像进行推理

5、YOLO NAS 训练模型视频推理结果

6、结论

用于训练 YOLO NAS 的物体探测数据集

先来熟悉一下无人机高空红外热数据集。

它包含夜间无人机热图像,由于是无人机的高空记录,大多数物体看起来都很小。这使得大多数物体检测模型都很难解决该数据集的问题。不过,它是一个完美的自定义数据集,可以用来训练 YOLO-NAS,检查 YOLO-NAS对小物体的准确性。

该数据集包含 5 个物体类别的 2898 幅热图像:

1、人

2、汽车

3、自行车

4、其他车辆

5、不关心

数据集已包含训练、验证和测试三部分,有 2008 个训练样本、287 个验证样本和 571 个测试样本,数据集已采用 YOLO 注释格式。

以下是数据集中一些未经标注的地面实况图像。

很显然,除了汽车之外,如果没有适当的标注,人眼是无法看到地面上的其他物体的。

为了了解每个物体的位置,可以看一些有标注的图片。

上图有助于更好地了解数据集中的物体,可以看到,在这个数据集中,我们人眼很难检测到人和自行车等物体,但经过训练后的YOLO-NAS 模型却可以检测到。

在自定义数据集上训练 YOLO NAS

接下来,开始深入研究本文的编码部分,如需下载完整本文代码,请添加老师微信 peaeci122

下面的代码将训练三个 YOLO NAS 模型:

YOLO NAS s(小型)

YOLO NAS m(中型)

YOLO NAS l(大型)

开始之前,需要先安装super-gradients,整个训练和推理过程中都需要它。

在自定义数据集上训练 YOLO NAS

微调 YOLO NAS 模型

第一步是导入所有必要的包和模块。

from super_gradients.training import Trainer
from super_gradients.training import dataloaders
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, 
    coco_detection_yolo_format_val
)
from super_gradients.training import models
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import (
    DetectionMetrics_050,
    DetectionMetrics_050_095
)
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from tqdm.auto import tqdm

import os
import requests
import zipfile
import cv2
import matplotlib.pyplot as plt
import glob
import numpy as np
import random

检查上述代码块中所有的主要导入

1、Trainer:培训师启动培训过程并设置实验目录;

2、dataloaders:Super Gradients 提供自己的数据加载器,轻松训练 YOLO NAS 模型;

3、coco_detection_yolo_format_train, coco_detection_yolo_format_val: 帮助定义训练和验证数据集;

4、models:初始化YOLO NAS模型;

5、PPYoloELoss:YOLO NAS 在训练时使用 PPYoloELoss;

6、DetectionMetrics_050, DetectionMetrics_050_095:在训练时监控 mAP在50% IoU和主要度量。

数据集下载和目录结构

接下来的几个代码块将下载数据集并将其提取到当前目录,这里我们先跳过这些代码块。所有笔记本和数据集都在父数据集目录中,结构如下:

hit-uav
├── dataset.yaml
├── images
│   ├── test
│   ├── train
│   └── val
└── labels
    ├── test
    ├── train
    └── val

训练 YOLO NAS 的数据集参数

准备数据集,进行多次训练实验,然后定义所有数据集路径和类,以及数据集参数字典。

ROOT_DIR = 'hit-uav'
train_imgs_dir = 'images/train'
train_labels_dir = 'labels/train'
val_imgs_dir = 'images/val'
val_labels_dir = 'labels/val'
test_imgs_dir = 'images/test'
test_labels_dir = 'labels/test'
classes = ['Person', 'Car', 'Bicycle', 'OtherVechicle', 'DontCare']

dataset_params = {
    'data_dir':ROOT_DIR,
    'train_images_dir':train_imgs_dir,
    'train_labels_dir':train_labels_dir,
    'val_images_dir':val_imgs_dir,
    'val_labels_dir':val_labels_dir,
    'test_images_dir':test_imgs_dir,
    'test_labels_dir':test_labels_dir,
    'classes':classes 
}

使用 dataset_params,可以轻松创建所需格式的数据集。此外,还需要定义训练周期数、批量大小和数据处理的工作人员数量。

# Global parameters.
EPOCHS = 50
BATCH_SIZE = 16
WORKERS = 8

上述超参数和数据集参数将全部用于三个训练实验。

训练 YOLO NAS 的数据集准备

所有参数都准备就绪后,就可以定义训练数据集和验证数据集了。

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size':BATCH_SIZE,
        'num_workers':WORKERS
    }
)

val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size':BATCH_SIZE,
        'num_workers':WORKERS
    }
)

在上面的代码块中,分别创建train_data和val_data用于训练和验证。

使用dataset_params字典来获取数据集路径和数据集类值,还为每个数据集定义了 dataloader_params,用于接受数据加载过程的批量大小和工人数量。

现阶段先排除测试分割,将其保留到以后使用,以便使用训练有素的 YOLO NAS 模型进行推理。

定义YOLO NAS训练的变换和增强

数据增强是训练任何深度学习模型的重要组成部分,可以让我们在不过度拟合的情况下训练更长时间的稳健模型。在训练过程中,模型还会学习到不同的特征,而这些特征可能并不属于原始数据集。

Super Gradients的数据转换管道为应用和控制增强提供了一种有效的方法。

下面将打印应用于数据集的所有默认转换。

  train_data.dataset.transforms

会看到以下输出

[DetectionMosaic('additional_samples_count': 3, 'non_empty_targets': False, 'prob': 1.0, 'input_dim': [640, 640], 'enable_mosaic': True, 'border_value': 114),

 DetectionRandomAffine('additional_samples_count': 0, 'non_empty_targets': False, 'degrees': 10.0, 
'translate': 0.1, 'scale': [0.1, 2], 'shear': 2.0, 'target_size': [640, 640], 'enable': True, 'filter_box_candidates': True, 'wh_thr': 2, 'ar_thr': 20, 'area_thr': 0.1, 'border_value': 114),

 DetectionMixup('additional_samples_count': 1, 'non_empty_targets': True, 'input_dim': [640, 640], 'mixup_scale': [0.5, 1.5], 'prob': 1.0, 'enable_mixup': True, 'flip_prob': 0.5, 'border_value': 114),
.
.
.
XYCoordinateFormat object at 0x7efcc5c07340>), ('labels', name=labels length=1)]), 'output_format': OrderedDict([('labels', name=labels length=1), ('bboxes', name=bboxes length=4 format=
<super_gradients.training.datasets.data_formats.bbox_formats.cxcywh.CXCYWHCoordinateFormat object at 0x7efcc5976170>)]), 'max_targets': 120, 'min_bbox_edge_size': 1, 'input_dim': [640, 640], 'targets_format_converter': 
<super_gradients.training.datasets.data_formats.format_converter.ConcatenatedTensorFormatConverter object at 0x7efcb956b5b0>)]

在众多增强技术中,MixUp 是其中之一,但是,如果处理不当,MixUp 会导致深度学习模型极难感知数据集。目前通过初步实验,发现数据集的 MixUp 增强会使数据集变得非常混乱,因为它会将一个图像重叠在另一个图像上,移除 Mixup 增强功能并保持其他功能不变会提高性能。

可以通过弹出索引 2 中的元素,从列表中删除 MixUp 增强。

############ An example on how to modify augmentations ###########
train_data.dataset.transforms.pop(2)

可以使用以下代码将转换后的图像可视化

train_data.dataset.plot(plot_transformed_data=True)

除了 MixUp,还对图像进行了马赛克、旋转、放大和缩小等增强处理。

未完待续

相关推荐
良月澪二12 分钟前
CSP-S 2021 T1廊桥分配
算法·图论
Elastic 中国社区官方博客14 分钟前
使用 Vertex AI Gemini 模型和 Elasticsearch Playground 快速创建 RAG 应用程序
大数据·人工智能·elasticsearch·搜索引擎·全文检索
说私域39 分钟前
地理定位营销与开源AI智能名片O2O商城小程序的融合与发展
人工智能·小程序
Q_w77421 小时前
计算机视觉小目标检测模型
人工智能·目标检测·计算机视觉
wangyue41 小时前
c# 线性回归和多项式拟合
算法
&梧桐树夏1 小时前
【算法系列-链表】删除链表的倒数第N个结点
数据结构·算法·链表
创意锦囊1 小时前
ChatGPT推出Canvas功能
人工智能·chatgpt
QuantumStack1 小时前
【C++ 真题】B2037 奇偶数判断
数据结构·c++·算法
知来者逆1 小时前
V3D——从单一图像生成 3D 物体
人工智能·计算机视觉·3d·图像生成
今天好像不上班1 小时前
软件验证与确认实验二-单元测试
测试工具·算法