【YOLOv3】源码(train.py)

概述

主要模块分析

  • 参数解析与初始化
    • 功能:解析命令行参数,设置训练配置
    • 项目经理制定详细的施工计划和资源分配
  • 日志记录与监控
    • 功能:初始化日志记录器,配置监控系统
    • 项目经理使用监控和记录工具,实时跟踪施工进度和质量
  • 模型与数据加载
    • 功能:加载模型权重和配置文件,准备训练数据
    • 项目经理选择建筑设计方案,准备施工材料和组织施工队伍
  • 优化器与学习率调度器设置
    • 功能:设置优化器和学习率调度器,指导模型参数更新
    • 项目经理分配施工资源,制定施工进度计划
  • 训练循环
    • 功能:执行模型的前向传播、损失计算、反向传播和参数更新
    • 施工队每日执行施工任务,项目经理监控进度和质量
  • 验证与评估
    • 功能:定期验证模型性能,评估训练效果
    • 项目经理进行阶段性质量检查,评估施工质量和进度
  • 模型保存与早停机制
    • 功能:保存模型状态,应用早停机制优化训练过程
    • 项目经理记录施工进度和质量,决定是否调整或终止施工计划

主要模块

参数解析与初始化

一般在训练模型的时候,需要在这里调整相应的参数,这类似于建筑项目经理制定详细的施工计划和资源分配

常用设置参数

  • **--weights:**模型初始权重路径,通常设置为预训练模型路径,例如YOLOve.pt
  • **--cfg:**模型结构的 YAML 配置文件路径,例如yolov3.yaml
  • **--data:**数据集配置文件路径,定义训练/验证数据集的路径和类别等信息
  • **--hyp:**超参数配置文件路径,控制训练的优化器、学习率等超参数
  • **--epochs:**训练的总轮数,决定训练时长
  • **--batch-size:**批量大小,影响内存占用和训练速度
  • **--imgsz:**输入图像的尺寸
  • **--device:**指定使用的设备,0就表示GPU0
  • **--adam:**是否使用 Adam 优化器(默认使用 SGD)

W&B 参数(类似项目中的监控和记录工具)

  • --entity:设置 W&B 的实体名称,用于项目关联
  • --upload_dataset:是否将数据集上传到 W&B Artifact Table
  • --bbox_interval:控制目标框日志记录的间隔
