dataloaders.py
utils\dataloaders.py
目录
[2.def get_hash(paths):](#2.def get_hash(paths):)
[3.def exif_size(img):](#3.def exif_size(img):)
[4.def exif_transpose(image):](#4.def exif_transpose(image):)
[5.def seed_worker(worker_id):](#5.def seed_worker(worker_id):)
[6.def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, close_mosaic=False, quad=False, min_items=0, prefix='', shuffle=False):](#6.def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, close_mosaic=False, quad=False, min_items=0, prefix='', shuffle=False):)
[7.class InfiniteDataLoader(dataloader.DataLoader):](#7.class InfiniteDataLoader(dataloader.DataLoader):)
[8.class _RepeatSampler:](#8.class _RepeatSampler:)
[9.class LoadScreenshots:](#9.class LoadScreenshots:)
[10.class LoadImages:](#10.class LoadImages:)
[11.class LoadStreams:](#11.class LoadStreams:)
[12.def img2label_paths(img_paths):](#12.def img2label_paths(img_paths):)
[13.class LoadImagesAndLabels(Dataset):](#13.class LoadImagesAndLabels(Dataset):)
[14.def flatten_recursive(path=DATASETS_DIR / 'coco128'):](#14.def flatten_recursive(path=DATASETS_DIR / 'coco128'):)
[15.def extract_boxes(path=DATASETS_DIR / 'coco128'):](#15.def extract_boxes(path=DATASETS_DIR / 'coco128'):)
[16.def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):](#16.def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):)
[17.def verify_image_label(args):](#17.def verify_image_label(args):)
[18.class HUBDatasetStats():](#18.class HUBDatasetStats():)
[19.class ClassificationDataset(torchvision.datasets.ImageFolder):](#19.class ClassificationDataset(torchvision.datasets.ImageFolder):)
[20.def create_classification_dataloader(path,imgsz=224, batch_size=16, augment=True, cache=False, rank=-1, workers=8, shuffle=True):](#20.def create_classification_dataloader(path,imgsz=224, batch_size=16, augment=True, cache=False, rank=-1, workers=8, shuffle=True):)
1.所需的库和模块
python
import contextlib
import glob
import hashlib
import json
import math
import os
import random
import shutil
import time
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm
from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
letterbox, mixup, random_perspective)
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
from utils.torch_utils import torch_distributed_zero_first
# 这段代码定义了一些参数,这些参数通常用于配置和运行深度学习模型,特别是在处理图像和视频数据时。
# Parameters
# 定义了一个帮助链接,指向 YOLOv5 官方 GitHub 仓库的 Wiki 页面,提供了关于如何使用 YOLOv5 训练自定义数据集的指南。
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
# 定义了一个元组,包含了支持的图像文件格式。这些格式是在处理图像数据时可能会遇到的,例如 BMP、JPEG、PNG 等。
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
# 定义了一个元组,包含了支持的视频文件格式。这些格式是在处理视频数据时可能会遇到的,例如 MP4、AVI、MOV 等。
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
# 使用 os.getenv 函数获取环境变量 LOCAL_RANK 的值,并将其转换为整数。如果环境变量未设置,则默认为 -1 。
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
# 类似于 LOCAL_RANK ,获取环境变量 RANK 的值,并将其转换为整数。如果环境变量未设置,则默认为 -1 。 RANK 也用于分布式训练中,表示当前进程在整个训练过程中的全局排名或编号。
RANK = int(os.getenv('RANK', -1))
# 获取环境变量 PIN_MEMORY 的值,如果未设置则默认为 True 。将值转换为字符串,并转换为小写。
# 检查 PIN_MEMORY 的值是否等于 'true' ,如果等于,则 PIN_MEMORY 为 True ,否则为 False 。
# PIN_MEMORY 参数通常用于 PyTorch 的 DataLoader 中,用于指示是否将数据加载到锁定的 CUDA 内存中,这可以提高数据传输到 GPU 的效率。
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
# 这些参数为深度学习模型的训练和数据加载提供了灵活的配置选项,特别是在分布式训练和多格式数据处理方面。
# 这段代码是在处理图像的EXIF(Exchangeable Image File Format)数据时,用于检索图像的方向(Orientation)标签。EXIF数据包含了图像的元数据,如拍摄时间、相机设置、图像方向等。
# Get orientation exif tag
# 开始一个循环,遍历 ExifTags.TAGS 字典的所有键。 ExifTags.TAGS 是一个模块级别的字典,它将EXIF标签的数值代码映射到它们对应的字符串名称。
for orientation in ExifTags.TAGS.keys():
# 在循环内部,检查当前键 orientation 对应的值是否等于 'Orientation' 。这是通过访问 ExifTags.TAGS 字典,并比较其值来完成的。
if ExifTags.TAGS[orientation] == 'Orientation':
# 如果找到匹配的键,即EXIF标签的名称为 'Orientation' 的键,则执行 break 语句退出循环。
break
# 这段代码的目的是找到表示图像方向的EXIF标签的键。由于EXIF标签是以数值代码存储的,而我们通常需要知道对应的字符串名称来理解每个标签的含义,因此需要这样的转换。一旦找到 'Orientation' 标签的键,就可以使用这个键来获取和处理图像的方向信息。
2.def get_hash(paths):
python
# 这段代码定义了一个名为 get_hash 的函数,它用于计算一个包含文件或目录路径列表的哈希值。
# 定义了一个名为 get_hash 的函数,它接受一个参数。
# 1.paths :这是一个包含文件或目录路径的列表。
def get_hash(paths):
# Returns a single hash value of a list of paths (files or dirs) 返回路径列表(文件或目录)的单个哈希值。
# 使用列表推导式计算 paths 列表中每个路径的大小总和。 os.path.getsize(p) 获取路径 p 的文件大小。 if os.path.exists(p) 确保路径存在,避免对不存在的文件计算大小。 sum(...) 将所有存在的文件大小加起来。
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
# hashlib.md5([data])
# 在Python中, hashlib.md5() 函数是 hashlib 模块提供的一个方法,用于创建一个新的MD5(Message-Digest Algorithm 5)哈希对象。MD5是一种广泛使用的哈希函数,它可以产生一个128位(16字节)的哈希值。
# 参数 :
# data :(可选)一个字符串(在Python 3中通常是字节串)。如果提供,将立即使用该数据初始化哈希对象。
# 返回值 :
# 返回一个新的MD5哈希对象。
# 请注意,MD5哈希函数已经不再被认为是安全的,因为它容易受到冲突攻击和碰撞攻击。因此,在需要高安全性的场合(如密码存储)中,建议使用更安全的哈希函数,如SHA-256或SHA-3。然而,MD5仍然适用于一些非安全关键的应用,如数据完整性检查。
# 使用 hashlib.md5 创建一个新的 MD5 哈希对象。 str(size).encode() 将大小总和转换为字符串,并编码为字节,以便可以用于哈希计算。
h = hashlib.md5(str(size).encode()) # hash sizes
# 将 paths 列表中的所有路径连接成一个字符串,然后编码为字节。 使用 h.update() 方法更新哈希对象,将路径信息加入到哈希计算中。
h.update(''.join(paths).encode()) # hash path
# h.hexdigest() 方法计算最终的哈希值,并以十六进制字符串的形式返回。
return h.hexdigest() # return hash
# 这个函数的目的是生成一个基于文件大小和路径的单一哈希值,可以用来快速检查文件列表是否发生变化。例如,如果文件的内容或文件列表发生变化,那么生成的哈希值也会不同。这种方法常用于缓存验证,确保在文件内容更新后重新执行某些操作。
3.def exif_size(img):
python
# 这段代码定义了一个名为 exif_size 的函数,它用于返回考虑了EXIF数据(特别是旋转信息)的PIL图像的实际尺寸。
# 定义了一个名为 exif_size 的函数,它接受一个参数。
# 1.img :这是一个PIL(Python Imaging Library)图像对象。
def exif_size(img):
# Returns exif-corrected PIL size 返回经过 exif 校正的 PIL 大小。
# 获取图像的原始尺寸, img.size 返回一个元组 (width, height) 。
s = img.size # (width, height)
# 使用 contextlib.suppress(Exception) 作为一个上下文管理器,用于捕获和忽略在获取EXIF数据时可能发生的任何异常。这样做是因为某些图像可能不包含EXIF信息,或者在读取时可能会遇到错误。
with contextlib.suppress(Exception):
# 尝试从图像的EXIF数据中获取旋转信息。 img._getexif() 返回图像的EXIF数据, .items() 将其转换为键值对列表, dict() 将其转换为字典。 orientation 是之前提到的EXIF标签,用于获取图像的旋转方向。
rotation = dict(img._getexif().items())[orientation]
# 检查旋转值是否为270度或90度。EXIF标准中,6代表90度顺时针旋转,8代表270度顺时针旋转(或90度逆时针旋转)。
if rotation in [6, 8]: # rotation 270 or 90
# 如果图像需要旋转,那么交换宽度和高度的值,因为旋转后的图像尺寸会发生变化。
s = (s[1], s[0])
# 返回调整后的图像尺寸。
return s
# 这个函数的目的是确保在处理带有EXIF旋转信息的图像时,能够返回正确的尺寸。这对于图像显示和处理非常重要,因为不正确的尺寸可能会导致图像显示不正确或者在某些应用中出现问题。通过考虑EXIF数据中的旋转信息, exif_size 函数可以提供更准确的图像尺寸。
4.def exif_transpose(image):
python
# 这段代码是一个名为 exif_transpose 的函数,它的作用是处理图像的EXIF信息,根据图像的方向标签(Orientation)来调整图像的方向,以确保图像是正确的方向显示。EXIF信息是图像文件中存储的元数据,其中包含了图像的各种属性,包括拍摄信息、图像方向等。
# 定义了一个函数 exif_transpose ,它接受一个参数。
# 1.image :这个参数应该是一个图像对象。
def exif_transpose(image):
# 如果 PIL 图像具有 EXIF Orientation 标签,则相应地对其进行转置。
"""
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
:param image: The image to transpose.
:return: An image.
"""
# 获取图像的EXIF信息,并将其存储在变量 exif 中。
exif = image.getexif()
# 从EXIF信息中获取 图像的方向 (Orientation),如果没有找到,则默认为1。
orientation = exif.get(0x0112, 1) # default 1
# 如果方向值大于1,说明图像需要被旋转或翻转。
if orientation > 1:
# 定义了一个字典 method ,它将方向值映射到PIL库(Python Imaging Library,现在称为Pillow)的图像变换方法。根据 orientation 的值,选择相应的变换方法。
# 这段代码是Python字典的一个应用,它映射了EXIF方向标签(Orientation)到Pillow库(PIL)中相应的图像变换方法。每个键值对代表一个方向标签和对应的图像处理操作。
method = {
# 如果方向标签为2,表示图像需要水平翻转(从左到右翻转)。
2: Image.FLIP_LEFT_RIGHT,
# 如果方向标签为3,表示图像需要旋转180度。
3: Image.ROTATE_180,
# 如果方向标签为4,表示图像需要垂直翻转(从上到下翻转)。
4: Image.FLIP_TOP_BOTTOM,
# 如果方向标签为5,表示图像需要进行转置(即旋转90度并交换宽度和高度)。
5: Image.TRANSPOSE,
# 如果方向标签为6,表示图像需要顺时针旋转270度。
6: Image.ROTATE_270,
# 如果方向标签为7,表示图像需要进行对角线翻转(即旋转270度并交换宽度和高度)。
7: Image.TRANSVERSE,
# 如果方向标签为8,表示图像需要顺时针旋转90度。
# .get(orientation) 方法用于从字典中获取与 orientation 值对应的图像变换方法。如果 orientation 在字典中有对应的值,那么 .get(orientation) 将返回该值;如果没有找到对应的值,它将返回 None 。
8: Image.ROTATE_90}.get(orientation)
# 如果找到了对应的变换方法,执行以下操作。
if method is not None:
# Image.transpose(method) -> Image
# Pillow库中的 transpose() 函数用于翻转或旋转图像。
# 参数 :
# method :该参数指定了翻转或旋转操作的类型,必须是以下预定义的常量之一 :
# Image.FLIP_LEFT_RIGHT :水平翻转图像(即,镜像效果)。
# Image.FLIP_TOP_BOTTOM :垂直翻转图像。
# Image.ROTATE_90 :将图像逆时针旋转90度。
# Image.ROTATE_180 :将图像旋转180度。
# Image.ROTATE_270 :将图像逆时针旋转270度(或顺时针旋转90度)。
# Image.TRANSPOSE :交换图像的X和Y轴(即,转置图像)。
# Image.TRANSVERSE :交换图像的X和Y轴,并旋转180度。
# 返回值 :
# 函数返回一个翻转或旋转后的图像副本。
# 使用Pillow库的 transpose 方法对图像进行变换。
image = image.transpose(method)
# 删除EXIF信息中的方向标签,因为图像已经被调整为正确的方向。
del exif[0x0112]
# 将更新后的EXIF信息重新写入图像的 info 属性中。
image.info["exif"] = exif.tobytes()
# 返回调整方向后的图像对象。
return image
# 这个函数的目的是确保图像在显示时不会因为EXIF中的旋转信息而显示错误的方向。这段代码没有处理所有可能的方向值,只处理了常见的几种情况。如果EXIF信息中的方向值不在处理范围内,图像将不会被变换。
5.def seed_worker(worker_id):
python
# 这段代码定义了一个名为 seed_worker 的函数,它用于为 PyTorch 的每个数据加载工作进程设置一个随机种子。
# 定义了一个名为 seed_worker 的函数,它接受一个参数。
# 1.worker_id :这是当前工作进程的编号。
def seed_worker(worker_id):
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader 设置数据加载器工作者种子 https://pytorch.org/docs/stable/notes/randomness.html#dataloader 。
# 使用 PyTorch 的 torch.initial_seed() 函数获取当前进程的随机种子。 计算 2 ** 32 ,即 4294967296 ,这是32位无符号整数的最大值。 通过取模运算 % ,确保 worker_seed 是一个32位整数。
worker_seed = torch.initial_seed() % 2 ** 32
# 使用 NumPy 的 np.random.seed() 函数设置 NumPy 随机数生成器的种子。这确保了 NumPy 生成的随机数在每个工作进程中是可重复的。
np.random.seed(worker_seed)
# 使用 Python 标准库 random 的 random.seed() 函数设置 Python 随机数生成器的种子。这确保了 Python 生成的随机数在每个工作进程中是可重复的。
random.seed(worker_seed)
# 这个函数通常用作 PyTorch DataLoader 的 worker_init_fn 参数,以确保在数据加载过程中,每个工作进程生成的随机数是独立的,从而避免数据加载时的随机性导致的潜在问题。这对于确保模型训练的可重复性非常重要,特别是在使用数据增强或其他需要随机性的操作时。通过为每个工作进程设置不同的随机种子,可以保证每个进程生成的随机数序列是不同的,从而避免了数据加载时的随机性对模型训练的影响。
6.def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, close_mosaic=False, quad=False, min_items=0, prefix='', shuffle=False):
python
# 这段代码定义了一个名为 create_dataloader 的函数,它用于创建一个 PyTorch 数据加载器(DataLoader),这个加载器可以用于深度学习模型的训练或验证。
# 参数解释。
# 1.path : 数据集的路径。
# 2.imgsz : 图像的输入尺寸。
# 3.batch_size : 每个批次的样本数量。
# 4.stride : 模型的步长。
# 5.single_cls : 是否为单类别数据集。
# 6.hyp : 超参数。
# 7.augment : 是否应用数据增强。
# 8.cache : 是否缓存图像。
# 9.pad : 填充比例。
# 10.rect : 是否使用矩形批次。
# 11.rank : 分布式训练的排名。
# 12.workers : 数据加载工作进程的数量。
# 13.image_weights : 是否使用图像权重。
# 14.close_mosaic : 是否关闭马赛克数据增强。
# 15.quad : 是否使用四倍数据增强。
# 16.min_items : 数据集中的最小项目数。
# 17.prefix : 日志前缀。
# 18.shuffle : 是否在每个epoch开始时打乱数据。
def create_dataloader(path,
imgsz,
batch_size,
stride,
single_cls=False,
hyp=None,
augment=False,
cache=False,
pad=0.0,
rect=False,
rank=-1,
workers=8,
image_weights=False,
close_mosaic=False,
quad=False,
min_items=0,
prefix='',
shuffle=False):
# 这段代码是 create_dataloader 函数的一部分,它处理数据加载器的配置和初始化。
# 在深度学习训练中,特别是在目标检测任务中, rect ( rectangular batches ) 和 shuffle 是数据加载器的两个不同参数,它们各自有不同的作用 :
# rect (矩形批次) :
# rect 参数通常用于目标检测任务,它指示数据加载器在创建批次时应该保持图像的原始长宽比,即使这意味着某些图像会被填充或裁剪以适应批次中的其他图像。
# 在矩形批次中,所有图像都被调整到一个固定的尺寸(例如,由 imgsz 参数指定),这有助于模型学习不同尺寸的图像。
# rect 模式通常与 True 一起使用,以确保批次中的所有图像都有相同的尺寸,这对于某些深度学习框架和模型架构是必要的。
# shuffle (数据打乱) :
# shuffle 参数指示数据加载器在每个epoch开始时是否应该随机打乱数据集中的图像顺序。
# 数据打乱有助于模型泛化,因为它确保了模型不会对数据集中的特定顺序产生依赖,并且可以帮助模型学习到更鲁棒的特征。
# 当 shuffle 设置为 True 时,每个epoch的数据顺序都会不同,这有助于模型训练的多样性。
# 为什么不兼容?
# rect 和 shuffle 不兼容的原因在于,当使用 rect 时,数据加载器需要保持图像的原始尺寸和比例,这意味着它不能随机打乱图像,因为这样做可能会导致批次中的图像尺寸不一致。
# 如果在 rect 模式下打乱数据,那么每个批次中的图像可能会有不同的尺寸,这将破坏 rect 模式的目的,即保持批次中所有图像的尺寸一致。
# 因此,当 rect 被设置为 True 时,为了保证批次中图像尺寸的一致性,必须关闭 shuffle ,或者至少不能在打乱数据时改变图像的尺寸。
# 总结来说, rect 和 shuffle 参数在目标检测任务中有不同的用途,它们在处理图像尺寸和顺序方面存在冲突,因此在实际应用中需要根据具体的训练需求和模型架构来选择使用。
# 一个条件语句,检查是否同时设置了 rect (矩形批次)和 shuffle (数据打乱)参数。
if rect and shuffle:
# 如果同时设置了 rect 和 shuffle ,使用 LOGGER 输出一条警告信息,指出这两个参数不兼容,并自动将 shuffle 设置为 False 。
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') # 警告⚠️ --rect 与 DataLoader shuffle 不兼容,请设置 shuffle=False 。
# 将 shuffle 参数强制设置为 False ,以确保数据加载时不会打乱。
shuffle = False
# 使用 torch_distributed_zero_first 作为上下文管理器,确保在分布式数据并行(Distributed Data Parallel, DDP)环境中,只有 rank 为 0 的进程首先初始化数据集的缓存文件(*.cache)。
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
# 创建 LoadImagesAndLabels 类的实例,初始化数据集。
dataset = LoadImagesAndLabels(
# 数据集的路径。
path,
# 图像的输入尺寸。
imgsz,
# 每个批次的样本数量。
batch_size,
# 是否应用数据增强。
augment=augment, # augmentation
# 超参数。
hyp=hyp, # hyperparameters
# 是否使用矩形批次。
rect=rect, # rectangular batches
# 是否缓存图像。
cache_images=cache,
# 是否为单类别数据集。
single_cls=single_cls,
# 模型的步长,转换为整数。
stride=int(stride),
# 填充比例。
pad=pad,
# 是否使用图像权重。
image_weights=image_weights,
# 数据集中的最小项目数。
min_items=min_items,
# 日志前缀。
prefix=prefix)
# 这段代码的目的是确保在分布式训练环境中,数据集的初始化只执行一次,并且处理了 rect 和 shuffle 参数的兼容性问题。通过这种方式,可以避免在分布式训练中不必要的数据重复加载和潜在的数据不一致问题。
# 这段代码是 create_dataloader 函数的后半部分,它负责设置和返回一个 PyTorch 数据加载器(DataLoader)。
# 确保 batch_size 不大于数据集的大小。如果数据集较小,则自动减小 batch_size 。
batch_size = min(batch_size, len(dataset))
# 获取当前系统中可用的 CUDA 设备数量。
nd = torch.cuda.device_count() # number of CUDA devices
# 计算用于数据加载的工作进程数( num_workers )。这个值是 CPU 核心数除以 CUDA 设备数(至少为1), batch_size (如果大于1),和用户提供的 workers 参数中的最小值。
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
# torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
# torch.utils.data.distributed.DistributedSampler 类的构造函数用于创建一个新的分布式采样器实例,它主要用于分布式训练环境中,以确保每个进程只处理数据集的一部分,从而实现数据的均匀分配。
# 参数 :
# dataset ( Dataset ) :要采样的数据集对象。
# num_replicas ( int ,可选) :分布式环境中的总副本(进程)数量。默认值为 None ,在这种情况下,它会尝试从当前的分布式环境变量中获取 world_size 。
# rank ( int ,可选) :当前进程的排名或ID。默认值为 None ,在这种情况下,它会尝试从当前的分布式环境变量中获取 rank 。
# shuffle ( bool ) :是否在每个epoch开始时打乱数据集的采样顺序。默认值为 True 。
# seed ( int ) :用于打乱数据集的随机种子。确保在所有进程中使用相同的种子以获得一致的打乱结果。默认值为 0 。
# drop_last ( bool ) :如果为 True ,则在数据集不能被均匀分配时,丢弃最后一部分数据以确保每个进程处理相同数量的数据。如果为 False ,则可能有些进程会处理更多的数据。默认值为 False 。
# 返回值 :
# 返回一个新的 DistributedSampler 实例。
# DistributedSampler 类在 PyTorch 中用于分布式训练,以下是它的一些常用属性和方法 :
# 属性 :
# dataset : 返回与采样器关联的数据集。
# num_replicas : 返回分布式环境中的总副本(进程)数量。
# rank : 返回当前进程的排名或ID。
# epoch : 返回当前的epoch数。这个属性在每个epoch开始时通过调用 set_epoch() 方法更新。
# 方法 :
# set_epoch(epoch) : 设置当前的epoch数。这对于确保在每个epoch中数据被打乱是必要的,特别是在 shuffle=True 时。
# __iter__() : 返回一个迭代器,该迭代器产生当前epoch中被采样器选中的数据集索引。
# __len__() : 返回当前epoch中被采样器选中的数据集索引的数量。
# update() : 更新采样器的状态,这个方法在 PyTorch 的某些版本中存在,用于重新配置采样器的参数。
# DistributedSampler 的主要作用是确保在分布式训练中,每个进程都能够处理数据集的不同部分,从而提高数据加载的效率和训练的可扩展性。通过在每个epoch开始时调用 set_epoch() 方法,可以确保数据在每个epoch中都被重新打乱,这对于模型的训练是非常重要的。
# 如果 rank 为 -1 (即不在分布式训练环境中),则不使用 sampler 。否则,创建一个 DistributedSampler 实例,它确保在分布式训练中每个进程只处理数据集的一部分。
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
#loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
# 根据 image_weights 或 close_mosaic 参数的值,选择使用标准的 DataLoader 或自定义的 InfiniteDataLoader 。
loader = DataLoader if image_weights or close_mosaic else InfiniteDataLoader
# torch.Generator(device='cpu')
# torch.Generator 是 PyTorch 中用于生成随机数的类。它管理生成伪随机数的算法状态,并在许多需要随机采样的函数中作为关键字参数使用。
# 参数 :
# device :生成器所在的设备,默认为 'cpu' 。也可以设置为 'cuda' ,表示生成器位于特定的 CUDA 设备上。
# 方法 :
# device :返回生成器的当前设备。
# get_state() :以 torch.ByteTensor 的形式返回生成器状态。
# initial_seed() :返回用于生成随机数的初始种子。
# manual_seed(seed) :设置生成随机数的种子,并返回 torch.Generator 对象。任何 32 位整数都是有效的种子。
# seed() :从 std::random_device 或当前时间获取非确定性随机数,并使用它来为生成器提供种子。
# set_state(new_state) :设置生成器的状态。
# torch.Generator 的主要作用是提供可预测和可复现的随机数生成。在分布式训练或需要确保随机性一致性的场景中,可以通过设置相同的种子来确保不同进程或设备上生成相同的随机数序列。这对于调试和复现实验结果非常重要。
# 创建一个 PyTorch 随机数生成器实例。
generator = torch.Generator()
# 为随机数生成器设置种子,确保在分布式训练中每个进程的随机性是独立的。
generator.manual_seed(6148914691236517205 + RANK)
# 返回 创建的数据加载器实例 和 数据集实例。
# 创建的数据加载器实例,包括以下参数 :
# 数据集对象。
return loader(dataset,
# 批次大小。
batch_size=batch_size,
# 是否打乱数据,只有在没有使用 sampler 时才有效。
shuffle=shuffle and sampler is None,
# 工作进程数。
num_workers=nw,
# 分布式训练中的采样器。
sampler=sampler,
# 是否将数据加载到锁定的 CUDA 内存中。
pin_memory=PIN_MEMORY,
# 定制的批处理函数,根据 quad 参数选择不同的函数。
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
# 工作进程初始化函数,用于设置随机种子。
# def seed_worker(worker_id): -> 用于为 PyTorch 的每个数据加载工作进程设置一个随机种子。
worker_init_fn=seed_worker,
# 随机数生成器。
generator=generator), dataset
# 这段代码的目的是配置和返回一个数据加载器,它可以根据是否在分布式训练环境中、是否需要图像权重、是否需要关闭马赛克数据增强等因素进行定制。通过这种方式,可以确保数据加载过程既高效又适合特定的训练需求。
# 这个函数的目的是提供一个灵活且可配置的方式来创建数据加载器,它考虑了多种数据增强和分布式训练的选项。通过这种方式,用户可以根据具体的训练需求来调整数据加载过程。
7.class InfiniteDataLoader(dataloader.DataLoader):
python
# 这段代码定义了一个名为 InfiniteDataLoader 的类,它继承自 PyTorch 的 DataLoader 类。 InfiniteDataLoader 用于创建一个可以无限重复产出数据的加载器,这在某些训练场景中非常有用,比如当数据集较小或者需要重复使用数据集多次时。
class InfiniteDataLoader(dataloader.DataLoader):
# 重用工作器的数据加载器。
# 使用与原始数据加载器相同的语法。
""" Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
# 这段代码是 InfiniteDataLoader 类的构造函数 __init__ 的定义,它继承自 PyTorch 的 DataLoader 类。这个构造函数的目的是初始化一个可以无限重复数据的 DataLoader 。
# 这是 InfiniteDataLoader 类的构造函数,它接受任意数量的位置参数 *args 和关键字参数 **kwargs ,这样它就可以接受与 DataLoader 构造函数相同的参数。
def __init__(self, *args, **kwargs):
# 调用父类 DataLoader 的构造函数,使用传递给 InfiniteDataLoader 的相同参数来初始化 DataLoader 实例。
super().__init__(*args, **kwargs)
# object.__setattr__(name, value)
# 在Python中, object.__setattr__() 是一个特殊方法,用于设置对象的属性。它是 object 类的一个方法,而 object 是Python中所有类的基类。 __setattr__() 方法在设置对象属性时被自动调用,但也可以在子类中被重写以自定义属性赋值的行为。
# 参数 :
# name :要设置的属性的名称。
# value :属性的值。
# 行为 :
# 当对一个对象的属性进行赋值操作时,例如 obj.attr = value ,Python会自动调用该对象的 __setattr__() 方法。这个方法的默认实现会设置一个名为 name 的属性,其值为 value 。
# 为什么使用 object.__setattr__ :
# 在某些情况下,你可能需要直接调用 __setattr__() 方法,特别是当你需要绕过属性赋值的默认行为时。例如,你可能想要在设置属性之前执行一些额外的检查或操作。
# 注意事项 :
# 使用 object.__setattr__() 时,应该谨慎,因为它会绕过属性的正常赋值机制,包括可能的属性监视器或装饰器。
# 在大多数情况下,直接使用 obj.attr = value 就足够了,除非有特殊需求需要自定义属性赋值的行为。
# 使用 object.__setattr__ 方法替换 DataLoader 实例的 batch_sampler 属性。这是因为 DataLoader 的 batch_sampler 属性是只读的,不能直接设置。这里,它被替换为 _RepeatSampler 实例,该实例包装了原始的 batch_sampler ,使得批次可以无限重复。
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
# 获取 DataLoader 实例的迭代器,并将其保存在 self.iterator 中。这样, InfiniteDataLoader 就可以在需要时无限次地迭代产出数据。
self.iterator = super().__iter__()
# 通过这种方式, InfiniteDataLoader 实现了无限重复数据的功能,这对于某些训练场景非常有用,比如当数据集较小或者需要重复使用数据集多次时。这个自定义的 DataLoader 可以在训练循环中使用,而不必担心数据集的大小限制。
# 这段代码定义了 InfiniteDataLoader 类的 __len__ 方法,它用于返回与 InfiniteDataLoader 关联的采样器中的样本数量。
# 这是 InfiniteDataLoader 类的 __len__ 方法,它是一个特殊方法,允许实例使用 len() 函数。
def __len__(self):
# 返回 self.batch_sampler 中采样器的长度。
# self.batch_sampler 是在 InfiniteDataLoader 的构造函数中设置的 _RepeatSampler 实例。
# self.batch_sampler.sampler 指的是 _RepeatSampler 内部包装的原始采样器,即 DataLoader 中的 batch_sampler 。
# len(self.batch_sampler.sampler) 计算并返回这个原始采样器中的样本数量。
return len(self.batch_sampler.sampler)
# 这个方法的实现意味着 InfiniteDataLoader 的长度与内部采样器的长度相同,即使 InfiniteDataLoader 本身可以无限重复产出数据。这在需要知道单个周期内数据集的长度时非常有用,例如,当需要确定训练周期中的迭代次数时。
# 需要注意的是,尽管 InfiniteDataLoader 可以无限重复数据,但在实际训练中,通常会设置一个固定的周期数(epochs),并在每个周期开始时通过调用 sampler.set_epoch(epoch) 方法来更新采样器的状态,以确保数据在每个周期中被打乱。因此, __len__ 方法返回的长度实际上只在第一个周期内有效,之后的周期中数据会被重复使用。
# 这段代码定义了 InfiniteDataLoader 类的 __iter__ 方法,它使得 InfiniteDataLoader 实例可以被迭代,即使它被设计为无限重复产出数据。
# 这是 InfiniteDataLoader 类的 __iter__ 方法,它是一个特殊方法,允许实例使用 for 循环进行迭代。
def __iter__(self):
# range(len(self)) 创建一个无限的迭代器,因为 InfiniteDataLoader 的长度是无限的(或者至少是非常大的,取决于内部采样器的长度)。 这里的 _ 是一个惯用的占位符变量名,表示我们不关心循环的索引值。
for _ in range(len(self)):
# 在每次迭代中,使用 next(self.iterator) 从 self.iterator 中获取下一个批次的数据,并使用 yield 语句产出这个批次。
# self.iterator 是在 InfiniteDataLoader 的构造函数中初始化的,它是一个迭代器,用于迭代 DataLoader 的批次。
# 由于 self.iterator 被封装在一个无限循环中, InfiniteDataLoader 可以无限次地重复产出数据。
yield next(self.iterator)
# 这个方法的实现意味着 InfiniteDataLoader 可以在训练循环中使用,而不必担心数据集的大小限制。它特别适合于数据集较小或者需要重复使用数据集多次的场景。例如,半监督学习或者数据增强技术中可能会用到无限数据加载器。
# 需要注意的是,虽然 InfiniteDataLoader 可以无限重复产出数据,但在实际使用时,通常会在训练循环中设置一个外部的停止条件,比如固定的迭代次数或者达到一定的性能指标后停止训练。
# InfiniteDataLoader 类通常用于创建一个可以无限重复产出数据的加载器,这在某些情况下非常有用,比如在数据增强、半监督学习或者数据集较小时,可以通过无限重复使用数据集来模拟更大的数据集。
8.class _RepeatSampler:
python
# 这段代码定义了一个名为 _RepeatSampler 的类,它是一个包装器(wrapper),用于重复一个给定的采样器(sampler)无限次。
class _RepeatSampler:
# 永远重复的采样器。
""" Sampler that repeats forever
Args:
sampler (Sampler)
"""
# 这是 _RepeatSampler 类的构造函数,它接受一个参数。
# 1.sampler :这是一个采样器对象。
def __init__(self, sampler):
# 构造函数将传入的采样器保存在实例变量 self.sampler 中。
self.sampler = sampler
# 这是 _RepeatSampler 类的 __iter__ 方法,它使得 _RepeatSampler 实例可以被迭代。
def __iter__(self):
# 方法内部有一个无限循环 while True ,这意味着迭代将无限次重复。
while True:
# 在循环内部,使用 yield from 语句从 self.sampler 中迭代并产出( yield )所有元素。
# iter(self.sampler) 调用 self.sampler 的 __iter__ 方法,获取其迭代器,然后 yield from 产出该迭代器中的所有元素。
# 由于循环是无限的, _RepeatSampler 将无限次重复产出 self.sampler 中的元素。
yield from iter(self.sampler)
# _RepeatSampler 类通常用于创建一个可以无限重复产出元素的采样器,这在某些情况下非常有用,例如在数据加载时,当数据集的大小不足以满足训练需求时,可以通过 _RepeatSampler 来重复使用数据集,直到训练完成。
9.class LoadScreenshots:
python
# 这段代码定义了一个名为 LoadScreenshots 的类,它用于从屏幕截图中加载图像数据,通常用于实时数据流或视频流的处理。
class LoadScreenshots:
# YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"` YOLOv5截图数据加载器,即`python detector.py --source "屏幕 0 100 100 512 256"`。
# 这段代码是 LoadScreenshots 类的构造函数 __init__ 的定义,它用于初始化类实例并设置屏幕截图的相关参数。
# 这是 LoadScreenshots 类的构造函数,它接受以下参数。
# 1.self :类的实例本身。
# 2.source :屏幕截图的来源,可以是屏幕编号或包含屏幕位置和尺寸的参数列表。
# 3.img_size :输出图像的尺寸,默认为640像素。
# 4.stride :用于图像处理的步长,默认为32像素。
# 5.auto :是否自动调整图像尺寸,默认为True。
# 6.transforms :可选的图像转换函数。
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
# source = [screen_number left top width height] (pixels)
# 调用 check_requirements 函数来检查 mss 库是否已安装。 mss 是一个用于屏幕截图和监控操作的Python库。
# def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''): -> 它用于检查是否安装了满足YOLO要求的依赖项。如果某些依赖项未安装或版本不兼容,函数会尝试自动安装它们。
check_requirements('mss')
# 导入 mss 库,以便使用其功能进行屏幕截图。
import mss
# 将 source 字符串拆分为多个参数, source 是屏幕编号, *params 是包含屏幕位置和尺寸的参数列表。
source, *params = source.split()
# 初始化 屏幕编号 和 屏幕位置尺寸参数 ,默认为全屏0 。
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
# 参数解析。
# 如果 params 长度为1,设置屏幕编号。
if len(params) == 1:
self.screen = int(params[0])
# 如果 params 长度为4,设置左上角坐标和宽度高度。
elif len(params) == 4:
left, top, width, height = (int(x) for x in params)
# 如果 params 长度为5,同时设置屏幕编号和左上角坐标及宽度高度。
elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params)
# 设置输出图像的尺寸。
self.img_size = img_size
# 设置图像处理的步长。
self.stride = stride
# 设置可选的图像转换函数。
self.transforms = transforms
# 设置是否自动调整图像尺寸。
self.auto = auto
# 设置模式为流模式。
self.mode = 'stream'
# 初始化帧编号为0。
self.frame = 0
# mss.mss()
# mss.mss() 是 Python mss 库中的一个函数,用于创建一个 MSS(Multiple Screen Shots)对象,该对象可以捕获屏幕截图。
# 函数定义 :
# with mss.mss() as sct: # 在此执行屏幕截图操作
# 参数 :无参数。
# 返回值 :
# 返回一个 mss 对象,该对象提供了屏幕截图的相关方法。
# 注意事项 :
# mss.mss() 函数通常与 with 语句一起使用,以确保资源的正确管理。
# sct.monitors 返回的监视器列表中,每个监视器都是一个字典,包含监视器的详细信息,如宽度、高度、左上角坐标等。
# sct.grab() 方法可以接收一个监视器字典或一个包含截图区域坐标的字典,用于指定截图的区域。
# mss.mss() 函数是 mss 库的核心,用于捕获屏幕截图,并提供了灵活的截图选项。
# MSS(Multiple Screen Shots)对象,即 mss.mss() 实例,提供了以下属性和方法 :
# 属性 :
# monitors :一个包含所有监视器信息的列表。每个监视器都是一个字典,包含该监视器的分辨率、位置等信息。
# 方法 :
# grab(monitor) :根据提供的监视器信息或区域截图。 monitor 参数可以是一个字典,包含截图区域的 top 、 left 、 width 、 height 等信息。
# shot(output=None, mon=-1, callback=None, title=None, include_layered=False, bbox=None) :保存第一个监视器的截图。如果提供 output 参数,截图将被保存到指定的文件路径。
# save(bbox, output) :这是一个迭代器,用于保存截图到指定路径。 bbox 参数定义了截图区域, output 参数指定了文件路径。
# 图像数据 :
# rgb : 包含截图的 RGB(去除透明度通道) 图像数据。
# bgra : 包含截图的 BGRA(包含透明度通道) 图像数据。
# 其他工具 :
# mss.tools.to_png(rgb, size, output) : 将 RGB 图像数据保存为 PNG 文件。
# MSS 对象提供了一个高效且灵活的方式来捕获屏幕截图,并支持多种操作系统。通过这些属性和方法,可以轻松地实现全屏截图、部分屏幕截图以及将截图保存为文件。
# 创建 mss 实例,用于进行屏幕截图。
self.sct = mss.mss()
# 监控器形状解析。
# Parse monitor shape
# 获取监控器的形状信息。
monitor = self.sct.monitors[self.screen]
# 如果 top 和 left 为 None ,则使用监控器的原始位置;否则,将偏移量加到监控器的原始位置上。
self.top = monitor["top"] if top is None else (monitor["top"] + top)
self.left = monitor["left"] if left is None else (monitor["left"] + left)
# 如果 width 和 height 为 None ,则使用监控器的原始尺寸;否则,使用指定的宽度和高度。
self.width = width or monitor["width"]
self.height = height or monitor["height"]
# 创建一个字典,保存监控器的位置和尺寸信息。
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
# 这个构造函数的主要作用是初始化屏幕截图的相关参数,并准备 mss 实例以便进行屏幕截图操作。通过这种方式, LoadScreenshots 类可以灵活地从指定的屏幕区域捕获图像,并根据需要进行图像处理。
# 这段代码定义了 LoadScreenshots 类的 __iter__ 方法,它是一个特殊方法,用于表示该类的对象是可迭代的。
# 这个方法没有接受任何参数(除了 self ),并且返回对象本身 self 。
def __iter__(self):
return self
# 在 Python 中,当一个对象需要被用作迭代器时,它必须实现 __iter__ 方法。这个方法应该返回一个迭代器,该迭代器能够产生序列中的下一个元素。
# 在 LoadScreenshots 类中, __iter__ 方法返回对象本身,这意味着 LoadScreenshots 类实例被用作迭代器。这通常意味着类中还应该实现 __next__ 方法,以提供迭代过程中的元素。
# 用途 :
# 这个方法使得 LoadScreenshots 类的对象可以被用于 for 循环或任何期望迭代器的上下文中。
# 当你使用 for 循环遍历 LoadScreenshots 的实例时,Python 会自动调用 __iter__ 方法来获取迭代器,然后不断调用 __next__ 方法来获取下一个元素,直到 __next__ 方法引发 StopIteration 异常,表示迭代结束。
# 这段代码定义了 LoadScreenshots 类的 __next__ 方法,它是 Python 迭代协议的一部分,用于在迭代过程中产生下一个元素。
# 这个方法是 LoadScreenshots 类的 __next__ 方法,它在每次迭代时被调用以产生下一个元素。
def __next__(self):
# mss screen capture: get raw pixels from the screen as np array mss 屏幕截图:从屏幕获取原始像素作为 np 数组。
# 使用 mss 库的 grab 方法从屏幕捕获原始像素数据,并将其转换为 NumPy 数组。 self.monitor 包含了要捕获的屏幕区域的位置和尺寸信息。 [:, :, :3] 将 BGRA 格式的像素数据转换为 BGR 格式。
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
# 创建一个字符串 s ,包含屏幕编号和捕获区域的位置尺寸信息。
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " # 屏幕 {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: 。
# 如果提供了 transforms 函数,则使用该函数对图像 im0 进行转换。
if self.transforms:
im = self.transforms(im0) # transforms
# 如果没有提供 transforms 函数,则使用 letterbox 函数对图像进行填充和尺寸调整。
else:
# etterbox 函数用于将图像调整到指定的尺寸 self.img_size ,同时保持图像的长宽比。
# def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# -> 用于将输入图像 im 调整大小并填充,以适应新的尺寸 new_shape ,同时保持图像的宽高比,并确保结果图像的尺寸是 stride 的倍数。返回 填充后的图像 、 缩放比例 和 填充尺寸 。
# -> return im, ratio, (dw, dh)
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
# 将图像从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式,并将 BGR 颜色通道顺序转换为 RGB。
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
# numpy.ascontiguousarray(a, dtype=None)
# np.ascontiguousarray 是 NumPy 库中的一个函数,它用于返回一个连续的内存块数组。如果输入数组已经是连续的,它将返回原数组;如果不是连续的,它将返回一个新数组,该数组是原数组的连续副本。
# 参数 :
# a :输入数组。
# dtype :数据类型,可选参数。如果指定,将输入数组转换为指定的数据类型。
# 返回值 :
# 返回一个连续的数组。用途 np.ascontiguousarray 通常用于确保数组在内存中是连续存储的,这在某些情况下是非常重要的,例如 :
# 当数组需要被传递给某些期望连续数据的库或函数时。
# 在进行内存操作或数组切片时,连续的数组可以提高性能。
# 确保图像数据在内存中是连续的,这对于某些深度学习框架是必要的。
im = np.ascontiguousarray(im) # contiguous
# 增加帧编号,用于跟踪迭代的进度。
self.frame += 1
# 返回一个元组,包含 屏幕编号 、 转换后的图像 、 原始图像 、一个 None 值 (可能用于其他数据,但在这种情况下没有使用) 和 屏幕信息字符串 。
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
# __next__ 方法使得 LoadScreenshots 类可以作为一个迭代器使用,它在每次迭代时捕获屏幕的当前状态,并将其转换为适合深度学习模型输入的格式。这对于实时屏幕监控、游戏流媒体、或者其他需要实时图像处理的应用非常有用。
# 这个类的主要用途是实时从屏幕捕获图像,并将其转换为适合深度学习模型输入的格式。这在实时目标检测、图像分类或其他需要实时图像处理的应用中非常有用。
10.class LoadImages:
python
# 这段代码定义了一个名为 LoadImages 的类,它用于加载图像和视频文件,准备它们以供深度学习模型使用。
class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` YOLOv5 图像/视频数据加载器,即"python detect.py --source image.jpg/vid.mp4"。
# 这段代码是 LoadImages 类的构造函数 __init__ 的定义,它用于初始化类实例并准备加载图像或视频文件。
# 这是 LoadImages 类的构造函数,它接受以下参数 :
# 1.self :类的实例本身。
# 2.path :图像或视频文件的路径,可以是单个路径、路径列表或通配符模式。
# 3.img_size :输出图像的尺寸,默认为640。
# 4.stride :用于图像处理的步长,默认为32。
# 5.auto :是否自动调整图像尺寸,默认为True。
# 6.transforms :可选的图像转换函数。
# 7.vid_stride :视频帧率步长,默认为1。
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
# 初始化一个空列表 files 用于存储所有找到的文件路径。
files = []
# 在Python中, sorted() 函数可以对任何可迭代的对象进行排序,包括列表、元组、字符串等。当 sorted(path) 被调用时,它会按照以下规则对路径进行排序 :
# 字符串排序 :如果 path 是一个字符串列表或元组, sorted() 函数会按照字符串的字典顺序(lexicographical order)进行排序。这意味着字符串会按照字符从左到右的顺序进行比较,类似于字典中单词的排序方式。
# 数字排序 :如果 path 中的元素是数字, sorted() 函数会按照数值大小进行排序。
# 混合类型 :如果 path 包含混合类型的元素(例如,同时包含字符串和数字),则在Python 3中,这种混合类型的排序会引发 TypeError 。但在Python 2中,数字会被排序到字符串前面,因为字符串在比较时会被转换为它们的ASCII值。
# 自定义排序 :可以通过传递一个 key 参数给 sorted() 函数来自定义排序逻辑。例如,如果路径包含文件名,可以提供一个键函数来按照文件名的长度、修改日期或其他标准进行排序。
# 在 LoadImages 类的上下文中, path 参数通常是一个包含文件路径的列表或元组。因此, sorted(path) 会按照文件路径的字典顺序进行排序。这意味着路径会按照字符串的字母顺序进行排序,不考虑文件系统中的实际路径结构。
# 检查 path 是否为列表或元组,如果是,则遍历排序后的路径;如果不是,则遍历包含单个路径的列表。
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
# 将路径 p 转换为字符串,并使用 Path.resolve() 解析其绝对路径。
p = str(Path(p).resolve())
# 如果路径中包含通配符 * ,则使用 glob.glob() 匹配所有文件,并扩展到 files 列表中。
if '*' in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
# 如果路径是一个目录,则使用 glob.glob() 匹配该目录下的所有文件,并扩展到 files 列表中。
elif os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
# 如果路径是一个文件,则将其添加到 files 列表中。
elif os.path.isfile(p):
files.append(p) # files
# 如果路径不存在,则抛出 FileNotFoundError 异常。
else:
raise FileNotFoundError(f'{p} does not exist') # {p} 不存在。
# 从 files 列表中筛选出图像文件,基于文件扩展名是否在 IMG_FORMATS 元组中。
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
# 从 files 列表中筛选出视频文件,基于文件扩展名是否在 VID_FORMATS 元组中。
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
# 计算图像文件和视频文件的数量。
ni, nv = len(images), len(videos)
# 设置输出图像的尺寸。
self.img_size = img_size
# 设置图像处理的步长。
self.stride = stride
# 合并 图像 和 视频 文件列表,并存储在 self.files 中。
self.files = images + videos
# 计算总文件数量。
self.nf = ni + nv # number of files
# 创建一个标志列表,用于区分图像和视频文件。
self.video_flag = [False] * ni + [True] * nv
# 初始化模式为 'image'。
self.mode = 'image'
# 设置自动调整图像尺寸的标志。
self.auto = auto
# 设置可选的图像转换函数。
self.transforms = transforms # optional
# 设置视频帧率步长。
self.vid_stride = vid_stride # video frame-rate stride
# 如果有视频文件,则创建一个新的视频捕获对象;否则,设置 self.cap 为 None 。
if any(videos):
self._new_video(videos[0]) # new video
else:
self.cap = None
# 确保至少找到了一个图像或视频文件,否则抛出异常。
assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}' # 在 {p} 中未找到任何图片或视频。支持的格式为:\n图片:{IMG_FORMATS}\n视频:{VID_FORMATS}。
# 这个构造函数的主要作用是初始化 LoadImages 类实例,准备加载图像和视频文件,并设置相关的参数和标志。通过这种方式, LoadImages 类可以灵活地处理各种图像和视频文件,并为深度学习模型提供输入数据。
# 这段代码定义了 LoadImages 类的 __iter__ 方法,它是一个特殊方法,用于表示该类的对象是可迭代的。
# 这个方法是 LoadImages 类的 __iter__ 方法,它在对象需要被迭代时被调用。
def __iter__(self):
# 初始化一个计数器 count ,用于跟踪当前迭代的位置。这个计数器在每次迭代时递增,以确保迭代可以按顺序访问每个元素。
self.count = 0
# 返回对象本身,表明这个类的实例本身就是迭代器。这意味着你可以使用 for 循环直接迭代 LoadImages 的实例。
return self
# 这个方法使得 LoadImages 类的对象可以被用于 for 循环或任何期望迭代器的上下文中。
# 当你使用 for 循环遍历 LoadImages 的实例时,Python 会自动调用 __iter__ 方法来获取迭代器,然后不断调用 __next__ 方法来获取下一个元素,直到 __next__ 方法引发 StopIteration 异常,表示迭代结束。
# 这段代码定义了 LoadImages 类的 __next__ 方法,它是 Python 迭代协议的一部分,用于在迭代过程中产生下一个元素。
# 这个方法是 LoadImages 类的 __next__ 方法,它在每次迭代时被调用以产生下一个元素。
def __next__(self):
# 检查是否已经迭代完所有的文件(图像和视频)。
if self.count == self.nf:
# 如果已经迭代完所有的文件,抛出 StopIteration 异常,表示迭代结束。
raise StopIteration
# 获取当前迭代的文件路径。
path = self.files[self.count]
# 检查当前文件是否是视频。
if self.video_flag[self.count]:
# Read video
# 设置模式为 'video'。
self.mode = 'video'
# 根据视频帧率步长,抓取视频帧。
for _ in range(self.vid_stride):
# ret, frame = cap.grab()
# 在计算机视觉和图像处理领域, grab() 函数通常与视频捕获相关,特别是在使用 OpenCV (cv2) 库时。 grab() 函数用于从视频流或相机捕获设备中获取一帧图像,而不会将其从队列中移除。这意味着,如果连续调用 grab() 函数,它会连续读取视频流中的后续帧。
# cap :是一个 VideoCapture 对象,它表示一个视频流或相机捕获设备。
# ret :是一个布尔值,表示是否成功获取了帧。
# frame :是一个图像矩阵,如果 ret 为 True ,则 frame 包含捕获的帧数据。
# 行为 :
# 当 grab() 被调用时,它会尝试从视频流或相机捕获设备中读取下一帧,但不会自动将该帧从队列中移除。因此,如果再次调用 grab() ,它会读取下一帧,而不是重复读取同一帧。
# 如果视频流结束或捕获设备没有数据, grab() 将返回 False ,表示没有更多的帧可以读取。
self.cap.grab()
# ret, frame = cap.retrieve()
# 在 OpenCV (cv2) 库中, retrieve() 方法是与 VideoCapture 对象一起使用的,用于在调用 grab() 方法之后检索(读取)最近捕获的视频帧。这个方法通常在视频流处理中使用,特别是在需要对捕获的帧进行进一步处理时。
# cap :是一个 VideoCapture 对象,它表示一个视频流或相机捕获设备。
# ret :是一个布尔值,表示是否成功检索到帧。
# frame :是一个图像矩阵,如果 ret 为 True ,则 frame 包含检索到的帧数据。
# 行为 :
# 当 retrieve() 被调用时,它会返回 grab() 方法最近捕获的帧。如果 grab() 方法成功捕获了帧, retrieve() 将返回该帧的数据。
# 如果 grab() 方法没有成功捕获帧,或者视频流结束, retrieve() 将返回 False ,并且 frame 将不包含有效的数据。
# 请注意, read() 方法在内部调用 grab() 和 retrieve() ,所以通常不需要单独调用这两个方法。直接使用 read() 方法可以简化代码并提高效率。
# 从视频捕获对象中检索帧。
ret_val, im0 = self.cap.retrieve()
# 如果检索到的帧不有效,处理异常情况。
while not ret_val:
# 增加迭代计数器。
self.count += 1
# ret = cap.release()
# 在 OpenCV (cv2) 库中, release() 方法是 VideoCapture 类的一个成员函数,用于释放由 VideoCapture 对象所占用的资源。这通常在完成视频捕获或处理后被调用,以确保系统资源被正确释放,特别是在使用视频文件或相机流时。
# cap :是一个 VideoCapture 对象,它表示一个视频流或相机捕获设备。
# ret :是一个布尔值,表示释放操作是否成功。在最新版本的 OpenCV 中, release() 方法不返回任何值。
# 行为 :
# 当 release() 被调用时,它会释放与 VideoCapture 对象相关联的所有资源,包括打开的视频文件或相机设备。
# 在调用 release() 之后, VideoCapture 对象不再可用,任何进一步的 grab() 、 retrieve() 或 read() 调用都会失败。
# 注意事项 :
# 确保在程序结束前调用 release() 方法,特别是在异常处理或多个返回路径的情况下,以避免资源泄露。
# 在 Python 中,使用 with 语句可以自动管理资源的释放,使用 with 语句可以确保即使在发生异常时,资源也能被正确释放。
# 释放当前视频捕获对象。
self.cap.release()
# 如果已经迭代完所有的视频,抛出 StopIteration 异常。
if self.count == self.nf: # last video
raise StopIteration
# 获取下一个文件路径。
path = self.files[self.count]
# 为下一个视频创建新的视频捕获对象。
self._new_video(path)
# ret, frame = cap.read()
# 在 OpenCV (cv2) 库中, read() 方法是 VideoCapture 类的一个成员函数,用于从视频文件或相机捕获设备中读取帧。这个方法结合了 grab() 和 retrieve() 两个步骤,先捕获一帧,然后检索它。
# cap :是一个 VideoCapture 对象,它表示一个视频流或相机捕获设备。
# ret :是一个布尔值,表示是否成功读取帧。
# frame :是一个图像矩阵,如果 ret 为 True ,则 frame 包含读取的帧数据。
# 行为 :
# 当 read() 被调用时,它会尝试从 VideoCapture 对象关联的视频流或相机中读取下一帧。
# 如果成功读取帧, ret 将为 True ,并且 frame 将包含该帧的图像数据。
# 如果没有更多的帧可以读取(例如,视频结束或相机关闭), ret 将为 False ,并且 frame 将不包含有效的数据。
# 读取下一个视频帧。
ret_val, im0 = self.cap.read()
# 增加视频帧计数器。
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
# 创建一个字符串,包含视频的迭代信息。
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
# 如果当前文件是图像,而不是视频。
else:
# Read image
# 增加迭代计数器。
self.count += 1
# 使用 cv2.imread 读取图像。
im0 = cv2.imread(path) # BGR
# 确保图像被成功读取。
assert im0 is not None, f'Image Not Found {path}' # 未找到图片 {path}。
# 创建一个字符串,包含图像的迭代信息。
s = f'image {self.count}/{self.nf} {path}: '
# 如果提供了转换函数。
if self.transforms:
# 应用转换函数。
im = self.transforms(im0) # transforms
# 如果没有提供转换函数。
else:
# 使用 letterbox 函数对图像进行填充和尺寸调整。
# def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# -> 用于将输入图像 im 调整大小并填充,以适应新的尺寸 new_shape ,同时保持图像的宽高比,并确保结果图像的尺寸是 stride 的倍数。返回 填充后的图像 、 缩放比例 和 填充尺寸 。
# -> return im, ratio, (dw, dh)
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
# 将图像从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式,并将 BGR 颜色通道顺序转换为 RGB。
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
# 确保图像数据在内存中是连续的。
im = np.ascontiguousarray(im) # contiguous
# 返回 文件路径 、 转换后的图像 、 原始图像 、 视频捕获对象 和 文件信息字符串 。
return path, im, im0, self.cap, s
# 这个方法的主要作用是迭代加载图像和视频文件,并将其转换为适合深度学习模型输入的格式。这对于实时屏幕监控、游戏流媒体、或者其他需要实时图像处理的应用非常有用。
# 这段代码定义了 LoadImages 类中的一个私有方法 _new_video ,它用于创建一个新的视频捕获对象,并初始化与视频相关的一些属性。
# 这个方法是 LoadImages 类的一个私有方法,不接受外部调用,仅供类内部使用。 它接受一个参数。
# 1.path :这是要加载的视频文件的路径。
def _new_video(self, path):
# Create a new video capture object 创建新的视频捕获对象。
# 初始化视频帧计数器 frame 为 0。这个计数器将用于跟踪当前处理的视频帧编号。
self.frame = 0
# 创建一个新的 cv2.VideoCapture 对象 cap ,并用 path 参数初始化它。这个对象将用于从视频文件中捕获帧。
self.cap = cv2.VideoCapture(path)
# 调用 get 方法获取视频的总帧数,并除以视频帧率步长 vid_stride ,然后将结果转换为整数。这个属性 frames 表示在考虑步长后,视频中可读取的帧数。
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
# 调用 get 方法获取视频的元数据中的 CAP_PROP_ORIENTATION_META 属性,这个属性包含了视频的旋转角度信息。然后将结果转换为整数,并存储在 self.orientation 中。
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# 这行代码被注释掉了,如果取消注释,它将设置视频捕获对象的自动旋转属性为关闭状态。这意味着视频帧将不会自动旋转以适应设备的屏幕方向。
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
# 这个方法的主要作用是为类准备一个新的视频捕获对象,并根据视频文件的属性初始化一些重要的类属性。这样,当需要处理新视频时,可以重用 LoadImages 类的实例,并调用 _new_video 方法来更新状态。
# 这段代码定义了 LoadImages 类中的一个私有方法 _cv2_rotate ,它用于根据视频的旋转元数据手动旋转视频帧。
# 这个方法是 LoadImages 类的一个私有方法,不接受外部调用,仅供类内部使用。它接受一个参数。
# 1.im :这是要旋转的视频帧图像。
def _cv2_rotate(self, im):
# Rotate a cv2 video manually 手动旋转 cv2 视频。
# 检查视频的旋转角度是否为0度。
if self.orientation == 0:
# cv2.rotate(src, rotateCode)
# cv2.rotate 是 OpenCV (cv2) 库中的一个函数,用于旋转图像。这个函数提供了一个方便的方式来旋转图像到指定的角度。
# 参数 :
# src :输入图像,即需要旋转的源图像。
# rotateCode :旋转的类型,可以是以下值之一 :
# cv2.ROTATE_90_CLOCKWISE :顺时针旋转90度。
# cv2.ROTATE_180 :旋转180度。
# cv2.ROTATE_90_COUNTERCLOCKWISE :逆时针旋转90度。
# 返回值 :
# 函数返回旋转后的图像。
# 请注意, cv2.rotate 函数在旋转图像时会保持图像的原始尺寸,但旋转后的图像可能会超出原始画布,导致部分图像被裁剪。如果需要保持图像内容的完整性,可能需要在旋转前对图像进行适当的缩放或填充。
# 如果旋转角度为0度,则将图像顺时针旋转90度。
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
# 检查视频的旋转角度是否为180度。
elif self.orientation == 180:
# 如果旋转角度为180度,则将图像逆时针旋转90度,这等同于旋转两次90度。
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
# 检查视频的旋转角度是否为90度。
elif self.orientation == 90:
# 如果旋转角度为90度,则将图像旋转180度。
return cv2.rotate(im, cv2.ROTATE_180)
# 如果视频的旋转角度不是0度、90度或180度中的任何一个,那么不进行旋转,直接返回原始图像。
return im
# 这个方法的主要作用是根据视频的旋转元数据手动旋转视频帧,以确保视频帧的方向正确。这在处理从手机或其他设备录制的视频时特别有用,因为这些设备可能会根据设备的朝向自动旋转视频帧。通过手动旋转,可以确保视频帧在处理和显示时方向正确。
def __len__(self):
return self.nf # number of files
# 这个类的主要用途是加载图像和视频文件,并将其转换为适合深度学习模型输入的格式。这在实时目标检测、图像分类或其他需要实时图像处理的应用中非常有用。
11.class LoadStreams:
python
# 这段代码定义了一个名为 LoadStreams 的类,它用于从多个视频流(包括IP摄像头、视频文件或YouTube视频)中加载图像数据。
class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` YOLOv5 流加载器,即 `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP、RTMP、HTTP 流`。
# 这段代码是 LoadStreams 类的 __init__ 方法,它负责初始化类实例并设置视频流的读取。
# 这是类的构造函数(初始化方法),它定义了创建类实例时可以传入的参数。
# 1.sources : 视频流的来源,可以是一个包含视频流URL或路径的文本文件,默认为 'streams.txt' 。
# 2.img_size : 处理视频流时的目标图像尺寸,默认为 640 。
# 3.stride : 模型输入时的步长,影响模型的下采样率,默认为 32 。
# 4.auto : 是否自动调整图像大小以适应模型输入,默认为 True 。
# 5.transforms : 可选的图像变换函数,可以对图像进行预处理。
# 6.vid_stride : 视频帧率步长,用于控制从视频流中读取帧的频率,默认为 1 。
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
# 这段代码是 LoadStreams 类的构造函数 __init__ 的一部分,它负责初始化类实例并准备从视频流中加载图像数据。
# 设置 PyTorch 的 cuDNN 后端为基准测试模式,这可以加速固定大小输入的推理过程。
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
# 设置类的一个属性 mode ,表示当前模式为处理视频流。
self.mode = 'stream'
# 将传入的 img_size 参数值赋给类的 img_size 属性。
self.img_size = img_size
# 将传入的 stride 参数值赋给类的 stride 属性。
self.stride = stride
# 将传入的 vid_stride 参数值赋给类的 vid_stride 属性,这个属性用于控制视频帧的读取频率。
self.vid_stride = vid_stride # video frame-rate stride
# Path.read_text(encoding='utf-8', errors='strict')
# read_text() 函数是 Python 标准库 pathlib 模块中的 Path 类的一个方法,用于读取文件内容并返回一个字符串。
# 参数 :
# encoding :一个字符串参数,指定用于解码文件内容的编码,默认为 'utf-8' 。
# errors :一个字符串参数,指定如何处理解码错误,默认为 'strict' ,意味着遇到错误会抛出异常。
# 返回值 :
# 方法返回一个字符串,包含文件的内容。
# 异常 :
# 如果文件无法读取或解码过程中出现错误,方法将引发异常。
# 注意事项 :
# Path 对象需要指向一个存在的文件,否则在尝试读取时会抛出 FileNotFoundError 。
# 如果文件包含无法用指定编码解码的字符,并且 errors 参数设置为 'strict' ,则会抛出 UnicodeDecodeError 。
# 你可以指定不同的编码和错误处理策略,例如 encoding='latin-1' 或 errors='ignore' ,以适应不同的文件内容和需求。
# 通过使用 read_text() 方法,你可以方便地以字符串形式读取文件内容,这在处理文本文件时非常有用。
# str.rsplit(separator=None, maxsplit=-1)
# 在Python中, rsplit() 是字符串( str )对象的一个方法,用于在字符串末尾进行分割操作。这个方法从字符串的右侧(末尾)开始分割,而不是默认的左侧(开头)。
# 参数 :
# separator : 分隔符,用于指定分隔字符串的字符或字符串。如果未指定或为 None ,则任何空白字符(如空格、换行 \n 、制表符 \t 等)都被视为分隔符。
# maxsplit :最大分割次数。如果设置为 -1 (默认值),则没有分割次数限制,字符串会被完全分割直到没有分隔符为止。如果设置为其他整数,则在达到指定的分割次数后停止分割。
# 返回值 :
# rsplit() 方法返回一个列表,包含分割后的子字符串。
# 检查 sources 参数是否指向一个文件。如果是文件,则读取文件内容并使用 rsplit() 方法分割成列表;如果不是文件,则直接将 sources 作为列表处理。
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
# 计算视频流源的数量,并将其存储在变量 n 中。
n = len(sources)
# 使用列表推导式和 clean_str 函数清理每个视频流源名称,并将结果存储在 self.sources 属性中。
# def clean_str(s): -> 清除字符串 s 中的特殊字符,并用下划线 _ 替换它们。返回值。函数返回替换后的字符串。 -> return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
self.sources = [clean_str(x) for x in sources] # clean source names for later
# 初始化四个列表,分别用于存储每个视频流的 当前帧 ( self.imgs )、 帧率 ( self.fps )、 总帧数 ( self.frames ) 和 线程对象 ( self.threads )。每个列表的长度都是 n ,即视频流源的数量。初始值分别为 None 和 0 。
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
# 这些初始化步骤为后续的视频流处理打下了基础,包括从视频流中读取帧、更新帧、以及可能的帧转换等操作。通过这种方式, LoadStreams 类可以灵活地处理多个视频流,并为深度学习模型提供输入数据。
# 这段代码是 LoadStreams 类的 __init__ 方法中的一部分,它负责遍历视频流源列表 sources ,并为每个视频流源设置一个线程来读取帧。
# 使用 enumerate 函数遍历 sources 列表, i 是索引(从0开始), s 是当前的流源。
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
# 创建一个字符串 st ,用于日志记录,显示当前处理的是第几个视频流源以及视频流源的名称。
st = f'{i + 1}/{n}: {s}... '
# result = urlparse(urlstring, scheme='', allow_fragments=True)
# urlparse() 函数是 Python 标准库 urllib.parse 模块中的一个函数,用于解析 URL(统一资源定位符)并将其分解为组件。这个函数在处理网络地址时非常有用,因为它可以将复杂的 URL 分解成易于管理的部分。
# 参数 :
# urlstring : 要解析的 URL 字符串。
# scheme : (可选)如果提供,将用于覆盖 URL 中的方案部分。
# allow_fragments : (可选)一个布尔值,指示是否允许解析 URL 的片段部分(即 # 后面的部分)。默认为 True 。
# 返回值 :
# urlparse() 函数返回一个 ParseResult 对象,该对象包含以下属性 :
# scheme : URL 的方案部分(例如 http 、 https )。
# netloc : 网络位置部分(例如域名和端口)。
# path : URL 的路径部分。
# params : URL 的参数部分( ? 后面的部分)。
# query : URL 的查询部分( ? 后面的部分,不包括 # )。
# fragment : URL 的片段部分( # 后面的部分)。
# urlparse() 函数是处理 URL 的基础工具,常用于网络编程、Web 开发和任何需要解析或构造 URL 的场景。
# 检查当前视频流源 s 是否是YouTube视频。如果是,执行以下操作。 urlparse(s).hostname 用于获取 URL 的 hostname 部分,以便判断视频流源是否来自特定的域名(如 YouTube)。
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
# 确保安装了所需的 pafy 和 youtube_dl 库。
# def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''): -> 它用于检查是否安装了满足YOLO要求的依赖项。如果某些依赖项未安装或版本不兼容,函数会尝试自动安装它们。
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
# 导入 pafy 库。
import pafy
# 使用 pafy 获取YouTube视频的最佳质量MP4链接,并将其赋值给 s 。
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
# 如果 s 是数字字符串,使用 eval 将其转换为整数(例如, '0' 转换为 0 ),这通常用于表示本地摄像头。
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
# 如果 s 等于0,检查是否在Colab或Kaggle环境中,因为这些环境不支持本地摄像头,并抛出错误信息。
if s == 0:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.' # --source 0 网络摄像头在 Colab 上不受支持。请在本地环境中重新运行命令。
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.' # --source 0 网络摄像头在 Kaggle 上不受支持。在本地环境中重新运行命令。
# 使用 cv2.VideoCapture 打开视频流源 s 。
cap = cv2.VideoCapture(s)
# ret = cap.isOpened()
# cap.isOpened() 是 OpenCV 库(cv2模块)中 VideoCapture 类的一个方法。这个方法用于检查视频捕获对象是否成功打开了视频文件或视频流。
# ap : 一个 VideoCapture 对象。
# ret : 一个布尔值,如果视频捕获对象成功打开,则返回 True ;否则返回 False 。
# 返回值 :
# isOpened() 方法返回一个布尔值,指示 VideoCapture 对象是否已经成功连接到视频源。
# 确保视频流成功打开,否则抛出错误信息。
assert cap.isOpened(), f'{st}Failed to open {s}' # {st}无法打开{s}。
# ret = cap.get(propId)
# cap.get() 函数是 OpenCV 库中 VideoCapture 类的一个成员函数,用于获取视频流或视频文件的各种属性。这个函数提供了一种查询视频捕获设备或文件当前状态的方法。
# cap : 一个 VideoCapture 对象,代表视频捕获设备或视频文件。
# propId : 一个指定的属性标识符,它是一个特定的枚举值,用于指定要检索的属性类型。
# ret : 返回的属性值,其类型和值取决于请求的属性。
# 属性标识符(propId) propId 参数可以是以下值之一,用于指定要检索的属性 :
# cv2.CAP_PROP_POS_MSEC :视频文件的当前位置,以毫秒为单位。
# cv2.CAP_PROP_POS_FRAMES :视频文件的当前位置,以帧为单位。
# cv2.CAP_PROP_POS_AVI_RATIO :视频文件的当前位置,以AVI文件的百分比表示。
# cv2.CAP_PROP_FRAME_WIDTH :视频流的帧宽度。
# cv2.CAP_PROP_FRAME_HEIGHT :视频流的帧高度。
# cv2.CAP_PROP_FPS :视频流的帧率。
# cv2.CAP_PROP_FOURCC :视频流的四字符代码(Four-Character Code),表示视频编码格式。
# cv2.CAP_PROP_FRAME_COUNT :视频文件中的总帧数。
# cv2.CAP_PROP_FORMAT :视频流的格式。
# cv2.CAP_PROP_MODE :视频捕获模式。
# cv2.CAP_PROP_BRIGHTNESS 、 cv2.CAP_PROP_CONTRAST 、 cv2.CAP_PROP_SATURATION 、 cv2.CAP_PROP_HUE :摄像头的控制属性。
# 返回值 :
# get() 方法返回一个值,该值是请求属性的当前值。返回值的类型取决于请求的属性,可能是整数、浮点数或字符串。
# 获取视频流的宽度,并转换为整数。
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# 获取视频流的高度,并转换为整数。
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 获取视频流的帧率,注意可能会返回0或NaN。
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
# 获取视频流的总帧数,如果没有帧数则默认为无穷大。
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
# 获取视频流的帧率,如果没有有效的帧率则默认为30。
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
# ret, frame = cap.read()
# 在 OpenCV (cv2) 库中, read() 方法是 VideoCapture 类的一个成员函数,用于从视频文件或相机捕获设备中读取帧。这个方法结合了 grab() 和 retrieve() 两个步骤,先捕获一帧,然后检索它。
# cap :是一个 VideoCapture 对象,它表示一个视频流或相机捕获设备。
# ret :是一个布尔值,表示是否成功读取帧。
# frame :是一个图像矩阵,如果 ret 为 True ,则 frame 包含读取的帧数据。
# 行为 :
# 当 read() 被调用时,它会尝试从 VideoCapture 对象关联的视频流或相机中读取下一帧。
# 如果成功读取帧, ret 将为 True ,并且 frame 将包含该帧的图像数据。
# 如果没有更多的帧可以读取(例如,视频结束或相机关闭), ret 将为 False ,并且 frame 将不包含有效的数据。
# 读取视频流的第一帧,并存储在 self.imgs 列表中。
_, self.imgs[i] = cap.read() # guarantee first frame
# Thread(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None)
# 在 Python 中, Thread 是一个类,它属于 threading 模块,用于创建和管理线程。
# 参数 :
# group :这个参数已经被废弃,不需要使用。
# target :一个可调用的对象,它将在这个线程中执行。 target 函数必须接受一个参数,即 Thread 实例本身。
# name :线程的名称。如果未提供,则线程将没有名称。
# args :一个元组,包含传递给 target 函数的参数。
# kwargs :一个字典,包含传递给 target 函数的关键字参数。
# daemon :一个布尔值,表示线程是否作为守护线程运行。如果设置为 True ,则当主程序退出时,线程也会自动退出。
# 方法 :
# start() :启动线程。线程将在 target 函数中执行。
# run() :这是一个在 Thread 类中定义的方法,可以被重写。如果提供了 target 参数,则 run() 方法不会被直接调用,而是调用 target 。
# join(timeout=None) :等待线程终止。 timeout 参数是可选的,表示等待的秒数。
# is_alive() :返回线程是否仍然活跃。
# Thread 类是 Python 中实现多线程编程的基础,允许程序同时执行多个任务。
# 为每个视频流创建一个线程,用于在后台持续读取视频帧。
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
# 使用 LOGGER 记录视频流的成功打开信息,包括 总帧数 、 分辨率 和 帧率 。
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)") # {st} 成功({self.frames[i]} 帧 {w}x{h} 以 {self.fps[i]:.2f} FPS)。
# 启动线程。
self.threads[i].start()
# 在日志中添加一个空行,用于分隔不同的视频流信息。
LOGGER.info('') # newline
# 这段代码负责为每个视频流源设置一个线程来读取帧,并记录视频流的基本信息。它处理了YouTube视频流的特殊情况,并确保了视频流可以成功打开。此外,它还记录了视频流的基本信息,如总帧数、分辨率和帧率。
# 这段代码是 LoadStreams 类中的一部分,用于检查所有视频流是否具有相同的图像尺寸和形状,这对于后续的图像处理和模型推理非常重要。
# check for common shapes
# 使用列表推导式遍历 self.imgs 中的每个图像。
# 对每个图像 x ,调用 letterbox 函数,该函数用于将图像调整到指定的 img_size ,同时保持图像的纵横比。
# letterbox 函数返回一个元组,其中第一个元素是调整后的图像, [0] 用于提取这个图像。
# .shape 获取调整后图像的形状(高度,宽度,通道数)。
# np.stack 将所有图像的形状堆叠成一个 NumPy 数组 s 。
# def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# -> 用于将输入图像 im 调整大小并填充,以适应新的尺寸 new_shape ,同时保持图像的宽高比,并确保结果图像的尺寸是 stride 的倍数。返回 填充后的图像 、 缩放比例 和 填充尺寸 。
# -> return im, ratio, (dw, dh)
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
# np.unique 函数用于找出数组 s 中的唯一形状。
# axis=0 指定沿着第一个轴(即每个形状的维度)进行操作。
# .shape[0] 获取唯一形状的数量。
# 如果唯一形状的数量为1,即所有图像具有相同的形状, self.rect 被设置为 True ,否则为 False 。这表示是否可以进行矩形推理(即所有图像形状相同)。
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
# 如果 auto 参数为 True 且 self.rect 也为 True ,则 self.auto 被设置为 True ,表示可以自动调整图像大小。
self.auto = auto and self.rect
# transforms 参数是一个可选的参数,用于存储图像预处理的变换函数。这些函数可以在图像被送入模型之前对其进行进一步的处理。
self.transforms = transforms # optional
# 如果 self.rect 为 False ,即图像形状不一致,执行以下操作。
if not self.rect:
# 使用 LOGGER 记录一条警告信息,提示用户为了获得最佳性能,应提供形状相似的视频流。
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.') # 警告 ⚠️ 流形状不同。为获得最佳性能,请提供形状相似的流。
# 这段代码检查所有视频流的图像是否具有相同的形状,这对于确保模型能够正确处理输入至关重要。如果形状不一致,会记录警告信息,提示用户优化视频流的输入形状。
# 这个方法是视频流处理的核心,负责初始化视频流的读取和处理。它使用了多线程来提高视频流读取的效率,并处理了多种视频流源类型,包括本地文件、网络URL和YouTube视频。
# 这段代码定义了一个名为 update 的方法,它是 LoadStreams 类的一部分。这个方法在一个守护线程中运行,负责持续从视频流中读取帧,并更新类实例中存储的当前帧。
# 这是 update 方法的定义,它接受三个参数。
# 1.self :类的实例自身。
# 2.i :当前视频流的索引。
# 3.cap : VideoCapture 对象,用于捕获视频流。
# 4.stream :视频流的源,可以是文件路径或摄像头索引。
def update(self, i, cap, stream):
# Read stream `i` frames in daemon thread
# 初始化两个变量。 n 用于记录当前读取的帧数, f 存储视频流的总帧数(从类的 self.frames 属性中获取)。
n, f = 0, self.frames[i] # frame number, frame array
# 一个循环,只要 cap 对象处于打开状态且当前帧数小于总帧数,就继续执行。
while cap.isOpened() and n < f:
# 每次循环迭代,帧数 n 增加 1。
n += 1
# 调用 cap.grab() 方法来从视频流中抓取一帧。这是一个非阻塞调用,它告诉视频捕获设备准备好下一帧。
cap.grab() # .read() = .grab() followed by .retrieve()
# 检查 当前帧数 除以 视频帧率步长 self.vid_stride 的余数是否为 0。如果是,表示需要处理当前帧。
if n % self.vid_stride == 0:
# 调用 cap.retrieve() 方法来检索由 cap.grab() 准备好的帧。该方法返回一个布尔值 success ,指示帧是否成功检索,以及帧图像 im 。
success, im = cap.retrieve()
# 如果帧成功检索,将帧图像 im 赋值给 self.imgs[i] ,更新类实例中存储的当前帧。
if success:
self.imgs[i] = im
# 如果帧未能成功检索,记录一条警告日志,并设置 self.imgs[i] 为一个与之前相同形状的零矩阵,表示当前帧无效。
else:
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.') # 警告⚠️视频流无响应,请检查您的 IP 摄像机连接。
self.imgs[i] = np.zeros_like(self.imgs[i])
# 如果帧未能成功检索,尝试重新打开视频流。
cap.open(stream) # re-open stream if signal was lost
# time.sleep(seconds)
# time.sleep() 是 Python 标准库 time 模块中的一个函数,用于让当前线程暂停执行指定的秒数。
# 参数 :
# seconds : 一个浮点数或整数,表示线程暂停的秒数。
# 返回值 :time.sleep() 函数没有返回值,它只是简单地让调用它的线程休眠。
# 注意事项 :
# time.sleep() 函数接受的参数是秒数,可以是整数或浮点数,表示精确的暂停时间。
# 在多线程程序中, time.sleep() 只暂停当前线程,不影响程序中的其他线程。
# 在某些操作系统上, time.sleep() 的实现可能无法精确到微秒级别,实际的暂停时间可能会比指定的时间长一点点。
# 在需要精确控制时间的场合,可能需要结合其他同步机制,如 threading 模块中的 Event 或 Condition 对象。
# 在循环的最后,使用 time.sleep(0.0) 暂停线程。虽然这里暂停时间是 0 秒,但这通常用于在循环中避免 100% 的 CPU 使用率。
time.sleep(0.0) # wait time
# update 方法用于在后台线程中持续从视频流中读取帧,并在满足条件时更新类实例中的当前帧。如果视频流出现问题,它会尝试重新连接视频流,并记录相应的日志信息。
# 这段代码定义了 LoadStreams 类的 __iter__ 方法。在 Python 中, __iter__ 方法是一个特殊的方法,当一个对象需要被迭代时(比如在 for 循环中使用时),会被自动调用。这个方法的目的是返回一个迭代器对象,该对象实现了 Python 的迭代器协议,即拥有 __next__ 方法。
# 这是 __iter__ 方法的定义,它不接受除了 self 之外的任何参数。
def __iter__(self):
# 在迭代开始之前,将一个名为 count 的实例变量初始化为 -1 。这个变量通常用于跟踪迭代过程中的当前位置或状态。初始化为 -1 可能是因为迭代器协议要求在第一次调用 __next__ 方法时返回第一个元素,所以这里先设置一个初始值,以便在 __next__ 方法中递增到 0 。
self.count = -1
# 返回类的实例自身,这意味着类的实例既是可迭代对象,也是迭代器。这是 Python 中的一个常见模式,称为"迭代器模式",其中同一个对象实现了迭代器协议。
return self
# __iter__ 方法使得类的实例可以被用作迭代器,这是实现迭代器协议的一部分。在类的实例上使用 for 循环或其他迭代上下文时,Python 会自动调用这个方法。
# 这段代码定义了 LoadStreams 类的 __next__ 方法,它是迭代器协议的一部分,用于在每次迭代中返回下一个元素。
# 这是 __next__ 方法的定义,它不接受除了 self 之外的任何参数。
def __next__(self):
# 将实例变量 count 的值增加 1,这个变量用于跟踪迭代的次数。
self.count += 1
# 使用 all() 函数检查 self.threads 列表中的所有线程是否仍然存活(即是否还在运行)。 cv2.waitKey(1) 是 OpenCV 的一个函数,用于等待一个键盘输入,参数 1 表示等待时间为 1 毫秒。如果用户按下了 'q' 键( ord('q') 是 'q' 的 ASCII 码),则返回 True 。如果任何一个条件为 True ,则执行以下代码。
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
# 关闭所有 OpenCV 创建的窗口。
cv2.destroyAllWindows()
# 抛出 StopIteration 异常,表示迭代结束。
raise StopIteration
# 创建 self.imgs 列表中图像的副本,以便在处理过程中不修改原始图像。
im0 = self.imgs.copy()
# 如果提供了 self.transforms (一个包含变换函数的列表),则对每个图像应用这些变换。
if self.transforms:
# 应用变换并堆叠结果。
im = np.stack([self.transforms(x) for x in im0]) # transforms
# 如果没有提供 self.transforms ,则执行以下操作。
else:
# 对每个图像应用 letterbox 函数进行调整大小。
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
# 将图像从 BGR 格式转换为 RGB 格式,并将数据从 BHWC 格式转换为 BCHW 格式(Batch, Channel, Height, Width, Channel)。
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
# 确保图像数据在内存中是连续的,这对于某些 NumPy 操作和 OpenCV 函数是必要的。
im = np.ascontiguousarray(im) # contiguous
# 返回一个元组,包含以下元素 : self.sources 视频流的源列表 、 im 处理后的图像 、 im0 原始图像的副本 、 None 可能用于其他数据,这里没有提供 、 '' 一个空字符串,可能用于其他信息,这里没有提供。
return self.sources, im, im0, None, ''
# __next__ 方法用于在迭代过程中返回下一个元素,它处理图像,应用变换,并在适当的时候结束迭代。这个方法是实现迭代器协议的关键部分,使得类的实例可以在 for 循环中使用。
# 这段代码定义了 LoadStreams 的 __len__ 方法,它用于返回对象的长度,即视频流源的数量。
# 这是 __len__ 方法的定义,它不接受除了 self 之外的任何参数。
def __len__(self):
# 返回 self.sources 列表的长度,即视频流源的数量。
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
# __len__ 方法使得类的实例可以被 len() 函数查询长度,这在很多情况下非常有用,比如当你需要知道有多少个视频流源时。这个方法是 Python 中的一个特殊方法,它允许对象模拟内置的序列类型(如列表、元组、字符串等)的行为。
# 注释中的 "1E12 frames = 32 streams at 30 FPS for 30 years" 提供了一个示例计算,说明如果每个视频流以每秒 30 帧的速度连续播放 30 年,那么 32 个这样的视频流总共会产生大约 1E12(10^12,即一万亿)帧。这个注释可能是为了说明 __len__ 方法返回值的潜在规模,或者是对视频流数据处理潜在复杂性的一个提示。
# 这个类的设计目的是为了能够从多个视频源(包括本地文件、网络摄像头、YouTube视频等)中实时加载和处理视频帧,适用于需要视频流处理的应用,如视频监控、实时目标检测等。代码中使用了多线程来提高性能,并且支持图像的预处理和转换。
12.def img2label_paths(img_paths):
python
# 这段代码定义了一个名为 img2label_paths 的函数,其目的是将包含图像文件路径的列表转换为对应的标签文件路径列表。这个函数假定图像文件和标签文件存放在不同的目录中,且这两个目录具有相同的结构,除了最后的目录名不同(一个是 images ,另一个是 labels )。
# 这是 img2label_paths 函数的定义,它接受一个参数。
# 1.img_paths :这是一个包含图像文件路径的列表。
def img2label_paths(img_paths):
# Define label paths as a function of image paths 将标签路径定义为图像路径的函数。
# os.sep
# os.sep 是 Python 标准库 os 模块中的一个属性,它代表操作系统特定的路径分隔符。这个属性在不同的操作系统中有不同的值 :
# 在 Windows 系统中, os.sep 是反斜杠 \ 。
# 在 Unix 和 Unix-like 系统(包括 Linux 和 macOS)中, os.sep 是正斜杠 / 。
# os.sep 用于确保代码在不同操作系统中处理文件路径时具有兼容性。例如,当你需要在路径字符串中插入分隔符时,使用 os.sep 而不是硬编码分隔符可以使得代码更加可移植。
# 这两行代码定义了两个字符串 sa 和 sb ,它们分别代表图像目录和标签目录的路径后缀。 os.sep 是操作系统特定的路径分隔符(例如,在Windows上是 \ ,在Unix/Linux上是 / )。
# sa 是图像目录的后缀,例如 /images/ 。
# sb 是标签目录的后缀,例如 /labels/ 。
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
# str.rsplit(separator=None, maxsplit=-1)
# 在Python中, rsplit() 是字符串( str )对象的一个方法,用于在字符串末尾进行分割操作。这个方法从字符串的右侧(末尾)开始分割,而不是默认的左侧(开头)。
# 参数 :
# separator : 分隔符,用于指定分隔字符串的字符或字符串。如果未指定或为 None ,则任何空白字符(如空格、换行 \n 、制表符 \t 等)都被视为分隔符。
# maxsplit :最大分割次数。如果设置为 -1 (默认值),则没有分割次数限制,字符串会被完全分割直到没有分隔符为止。如果设置为其他整数,则在达到指定的分割次数后停止分割。
# 返回值 :
# rsplit() 方法返回一个列表,包含分割后的子字符串。
# 这是一个列表推导式,它遍历 img_paths 列表中的每个图像路径 x
# x.rsplit(sa, 1) 从路径 x 的末尾开始分割,最多分割一次,将图像目录后缀 sa 替换掉。
# sb.join(...) 将替换后的路径与标签目录后缀 sb 连接起来,形成新的路径。
# .rsplit('.', 1)[0] 再次从新路径的末尾开始分割,最多分割一次,去掉文件扩展名,只保留路径和文件名。
# '.txt' 将 .txt 扩展名添加到结果路径和文件名的末尾,因为标签文件通常以 .txt 结尾。
# 最终,这个列表推导式返回一个新的列表,其中包含每个图像文件对应的标签文件路径。
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
# img2label_paths 函数通过替换路径中的 images 子串为 labels 并添加 .txt 扩展名,将图像文件路径转换为标签文件路径。这个函数在处理图像和相关标签文件时非常有用,特别是在机器学习和计算机视觉任务中,例如目标检测或图像分割。
13.class LoadImagesAndLabels(Dataset):
python
# 这段代码定义了一个名为 LoadImagesAndLabels 的类,它继承自 Dataset 类(指的是 PyTorch 的 torch.utils.data.Dataset )。这个类用于加载图像和对应的标签,并提供数据增强、缓存等功能,以便于在深度学习模型训练中使用。
class LoadImagesAndLabels(Dataset):
# YOLOv5 train_loader/val_loader, loads images and labels for training and validation YOLOv5 train_loader/val_loader,加载用于训练和验证的图像和标签。
# 类属性。
# 数据集标签缓存的版本号。
cache_version = 0.6 # dataset labels *.cache version
# 随机插值方法列表,用于数据增强时的图像大小调整。
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
# 这段代码是 LoadImagesAndLabels 类的构造函数 __init__ ,它初始化类的实例并设置了一系列参数来配置数据加载和预处理的行为。
# 这是构造函数的定义,它接受多个参数来定制数据集的行为。
# 1.self :这是 Python 类实例方法的第一个参数,它指向类实例本身。
# 2.path :一个字符串或字符串列表,表示图像数据的路径。可以是一个文件夹路径,也可以是一个包含图像路径的文件。
# 3.img_size :一个整数,指定加载图像的目标尺寸,默认为 640 像素。
# 4.batch_size :一个整数,指定每个批次中的图像数量,默认为 16。
# 5.augment :一个布尔值,指示是否对图像进行数据增强,默认为 False 。
# 6.hyp :一个字典或对象,存储超参数,如学习率等,默认为 None 。
# 7.rect :一个布尔值,指示是否使用矩形训练模式,默认为 False 。
# 8.image_weights :一个布尔值,指示是否使用图像权重进行训练,默认为 False 。
# 9.cache_images :一个布尔值,指示是否将图像缓存到内存或磁盘中以加快训练速度,默认为 False 。
# 10.single_cls :一个布尔值,指示是否进行单类别训练,默认为 False 。
# 11.stride :一个整数,指定模型的步长,默认为 32 像素。
# 12.pad :一个浮点数,指定在调整图像大小时的填充比例,默认为 0.0。
# 13.min_items :一个整数,指定每个图像的最小项目数(如标签数),少于该数量的图像将被过滤掉,默认为 0。
# 14.prefix :一个字符串,用于在日志和错误消息前添加前缀,默认为空字符串。
# 这些参数提供了对数据加载和预处理流程的精细控制,使得 LoadImagesAndLabels 类能够适应不同的训练需求和配置。
def __init__(self,
path,
img_size=640,
batch_size=16,
augment=False,
hyp=None,
rect=False,
image_weights=False,
cache_images=False,
single_cls=False,
stride=32,
pad=0.0,
min_items=0,
prefix=''):
# 设置实例变量 img_size ,表示加载图像的目标尺寸。
self.img_size = img_size
# 设置实例变量 augment ,一个布尔值,指示是否应用数据增强。
self.augment = augment
# 设置实例变量 hyp ,通常用于存储超参数。
self.hyp = hyp
# 设置实例变量 image_weights ,一个布尔值,指示是否使用图像权重。
self.image_weights = image_weights
# 根据 image_weights 的值设置 rect 变量。如果 image_weights 为 True ,则 rect 设置为 False ;否则,使用 rect 参数的值。
self.rect = False if image_weights else rect
# 设置实例变量 mosaic ,一个布尔值,指示是否在训练时将4张图像加载到一个镶嵌(mosaic)中。
self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
# 设置实例变量 mosaic_border ,表示镶嵌的边界。
self.mosaic_border = [-img_size // 2, -img_size // 2]
# 设置实例变量 stride ,表示模型的步长。
self.stride = stride
# 设置实例变量 path ,表示数据集的路径。
self.path = path
# 如果启用了数据增强( augment 为 True ),则实例化 Albumentations 对象,并设置图像尺寸;否则,设置为 None 。
# class Albumentations:
# -> 用于图像增强,特别是在目标检测任务中。这个类使用了 albumentations 库来应用一系列的图像变换,以增加数据集的多样性并提高模型的泛化能力。
# -> def __init__(self, size=640):
self.albumentations = Albumentations(size=img_size) if augment else None
# 这个构造函数为数据集类提供了丰富的配置选项,允许用户根据需要定制数据加载和预处理流程。通过设置这些参数,用户可以控制数据增强的行为、图像尺寸、是否使用图像权重等关键因素。
# 这段代码是 LoadImagesAndLabels 类的构造函数 __init__ 中的一部分,它负责根据提供的路径 path 收集图像文件的路径列表。
# 开始一个 try 块,用于捕获在加载数据时可能发生的任何异常。
try:
# 初始化一个空列表 f ,用于存储找到的图像文件路径。
f = [] # image files
# 检查 path 参数是否是一个列表,如果是,则直接迭代;如果不是,则将其放入列表中以便迭代。
for p in path if isinstance(path, list) else [path]:
# 将路径转换为 Path 对象,这使得路径操作跨平台兼容。
p = Path(p) # os-agnostic
# 检查 Path 对象 p 是否指向一个目录。
if p.is_dir(): # dir
# 如果 p 是目录,使用 glob 模块递归地搜索所有文件,并将匹配的文件路径添加到列表 f 中。
# glob.glob : glob.glob 是一个函数,用于从目录树中搜索匹配特定模式的文件路径。
# str(p / '**' / '*.*') :
# p 是一个 Path 对象, p / '**' / '*.*' 使用路径分隔符 / 来构建一个路径模式。
# ** 表示任意数量的目录, *.* 匹配目录下的所有文件( * 表示任意字符, . 表示文件扩展名)。
# str() 函数将 Path 对象转换为字符串路径,因为 glob.glob 需要字符串参数。
# recursive=True :这个参数指示 glob.glob 函数递归地搜索所有子目录。
# f += : f 是一个列表,用于存储找到的文件路径。 += 运算符将 glob.glob 的结果(一个列表)添加到 f 列表的末尾。
# 这行代码的作用是将指定目录 p 及其所有子目录中的所有文件的路径收集到列表 f 中。这是加载数据集图像文件路径的常用方法,尤其是在处理大型数据集时。
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
# 如果 p 是文件,执行以下操作。
elif p.is_file(): # file
# 打开文件并读取其内容,文件中的每一行都包含一个图像路径。
with open(p) as t:
# 这行代码执行了三个字符串方法,用于处理从文件中读取的内容 :
# t.read() :这个方法从文件中读取全部内容,并将其作为一个字符串返回。
# .strip() : strip() 方法去除字符串两端的空白字符,包括空格、换行符 \n 、制表符 \t 等。这有助于清理文件内容,避免由于额外的空格或换行符引起的问题。
# .splitlines() : splitlines() 方法将字符串分割成多行,每行作为一个列表元素返回。默认情况下,它会使用任何标准的空白字符(如 \n 或 \r )作为行分隔符。
# t = t.read().strip().splitlines() 这行代码读取文件的全部内容,去除两端的空白字符,并将内容按行分割,存储在变量 t 中。这样, t 就包含了一个字符串列表,每个字符串代表文件中的一行。这种处理方式常用于读取包含多行数据的文本文件,例如图像路径列表或配置文件。
t = t.read().strip().splitlines()
# 获取文件的父目录路径,并添加操作系统特定的路径分隔符。
parent = str(p.parent) + os.sep
# 将文件中的相对路径转换为绝对路径,并将这些路径添加到列表 f 中。
f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t] # to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
# 如果 p 既不是目录也不是文件,抛出 FileNotFoundError 异常。
else:
raise FileNotFoundError(f'{prefix}{p} does not exist') # {prefix}{p} 不存在。
# 过滤列表 f ,只保留文件扩展名在 IMG_FORMATS 中的图像文件路径,并将路径中的正斜杠替换为操作系统特定的路径分隔符,然后对结果进行排序。
self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
# 确保 self.im_files 不为空,否则抛出异常,表示没有找到图像。
assert self.im_files, f'{prefix}No images found' # {prefix}未找到任何图片。
# 如果在 try 块中发生任何异常,捕获它并抛出一个新的异常,包含前缀、路径和错误信息。
except Exception as e:
# 抛出一个新的异常,包含错误信息和帮助链接 HELP_URL ,并保留原始异常的上下文。
# HELP_URL -> 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e # {prefix}从 {path} 加载数据时出错:{e}\n{HELP_URL}。
# 这段代码确保了无论提供的路径是单个目录、单个文件还是文件列表,都能正确地加载图像文件路径,并处理了各种可能的错误情况。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,它负责检查和创建缓存文件。
# Check cache
# 调用 img2label_paths 函数将图像文件路径转换为对应的标签文件路径,并存储在 self.label_files 中。
# def img2label_paths(img_paths): -> 将包含图像文件路径的列表转换为对应的标签文件路径列表。这个列表推导式返回一个新的列表,其中包含每个图像文件对应的标签文件路径。 -> return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
self.label_files = img2label_paths(self.im_files) # labels
# path.with_suffix(suffix)
# with_suffix 是 Python pathlib 模块中 Path 类的一个方法,它用于修改路径对象的后缀(扩展名)。
# path : Path 类的实例。
# suffix :要设置的新后缀。如果为空字符串,则移除路径的当前后缀。
# 返回值 :
# 返回一个新的 Path 对象,其后缀被修改为指定的 suffix 。
# 方法功能 :
# 如果原始路径没有后缀, with_suffix 方法会将指定的后缀追加到路径的末尾。
# 如果原始路径已经有后缀, with_suffix 方法会替换为指定的新后缀。
# 如果指定的 suffix 是空字符串, with_suffix 方法会移除路径的当前后缀。
# 注意事项 :
# with_suffix 方法不会修改原始的 Path 对象,而是返回一个新的 Path 对象。
# 这个方法在处理文件扩展名时非常有用,尤其是在你需要动态更改文件类型或处理不同格式的文件时。
# 确定缓存文件的路径。如果 p 是一个文件路径,则使用该路径;如果不是,使用第一个标签文件的父目录。 给路径添加 .cache 后缀,以创建缓存文件的名称。
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
# 开始一个 try 块,用于捕获加载缓存文件时可能发生的任何异常。
try:
# 使用 np.load 函数加载缓存文件, allow_pickle=True 允许加载 pickle 对象。 .item() 方法将 NumPy 数组转换为 Python 字典。 exists 设置为 True ,表示缓存文件已成功加载。
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
# 断言缓存文件的版本与当前版本一致。
assert cache['version'] == self.cache_version # matches current version
# 断言缓存文件的哈希值与当前标签文件和图像文件的哈希值一致,确保文件没有变化。
# def get_hash(paths): -> 用于计算一个包含文件或目录路径列表的哈希值。计算最终的哈希值,并以十六进制字符串的形式返回。 -> return h.hexdigest() # return hash
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
# 如果在 try 块中发生任何异常,执行 except 块中的代码。
except Exception:
# 如果异常发生,调用 self.cache_labels 方法来创建新的缓存文件,并设置 exists 为 False 。
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
# 这段代码的目的是确保使用最新的、正确的缓存文件。如果缓存文件不存在或不匹配,它会触发缓存文件的创建过程。这是为了提高数据加载的效率,避免每次运行程序时都重复相同的计算。
# 这段代码是 LoadImagesAndLabels 类构造函数中的一部分,用于显示缓存文件的扫描结果。
# Display cache
# 从缓存字典 cache 中弹出键为 'results' 的项,该项包含五个值。 nf (找到的图像数量) 、 nm (缺失的背景数量) 、 ne (空背景数量) 、 nc (损坏的图像数量) 和 n (总图像数量) 。 cache.pop('results') 会从字典中移除这个键值对,并返回其值。
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
# 检查缓存文件是否存在( exists 为 True )以及当前进程的 LOCAL_RANK 是否为 -1 或 0 。 LOCAL_RANK 通常用于分布式训练,其中 -1 表示非分布式环境, 0 表示分布式环境中的主进程。
# LOCAL_RANK -> 使用 os.getenv 函数获取环境变量 LOCAL_RANK 的值,并将其转换为整数。如果环境变量未设置,则默认为 -1 。
if exists and LOCAL_RANK in {-1, 0}:
# 创建一个字符串 d ,用于描述扫描缓存的结果,包括 缓存文件路径 、 找到的图像数量 、 背景数量 和 损坏的图像数量 。
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" # 扫描 {cache_path}...{nf} 幅图像、{nm + ne} 幅背景、{nc} 幅损坏图像
# 使用 tqdm 库显示一个进度条,其中 None 表示没有迭代器, desc 是进度条的描述, total 是总进度, initial 是初始进度, bar_format 是进度条的格式。
# TQDM_BAR_FORMAT -> '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
# 检查缓存字典中是否包含 'msgs' 键,该键对应于一些警告消息。检查缓存字典中是否包含 'msgs' 键,该键对应于一些警告消息。
if cache['msgs']:
# 如果存在警告消息,则使用 LOGGER 将这些消息记录为信息日志。
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
# 使用 assert 语句确保如果数据增强未启用( not augment ),则不要求必须找到标签;如果启用了数据增强,则至少找到一个标签。如果没有找到标签且尝试进行数据增强,则抛出异常,提示没有找到标签,无法开始训练,并给出帮助链接。
# HELP_URL -> 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}' # {prefix}在 {cache_path} 中未找到标签,无法开始训练。{HELP_URL}。
# 这段代码的目的是向用户提供有关缓存文件内容的反馈,包括图像和标签的统计信息,并确保在尝试训练模型之前,至少有一些有效的标签数据可用。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,用于从缓存中读取数据并进行处理。
# Read cache
# 这是一个列表推导式,用于从缓存字典 cache 中移除键 'hash' 、 'version' 和 'msgs' 及其对应的值。 cache.pop(k) 会从字典中移除键 k 并返回其值,这个操作会改变原始的 cache 字典。
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
# 使用 zip 函数将 cache.values() 的结果(一个迭代器,包含所有值)解压成三个独立的迭代器。 *cache.values() 是一个星号表达式,它将 cache.values() 的结果解包成独立的参数传递给 zip 函数。 labels 存储所有标签数据, shapes 存储所有图像的形状数据, self.segments 存储所有分割数据。
labels, shapes, self.segments = zip(*cache.values())
# 使用 np.concatenate 函数将 labels 列表中的所有数组沿着第一个轴(0轴)连接成一个大数组。 len 函数计算连接后的数组的长度,即 总标签数量 。
nl = len(np.concatenate(labels, 0)) # number of labels
# 使用 assert 语句确保如果数据增强未启用( not augment ),则不要求必须有标签;如果启用了数据增强,则至少有一个标签。 如果所有标签都为空且尝试进行数据增强,则抛出异常,提示所有标签都为空,无法开始训练,并给出帮助链接。
# HELP_URL -> 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}' # {prefix}{cache_path} 中的所有标签都为空,无法开始训练。{HELP_URL}。
# 将 labels 列表转换为类的属性 self.labels 。
self.labels = list(labels)
# 将 shapes 列表转换为 NumPy 数组,并存储在 self.shapes 中。
self.shapes = np.array(shapes)
# 将缓存字典 cache 的键(即图像文件路径)转换为列表,并更新 self.im_files 属性。
self.im_files = list(cache.keys()) # update
# 调用 img2label_paths 函数将 图像文件路径 转换为对应的 标签文件路径 ,并更新 self.label_files 属性。
# def img2label_paths(img_paths): -> 将包含图像文件路径的列表转换为对应的标签文件路径列表。这个列表推导式返回一个新的列表,其中包含每个图像文件对应的标签文件路径。 -> return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
self.label_files = img2label_paths(cache.keys()) # update
# 这段代码的目的是处理缓存中的数据,确保数据的完整性,并将其存储在类的属性中,以便后续的训练和处理。通过检查标签的数量,它还确保了数据增强的可行性。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,用于根据 min_items 参数过滤图像和对应的标签。
# Filter images
# 这个条件检查是否设置了 min_items 参数(即每个图像的最小项目数,如标签数)。
if min_items:
# 创建一个布尔数组,其中每个元素对应于 self.labels 中的标签数组是否满足最小项目数 min_items 。
# len(x) >= min_items 检查每个标签数组 x 的长度是否大于或等于 min_items 。 np.array(...) 将布尔列表转换为 NumPy 数组。 .nonzero()[0] 获取布尔数组中为 True 的索引。 .astype(int) 将索引转换为整数类型。
include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
# 使用 LOGGER 记录信息,显示从数据集中过滤掉的图像数量。 n 是总图像数量, len(include) 是满足 min_items 条件的图像数量。 prefix 是日志消息前缀。
LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset') # 从数据集中过滤的 {prefix}{n - len(include)}/{n} 个图像。
# 更新 self.im_files 属性,只保留满足 min_items 条件的图像文件路径。
self.im_files = [self.im_files[i] for i in include]
# 更新 self.label_files 属性,只保留满足 min_items 条件的标签文件路径。
self.label_files = [self.label_files[i] for i in include]
# 更新 self.labels 属性,只保留满足 min_items 条件的标签数据。
self.labels = [self.labels[i] for i in include]
# 如果使用了分割数据,更新 self.segments 属性,只保留满足 min_items 条件的分割数据。
self.segments = [self.segments[i] for i in include]
# 更新 self.shapes 属性,只保留满足 min_items 条件的图像形状数据。
self.shapes = self.shapes[include] # wh
# 这段代码的目的是确保数据集中的每个图像都至少包含 min_items 个项目(通常是标签)。如果图像的项目数少于 min_items ,则该 图像 及其 对应的标签 、 分割数据 和 形状数据 将从数据集中移除。这有助于提高训练效率,特别是在处理不平衡数据集时。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,用于创建数据集的索引和批量索引。
# Create indices
# 计算 图像的总数 ,并将其存储在变量 n 中。
n = len(self.shapes) # number of images
# 使用 np.arange(n) 创建一个从 0 到 n-1 的整数序列,代表每个图像的索引。 将每个索引除以 batch_size ,计算每个图像属于哪个批次。 np.floor 函数向下取整,确保每个图像都被分配到正确的批次。 .astype(int) 将结果转换为整数类型。 bi 是一个数组,其中每个元素代表对应图像的批次索引。
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
# 通过取 bi 数组中的最大值并加 1,计算 总的批次数量 nb 。
nb = bi[-1] + 1 # number of batches
# 将批次索引数组 bi 存储为类的属性 self.batch 。
self.batch = bi # batch index of image
# 将图像总数 n 存储为类的属性 self.n 。
self.n = n
# 创建一个从 0 到 n-1 的整数序列,并将其存储为类的属性 self.indices 。这个序列代表所有图像的索引,用于随机访问或迭代。
self.indices = range(n)
# 这段代码的目的是为数据集创建必要的索引信息,以便在训练过程中能够根据批次索引和图像总数来组织和访问数据。这些索引对于实现数据的批量处理和迭代器协议至关重要。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,用于更新标签数据,包括过滤特定类别的标签和处理单类别训练的情况。
# Update labels
# 定义一个空列表 include_class ,用于存储需要包含的类别索引。如果这个列表不为空,只有这些类别的标签会被保留。
include_class = [] # filter labels to include only these classes (optional) 过滤标签以仅包含这些类别(可选)。
# 将 include_class 列表转换为 NumPy 数组,并将其重塑为一行多列的数组,以便进行广播比较。
include_class_array = np.array(include_class).reshape(1, -1)
# 使用 enumerate 函数和 zip 函数遍历 self.labels 和 self.segments ,同时获取索引 i 和对应的标签 label 和分割数据 segment 。
for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
# 如果 include_class 列表不为空,执行以下操作。
if include_class:
# 比较标签中的类别索引与 include_class_array ,返回一个布尔数组 j ,表示哪些标签属于指定的类别。
j = (label[:, 0:1] == include_class_array).any(1)
# 更新 self.labels 中的标签,只保留属于指定类别的标签。
self.labels[i] = label[j]
# 如果分割数据 segment 存在,同样更新分割数据,只保留属于指定类别的部分。
if segment:
self.segments[i] = segment[j]
# 如果设置了单类别训练( single_cls 为 True ),则执行以下操作。
if single_cls: # single-class training, merge all classes into 0
# 将所有标签的第一个元素(通常是类别索引)设置为 0,这样在训练时就只有一个类别。
self.labels[i][:, 0] = 0
# 这段代码的目的是确保标签数据符合训练要求,包括过滤特定类别的标签和适应单类别训练的场景。过滤特定类别的标签可以在多类别数据集中进行特定类别的训练,而单类别训练则将所有标签合并为一个类别,这在某些特定的训练场景中可能会用到。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,用于处理矩形训练(Rectangular Training),也就是根据图像的宽高比(aspect ratio)对图像进行排序。
# Rectangular Training
# 检查是否设置了矩形训练模式( self.rect 为 True )。
if self.rect:
# Sort by aspect ratio
# 获取图像的形状数据, s 是一个数组,其中每行包含两个值。 宽度( w )和高度( h )。
s = self.shapes # wh
# 计算每个图像的宽高比( ar ),即图像宽度除以高度。
ar = s[:, 1] / s[:, 0] # aspect ratio
# 使用 argsort() 方法对宽高比进行排序,并获取排序后的索引。
irect = ar.argsort()
# 根据排序后的索引 irect ,对图像文件路径列表 self.im_files 进行重新排序。
self.im_files = [self.im_files[i] for i in irect]
# 对标签文件路径列表 self.label_files 进行重新排序。
self.label_files = [self.label_files[i] for i in irect]
# 对标签数据列表 self.labels 进行重新排序。
self.labels = [self.labels[i] for i in irect]
# 如果使用了分割数据,对分割数据列表 self.segments 进行重新排序。
self.segments = [self.segments[i] for i in irect]
# 对图像形状数据 self.shapes 进行重新排序。
self.shapes = s[irect] # wh
# 对宽高比数组 ar 进行重新排序。
ar = ar[irect]
# 这段代码的目的是对数据集中的图像进行排序,使得具有相似宽高比的图像在训练时放在一起处理。这有助于提高训练效率,特别是在进行批量处理时,因为具有相似尺寸的图像可以更有效地进行批处理。此外,这也有助于保持训练过程中的一致性,因为相似尺寸的图像可能需要相似的处理方式。
# 这段代码继续处理 LoadImagesAndLabels 类的构造函数中的矩形训练设置,用于确定每个批次的训练图像形状。
# Set training image shapes
# 初始化一个形状列表 shapes ,其中每个元素都是 [1, 1] ,这个列表的长度等于批次数量 nb 。这个列表将存储每个批次的目标图像形状。
shapes = [[1, 1]] * nb
# 遍历每个批次, i 是批次索引。
for i in range(nb):
# 从宽高比数组 ar 中选择属于当前批次 i 的所有宽高比值。
ari = ar[bi == i]
# 计算当前批次中宽高比的最小值 mini 和最大值 maxi 。
mini, maxi = ari.min(), ari.max()
# 如果最大宽高比小于 1,说明该批次中的图像主要是宽度大于高度的肖像模式图像。
if maxi < 1:
# 设置该批次的目标形状为 [maxi, 1] ,意味着高度被调整为 1,宽度按比例缩小。
shapes[i] = [maxi, 1]
# 如果最小宽高比大于 1,说明该批次中的图像主要是高度大于宽度的landscape模式图像。
elif mini > 1:
# 设置该批次的目标形状为 [1, 1 / mini] ,意味着宽度被调整为 1,高度按比例缩小。
shapes[i] = [1, 1 / mini]
# 将 shapes 列表转换为 NumPy 数组,并计算每个批次的实际图像形状,考虑了 img_size 、 stride 和 pad 。
# np.array(shapes) * img_size 将形状按图像尺寸放大。
# / stride 应用步长,确保形状是步长的整数倍。
# + pad 添加填充。
# np.ceil(...) 对结果向上取整,确保形状在放大和填充后仍然是整数。
# .astype(int) 将结果转换为整数类型。
# 最后,结果乘以 stride ,确保最终形状符合模型的输入要求。
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
# 这段代码的目的是根据每个批次中图像的宽高比来调整图像的形状,以便在训练时使用。这种调整有助于优化网络的性能,特别是在处理不同宽高比的图像时。通过这种方式,可以确保每个批次的图像在送入网络之前都被适当地调整和对齐。
# 这段代码是 LoadImagesAndLabels 类构造函数的一部分,用于将图像缓存到内存(RAM)或磁盘上,以便加快训练速度。
# Cache images into RAM/disk for faster training
# 检查是否设置了将图像缓存到RAM ( cache_images == 'ram' ) 并且是否通过了缓存到RAM的检查( self.check_cache_ram 方法)。 如果没有通过检查,则不缓存图像到RAM, cache_images 设置为 False 。
if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
cache_images = False
# 初始化一个列表 self.ims ,用于存储加载到内存中的图像数据,长度为 n (图像总数)。
self.ims = [None] * n
# 为每个图像文件路径创建一个对应的 .npy 文件路径列表 self.npy_files ,这些文件将用于存储缓存的图像数据。
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
# 检查是否需要缓存图像,无论是到RAM还是磁盘。
if cache_images:
# 初始化变量 b 用于跟踪已缓存图像的字节数, gb 为每吉字节的字节数(1GB = 2^30 bytes)。
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
# 初始化两个列表 self.im_hw0 和 self.im_hw ,用于存储 原始 和 调整后 的图像尺寸,长度为 n 。
self.im_hw0, self.im_hw = [None] * n, [None] * n
# 根据缓存位置(RAM或磁盘),设置函数 fcn 为 self.cache_images_to_disk 或 self.load_image 。
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
# 使用线程池 ThreadPool 和 imap 函数并行地加载图像, NUM_THREADS 是线程数量, range(n) 是图像索引的范围。
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
# 创建一个 tqdm 进度条来显示图像加载的进度。
# TQDM_BAR_FORMAT -> '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
# LOCAL_RANK -> 使用 os.getenv 函数获取环境变量 LOCAL_RANK 的值,并将其转换为整数。如果环境变量未设置,则默认为 -1 。
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
# 遍历进度条中的每个项, i 是索引, x 是加载的图像数据。
for i, x in pbar:
# 如果缓存到磁盘,更新已缓存的字节数 b 。
if cache_images == 'disk':
# 将第 i 个 .npy 文件的大小加到 b 。
b += self.npy_files[i].stat().st_size
# 如果缓存到RAM,更新图像数据和已缓存的字节数 b 。
else: # 'ram'
# 将加载的图像数据存储到 self.ims 。从加载图像的函数返回值中解包数据,并将其存储在类的相应属性中。
# self.ims[i], self.im_hw0[i], self.im_hw[i] = x
# 这行代码假设 x 是一个包含三个元素的元组,这些元素分别代表 : self.ims[i] 第 i 个图像加载到内存中的数组形式。 self.im_hw0[i] 第 i 个图像的原始宽度和高度。 self.im_hw[i] 第 i 个图像加载后(可能经过调整大小)的宽度和高度。
# x
# x 是从 self.load_image 或 self.cache_images_to_disk 函数返回的结果,这些函数负责加载图像并进行一些预处理,如调整大小、归一化等。
# 解包操作 :
# self.ims[i] 存储第 i 个图像的数据,通常是一个 NumPy 数组。
# self.im_hw0[i] 存储第 i 个图像的原始尺寸,这在某些情况下(如需要恢复图像到原始尺寸)非常有用。
# self.im_hw[i] 存储第 i 个图像加载后的尺寸,这与 self.im_hw0[i] 不同,如果图像在加载过程中被调整了大小。
# 这种解包操作使得代码更加简洁,并且能够清晰地将函数返回的多个值分配给类的多个属性。这样做的好处是可以在后续的处理中方便地访问这些数据,而不需要再次查询或计算。
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
# 将第 i 个图像的大小加到 b 。
b += self.ims[i].nbytes
# 更新进度条的描述,显示已缓存的图像大小和缓存位置。
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' # {prefix}缓存图像 ({b / gb:.1f}GB {cache_images})。
# 关闭进度条。
pbar.close()
# 这段代码的目的是提高数据加载效率,通过将图像预先加载到内存或磁盘上,减少训练过程中的I/O操作。这对于大型数据集或训练速度较慢的环境尤其有用。
# 这个类提供了一个灵活的方式来加载和处理图像和标签数据,支持多种数据增强和缓存策略,适用于深度学习模型的训练。
# 这段代码定义了一个名为 check_cache_ram 的方法,它是 LoadImagesAndLabels 类的一部分。这个方法用于检查是否有足够的可用内存来将整个数据集缓存到 RAM 中。
# 这是 check_cache_ram 方法的定义,它接受两个参数。
# 1.safety_margin :安全边际比例,默认为 0.1。
# 2.prefix :用于日志消息的前缀,默认为空字符串。
def check_cache_ram(self, safety_margin=0.1, prefix=''):
# Check image caching requirements vs available memory 检查图片缓存要求与可用内存。
# 初始化变量 b 用于跟踪缓存图像的总字节数, gb 为每吉字节的字节数(1GB = 2^30 bytes)。
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes 缓存图像的字节数(每 GB 的字节数)。
# 从数据集中随机选择 30 个图像(或者数据集大小和 30 中的较小值)来估算整个数据集的内存需求。
n = min(self.n, 30) # extrapolate from 30 random images 根据 30 张随机图像进行推断。
# 遍历选定的图像数量。
for _ in range(n):
# random.choice(sequence)
# random.choice() 是 Python 标准库 random 模块中的一个函数,用于从序列中随机选择一个元素。
# 参数 :
# sequence : 一个序列,比如列表、元组或字符串,从中随机选择一个元素。
# 返回值 :
# random.choice() 函数返回序列中的一个随机元素。
# 随机选择一个图像文件并使用 OpenCV 的 imread 函数加载它。
im = cv2.imread(random.choice(self.im_files)) # sample image
# 计算加载图像后的目标尺寸与原始尺寸的比例。
ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
# 更新总字节数 b ,考虑到图像尺寸调整后的内存需求。
# 这行代码出现在 check_cache_ram 方法中,用于估算缓存整个数据集到 RAM 所需的内存量。
# b += im.nbytes * ratio ** 2 :
# b 是一个变量,用于累计估算的缓存图像所需的总字节数。
# im.nbytes 是一个属性,表示当前图像 im 的原始字节数。
# ratio 是一个变量,表示图像调整大小后与原始大小的比例,用于估算调整大小后的图像所需的内存。
# ratio ** 2 是将比例平方,因为图像的面积(宽度乘以高度)会按照比例的平方放大或缩小。
# 这行代码通过将原始图像的字节数乘以比例的平方,来估算调整大小后的图像所需的内存量,并累加到 b 中。
# 例如,如果原始图像大小为 1000x1000 像素,调整目标大小为 640x640 像素,则比例为 640 / 1000 = 0.64 ,比例的平方为 0.64 ** 2 = 0.4096 。如果原始图像大小为 3MB(大约 3 * 1024 * 1024 字节),则调整大小后的图像大小大约为 3 * 1024 * 1024 * 0.4096 字节。
# 该代码的目的是通过估算调整大小后的图像所需的内存量,来预测缓存整个数据集到 RAM 所需的内存量,从而决定是否有足够的内存来缓存数据集。
# 这种方法考虑了图像调整大小对内存需求的影响,提供了一个更准确的内存需求估算。
b += im.nbytes * ratio ** 2
# 根据样本图像计算整个数据集缓存到 RAM 中所需的内存。
mem_required = b * self.n / n # GB required to cache dataset into RAM 将数据集缓存到 RAM 中所需的 GB 。
# psutil.virtual_memory()
# psutil.virtual_memory() 是一个函数,属于 psutil 库,用于获取系统虚拟内存(RAM)的使用情况。
# 参数 :无参数。
# 返回值 :
# 该函数返回一个命名元组( psutil._common.smem ),其中包含了以下属性 :
# total :总物理内存大小,单位为字节。
# available :可供分配的内存大小,单位为字节,这个值是系统认为可用的内存,包括缓存和缓冲区占用的内存。
# percent :已使用内存的百分比。
# used :已使用的内存大小,单位为字节。
# free :空闲的内存大小,单位为字节。
# active :当前正在使用或最近使用的内存,单位为字节。
# inactive :标记为未使用的内存,单位为字节。
# buffers :缓存数据,如文件系统元数据,单位为字节。
# cached :缓存数据,单位为字节。
# shared :可由多个进程共享的内存,单位为字节。
# slab :用于内核数据结构的内存,单位为字节。
# 使用 psutil 库获取系统虚拟内存(RAM)的信息。
mem = psutil.virtual_memory()
# 判断是否有足够的可用内存来缓存数据集,考虑到安全边际。
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question 缓存还是不缓存,这是个问题。
# 如果没有足够的内存来缓存数据集,记录一条日志信息。
if not cache:
# 使用 LOGGER 记录当前的内存需求、可用内存和是否缓存图像的决定。
LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " # 需要 {prefix}{mem_required / gb:.1f}GB RAM,
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " # {mem.available / gb:.1f}/{mem.total / gb:.1f}可用GB数,
f"{'caching images ✅' if cache else 'not caching images ⚠️'}") # {'缓存图像✅' if cache else '不缓存图像⚠️'}。
# 返回一个布尔值,指示是否有足够的内存来缓存数据集。
return cache
# 这个方法的目的是确保在尝试将数据集缓存到 RAM 之前,系统有足够的可用内存,以避免内存不足导致的问题。通过估算数据集的内存需求并与系统的可用内存进行比较,这个方法提供了一个预防措施,确保数据加载过程的顺利进行。
# 这段代码定义了一个名为 cache_labels 的方法,它是 LoadImagesAndLabels 类的一部分。这个方法用于缓存数据集的标签,检查图像,并读取图像的形状。
# 这是 cache_labels 方法的定义,它接受两个参数。
# path :缓存文件的路径,默认为当前目录下的 labels.cache 。
# prefix :日志消息的前缀,默认为空字符串。
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes 缓存数据集标签、检查图像并读取形状。
# 初始化一个空字典 x ,用于存储缓存的数据。
x = {} # dict
# 初始化计数器和消息列表,用于跟踪 缺失 、 找到 、 空 、 损坏 的标签数量以及 警告消息 。
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
# 创建一个描述字符串 desc ,用于 tqdm 进度条的描述。
desc = f"{prefix}Scanning {path.parent / path.stem}..." # {prefix} 正在扫描 {path.parent / path.stem}...
# 使用 multiprocessing.Pool 创建一个进程池, NUM_THREADS 是进程数量。
with Pool(NUM_THREADS) as pool:
# pool.imap(func, iterable, chunksize=None)
# imap() 方法用于将一个可迭代的输入序列分块分配到线程池中的线程进行处理,并将结果返回一个迭代器。这个方法特别适用于需要顺序处理输入和输出的场景。
# 参数 :
# func :一个函数,它将被调用并传入 iterable 中的每个项目。
# iterable :一个可迭代对象,其元素将被传递给 func 函数。
# chunksize :(可选)一个整数,指定了每个任务传递给 func 的项目数量。默认值为 1,意味着每个任务只包含一个项目。如果设置为大于 1 的值,那么 func 将接收到一个包含多个项目的列表。
# 返回值 :
# 返回一个 Iterator ,它生成每个输入元素经过 func 处理后的结果。
# 特点 :
# 结果的顺序与输入序列的顺序相同。
# 如果任何一个任务因为异常而终止, imap() 会立即抛出异常。
# 它允许主线程在子线程完成工作之前继续执行,而不是等待所有任务完成。
# ThreadPool.imap() 是处理 I/O 密集型任务或者需要顺序处理结果的并发任务的有用工具。与之相对的是 imap_unordered() ,它同样返回一个迭代器,但是结果的顺序可能与输入序列不同,适用于不在乎结果顺序的场景。
# 使用 tqdm 创建一个进度条 pbar ,显示 verify_image_label 函数的执行进度。
# pool.imap 将 verify_image_label 函数应用于 self.im_files 、 self.label_files 和 prefix 的组合。
# zip(self.im_files, self.label_files, repeat(prefix)) 将图像文件路径、标签文件路径和前缀打包成一个元组列表。
pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
desc=desc,
total=len(self.im_files),
bar_format=TQDM_BAR_FORMAT)
# 这段代码是 cache_labels 方法中的一部分,它遍历由 tqdm 进度条 pbar 提供的迭代器,处理每个图像文件及其对应的标签和元数据。
# 这行代码遍历 pbar 进度条,它是由 pool.imap 生成的迭代器。
# im_file :是当前处理的图像文件路径。
# lb :是对应的标签数据。
# shape :是图像的形状(尺寸)。
# segments :是分割数据(如果有)。
# nm_f , nf_f , ne_f , nc_f :是计数器,分别表示 缺失 ( nm_f ) 、找到 ( nf_f ) 、 空 ( ne_f ) 、 损坏 ( nc_f ) 的标签数量。
# msg :是可能的警告或错误消息。
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
# 将当前迭代的 缺失标签计数 加到总缺失标签计数 nm 上。
nm += nm_f
# 将当前迭代的 找到标签计数 加到总找到标签计数 nf 上。
nf += nf_f
# 将当前迭代的 空标签计数 加到总空标签计数 ne 上。
ne += ne_f
# 将当前迭代的 损坏标签计数 加到总损坏标签计数 nc 上。
nc += nc_f
# 检查 im_file 是否存在(非空)。
if im_file:
# 如果 im_file 存在,将图像文件路径作为键,其对应的标签、形状和分割数据作为值,存储在字典 x 中。
x[im_file] = [lb, shape, segments]
# 检查是否有警告或错误消息。
if msg:
# 如果有消息,将其添加到消息列表 msgs 中。
msgs.append(msg)
# 更新进度条的描述,显示当前 找到的图像数量 ( nf ) 、 背景数量 ( nm + ne ) 和 损坏的图像数量 ( nc ) 。
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" # {desc} {nf} 图像,{nm + ne} 背景,{nc} 损坏。
# 这段代码的目的是收集和更新处理过程中的关键信息,包括标签和图像的统计数据,以及任何警告或错误消息。这些信息对于监控数据处理过程和后续的调试非常有用。通过更新进度条描述,用户可以实时了解数据处理的进度和状态。
# 关闭进度条。
pbar.close()
# 如果有警告消息,使用 LOGGER 记录它们。
if msgs:
LOGGER.info('\n'.join(msgs))
# 如果没有找到标签,记录一条警告日志。
if nf == 0:
LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') # {prefix}警告 ⚠️ 在 {path} 中未找到标签。{HELP_URL}。
# 计算标签文件和图像文件的哈希值,并将其存储在 x 字典中。
# # def get_hash(paths): -> 用于计算一个包含文件或目录路径列表的哈希值。计算最终的哈希值,并以十六进制字符串的形式返回。 -> return h.hexdigest() # return hash
x['hash'] = get_hash(self.label_files + self.im_files)
# 将统计结果存储在 x 字典中。
x['results'] = nf, nm, ne, nc, len(self.im_files)
# 将警告消息列表存储在 x 字典中。
x['msgs'] = msgs # warnings
# 将缓存版本号存储在 x 字典中。
x['version'] = self.cache_version # cache version
# 尝试将 x 字典保存为缓存文件。
try:
# 使用 numpy.save 函数保存缓存文件。
np.save(path, x) # save cache for next time
# 重命名缓存文件,移除 .npy 后缀。
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
# 记录一条日志。
LOGGER.info(f'{prefix}New cache created: {path}') # {prefix} 创建新缓存:{path}。
# 如果保存缓存文件时发生异常,记录一条警告日志。
except Exception as e:
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable {prefix}警告 ⚠️ 缓存目录 {path.parent} 不可写:{e}。
# 返回缓存的数据。
return x
# 这个方法的目的是创建一个缓存文件,存储数据集的标签和图像信息,以便在下次加载数据集时可以快速访问。通过并行处理和进度条显示,这个方法提供了一个高效且用户友好的方式来缓存大型数据集。
# 这段代码定义了一个名为 __len__ 的特殊方法,它是 Python 类中的一个内置方法,用于返回对象的长度。在 LoadImagesAndLabels 类的上下文中,这个方法返回数据集中图像文件的数量。
# 1.self : 指代类的实例本身。
# 2.self.im_files : 类的一个属性,是一个包含图像文件路径的列表。
def __len__(self):
# 返回值。 __len__ 方法返回 self.im_files 列表的长度,即数据集中图像文件的数量。
return len(self.im_files)
# __len__ 方法的目的是让类的实例能够与 Python 的内置 len() 函数一起使用,这样用户就可以方便地获取数据集的大小,而不需要直接访问类的内部属性。这使得代码更加简洁和易于维护。
# def __iter__(self):
# self.count = -1
# print('ran dataset iter')
# #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
# return self
# 这段代码定义了一个名为 __getitem__ 的特殊方法,它是 Python 类中的一个内置方法,用于实现数据集对象的索引操作。在 LoadImagesAndLabels 类的上下文中,这个方法负责根据给定的索引 index 加载和处理一个图像及其对应的标签,然后返回处理后的图像和标签数据。
# 定义 __getitem__ 方法,它允许对象通过索引访问,类似于列表或数组。
# 1.self :代表类的实例。
# 2.index :是访问的数据项索引。
def __getitem__(self, index):
# 这段代码是 LoadImagesAndLabels 类的 __getitem__ 方法的一部分,它负责根据给定的索引 index 加载和处理一个图像及其对应的标签,然后返回处理后的图像和标签数据。
# 获取实际的索引。如果 self.indices 是一个简单的范围列表,则 index 直接对应于图像文件列表中的位置。如果 self.indices 被随机打乱或根据图像权重调整,则 index 将被映射到相应的位置。
index = self.indices[index] # linear, shuffled, or image_weights
# 获取类的 hyp 属性,它包含了超参数,这些参数控制训练过程中的各种设置。
hyp = self.hyp
# random_number = random.random()
# random.random() 是 Python 标准库 random 模块中的一个函数,用于生成一个随机浮点数。
# 参数 :这个函数不需要任何参数。
# 返回值 :
# random.random() 函数返回一个随机浮点数,这个数在 0.0(包括)和 1.0(不包括)之间 (0.0 <= N < 1.0)。
# 检查是否进行镶嵌(mosaic)增强。 self.mosaic 是一个布尔值,指示是否启用镶嵌增强。 hyp['mosaic'] 是一个概率值,表示选择镶嵌增强的概率。
mosaic = self.mosaic and random.random() < hyp['mosaic']
# 如果条件满足,执行以下代码块。
if mosaic:
# Load mosaic
# 调用 self.load_mosaic 方法,根据索引加载一个镶嵌图像和对应的标签。镶嵌增强是一种数据增强技术,通过将多个图像拼接成一个图像来增加数据多样性。
img, labels = self.load_mosaic(index)
# 初始化 shapes 变量为 None 。在镶嵌增强的情况下,原始图像的形状可能不再适用,因此将其设置为 None 。
shapes = None
# MixUp augmentation
# 检查是否进行 MixUp 增强。 random.random() 生成一个随机数,如果这个数小于 hyp['mixup'] 指定的概率,则执行 MixUp 增强。
if random.random() < hyp['mixup']:
# 如果进行 MixUp 增强,调用 mixup 函数。这个函数将当前图像和标签与另一个随机选择的镶嵌图像及其标签混合。 random.randint(0, self.n - 1) 用于生成一个随机索引, self.load_mosaic 根据这个索引加载另一个镶嵌图像和标签,然后这些数据被传递给 mixup 函数。
# def mixup(im, labels, im2, labels2): -> 实现了一种称为 MixUp 的数据增强技术。返回混合后的图像 im 和新的标签数组 labels 。 -> return im, labels
img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
# 这段代码的目的是实现两种高级数据增强技术:镶嵌(mosaic)和 MixUp。这些技术可以增加数据集的多样性,提高模型的泛化能力。通过随机选择和组合图像,这些方法可以帮助模型学习更鲁棒的特征。
# 这段代码是 LoadImagesAndLabels 类的 __getitem__ 方法的继续部分,它处理非镶嵌(mosaic)情况下的图像加载和增强。
# 这个 else 块与之前的 if mosaic: 块相对应,表示如果不进行镶嵌增强,则执行这里的代码。
else:
# Load image 说明下面的代码块用于加载单个图像。
# 调用 self.load_image 方法,根据索引加载图像,并返回 像数组 img ,以及 图像的原始尺寸 (h0, w0) 和 加载后的尺寸 (h, w) 。
img, (h0, w0), (h, w) = self.load_image(index)
# Letterbox 说明下面的代码块用于对图像进行 Letterbox 调整。
# 根据是否进行矩形训练( self.rect ),确定最终的 Letterbox 形状。如果是矩形训练,使用 self.batch_shapes 中对应的形状;否则,使用 self.img_size 。
shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
# 调用 letterbox 函数,对图像进行 Letterbox 调整,保持原始比例不变。 auto=False 表示不自动调整, scaleup=self.augment 表示在数据增强时允许放大图像。
# def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# -> 用于将输入图像 im 调整大小并填充,以适应新的尺寸 new_shape ,同时保持图像的宽高比,并确保结果图像的尺寸是 stride 的倍数。返回 填充后的图像 、 缩放比例 和 填充尺寸 。
# -> return im, ratio, (dw, dh)
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
# 存储 原始尺寸 和 调整尺寸的比例 以及 填充信息 ,以便后续可能需要的 COCO mAP 重新缩放。
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
# 复制 当前索引 对应的 标签数据 。
labels = self.labels[index].copy()
# 检查 标签数组 是否非空。
if labels.size: # normalized xywh to pixel xyxy format
# 如果标签非空,将归一化的 xywh 格式标签转换为像素 xyxy 格式,并根据比例和填充调整标签坐标。
# def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): -> 用于将边界框的坐标从中心点坐标加宽高( x, y, w, h )的格式转换为左上角和右下角坐标( x1, y1, x2, y2 )的格式。返回转换后的边界框坐标数组 y ,格式为 [x1, y1, x2, y2] 。 -> return y
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
# 检查是否进行数据增强。
if self.augment:
# 如果进行数据增强,调用 random_perspective 函数,对图像和标签进行随机透视变换。这个函数接受多个参数,包括 degrees 、 translate 、 scale 、 shear 和 perspective ,这些参数控制透视变换的程度。
# def random_perspective(im,targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
# -> 用于对图像进行随机透视变换,同时对图像中的标签(如边界框)进行相应的变换。返回结果。这行代码返回变换后的图像 im 和更新后的目标 targets 。
# -> return im, targets
img, labels = random_perspective(img,
labels,
degrees=hyp['degrees'],
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'],
perspective=hyp['perspective'])
# 这段代码展示了在非镶嵌情况下,如何加载图像、进行 Letterbox 调整、转换标签格式以及应用随机透视变换等数据增强技术。这些步骤有助于提高模型的泛化能力,通过创建新的、多样化的训练样本来模拟训练数据集中不存在的情况。
# 这段代码继续处理 LoadImagesAndLabels 类的 __getitem__ 方法中的图像和标签数据,包括标签格式转换、数据增强以及各种图像变换。
# 获取标签数组 labels 的长度,即标签的数量,并存储在变量 nl 中。
nl = len(labels) # number of labels
# 检查是否有标签(即 nl 是否大于0)。
if nl:
# 如果有标签,将标签从 xyxy 格式(四个角点的坐标)转换为 xywh 格式(中心点坐标和宽高),并进行归一化处理。 w 和 h 分别是图像的宽度和高度, clip=True 表示将超出图像边界的坐标限制在边界内, eps=1E-3 用于处理数值稳定性问题。
# def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): -> 将边界框的坐标从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式,并且将坐标归一化到 [0, 1] 范围内。返回转换后的边界框坐标数组。 -> return y
labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
# 检查是否需要进行数据增强。
if self.augment:
# Albumentations
# 如果需要增强,使用 albumentations 库对图像和标签进行增强。这是一个强大的图像增强库,可以进行各种复杂的变换。
img, labels = self.albumentations(img, labels)
# 在应用 albumentations 后更新标签数量。
nl = len(labels) # update after albumentations
# HSV color-space
# 对图像进行 HSV(色调、饱和度、亮度)空间的增强,调整色调、饱和度和亮度的值。
# def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): -> 用于在HSV颜色空间中对图像进行增强。
augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
# Flip up-down
# 以一定的概率(由 hyp['flipud'] 控制)进行上下翻转(up-down flip)。
if random.random() < hyp['flipud']:
# 实际执行上下翻转操作。
img = np.flipud(img)
# 检查是否有标签。
if nl:
# 如果有标签,更新标签的 y 坐标,以反映上下翻转的影响。
labels[:, 2] = 1 - labels[:, 2]
# Flip left-right
# 以一定的概率(由 hyp['fliplr'] 控制)进行左右翻转(left-right flip)。
if random.random() < hyp['fliplr']:
# 实际执行左右翻转操作。
img = np.fliplr(img)
# 检查是否有标签。
if nl:
# 如果有标签,更新标签的 x 坐标,以反映左右翻转的影响。
labels[:, 1] = 1 - labels[:, 1]
# Cutouts
# 这是一个被注释掉的代码行,表示可以在这里添加 CutOut 数据增强,这是一种通过在图像中随机遮盖部分区域来增加模型鲁棒性的方法。
# labels = cutout(img, labels, p=0.5)
# 如果使用了 CutOut 增强,需要更新标签数量。
# nl = len(labels) # update after cutout
# 这段代码展示了在深度学习训练中如何对图像和标签进行预处理和数据增强,以提高模型的性能和泛化能力。通过应用各种图像变换和增强技术,可以模拟训练数据集中不存在的情况,使模型更加健壮。
# 这段代码是 LoadImagesAndLabels 类的 __getitem__ 方法的最后部分,它负责将处理后的图像和标签数据转换为 PyTorch 张量,并返回这些数据。
# 创建一个形状为 (nl, 6) 的零张量 labels_out ,其中 nl 是标签的数量。这个张量用于存储转换后的标签数据。
labels_out = torch.zeros((nl, 6))
# 检查是否有标签(即 nl 是否大于0)。
if nl:
# 如果有标签,将 NumPy 数组格式的 labels 转换为 PyTorch 张量,并存储在 labels_out 的第 2 列及以后的列中。这里假设 labels_out 的第 1 列(索引为 0)用于存储类别标签,而后面的列用于存储其他信息(如边界框坐标)。
labels_out[:, 1:] = torch.from_numpy(labels)
# Convert
# 将图像 img 从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式,以适应 PyTorch 的要求。 [::-1] 用于将 BGR 格式的图像转换为 RGB 格式。
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
# 确保 NumPy 数组 img 在内存中是连续的,这对于后续的 PyTorch 操作是必要的。
img = np.ascontiguousarray(img)
# 将 NumPy 数组 img 转换为 PyTorch 张量,并返回这个 张量 、 标签张量 labels_out 、 当前图像的文件路径 self.im_files[index] ,以及可能的 形状信息 shapes 。
return torch.from_numpy(img), labels_out, self.im_files[index], shapes
# 这段代码的目的是将图像和标签数据转换为 PyTorch 张量,以便可以在 PyTorch 框架中用于训练和推理。通过这种方式,数据可以被有效地传递给神经网络模型,并且可以利用 PyTorch 的 GPU 加速功能。
# 这个方法是数据加载过程中的核心,它确保了图像和标签的加载、预处理和增强操作能够根据索引进行。通过这种方式,深度学习模型可以在训练时逐个或批量地访问数据集中的图像和标签。
# 这段代码定义了一个名为 load_image 的方法,它是 LoadImagesAndLabels 类的一部分。这个方法负责根据数据集的索引 i 加载一个图像,并返回图像本身以及它的原始尺寸和调整后的尺寸。
# 这是 load_image 方法的定义,它接受两个参数。
# 1.self :类的实例本身。
# 2.i :图像在数据集中的索引。
def load_image(self, i):
# Loads 1 image from dataset index 'i', returns (im, original hw, resized hw) 从数据集索引"i"加载 1 个图像,返回 (im, 原始 hw, 调整大小的 hw)。
# 从类的属性中获取第 i 个图像的 缓存数据 im , 图像文件路径 f ,以及对应的 .npy 文件路径 fn 。
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
# 检查缓存的图像数据 im 是否为 None ,即是否已经在 RAM 中缓存。
if im is None: # not cached in RAM
# 如果 .npy 文件存在,则加载该文件中的图像数据。
if fn.exists(): # load npy
# 使用 NumPy 的 load 函数加载 .npy 文件中的图像数据。
im = np.load(fn)
# 如果 .npy 文件不存在,使用 OpenCV 的 imread 函数从图像文件路径 f 读取图像。
else: # read image
# 读取图像文件,注意 OpenCV 默认以 BGR 格式读取图像。
im = cv2.imread(f) # BGR
# 确保图像数据 im 不为 None ,否则抛出异常,表示图像文件未找到。
assert im is not None, f'Image Not Found {f}' # 未找到图片 {f}。
# 获取原始图像的高 h0 和宽 w0 。
h0, w0 = im.shape[:2] # orig hw
# 计算调整图像大小的比例 r ,以确保图像的最大边不超过 self.img_size 。
r = self.img_size / max(h0, w0) # ratio
# 如果比例 r 不等于 1,即图像大小需要调整。
if r != 1: # if sizes are not equal
# 根据是否进行数据增强或比例 r 是否大于 1,选择插值方法。
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
# 使用 OpenCV 的 resize 函数调整图像大小。
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
# 返回调整后的图像 im ,原始尺寸 (h0, w0) ,以及 调整后的尺寸 。
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
# 如果图像已经在 RAM 中缓存,则直接返回缓存的图像数据和尺寸。(图像、原始尺寸、调整后的尺寸)
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
# 这个方法的目的是确保图像数据以正确的尺寸和格式加载,无论是从缓存中直接获取还是从磁盘加载并调整大小。这使得数据集的图像可以被神经网络模型有效地处理。
# 这段代码定义了一个名为 cache_images_to_disk 的方法,它是 LoadImagesAndLabels 类的一部分。这个方法负责将图像数据保存到磁盘上,以 .npy 文件格式存储,以便后续能够更快地加载这些图像。
# 这是 cache_images_to_disk 方法的定义,它接受两个参数。
# 1.self :类的实例本身。
# 2.i :图像在数据集中的索引。
def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading
# 从类的属性 self.npy_files 中获取第 i 个图像对应的 .npy 文件路径。
f = self.npy_files[i]
# 检查对应的 .npy 文件是否已经存在于磁盘上。
if not f.exists():
# 如果文件不存在,则执行以下操作 :
# 使用 OpenCV 的 imread 函数从图像文件路径 self.im_files[i] 读取图像数据。
# 使用 NumPy 的 save 函数将读取的图像数据保存到 .npy 文件中。 f.as_posix() 将路径转换为字符串(如果它还不是字符串)。
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
# 这个方法的目的是在第一次加载图像时将其缓存到磁盘上,这样在后续的训练过程中,图像可以直接从磁盘上的 .npy 文件中快速加载,而不需要每次都从原始图像文件中重新读取和处理。这可以显著提高数据加载的效率,特别是在处理大型数据集时。
# 这段代码定义了一个名为 load_mosaic 的方法,它是 LoadImagesAndLabels 类的一部分。这个方法负责创建一个由四张图像拼接而成的镶嵌图像(mosaic),并处理相应的标签。
# 这是 load_mosaic 方法的定义,它接受两个参数。
# 1.self :类的实例本身。
# 2.index :用于选择四张图像中的第一张。
def load_mosaic(self, index):
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic YOLOv5 4 马赛克加载器。将 1 幅图像 + 3 幅随机图像加载到 4 幅图像马赛克中。
# 初始化两个空列表,用于存储 四张图像的 标签 和 分割数据 。
labels4, segments4 = [], []
# 获取图像的目标尺寸。
s = self.img_size
# 随机生成镶嵌图像的中心坐标 xc 和 yc 。
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
# 选择四张图像的索引,包括传入的 index 和随机选择的三个额外索引。
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
# 随机打乱索引,以确保镶嵌的随机性。
random.shuffle(indices)
# 遍历索引,加载每张图像,并将其放置在镶嵌图像的正确位置。
for i, index in enumerate(indices):
# Load image
# 加载每张图像,并获取其原始尺寸和调整后的尺寸。
img, _, (h, w) = self.load_image(index)
# place img in img4
if i == 0: # top left
# 创建一个基础镶嵌图像,填充值为 114(通常是背景色)。
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
# 计算大图像的坐标范围。
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
# 计算小图像的坐标范围。
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
# 将小图像放置在大图像的正确位置。
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
# 计算水平填充。
padw = x1a - x1b
# 计算垂直填充。
padh = y1a - y1b
# 这段代码处理从数据集中获取的图像对应的标签和分割数据,并将它们从归一化坐标转换为像素坐标。
# Labels
# 获取索引 index 对应的 标签 和 分割数据 ,并复制它们。这是为了防止在原始数据上进行修改。
labels, segments = self.labels[index].copy(), self.segments[index].copy()
# 检查标签数组 labels 是否非空。
if labels.size:
# 如果标签非空,将标签中的边界框坐标从归一化的 xywh 格式(中心点坐标加宽高)转换为像素级的 xyxy 格式(两个对角点坐标)。这里 w 和 h 是调整后的图像宽度和高度, padw 和 padh 是水平和垂直填充。
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
# 对于每个分割数据 x ,将其从归一化的 xyn 格式(中心点坐标加宽高)转换为像素级的 xy 格式(两个端点坐标)。这里同样使用 w 、 h 、 padw 和 padh 进行转换。
segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
# 将转换后的标签 labels 添加到 labels4 列表中,这个列表用于存储所有四张图像的标签。
labels4.append(labels)
# 将转换后的分割数据 segments 添加到 segments4 列表中,这个列表用于存储所有四张图像的分割数据。
segments4.extend(segments)
# 这段代码的目的是将标签和分割数据从归一化坐标转换为像素坐标,以便它们可以被正确地应用到实际的图像尺寸上。这对于目标检测和图像分割等任务至关重要,因为模型需要知道边界框和分割的确切像素位置。通过这种方式,可以确保在图像增强和变换过程中标签和分割数据的准确性。
# 这段代码继续处理标签数据,将它们合并并进行裁剪,以确保标签坐标在图像的有效范围内。
# Concat/clip labels
# 使用 NumPy 的 concatenate 函数将 labels4 列表中的所有标签数组沿着第一个维度(0维,即行)合并成一个大的数组。这通常是在处理多张图像的标签时,将它们堆叠成一个数组以便于批量处理。
labels4 = np.concatenate(labels4, 0)
# 遍历 labels4 中除了第一列之外的所有列(即所有边界框坐标)以及 segments4 中的所有分割数据。 labels4[:, 1:] 表示 labels4 数组中除了第一列(通常是类别标签)之外的所有列, *segments4 是将 segments4 列表解包为独立的参数。
for x in (labels4[:, 1:], *segments4):
# 对于每个 x (边界框坐标或分割数据),使用 NumPy 的 clip 函数将其中的值限制在 [0, 2 * s] 的范围内。 s 是图像的尺寸, 2 * s 表示图像尺寸的两倍,这通常用于处理随机透视变换时坐标可能超出图像边界的情况。 out=x 表示在原地修改 x 。
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
# 这是一个被注释掉的代码行,如果取消注释,它将执行一个复制粘贴增强操作。这个操作会随机选择图像的一部分并将其复制到图像的另一个位置,同时更新对应的标签数据。这种增强技术可以增加数据的多样性并提高模型的泛化能力。
# img4, labels4 = replicate(img4, labels4) # replicate
# 这段代码的目的是确保所有的标签和分割数据在进行数据增强和变换后仍然保持有效,即它们的坐标不会超出图像的边界。这对于训练一个鲁棒的模型非常重要,因为模型需要能够处理各种边界情况。
# Augment
# 执行复制粘贴增强。
# def copy_paste(im, labels, segments, p=0.5): -> 用于在图像增强过程中执行复制粘贴操作。这个函数从原始图像中随机选择对象,将其复制到新的位置,并更新相应的标签和线段。返回经过复制粘贴操作后的 图像 、更新后的 标签 和 线段 。 -> return im, labels, segments
img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
# 执行随机透视变换。
# def random_perspective(im,targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
# -> 用于对图像进行随机透视变换,同时对图像中的标签(如边界框)进行相应的变换。返回结果。这行代码返回变换后的图像 im 和更新后的目标 targets 。
# -> return im, targets
img4, labels4 = random_perspective(img4,
labels4,
segments4,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],
scale=self.hyp['scale'],
shear=self.hyp['shear'],
perspective=self.hyp['perspective'],
border=self.mosaic_border) # border to remove
# 返回镶嵌图像和对应的标签。
return img4, labels4
# 这个方法的目的是创建一个由四张图像拼接而成的镶嵌图像,并处理相应的标签,以增加数据的多样性并提高模型的泛化能力。通过这种方式,可以在训练过程中模拟不同的场景和图像组合,从而使模型更加健壮。
# 这段代码定义了一个名为 load_mosaic9 的方法,它是 LoadImagesAndLabels 类的一部分。这个方法负责创建一个由九张图像拼接而成的镶嵌图像(mosaic),并处理相应的标签和分割数据。
# 这是 load_mosaic9 方法的定义,它接受两个参数。
# 1.self :类的实例本身。
# 2.index :用于选择九张图像中的第一张。
def load_mosaic9(self, index):
# YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
# 初始化两个空列表,用于存储九张图像的 标签 和 分割数据 。
labels9, segments9 = [], []
# 获取图像的目标尺寸。
s = self.img_size
# 选择九张图像的索引,包括传入的 index 和随机选择的八个额外索引。
indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
# 随机打乱索引,以确保镶嵌的随机性。
random.shuffle(indices)
# 初始化前一张图像的 高度 和 宽度 为 -1。
hp, wp = -1, -1 # height, width previous
# 遍历索引,加载每张图像,并将其放置在镶嵌图像的正确位置。
for i, index in enumerate(indices):
# Load image
# 加载每张图像,并获取其原始尺寸和调整后的尺寸。
img, _, (h, w) = self.load_image(index)
# place img in img9
if i == 0: # center
# 创建一个基础镶嵌图像,填充值为 114(通常是背景色)。
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
# 存储当前图像的高度和宽度。
h0, w0 = h, w
# 计算基础图像的坐标范围。
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
elif i == 1: # top
c = s, s - h, s + w, s
elif i == 2: # top right
c = s + wp, s - h, s + wp + w, s
elif i == 3: # right
c = s + w0, s, s + w0 + w, s + h
elif i == 4: # bottom right
c = s + w0, s + hp, s + w0 + w, s + hp + h
elif i == 5: # bottom
c = s + w0 - w, s + h0, s + w0, s + h0 + h
elif i == 6: # bottom left
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
elif i == 7: # left
c = s - w, s + h0 - h, s, s + h0
elif i == 8: # top left
c = s - w, s + h0 - hp - h, s, s + h0 - hp
# 计算当前镶嵌区域的水平和垂直填充。
padx, pady = c[:2]
# 计算分配的坐标范围,确保它们非负。
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# Labels
# 复制当前索引的标签和分割数据。
labels, segments = self.labels[index].copy(), self.segments[index].copy()
if labels.size:
# 将标签从归一化的 xywh 格式转换为像素 xyxy 格式,并根据填充调整坐标。
labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
# 将分割数据从归一化的 xyn 格式转换为像素 xy 格式,并根据填充调整坐标。
segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
# 将转换后的标签添加到 labels9 列表中。
labels9.append(labels)
# 将转换后的分割数据添加到 segments9 列表中。
segments9.extend(segments)
# Image
# 将当前图像放置在基础镶嵌图像的正确位置。
img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous
# Offset
# 随机生成镶嵌图像的中心坐标 xc 和 yc 。
yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
# 从基础镶嵌图像中裁剪出一个中心区域,作为最终的镶嵌图像。
img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
# 这段代码处理在创建9宫格镶嵌图像(mosaic)后标签的合并、调整和裁剪操作。
# Concat/clip labels
# 使用 NumPy 的 concatenate 函数将 labels9 列表中的所有标签数组沿着第一个维度(0维,即行)合并成一个大的数组。
labels9 = np.concatenate(labels9, 0)
# 调整合并后的标签数组中所有边界框的 x 坐标(第1列和第3列),减去镶嵌图像中心的 x 坐标 xc ,以确保坐标相对于镶嵌图像的中心。
labels9[:, [1, 3]] -= xc
# 调整合并后的标签数组中所有边界框的 y 坐标(第2列和第4列),减去镶嵌图像中心的 y 坐标 yc ,以确保坐标相对于镶嵌图像的中心。
labels9[:, [2, 4]] -= yc
# 创建一个包含镶嵌图像中心坐标的 NumPy 数组。
c = np.array([xc, yc]) # centers
# 对于 segments9 列表中的每个分割数据点,减去中心坐标 c ,以调整分割数据点的位置。
segments9 = [x - c for x in segments9]
# 遍历 labels9 中除了类别标签之外的所有列(即所有边界框坐标)以及 segments9 中的所有分割数据。
for x in (labels9[:, 1:], *segments9):
# 对于每个 x (边界框坐标或分割数据),使用 NumPy 的 clip 函数将其中的值限制在 [0, 2 * s] 的范围内。 s 是图像的尺寸, 2 * s 表示图像尺寸的两倍,这通常用于处理随机透视变换时坐标可能超出图像边界的情况。 out=x 表示在原地修改 x 。
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
# 这是一个被注释掉的代码行,如果取消注释,它将执行一个复制粘贴增强操作。这个操作会随机选择图像的一部分并将其复制到图像的另一个位置,同时更新对应的标签数据。这种增强技术可以增加数据的多样性并提高模型的泛化能力。
# img9, labels9 = replicate(img9, labels9) # replicate
# 这段代码的目的是确保所有的标签和分割数据在进行数据增强和变换后仍然保持有效,即它们的坐标不会超出图像的边界。这对于训练一个鲁棒的模型非常重要,因为模型需要能够处理各种边界情况。
# Augment
# 执行复制粘贴增强。
img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
# 执行随机透视变换。
img9, labels9 = random_perspective(img9,
labels9,
segments9,
degrees=self.hyp['degrees'],
translate=self.hyp['translate'],
scale=self.hyp['scale'],
shear=self.hyp['shear'],
perspective=self.hyp['perspective'],
border=self.mosaic_border) # border to remove
# 返回最终的镶嵌图像和对应的标签。
return img9, labels9
# 这个方法的目的是创建一个由九张图像拼接而成的镶嵌图像,并处理相应的标签和分割数据,以增加数据的多样性并提高模型的泛化能力。通过这种方式,可以在训练过程中模拟不同的场景和图像组合,从而使模型更加健壮。
# 这段代码定义了一个名为 collate_fn 的静态方法,通常用在 PyTorch 的 DataLoader 中,以自定义数据批量的组合方式。
# 这个装饰器表示 collate_fn 是一个静态方法,它不需要访问类的实例( self )。
@staticmethod
# 这是 collate_fn 方法的定义,它接受一个参数。
# 1.batch :这个参数是一个列表,包含了从数据集中获取的多个数据项。
def collate_fn(batch):
# 在目标检测任务中, label 数组的形状和每列的含义可能会根据具体的数据集和模型设计而有所不同。但是,通常情况下, label 数组是一个二维数组,其中每一行代表一个边界框的标签信息,每一列代表不同的属性。以下是一些常见的含义 :
# 类别标签(Class Label) :
# 通常, label 数组的第一列表示边界框所属的类别标签。这个值通常是整数,从0开始编码,每个整数对应一个特定的类别。
# 边界框坐标(Bounding Box Coordinates) :
# 接下来的四列通常表示边界框的坐标。这些坐标可以以不同的格式表示,例如 :(x_center, y_center, width, height) :边界框的中心点坐标以及宽度和高度。 (x_min, y_min, x_max, y_max) :边界框的左上角和右下角坐标。 这些坐标可能是归一化的(即相对于图像宽度和高度的比例),也可能是像素坐标。
# 置信度得分(Confidence Score) :
# 在某些情况下, label 数组可能还会包含一个表示置信度得分的列,用于表示边界框包含目标对象的概率。
# 其他属性(Other Attributes) :
# 根据任务的不同, label 数组还可能包含其他属性,例如目标对象的分割掩码、关键点坐标等。
# 例如,如果 label 数组的形状是 (N, 5) ,其中 N 是边界框的数量,那么每行可能包含以下信息 : label[:, 0] 类别标签。 label[:, 1:3] 边界框的 x_center 和 y_center 坐标。 label[:, 3:5] :边界框的宽度和高度。
# 如果 label 数组包含置信度得分,那么它可能的形状是 (N, 6) ,其中最后一列是置信度得分。需要注意的是,这些只是常见的约定,具体的 label 数组的形状和含义应根据实际使用的数据集和模型进行确定。
# 使用 zip 函数和星号表达式( * )对 batch 列表中的每个数据项进行解包和转置。每个数据项为一个四元组 (image, label, path, shapes) ,这样 im 将包含所有的图像数据, label 将包含所有的标签数据,依此类推。
# 这行代码使用 Python 的内置函数 zip 和星号表达式( * )来解包和重组 batch 中的数据。以下是详细解释 :
# batch :
# batch 是一个列表,其中每个元素都是一个四元组 (image, label, path, shapes) 。这个列表由 PyTorch 的 DataLoader 在每个迭代中自动生成,包含了从数据集中按顺序取出的多个数据项。
# zip(*batch) :
# zip 函数通常用于将多个可迭代对象(如列表或元组)的对应元素打包成一个个元组,然后返回由这些元组组成的列表。
# 在这里, *batch 是星号表达式,它将 batch 列表中的每个元素(即四元组)解包为独立的参数传递给 zip 函数。
# zip 函数将这些四元组的对应元素分别打包成四个列表:一个包含所有的 图像数据 im ,一个包含所有的 标签数据 label ,一个包含所有的 路径数据 path ,和一个包含所有的 形状数据 shapes 。
# im, label, path, shapes = zip(*batch) :
# 这行代码将 zip 函数返回的四个列表分别赋值给四个变量 : im 、 label 、 path 和 shapes 。
# im 将包含 batch 中所有的图像数据。
# label 将包含 batch 中所有的标签数据。
# path 将包含 batch 中所有的图像文件路径。
# shapes 将包含 batch 中所有的图像形状信息。
# 这种数据重组方式使得可以分别对图像、标签、路径和形状进行批量处理,例如,将所有的图像数据堆叠成一个批次输入到神经网络中,或者将所有的标签数据连接起来进行批量处理。这是深度学习中常用的数据预处理步骤,有助于提高数据加载和处理的效率。
im, label, path, shapes = zip(*batch) # transposed
# 遍历 label 列表, i 是索引, lb 是当前的标签数据。
for i, lb in enumerate(label):
# 执行 lb[:, 0] = i 这行代码之后, lb 数组(即 label 数组)中第一列原有的类别标签会被覆盖掉。
# 这行代码的作用是将 lb 数组(标签数组)的第一列(索引为 0 的列)的所有值设置为当前图像的索引 i 。这样做的目的通常是为每个边界框添加一个唯一的图像索引,以便于在训练过程中区分来自不同图像的边界框。
# 如果 lb 数组原本的第一列包含的是类别标签,那么执行这行代码后,这些类别标签将被替换为图像索引。这意味着类别标签信息将丢失,因此在执行这行代码之前,需要确保已经保存或不需要这些类别标签信息。
# 在某些目标检测框架中,可能需要为每个边界框添加一个图像索引,以帮助在训练时构建目标或进行损失计算。这种情况下,类别标签通常存储在其他位置,或者在训练过程中不再需要使用原始的类别标签。如果需要保留类别标签,可能需要对代码进行相应的调整,以确保类别标签不会被覆盖。
# 如果在执行 lb[:, 0] = i 这行代码后,原本第一列的类别标签没有在其他地方保存,那么这确实会影响后续的模型训练。原因如下 :
# 类别信息丢失 :第一列原本存储的类别标签被覆盖后,每个边界框的类别信息就丢失了。这意味着在训练过程中,模型将无法知道每个边界框对应的具体类别,这对于分类任务来说是必要的信息。
# 影响损失计算 :在目标检测模型中,损失函数通常需要类别标签来计算分类损失(例如,交叉熵损失)。如果类别标签丢失,模型将无法正确计算分类损失,从而影响模型的学习效果。
# 影响评估指标 :在模型评估阶段,需要使用类别标签来计算各种评估指标,如精确度(precision)、召回率(recall)和平均精度均值(mean average precision, mAP)。没有类别标签,这些评估指标将无法计算,从而无法评估模型的性能。
# 数据不一致性 :如果在数据预处理阶段覆盖了类别标签,但在模型训练和评估时又需要这些信息,将导致数据不一致性问题。这可能会导致模型训练不稳定,甚至导致训练失败。
# 为了避免这些问题,可以采取以下措施 :
# 保留类别标签 :在覆盖第一列之前,确保已经将类别标签保存在其他地方,例如在单独的变量或文件中。
# 使用额外的变量或数据结构 :如果需要在标签数组中添加额外的信息(如图像索引),可以考虑使用额外的变量或数据结构来存储这些信息,而不是直接覆盖原有的类别标签。
# 修改数据预处理流程 :重新设计数据预处理流程,确保类别标签不会被无意中覆盖,同时满足模型训练的需求。
# 总之,类别标签对于目标检测模型的训练和评估至关重要。在数据预处理阶段,应确保类别标签的完整性和准确性,以避免对模型训练造成不利影响。
# 对于每个标签数据 lb ,将其第一列(索引为 0 的列)的所有值设置为当前的索引 i 。这样做是为了在训练时为目标检测添加目标图像的索引,这在某些目标检测框架中是必要的。
lb[:, 0] = i # add target image index for build_targets()
# 返回处理后的数据。
# torch.stack(im, 0) :将所有图像数据堆叠成一个四维张量,第一个维度是批次大小。 torch.cat(label, 0) :将所有标签数据连接成一个二维张量,第一个维度是批次大小。 path :包含所有图像文件路径的列表。 shapes :包含所有图像形状信息的列表。
return torch.stack(im, 0), torch.cat(label, 0), path, shapes
# 这个方法的目的是将从数据集中获取的多个数据项组合成一个批次,以便可以批量地输入到模型中进行训练。通过这种方式,可以提高数据加载的效率,并使数据批量处理更加灵活。
# 这段代码定义了一个名为 collate_fn4 的静态方法,它是用于处理数据集中的一批数据,并将它们组合成一个四宫格(4-mosaic)图像以及对应的标签。
# 这个装饰器表示 collate_fn4 是一个静态方法,它不需要访问类的实例( self )。
@staticmethod
# 这是 collate_fn4 方法的定义,它接受一个参数。
# 1.batch :这个参数是一个列表,包含了从数据集中获取的多个数据项。
def collate_fn4(batch):
# 使用 zip 函数和星号表达式( * )对 batch 列表中的每个数据项进行解包和转置。每个数据项为一个四元组 (image, label, path, shapes) 。
im, label, path, shapes = zip(*batch) # transposed
# 计算每组四宫格图像的数量。
n = len(shapes) // 4
# 初始化空列表 im4 和 label4 用于存储四宫格图像和标签, path4 和 shapes4 用于存储对应的路径和形状信息。
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
# 创建一个表示水平偏移的张量。
ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
# 创建一个表示垂直偏移的张量。
wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
# 创建一个表示缩放比例的张量。
s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
# 遍历每组四宫格图像。
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
# 将索引 i 乘以 4,因为每组四宫格包含 4 张图像。
i *= 4
# 以 50% 的概率决定是进行双倍插值还是拼接四张图像。
if random.random() < 0.5:
# 如果随机数小于 0.5,对第一张图像进行双倍插值,然后转换回原始数据类型。
im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
align_corners=False)[0].type(im[i].type())
# 如果进行双倍插值,只使用第一张图像的标签。
lb = label[i]
# 如果随机数大于或等于 0.5,将四张图像拼接成一个大图像,并更新标签。
else:
# 将四张图像水平和垂直拼接成一个大图像。
im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
# 更新标签,包括偏移和缩放。
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
# 将处理后的图像添加到 im4 列表中。
im4.append(im1)
# 将更新后的标签添加到 label4 列表中。
label4.append(lb)
# 遍历 label4 列表,为每个标签添加目标图像索引。
for i, lb in enumerate(label4):
# 在标签的第一列添加目标图像索引。
lb[:, 0] = i # add target image index for build_targets()
# 返回处理后的四宫格图像张量、标签张量、路径列表和形状列表。
return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
# 这个方法的目的是将一批数据组合成一个四宫格图像和对应的标签,以增加数据的多样性并提高模型的泛化能力。通过这种方式,可以在训练过程中模拟不同的场景和图像组合,从而使模型更加健壮。
14.def flatten_recursive(path=DATASETS_DIR / 'coco128'):
python
# Ancillary functions --------------------------------------------------------------------------------------------------
# 这段代码定义了一个名为 flatten_recursive 的函数,其目的是将一个递归目录(即包含子目录的目录)中的所有文件移动到该目录的顶级,从而"扁平化"目录结构。
# 这是 flatten_recursive 函数的定义,它接受一个参数。
# 1.path :该参数默认为 DATASETS_DIR / 'coco128' ,表示要扁平化的目录路径。
def flatten_recursive(path=DATASETS_DIR / 'coco128'):
# Flatten a recursive directory by bringing all files to top level 通过将所有文件移至顶层来展平递归目录。
# 创建一个新的 Path 对象 new_path ,表示扁平化后的目录路径。这个新路径是原路径后加上 _flat 后缀。
new_path = Path(f'{str(path)}_flat')
# 检查扁平化后的目录是否已经存在。
if os.path.exists(new_path):
# shutil.rmtree(path, ignore_errors=False, onerror=None)
# shutil.rmtree 是 Python 标准库 shutil 模块中的一个函数,用于递归地删除一个目录以及其中的所有内容。这个函数会删除指定的目录,包括其中的所有子目录和文件,以及这些子目录和文件的所有内容。
# 参数 :
# path : 要删除的目录的路径。这个路径可以是字符串或者 Path 对象。
# ignore_errors (可选): 如果设置为 True ,则在删除文件或目录时,如果遇到任何错误(例如权限错误), shutil.rmtree 会忽略这些错误。默认值为 False ,即在遇到错误时会抛出异常。
# onerror (可选): 一个回调函数,用于处理在删除过程中遇到的错误。
# 这个回调函数会接收三个参数 :函数(即 onerror 本身)、路径和异常信息。如果提供了 onerror 回调函数,那么即使 ignore_errors 设置为 False ,也不会抛出异常,而是调用回调函数来处理错误。
# 注意事项 :
# 使用 shutil.rmtree 时要非常小心,因为它会永久删除文件和目录,这个操作是不可逆的。
# 确保在调用 shutil.rmtree 之前,你确实希望删除指定的目录及其所有内容。
# 如果你只想删除空目录,可以使用 os.rmdir 或 pathlib.Path.rmdir 方法。 shutil.rmtree 是一个强大的工具,但也需要谨慎使用,以避免意外删除重要数据。
# 如果存在,使用 shutil.rmtree 函数删除这个目录,以确保在创建新目录之前没有旧的目录。
shutil.rmtree(new_path) # delete output folder
# 使用 os.makedirs 函数创建新的扁平化目录。
os.makedirs(new_path) # make new output folder
# 使用 glob.glob 函数递归地搜索 path 目录下的所有文件,并使用 tqdm 显示进度条。
for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
# shutil.copyfile(src, dst)
# shutil.copyfile() 是 Python 标准库 shutil 模块中的一个函数,用于将一个文件的内容复制到另一个文件。这个函数会打开源文件和目标文件,读取源文件的内容,并将其写入目标文件。
# 参数 :
# src : 源文件的路径。
# dst : 目标文件的路径。如果目标文件已存在,则会被覆盖。
# 返回值 :
# 无返回值。
# 异常 :
# 如果源文件不存在或无法读取,或者目标文件无法写入,可能会抛出 FileNotFoundError 或 PermissionError 。
# 对于每个找到的文件,使用 shutil.copyfile 函数将其复制到新的扁平化目录中。复制的文件名保持不变。
shutil.copyfile(file, new_path / Path(file).name)
# 这个函数的目的是简化目录结构,使得所有文件都位于一个单一的目录中,而不是分散在多个子目录中。这在某些情况下很有用,比如当你需要处理一个包含大量子目录的数据集时,扁平化目录可以简化文件的访问和管理。
15.def extract_boxes(path=DATASETS_DIR / 'coco128'):
python
# 这段代码定义了一个名为 extract_boxes 的函数,它将目标检测数据集转换为分类数据集,为每个类别创建一个单独的目录,并保存每个边界框中的图像。
# 这是 extract_boxes 函数的定义,它接受一个参数。
# 1.path :默认指向 DATASETS_DIR / 'coco128' 路径。
def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
# Convert detection dataset into classification dataset, with one directory per class 将检测数据集转换为分类数据集,每个类一个目录。
# 将输入路径转换为 Path 对象,以便使用路径操作。
path = Path(path) # images dir
# 如果存在名为 classification 的目录,则删除它。
shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
# Path.rglob(pattern)
# rglob() 是 Python pathlib 模块中 Path 类的一个方法,用于递归地搜索与给定模式匹配的所有文件路径。这个方法会遍历给定路径下的所有子目录,寻找匹配指定模式的文件。
# 参数 :
# pattern :一个字符串,表示要匹配的文件名模式。这个模式遵循 Unix shell 的规则,其中 * 匹配任意数量的字符(除了路径分隔符),而 ** 用于表示任意深度的目录。
# 返回值 :
# 返回一个生成器(generator),生成所有匹配模式的 Path 对象。
# rglob() 方法是递归的,因此它会搜索所有子目录,而不仅仅是当前目录。这使得它非常适合于在大型项目中查找特定类型的文件。
# 使用 rglob 方法递归地获取路径下的所有文件,并存储在 files 列表中。
files = list(path.rglob('*.*'))
# 获取文件总数。
n = len(files) # number of files
# 遍历 files 列表中的每个文件,并使用 tqdm 显示进度条。
for im_file in tqdm(files, total=n):
# 检查文件扩展名是否为支持的图像格式。
if im_file.suffix[1:] in IMG_FORMATS:
# image
# 使用 OpenCV 读取图像文件,并将其从 BGR 格式转换为 RGB 格式。
im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
# 获取图像的高度和宽度。
h, w = im.shape[:2]
# labels
# 获取与当前图像文件对应的标签文件路径。
# def img2label_paths(img_paths): -> 将包含图像文件路径的列表转换为对应的标签文件路径列表。这个列表推导式返回一个新的列表,其中包含每个图像文件对应的标签文件路径。 -> return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
lb_file = Path(img2label_paths([str(im_file)])[0])
# 检查标签文件是否存在。
if Path(lb_file).exists():
# 打开标签文件。
with open(lb_file) as f:
# 读取标签文件内容,将其转换为 NumPy 数组。
lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
# 遍历每个标签。
for j, x in enumerate(lb):
# 获取类别标签。
c = int(x[0]) # class
# 构建新图像文件的路径,该路径位于以类别命名的子目录中。
f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
# 检查新图像文件的父目录是否存在。
if not f.parent.is_dir():
# 如果父目录不存在,则创建它。
f.parent.mkdir(parents=True)
# 将归一化的边界框坐标转换为像素坐标。
b = x[1:] * [w, h, w, h] # box
# b[2:] = b[2:].max() # rectangle to square
# 对边界框进行扩展和填充。
b[2:] = b[2:] * 1.2 + 3 # pad
# 将边界框坐标从 xywh 格式转换为 xyxy 格式。
b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
# 将边界框的 x 坐标限制在图像宽度内。
b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
# 将边界框的 y 坐标限制在图像高度内。
b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
# 裁剪图像到边界框区域,并保存新图像。使用 assert 语句确保写入操作成功。
assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}' # {f} 中的边界框故障。
# 这个函数的目的是将目标检测数据集中的图像转换为分类数据集,每个类别一个目录,每个目录包含该类别的所有图像。这对于训练分类模型非常有用,因为它提供了一个更简单的数据结构,使得模型可以专注于学习每个类别的特征。
16.def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
python
# 这段代码定义了一个名为 autosplit 的函数,其目的是自动将一个图像数据集分割成 训练集 、 验证集 和 测试集 。
# 这是 autosplit 函数的定义,它接受三个参数。
# 1.path :图像数据集的路径,默认为 DATASETS_DIR / 'coco128/images' 。
# 2.weights :分割比例,默认为 (0.9, 0.1, 0.0) ,分别对应训练集、验证集和测试集的比例。
# 3.annotated_only :一个布尔值,指示是否只使用有标注的图像,默认为 False 。
def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
# 自动将数据集拆分为训练/验证/测试拆分并保存 path/autosplit_*.txt 文件。
""" Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
Usage: from utils.dataloaders import *; autosplit()
Arguments
path: Path to images directory
weights: Train, val, test weights (list, tuple)
annotated_only: Only use images with an annotated txt file
"""
# 将输入路径转换为 Path 对象。
path = Path(path) # images dir
# 使用 rglob 方法递归地获取路径下的所有图像文件,并排序。
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
# 获取图像文件总数。
n = len(files) # number of files
# 设置随机种子以确保结果可复现。
random.seed(0) # for reproducibility
# 使用 random.choices 函数根据提供的权重随机分配每个图像到三个分割中的一个。
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
# 定义三个文本文件名,用于存储分割后的图像路径。
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
# 遍历三个文本文件名。
for x in txt:
# 检查对应的文本文件是否已经存在。
if (path.parent / x).exists():
# 如果存在,则删除旧的文本文件。
(path.parent / x).unlink() # remove existing
# 打印自动分割的信息,如果 annotated_only 为 True ,则附加使用标注图像的信息。
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only) # 自动分割来自 {path}' + ' 的图像,仅使用 *.txt 标记图像' * annotated_only 。
# 遍历图像文件和对应的分割索引,并使用 tqdm 显示进度条。
for i, img in tqdm(zip(indices, files), total=n):
# 如果 annotated_only 为 False 或者对应的标签文件存在,则执行以下操作。
# def img2label_paths(img_paths): -> 将包含图像文件路径的列表转换为对应的标签文件路径列表。这个列表推导式返回一个新的列表,其中包含每个图像文件对应的标签文件路径。 -> return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
# 打开对应的文本文件以追加模式。
with open(path.parent / txt[i], 'a') as f:
# f.write(string)
# 在Python中, f.write() 是一个用于写入文件的方法,它属于文件对象。当你打开一个文件用于写入('w')或追加('a')时,可以使用 write() 方法将字符串写入该文件。
# 参数 :
# f : 文件对象,必须是以写入模式('w')或追加模式('a')打开的。
# string : 要写入文件的字符串。
# 返回值 : write() 方法没有返回值,它直接将数据写入文件。如果你需要写入多行数据,可以使用换行符 \n 。
# 注意事项 :
# 如果在写入过程中文件被其他程序锁定, write() 方法可能会抛出异常。
# write() 方法不会自动添加换行符,如果你需要写入多行,必须手动添加 \n 。
# 在使用 write() 方法时,确保文件以正确的模式打开,否则可能会遇到权限错误。
# path.relative_to(anchor)
# relative_to() 是 Python pathlib 模块中的 Path 类的一个方法,它用于返回一个路径对象相对于另一个路径的相对路径。
# 参数 :
# path : Path 对象,表示要获取相对路径的路径。
# anchor : Path 对象或字符串,表示用作参考的路径。
# 返回值 :
# relative_to() 方法返回一个 Path 对象,它表示 path 相对于 anchor 的相对路径。如果 path 不在 anchor 下,会抛出 ValueError 异常。
# 注意事项 :
# 如果 path 不是 anchor 的子路径, relative_to() 方法会抛出 ValueError 异常。
# relative_to() 方法返回的是一个 Path 对象,如果需要字符串形式的路径,可以使用 str() 函数进行转换。
# 这个方法在处理文件和目录路径时非常有用,特别是在你需要将路径转换为相对于某个特定目录的相对路径时。
# path.as_posix()
# 在 Python 的 pathlib 模块中, Path 类的 .as_posix() 方法用于将 Path 对象表示的路径转换为 POSIX 风格的字符串。POSIX 是一个操作系统标准,它规定了文件路径应该使用正斜杠( / )作为目录分隔符。
# path : Path 类的实例。
# 返回值 :
# 返回一个字符串,表示 Path 对象的路径,其中所有的路径分隔符都被替换为正斜杠( / )。
# 方法功能 :
# .as_posix() 方法将 Path 对象中的路径转换为一个字符串,这个字符串使用正斜杠( / )作为所有目录的分隔符,无论在原始路径中使用的是哪种操作系统的路径分隔符(例如,在 Windows 中可能是反斜杠 \ )。
# 此外,该方法还会处理路径中的一些特殊情况,例如,将相对路径(如 ./ 或 ../ )转换为简化形式,但不改变它们的相对性。 去除路径中多余的分隔符。
# 注意事项 :
# .as_posix() 方法不检查路径的实际存在性,它仅仅进行字符串层面的转换。
# 如果你需要在不同的操作系统之间移植代码,或者与期望 POSIX 路径风格的外部工具或库交互,使用 .as_posix() 方法可以帮助确保路径的兼容性。
# 将图像的相对路径写入文本文件,并添加换行符。
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
# 这个函数的目的是简化数据集的分割过程,使得用户可以轻松地将图像数据集分割成训练集、验证集和测试集,而不需要手动创建和维护这些分割。这对于机器学习和深度学习项目中的实验和模型评估非常有用。
17.def verify_image_label(args):
python
# 这段代码定义了一个名为 verify_image_label 的函数,它用于验证图像和对应的标签文件是否有效。
# 这是 verify_image_label 函数的定义,它接受一个参数。
# 1.args :这个参数是一个包含图像文件路径、标签文件路径和前缀的元组。
def verify_image_label(args):
# Verify one image-label pair 验证一个图像-标签对。
# 这段代码是 verify_image_label 函数的一部分,它负责验证图像文件的有效性,并在必要时修复损坏的JPEG图像。
# 解包 args 元组,获取 图像文件路径 im_file 、 标签文件路径 lb_file 和 前缀 prefix 。
im_file, lb_file, prefix = args
# 初始化计数器和消息列表,用于跟踪 缺失 nm 、 找到 nf 、 空 ne 、 损坏 nc 的标签数量以及 警告消息 msg 和 分割数据 segments 。
nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
# 开始一个 try 块,用于捕获在验证过程中可能发生的任何异常。
try:
# verify images
# 使用 PIL(Python Imaging Library)打开图像文件。
im = Image.open(im_file)
# im.verify()
# 在 Python 的 PIL(Python Imaging Library)库中, .verify() 方法用于验证图像文件的完整性。当处理图像文件时,这个方法尝试确认文件是否未损坏并且可以被正确解码。
# im :一个 PIL Image 对象。
# 功能 :
# .verify() 方法检查图像文件是否完整且未损坏。如果文件损坏或无法被识别,这个方法会抛出一个 IOError (输入/输出错误)异常。
# 注意事项 :
# .verify() 方法只适用于某些图像格式,特别是那些 PIL 支持的格式。
# 这个方法不会检查图像的元数据或内容,只检查图像文件的完整性和可读性。
# 在处理大量图像文件时,使用 .verify() 方法可以帮助识别和排除损坏的文件,以避免在后续处理中出现问题。
# 验证图像文件是否完整,没有损坏。
im.verify() # PIL verify
# 获取图像的尺寸, exif_size 函数从图像的 EXIF 数据中获取尺寸。
# def exif_size(img): -> 用于返回考虑了EXIF数据(特别是旋转信息)的PIL图像的实际尺寸。返回调整后的图像尺寸。 -> return s
shape = exif_size(im) # image size
# 确保图像尺寸大于 9x9 像素,否则抛出异常。
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' # 图像大小{形状} <10像素。
# 确保图像格式是支持的格式之一,否则抛出异常。
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' # 无效的图像格式 {im.format}。
# 如果图像格式是 JPG 或 JPEG,检查图像是否损坏。
if im.format.lower() in ('jpg', 'jpeg'):
# 以二进制读取模式打开图像文件。
with open(im_file, 'rb') as f:
# file_object.seek(offset, whence=0)
# seek() 函数是 Python 中文件对象的一个方法,它用于改变当前文件的读写位置。这个方法通常用于二进制文件操作,尤其是在需要随机访问文件内容时。
# file_object :文件对象,它必须是打开的文件,并且具有读写能力。
# 参数 :
# offset :偏移量,表示从 whence 指定的位置开始移动的字节数。正值表示向前(文件末尾方向),负值表示向后(文件开头方向)。
# whence :(可选)起始位置,指定 offset 从何处开始计算,默认值为0。它的值可以是 :
# 0 :文件开头(默认值), offset 表示从文件开头开始的字节数。
# 1 :当前位置, offset 表示从当前文件位置开始的字节数。
# 2 :文件末尾, offset 表示从文件末尾开始的字节数,通常用于移动到文件末尾之后的位置。
# 返回值 : seek() 方法没有返回值(返回 None )。
# seek() 方法是文件随机访问的关键,它允许程序在文件中快速定位到任意位置,而不需要从头开始读取或写入。
# 移动文件指针到文件末尾前2个字节。
f.seek(-2, 2)
# 读取文件末尾的2个字节,检查是否为 JPEG 文件结束标志 \xff\xd9 。
if f.read() != b'\xff\xd9': # corrupt JPEG
# PIL.ImageOps.exif_transpose(image, *, in_place=False)
# ImageOps.exif_transpose() 函数是 Python Imaging Library (PIL) 的一个扩展库 Pillow 中的一个函数,它用于根据图像的 EXIF 定向标签来调整图像的方向,使得图像按照 EXIF 标签中指定的方向进行正确的显示。如果图像没有 EXIF 定向标签或者标签值为 1(表示图像已经是正确的方向),则返回图像的一个副本。
# 参数 :
# image :要调整方向的 PIL 图像对象。
# in_place :(关键字参数)布尔值,如果设置为 True ,则在原图像对象上进行修改,并返回 None ;如果设置为 False (默认值),则返回一个新的图像对象,原图像对象不变。
# 返回值 :
# 如果 in_place 参数为 False (默认),返回一个新的图像对象,该对象根据 EXIF 定向标签调整了方向。
# 如果 in_place 参数为 True ,则原图像对象被修改,函数返回 None 。
# 功能 :
# 读取图像的 EXIF 信息,特别是 Orientation 标签。
# 根据 Orientation 标签的值,确定如何调整图像的方向。
# 应用相应的变换(例如旋转、翻转)来调整图像的方向。
# 如果 in_place 为 False ,则返回调整方向后的新图像对象;否则,修改原图像对象。
# 如果 JPEG 文件损坏,使用 PIL 的 ImageOps.exif_transpose 函数修复图像的方向,然后保存修复后的图像。
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
# 创建一条警告消息,提示 JPEG 图像已损坏并已修复。
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' # {prefix}警告⚠️{im_file}:损坏的 JPEG 已恢复并保存。
# 这段代码的目的是确保图像文件的完整性和格式正确性,并在检测到损坏的 JPEG 图像时进行修复。这对于维护数据集的质量非常重要,尤其是在处理大量图像数据时。
# 这段代码是 verify_image_label 函数的一部分,它负责验证与图像文件关联的标签文件是否有效,并处理标签数据。
# verify labels
# 检查标签文件 lb_file 是否存在。
if os.path.isfile(lb_file):
# 如果标签文件存在,设置 nf (找到的标签数量)为 1。
nf = 1 # label found
# 打开标签文件以读取。
with open(lb_file) as f:
# 读取标签文件的每一行,去除空白行,分割每行的数据,并创建一个列表 lb 。
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
# 检查 lb 列表中的任何一项是否包含超过 6 个元素,这通常表示存在分割(segmentation)标签。
if any(len(x) > 6 for x in lb): # is segment
# 提取类别标签并转换为 NumPy 数组。
classes = np.array([x[0] for x in lb], dtype=np.float32)
# 提取分割数据并转换为 NumPy 数组,每个分割数据表示为一系列点的坐标。
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
# 将 类别标签 和 分割数据 转换为边界框格式(xywh)。
# def segments2boxes(segments): -> 分割标签(segment labels)转换为边界框标签(box labels)。将 boxes 列表转换为 NumPy 数组,并使用 xyxy2xywh 函数将其从 xyxy 格式转换为 xywh 格式。 -> return xyxy2xywh(np.array(boxes)) # cls, xywh
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
# 将 lb 列表转换为 NumPy 数组。
lb = np.array(lb, dtype=np.float32)
# 获取标签数量。
nl = len(lb)
# 如果存在标签,执行以下检查。
if nl:
# 确保标签数组的列数为 5。
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' # 标签需要 5 列,检测到 {lb.shape[1]} 列。
# 确保标签数组中没有负值。
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}' # 负标签值 {lb[lb < 0]}。
# 确保标签数组中的坐标值在 0 到 1 之间。
assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}' # 非规范化或超出范围的坐标 {lb[:, 1:][lb[:, 1:] > 1]}。
# 找出标签数组中的唯一行及其索引。
_, i = np.unique(lb, axis=0, return_index=True)
# 检查通过 np.unique 函数返回的索引数组 i 的长度是否小于原始标签数组 lb 的长度 nl 。如果是,说明存在重复的标签行。
if len(i) < nl: # duplicate row check
# 使用索引数组 i 从原始标签数组 lb 中选择唯一的标签,从而移除重复的标签。
lb = lb[i] # remove duplicates
# 检查是否存在分割数据 segments 。
if segments:
# 如果存在分割数据,同样使用索引数组 i 来选择与唯一标签对应的分割数据,从而移除重复的分割数据。
segments = [segments[x] for x in i]
# 构造一条警告消息 msg ,指出在图像文件 im_file 中移除了 nl - len(i) 个重复的标签。 prefix 是一个前缀字符串,用于在日志消息前添加上下文信息。
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed' # {prefix}警告 ⚠️ {im_file}: {nl - len(i)} 重复标签已删除。
# 如果标签文件为空或不存在,设置 ne (空标签数量)为 1,并创建一个空的 NumPy 数组作为 lb 。
else:
# 设置空标签计数器。
ne = 1 # label empty
# 创建一个空的标签数组。
lb = np.zeros((0, 5), dtype=np.float32)
# 这段代码的目的是确保标签文件的内容有效,并将其转换为适合模型训练的格式。它还检查标签文件是否存在、是否为空,并在必要时移除重复的标签。这对于确保数据集的质量和一致性至关重要。
# 这段代码是 verify_image_label 函数的结尾部分,它处理了标签文件不存在的情况以及可能发生的任何异常。
# 这个 else 块与前面的 if os.path.isfile(lb_file): 相对应,表示如果标签文件不存在,则执行以下操作。
else:
# 设置 nm (缺失的标签数量)为 1,表示当前处理的图像缺少标签文件。
nm = 1 # label missing
# 创建一个形状为 (0, 5) 的零数组 lb ,表示没有标签数据。这个数组将被用作标签数据的占位符。
lb = np.zeros((0, 5), dtype=np.float32)
# 返回处理结果,包括图像文件路径 im_file 、 标签数据 lb 、 图像尺寸 shape 、 分割数据 segments 、 缺失标签计数 nm 、 找到标签计数 nf 、 空标签计数 ne 、 损坏标签计数 nc 和 警告消息 msg 。
return im_file, lb, shape, segments, nm, nf, ne, nc, msg
# 捕获 try 块中可能发生的任何异常。
except Exception as e:
# 设置 nc (损坏的标签数量)为 1,表示当前处理的图像或标签存在问题。
nc = 1
# 创建一条警告消息 msg ,指出由于存在损坏的图像或标签,程序将忽略当前的图像或标签,并提供异常信息 e 。
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' # {prefix}警告⚠️{im_file}:忽略损坏的图像/标签:{e}。
# 返回一个包含 None 和计数器的列表,以及警告消息。这个列表将被用作函数的返回值,指示处理过程中发生了异常。
return [None, None, None, None, nm, nf, ne, nc, msg]
# 这段代码确保了即使在标签文件不存在或图像标签损坏的情况下,函数也能返回一个一致的结果,同时提供了足够的信息来诊断问题。这对于数据清洗和数据集准备阶段非常重要,因为它有助于识别和处理数据集中的不良数据。
# 这个函数的目的是确保图像和标签文件的有效性,包括图像尺寸、格式和标签的格式和内容。如果发现任何问题,它会记录警告消息并返回相应的计数器。这对于数据清洗和准备阶段非常重要,可以确保训练数据的质量和一致性。
18.class HUBDatasetStats():
python
# 这段代码定义了一个名为 HUBDatasetStats 的类,其目的是处理与 YOLOv5 模型训练相关的数据集统计信息。
# 定义了 HUBDatasetStats 类。
class HUBDatasetStats():
# 用于生成 HUB 数据集 JSON 和 `-hub` 数据集目录的类。
""" Class for generating HUB dataset JSON and `-hub` dataset directory
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
autodownload: Attempt to download dataset if not found locally
Usage
from utils.dataloaders import HUBDatasetStats
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
stats.get_json(save=False)
stats.process_images()
"""
# 这段代码定义了一个类的构造函数 __init__ ,用于初始化一个处理数据集统计信息的类( HUBDatasetStats )实例。
# 这是类的构造函数定义,包含两个参数。
# 1.path :指定数据集配置文件的路径,默认为 'coco128.yaml' 。
# 2.autodownload :一个布尔值,指示是否在数据集缺失时自动下载,默认为 False 。
def __init__(self, path='coco128.yaml', autodownload=False):
# Initialize class
# 调用类的一个方法 _unzip ,传入 path 参数的 Path 对象,返回三个值 :数据集是否为压缩格式 zipped 、数据目录 data_dir 和 YAML 文件路径 yaml_path 。
zipped, data_dir, yaml_path = self._unzip(Path(path))
# 开始一个 try 块,用于捕获在加载 YAML 文件时可能发生的任何异常。
try:
# 使用 check_yaml 函数检查 YAML 文件路径,并以忽略错误的方式打开文件。
# def check_yaml(file, suffix=('.yaml', '.yml')):
# -> 检查一个 YAML 文件是否存在,如果不存在且文件是一个网址,则下载该文件。调用 check_file 函数,并传入 file 和 suffix 参数。 check_file 函数会检查文件是否存在,如果不存在且是一个网址,则下载文件,并确保文件后缀符合传入的 suffix 参数。
# -> return check_file(file, suffix)
with open(check_yaml(yaml_path), errors='ignore') as f:
# 使用 yaml.safe_load 函数加载 YAML 文件内容,并将其存储在字典 data 中。
data = yaml.safe_load(f) # data dict
# 如果数据集是压缩的,则更新 data 字典中的路径信息。
if zipped:
data['path'] = data_dir
# 如果在加载 YAML 文件时发生异常,捕获异常。
except Exception as e:
# 抛出一个异常,指示在加载 YAML 文件时出错,并提供原始异常的上下文。
raise Exception("error/HUB/dataset_stats/yaml_load") from e # 错误/HUB/dataset_stats/yaml_load 。
# 调用 check_dataset 函数,检查数据集是否存在,如果缺失且 autodownload 为 True ,则下载数据集。
# def check_dataset(data, autodownload=True): -> 检查、下载(如果需要)并解压数据集,确保数据集在本地可用。返回包含数据集信息的字典。 -> return data # dictionary
check_dataset(data, autodownload) # download dataset if missing
# 创建一个 Path 对象,表示数据集的 HUB 目录路径,并附加 -hub 后缀。
self.hub_dir = Path(data['path'] + '-hub')
# 创建一个 Path 对象,表示图像目录路径。
self.im_dir = self.hub_dir / 'images'
# 创建图像目录,如果目录已存在则不抛出异常。
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
# 创建一个统计信息字典 self.stats ,包含类别数量 nc 和类别名称列表 names 。
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
# 将加载的数据集信息存储在 self.data 中。
self.data = data
# 这个类的目的是初始化数据集统计信息,包括类别数量、类别名称等,并确保数据集文件的存在。它通过加载 YAML 配置文件来获取数据集信息,并在必要时下载数据集。这个类是 YOLOv5 模型训练流程中数据准备阶段的一部分。
# 这段代码定义了一个名为 _find_yaml 的静态方法,其目的是在给定的目录中查找并返回一个 .yaml 文件的路径。
# 这个装饰器表示 _find_yaml 是一个静态方法,它不需要访问类的实例或类变量。
@staticmethod
# 这是 _find_yaml 方法的定义,它接受一个参数。
# 1.dir :这是一个 pathlib.Path 对象,表示要搜索 .yaml 文件的目录。
def _find_yaml(dir):
# Return data.yaml file
# 首先尝试在 dir 目录的顶级查找所有以 .yaml 结尾的文件。如果没有找到,然后递归地在整个目录树中搜索 .yaml 文件。
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
# 确保至少找到了一个 .yaml 文件,如果没有找到,抛出一个 AssertionError 。
assert files, f'No *.yaml file found in {dir}' # 在 {dir} 中未找到 *.yaml 文件。
# 如果找到了多个 .yaml 文件,执行以下操作来选择一个首选的文件。
if len(files) > 1:
# 从找到的文件中选择文件名(不包括扩展名)与目录名相同的 .yaml 文件。
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
# 确保在上一步中只找到了一个匹配的 .yaml 文件,如果没有找到,抛出一个 AssertionError 。
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed' # 在 {dir} 中找到多个 *.yaml 文件,仅允许 1 个 *.yaml 文件。
# 确保在整个搜索过程中只找到了一个 .yaml 文件,如果有多个,抛出一个 AssertionError 。
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}' # 找到多个 *.yaml 文件:{files},{dir} 中仅允许 1 个 *.yaml 文件。
# 返回找到的 .yaml 文件的路径。
return files[0]
# 这个方法的目的是确保在给定目录中只存在一个 .yaml 文件,并且在多个文件的情况下,优先选择与目录名相同的 .yaml 文件。这在处理配置文件时非常有用,因为它提供了一种一致的方式来定位和使用配置文件。
# 这段代码定义了一个名为 _unzip 的方法,它是用于处理数据集压缩文件(通常是 .zip 格式)的辅助函数。
# 这是 _unzip 方法的定义,它接受一个参数。
# 1.path :表示数据集压缩文件的路径。
def _unzip(self, path):
# Unzip data.zip
# 检查 path 是否以 .zip 结尾,如果不是,说明 path 可能直接指向 data.yaml 文件。
if not str(path).endswith('.zip'): # path is data.yaml
# 如果 path 不是 .zip 文件,返回一个元组,其中包含三个值 : False (表示不是压缩文件), None (没有解压目录),以及原始路径 path 。
return False, None, path
# 确保 path 指向的文件存在,如果文件不存在,抛出一个 AssertionError 。
assert Path(path).is_file(), f'Error unzipping {path}, file not found' # 解压 {path} 时出错,找不到文件。
# 调用 unzip_file 函数来解压 .zip 文件,解压目标目录设置为 .zip 文件的父目录。
# def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): -> 解压缩一个 .zip 文件到指定的路径,并排除包含特定字符串的文件。
unzip_file(path, path=path.parent)
# 创建一个新的 Path 对象 dir ,表示解压后的目录,即 .zip 文件名(不包括扩展名)。
dir = path.with_suffix('') # dataset directory == zip name
# 确保解压后的目录 dir 存在,如果目录不存在,抛出一个 AssertionError 。
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/' # 解压 {path} 时出错,未找到 {dir}。path/to/abc.zip 必须解压至 path/to/abc/ 。
# 返回一个元组,其中包含三个值 : True (表示文件已解压),解压后的目录路径 str(dir) ,以及调用 _find_yaml 方法找到的 .yaml 文件路径。
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
# 这个方法的目的是处理数据集的压缩文件,如果提供的路径是一个 .zip 文件,它将被解压,并且方法会返回解压后的数据目录和 .yaml 配置文件的路径。如果提供的路径不是 .zip 文件,它假定是一个 .yaml 文件,并直接返回该文件的路径。这个函数有助于自动化数据集的解压和配置文件的查找过程。
# 这段代码定义了一个名为 _hub_ops 的方法,它是用于处理单个图像的辅助函数,目的是将图像调整大小并保存在较低质量下,以便在 Web 或 App 中查看。
# 这是 _hub_ops 方法的定义,它接受两个参数。
# 1.f :要处理的图像文件的路径。
# 2.max_dim :图像的最大尺寸,默认为 1920。
def _hub_ops(self, f, max_dim=1920):
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing 1 张图像"f"的 HUB 操作:调整大小并以降低的质量保存在 /dataset-hub 中以供网络/应用程序查看。
# 创建一个新的 Path 对象 f_new ,表示调整大小后的图像文件名,该文件将保存在 self.im_dir 目录下。
f_new = self.im_dir / Path(f).name # dataset-hub image filename
# 开始一个 try 块,用于尝试使用 PIL 处理图像。
try: # use PIL
# 使用 PIL 的 Image.open 函数打开图像文件。
im = Image.open(f)
# 计算调整大小的比例 r ,使得图像的最大维度不超过 max_dim 。
r = max_dim / max(im.height, im.width) # ratio
# 如果比例 r 小于 1,说明图像太大,需要调整大小。
if r < 1.0: # image too large
# 使用 PIL 的 resize 方法调整图像大小。
im = im.resize((int(im.width * r), int(im.height * r)))
# 将调整大小后的图像保存为 JPEG 格式,质量设置为 50,并优化文件大小。
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
# 如果在 PIL 处理过程中发生异常,捕获异常。
except Exception as e: # use OpenCV
# 记录一条日志信息,提示 PIL 处理失败。
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}') # 警告 ⚠️ HUB 操作 PIL 失败 {f}:{e} 。
# 使用 OpenCV 的 imread 函数读取图像文件。
im = cv2.imread(f)
# 获取图像的高度和宽度。
im_height, im_width = im.shape[:2]
# 重新计算调整大小的比例 r 。
r = max_dim / max(im_height, im_width) # ratio
# 如果比例 r 小于 1,说明图像太大,需要调整大小。
if r < 1.0: # image too large
# 使用 OpenCV 的 resize 函数调整图像大小。
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
# 将调整大小后的图像保存为 JPEG 格式。
cv2.imwrite(str(f_new), im)
# 这个方法的目的是确保图像在 Web 或 App 中查看时不会过大,同时降低图像质量以减少加载时间。它首先尝试使用 PIL 来处理图像,如果 PIL 处理失败,则使用 OpenCV 作为备选方案。这种方法提高了代码的健壮性,并确保了在不同环境下都能正确处理图像。
# 这段代码定义了一个名为 get_json 的方法,它用于生成并返回一个包含数据集统计信息的 JSON 对象,这些信息可以用于 Ultralytics HUB。 Ultralytics HUB 是一个由 Ultralytics 团队开发的在线平台,旨在为用户提供一个集中的地方来可视化、训练和部署 YOLO(You Only Look Once)系列模型,包括 YOLOv5、YOLOv8 和 YOLO11。
# 这是 get_json 方法的定义,它接受三个参数。
# 1.save :一个布尔值,指示是否保存统计信息到文件,默认为 False 。
# 2.verbose :一个布尔值,指示是否打印统计信息,默认为 False 。
def get_json(self, save=False, verbose=False):
# Return dataset JSON for Ultralytics HUB 返回 Ultralytics HUB 的数据集 JSON。
# 定义一个内部函数 _round ,用于将标签数据中的类别索引转换为整数,并将坐标值四舍五入到六位小数。
def _round(labels):
# Update labels to integer class and 6 decimal place floats 将标签更新为整数类和 6 位小数浮点数。
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
# 遍历 'train' 、 'val' 和 'test' 三个数据集分割。
for split in 'train', 'val', 'test':
# 检查当前分割是否存在于 self.data 字典中。
if self.data.get(split) is None:
# 如果当前分割不存在,设置对应的统计信息为 None 。
self.stats[split] = None # i.e. no test set
continue
# 如果当前分割存在,使用 LoadImagesAndLabels 类加载数据集。
# class LoadImagesAndLabels(Dataset):
# -> 用于加载图像和对应的标签,并提供数据增强、缓存等功能,以便于在深度学习模型训练中使用。
# -> def __init__(self,path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, cache_images=False, single_cls=False, stride=32, pad=0.0, min_items=0, prefix=''):
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
# 计算每个类别的实例统计信息,并存储在 NumPy 数组 x 中。
# 这段代码是用于计算数据集中每个类别的实例数量的。
# x = np.array([...]) :创建一个 NumPy 数组 x ,其中包含计算得到的每个类别的实例数量。
# np.bincount(label[:, 0].astype(int), minlength=self.data['nc']) :对于 dataset.labels 中的每个标签数组 label ,使用 np.bincount 函数计算每个类别的实例数量。
# label[:, 0] :提取每个标签数组中的第一个元素,即类别索引。
# .astype(int) :确保类别索引是整数类型。
# minlength=self.data['nc'] :确保 np.bincount 的输出长度至少为 self.data['nc'] ,即数据集中的类别数量。
# for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')) :使用 tqdm 库遍历 dataset.labels 中的所有标签数组。
# total=dataset.n :设置 tqdm 的总进度, dataset.n 是数据集中的图像数量。
# desc='Statistics' :设置 tqdm 的描述,显示当前正在进行的统计操作。
# np.array([...]) :将上述计算得到的类别实例数量的列表转换为一个 NumPy 数组。
# 这段代码的结果是一个形状为 (128, 80) 的 NumPy 数组 x ,其中每一行代表一个分割(例如训练集、验证集或测试集),每一列代表一个类别的实例数量。这个数组可以用来分析数据集中类别的分布情况,例如哪些类别有很多实例,哪些类别实例较少。
x = np.array([
# np.bincount(x, minlength=None)
# np.bincount 是 NumPy 库中的一个函数,它用于计算非负整数数组中每个值的出现次数。
# 参数 :
# x :输入数组,其中的元素必须是非负整数。
# minlength (可选) :输出数组的最小长度。如果提供,数组 x 中小于 minlength 的值将被忽略,而 x 中等于或大于 minlength 的值将导致数组被扩展以包含这些值。如果未提供或为 None ,则输出数组的长度将与 x 中的最大值加一相匹配。
# 返回值 :
# 返回一个数组,其中第 i 个元素代表输入数组 x 中值 i 出现的次数。
# 功能 :
# np.bincount 函数对输入数组 x 中的每个值进行计数,返回一个一维数组,其长度至少与 x 中的最大值一样大。
# 如果 x 中的某个值没有出现,那么在返回的数组中对应的位置将为 0。
# 例:
# x = np.array([1, 2, 3, 3, 0, 1, 4])
# np.bincount(x)
# '''array([1, 2, 1, 2, 1], dtype=int64)'''
# 输出 : [1 2 1 2 1]。统计索引出现次数:索引0出现1次,1出现2次,2出现1次,3出现2次,4出现1次。
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
# 构建包含实例统计信息和图像统计信息的字典,并将其存储在 self.stats 字典中。
# 这段代码是在构建一个名为 self.stats 的字典,用于存储数据集的统计信息,并将这些信息组织成一个结构化的格式。这个字典包含了实例统计、图像统计和标签信息。
# self.stats[split] = { ... } : split 是当前处理的数据集分割(例如 'train', 'val', 'test')。 在 self.stats 字典中为当前分割创建一个条目,并赋予一个包含统计信息的新字典。
self.stats[split] = {
# 'instance_stats' :包含实例级别的统计信息。
'instance_stats': {
# 'total' :数据集中所有实例的总数,通过 x.sum() 计算得到, x 是一个数组,其中每个元素代表一个类别的实例数量。
'total': int(x.sum()),
# 'per_class' :每个类别的实例数量,通过 x.sum(0) 计算得到,即对数组 x 的每一列求和,得到每个类别的总数。
'per_class': x.sum(0).tolist()},
# image_stats' :包含图像级别的统计信息。
'image_stats': {
# 'total' :数据集中的图像总数,由 dataset.n 提供。
'total': dataset.n,
# 'unlabelled' :未标记的图像数量,通过 np.all(x == 0, 1).sum() 计算得到,即统计 x 数组中所有元素都为0的行数。
'unlabelled': int(np.all(x == 0, 1).sum()),
# 'per_class' :每个类别至少有一个实例的图像数量,通过 (x > 0).sum(0) 计算得到,即统计 x 数组中每列大于0的行数。
'per_class': (x > 0).sum(0).tolist()},
# 'labels' :包含每个图像的标签信息。
'labels': [{
# 使用列表推导式,对于 dataset.im_files 中的每个图像文件路径 k 和对应的标签 v ,创建一个字典条目。
# str(Path(k).name) 获取图像文件的名称作为键。
# _round(v.tolist()) 对标签数据进行处理,将类别索引转换为整数,并将坐标值四舍五入到六位小数,结果作为值。
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
# 这个字典结构的设计旨在提供一种清晰、易于理解的方式来查看和分析数据集的统计信息,包括每个类别的实例数量、图像数量以及每个图像的具体标签。这些信息对于数据集的管理和模型训练非常有用。
# Save, print and return
# 如果 save 参数为 True ,则保存统计信息到 JSON 文件。
if save:
# 确定保存统计信息的路径。
stats_path = self.hub_dir / 'stats.json'
print(f'Saving {stats_path.resolve()}...') # 正在保存 {stats_path.resolve()}...
# 打开文件以写入。
with open(stats_path, 'w') as f:
# 将 self.stats 字典转换为 JSON 格式并保存到文件。
json.dump(self.stats, f) # save stats.json
# 如果 verbose 参数为 True ,则打印统计信息。
if verbose:
# 将 self.stats 字典转换为格式化的 JSON 字符串并打印。
print(json.dumps(self.stats, indent=2, sort_keys=False))
# 返回包含统计信息的 self.stats 字典。
return self.stats
# 这个方法的目的是收集数据集的统计信息,如每个类别的实例数量和图像数量,并将这些信息保存为 JSON 文件,以便在 Ultralytics HUB 中使用。通过这种方式,可以方便地查看和分析数据集的分布情况。
# 这段代码定义了一个名为 process_images 的方法,它是用于压缩图像以供 Ultralytics HUB 使用的函数。
# 这是 process_images 方法的定义,它不接受除了 self 之外的任何参数。
def process_images(self):
# Compress images for Ultralytics HUB
# 遍历 'train' 、 'val' 和 'test' 三个数据集分割。
for split in 'train', 'val', 'test':
# 检查 self.data 字典中是否存在当前分割的数据。
if self.data.get(split) is None:
# 如果当前分割的数据不存在,则跳过当前迭代。
continue
# 如果当前分割的数据存在,使用 LoadImagesAndLabels 类加载数据集。
# class LoadImagesAndLabels(Dataset):
# -> 用于加载图像和对应的标签,并提供数据增强、缓存等功能,以便于在深度学习模型训练中使用。
# -> def __init__(self,path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, cache_images=False, single_cls=False, stride=32, pad=0.0, min_items=0, prefix=''):
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
# 创建一个描述字符串 desc ,用于 tqdm 进度条的描述。
desc = f'{split} images'
# 使用 ThreadPool 创建一个线程池, NUM_THREADS 是线程数量。
# ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files) 将 self._hub_ops 函数映射到 dataset.im_files 列表中的每个图像文件路径。
# tqdm 用于显示进度条, total=dataset.n 设置进度条的总进度, desc=desc 设置进度条的描述。
for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
# 这个 for 循环的目的是迭代进度条, pass 表示循环体中不执行任何操作。
pass
# 处理完所有图像后,打印一条消息,指示所有图像已保存到 self.im_dir 目录。
print(f'Done. All images saved to {self.im_dir}') # 完成。所有图像已保存至 {self.im_dir} 。
# 返回图像保存的目录路径。
return self.im_dir
# 这个方法的目的是批量处理数据集中的图像,将它们压缩并保存到指定目录,以便在 Ultralytics HUB 中使用。通过使用多线程,可以加速图像处理过程,提高效率。
19.class ClassificationDataset(torchvision.datasets.ImageFolder):
python
# Classification dataloaders -------------------------------------------------------------------------------------------
# 这段代码定义了一个名为 ClassificationDataset 的类,它继承自 torchvision.datasets.ImageFolder 。这个类用于加载和处理图像分类数据集,支持数据增强、缓存到 RAM 或磁盘等功能。
# 定义了一个名为 ClassificationDataset 的类,它继承自 torchvision.datasets.ImageFolder 。
class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
YOLOv5 Classification Dataset.
Arguments
root: Dataset path
transform: torchvision transforms, used by default
album_transform: Albumentations transforms, used if installed
"""
# 这段代码是 ClassificationDataset 类的构造函数 __init__ ,它初始化类实例并设置图像分类数据集的配置。
# 这是类的构造函数定义,包含四个参数。
# 1.root :数据集的根目录路径。
# 2.augment :一个布尔值,指示是否应用数据增强。
# 3.imgsz :图像的目标尺寸。
# 4.cache :一个布尔值或字符串,指示是否缓存图像数据,可以是 True 、 False 、 'ram' 或 'disk' ,默认为 False 。
def __init__(self, root, augment, imgsz, cache=False):
# 调用父类 torchvision.datasets.ImageFolder 的构造函数,初始化数据集。
super().__init__(root=root)
# 创建一个 PyTorch 变换对象 torch_transforms ,用于对图像进行预处理。 classify_transforms 函数应返回适用于分类任务的变换。
# def classify_transforms(size=224): -> 用于创建一组图像预处理转换操作,这些操作通常用于深度学习模型中,特别是在图像分类任务中。返回一个由 T.Compose 创建的转换序列。 -> return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
self.torch_transforms = classify_transforms(imgsz)
# 如果启用数据增强( augment 为 True ),则创建一个 Albumentations 变换对象 album_transforms ,用于数据增强。 classify_albumentations 函数应返回适用于分类任务的数据增强变换。
# def classify_albumentations(augment=True, size=224, scale=(0.08, 1.0), ratio=(0.75, 1.0 / 0.75), hflip=0.5, vflip=0.0, jitter=0.4, mean=IMAGENET_MEAN, std=IMAGENET_STD, auto_aug=False):
# -> 用于创建一组用于图像分类任务的图像增强变换,这些变换基于 albumentations 库实现。返回一个 A.Compose 对象,该对象包含所有的变换。
# -> return A.Compose(T)
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
# 根据 cache 参数的值,设置 cache_ram 属性,指示是否将图像缓存到 RAM 中。
self.cache_ram = cache is True or cache == 'ram'
# 根据 cache 参数的值,设置 cache_disk 属性,指示是否将图像缓存到磁盘上。
self.cache_disk = cache == 'disk'
# 扩展 self.samples 列表,为每个样本添加两个新元素 : .npy 文件路径 和 图像数据 。这里假设 self.samples 已经包含了原始的样本信息,如文件路径和索引。 Path(x[0]).with_suffix('.npy') 生成对应的 .npy 文件路径。
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
# 这段代码的目的是初始化一个用于图像分类的数据集对象,配置数据增强和缓存选项,并准备样本列表以包含新的缓存信息。这样,当数据集被用于训练或验证时,可以高效地加载和处理图像数据。
# 这段代码定义了 ClassificationDataset 类的 __getitem__ 方法,它用于按索引 i 获取数据集中的一个样本,并对其进行处理。
# 这是 __getitem__ 方法的定义,它接受一个参数。
# 1.i :表示要获取的样本的索引。
def __getitem__(self, i):
# 从 self.samples 列表中获取第 i 个样本,该样本包含四个元素 :文件名 f 、索引 j 、 .npy 文件路径 fn 和图像数据 im 。
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
# 如果启用了 RAM 缓存( self.cache_ram 为 True )且图像数据 im 为 None ,则从文件中读取图像数据并缓存。
if self.cache_ram and im is None:
# 使用 OpenCV 的 imread 函数从文件路径 f 读取图像数据,并将其存储在 self.samples[i][3] 中。
im = self.samples[i][3] = cv2.imread(f)
# 如果启用了磁盘缓存( self.cache_disk 为 True )。
elif self.cache_disk:
# 检查 .npy 文件是否存在。
if not fn.exists(): # load npy
# 如果 .npy 文件不存在,则读取图像文件并保存为 .npy 文件。
np.save(fn.as_posix(), cv2.imread(f))
# 从 .npy 文件中加载图像数据。
im = np.load(fn)
# 如果没有启用缓存,直接从文件中读取图像数据。
else: # read image
# 使用 OpenCV 的 imread 函数从文件路径 f 读取图像数据。
im = cv2.imread(f) # BGR
# 如果定义了 Albumentations 变换( self.album_transforms )。
if self.album_transforms:
# 将图像从 BGR 格式转换为 RGB 格式,并应用 Albumentations 变换。
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
# 如果没有定义 Albumentations 变换。
else:
# 应用 PyTorch 变换 self.torch_transforms 到图像数据。
sample = self.torch_transforms(im)
# 返回处理后的样本 sample 和索引 j 。
return sample, j
# 这个方法的目的是提供一种灵活的方式来加载和处理图像数据,支持缓存和数据增强。通过重写 __getitem__ 方法,可以自定义数据加载和预处理流程,适应不同的训练需求。
# 这个类的目的是提供一个灵活的数据加载器,支持数据增强和缓存,以提高图像分类任务的训练效率。通过重写 __getitem__ 方法,可以自定义数据加载和预处理流程,适应不同的训练需求。
20.def create_classification_dataloader(path,imgsz=224, batch_size=16, augment=True, cache=False, rank=-1, workers=8, shuffle=True):
python
# 这段代码定义了一个名为 create_classification_dataloader 的函数,它用于创建并返回一个用于图像分类任务的数据加载器(Dataloader)。
# 这是 create_classification_dataloader 函数的定义,它接受多个参数。
# 1.path :数据集的根目录路径。
# 2.imgsz :图像的目标尺寸,默认为 224。
# 3.batch_size :每个批次的样本数量,默认为 16。
# 4.augment :一个布尔值,指示是否应用数据增强,默认为 True 。
# 5.cache :一个布尔值或字符串,指示是否缓存图像数据,默认为 False 。
# 6.rank :用于分布式训练的排名,默认为 -1 。
# 7.workers :用于数据加载的工作线程数量,默认为 8。
# 8.shuffle :一个布尔值,指示是否在每个epoch开始时打乱数据,默认为 True 。
def create_classification_dataloader(path,
imgsz=224,
batch_size=16,
augment=True,
cache=False,
rank=-1,
workers=8,
shuffle=True):
# Returns Dataloader object to be used with YOLOv5 Classifier 返回与 YOLOv5 分类器一起使用的 Dataloader 对象。
# 使用 torch_distributed_zero_first 函数确保在分布式训练中只有一个进程初始化数据集的缓存。
# def torch_distributed_zero_first(local_rank: int): -> 用于在分布式训练中同步所有进程,确保每个进程在执行特定操作之前等待主进程(通常是 local_rank 为 0 的进程)完成某些任务。
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
# 创建一个 ClassificationDataset 实例,用于加载和处理图像分类数据集。
dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
# 确保 batch_size 不超过数据集的大小。
batch_size = min(batch_size, len(dataset))
# 获取可用的 GPU 数量。
nd = torch.cuda.device_count()
# cpu_count = os.cpu_count()
# os.cpu_count() 是 Python 标准库 os 模块中的一个函数,用于返回当前机器上可用的 CPU 核心数。
# 参数 :这个函数不需要任何参数。
# 返回值 :
# os.cpu_count() 函数返回一个整数,表示 CPU 核心的数量。如果返回 None ,则表示无法确定 CPU 核心数。
# os.cpu_count() 函数在多线程和多进程编程中非常有用,可以帮助开发者了解系统的并发处理能力,从而更好地规划任务分配和资源利用。
# 计算用于数据加载的工作线程数量,考虑到 CPU 核心数和 GPU 数量。
# batch_size if batch_size > 1 else 0 :检查 batch_size 是否大于 1 ,如果是,则使用 batch_size 作为候选值;如果不是(即 batch_size 为 1 或 0 ),则使用 0 作为候选值。这个条件确保了当批量大小不合理时不会分配工作线程。
# 这行代码的目的是平衡 CPU 和 GPU 资源的使用,确保数据加载过程不会超过系统的处理能力,同时考虑到用户设置的工作线程数量。通过这种方式,可以优化数据加载的性能,避免因为过多的工作线程而导致的资源竞争和潜在的性能下降。
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
# torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)
# torch.utils.data.distributed.DistributedSampler 类的构造函数用于创建一个新的分布式采样器实例,它主要用于分布式训练环境中,以确保每个进程只处理数据集的一部分,从而实现数据的均匀分配。
# 参数 :
# dataset ( Dataset ) :要采样的数据集对象。
# num_replicas ( int ,可选) :分布式环境中的总副本(进程)数量。默认值为 None ,在这种情况下,它会尝试从当前的分布式环境变量中获取 world_size 。
# rank ( int ,可选) :当前进程的排名或ID。默认值为 None ,在这种情况下,它会尝试从当前的分布式环境变量中获取 rank 。
# shuffle ( bool ) :是否在每个epoch开始时打乱数据集的采样顺序。默认值为 True 。
# seed ( int ) :用于打乱数据集的随机种子。确保在所有进程中使用相同的种子以获得一致的打乱结果。默认值为 0 。
# drop_last ( bool ) :如果为 True ,则在数据集不能被均匀分配时,丢弃最后一部分数据以确保每个进程处理相同数量的数据。如果为 False ,则可能有些进程会处理更多的数据。默认值为 False 。
# 返回值 :
# 返回一个新的 DistributedSampler 实例。
# DistributedSampler 类在 PyTorch 中用于分布式训练,以下是它的一些常用属性和方法 :
# 属性 :
# dataset : 返回与采样器关联的数据集。
# num_replicas : 返回分布式环境中的总副本(进程)数量。
# rank : 返回当前进程的排名或ID。
# epoch : 返回当前的epoch数。这个属性在每个epoch开始时通过调用 set_epoch() 方法更新。
# 方法 :
# set_epoch(epoch) : 设置当前的epoch数。这对于确保在每个epoch中数据被打乱是必要的,特别是在 shuffle=True 时。
# __iter__() : 返回一个迭代器,该迭代器产生当前epoch中被采样器选中的数据集索引。
# __len__() : 返回当前epoch中被采样器选中的数据集索引的数量。
# update() : 更新采样器的状态,这个方法在 PyTorch 的某些版本中存在,用于重新配置采样器的参数。
# DistributedSampler 的主要作用是确保在分布式训练中,每个进程都能够处理数据集的不同部分,从而提高数据加载的效率和训练的可扩展性。通过在每个epoch开始时调用 set_epoch() 方法,可以确保数据在每个epoch中都被重新打乱,这对于模型的训练是非常重要的。
# 如果不是分布式训练( rank 为 -1 ),则不使用 sampler ;否则,创建一个 DistributedSampler 实例。
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
# 创建一个 PyTorch 随机数生成器实例。
generator = torch.Generator()
# 设置随机数生成器的种子,确保在分布式训练中每个进程的随机性是一致的。
generator.manual_seed(6148914691236517205 + RANK)
# 返回一个 InfiniteDataLoader 实例,它是一个自定义的数据加载器,用于无限循环地加载数据。 InfiniteDataLoader 的参数包括 :
# dataset :数据集实例。
# batch_size :批次大小。
# shuffle :是否打乱数据。
# num_workers :工作线程数量。
# sampler :用于分布式训练的采样器。
# pin_memory :是否将数据加载到 GPU 内存中。
# worker_init_fn :工作线程的初始化函数。
# generator :随机数生成器。
return InfiniteDataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
worker_init_fn=seed_worker,
generator=generator) # or DataLoader(persistent_workers=True)
# 这个函数的目的是提供一个灵活且高效的数据加载器,支持数据增强、缓存和分布式训练。通过使用 InfiniteDataLoader ,可以确保数据集在训练过程中无限循环,这对于深度学习训练非常有用。