train.py
ultralytics\models\yolo\detect\train.py
目录
[2.class DetectionTrainer(BaseTrainer):](#2.class DetectionTrainer(BaseTrainer):)
1.所需的库和模块
python
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
import random
from copy import copy
import numpy as np
import torch.nn as nn
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
2.class DetectionTrainer(BaseTrainer):
python
# 这段代码是一个用于目标检测训练的类 DetectionTrainer ,它继承自 BaseTrainer 类。这个类包含了构建数据集、获取数据加载器、预处理批次数据、设置模型属性、获取模型、获取验证器、标签损失项、进度字符串、绘制训练样本、绘制指标和绘制训练标签等方法。
class DetectionTrainer(BaseTrainer):
# 扩展 BaseTrainer 类的类,用于基于检测模型进行训练。
"""
A class extending the BaseTrainer class for training based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model="yolov8n.pt", data="coco8.yaml", epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```
"""
# 这段代码定义了 DetectionTrainer 类中的 build_dataset 方法,该方法用于构建一个数据集,这个数据集是用于训练或验证 YOLO 模型的。
# 这行定义了 build_dataset 方法,它接受三个参数。
# 1.self : 类实例的引用,允许访问类的属性和方法。
# 2.img_path : 传递给方法的参数,表示图像文件的路径。
# 3.mode : 传递给方法的参数,表示数据集的模式,可以是 "train"(训练)或 "val"(验证)。默认值为 "train"。
# 4.batch : 可选参数,表示批次大小。如果提供,将用于确定数据集的批次大小。
def build_dataset(self, img_path, mode="train", batch=None):
# 构建 YOLO 数据集。
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
# 计算全局 stride ( gs ),这是模型中最大的 stride 值,用于确定数据集中图像的尺寸。 de_parallel 函数是用于移除模型的分布式数据并行包装器,以便访问模型的实际stride属性。
# 如果模型存在,它计算模型的最大 stride 值;如果模型不存在,则使用 0 。然后,这个值与 32 进行比较,取两者中较大的值作为全局 stride 。
# def de_parallel(model):
# -> 将一个可能处于数据并行(DataParallel)或分布式数据并行(DistributedDataParallel,简称 DDP)状态的模型转换回单GPU模型。
# -> return model.module if is_parallel(model) else model
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
# 调用 build_yolo_dataset 函数来构建 YOLO 数据集。传递给这个函数的参数包括 :
# self.args : 包含训练参数的对象或字典。
# img_path : 图像文件的路径。
# batch : 批次大小。
# self.data : 包含数据集相关信息的对象或字典。
# mode : 数据集模式,"train" 或 "val"。
# rect : 一个布尔值,指示是否使用矩形训练。如果模式是 "val",则为 True。
# stride : 上一步计算的全局stride值。
# def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
# -> 用于构建 YOLO(You Only Look Once)目标检测模型的数据集。这个函数根据提供的配置和参数,初始化并返回一个 YOLODataset 或 YOLOMultiModalDataset 实例。
# -> return dataset(img_path=img_path, imgsz=cfg.imgsz, batch_size=batch, augment=mode == "train", hyp=cfg, rect=cfg.rect or rect, cache=cfg.cache or None, single_cls=cfg.single_cls or False, stride=int(stride),
# pad=0.0 if mode == "train" else 0.5, prefix=colorstr(f"{mode}: "),task=cfg.task, classes=cfg.classes, data=data, fraction=cfg.fraction if mode == "train" else 1.0,)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
# 这个方法的目的是根据模型的配置和数据集的路径来构建一个适用于YOLO目标检测模型的数据集。YOLO需要特定格式的数据集来进行训练和验证。
# 这段代码定义了一个名为 get_dataloader 的方法,它用于构建并返回一个数据加载器(dataloader)。这个方法是 DetectionTrainer 类的一部分,该类用于训练目标检测模型。
# 定义了一个名为 get_dataloader 的方法,它接受四个参数。
# 1.self :类的实例自身。
# 2.dataset_path :数据集的路径。
# 3.batch_size :批处理大小,默认为16。
# 4.rank :在分布式训练中的排名,默认为0。
# 5.mode :数据集模式,默认为"train"。
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
# 构造并返回数据加载器。
"""Construct and return dataloader."""
# 这是一个断言语句,用于确保 mode 参数只能是"train"或"val"。如果不是这两个值,将抛出一个异常。
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}." # 模式必须是"train"或"val",而不是{mode}。
# 这个上下文管理器用于在分布式训练中确保只有 rank 为0的进程首先执行某些操作。这是为了减少数据加载时的冗余操作。
# def torch_distributed_zero_first(local_rank: int): -> 用于在分布式训练中确保所有进程等待主进程(通常是 rank 0)完成特定任务后再继续执行。
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
# 在 with 块内部,调用 self.build_dataset 方法来构建数据集。这个方法需要数据集路径、模式和批处理大小作为参数。
dataset = self.build_dataset(dataset_path, mode, batch_size)
# 定义一个变量 shuffle ,如果模式是"train",则设置为True,否则为False。这通常用于决定是否在训练时打乱数据。
shuffle = mode == "train"
# 检查数据集对象是否有 rect 属性,并且是否设置为True。同时检查 shuffle 是否为True。
if getattr(dataset, "rect", False) and shuffle:
# 如果 rect 为True且 shuffle 为True,则记录一个警告,说明 rect=True 与 DataLoader 的 shuffle 不兼容,并将 shuffle 设置为False。
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") # 警告 ⚠️ 'rect=True' 与 DataLoader shuffle 不兼容,请设置 shuffle=False。
# 将 shuffle 设置为False,以解决不兼容性问题。
shuffle = False
# 根据模式设置工作线程数。如果是训练模式("train"),则使用 self.args.workers 指定的线程数;如果是验证模式("val"),则使用训练模式的两倍线程数。
workers = self.args.workers if mode == "train" else self.args.workers * 2
# 调用 build_dataloader 函数来构建并返回一个数据加载器。这个函数接受 数据集对象 、 批处理大小 、 工作线程数 、 是否打乱数据 和 分布式训练的排名 作为参数。
# def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
# -> 用于构建 PyTorch 的 DataLoader 或 InfiniteDataLoader 对象。
# -> return InfiniteDataLoader(dataset=dataset, batch_size=batch, shuffle=shuffle and sampler is None, num_workers=nw, sampler=sampler, pin_memory=PIN_MEMORY,
# collate_fn=getattr(dataset, "collate_fn", None), worker_init_fn=seed_worker, generator=generator,)
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
# 这个方法的作用是在训练或验证模式下构建一个适合目标检测模型的数据加载器,它考虑了分布式训练、数据打乱和多进程加载等因素,以确保数据的有效加载和处理。
# 这段代码定义了一个名为 preprocess_batch 的方法,它用于预处理一批图像数据。
# 定义了一个名为 preprocess_batch 的方法,它接受两个参数。
# 1.self :类的实例自身。
# 2.batch :一批图像数据。
def preprocess_batch(self, batch):
# 通过缩放和转换为浮点数来预处理一批图像。
"""Preprocesses a batch of images by scaling and converting to float."""
# 这行代码执行了三个操作 :
# batch["img"].to(self.device, non_blocking=True) :将图像数据移动到指定的设备(如GPU), non_blocking=True 参数表示如果数据已经在设备上,则不会阻塞等待。
# .float() :将图像数据转换为浮点数类型。
# / 255 :将图像数据的像素值从[0, 255]缩放到[0, 1]。
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
# 检查是否启用了多尺度训练,这是通过 self.args.multi_scale 参数控制的。
if self.args.multi_scale:
# 如果启用了多尺度训练,将图像数据赋值给 imgs 变量。
imgs = batch["img"]
# 计算新的图像尺寸 sz 。这个尺寸是在原始图像尺寸的50%到150%之间随机选择的,并且是 self.stride 的整数倍。
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
) # size
# 计算缩放因子 sf ,即新尺寸 sz 与图像的最大维度的比值。 imgs.shape[2:] 获取的是图像的高度和宽度。
sf = sz / max(imgs.shape[2:]) # scale factor
# 检查缩放因子是否不等于1,即是否需要缩放图像。
if sf != 1:
# 如果需要缩放,计算新的图像尺寸 ns ,确保新尺寸是 self.stride 的整数倍。
ns = [
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
] # new shape (stretched to gs-multiple)
# 使用双线性插值方法将图像缩放到新的尺寸 ns 。
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
# 将缩放后的图像数据赋值回 batch["img"] 。
batch["img"] = imgs
# 返回预处理后的批次数据。
return batch
# 这个方法的主要功能是对一批图像数据进行预处理,包括将像素值缩放到[0, 1],以及在启用多尺度训练时对图像进行随机缩放。这些预处理步骤对于训练深度学习模型是常见的,有助于提高模型的泛化能力。
# 这段代码定义了一个名为 set_model_attributes 的方法,它的作用是为模型设置一些属性,这些属性通常与数据集和模型的配置有关。
# 定义了一个名为 set_model_attributes 的方法,它不接受任何参数,只使用类的实例自身( self )。
def set_model_attributes(self):
# Nl = de_parallel(self.model).model[-1].nl # 检测层的数量(以缩放 hyps)。
"""Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# 这是一个被注释掉的代码行,它表示将 self.args.box 乘以 3/nl ,以根据 检测层的数量 进行缩放。
# self.args.box *= 3 / nl # scale to layers
# 这是另一行被注释掉的代码,它表示将 self.args.cls 乘以 (self.data["nc"] / 80) * 3 / nl ,以根据 类别数量 和 检测层的数量 进行缩放。
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# 这行被注释掉的代码进一步调整 self.args.cls ,这次是根据 图像尺寸 和 检测层的数量 进行缩放。
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
# 将数据集中的类别数量( self.data["nc"] )设置到模型的 nc 属性中,这样模型就知道它需要处理多少个类别。
self.model.nc = self.data["nc"] # attach number of classes to model
# 将数据集中的类别名称( self.data["names"] )设置到模型的 names 属性中,这样模型就可以访问每个类别的名称。
self.model.names = self.data["names"] # attach class names to model
# 将超参数( self.args )设置到模型的 args 属性中,这样模型就可以访问训练过程中使用的配置参数。
self.model.args = self.args # attach hyperparameters to model
# 这是一个待办事项(TODO)注释,提示开发者将来可能需要实现一个功能,即根据数据集中的 标签 和 类别数量 计算类别权重,并将这些权重设置到模型的 class_weights 属性中。这个属性可以帮助模型在训练时对不平衡的数据集进行调整。
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
# set_model_attributes 方法用于将数据集的相关信息和超参数配置到模型中,以便模型在训练时能够正确地使用这些信息。
# 这段代码定义了一个名为 get_model 的方法,它用于创建并返回一个YOLO目标检测模型。
# 定义了一个名为 get_model 的方法,它接受四个参数。
# 1.self :类的实例自身。
# 2.cfg :模型的配置文件路径,默认为 None 。
# 3.weights :模型的预训练权重文件路径,默认为 None 。
# 4.verbose :是否打印详细信息,默认为 True 。
def get_model(self, cfg=None, weights=None, verbose=True):
# 返回 YOLO 检测模型。
"""Return a YOLO detection model."""
# 创建一个 DetectionModel 实例,它是YOLO模型的一个实现。这个实例化过程需要三个参数 :
# cfg :模型的配置文件路径。
# nc :类别数量,从 self.data["nc"] 中获取,表示数据集中的类别数。
# verbose :是否打印详细信息,这里使用了一个条件表达式 verbose and RANK == -1 ,意味着只有当 verbose 为 True 且 RANK 变量等于 -1 时,才会打印详细信息。 RANK 通常用于分布式训练中标识进程的排名,当 RANK 为 -1 时,表示不是在分布式环境中运行。
# class DetectionModel(BaseModel):
# -> 用于构建和初始化 YOLOv8 检测模型。
# -> def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
# 检查是否提供了 weights 参数。
if weights:
# 如果提供了 weights 参数,调用模型的 load 方法来加载预训练权重。
model.load(weights)
# 返回创建的YOLO模型实例。
return model
# 这个方法的主要功能是初始化一个YOLO模型,并根据提供的配置和权重进行设置。如果提供了权重路径,它会加载这些权重到模型中。这个方法可以用于创建一个新的模型实例,或者在训练前加载一个预训练的模型。
# 这段代码定义了一个名为 get_validator 的方法,它用于创建并返回一个 DetectionValidator 实例,用于YOLO模型的验证过程。
# 定义了一个名为 get_validator 的方法,它不接受任何额外参数,只使用类的实例自身( self )。
def get_validator(self):
# 返回用于 YOLO 模型验证的 DetectionValidator。
"""Returns a DetectionValidator for YOLO model validation."""
# 设置了类的 loss_names 属性,它包含了在验证过程中需要计算的损失函数名称的元组。这些损失函数通常包括 边界框损失( box_loss )、 类别损失( cls_loss )和 置信度损失( dfl_loss )。
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
# 返回一个 DetectionValidator 实例,它是从 yolo.detect 模块导入的。
return yolo.detect.DetectionValidator(
# self.test_loader :是一个数据加载器,用于在验证过程中提供测试数据。
# save_dir=self.save_dir :指定了保存验证结果的目录路径。
# args=copy(self.args) :它接受一个复制的 self.args 对象。 self.args 包含了模型训练和验证过程中使用的超参数。使用 copy 函数确保传递给 DetectionValidator 的是 args 的一个副本,而不是原始对象。
# _callbacks=self.callbacks :它接受 self.callbacks ,这是一个回调函数列表,用于在验证过程中执行特定的操作,如打印日志、保存模型等。
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
# 这个方法的主要功能是初始化一个 DetectionValidator 对象,用于在YOLO模型的验证阶段计算损失、评估模型性能,并可能执行其他回调操作。通过传递测试数据加载器、保存目录、超参数和回调函数, DetectionValidator 能够完成模型验证的整个流程。
# 这段代码定义了一个名为 label_loss_items 的方法,它用于处理和格式化损失项(loss items)的数据。
# 定义了一个名为 label_loss_items 的方法,它接受三个参数。
# 1.self :类的实例自身。
# 2.loss_items :损失项的数值列表,默认为 None 。
# 3.prefix :损失项名称的前缀,默认为 "train" 。
def label_loss_items(self, loss_items=None, prefix="train"):
# 返回带有标记训练损失项张量的损失字典。
# 分类不需要,但分割和检测需要。
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
# 创建一个列表 keys ,其中包含格式化的损失项名称。 self.loss_names 是一个包含损失项名称的元组, prefix 是前缀(例如 "train" 或 "val" ),用于区分训练和验证阶段的损失项。
keys = [f"{prefix}/{x}" for x in self.loss_names]
# 检查 loss_items 参数是否被提供且不为 None 。
if loss_items is not None:
# 如果 loss_items 不为 None ,则将列表中的每个元素转换为浮点数,并四舍五入到小数点后五位。这个步骤确保了损失项的数值是精确的,并且格式化为易于阅读的格式。
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
# 使用 zip 函数将 keys 列表和 loss_items 列表合并成一个字典,并返回这个字典。这个字典将每个损失项的名称(例如 "train/box_loss" )映射到其对应的数值。
return dict(zip(keys, loss_items))
# 如果 loss_items 为 None ,则方法只返回 keys 列表,即只包含损失项名称的列表。
else:
return keys
# 这个方法的主要功能是将损失项的数值与它们的名称关联起来,并返回一个字典或名称列表。如果提供了损失项的数值,它会将这些数值格式化并创建一个字典;如果没有提供数值,它只返回损失项的名称列表。这使得在训练和验证模型时可以轻松地跟踪和报告损失项。
# 这段代码定义了一个名为 progress_string 的方法,它用于生成一个格式化的字符串,显示训练进度的信息。
# 定义了一个名为 progress_string 的方法,它不接受任何额外参数,只使用类的实例自身( self )。
def progress_string(self):
# 返回带有纪元、GPU 内存、损失、实例和大小的训练进度的格式化字符串。
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
# 构建了一个格式化字符串。 "%11s" 是一个格式化占位符,表示一个宽度为11个字符的字符串。 (4 + len(self.loss_names)) 计算总共需要多少个这样的占位符,其中4是固定的(对应于"Epoch"、"GPU_mem"、"Instances"、"Size"), len(self.loss_names) 是损失项的数量。
# % 操作符用于将这些参数填充到格式化字符串中。 "\n" 在字符串的开始添加了一个新的行,确保这个进度信息在新的一行显示。
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
# 这是一个参数列表,用于填充上面提到的格式化占位符。参数包括 :
# "Epoch" :表示当前的训练周期。
# "GPU_mem" :表示当前的GPU内存使用情况。
# *self.loss_names :使用星号操作符将 self.loss_names 列表中的所有损失项名称作为参数展开。
# "Instances" :表示当前处理的实例数。
# "Size" :表示当前处理的数据批次的大小。
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
# progress_string 方法生成了一个标题行,包含了训练进度的所有关键信息的标签。这个字符串可以用来在训练过程中打印进度,方便用户跟踪训练的状态。
# 这段代码定义了一个名为 plot_training_samples 的方法,它用于绘制训练样本及其标注信息。
# 定义了一个名为 plot_training_samples 的方法,它接受三个参数。
# 1.self :类的实例自身。
# 2.batch :一批图像数据及其标注信息。
# 3.ni :当前批次的索引或编号。
def plot_training_samples(self, batch, ni):
# 绘制训练样本及其注释。
"""Plots training samples with their annotations."""
# 调用 plot_images 函数,它用于绘制图像及其标注。
# def plot_images(images: Union[torch.Tensor, np.ndarray], batch_idx: Union[torch.Tensor, np.ndarray], cls: Union[torch.Tensor, np.ndarray], bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
# confs: Optional[Union[torch.Tensor, np.ndarray]] = None, masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
# kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), paths: Optional[List[str]] = None, fname: str = "images.jpg",
# names: Optional[Dict[int, str]] = None, on_plot: Optional[Callable] = None, max_size: int = 1920, max_subplots: int = 16, save: bool = True, conf_thres: float = 0.25,) -> Optional[np.ndarray]:
# -> 用于将一系列图像和相关的边界框、类别、置信度等信息绘制成一幅图像马赛克(mosaic)。返回图像数据。返回 Annotator 对象中的图像数据,以 NumPy 数组的形式。这样,即使不保存到文件系统,也可以在内存中使用或进一步处理图像。
# -> return np.asarray(annotator.im)
plot_images(
# 从 batch 字典中获取 图像数据 ,赋值给 images 参数。
images=batch["img"],
# 从 batch 字典中获取 批次索引 ,用于标识每个图像在批次中的位置。
batch_idx=batch["batch_idx"],
# 从 batch 字典中获取 类别信息 ,并使用 .squeeze(-1) 方法去除最后一个维度,通常用于去除长度为1的维度。
cls=batch["cls"].squeeze(-1),
# 从 batch 字典中获取 边界框信息 。
bboxes=batch["bboxes"],
# 从 batch 字典中获取 图像文件的路径 。
paths=batch["im_file"],
# 指定了保存绘制图像的文件名。这里使用 self.save_dir 指定保存目录,并以 train_batch{ni}.jpg 格式命名文件,其中 ni 是批次索引。
fname=self.save_dir / f"train_batch{ni}.jpg",
# 接受一个回调函数(如果有的话),用于在绘制图像时执行额外的操作。
on_plot=self.on_plot,
)
# plot_training_samples 方法用于将一批训练样本及其标注信息绘制成图像,并保存到指定的目录。这可以帮助开发者和研究人员直观地检查模型的训练数据和标注质量。
# 这段代码定义了一个名为 plot_metrics 的方法,它用于从一个CSV文件中绘制度量指标(metrics)。
# 定义了一个名为 plot_metrics 的方法,它不接受任何额外参数,只使用类的实例自身( self )。
def plot_metrics(self):
# 根据 CSV 文件绘制指标。
"""Plots metrics from a CSV file."""
# 调用了 plot_results 函数,它用于绘制存储在CSV文件中的度量指标数据。
# file=self.csv : 它指定了包含度量指标数据的CSV文件。这里使用 self.csv ,意味着CSV文件的路径或名称存储在类的 csv 属性中。
# on_plot=self.on_plot : 它接受一个回调函数(如果有的话),用于在绘制图像时执行额外的操作,比如添加特定的标签或注释。这里使用 self.on_plot ,意味着回调函数存储在类的 on_plot 属性中。
# def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None): -> 用于绘制存储在 CSV 文件中的结果数据。
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
# 这个方法的主要功能是将存储在CSV文件中的度量指标数据绘制成图像,并可能保存到文件中。这可以帮助开发者和研究人员可视化模型的性能变化,比如损失函数值、准确率等关键指标。通过使用回调函数,还可以自定义图像的样式和内容。
# 这段代码定义了一个名为 plot_training_labels 的方法,它用于创建一个标记了训练数据标签的YOLO模型训练图。
# 定义了一个名为 plot_training_labels 的方法,它不接受任何额外参数,只使用类的实例自身( self )。
def plot_training_labels(self):
# 创建 YOLO 模型的标记训练图。
"""Create a labeled training plot of the YOLO model."""
# 创建一个包含所有边界框数据的NumPy数组。它通过列表推导式从 self.train_loader.dataset.labels 中提取每个标签的 "bboxes" ,然后使用 np.concatenate 将它们沿着第一个维度(0)连接起来,形成一个大的数组。
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
# 创建一个包含所有类别数据的NumPy数组。它通过列表推导式从 self.train_loader.dataset.labels 中提取每个标签的 "cls" ,然后使用 np.concatenate 将它们沿着第一个维度(0)连接起来,形成一个大的数组。
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
# 调用 plot_labels 函数,它用于绘制带有类别标签的边界框。
# boxes : 包含了所有边界框的数据。
# cls.squeeze() : 它是一个包含类别数据的数组。 .squeeze() 方法用于去除数组中长度为1的维度。
# names=self.data["names"] : 它包含了类别的名称,从 self.data["names"] 中获取。
# save_dir=self.save_dir : 它指定了保存绘制图像的目录,从 self.save_dir 中获取。
# on_plot=self.on_plot : 它接受一个回调函数(如果有的话),用于在绘制图像时执行额外的操作,比如添加特定的标签或注释,从 self.on_plot 中获取。
# def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): -> 绘制训练标签,包括类别直方图和框统计信息。
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
# 这个方法的主要功能是将训练数据中的边界框和类别标签绘制成图像,并保存到指定的目录。这可以帮助开发者和研究人员直观地检查模型的训练数据和标注质量。通过使用回调函数,还可以自定义图像的样式和内容。
# 这个类显然是为深度学习框架(如PyTorch)设计的,用于训练YOLO(You Only Look Once)模型,这是一种流行的实时目标检测算法。代码中使用了分布式训练、多尺度训练和数据增强等高级特性。此外,代码中还包含了一些待完成的TODO项,例如计算类别权重。