python 复制代码
def parse_opt(known=False):
    """
    函数功能:
        用于解析命令行参数,设置训练、验证和测试时的超参数及其他相关配置。

    参数:
        known (bool): 是否只解析已知的命令行参数。如果为 True,则返回已知参数,忽略其他参数。
    
    返回:
        argparse.Namespace: 包含解析后参数的命名空间对象 `opt`。
    """
    import argparse

    # 创建 ArgumentParser 对象
    parser = argparse.ArgumentParser()

    # ---------------------------- 常用参数配置 ----------------------------------
    # 权重文件路径
    parser.add_argument('--weights', type=str, default=ROOT / 'weight/yolov3.pt',
                        help='initial weights path (初始权重文件路径)')

    # 模型配置文件路径
    parser.add_argument('--cfg', type=str, default='models/yolov3.yaml',
                        help='model.yaml path (模型结构配置文件路径)')

    # 数据集配置文件路径
    parser.add_argument('--data', type=str, default=ROOT / 'data/you.yaml',
                        help='dataset.yaml path (数据集配置文件路径)')

    # 超参数配置文件路径
    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml',
                        help='hyperparameters path (超参数配置文件路径)')

    # 训练周期数
    parser.add_argument('--epochs', type=int, default=20,
                        help='Number of epochs to train (训练的总轮数)')

    # 批量大小
    parser.add_argument('--batch-size', type=int, default=4,
                        help='Total batch size for all GPUs, -1 for autobatch (总的批量大小)')

    # 图像大小
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=416,
                        help='train, val image size (pixels) (训练和验证时图像的输入尺寸)')

    # 是否使用矩形训练
    parser.add_argument('--rect', action='store_true', default=True,
                        help='rectangular training (是否使用矩形训练)')

    # 是否恢复最近一次的训练
    parser.add_argument('--resume', nargs='?', const=True, default="",
                        help='resume most recent training (恢复最近的训练检查点)')

    # 仅保存最终的检查点
    parser.add_argument('--nosave', action='store_true',
                        help='only save final checkpoint (只保存最终检查点)')

    # 仅验证最终周期
    parser.add_argument('--noval', action='store_true',
                        help='only validate final epoch (只在最后一轮进行验证)')

    # 是否禁用自动生成 anchors
    parser.add_argument('--noautoanchor', action='store_true',
                        help='disable autoanchor check (禁用自动生成 anchor 的功能)')

    # 超参数进化
    parser.add_argument('--evolve', type=int, nargs='?', const=300,
                        help='evolve hyperparameters for x generations (超参数进化代数)')

    # Google Cloud Bucket
    parser.add_argument('--bucket', type=str, default='',
                        help='gsutil bucket (Google 云存储桶路径)')

    # 是否缓存数据集到 RAM 或磁盘
    parser.add_argument('--cache', type=str, nargs='?', const='ram', default=True,
                        help='--cache images in "ram" (default) or "disk" (缓存数据集)')

    # 是否使用加权的图像选择训练
    parser.add_argument('--image-weights', action='store_true',
                        help='use weighted image selection for training (训练时使用加权图像选择)')

    # 指定训练的设备
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu (指定训练的设备)')

    # 是否启用多尺度训练
    parser.add_argument('--multi-scale', action='store_true',
                        help='vary img-size +/- 50%% (多尺度训练)')

    # 将多类别数据作为单类别训练
    parser.add_argument('--single-cls', action='store_true',
                        help='train multi-class data as single-class (单类别训练)')

    # 是否使用 Adam 优化器
    parser.add_argument('--adam', action='store_true',
                        help='use torch.optim.Adam() optimizer (使用 Adam 优化器)')

    # 是否启用同步 BatchNorm
    parser.add_argument('--sync-bn', action='store_true',
                        help='use SyncBatchNorm, only available in DDP mode (同步 BatchNorm,仅适用于 DDP 模式)')

    # 数据加载器的最大工作线程数
    parser.add_argument('--workers', type=int, default=1,
                        help='max dataloader workers (per RANK in DDP mode) (最大数据加载线程数)')

    # 项目保存目录
    parser.add_argument('--project', default=ROOT / 'runs/train',
                        help='save to project/name (项目保存路径)')

    # 保存的实验名称
    parser.add_argument('--name', default='exp',
                        help='save to project/name (实验保存名称)')

    # 是否允许覆盖现有项目
    parser.add_argument('--exist-ok', action='store_true',
                        help='existing project/name ok, do not increment (允许覆盖现有项目名称)')

    # 是否使用四元数据加载器
    parser.add_argument('--quad', action='store_true',
                        help='quad dataloader (启用四元数据加载器)')

    # 是否使用线性学习率
    parser.add_argument('--linear-lr', action='store_true',
                        help='linear LR (启用线性学习率)')

    # 标签平滑参数
    parser.add_argument('--label-smoothing', type=float, default=0.0,
                        help='Label smoothing epsilon (标签平滑参数 epsilon)')

    # 提前停止的容忍轮数
    parser.add_argument('--patience', type=int, default=1000,
                        help='EarlyStopping patience (epochs without improvement) (提前停止的容忍轮数)')

    # 冻结的层数
    parser.add_argument('--freeze', type=int, default=0,
                        help='Number of layers to freeze. backbone=10, all=24 (冻结层数)')

    # 检查点保存间隔
    parser.add_argument('--save-period', type=int, default=-1,
                        help='Save checkpoint every x epochs (disabled if < 1) (每隔几轮保存一次检查点)')

    # 本地进程排名(DDP 模式用)
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='DDP parameter, do not modify (DDP 模式的进程排名)')

    # ---------------------------- W&B(Weights & Biases)参数配置 ----------------------------
    parser.add_argument('--entity', default=None,
                        help='W&B: Entity (W&B 实体名称)')
    parser.add_argument('--upload_dataset', action='store_true',
                        help='W&B: Upload dataset as artifact table (上传数据集到 W&B Artifact Table)')
    parser.add_argument('--bbox_interval', type=int, default=-1,
                        help='W&B: Set bounding-box image logging interval (设置目标框日志记录间隔)')
    parser.add_argument('--artifact_alias', type=str, default='latest',
                        help='W&B: Version of dataset artifact to use (使用的数据集版本别名)')

    # ---------------------------- 参数解析 ----------------------------
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt

