dataset.py
ultralytics\data\dataset.py
目录
[2.class YOLODataset(BaseDataset):](#2.class YOLODataset(BaseDataset):)
[3.class YOLOMultiModalDataset(YOLODataset):](#3.class YOLOMultiModalDataset(YOLODataset):)
[4.class GroundingDataset(YOLODataset):](#4.class GroundingDataset(YOLODataset):)
[5.class YOLOConcatDataset(ConcatDataset):](#5.class YOLOConcatDataset(ConcatDataset):)
[6.class SemanticDataset(BaseDataset):](#6.class SemanticDataset(BaseDataset):)
[7.class ClassificationDataset:](#7.class ClassificationDataset:)
1.所需的库和模块
python
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import json
from collections import defaultdict
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCHVISION_0_18
from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
load_dataset_cache_file,
save_dataset_cache_file,
verify_image,
verify_image_label,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
2.class YOLODataset(BaseDataset):
python
# 这段代码定义了一个名为 YOLODataset 的类,它继承自 BaseDataset 类。这个类是用于初始化一个与 YOLO 相关的数据集。
# 定义了一个名为 YOLODataset 的新类,它继承自 BaseDataset 。
class YOLODataset(BaseDataset):
# 用于以 YOLO 格式加载对象检测和/或分割标签的数据集类。
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
# 这是 YOLODataset 类的构造函数,它接受任意数量的位置参数( *args ),任意数量的关键字参数( **kwargs ),一个名为 data 的参数(默认值为 None ),以及一个名为 task 的参数(默认值为 "detect" )。
def __init__(self, *args, data=None, task="detect", **kwargs):
# 使用段和关键点的可选配置初始化 YOLODataset。
"""Initializes the YOLODataset with optional configurations for segments and keypoints."""
# 根据传入的 task 参数值,设置 self.use_segments 属性。如果 task 为 "segment" ,则 self.use_segments 为 True ,否则为 False 。
self.use_segments = task == "segment"
# 类似地,根据 task 参数值,设置 self.use_keypoints 属性。如果 task 为 "pose" ,则 self.use_keypoints 为 True ,否则为 False 。
self.use_keypoints = task == "pose"
# 设置 self.use_obb 属性。如果 task 为 "obb" ,则 self.use_obb 为 True ,否则为 False 。
self.use_obb = task == "obb"
# 将传入的 data 参数赋值给 self.data 属性。
self.data = data
# 这是一个断言语句,用于确保 self.use_segments 和 self.use_keypoints 不能同时为 True 。如果同时为 True ,则会抛出异常。
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." # 不能同时使用分割和关键点。
# 调用父类 BaseDataset 的构造函数,并将所有传入的位置参数和关键字参数传递给它。
super().__init__(*args, **kwargs)
# 这个类的主要作用是根据 task 参数的不同值来初始化不同的属性,以支持不同的任务(如目标检测、图像分割、关键点检测等)。同时,它还确保了不能同时使用图像分割和关键点检测两种任务。
# .cache 后缀的文件通常表示这是一个缓存文件,用于存储某些类型的数据,以便在后续的使用中可以快速加载或恢复之前的状态,而不需要重新计算或重新从原始数据源获取数据。缓存文件可以用于多种不同的应用和场景,具体含义取决于它被创建和使用的上下文。
# 这段代码定义了一个名为 cache_labels 的方法,它的作用是扫描图像文件和对应的标签文件,验证它们,并缓存结果。
# 定义了一个方法 cache_labels ,它接受一个参数。
# 1.path :该参数默认为当前目录下的 labels.cache 文件。
def cache_labels(self, path=Path("./labels.cache")):
# 缓存数据集标签,检查图像并读取形状。
"""
Cache dataset labels, check images and read shapes.
Args:
path (Path): Path where to save the cache file. Default is Path('./labels.cache').
Returns:
(dict): labels.
"""
# 初始化一个字典 x ,用于存储标签信息。
x = {"labels": []}
# 初始化一些计数器和消息列表。 nm :缺失的标签数量。 nf :找到的标签数量。 ne :空的标签数量。 nc :损坏的标签数量。 msgs :警告消息列表。
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
# 构造一个描述字符串,用于进度条的描述。
desc = f"{self.prefix}Scanning {path.parent / path.stem}..." # {self.prefix} 正在扫描 {path.parent / path.stem}...
# 获取图像文件的数量,用于进度条的总数。
total = len(self.im_files)
# 从 self.data 中获取关键点的形状信息。 nkpt (关键点的数量), ndim (每个关键点的维度数)。
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
# 这段代码是一个条件检查,用于验证关键点(keypoints)数据的配置是否正确。如果配置不正确,它会抛出一个 ValueError 异常。
# 这一行是一个条件语句,它检查两个条件 :
# self.use_keypoints :这是一个布尔值,表示是否使用关键点数据。如果为 True ,则表示需要检查关键点的配置。
# (nkpt <= 0 or ndim not in {2, 3}) :这是一个复合条件,检查两个子条件 :
# kpt <= 0 :检查 nkpt (关键点的数量)是否小于或等于0,如果是,则表示关键点数量的配置不正确。
# ndim not in {2, 3} :检查 ndim (每个关键点的维度数)是否不是2或3。2表示每个关键点有x和y两个坐标,3表示每个关键点有x、y和可见性三个坐标。如果 ndim 不是这两个值中的任何一个,那么配置也不正确。
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
# 如果上述条件为真,即关键点的使用被启用,但关键点数量配置不正确或维度数不是2或3,那么代码将抛出一个 ValueError 异常。
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " # data.yaml 中的"kpt_shape"缺失或不正确。应为包含
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" # [关键点数量、维度数量(x、y 为 2,x、y、visible 为 3)] 的列表,即"kpt_shape:[17, 3]"。
)
# 使用线程池来并行处理图像和标签的验证工作。
with ThreadPool(NUM_THREADS) as pool:
# pool.imap(func, iterable, chunksize=None)
# imap() 方法用于将一个可迭代的输入序列分块分配到线程池中的线程进行处理,并将结果返回一个迭代器。这个方法特别适用于需要顺序处理输入和输出的场景。
# 参数 :
# func :一个函数,它将被调用并传入 iterable 中的每个项目。
# iterable :一个可迭代对象,其元素将被传递给 func 函数。
# chunksize :(可选)一个整数,指定了每个任务传递给 func 的项目数量。默认值为 1,意味着每个任务只包含一个项目。如果设置为大于 1 的值,那么 func 将接收到一个包含多个项目的列表。
# 返回值 :
# 返回一个 Iterator ,它生成每个输入元素经过 func 处理后的结果。
# 特点 :
# 结果的顺序与输入序列的顺序相同。
# 如果任何一个任务因为异常而终止, imap() 会立即抛出异常。
# 它允许主线程在子线程完成工作之前继续执行,而不是等待所有任务完成。
# ThreadPool.imap() 是处理 I/O 密集型任务或者需要顺序处理结果的并发任务的有用工具。与之相对的是 imap_unordered() ,它同样返回一个迭代器,但是结果的顺序可能与输入序列不同,适用于不在乎结果顺序的场景。
# 使用 imap 方法来映射 verify_image_label 函数到参数列表上,这些参数包括图像文件、标签文件、前缀、是否使用关键点、类别数量、关键点数量和维度。
# 这段代码使用了 concurrent.futures.ThreadPoolExecutor 类的 imap 方法来并行处理图像和标签文件的验证工作。
# results = pool.imap(...) :这行代码调用了线程池 pool 的 imap 方法,它将返回一个迭代器 results ,该迭代器生成 verify_image_label 函数的返回值。
results = pool.imap(
# func=verify_image_label :func 参数指定了要并行执行的函数,这里是 verify_image_label 函数。
# def verify_image_label(args):
# -> 用于验证图像和对应的标签文件是否有效。返回一个元组,包含以下信息。im_file :图像文件路径。 lb :包含类别和边界框坐标的数组。 shape :图像的形状(高度,宽度)。 segments :段落数据(如果有)。
# -> keypoints :包含关键点坐标和掩码的数组。 nm :缺失的标签数量。 nf :找到的标签数量。 ne :空的标签数量。 nc :损坏的标签数量。 msg :任何相关的消息或警告。
# -> return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg / return [None, None, None, None, None, nm, nf, ne, nc, msg]
func=verify_image_label,
# 解包参数。
# im_file :图像文件的路径。
# lb_file :标签文件的路径。
# prefix :一个前缀字符串,通常用于打印消息时标识来源。
# keypoint :一个布尔值,表示是否处理关键点数据。
# num_cls :数据集中类别的数量。
# nkpt :关键点的数量(如果 keypoint 为 True )。
# ndim :每个关键点的维度数(如果 keypoint 为 True )。
# im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
# iterable=zip(...) : iterable 参数指定了一个可迭代对象,它是由 zip 函数创建的。 zip 函数将多个可迭代对象打包成一个元组的迭代器,这里的元组包含了 verify_image_label 函数所需的参数。
iterable=zip(
# 这是一个包含图像文件路径的列表。
self.im_files,
# 这是一个包含对应标签文件路径的列表。
self.label_files,
# itertools.repeat(object[, times])
# repeat() 函数是 Python 标准库 itertools 模块中的一个函数,它用于创建一个迭代器,该迭代器会无限次重复给定的值。
# 参数 :
# object :要重复的值。
# times :(可选)重复的次数。如果不提供或为 None ,则迭代器将无限重复给定的值。
# 返回值 :
# 返回一个迭代器,该迭代器重复给定的值。
# repeat() 函数常用于需要固定值的场景,例如在 zip() 函数中为每个元素对提供相同的参数,或者在其他需要重复值的迭代处理中。
# itertools.repeat 函数创建一个迭代器,它无限重复给定的值,这里是 self.prefix 。这个前缀将被用于每个图像-标签对的验证。
repeat(self.prefix),
# 同样, self.use_keypoints 指示是否处理关键点数据,它将被重复用于每个图像-标签对。
repeat(self.use_keypoints),
# 这个参数指定了数据集中的类别数量,它将被重复用于每个图像-标签对。
repeat(len(self.data["names"])),
# nkpt 是关键点的数量,如果处理关键点数据,它将被重复用于每个图像-标签对。
repeat(nkpt),
# ndim 是每个关键点的维度数,如果处理关键点数据,它将被重复用于每个图像-标签对。
repeat(ndim),
),
)
# zip 函数将这些可迭代对象组合成一个迭代器,每次迭代产生一个包含所有参数的元组,这些元组将被传递给 verify_image_label 函数。 imap 方法将这些元组分配给线程池中的线程,以并行方式执行 verify_image_label 函数。
# 最终, results 迭代器将生成每个图像-标签对验证的结果,这些结果可以被用来更新进度条、记录日志或进行进一步的处理。
# 创建一个进度条对象 pbar ,用于显示进度。
# 这段代码使用了 TQDM 类来创建一个进度条对象 pbar ,用于显示 verify_image_label 函数处理图像和标签对的进度。
# 创建一个 TQDM 进度条对象 pbar 。 results : imap 方法返回的迭代器,它生成 verify_image_label 函数的返回值。 desc :进度条的描述字符串,通常是正在执行的任务的描述。 total :总的任务数量,用于设置进度条的总长度。
pbar = TQDM(results, desc=desc, total=total)
# 遍历 pbar 进度条对象,它将迭代 results 中的每个元素。 每次迭代返回一个元组,包含 verify_image_label 函数的返回值 :
# im_file :图像文件路径。
# lb :包含类别和边界框坐标的数组。
# shape :图像的形状(高度,宽度)。
# segments :段落数据(如果有)。
# keypoints :包含关键点坐标和掩码的数组。
# nm_f :缺失的标签数量。
# nf_f :找到的标签数量。
# ne_f :空的标签数量。
# nc_f :损坏的标签数量。
# msg :任何相关的消息或警告。
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
# 将当前迭代的 缺失标签 数量 nm_f 加到总的缺失标签数量 nm 上。
nm += nm_f
# 将当前迭代的 找到的标签 数量 nf_f 加到总的找到的标签数量 nf 上。
nf += nf_f
# 将当前迭代的 空标签 数量 ne_f 加到总的空标签数量 ne 上。
ne += ne_f
# 将当前迭代的 损坏标签 数量 nc_f 加到总的损坏标签数量 nc 上。
nc += nc_f
# 检查 im_file 是否非空,即是否有有效的图像文件路径。
if im_file:
# 如果有有效的图像文件路径,将标签信息追加到字典 x 的 "labels" 键对应的列表中。
# 标签信息包括 图像文件路径 、 形状 、 类别 、 边界框 、 段落 、 关键点 、 归一化标志 和 边界框格式 。
x["labels"].append(
{
"im_file": im_file,
"shape": shape,
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"segments": segments,
"keypoints": keypoint,
"normalized": True,
"bbox_format": "xywh",
}
)
# 检查是否有警告消息 msg 。
if msg:
# 如果有警告消息,将其追加到 msgs 列表中。
msgs.append(msg)
# 更新进度条的描述,显示已处理的图像数量、背景数量和损坏的标签数量。
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" # {desc} {nf} 图像,{nm + ne} 背景,{nc} 损坏。
# 这段代码的主要作用是遍历 verify_image_label 函数的处理结果,更新缺失、找到、空和损坏的标签数量,收集标签信息和警告消息,并实时更新进度条的描述。这样可以实现对长时间运行的任务的可视化,让用户了解任务的执行进度。
# 关闭进度条。
pbar.close()
# 如果有警告消息,使用 LOGGER 记录它们。
if msgs:
LOGGER.info("\n".join(msgs))
# 如果没有找到任何标签,记录一个警告。
if nf == 0:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") # {self.prefix}警告 ⚠️ 在 {path} 中未找到标签。{HELP_URL}。
# 计算标签文件和图像文件的哈希值,并存储在 x 字典中。
# def get_hash(paths): -> 用于计算一个包含文件或目录路径列表的单一哈希值。计算最终的哈希值,并以十六进制格式返回。 -> return h.hexdigest() # return hash
x["hash"] = get_hash(self.label_files + self.im_files)
# 存储结果统计信息。
# 在 cache_labels 方法中设置 x 字典的 "results" 键的值。 x 字典被用来存储和返回关于图像文件和标签文件验证结果的统计信息。具体来说, "results" 键被赋值为一个包含五个元素的元组 :
# nf :找到的标签数量(number of found labels)。
# nm :缺失的标签数量(number of missing labels)。
# ne :空的标签数量(number of empty labels)。
# nc :损坏的标签数量(number of corrupt labels)。
# len(self.im_files) :图像文件的总数。
# 这个元组提供了一个快照,显示了在扫描和验证过程中发现的不同类别的标签文件的数量,以及处理的图像文件的总数。这样的统计信息对于调试和验证数据集的完整性非常有用。
x["results"] = nf, nm, ne, nc, len(self.im_files)
# 存储警告消息。
x["msgs"] = msgs # warnings
# 保存缓存文件。
# def save_dataset_cache_file(prefix, path, x, version): -> 用于将一个名为 x 的数据集缓存字典保存到指定的路径。
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
# 返回包含标签信息的字典。
return x
# 这个方法的主要作用是验证图像和标签文件的一致性,并将结果缓存起来,以便后续使用。它使用了多线程来加速处理过程,并提供了进度条来显示当前的进度。
# 这段代码定义了一个名为 get_labels 的方法,它用于返回 YOLO 训练所需的标签字典。
# 定义了一个实例方法 get_labels ,没有额外参数。
def get_labels(self):
# 返回 YOLO 训练的标签字典。
"""Returns dictionary of labels for YOLO training."""
# 调用 img2label_paths 函数将图像文件路径转换为对应的标签文件路径,并更新 self.label_files 。
# def img2label_paths(img_paths): -> 它将图像文件路径转换为对应的标签文件路径。用于转换每个图像路径 x 到对应的标签路径。 -> return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
self.label_files = img2label_paths(self.im_files)
# 构造缓存文件路径,基于第一个标签文件的父目录,并添加 .cache 后缀。
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
# 开始一个 try 块,尝试加载缓存文件。
try:
# 尝试从 cache_path 加载缓存文件,并设置 exists 为 True 。
# def load_dataset_cache_file(path): -> 用于从指定路径加载 Ultralytics 的 .cache 字典文件。返回加载的缓存字典。 -> return cache
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
# 确保缓存的版本与当前版本 DATASET_CACHE_VERSION 匹配。
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
# 确保缓存的哈希值与当前标签文件和图像文件的哈希值匹配。
# def get_hash(paths): -> 用于计算一个包含文件或目录路径列表的单一哈希值。计算最终的哈希值,并以十六进制格式返回。 -> return h.hexdigest() # return hash
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
# 如果在尝试加载缓存时发生任何异常(文件未找到、断言错误或属性错误),则执行 except 块的代码。
except (FileNotFoundError, AssertionError, AttributeError):
# 调用 self.cache_labels 方法生成新的缓存,并设置 exists 为 False 。
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
# 从缓存中提取结果统计信息,并更新缓存字典。
# 这行代码是从一个名为 cache 的字典中提取关键的统计数据,这些数据关于数据集的标签和图像文件的验证结果。 cache 字典在此之前已经被填充了相应的数据,通常在处理或扫描数据集时生成。
# nf :代表 "number found",即 找到的标签 数量。这通常是指向数据集中存在有效标签文件的图像文件的数量。
# nm :代表 "number missing",即 缺失的标签 数量。这表示数据集中没有找到对应标签文件的图像文件的数量。
# ne :代表 "number empty",即 空的标签 数量。这指的是标签文件存在但为空的情况。
# nc :代表 "number corrupt",即 损坏的标签 数量。这表示标签文件存在但损坏或无法读取的情况。
# n :代表总的 图像文件数量 ,通常是 nf 、 nm 、 ne 和 nc 的总和。
# cache.pop("results") 是 Python 字典的一个方法,它移除并返回字典中键为 "results" 的值。这个方法调用后, "results" 键将从 cache 字典中删除。
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
# 如果缓存存在且本地等级为 -1 或 0 (通常用于单GPU或CPU环境),则显示扫描结果。
# LOCAL_RANK -> 用于在分布式训练环境中确定当前进程的本地排名(LOCAL_RANK)。
if exists and LOCAL_RANK in {-1, 0}:
# 构造描述字符串,显示扫描的 文件路径 、 找到的图像数量 、 背景数量 和 损坏的数量 。
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" # 扫描 {cache_path}...{nf} 幅图像、{nm + ne} 幅背景、{nc} 幅损坏图像。
# 使用 TQDM 显示结果进度条。
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
# 如果缓存中有警告消息,则使用 LOGGER 记录它们。
if cache["msgs"]:
# str.join(iterable)
# 在Python中, join() 是字符串( str )对象的一个方法,它用于将序列中的元素连接成一个新的字符串。这个方法将序列中的每个元素连接起来,并可以在它们之间插入一个指定的分隔符。
# 参数 :
# str :这是调用 join() 方法的字符串对象,它将被用作分隔符。
# iterable :这是一个可迭代对象,比如列表( list )、元组( tuple )、字符串( str )等,其元素将被连接。
# 返回值 :
# join() 方法返回一个新的字符串,该字符串是由 iterable 中的元素按照它们在序列中的顺序连接而成的,元素之间用 str 指定的分隔符分隔。
# 注意事项 :
# 如果 iterable 中包含非字符串类型的元素,将会引发 TypeError 。如果需要连接非字符串类型的元素,可以先将它们转换为字符串,然后再使用 join() 方法。
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
# 从缓存中移除不需要的键值对。
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
# 从缓存中提取标签列表。
labels = cache["labels"]
# 如果没有标签,记录警告消息。
if not labels:
LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") # 警告 ⚠️ 在 {cache_path} 中未找到图像,训练可能无法正常工作。{HELP_URL}。
# 更新 self.im_files 为缓存中的图像文件路径。
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
# 计算每个标签中的 类别 、 边界框 和 段落 的数量。
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
# 计算总的 类别 、 边界框 和 段落 数量。
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
# 如果数据集中同时包含 边界框 和 段落 ,但数量不一致,则发出警告并移除所有段落。
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " # 警告 ⚠️ 框和段数应该相等,但得到的 len(segments) = {len_segments},
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " # len(boxes) = {len_boxes}。为了解决这个问题,将只使用边界框,所有段都将被删除。
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." # 为了避免这种情况,请提供检测或分割数据集,而不是检测-分割混合数据集。
)
for lb in labels:
lb["segments"] = []
# 如果没有类别标签,记录警告消息。
if len_cls == 0:
LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") # 警告 ⚠️ 在 {cache_path} 中未找到标签,训练可能无法正常工作。{HELP_URL}。
# 返回标签列表。
return labels
# 这个方法的主要作用是加载或生成数据集的缓存,并从中提取标签信息,以便用于 YOLO 训练。它还负责检查缓存的版本和完整性,并在必要时更新图像文件路径。
# 这段代码定义了一个名为 build_transforms 的方法,它用于构建图像变换(transforms)并将其追加到一个列表中。这些变换通常用于数据预处理、数据增强或模型推理前的图像处理。
# 定义了一个实例方法 build_transforms ,它接受一个可选参数。
# 1.hyp :这个参数通常包含超参数或配置设置。
def build_transforms(self, hyp=None):
# 构建转换并将其附加到列表。
"""Builds and appends transforms to the list."""
# 检查实例变量 self.augment 是否为 True ,这通常表示是否启用数据增强。
if self.augment:
# 如果启用数据增强且不是矩形训练( self.rect 为 False ),则设置 hyp.mosaic 的值;否则,将其设置为 0.0 。
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
# 类似地,如果启用数据增强且不是矩形训练,设置 hyp.mixup 的值;否则,将其设置为 0.0 。
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
# 如果启用数据增强,调用 v8_transforms 函数来构建一系列变换,这些变换可能包括多种数据增强技术。
# def v8_transforms(dataset, imgsz, hyp, stretch=False):
# -> 用于构建一个图像增强的流程,通常用于计算机视觉任务,如目标检测或图像分类。创建了最终的变换流程,包括预变换、混合增强(MixUp)、随机HSV颜色变换、垂直翻转和水平翻转。水平翻转使用了 flip_idx 来确定哪些关键点需要翻转。
# -> return Compose([pre_transform, MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), Albumentations(p=1.0), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
# RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),]
transforms = v8_transforms(self, self.imgsz, hyp)
# 如果没有启用数据增强,使用 Compose 函数创建一个只包含 LetterBox 变换的变换列表,用于将图像调整到模型输入尺寸。
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
# 向 transforms 列表追加一个 Format 变换,这个变换用于格式化图像和标注数据,包括归一化、返回掩码、返回关键点等。
transforms.append(
# Format 变换的参数解释 :
# bbox_format="xywh" :指定边界框的格式为 (x, y, width, height) 。
# normalize=True :表示需要对图像进行归一化处理。
# return_mask=self.use_segments :如果 self.use_segments 为 True ,则返回掩码。
# return_keypoint=self.use_keypoints :如果 self.use_keypoints 为 True ,则返回关键点。
# return_obb=self.use_obb :如果 self.use_obb 为 True ,则返回方向边界框。
# batch_idx=True :表示是否返回批次索引。
# mask_ratio=hyp.mask_ratio :掩码的比例。
# mask_overlap=hyp.overlap_mask :掩码的重叠度。
# bgr=hyp.bgr if self.augment else 0.0 :如果启用数据增强,则使用 hyp.bgr 值;否则,设置为 0.0 。
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
)
)
# 返回构建好的变换列表。
return transforms
# 这个方法的主要作用是根据配置和数据增强的需求,构建一个图像变换的流程,以便在训练或推理前对图像进行适当的处理。
# 这段代码定义了一个名为 close_mosaic 的方法,它用于设置特定的数据增强选项(mosaic、copy_paste 和 mixup)为 0.0,即关闭这些数据增强技术,并且重新构建图像变换列表。
# 定义了一个实例方法 close_mosaic ,它接受一个参数。
# 1.hyp :这个参数通常包含超参数或配置设置。
def close_mosaic(self, hyp):
# 将马赛克、复制粘贴和混合选项设置为 0.0 并构建转换。
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
# 设置 hyp 中的 mosaic 选项为 0.0,这意味着关闭 mosaic 数据增强技术。
hyp.mosaic = 0.0 # set mosaic ratio=0.0
# 设置 hyp 中的 copy_paste 选项为 0.0,这意味着关闭 copy_paste 数据增强技术。
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
# 设置 hyp 中的 mixup 选项为 0.0,这意味着关闭 mixup 数据增强技术。
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
# 调用 build_transforms 方法,并传入更新后的 hyp 参数,以重新构建图像变换列表,并将结果赋值给 self.transforms 。
self.transforms = self.build_transforms(hyp)
# 这个方法的主要作用是在不需要使用 mosaic、copy_paste 和 mixup 数据增强技术时,提供一个方便的方式来关闭这些技术,并且确保图像变换列表与当前的配置设置相匹配。
# 这对于实验不同的数据增强策略或者在不同阶段(如训练和验证)使用不同的数据增强设置非常有用。通过关闭这些数据增强技术,可以确保模型在推理或验证阶段不会受到这些增强技术的影响。
# 这段代码定义了一个名为 update_labels_info 的方法,它用于更新标签信息,将提取的 边界框 、 段落 和 关键点数据 整合到一个 Instances 对象中,并返回更新后的标签字典。
# 定义了一个实例方法 update_labels_info ,它接受一个参数。
# 1.label :这是一个包含标签信息的字典。
def update_labels_info(self, label):
# 在此处自定义您的标签格式。
# 注意:
# cls 现在不包含 bboxes ,分类和语义分割需要独立的 cls 标签。
# 还可以通过添加或删除字典键来支持分类和语义分割。
"""
Custom your label format here.
Note:
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
Can also support classification and semantic segmentation by adding or removing dict keys there.
"""
# 从 label 字典中提取边界框数据,并从字典中移除该键值对。
bboxes = label.pop("bboxes")
# 从 label 字典中提取段落数据,默认值为空列表。
segments = label.pop("segments", [])
# 从 label 字典中提取关键点数据,默认值为 None 。
keypoints = label.pop("keypoints", None)
# 从 label 字典中提取边界框格式。
bbox_format = label.pop("bbox_format")
# 从 label 字典中提取归一化标志。
normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes
# 根据是否使用 方向边界框 (Oriented Bounding Boxes, OBB),设置段落重采样的数量。
segment_resamples = 100 if self.use_obb else 1000
# 检查是否有段落数据。
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
# (N, 1000, 2)
# 如果有段落数据,调用 resample_segments 函数对段落进行重采样,并将结果堆叠成一个 NumPy 数组。
# def resample_segments(segments, n=1000): -> 用于对输入的线段进行重新采样,使得每个线段包含 n 个点。返回重新采样后的线段数组。 -> return segments
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
# 如果没有段落数据,创建一个填充零的 NumPy 数组。
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
# 创建一个 Instances 对象,包含边界框、段落、关键点、边界框格式和归一化标志,并将该对象赋值给 label 字典中的 "instances" 键。
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
# 返回更新后的标签字典。
return label
# 这个方法的主要作用是将标签信息中的边界框、段落和关键点数据整合到一个 Instances 对象中,这样可以更方便地在后续的处理中使用这些数据。 Instances 对象通常用于表示图像中的实例,包括它们的边界框、关键点和可能的其他属性。通过重采样段落数据,可以确保数据的一致性和格式的正确性。
# 这段代码定义了一个名为 collate_fn 的静态方法,它用于将多个数据样本(通常从数据加载器中获取)合并成一批数据,以便进行批量处理。这个方法通常用在 PyTorch 的 DataLoader 中,作为其 collate_fn 参数。
# 表示 collate_fn 是一个静态方法,它不需要访问类的实例变量。
@staticmethod
# 定义了一个方法 collate_fn ,它接受一个参数。
# 1.batch :这是一个列表,包含了多个数据样本(字典)。
def collate_fn(batch):
# 将数据样本整理成批。
"""Collates data samples into batches."""
# 初始化一个空字典 new_batch ,用于存储合并后的数据。
new_batch = {}
# 获取第一个数据样本的键(key),这些键代表了数据样本中的不同字段。
# 在计算机科学和数据处理中,"字段"(Field)通常指的是数据结构中的一个元素,它存储了数据记录中的一个特定信息项。字段可以被视为数据的最小单位,它们组成了更大的数据实体,如记录或对象。
# 字段是数据的基本构建块,它们定义了数据的结构和内容,使得数据可以被有效地存储、访问和处理。在不同的上下文中,字段可能有不同的名称,如属性(Attribute)、列(Column)、参数(Parameter)等。
# 在数据处理和机器学习中,数据样本通常被组织为字典(dictionary)格式,其中每个键(key)对应于一个特定的数据字段。这种组织方式允许我们将不同类型的数据关联在一起,每个字段包含有关样本的不同方面的信息。
# 以下是为什么第一个数据样本的键(key)代表了数据样本中的不同字段的几个原因 :
# 标准化结构 :
# 在许多数据处理流程中,所有数据样本都遵循相同的结构。这意味着每个样本都包含相同的字段集合,无论它们的具体值如何。通过查看第一个样本的键,我们可以了解所有样本的字段结构。
# 一致性 :
# 为了确保数据处理的一致性,所有样本都预期具有相同的字段。这允许处理函数(如 collate_fn )在不额外检查每个样本的情况下,对任何样本执行相同的操作。
# 批处理 :
# 在批处理数据时,我们通常将多个样本组合在一起进行处理。这些样本的键集合提供了一个蓝图,指示如何处理和组织批量数据。例如,在PyTorch的 DataLoader 中, collate_fn 需要知道如何将不同字段的数据堆叠或连接起来。
# 数据加载器的假设 :
# 数据加载器(如PyTorch的 DataLoader )通常假设所有样本都具有相同的键集合。这种假设允许数据加载器在内部优化内存使用和数据处理。
# 错误检查 :
# 如果所有样本都预期具有相同的字段,那么检查第一个样本的键集合可以作为一种错误检查机制。如果后续样本缺少任何预期的键,这可能表明数据存在问题。
# 灵活性 :
# 即使在实际应用中,某些样本可能缺少某些字段或包含额外的字段,处理函数通常设计为只处理存在的键。这意味着即使数据样本在字段上有所不同,只要它们至少包含第一个样本的所有键,处理函数仍然可以正常工作。
# 因此,第一个数据样本的键(key)代表了数据样本中的不同字段,这是数据处理和机器学习中常见的一种假设和实践,它有助于确保数据的一致性和处理流程的标准化。
# 这行代码是在 collate_fn 函数中获取批处理数据中第一个样本的键(key)集合。
# batch :这是一个列表,包含了多个数据样本,每个样本通常是一个字典(dictionary),包含了图像和其对应的标签信息。
# batch[0] :通过索引 [0] 访问列表中的第一个数据样本,这个样本是一个字典。
# .keys() :这是一个字典方法,用于获取字典中所有的键(key),并以视图的形式返回。这个视图显示了字典的键,但不包括值。
# keys = batch[0].keys() :将第一个数据样本的键集合赋值给变量 keys 。
# 这个变量将在后续的代码中用于迭代和访问每个样本中对应的值(value)。这行代码的作用是为后续的数据处理提供一个字段列表,使得 collate_fn 函数能够按字段处理每个样本的数据,例如合并图像张量或连接标签张量。
# 通过首先获取键集合, collate_fn 函数可以确保它能够正确地处理每个样本中的所有数据,无论样本中包含哪些字段。
keys = batch[0].keys()
# 将每个数据样本的值(value)提取出来,并使用 zip 函数将它们按字段分组。
values = list(zip(*[list(b.values()) for b in batch]))
# 遍历每个字段。
for i, k in enumerate(keys):
# 获取当前字段的所有值。
value = values[i]
# 如果字段是图像数据,使用 torch.stack 将它们堆叠成一个多维张量。
if k == "img":
value = torch.stack(value, 0)
# 如果字段是掩码、关键点、边界框、类别、段落或方向边界框,使用 torch.cat 将它们连接成一个张量。
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
value = torch.cat(value, 0)
# 将处理后的值赋给 new_batch 字典中对应的字段。
new_batch[k] = value
# 将 batch_idx 转换为列表,以便进行索引调整。
new_batch["batch_idx"] = list(new_batch["batch_idx"])
# 遍历 batch_idx 列表,为每个样本添加目标图像索引。
for i in range(len(new_batch["batch_idx"])):
# batch_idx 通常表示"批次索引"(batch index),它是一个用于标识每个样本在当前批次中位置的索引。
# 更新 batch_idx ,为每个样本添加其在批次中的位置索引。
# 这行代码将每个 "batch_idx" 元素的值增加它在列表中的索引 i 。
# new_batch["batch_idx"] :这是 new_batch 字典中的一个键,它对应的值是一个列表,包含了每个样本在原始数据批次中的索引。
# new_batch["batch_idx"][i] :这是访问 "batch_idx" 列表中的第 i 个元素。
# new_batch["batch_idx"][i] += i :这行代码将列表中第 i 个元素的值增加 i 。这样做的目的是为了为每个样本分配一个唯一的索引,这个索引反映了样本在新批次中的相对位置。
# 假设我们有一个包含多个样本的批次,每个样本都有一个与之关联的 batch_idx ,表示该样本在原始数据集中的索引。当我们使用 collate_fn 函数将这些样本合并成一个批次时,我们可能需要保留这些原始索引信息,以便在后续的处理中使用。
# 这行代码的目的是将每个样本的原始索引(存储在 new_batch["batch_idx"] 中)与它在新批次中的偏移量(即 i )相加。这样做的原因是为了调整每个样本的索引,使其反映在新批次中的位置。
# 使用 new_batch["batch_idx"][i] += i 更新 batch_idx 列表并不是绝对必须的,但它是一种常见且有效的方式来为每个样本在新批次中分配一个唯一的索引。这个操作的目的是确保每个样本的索引在新批次中是连续的,并且能够反映它们在原始数据集中的位置。
# 为什么使用 new_batch["batch_idx"][i] += i ? 这种特定的更新方式( += i )确保了每个样本的索引在新批次中是唯一的,并且与它们在原始批次中的位置相对应。
# 使用其他方法,例如 new_batch["batch_idx"][i] += 2i ? 使用 new_batch["batch_idx"][i] += 2i 或其他类似的方法是可能的,但它会改变样本索引的分布。
# 这样,样本的索引将不再是连续的,而是以2的倍数递增。这种更新方式可能会带来一些问题 :
# 非连续索引:样本的索引将不再是连续的,这可能会导致一些依赖于连续索引的操作出现问题。
# 索引间隔:样本的索引间隔将增加,这可能会影响某些算法的性能,特别是那些需要索引连续性或特定间隔的算法。
# 复杂性增加:这种方法增加了代码的复杂性,因为它不再直观地反映样本在原始批次中的位置。
# 结论 :
# 虽然技术上可以使用 new_batch["batch_idx"][i] += 2i 或其他类似的更新方式,但通常不建议这样做,除非你有特定的理由需要这样做。
# 在大多数情况下,使用 new_batch["batch_idx"][i] += i 是更好的选择,因为它保持了样本索引的连续性和直观性,这对于大多数数据处理和机器学习任务来说是有利的。
new_batch["batch_idx"][i] += i # add target image index for build_targets()
# 将更新后的 batch_idx 列表连接成一个张量。
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
# 返回合并后的批次数据。
return new_batch
# 这个方法的主要作用是将多个数据样本合并成一个批次,以便进行批量处理。它处理了不同类型的数据,如图像、掩码、关键点等,并确保它们被正确地堆叠或连接。此外,它还更新了 batch_idx ,这对于后续的目标构建(如在目标检测中)非常重要。
3.class YOLOMultiModalDataset(YOLODataset):
python
# 这段代码定义了一个名为 YOLOMultiModalDataset 的类,它是 YOLODataset 类的子类。这个类继承了 YOLODataset 的功能,并添加了对多模态模型训练的支持,特别是处理文本信息。
# 多模态训练(Multimodal Training)是指在机器学习和人工智能领域中,同时处理和整合来自多种不同类型数据源或模态的信息的训练过程。在现实世界中,信息往往不是单一模态的,而是以多种形式出现,例如视觉(图像)、文本、音频和传感器数据等。
# 多模态训练的目的是要让模型能够理解和利用这些不同类型的数据,以提高其性能和泛化能力。以下是多模态训练的一些关键特点和应用场景 :
# 数据融合 :
# 多模态训练涉及将不同模态的数据融合在一起,以便模型可以从多个角度理解数据。例如,结合图像和相关文本描述来提高图像识别的准确性。
# 特征提取 :
# 在多模态训练中,模型需要从每种模态中提取特征,并将这些特征整合起来,以形成更全面的数据处理。
# 上下文理解 :
# 多模态数据可以提供更丰富的上下文信息。例如,视频理解任务中,结合视觉信息和音频信息可以更好地理解视频内容。
# 鲁棒性提升 :
# 通过整合多种模态的数据,模型可以变得更加鲁棒,因为一种模态的数据缺失或不准确时,其他模态可以提供补充信息。
# 应用场景 :
# 多模态训练在许多领域都有应用,包括自动驾驶(结合视觉和雷达数据)、情感分析(结合语音和文本)、健康医疗(结合医疗影像和电子健康记录)等。
# 挑战 :
# 多模态数据的处理面临一些挑战,如不同模态数据的同步、特征空间的差异、数据标注的复杂性等。
# 模型设计 :
# 多模态训练需要特别设计的模型架构,这些架构能够处理不同类型的数据,并有效地整合这些数据。
# 数据增强 :
# 在多模态训练中,可以通过增强一种模态的数据来提高模型对其他模态数据的泛化能力。
# 多模态训练是人工智能领域的一个重要研究方向,它旨在使模型能够更全面地理解和处理复杂的、多模态的世界。随着技术的发展,多模态训练的方法和应用将越来越多样化。
# 这行代码声明了一个名为 YOLOMultiModalDataset 的新类,它继承自 YOLODataset 类。
class YOLOMultiModalDataset(YOLODataset):
# 用于以 YOLO 格式加载对象检测和/或分割标签的数据集类。
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
# 这是 YOLOMultiModalDataset 类的构造函数,它接受任意数量的位置参数 *args 、关键字参数 **kwargs ,以及两个具有默认值的命名参数: data=None 和 task="detect" 。
# 1.*args :允许传递任意数量的位置参数到父类的构造函数。
# 2.data=None :一个可选参数,用于传递数据集的配置或数据信息。默认值为 None 。
# 3.task="detect" :一个可选参数,用于指定数据集的任务类型,例如目标检测。默认值为字符串 "detect" 。
# 4.**kwargs :允许传递任意数量的关键字参数到父类的构造函数。
def __init__(self, *args, data=None, task="detect", **kwargs):
# 使用可选规范初始化用于对象检测任务的数据集对象。
"""Initializes a dataset object for object detection tasks with optional specifications."""
# 这行代码调用了父类 YOLODataset 的构造函数,并将所有传入的位置参数 *args 、命名参数 data 和 task ,以及关键字参数 **kwargs 传递给它。
# super() 函数用于访问父类(在这种情况下是 YOLODataset )的属性和方法。
super().__init__(*args, data=data, task=task, **kwargs)
# 这个构造函数的作用是初始化 YOLOMultiModalDataset 对象,设置其属性,并准备数据集以供后续的目标检测或其他相关任务使用。通过继承 YOLODataset , YOLOMultiModalDataset 能够利用父类中定义的方法和属性,同时添加或修改特定于多模态任务的功能。
# 这段代码定义了一个名为 update_labels_info 的方法,它是 YOLOMultiModalDataset 类中的一个成员函数。这个方法的作用是扩展父类 YOLODataset 中的 update_labels_info 方法,以添加文本信息,这些信息对于多模态模型训练是必要的。
# 定义了一个实例方法 update_labels_info ,它接受一个参数。
# 1.label :这个参数通常是一个包含标签信息的字典。
def update_labels_info(self, label):
# 添加用于多模态模型训练的文本信息。
"""Add texts information for multi-modal model training."""
# 调用父类 YOLODataset 的 update_labels_info 方法,并传入 label 参数。这行代码的目的是获取父类处理后的标签信息,并将其存储在变量 labels 中。
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`. 注意:一些类别与其同义词通过"/"连接。
# dict.items()
# 在Python中, items() 是一个字典( dict )对象的方法,它用于返回一个包含字典中所有键值对的视图对象。这个视图对象可以被用来迭代字典中的键值对。
# 返回值 :
# 一个包含字典中所有键值对的 dict_items 视图对象。
# 特点 :
# dict_items 视图对象是动态的,这意味着在迭代过程中如果字典被修改,迭代器也会反映这些变化。
# items() 方法返回的视图对象不支持索引操作,它只能被用来迭代。
# 与其他方法的比较 :
# keys() :返回一个包含字典中所有键的视图对象。
# values() :返回一个包含字典中所有值的视图对象。
# items() 方法是处理字典时常用的一个工具,它允许你同时访问键和值,这在需要对字典中的每个元素进行操作时非常有用。
# 这行代码创建一个新的列表,其中包含 self.data["names"] 字典中每个值(代表类别名称)按 / 分隔的同义词。
# self.data["names"] 是一个字典,其中键是类别索引,值是类别名称。
# .items() 方法将字典转换为键值对的列表。
# 列表推导式 [v.split("/") for _, v in self.data["names"].items()] 遍历每个键值对,并将值(类别名称)按 / 分隔成列表。
# 最后,这个列表被赋值给 labels 字典中的 "texts" 键。
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
# 返回更新后的 labels 字典,其中包含了原始标签信息以及新增的文本信息。
return labels
# 这个方法的主要作用是为多模态模型训练提供额外的文本信息,这些信息可能包括类别名称及其同义词。这对于模型理解不同的类别表示和上下文信息非常重要,尤其是在处理图像和文本数据时。通过这种方式, YOLOMultiModalDataset 类能够支持更复杂的数据集,这些数据集不仅包含视觉信息,还包含文本信息。
# 这段代码定义了一个名为 build_transforms 的方法,它是 YOLOMultiModalDataset 类中的一个成员函数。这个方法的作用是构建和增强数据变换,特别是为多模态训练添加可选的文本增强。
# 定义了一个实例方法 build_transforms ,它接受一个可选参数.
# 1.hyp :这个参数通常包含超参数或配置设置。
def build_transforms(self, hyp=None):
# 通过可选的文本增强来增强数据转换,以进行多模式训练。
"""Enhances data transformations with optional text augmentation for multi-modal training."""
# 调用父类 YOLODataset 的 build_transforms 方法,并传入 hyp 参数。这行代码的目的是获取父类构建的基本数据变换列表,并将其存储在变量 transforms 中。
transforms = super().build_transforms(hyp)
# 检查实例变量 self.augment 是否为 True ,这通常表示是否启用数据增强。
if self.augment:
# NOTE: hard-coded the args for now.
# 如果启用数据增强,使用 insert 方法在 transforms 列表的末尾插入一个新的变换对象 RandomLoadText 。
# RandomLoadText 是一个自定义的变换类,它用于加载文本数据并进行随机采样。
# max_samples 参数设置为 self.data["nc"] (数据集中的类别数量)和 80 之间的最小值,这限制了采样的最大样本数。
# padding 参数设置为 True ,表示需要对文本数据进行填充处理。
# class RandomLoadText:
# -> 用于处理文本数据的加载。
# -> def __init__(self, prompt_format: str = "{}", neg_samples: Tuple[int, int] = (80, 80), max_samples: int = 80, padding: bool = False, padding_value: str = "", ) -> None:
# -> RandomLoadText 类的 __call__ 方法,该方法用于处理和随机选择文本标签。
# -> def __call__(self, labels: dict) -> dict:
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
# 返回更新后的数据变换列表,其中包含了父类的基本变换以及新增的文本增强变换。
return transforms
# 这个方法的主要作用是为多模态模型训练提供数据变换的支持,特别是在需要处理图像和文本数据时。通过添加文本增强,模型可以学习从文本中提取特征,这对于提高模型在多模态任务中的性能是非常重要的。
# 这个类的主要作用是扩展 YOLODataset 的功能,使其能够处理多模态数据,特别是包含文本信息的数据。这对于训练能够同时处理图像和文本的多模态模型非常有用。
4.class GroundingDataset(YOLODataset):
python
# 这段代码定义了一个名为 GroundingDataset 的类,它是 YOLODataset 类的子类。这个类专门用于处理目标检测任务,并且能够从 JSON 格式的注释文件中加载图像和标签数据。
# JSON(JavaScript Object Notation)格式文件是一种轻量级的数据交换格式,它基于JavaScript编程语言的对象表示方法,但是独立于语言,可以被多种编程语言读取和写入。JSON格式易于人阅读和编写,同时也易于机器解析和生成。
# JSON格式的特点 :
# 文本格式 : JSON是一种纯文本格式,可以被人类阅读和编写,也可以被机器轻松解析。
# 键值对 : JSON数据由键值对组成,其中键是字符串,值可以是字符串、数字、布尔值、数组、对象或其中的组合。
# 数组 : 在JSON中,值也可以是数组,数组中的元素可以是任何类型的值。
# 嵌套 : JSON支持嵌套结构,对象可以包含其他对象或数组。
# 语言无关性 : 尽管JSON的名字来源于JavaScript,但它与编程语言无关,可以被任何编程语言处理。
# 数据交换 : JSON广泛用于网络应用之间的数据交换,特别是在Web开发中,用于在客户端和服务器之间传输数据。
# JSON文件的示例:
# {
# "name": "John Doe",
# "age": 30,
# "is_student": false,
# "skills": ["Python", "JavaScript", "Machine Learning"],
# "address": {
# "street": "123 Main St",
# "city": "Anytown",
# "state": "CA"
# }
# }
# 在这个示例中,有一个JSON对象,它包含了一个人的姓名、年龄、是否是学生、技能列表和地址信息。这个对象由多个键值对组成,其中" skills "是一个数组," address "是一个嵌套的对象。
# JSON的使用场景 :
# 配置文件 : 许多应用程序使用JSON作为配置文件的格式。
# Web APIs : RESTful Web服务通常使用JSON作为请求和响应的数据格式。
# 数据存储 : 一些轻量级数据库和文件格式使用JSON来存储数据。
# 缓存 : Web浏览器和其他应用程序可能使用JSON作为缓存数据的格式。
# 移动应用 : 移动应用程序经常使用JSON来交换数据。
# JSON因其简洁性和灵活性而成为数据交换的首选格式之一。
class GroundingDataset(YOLODataset):
# 通过从指定的 JSON 文件加载注释来处理对象检测任务,支持 YOLO 格式。
"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
# 这是 GroundingDataset 类的构造函数,它接受任意数量的位置参数 *args 、关键字参数 **kwargs ,以及两个特定的参数: task 和 json_file 。
#1.task :一个字符串参数,默认值为 "detect" ,表示数据集的任务类型。当前类只支持目标检测任务。
# 2.json_file :一个参数,表示包含注释信息的 JSON 文件的路径。
def __init__(self, *args, task="detect", json_file, **kwargs):
# 初始化 GroundingDataset 用于对象检测,从指定的 JSON 文件加载注释。
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
# 这行代码使用 assert 语句确保 task 参数的值是 "detect" 。如果 task 不是 "detect" ,则会抛出一个 AssertionError ,提示当前 GroundingDataset 类只支持目标检测任务。
assert task == "detect", "`GroundingDataset` only support `detect` task for now!" # GroundingDataset` 目前仅支持 `detect` 任务!
# 这行代码将传入的 json_file 参数赋值给实例变量 self.json_file ,以便后续使用。
self.json_file = json_file
# 这行代码调用父类 YOLODataset 的构造函数,并传递所有位置参数 *args 、关键字参数 **kwargs ,以及 task 和 data 参数。
# data={} :一个空字典,表示在 GroundingDataset 类中,数据集的配置或数据信息可以通过其他方式(如 JSON 文件)加载,而不是通过 data 参数传递。
super().__init__(*args, task=task, data={}, **kwargs)
# 构造函数的作用是初始化 GroundingDataset 对象,设置其属性,并准备数据集以供后续的目标检测任务使用。通过继承 YOLODataset , GroundingDataset 能够利用父类中定义的方法和属性,同时添加或修改特定于从 JSON 文件加载注释信息的功能。
# 这段代码定义了一个名为 get_img_files 的方法,它是 GroundingDataset 类中的一个成员函数。这个方法的目的是返回一个包含图像文件路径的列表。不过,在这个特定的实现中,方法直接返回了一个空列表 [] ,这意味着它没有提供任何图像文件路径。
# 定义了一个实例方法 get_img_files ,它接受两个参数。
# 1.self :指向当前实例的引用。
# 2.img_path :一个字符串,表示图像文件的路径或目录。
def get_img_files(self, img_path):
# 图像文件将在" get_labels "函数中读取,在此处返回空列表。
"""The image files would be read in `get_labels` function, return empty list here."""
# 方法返回一个空的列表,这表明在当前的类实现中,没有图像文件被检索或包含在数据集中。
return []
# 这个方法可能是一个占位符或钩子(hook),在子类中被重写以提供具体的图像文件路径。在 GroundingDataset 类的上下文中,图像文件路径在其他方法中被处理,例如在 get_labels 方法中,从 JSON 注释文件中构建标签信息时,会同时检索图像文件路径。
# 在编程中,"钩子"(hook)是一个术语,指的是一个允许用户或开发者介入或修改软件标准行为的机制。钩子通常用于扩展或自定义软件的功能,而不需要修改其核心代码。它们提供了一个接口或回调函数,可以在特定的事件发生时触发,允许用户插入自己的代码来处理事件。
# 在 GroundingDataset 类的上下文中, get_img_files 方法可以被视为一个钩子,因为它提供了一个可以被重写的方法,以便在子类中实现特定的图像文件检索逻辑。
# 这样, YOLODataset 类的使用者可以根据自己的需求定制图像文件的加载方式,而不需要修改 YOLODataset 类的核心代码。通过重写这个方法,用户可以插入自己的代码来指定如何获取图像文件的路径列表。
# 这段代码定义了一个名为 get_labels 的方法,它是 GroundingDataset 类中的一个成员函数。这个方法的作用是从 JSON 格式的注释文件中加载注释数据,并构建标签信息列表。
def get_labels(self):
# 从 JSON 文件加载注释,过滤并规范化每个图像的边界框。
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
# 初始化一个空列表 labels ,用于存储构建的标签信息。
labels = []
# 使用日志记录器记录开始加载注释文件的信息。
LOGGER.info("Loading annotation file...") # 正在加载注释文件...
# 打开实例变量 self.json_file 指定的 JSON 文件。
with open(self.json_file) as f:
# json.load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=False, object_pairs_hook=None)
# json.load() 函数是 Python json 模块中的一个函数,它用于从文件对象中加载 JSON 数据,并将其转换为 Python 的数据类型(通常是字典或列表)。
# 参数 :
# fp :要加载的文件对象,该文件对象应该处于可读模式。
# encoding :(可选)文件的编码,默认为 None ,表示使用文件对象的默认编码。
# cls :(可选)一个自定义的解码器类,用于解析 JSON 数据。
# object_hook :(可选)一个函数,用于处理解码后的对象。
# parse_float :(可选)一个用于解析浮点数的函数。
# parse_int :(可选)一个用于解析整数的函数。
# parse_constant :(可选)一个用于解析其他 JSON 常量(如 null 、 true 、 false )的函数。
# strict :(可选)布尔值,指示是否严格解析 JSON。默认为 False 。
# object_pairs_hook :(可选)一个函数,用于处理 JSON 对象的键值对。
# 返回值 :
# 返回一个 Python 数据类型,通常是字典或列表,这取决于 JSON 数据的结构。
# json.load() 函数是处理 JSON 数据的基本工具之一,它允许你轻松地将 JSON 格式的数据转换为 Python 可以操作的数据类型。这使得从文件中读取配置、数据交换或存储变得非常方便。
# 使用 json.load 函数从文件中加载注释数据,并将其存储在变量 annotations 中。
annotations = json.load(f)
# 从注释数据中提取图像信息,并将其存储在 images 字典中,键为图像的 ID,值为图像数据。
# 这行代码是一个字典推导式(dictionary comprehension),它用于从 annotations 字典中的 "images" 键对应的列表创建一个新的字典 images 。这个新的字典将每个图像的 ID 作为键,图像的详细信息作为值。
# annotations["images"] :
# 这是从 annotations 字典中获取的列表,其中包含了所有图像的注释信息。
# for x in annotations["images"] :
# 这个循环遍历 annotations["images"] 列表中的每个元素 x ,每个 x 代表一个图像的注释信息。
# f'{x["id"]:d}' :
# 这是一个格式化字符串(f-string),用于创建字典的键。
# x["id"] 获取当前图像注释信息中的 ID。
# :d 指定格式化为整数( int ),但在这里实际上是多余的,因为 ID 通常已经是整数。这可能是一个错误或不必要的格式化。
# {f'{x["id"]:d}': x for x in annotations["images"]} :
# 字典推导式创建一个新的字典,其中每个键是图像的 ID,每个值是对应的图像注释信息 x 。
# images = ... :
# 将字典推导式的结果赋值给变量 images 。
# 因此,这行代码的结果是创建了一个名为 images 的字典,其键是图像的 ID,值是包含图像详细信息的字典。这使得后续可以通过图像 ID 快速访问任何图像的注释信息。
images = {f'{x["id"]:d}': x for x in annotations["images"]}
# collections.defaultdict(default_factory)
# defaultdict 是 Python 标准库 collections 模块中的一个类,它用于创建一个带有默认值的字典。与普通字典不同, defaultdict 在访问不存在的键时,不会抛出 KeyError 异常,而是会使用一个用户定义的工厂函数来提供一个默认值。
# 参数 :
# default_factory :一个可调用对象,当访问的键不存在时,会被调用以产生默认值。通常是一个函数或类。
# 返回值 :
# 返回一个 defaultdict 对象。
# 使用示例 :
# d = defaultdict(list)
# d['new_key'].append('value')
# print(d['new_key']) # 输出: ['value']
# 在这个示例中, defaultdict 使用 list 作为默认值工厂,所以当访问不存在的键时,它返回一个空列表,然后我们可以在这个列表上执行操作,如 append 。
# 特点 :
# defaultdict 避免了在访问字典键之前需要检查键是否存在的麻烦。
# 它在处理字典时提供了一种更简洁和直观的方式来处理缺失的键。
# defaultdict 可以与任何可调用对象一起使用,这提供了灵活的默认值生成方式。
# defaultdict 是处理字典时的一个非常有用工具,特别是在需要为缺失的键提供默认值的场景中。
# 初始化一个默认值为列表的字典 img_to_anns ,用于存储每个图像的注释信息。
img_to_anns = defaultdict(list)
# 遍历注释数据中的所有注释。
for ann in annotations["annotations"]:
# 将每个注释添加到对应图像 ID 的列表中。
img_to_anns[ann["image_id"]].append(ann)
# 使用 TQDM 显示读取注释的进度,并遍历每个图像的注释列表。
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"): # 读取注释{self.json_file}。
# 获取当前图像的信息。
img = images[f"{img_id:d}"]
# 从图像信息中提取高度、宽度和文件名。
h, w, f = img["height"], img["width"], img["file_name"]
# 构建图像文件的完整路径。
im_file = Path(self.img_path) / f
# 检查图像文件是否存在。
if not im_file.exists():
continue
# 如果文件存在,将其路径添加到 self.im_files 列表中。
self.im_files.append(str(im_file))
# 初始化一个空列表 bboxes ,用于存储 边界框信息 。
bboxes = []
# 初始化一个空字典 cat2id ,用于存储 类别名称 到 类别 ID 的映射。
cat2id = {}
# 初始化一个空列表 texts ,用于存储 文本信息 。
texts = []
# 遍历当前图像的所有注释。
for ann in anns:
# 如果注释被标记为 iscrowd ,则跳过。
if ann["iscrowd"]:
continue
# 将注释的 边界框 转换为 NumPy 数组。
box = np.array(ann["bbox"], dtype=np.float32)
# 调整边界框的坐标,使其符合 中心点和宽度/高度 的格式。
box[:2] += box[2:] / 2
# 归一化边界框的 x 坐标。
box[[0, 2]] /= float(w)
# 归一化边界框的 y 坐标。
box[[1, 3]] /= float(h)
# 检查边界框的宽度和高度是否有效。
if box[2] <= 0 or box[3] <= 0:
continue
# 从图像的描述中提取类别名称。
# 这行代码是 Python 中的一个列表推导式,用于从图像的描述(caption)中提取正面的标记(tokens_positive),并将它们连接成一个完整的字符串,即类别名称(cat_name)。
# img["caption"] :
# 这是图像信息中的描述字段,通常是一个字符串,包含了图像的文本描述。
# ann["tokens_positive"] :
# 这是注释信息中的一个字段,包含了正面标记的列表。每个标记是一个元组 t ,其中 t[0] 是标记在描述中开始的索引, t[1] 是标记结束的索引。
# [img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]] :
# 这是一个列表推导式,它遍历 ann["tokens_positive"] 中的每个元组 t ,并使用切片 img["caption"][t[0] : t[1]] 提取描述中相应的文本片段。
# " ".join(...) :
# join 方法用于将列表中的所有字符串元素连接成一个单一的字符串,元素之间用一个空格分隔。
# 这行代码的作用是将图像描述中所有正面标记的文本片段提取出来,并用空格连接成一个完整的类别名称。
# 例如,如果一个图像的描述是 "A cat sits on a mat",并且 ann["tokens_positive"] 包含了 ((0, 3), (10, 13)) ,那么 cat_name 将是 "A cat a mat"。
# 这种方法常用于从图像的文本描述中提取类别信息,特别是在需要将文本描述与图像中的对象关联起来的场景中,例如在目标检测和图像标注任务中。
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
# 如果类别名称不在 cat2id 字典中,则添加它。
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
# 获取类别 ID。
cls = cat2id[cat_name] # class
# 将类别 ID 添加到边界框信息中。
# box.tolist() :
# box 是一个 NumPy 数组,包含了边界框的坐标信息。 tolist() 方法将 NumPy 数组 box 转换为一个 Python 列表。
# [cls] + box.tolist() :
# cls 是类别 ID,它被添加到列表的开始位置。 + 运算符用于连接两个列表,这里将包含类别 ID 的列表 [cls] 和 box 数组转换后的列表连接起来。
# 最终, box 变量变成了一个新的列表,其中第一个元素是类别 ID,后面跟着边界框的坐标信息。这样的格式通常用于表示一个边界框及其对应的类别,例如在目标检测任务中。
# 举个例子,如果 cls 是 1(代表某个类别),而 box 数组是 [50, 50, 100, 100] (代表边界框的 x, y, width, height),那么执行这行代码后, box 将变为: [1, 50, 50, 100, 100] 。
box = [cls] + box.tolist()
# 如果边界框不在 bboxes 列表中,则添加它。
if box not in bboxes:
bboxes.append(box)
# 将边界框信息转换为 NumPy 数组。
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
# 将构建的标签信息添加到 labels 列表中。
# 这行代码是 Python 中的一个字典构造,随后使用 append 方法将这个字典添加到 labels 列表中。这个字典包含了关于一个图像及其标签的详细信息,通常用于机器学习任务,特别是在目标检测中。
labels.append(
{
# 键 "im_file" 对应的值是 im_file ,这是一个字符串,表示图像 文件的路径 。
"im_file": im_file,
# 键 "shape" 对应的值是一个元组 (h, w) ,表示图像的 高度和宽度 。
"shape": (h, w),
# 键 "cls" 对应的值是 lb[:, 0:1] ,这是一个 NumPy 数组切片,表示所有边界框的类别信息。这里 lb 是一个包含边界框信息的 NumPy 数组, [:, 0:1] 表示选择数组中所有行的第一列,即类别 ID。
"cls": lb[:, 0:1], # n, 1
# 键 "bboxes" 对应的值是 lb[:, 1:] ,这是另一个 NumPy 数组切片,表示所有边界框的坐标信息。 [:, 1:] 表示选择数组中所有行,除了第一列之外的所有列,即边界框的坐标。
"bboxes": lb[:, 1:], # n, 4
# 键 "normalized" 对应的值是 True ,表示边界框坐标是归一化的,即它们是相对于图像宽度和高度的比例值。
"normalized": True,
# 键 "bbox_format" 对应的值是 "xywh" ,表示边界框的格式是中心点坐标 (x, y) 加上宽度 w 和高度 h 。
"bbox_format": "xywh",
# 键 "texts" 对应的值是 texts ,这是一个列表,包含了与图像相关的文本信息,例如类别名称或描述。
"texts": texts,
}
)
# 返回标签信息列表。
return labels
# 这个方法的主要作用是从 JSON 格式的注释文件中提取图像和注释信息,构建每个图像的标签信息,并返回一个包含所有标签信息的列表。这些标签信息可以用于训练目标检测模型。
# 这段代码定义了一个名为 build_transforms 的方法,它是 GroundingDataset 类中的一个成员函数。这个方法的作用是配置数据增强(augmentations)功能,这些功能用于训练,并且可以包含可选的文本加载。
# 定义了一个实例方法 build_transforms ,它接受一个可选参数。
# 1.hyp :这个参数通常包含超参数或配置设置,用于调整增强的强度。
def build_transforms(self, hyp=None):
# 配置用于训练的增强功能,并带有可选的文本加载;`hyp` 调整增强强度。
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
# 调用父类 YOLODataset 的 build_transforms 方法,并传入 hyp 参数。这行代码的目的是获取父类构建的基本数据增强列表,并将其存储在变量 transforms 中。
transforms = super().build_transforms(hyp)
# 检查实例变量 self.augment 是否为 True ,这通常表示是否启用数据增强。
if self.augment:
# NOTE: hard-coded the args for now.
# list.insert(index, object)
# list.insert() 是 Python 中列表( list )对象的一个方法,它用于在列表的指定位置插入一个元素。
# 参数 :
# index :要插入元素的索引位置。索引从 0 开始,表示列表中元素的位置。如果索引超出了列表的当前长度,元素将被追加到列表的末尾。
# object :要插入的元素。
# 返回值 :
# insert() 方法没有返回值(返回 None )。
# 特点 :
# insert() 方法可以用于在列表的任何位置插入元素,包括列表的开头和末尾。
# 如果索引大于列表的长度, insert() 方法会将元素追加到列表的末尾,而不是抛出错误。
# 这个方法会改变原列表,而不是创建一个新的列表。
# insert() 方法是列表操作中常用的一个工具,它允许在列表中的特定位置添加元素,这在处理需要特定顺序的数据时非常有用。
# 如果启用数据增强,使用 insert 方法在 transforms 列表的末尾插入一个新的变换对象 RandomLoadText 。
# RandomLoadText 是一个假设的变换类(在标准库中不存在,可能是自定义的),它可能用于加载文本数据并进行随机采样。
# max_samples=80 参数设置为最大样本数为 80,限制了采样的最大数量。
# padding=True 参数表示需要对文本数据进行填充处理,以确保输入的一致性。
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
# 返回更新后的数据增强列表,其中包含了父类的基本增强以及新增的文本增强变换。
return transforms
# 这个方法的主要作用是为训练过程提供数据增强的支持,特别是在需要处理图像和文本数据时。通过添加文本增强,模型可以学习从文本中提取特征,这对于提高模型在多模态任务中的性能是非常重要的。
# GroundingDataset 类的主要作用是处理特定于目标检测任务的数据集,它从 JSON 文件中加载图像和标签数据,并构建适用于目标检测模型训练的标签信息。此外,它还支持文本增强,这在多模态训练中非常有用。
5.class YOLOConcatDataset(ConcatDataset):
python
# 这段代码定义了一个名为 YOLOConcatDataset 的类,它是 ConcatDataset 类的子类。 ConcatDataset 通常用于将多个数据集串联在一起,以便可以像处理单个数据集一样处理它们。 YOLOConcatDataset 类提供了一个静态方法 collate_fn ,用于定制如何将数据样本合并成批次。
# 类定义。这行代码声明了一个名为 YOLOConcatDataset 的新类,它继承自 ConcatDataset 类。
class YOLOConcatDataset(ConcatDataset):
# 数据集为多个数据集的串联。
# 此类可用于组装不同的现有数据集。
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
"""
# collate_fn 静态方法。
# @staticmethod 装饰器表示 collate_fn 是一个静态方法,它不需要访问类的实例或类变量。
@staticmethod
# collate_fn 是一个方法,接受一个参数。
# 1.batch :这个参数是一个列表,包含了来自串联数据集的一个批次的数据样本。
def collate_fn(batch):
# 将数据样本整理成批。
"""Collates data samples into batches."""
# 这行代码调用了 YOLODataset 类的 collate_fn 方法,并传入 batch 参数。
# 这意味着 YOLOConcatDataset 类使用 YOLODataset 类中定义的 collate_fn 方法来处理批次数据的合并。
# YOLODataset.collate_fn 方法负责具体的合并逻辑,例如如何处理图像和标签数据,以及如何应用数据增强等。
return YOLODataset.collate_fn(batch)
# YOLOConcatDataset 类通过继承 ConcatDataset 并提供自定义的 collate_fn 方法,允许用户在处理多个串联的数据集时使用 YOLODataset 的批次合并逻辑。这种方式使得在训练过程中可以灵活地处理来自不同数据集的样本,同时保持 YOLODataset 的数据处理特性。
6.class SemanticDataset(BaseDataset):
python
# TODO: support semantic segmentation
# 这段代码定义了一个名为 SemanticDataset 的类,它是 BaseDataset 类的子类。 SemanticDataset 类的构造函数初始化了一个语义分割数据集对象。
# 类定义。这行代码声明了一个名为 SemanticDataset 的新类,它继承自 BaseDataset 类。
class SemanticDataset(BaseDataset):
# 语义分割数据集。
# 此类负责处理用于语义分割任务的数据集。它从 BaseDataset 类继承功能。
# 注意:
# 此类当前为占位符,需要填充方法和属性以支持语义分割任务。
"""
Semantic Segmentation Dataset.
This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
from the BaseDataset class.
Note:
This class is currently a placeholder and needs to be populated with methods and attributes for supporting
semantic segmentation tasks.
"""
# 构造函数。这是 SemanticDataset 类的构造函数,它没有接受任何参数。
def __init__(self):
# 初始化语义数据集对象。
"""Initialize a SemanticDataset object."""
# 这行代码调用了父类 BaseDataset 的构造函数。由于没有传递任何参数,这意味着 BaseDataset 的构造函数可能有一个无参数的默认实现,或者它可能使用了默认参数值。
super().__init__()
# SemanticDataset 类的构造函数非常简单,它只是简单地调用了父类的构造函数。这表明 SemanticDataset 类可能依赖于 BaseDataset 类的默认初始化行为,或者它可能在其他方法中添加了特定的初始化逻辑。
# 在实际应用中, SemanticDataset 类可能会重写 BaseDataset 类中的其他方法,以提供语义分割任务所需的特定功能,例如加载图像和对应的标签、应用数据增强、处理数据等。
# 构造函数中的 super().__init__() 调用确保了父类被正确初始化,而任何额外的初始化逻辑则可以在 SemanticDataset 类的其他方法中实现。
7.class ClassificationDataset:
python
# 这段代码定义了一个名为 ClassificationDataset 的类,它用于处理图像分类任务的数据集。
class ClassificationDataset:
# 扩展 torchvision ImageFolder 以支持 YOLO 分类任务,提供图像增强、缓存和验证等功能。它旨在高效处理用于训练深度学习模型的大型数据集,并具有可选的图像转换和缓存机制以加快训练速度。
# 此类允许使用 torchvision 和 Albumentations 库进行增强,并支持在 RAM 或磁盘上缓存图像以减少训练期间的 IO 开销。此外,它还实现了强大的验证过程以确保数据的完整性和一致性。
"""
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
learning models, with optional image transformations and caching mechanisms to speed up training.
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
to ensure data integrity and consistency.
Attributes:
cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
torch_transforms (callable): PyTorch transforms to be applied to the images.
"""
# 这是 ClassificationDataset 类的构造函数,它接受四个参数。
# 1.root :数据集的根目录路径。
# 2.args :包含训练参数的对象。
# 3.augment :一个布尔值,表示是否应用数据增强,默认为 False 。
# 4.prefix :一个字符串,用于日志消息的前缀,默认为空字符串。
def __init__(self, root, args, augment=False, prefix=""):
# 使用 root、图像大小、增强和缓存设置初始化 YOLO 对象。
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
Args:
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
debugging. Default is an empty string.
"""
# 导入 torchvision 。 导入 torchvision 库,用于处理图像数据集。
import torchvision # scope for faster 'import ultralytics'
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
# 这段代码是 ClassificationDataset 类构造函数的一部分,它负责根据 torchvision 库的版本来初始化基础数据集,并设置样本和根目录。
# 这是一个条件判断,检查 TORCHVISION_0_18 变量的值。这个变量是一个布尔值,指示当前环境中的 torchvision 版本是否为 0.18 或更高版本。
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
# torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=None, is_valid_file=None)
# torchvision.datasets.ImageFolder 是 PyTorch 的一个类,它提供了一种方便的方式来加载结构化存储的图像数据集。这种结构化存储意味着图像被组织在不同的文件夹中,每个文件夹的名称对应一个类别。
# 参数 :
# root :数据集的根目录路径,其中包含所有类别的子文件夹。
# transform :一个可选的函数或可调用对象,用于对图像进行预处理或数据增强。它在图像加载后、返回前应用于图像。
# target_transform :一个可选的函数或可调用对象,用于对标签进行预处理。它在标签加载后、返回前应用于标签。
# loader :一个函数,用于加载图像文件。默认情况下,使用 PIL 库加载图像。
# is_valid_file :一个函数,用于检查文件名是否有效。如果提供,它将被用于过滤文件。
# 返回值 :
# 返回一个 ImageFolder 实例,该实例包含图像数据集的加载和预处理逻辑。
# ImageFolder 类是 PyTorch 中处理图像分类任务时常用的工具之一,它简化了数据加载和预处理的过程,使得用户可以专注于模型的训练和评估。
# torchvision.datasets.ImageFolder 类的实例通常包含以下常见的属性 :
# root : 字符串,表示数据集的根目录路径。
# samples : 列表,包含数据集中所有图像的元组信息,通常每个元组包含图像的路径和对应的标签索引。
# classes : 列表,包含数据集中所有类别的名称,顺序与 samples 中的标签索引相对应。
# class_to_idx : 字典,映射类别名称到它们在 classes 列表中的索引。
# imgs : 列表,与 samples 类似,包含图像的路径和标签,但在某些版本的 torchvision 中可能不直接提供。
# targets : 列表,包含与 imgs 列表相对应的标签。
# transform : 函数或可调用对象,用于对图像进行预处理或数据增强,如果提供了 transform 参数,则在加载图像时应用。
# target_transform : 函数或可调用对象,用于对标签进行预处理,如果提供了 target_transform 参数,则在加载标签时应用。
# loader : 函数,用于加载图像文件,默认使用 PIL 库。
# is_valid_file : 函数,用于检查文件名是否有效,如果提供了 is_valid_file 参数,则在加载图像时用于过滤文件。
# 这些属性使得 ImageFolder 实例能够方便地访问和操作图像数据集,同时提供了灵活的预处理和数据加载选项。通过这些属性,用户可以轻松地对数据集进行迭代、应用变换、加载图像和标签等操作。
# 如果 TORCHVISION_0_18 为 True ,即 torchvision 版本支持 allow_empty 参数,那么使用 allow_empty=True 初始化 ImageFolder 类。
# ImageFolder 是 torchvision.datasets 模块中的一个类,用于从文件夹中加载图像数据集,其中每个子文件夹代表一个类别,文件夹名称为类别名。
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
# 如果 TORCHVISION_0_18 为 False ,即 torchvision 版本不支持 allow_empty 参数,则不传递 allow_empty 参数,使用默认的 ImageFolder 初始化方式。
else:
# 在 else 分支中,初始化 ImageFolder 类,并将根目录 root 传递给 ImageFolder 。
self.base = torchvision.datasets.ImageFolder(root=root)
# 将 ImageFolder 实例的 samples 属性赋值给 self.samples 。 samples 属性包含了数据集中所有图像的 路径 和 标签 。
self.samples = self.base.samples
# 将 ImageFolder 实例的 root 属性赋值给 self.root 。 root 属性表示数据集的 目录 。
self.root = self.base.root
# 这段代码的目的是确保 ClassificationDataset 类能够兼容不同版本的 torchvision 。 allow_empty 参数允许 ImageFolder 类在某些类别没有图像时不会抛出错误,这对于某些数据集是有用的。通过这种方式, ClassificationDataset 类可以灵活地处理不同版本的 torchvision ,同时保持对数据集的访问和操作。
# Initialize attributes
# 这段代码是 ClassificationDataset 类中的一部分,它处理数据增强、缓存策略以及数据变换的设置。
# 减少训练样本比例。
# 如果启用了数据增强 ( augment ) 并且训练样本比例 ( args.fraction ) 小于 1.0,则减少样本数量以减少训练数据集的大小。
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
# 设置前缀。 如果提供了 prefix 参数,则设置日志消息的前缀。
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
# 设置 RAM 缓存。根据 args.cache 参数的值,确定是否将图像缓存到 RAM 中。
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
# 处理 RAM 缓存的已知内存泄漏问题。
# 如果启用了 RAM 缓存,由于已知的内存泄漏问题(在 Ultralytics 的 GitHub 问题 #9824 中提到),则发出警告并禁用 RAM 缓存。
if self.cache_ram:
LOGGER.warning(
"WARNING ⚠️ Classification `cache_ram` training has known memory leak in " # 警告⚠️分类`cache_ram`训练在https://github.com/ultralytics/ultralytics/issues/9824中存在已知内存泄漏,设置`cache_ram=False`。
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
)
self.cache_ram = False
# 设置硬盘缓存。根据 args.cache 参数的值,确定是否将图像缓存到硬盘上。
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
# 验证图像文件。调用 verify_images 方法来检查图像文件的有效性,并过滤掉损坏的图像。
self.samples = self.verify_images() # filter out bad images
# 扩展样本信息。扩展样本信息,为每个样本添加 .npy 文件路径和初始设置为 None 的图像数据。
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
# 计算缩放比例,其中 args.scale 是一个值,通常用于确定数据增强时的尺度变化范围。例如,如果 args.scale 是 0.08,则 scale 将是 (0.92, 1.0),表示图像可以在 92% 到 100% 之间随机缩放。
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
# 设置数据变换。
# 根据是否启用数据增强,设置分类任务的数据增强或变换序列。这些变换包括大小调整、随机水平翻转、垂直翻转、擦除、自动增强以及 HSV 颜色空间的调整。
# 这段代码在 ClassificationDataset 类中定义了 self.torch_transforms 属性,它根据是否启用数据增强( augment 参数)来设置不同的 PyTorch 数据变换序列。这些变换用于图像的预处理和数据增强,以提高模型的泛化能力。
# 定义了一个条件表达式,根据 augment 参数的值选择使用数据增强变换还是普通变换。
# 如果 augment 为 True ,则使用 classify_augmentations 函数创建一个数据增强变换序列。
# classify_augmentations : 一个函数,用于创建分类任务的数据增强变换序列。它接受以下参数 :
# size :输出图像的大小。
# scale :图像缩放的比例范围。
# hflip :是否进行随机水平翻转。
# vflip :是否进行随机垂直翻转。
# erasing :是否应用随机擦除。
# auto_augment :是否应用自动增强策略。
# hsv_h 、 hsv_s 、 hsv_v :HSV 颜色空间的调整范围。
self.torch_transforms = (
classify_augmentations(
size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v,
)
if augment
# 如果 augment 为 False ,则使用 classify_transforms 函数创建一个普通的变换序列,通常包括缩放和裁剪操作。
# size :输出图像的大小。 crop_fraction :裁剪的比例。
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
)
# 最终, self.torch_transforms 被设置为上述条件表达式的结果,它是一个 PyTorch 的变换序列,将被应用于数据集中的每个图像。
# 这段代码的目的是为分类任务提供灵活的数据预处理和增强选项,使得可以根据训练配置选择适当的变换策略。通过这种方式,用户可以根据需要调整模型训练的难度和效果。
# 这段代码涵盖了数据集的初始化、数据增强、缓存策略和数据变换等多个方面,为分类任务提供了灵活的数据准备选项。
# 这个类的构造函数初始化了一个分类数据集对象,设置了数据增强、缓存策略,并构建了 PyTorch 的数据变换序列。这些设置使得数据集可以灵活地应用于不同的分类任务和训练需求。
# 这段代码定义了一个名为 __getitem__ 的方法,它是 Python 中数据集类的一个特殊方法,用于实现数据集的索引操作,即当你使用索引访问数据集时,这个方法会被调用。在 PyTorch 的 Dataset 类中, __getitem__ 是一个必须被重写的方法,它允许数据加载器( DataLoader )从数据集中获取单个样本或一批样本。
# 定义了一个实例方法 __getitem__ ,它接受一个参数。
# 1.i :表示要获取的数据样本的索引。
def __getitem__(self, i):
# 返回与给定索引相对应的数据子集和目标。
"""Returns subset of data and targets corresponding to given indices."""
# 从 self.samples 列表中获取索引 i 对应的样本信息,并解包为四个变量。 f (图像文件名)、 j (标签索引)、 fn (图像文件名对应的 .npy 文件名)、 im (图像数据)。
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
# 如果启用了 RAM 缓存( self.cache_ram 为 True ),则检查 im 是否为 None 。
if self.cache_ram:
# 如果 im 为 None ,则使用 cv2.imread(f) 读取图像数据,并更新 self.samples[i] 中的 im 。
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
im = self.samples[i][3] = cv2.imread(f)
# 如果启用了硬盘缓存( self.cache_disk 为 True ),则检查 .npy 文件是否存在。
elif self.cache_disk:
# 如果 .npy 文件不存在,则使用 np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) 保存 图像数据 为 .npy 文件。
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
# 如果 .npy 文件存在,则使用 np.load(fn) 加载图像数据。
im = np.load(fn)
# 如果没有启用缓存,则直接使用 cv2.imread(f) 读取图像数据。
else: # read image
im = cv2.imread(f) # BGR
# Convert NumPy array to PIL image
# 将图像数据从 BGR 格式转换为 RGB 格式,并使用 Image.fromarray 将 NumPy 数组转换为 PIL 图像。
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
# 将 PIL 图像传递给 self.torch_transforms ,这是一个包含预处理和数据增强操作的 PyTorch 变换序列。
sample = self.torch_transforms(im)
# 返回一个字典,包含处理后的图像样本 sample 和对应的标签索引 j 。
return {"img": sample, "cls": j}
# 这个方法的主要作用是根据索引获取数据集中的单个样本,并应用预处理和数据增强操作,然后返回处理后的样本和对应的标签。这是 PyTorch 数据加载和训练流程中的关键部分,允许模型在训练时动态地访问和处理数据集。
# 这段代码定义了一个名为 __len__ 的方法,它是 Python 中的一个特殊方法(也称为魔术方法或内置方法)。 __len__ 方法用于返回对象的长度,即对象中包含的元素数量。在数据集类中, __len__ 方法通常被重写以返回数据集中的样本总数。
# 定义了一个实例方法 __len__ ,它不接受任何参数,除了 self (指向当前实例的引用)。 方法的返回类型被指定为 int ,表示返回值将是一个整数。
def __len__(self) -> int:
# 返回数据集中的样本总数。
"""Return the total number of samples in the dataset."""
# 返回 self.samples 列表的长度,即数据集中的样本总数。
return len(self.samples)
# __len__ 方法在 PyTorch 的 Dataset 类中经常被重写,因为它允许 DataLoader 知道每个epoch中有多少批次的数据,以及在调试和日志记录时提供有用的信息。当你使用 len(dataset) 时,Python 会自动调用这个 __len__ 方法来获取数据集的长度。
# 这段代码定义了一个名为 verify_images 的方法,它是用于验证数据集中所有图像文件的有效性,并根据需要缓存结果以加快后续操作的速度。
# 方法定义。这是一个实例方法,属于某个数据集类,用于验证数据集中的图像文件。
def verify_images(self):
# 验证数据集中的所有图像。
"""Verify all images in dataset."""
# 扫描描述和缓存文件路径。 desc 是一个字符串,描述了当前正在执行的操作。 path 是缓存文件的路径,它是数据集根目录路径加上 .cache 后缀。
desc = f"{self.prefix}Scanning {self.root}..." # {self.prefix}正在扫描 {self.root}...
path = Path(self.root).with_suffix(".cache") # *.cache file path
# 这段代码是 verify_images 方法的一部分,它尝试从缓存文件中加载数据集的状态,并在成功加载后进行一系列验证和处理。
# 使用 contextlib.suppress 作为一个上下文管理器,它会捕获并抑制指定的异常 :FileNotFoundError (文件未找到错误)、 AssertionError (断言错误)和 AttributeError (属性错误)。这意味着如果这些异常在代码块中发生,它们将被静默处理,不会向外抛出。
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
# 调用 load_dataset_cache_file 函数尝试从路径 path 加载 .cache 文件。这个文件包含了之前的数据集验证结果,以便快速恢复状态。
# # def load_dataset_cache_file(path): -> 用于从指定路径加载 Ultralytics 的 .cache 字典文件。返回加载的缓存字典。 -> return cache
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
# 使用 assert 语句确保缓存文件的版本与当前版本 DATASET_CACHE_VERSION 匹配。如果不匹配,将抛出 AssertionError 。
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
# 再次使用 assert 语句确保缓存文件的哈希值与当前样本列表的哈希值匹配。如果不匹配,将抛出 AssertionError 。
# def get_hash(paths): -> 用于计算一个包含文件或目录路径列表的单一哈希值。计算最终的哈希值,并以十六进制格式返回。 -> return h.hexdigest() # return hash
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
# 从缓存中提取验证结果,包括 找到的图像 数量 nf 、 损坏的图像 数量 nc 、 空的图像 数量 n 和 样本列表 samples 。
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
# 检查 LOCAL_RANK 变量是否为 -1 或 0 ,这通常用于确定是否在主进程中运行(在分布式训练环境中)。
if LOCAL_RANK in {-1, 0}:
# 构造一个描述字符串,显示已找到的图像数量和损坏的图像数量。
d = f"{desc} {nf} images, {nc} corrupt" # {desc} {nf} 图像,{nc} 损坏。
# 使用 TQDM 显示进度条,描述字符串为 d ,总进度为 n ,初始进度也为 n 。
TQDM(None, desc=d, total=n, initial=n)
# 检查缓存中是否有警告消息。
if cache["msgs"]:
# 如果有警告消息,使用 LOGGER 记录它们。
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# 返回验证后的样本列表。
return samples
# 这段代码的主要作用是尝试从缓存文件中恢复数据集的状态,如果成功,则验证版本和哈希值的一致性,并显示进度和警告。如果缓存文件不存在或版本/哈希值不匹配,代码将跳过缓存加载,执行完整的图像验证流程。
# 这段代码是 verify_images 方法的一部分,它负责在无法从缓存中加载数据集状态时,执行实际的图像验证过程,并缓存结果以备将来使用。
# Run scan if *.cache retrieval failed
# 初始化变量。
# nf 和 nc 分别初始化为 0,用于计数 找到的 和 损坏的 图像数量。
# msgs 初始化为空列表,用于存储验证过程中的消息。
# samples 初始化为空列表,用于存储验证后的样本。
# x 初始化为空字典,用于存储缓存数据。
nf, nc, msgs, samples, x = 0, 0, [], [], {}
# 使用线程池进行图像验证。
# 使用 ThreadPool 来并行执行图像验证。
with ThreadPool(NUM_THREADS) as pool:
# pool.imap(func, iterable, chunksize=None)
# imap() 方法用于将一个可迭代的输入序列分块分配到线程池中的线程进行处理,并将结果返回一个迭代器。这个方法特别适用于需要顺序处理输入和输出的场景。
# 参数 :
# func :一个函数,它将被调用并传入 iterable 中的每个项目。
# iterable :一个可迭代对象,其元素将被传递给 func 函数。
# chunksize :(可选)一个整数,指定了每个任务传递给 func 的项目数量。默认值为 1,意味着每个任务只包含一个项目。如果设置为大于 1 的值,那么 func 将接收到一个包含多个项目的列表。
# 返回值 :
# 返回一个 Iterator ,它生成每个输入元素经过 func 处理后的结果。
# 特点 :
# 结果的顺序与输入序列的顺序相同。
# 如果任何一个任务因为异常而终止, imap() 会立即抛出异常。
# 它允许主线程在子线程完成工作之前继续执行,而不是等待所有任务完成。
# ThreadPool.imap() 是处理 I/O 密集型任务或者需要顺序处理结果的并发任务的有用工具。与之相对的是 imap_unordered() ,它同样返回一个迭代器,但是结果的顺序可能与输入序列不同,适用于不在乎结果顺序的场景。
# pool.imap 将 verify_image 函数应用于 self.samples 中的每个样本和前缀。
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
# TQDM 用于显示进度条。
pbar = TQDM(results, desc=desc, total=len(self.samples))
# 处理验证结果。
# 这段代码是 verify_images 方法中的一部分,它处理图像验证过程中的每个样本,并更新进度条描述。
# 遍历 TQDM 对象 pbar 的结果。 pbar 是由 pool.imap 生成的迭代器,它为数据集中的每个样本返回一个元组,包含 :
# sample :当前处理的样本。
# nf_f :一个布尔值,表示当前样本是否有效(找到的)。
# nc_f :一个布尔值,表示当前样本是否损坏。
# msg :一个字符串,包含有关当前样本的任何消息。
for sample, nf_f, nc_f, msg in pbar:
# 如果 nf_f 为 True ,表示当前样本是有效的,将其添加到 samples 列表中。
if nf_f:
samples.append(sample)
# 如果 msg 非空,表示有关于当前样本的消息,将其添加到 msgs 列表中。
if msg:
msgs.append(msg)
# 更新找到的图像数量 nf 。由于 nf_f 是布尔值, True 相当于 1, False 相当于 0。
nf += nf_f
# 更新损坏的图像数量 nc 。由于 nc_f 是布尔值, True 相当于 1, False 相当于 0。
nc += nc_f
# 更新进度条描述 pbar.desc ,显示当前找到的图像数量和损坏的图像数量。
pbar.desc = f"{desc} {nf} images, {nc} corrupt" # {desc} {nf} 图像,{nc} 损坏。
# 关闭 TQDM 进度条。这是一个好习惯,特别是在使用 with 语句或其他上下文管理器时,确保资源被正确释放。
pbar.close()
# 这段代码的主要作用是迭代处理每个样本的结果,更新样本列表和消息列表,并实时更新进度条描述,以便用户可以看到验证过程的进度和状态。这种方法提高了数据处理的透明度,并允许用户监控长时间运行的任务。
# 记录消息。
# 如果有消息,使用 LOGGER 记录它们。
if msgs:
LOGGER.info("\n".join(msgs))
# 构建缓存数据。
# 计算样本的哈希值并存储在 x 中。
x["hash"] = get_hash([x[0] for x in self.samples])
# 存储 验证结果 和 消息 。
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
# 保存缓存文件。调用 save_dataset_cache_file 函数保存缓存文件。
# def save_dataset_cache_file(prefix, path, x, version): -> 用于将一个名为 x 的数据集缓存字典保存到指定的路径。
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
# 返回样本列表。返回验证后的样本列表。
return samples
# 这段代码的主要作用是验证数据集中的图像文件,记录验证过程中的消息,并缓存验证结果以加快后续操作的速度。通过并行处理和缓存机制,它提高了数据集验证的效率。
# 这个方法的主要作用是确保数据集中的图像文件是有效的,并且通过缓存机制提高数据加载的效率。它处理了图像文件的验证、进度显示和缓存保存,是数据预处理流程中的关键步骤。