日志记录与监控

初始化日志记录器,配置日志系统,并注册回调函数。这类似于建筑项目中的监控和记录系统,用于实时跟踪施工进度和质量

python 复制代码
# 判断是否是主进程(RANK == -1 表示单机训练,RANK == 0 表示分布式训练的主进程)
if RANK in [-1, 0]:  
    # **Step 1: 初始化日志记录器**
    # 创建 Loggers 对象,用于管理训练过程的日志(包括本地日志和 W&B 日志)。
    # 参数说明:
    # - `save_dir`: 日志文件和模型保存的路径。
    # - `weights`: 模型权重文件的路径。
    # - `opt`: 训练过程中所有配置的参数。
    # - `hyp`: 超参数配置。
    # - `LOGGER`: 用于打印日志到控制台的日志记录器。
    loggers = Loggers(save_dir, weights, opt, hyp, LOGGER)

    # **Step 2: W&B 特定处理**
    # 如果启用了 W&B(Weights & Biases)日志记录功能
    if loggers.wandb:
        # 获取 W&B 数据字典(用于记录训练数据相关信息)
        data_dict = loggers.wandb.data_dict

        # 如果恢复训练(`resume` 参数为 True)
        if resume:
            # 使用恢复的权重、训练轮数和超参数,覆盖当前的 opt 配置
            weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp

    # **Step 3: 注册回调函数**
    # 遍历 `loggers` 中所有的方法(`methods(loggers)` 返回可用方法的列表)
    for k in methods(loggers):
        # 将每个方法作为回调函数注册到 `callbacks` 中
        # 参数说明:
        # - `k`: 回调方法的名称(如 `on_train_start`, `on_epoch_end` 等)。
        # - `callback`: 对应的回调函数(通过 `getattr` 获取 `loggers` 中的方法)。
        callbacks.register_action(k, callback=getattr(loggers, k))

模型与数据加载

加载模型权重和配置文件,设置模型参数,加载训练数据。这类似于建筑项目中选择建筑设计方案、准备施工材料和组织施工队伍

分析:该部分代码属于yolov3结构中的哪个阶段?

主要发生在训练前的准备工作,也就是还没有进入模型的前向传播或者反向传播阶段

运行逻辑分析

  • 模型加载与构建
    • 如果提供预训练权重,加载模型并初始化参数
    • 如果没有提供权重,则根据配置文件构建新模型
  • 冻结层设置
    • 固定部分参数(如 Backbone 层),以适应迁移学习或微调场景
  • 训练数据准备
    • 创建数据加载器,支持多线程加载、数据增强和分布式训练

分析:冻结层的使用场景

  • 冻结Backbone
    • 例如之前已经从大规模的数据中学到了一些通用特则会给你,那么通过冻结Backbone的参数,仅仅训练Detection Head用于适配新的任务和类别即可
  • 冻结全部层
    • 这种场景仅仅适合在微调检测头的时候使用
    • 对于小规模数据集(如只有少量的目标类别),可以冻结所有 Backbone 层,只训练最后的预测头
  • 根据自己的需求进行冻结,冻结就是利用已经预训练的特征提取能力,然后让其在新的数据集上可以实现高效的训练
python 复制代码
# ------------------------------- 模型部分 -------------------------------

# 检查权重文件的后缀是否为 .pt(PyTorch 模型格式)
check_suffix(weights, '.pt')  

# 判断是否加载预训练模型
pretrained = weights.endswith('.pt')  

if pretrained:  # 如果加载的是预训练模型
    # 确保在分布式训练中只由一个进程下载权重文件,避免冲突
    with torch_distributed_zero_first(LOCAL_RANK):
        weights = attempt_download(weights)  # 下载或加载指定的权重文件

    # 加载权重文件到内存,并指定加载到的设备(如 GPU 或 CPU)
    ckpt = torch.load(weights, map_location=device)  

    # 创建模型对象
    # - 如果提供了 cfg 文件,则使用 cfg 文件构建模型
    # - 否则,使用权重文件中的模型配置(`ckpt['model'].yaml`)
    model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  

    # 定义要排除的参数(如 anchor 参数):
    # - 如果提供了 cfg 文件或 hyperparameters 中指定了 anchor 配置,并且不是恢复训练模式,则排除 anchor 参数。
    exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []

    # 加载预训练模型的参数
    csd = ckpt['model'].float().state_dict()  # 从权重文件中提取模型的状态字典
    csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # 匹配当前模型的参数,并排除指定的参数
    model.load_state_dict(csd, strict=False)  # 将预训练权重加载到模型中,允许部分参数不匹配

    # 打印日志:显示加载的参数数量与模型参数总数量
    LOGGER.info(f'从 {weights} 转移了 {len(csd)}/{len(model.state_dict())} 项')  

else:  # 如果没有加载预训练模型,则从头构建一个新模型
    model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  


# ------------------------------- 冻结层部分 -------------------------------

# 冻结部分模型层的参数,以防止它们在训练过程中更新
# 生成冻结层的层名前缀列表,例如:['model.0.', 'model.1.', ..., 'model.N.']
freeze = [f'model.{x}.' for x in range(freeze)]  

# 遍历模型中的所有参数(键值对:参数名称和参数值)
for k, v in model.named_parameters():
    v.requires_grad = True  # 默认所有参数可训练
    # 如果当前参数的名称包含在冻结层列表中
    if any(x in k for x in freeze):
        LOGGER.info(f'冻结 {k}')  # 打印冻结的参数名称
        v.requires_grad = False  # 禁止该参数的梯度更新(冻结参数)


# ------------------------------- 数据加载部分 -------------------------------

# 创建训练数据加载器(train_loader)和数据集对象(dataset)
train_loader, dataset = create_dataloader(
    train_path,          # 训练数据的路径
    imgsz,               # 输入图像的大小
    batch_size // WORLD_SIZE,  # 每个 GPU 的批量大小(在分布式训练中,批量大小会被划分)
    gs,                  # 网格大小(grid size),用于确保图像大小是网格的倍数
    single_cls,          # 是否将多类数据当作单类数据处理
    hyp=hyp,             # 超参数配置
    augment=True,        # 是否进行数据增强
    cache=opt.cache,     # 是否缓存数据到内存或磁盘
    rect=opt.rect,       # 是否使用矩形训练
    rank=LOCAL_RANK,     # 分布式训练时的本地进程编号
    workers=workers,     # 数据加载线程数
    image_weights=opt.image_weights,  # 是否加权选择图像
    quad=opt.quad,       # 是否启用四元数据加载器
    prefix=colorstr('train: '),  # 日志前缀
    shuffle=True         # 是否对数据进行随机打乱
)

优化器与学习率调度器设置

理解优化器和学习率调度器

  • 优化器设置:类似于项目经理分配施工资源(如劳动力、设备),选择适当的施工方法(如快速建造或精细施工)
  • 学习率调度器:对应于施工进度计划,决定资源的使用速度和调整施工节奏,以确保建筑按时完成且质量达标
python 复制代码
# ------------------------------- 优化器设置 -------------------------------

# 计算梯度累积步数(Accumulate Step)
# nbs: 基准批量大小(64,是一个参考值)
nbs = 64  
# 计算当前的累积步数,公式为:基准批量大小 / 当前批量大小(最小值为 1)
accumulate = max(round(nbs / batch_size), 1)  

# 根据批量大小和累积步数调整权重衰减(Weight Decay)
# 如果批量大小变大,适当放大权重衰减;反之则缩小权重衰减。
hyp['weight_decay'] *= batch_size * accumulate / nbs  
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")  # 打印调整后的权重衰减值

# ------------------------------- 参数分组 -------------------------------

# 将模型的参数分为三类:
# g0: BatchNorm 的权重
# g1: 卷积层或全连接层的权重
# g2: 偏置(bias)
g0, g1, g2 = [], [], []  # 初始化三个参数组
for v in model.modules():  # 遍历模型中的每个模块
    # 如果模块有偏置参数(bias),并且是 nn.Parameter 类型,则将其加入 g2
    if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
        g2.append(v.bias)
    # 如果模块是 BatchNorm2d,则将其权重加入 g0
    if isinstance(v, nn.BatchNorm2d):
        g0.append(v.weight)
    # 如果模块有权重参数(weight),并且是 nn.Parameter 类型,则将其加入 g1
    elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
        g1.append(v.weight)

# ------------------------------- 优化器设置 -------------------------------

# 如果使用 Adam 优化器
if opt.adam:
    # 创建 Adam 优化器
    optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999))  
else:
    # 否则使用 SGD 优化器
    optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)  # 使用 Nesterov 动量

# 为优化器添加参数组
# 添加 g1 参数组,并设置 weight_decay 为超参数中的值
optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']})  
# 添加 g2 参数组(偏置),不使用权重衰减
optimizer.add_param_group({'params': g2})  

# ------------------------------- 学习率调度器 -------------------------------

# 定义学习率调度器的变化方式(scheduler)
if opt.linear_lr:
    # 如果使用线性学习率,定义线性衰减函数
    # 公式:初始学习率从 (1 - x) 减少到 lrf(最低学习率比例)
    lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf']  
else:
    # 否则使用 One-Cycle 学习率调度器
    lf = one_cycle(1, hyp['lrf'], epochs)  

# 创建学习率调度器,基于上述的学习率变化函数 lf
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)  

训练循环(核心)

主要阶段总结

  • 训练循环 :类似于施工队每日的施工任务,包括材料的使用、施工进度的监控、质量的检查以及资源的优化配置
  • 热身阶段 :类似于施工队初期的准备工作,逐步适应施工环境和进度
  • 多尺度训练 :类似于根据不同施工需求调整施工方法和材料,以适应不同的建筑部分和设计要求
  • 前向传播与损失计算 :对应于施工过程中的质量检查和评估,确保每一步施工符合设计标准
  • 反向传播与优化 :类似于根据质量检查结果调整施工方法和资源分配,以提高施工效率和建筑质量
  • 日志记录 :类似于施工日志和进度报告,实时记录施工进展和遇到的问题
  • 学习率调度 :对应于施工进度的动态调整和优化,根据施工进展和质量要求调整施工节奏
  • 验证与评估 :类似于阶段性质量检查和最终验收,确保建筑物的整体质量和功能
python 复制代码
# ---------------------- 训练循环 ----------------------
for epoch in range(start_epoch, epochs):  # 遍历每个 epoch
    model.train()  # 设置模型为训练模式
    if opt.image_weights:  # 如果启用了类别权重调整
        # 根据类别权重和映射关系调整样本权重
        cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  
        iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  
        dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # 根据权重重采样数据集

    mloss = torch.zeros(3, device=device)  # 初始化平均损失记录 (box_loss, obj_loss, cls_loss)
    if RANK != -1:  # 如果是分布式训练模式
        train_loader.sampler.set_epoch(epoch)  # 设置当前 epoch,确保分布式训练的数据加载一致

    # 进度条设置
    pbar = enumerate(train_loader)  # 枚举数据加载器
    print(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))  # 打印标题行
    if RANK in [-1, 0]:  # 如果是主进程
        # 显示训练进度条
        pbar = tqdm(pbar, total=nb, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')  

    optimizer.zero_grad()  # 优化器梯度清零
    for i, (imgs, targets, paths, _) in pbar:  # 遍历批次数据
        ni = i + nb * epoch  # 计算全局迭代步数
        imgs = imgs.to(device, non_blocking=True).float() / 255  # 将图像归一化到 [0,1] 并移到设备上

        # -------------------- 热身阶段 --------------------
        if ni <= nw:  # 如果在热身阶段
            xi = [0, nw]  # 热身范围
            # 动态调整累积步数(accumulate)和学习率
            accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())  
            for j, x in enumerate(optimizer.param_groups):  # 遍历优化器的参数组
                # 动态调整学习率
                x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                # 动态调整动量
                if 'momentum' in x:
                    x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])

        # -------------------- 多尺度训练 --------------------
        if opt.multi_scale:  # 如果启用了多尺度训练
            # 随机生成新的训练图像尺寸
            sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs  
            sf = sz / max(imgs.shape[2:])  # 缩放因子
            if sf != 1:  # 如果需要缩放
                # 计算新的图像尺寸并调整
                ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  
                imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)  

        # -------------------- 前向传播 --------------------
        with amp.autocast(enabled=cuda):  # 混合精度加速
            pred = model(imgs)  # 模型前向传播
            loss, loss_items = compute_loss(pred, targets.to(device))  # 计算损失
            if RANK != -1:  # 如果是分布式训练
                loss *= WORLD_SIZE  # 按照分布式规模调整损失

        # -------------------- 反向传播 --------------------
        scaler.scale(loss).backward()  # 使用梯度缩放反向传播

        # -------------------- 参数更新 --------------------
        if ni - last_opt_step >= accumulate:  # 如果满足累积步数条件
            scaler.step(optimizer)  # 更新优化器参数
            scaler.update()  # 更新梯度缩放比例
            optimizer.zero_grad()  # 清零梯度
            if ema:  # 如果启用了 EMA
                ema.update(model)  # 更新模型的指数移动平均
            last_opt_step = ni  # 更新最后一次优化步数

        # -------------------- 日志记录 --------------------
        if RANK in [-1, 0]:  # 如果是主进程
            # 动态更新平均损失
            mloss = (mloss * i + loss_items) / (i + 1)  
            mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # 显存使用量
            # 更新进度条显示内容
            pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
                f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))

    # -------------------- 学习率调度 --------------------
    lr = [x['lr'] for x in optimizer.param_groups]  # 获取当前学习率
    scheduler.step()  # 更新学习率调度器

    # -------------------- 评估与保存 --------------------
    if RANK in [-1, 0]:  # 如果是主进程
        callbacks.run('on_train_epoch_end', epoch=epoch)  # 运行训练结束的回调
        ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])  # 更新 EMA 属性
        final_epoch = (epoch + 1 == epochs) or stopper.possible_stop  # 检查是否是最后一个 epoch
        if not noval or final_epoch:  # 如果需要验证
            # 运行验证,并获取验证结果
            results, maps, _ = val.run(data_dict,
                                       batch_size=batch_size // WORLD_SIZE * 2,
                                       imgsz=imgsz,
                                       model=ema.ema,
                                       single_cls=single_cls,
                                       dataloader=val_loader,
                                       save_dir=save_dir,
                                       plots=False,
                                       callbacks=callbacks,
                                       compute_loss=compute_loss)

        # 更新最佳 mAP
        fi = fitness(np.array(results).reshape(1, -1))  # 计算当前结果的 fitness
        if fi > best_fitness:  # 如果当前 fitness 是最优的
            best_fitness = fi  # 更新最佳 fitness
        log_vals = list(mloss) + list(results) + lr  # 记录日志值
        callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)  # 运行回调

        # 保存模型
        if (not nosave) or (final_epoch and not evolve):  # 如果需要保存模型
            ckpt = {'epoch': epoch,  # 记录当前 epoch
                    'best_fitness': best_fitness,  # 最佳 fitness
                    'model': deepcopy(de_parallel(model)).half(),  # 模型参数
                    'ema': deepcopy(ema.ema).half(),  # EMA 参数
                    'updates': ema.updates,  # EMA 更新次数
                    'optimizer': optimizer.state_dict(),  # 优化器状态
                    'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,  # W&B 运行 ID
                    'date': datetime.now().isoformat()}  # 保存日期

            torch.save(ckpt, last)  # 保存为最后一次权重文件
            if best_fitness == fi:  # 如果当前 fitness 是最优的
                torch.save(ckpt, best)  # 保存为最佳权重文件
            if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):  # 按周期保存权重
                torch.save(ckpt, w / f'epoch{epoch}.pt')  
            del ckpt  # 删除检查点,释放内存
            callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)  # 运行回调

        # 提前停止(仅在单 GPU 模式下)
        if RANK == -1 and stopper(epoch=epoch, fitness=fi):  # 如果满足提前停止条件
            break  # 停止训练

验证与评估

在每个训练周期结束后,对模型进行验证,评估其在验证集上的性能(如mAP)。这类似于建筑项目中的阶段性质量检查和评估,确保施工质量符合设计要求

python 复制代码
if RANK in [-1, 0]:  # 检查当前设备是否为主进程(单GPU模式或主节点)
    # 调用验证函数 val.run(),对当前模型在验证集上的性能进行评估
    results, maps, _ = val.run(
        data_dict,                 # 数据集的配置信息,包含训练、验证和测试数据的路径
        batch_size=batch_size // WORLD_SIZE * 2,  # 验证集的批次大小,调整为全局批量大小(batch_size)除以总进程数 WORLD_SIZE,再乘以 2
        imgsz=imgsz,               # 输入图像的尺寸
        model=ema.ema,             # 使用 EMA(指数移动平均)模型的权重进行评估,以获得更平滑和稳定的验证性能
        single_cls=single_cls,     # 是否将多类别数据视为单类别任务,用于单类别检测
        dataloader=val_loader,     # 验证集的数据加载器
        save_dir=save_dir,         # 保存结果的路径,用于存储验证过程的日志或可视化图表
        plots=False,               # 是否生成验证结果的可视化图表,设置为 False 表示不生成
        callbacks=callbacks,       # 回调函数,用于扩展验证过程,例如记录日志或自定义处理
        compute_loss=compute_loss  # 损失函数,用于计算验证过程中的损失值
    )

模型保存与早停机制

据训练过程中的表现,保存当前模型的状态(如最佳模型、最新模型等)。同时,通过早停机制,在模型性能不再提升时提前终止训练

  • 模型保存:项目经理定期记录施工进度和质量状况,保存关键的施工记录和里程碑
  • 早停机制:如果发现施工质量无法满足要求,或项目进度严重滞后,项目经理决定提前终止或调整施工计划,以避免资源浪费和进一步的问题
python 复制代码
# 保存模型
if (not nosave) or (final_epoch and not evolve):  # 检查是否需要保存模型
    # 构建一个检查点字典,用于保存模型的状态和相关信息
    ckpt = {
        'epoch': epoch,                            # 当前的训练轮次
        'best_fitness': best_fitness,              # 当前训练过程中模型的最佳 fitness(如 mAP)
        'model': deepcopy(de_parallel(model)).half(),  # 深拷贝模型的状态,转换为半精度以减少存储需求
        'ema': deepcopy(ema.ema).half(),           # 深拷贝 EMA(指数移动平均)模型的状态
        'updates': ema.updates,                    # EMA 更新的次数
        'optimizer': optimizer.state_dict(),       # 优化器的状态字典(保存优化器参数和学习率等信息)
        'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,  # W&B 运行 ID(如果使用 W&B)
        'date': datetime.now().isoformat()         # 保存当前时间的时间戳,用于记录模型保存时间
    }

    # 保存最新的模型权重到指定路径 `last`
    torch.save(ckpt, last)

    # 如果当前的 fitness 指标是最佳值,则保存模型到 `best` 路径
    if best_fitness == fi:
        torch.save(ckpt, best)

    # 如果训练轮次大于 0 且保存周期 `save_period` 大于 0,并且当前轮次是保存周期的倍数
    # 则保存当前轮次的模型权重到以 `epoch{轮次}.pt` 命名的文件
    if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
        torch.save(ckpt, w / f'epoch{epoch}.pt')  # 保存到指定路径

    # 删除保存的检查点对象,以释放内存
    del ckpt

    # 触发 `on_model_save` 回调函数,通知其他组件模型已保存
    callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)

# 单 GPU 模式下的提前停止机制
if RANK == -1 and stopper(epoch=epoch, fitness=fi):  # 如果是单 GPU 模式,并且达到提前停止条件
    break  # 停止训练,结束当前循环
相关推荐
阿拉斯攀登16 分钟前
【无人售货柜・RK+YOLO】篇 6:安卓端落地!RK3576 + 安卓系统,YOLO RKNN 模型实时推理保姆级教程
android·人工智能·yolo·目标跟踪·瑞芯微·嵌入式驱动
Cpsu42 分钟前
EdgeCrafter:实时目标检测任务新SOTA
人工智能·yolo·目标检测·计算机视觉
JicasdC123asd17 小时前
密集残差瓶颈网络改进YOLOv26特征复用与梯度传播双重优化
网络·yolo·目标跟踪
JicasdC123asd20 小时前
密集连接瓶颈模块改进YOLOv26特征复用与梯度流动双重优化
人工智能·yolo·目标跟踪
duyinbi751721 小时前
局部特征提取改进YOLOv26空间移位卷积与轻量化设计双重突破
人工智能·yolo·目标跟踪
张道宁1 天前
基于Spring Boot与Docker的YOLOv8检测服务实战
spring boot·yolo·docker
duyinbi75171 天前
大核瓶颈架构改进YOLOv26扩大感受野与多尺度特征提取双重突破
yolo·架构
孤狼warrior1 天前
YOLO技术架构发展详解(从v1到v8)近万字底层实现逻辑解析
yolo
张张123y1 天前
机器学习与深度学习:从基础概念到YOLOv8全解析
深度学习·yolo·机器学习
hans汉斯2 天前
基于区块链和语义增强的科研诚信智能管控平台
人工智能·算法·yolo·数据挖掘·区块链·汉斯出版社