YOLOv9-0.1部分代码阅读笔记-detect.py

detect.py

detect.py

目录

detect.py

1.所需的库和模块

[2.def run(weights=ROOT / 'yolo.pt', source=ROOT / 'data/images', data=ROOT / 'data/coco.yaml', imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, max_det=1000, device='', view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=ROOT / 'runs/detect', name='exp', exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False, vid_stride=1,):](#2.def run(weights=ROOT / 'yolo.pt', source=ROOT / 'data/images', data=ROOT / 'data/coco.yaml', imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, max_det=1000, device='', view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=ROOT / 'runs/detect', name='exp', exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False, vid_stride=1,):)

[3.def parse_opt():](#3.def parse_opt():)

[4.def main(opt):](#4.def main(opt):)

[5.if name == "main":](#5.if name == "main":)


1.所需的库和模块

python 复制代码
import argparse
import os
import platform
import sys
from pathlib import Path

import torch

# 这段代码涉及到Python中的文件路径操作和系统路径管理。
# 创建了一个 Path 对象,它代表当前执行文件(脚本)的路径,并使用 resolve() 方法解析为绝对路径。 __file__ 是Python中的一个特殊变量,它包含了当前文件的路径。
FILE = Path(__file__).resolve()
# 获取 FILE 的父目录,即脚本所在的目录的上一级目录,并将其赋值给 ROOT 变量。注释中的 YOLO root directory 表明这个目录是YOLO项目的根目录。
ROOT = FILE.parents[0]  # YOLO root directory
# 检查 ROOT 目录的字符串表示是否已经在 sys.path 中。 sys.path 是一个列表,包含了Python解释器搜索模块的路径。
if str(ROOT) not in sys.path:
    # 如果 ROOT 不在 sys.path 中,将 ROOT 的字符串表示添加到 sys.path 列表的末尾。这样做的目的是确保Python解释器能够找到YOLO项目根目录下的模块。
    sys.path.append(str(ROOT))  # add ROOT to PATH
# 重新设置 ROOT 变量,使其成为相对于当前工作目录( Path.cwd() )的相对路径。 os.path.relpath() 函数计算两个路径之间的相对路径。
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
# 这段代码的主要目的是确保YOLO项目的根目录被添加到Python的模块搜索路径中,以便可以正确地导入项目中的模块。同时,它还将根目录的路径转换为相对于当前工作目录的路径,这样做可以使得路径在不同环境中更加灵活和可移植。

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode

2.def run(weights=ROOT / 'yolo.pt', source=ROOT / 'data/images', data=ROOT / 'data/coco.yaml', imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, max_det=1000, device='', view_img=False, save_txt=False, save_conf=False, save_crop=False, nosave=False, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=ROOT / 'runs/detect', name='exp', exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False, vid_stride=1,):

python 复制代码
# 这段代码定义了一个名为 run 的函数,它用于运行一个深度学习模型的推理过程。这个函数接受多个参数,用于配置模型的权重、输入源、数据集配置、推理大小、置信度阈值、NMS IOU阈值等。
# 这个装饰器,它用于将函数运行在智能推理模式下,这种模式可能会优化推理过程。
# def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
# -> 它是一个装饰器工厂,用于根据 PyTorch 版本应用不同的装饰器。根据 torch_1_9 的值,选择 torch.inference_mode 或 torch.no_grad 装饰器,并将其应用于函数 fn 。
# -> return decorate
# -> return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
@smart_inference_mode()
# 定义了一个名为 run 的函数,它接受多个参数,用于配置模型的权重、输入源、数据集配置等。
# 1.weights :指定了模型权重的路径或者是Triton服务的URL,默认是项目根目录下的 yolo.pt 文件。
# 2.source :指定了输入源的路径,可以是文件、目录、URL、glob模式、屏幕捕获或摄像头,默认是项目根目录下的 data/images 目录。
# 3.data :指定了数据集配置文件的路径,默认是项目根目录下的 data/coco.yaml 文件。
# 4.imgsz :指定了推理时图像的尺寸(高度和宽度),默认为640x640。
# 5.conf_thres :指定了置信度阈值,默认为0.25。
# 6.iou_thres :指定了非最大抑制(NMS)的IOU(交并比)阈值,默认为0.45。
# 7.max_det :指定了每张图像的最大检测数量,默认为1000。
# 8.device :指定了使用的设备,如GPU编号或CPU,默认为空,意味着使用系统默认设备。
# 9.view_img :指定了是否显示结果图像,默认为False。
# 10.save_txt :指定了是否将结果保存到文本文件,默认为False。
# 11.save_conf :指定了在保存文本文件时是否包含置信度,默认为False。
# 12.save_crop :指定了是否保存裁剪的预测框,默认为False。
# 13.nosave :指定了是否不保存图像或视频,默认为False。
# 14.classes :允许通过类别过滤检测结果。
# 15.agnostic_nms :指定了是否使用类别无关的NMS,默认为False。
# 16.augment :指定了是否进行增强推理,默认为False。
# 17.visualize :指定了是否可视化特征,默认为False。
# 18.update :指定了是否更新所有模型,默认为False。
# 19.project :指定了保存结果的项目目录,默认是项目根目录下的 runs/detect 目录。
# 20.name :指定了保存结果的名称,默认为 exp 。
# 21.exist_ok :指定了如果项目/名称已存在是否允许,默认为False,意味着不允许。
# 22.line_thickness :指定了边界框的线宽,默认为3像素。
# 23.hide_labels :指定了是否隐藏标签,默认为False。
# 24.hide_conf :指定了是否隐藏置信度,默认为False。
# 25.half :指定了是否使用FP16半精度推理,默认为False。
# 26.dnn :指定了是否使用OpenCV DNN进行ONNX推理,默认为False。
# 27.vid_stride :指定了视频帧率步长,默认为1。
def run(
        weights=ROOT / 'yolo.pt',  # model path or triton URL
        source=ROOT / 'data/images',  # file/dir/URL/glob/screen/0(webcam)
        data=ROOT / 'data/coco.yaml',  # dataset.yaml path
        imgsz=(640, 640),  # inference size (height, width)
        conf_thres=0.25,  # confidence threshold
        iou_thres=0.45,  # NMS IOU threshold
        max_det=1000,  # maximum detections per image
        device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
        view_img=False,  # show results
        save_txt=False,  # save results to *.txt
        save_conf=False,  # save confidences in --save-txt labels
        save_crop=False,  # save cropped prediction boxes
        nosave=False,  # do not save images/videos
        classes=None,  # filter by class: --class 0, or --class 0 2 3
        agnostic_nms=False,  # class-agnostic NMS
        augment=False,  # augmented inference
        visualize=False,  # visualize features
        update=False,  # update all models
        project=ROOT / 'runs/detect',  # save results to project/name
        name='exp',  # save results to project/name
        exist_ok=False,  # existing project/name ok, do not increment
        line_thickness=3,  # bounding box thickness (pixels)
        hide_labels=False,  # hide labels
        hide_conf=False,  # hide confidences
        half=False,  # use FP16 half-precision inference
        dnn=False,  # use OpenCV DNN for ONNX inference
        vid_stride=1,  # video frame-rate stride
):
    # 这段代码是 run 函数的一部分,用于处理输入源,并确定如何处理它。
    # 将 source 参数转换为字符串类型。这是因为后续的操作需要 source 作为一个字符串来处理。
    source = str(source)
    # 设置了一个布尔值 save_img ,它决定了是否保存推理后的图像。如果 nosave 参数为 False 且 source 不是以 .txt 结尾的文件,则 save_img 为 True ,表示会保存图像。
    save_img = not nosave and not source.endswith('.txt')  # save inference images
    # 检查 source 是否是一个文件,并且该文件的扩展名是否在支持的图像或视频格式列表 IMG_FORMATS 和 VID_FORMATS 中。如果是, is_file 为 True 。
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
    # 检查 source 是否是一个URL。它通过检查字符串是否以 rtsp:// 、 rtmp:// 、 http:// 或 https:// 开头来判断。如果是, is_url 为 True 。
    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
    # 确定 source 是否表示一个摄像头输入。如果 source 是数字(可能表示摄像头索引),以 .txt 结尾(可能包含摄像头输出的文件),或者是一个URL但不是一个文件(如流媒体),则 webcam 为 True 。
    webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
    # 检查 source 是否表示一个屏幕截图。如果 source 以 "screen" 开头(不区分大小写),则 screenshot 为 True 。
    screenshot = source.lower().startswith('screen')
    # 一个条件语句,如果 source 是一个URL且是一个文件(即在线文件),则执行下面的代码。
    if is_url and is_file:
        # 如果 source 是一个在线文件,这行代码会调用 check_file 函数来下载该文件。
        # def check_file(file, suffix=''):
        # -> 检查一个文件是否存在,如果不存在且文件是一个网址,则下载该文件;如果文件是一个 ClearML 数据集 ID,则检查 ClearML 是否已安装;如果都不是,则在指定的目录中搜索文件。
        # -> return file / return files[0]  # return file
        source = check_file(source)  # download
    # 这段代码的主要目的是确定输入源的类型(文件、URL、摄像头、屏幕截图)并进行相应的处理。它检查 source 是否是文件、URL,并根据这些信息设置保存图像的标志,以及是否需要下载在线文件。这些信息对于后续的图像加载和推理过程至关重要。

    # 这两行代码处理了结果保存目录的创建和设置。
    # Directories
    # 定义了保存推理结果的目录路径 save_dir 。它使用 increment_path 函数,该函数用于在目录名中添加一个数字后缀,以确保目录名是唯一的,避免覆盖已有的结果。如果 exist_ok 参数为 True ,则即使目录已存在也不会增加后缀,否则会递增后缀直到找到一个不存在的目录名。
    # Path(project) / name 构建了基本的路径,即项目目录下的特定实验或检测名称的路径。
    # def increment_path(path, exist_ok=False, sep='', mkdir=False): -> 为文件或目录生成一个新的路径名,如果原始路径已经存在,则通过在路径后面添加一个数字(默认从2开始递增)来创建一个新的路径。函数返回最终的路径 path 。 -> return path
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
    # 根据 save_txt 参数的值决定是否在 save_dir 下创建一个名为 labels 的子目录。如果 save_txt 为 True ,则创建 save_dir/labels 目录;如果为 False ,则直接使用 save_dir 。
    # mkdir(parents=True, exist_ok=True) 函数创建这个目录。 parents=True 参数表示如果父目录不存在,则一并创建。 exist_ok=True 参数表示如果目录已存在,则不抛出异常。
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir
    # 这段代码的目的是确保有一个合适的目录来保存推理结果。如果设置了保存文本文件,它还会创建一个子目录来存放这些文本文件。这个目录是唯一的,因为 increment_path 函数会处理同名目录的冲突。这样的设计可以组织和管理不同推理运行的结果,避免数据混淆。

    # 这段代码是 run 函数中用于加载模型的部分。
    # Load model
    # 调用 select_device 函数来确定模型将运行在哪个设备上。 device 参数可以是GPU编号(如 '0' 或 '0,1,2,3' ),也可以是 'cpu' 。如果 device 参数为空字符串,则 select_device 函数将自动选择一个设备,通常是默认的GPU(如果有的话),或者CPU。
    # def select_device(device='', batch_size=0, newline=True):
    # -> 根据用户提供的参数选择使用 CPU、单个 GPU 或多个 GPU,并返回一个对应的 PyTorch 设备对象。返回一个 PyTorch 设备对象,用于指定后续计算应该在哪个设备上执行。
    # -> return torch.device(arg)
    device = select_device(device)
    # 初始化 DetectMultiBackend 类的实例,这是一个用于检测的模型,支持多种后端。 weights 参数指定了模型权重的路径, device 是模型运行的设备, dnn 表示是否使用OpenCV DNN进行ONNX模型推理, data 是数据集配置文件的路径, fp16 表示是否使用半精度(FP16)进行推理,这可以加快推理速度,但需要硬件支持。
    # class DetectMultiBackend(nn.Module):
    # -> DetectMultiBackend 类实现了一个多后端检测模型,能够支持多种不同的模型格式和推理引擎。
    # -> def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    # 从模型实例中提取三个属性 : stride ( 模型的步长 ) , names ( 类别名称列表 ) , pt (一个布尔值,表示模型 是否使用PyTorch Tensor )。这些属性将在后续的推理过程中使用。
    stride, names, pt = model.stride, model.names, model.pt
    # 调用 check_img_size 函数来验证或调整输入图像的尺寸。 imgsz 参数是推理时图像的尺寸(高度和宽度), s 是模型的步长。这个函数确保输入图像的尺寸与模型的步长兼容,可能需要对图像尺寸进行调整以满足模型的要求。
    imgsz = check_img_size(imgsz, s=stride)  # check image size
    # 这段代码负责加载模型并确保其在正确的设备上运行。它还提取了模型的关键属性,并验证了输入图像的尺寸是否适合模型。这些步骤是进行有效推理的必要前提,确保了模型可以接收适当尺寸的输入,并在预期的设备上运行。

    # 这段代码负责根据输入源的类型创建相应的数据加载器(dataloader),并为视频推理准备必要的变量。
    # Dataloader
    # 设置了一个变量 bs 代表批处理大小(batch size),这里初始化为1,意味着默认情况下每次处理一张图像。
    bs = 1  # batch_size
    # 一个条件判断,如果 webcam 为 True ,表示输入源是摄像头。
    if webcam:
        # 如果输入源是摄像头,调用 check_imshow 函数来检查是否可以在屏幕上显示图像。 warn=True 参数表示如果无法显示图像,将打印警告信息。
        # def check_imshow(warn=False): -> 检查当前环境是否支持图像显示,特别是在使用 OpenCV 的 cv2.imshow() 函数时。如果以上代码都成功执行,函数返回 True ,表示环境支持图像显示。在异常情况下,函数返回 False ,表示环境不支持图像显示。 -> return True / return False
        view_img = check_imshow(warn=True)
        # 创建了一个 LoadStreams 数据加载器实例,用于从摄像头或视频流中加载图像。 source 是输入源, img_size 是图像尺寸, stride 是模型步长, auto 是一个布尔值表示是否自动调整图像尺寸, vid_stride 是视频帧率步长。
        # class LoadStreams:
        # -> 用于从多个视频流(包括IP摄像头、视频文件或YouTube视频)中加载图像数据。
        # -> def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
        # 如果数据加载器 dataset 被成功创建,更新 bs 为数据集中的流数量,这可能大于1,如果视频流中有多个图像源。
        bs = len(dataset)
    # 另一个条件判断,如果 screenshot 为 True ,表示输入源是屏幕截图。
    elif screenshot:
        # 如果输入源是屏幕截图,这行代码创建了一个 LoadScreenshots 数据加载器实例,用于加载屏幕截图图像。
        # class LoadScreenshots:
        # -> 用于从屏幕截图中加载图像数据,通常用于实时数据流或视频流的处理。
        # -> def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
    # 如果输入源既不是摄像头也不是屏幕截图,执行这个分支。
    else:
        # 创建了一个 LoadImages 数据加载器实例,用于从文件或目录中加载图像。参数与 LoadStreams 类似。
        # class LoadImages:
        # -> 用于加载图像和视频文件,准备它们以供深度学习模型使用。
        # -> def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
    # 初始化两个列表 vid_path 和 vid_writer ,它们分别用于存储 视频保存路径 和 视频写入对象 。列表的长度由 bs 决定,即批处理大小。 None 值表示这些变量在初始化时尚未被分配具体的路径或写入对象。
    vid_path, vid_writer = [None] * bs, [None] * bs
    # 这段代码根据不同的输入源类型(摄像头、屏幕截图或普通图像/视频文件)创建相应的数据加载器,并准备视频推理所需的变量。通过这种方式,函数可以灵活地处理不同类型的输入源,并为后续的推理过程提供必要的数据。

    # 这段代码是 run 函数中用于执行模型推理的部分。
    # Run inference
    # 调用模型的 warmup 方法,用于预热模型。传递给 warmup 的 imgsz 参数是一个元组,表示推理时的图像尺寸。如果模型使用PyTorch Tensor ( pt ) 或是Triton后端 ( model.triton ),则批处理大小为1,否则使用之前确定的 bs 。这个预热步骤有助于初始化模型,使其在实际推理之前达到最佳状态。
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
    # 初始化三个变量 : seen 用于记录处理的图像数量, windows 用于存储窗口句柄(可能用于显示图像), dt 是一个包含三个 Profile 实例的元组,用于性能分析,分别对应 预处理 、 推理 和 NMS 的时间。
    seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
    # 一个循环,遍历数据加载器 dataset 返回的每个项目。每个项目包含 路径 path 、图像 im 、 原始图像 im0s 、 视频捕获对象 vid_cap 和一个 字符串 s 。
    for path, im, im0s, vid_cap, s in dataset:
        #  这个 with 语句块用于测量 预处理 步骤的时间,它使用 dt 元组中的第一个 Profile 实例。
        with dt[0]:
            # 将 NumPy 数组 im 转换为 PyTorch 张量,并将其移动到模型所在的设备(GPU或CPU)。
            im = torch.from_numpy(im).to(model.device)
            # 如果模型使用半精度(FP16),这行代码将张量 im 转换为半精度浮点数;否则,转换为全精度浮点数。这将图像数据从 uint8 格式转换为浮点数格式。
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
            # 将图像数据从 [0, 255] 的范围归一化到 [0.0, 1.0] 的范围。
            im /= 255  # 0 - 255 to 0.0 - 1.0
            # 条件判断检查张量 im 的维度。
            if len(im.shape) == 3:
                # 如果 im 只有三个维度(表示单个图像),在第一个维度添加一个额外的维度,以适应模型的批处理要求。
                im = im[None]  # expand for batch dim
    # 这段代码负责准备模型进行推理,包括预热模型、初始化性能分析工具、遍历数据加载器中的图像,并执行必要的预处理步骤,如数据类型转换、归一化和批处理维度扩展。这些步骤确保了图像数据以正确的格式输入到模型中,以便进行有效的推理。

        # 这段代码继续处理模型推理的过程,并且包含了可视化推理结果的选项。
        # Inference
        # 这个 with 语句块用于测量模型推理步骤的时间,它使用 dt 元组中的第二个 Profile 实例。
        with dt[1]:
            # 处理可视化推理结果的路径。如果 visualize 参数为 True ,则使用 increment_path 函数来创建一个唯一的路径,用于保存可视化结果。 save_dir / Path(path).stem 构建了基本的路径,即在 save_dir 下使用图像路径的基本名称(无扩展名)。 mkdir=True 参数指示 increment_path 函数创建必要的目录。如果 visualize 为 False ,则 visualize 变量设置为 False ,表示不保存可视化结果。
            # def increment_path(path, exist_ok=False, sep='', mkdir=False): -> 为文件或目录生成一个新的路径名,如果原始路径已经存在,则通过在路径后面添加一个数字(默认从2开始递增)来创建一个新的路径。函数返回最终的路径 path 。 -> return path
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
            # 执行模型的推理。 model 是之前加载的模型实例, im 是预处理后的图像张量。 augment 参数指示是否进行增强推理,这通常用于提高模型的鲁棒性或准确性。 visualize 参数传递给模型,指示是否需要在推理过程中保存中间层的可视化结果。
            pred = model(im, augment=augment, visualize=visualize)
        # 这段代码负责执行模型的推理,并根据 visualize 参数决定是否保存中间层的可视化结果。通过使用 Profile 实例,它还测量并记录推理步骤所需的时间,这对于性能分析和优化非常有用。推理结果 pred 将包含模型对输入图像的检测结果,这些结果将在后面的处理步骤中进一步处理,如应用NMS和保存结果。

        # 这段代码处理非最大抑制(NMS)和可选的第二阶段分类。
        # NMS
        # 这个 with 语句块用于测量非最大抑制(NMS)步骤的时间,它使用 dt 元组中的第三个 Profile 实例。
        with dt[2]:
            # 对模型的推理结果应用NMS。 non_max_suppression 函数接受以下参数 : pred 模型输出的未经NMS处理的预测结果。 conf_thres 置信度阈值,只有高于此阈值的预测结果会被保留。 iou_thres 交并比(IOU)阈值,用于确定何时抑制重叠的检测框。 classes 可选的类别过滤器,只保留特定类别的检测结果。 agnostic_nms 是否执行类别无关的NMS。 max_det 每张图像最大检测数量的上限。 经过NMS处理后, pred 将只包含最终的检测结果。
            # def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300, nm=0,):
            # -> 用于在目标检测任务中执行非极大值抑制(Non-Maximum Suppression, NMS),以去除多余的边界框,只保留最佳的检测结果。函数返回最终的输出列表 output ,其中包含了批量中每个图像经过NMS处理后的检测结果。
            # -> return output
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

        # Second-stage classifier (optional)
        # 这是一个被注释掉的代码行,表示可选的第二阶段分类步骤。如果取消注释,这行代码将使用一个额外的分类器模型( classifier_model )来进一步分类NMS后的检测结果。
        # utils.general.apply_classifier 函数接受以下参数 : pred NMS后的预测结果。 classifier_model 第二阶段分类器模型。 im 预处理后的图像张量。 im0s 原始图像。 这个步骤可以用于提高特定类别的检测精度,特别是在模型对某些类别的检测不够确定时。
        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
        # 这段代码负责对模型的推理结果执行NMS,以去除重叠的检测框,并保留最有可能的检测结果。可选的第二阶段分类可以进一步提高检测的准确性,尤其是在需要对检测结果进行更细致分类的场景中。通过测量NMS步骤的时间,代码还有助于分析整个推理过程的性能。

        # 这段代码处理模型推理后的预测结果,并对每张图像进行后处理。
        # Process predictions
        # 这个 for 循环遍历 pred ,它包含了模型对每张图像的预测结果。 enumerate 函数同时提供索引 i 和值 det 。
        for i, det in enumerate(pred):  # per image
            # 变量 seen 用于计数处理过的图像数量,每处理一张图像, seen 就增加1。
            seen += 1
            # 条件判断检查输入源是否为摄像头。如果是摄像头, batch_size 至少为1,因为摄像头会连续提供图像帧。
            if webcam:  # batch_size >= 1
                # 如果是摄像头输入, path 和 im0s 都是列表,包含了每帧图像的路径和原始图像。这里分别获取第 i 帧的路径 p ,复制第 i 帧的原始图像 im0 ,以及从数据加载器 dataset 获取当前的帧计数 frame 。
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                # 如果是摄像头输入,将索引 i 添加到字符串 s 中,用于后续打印或记录每帧的结果。
                s += f'{i}: '
            # 如果输入源不是摄像头,执行这个分支。
            else:
                # 如果不是摄像头输入,将整个路径 path 和原始图像 im0s 赋值给 p 和 im0 。 getattr(dataset, 'frame', 0) 尝试从数据加载器 dataset 获取当前帧计数 frame ,如果 dataset 没有 frame 属性,则默认为0。
                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
        # 这段代码负责对每张图像的预测结果进行后处理。它区分了摄像头输入和非摄像头输入,分别处理每张图像的路径、原始图像和帧计数。这些信息对于后续的结果展示、保存和分析是必要的。通过累加 seen 变量,代码还跟踪了处理过的图像总数。

            # 这段代码继续处理模型推理后的预测结果,并准备将结果保存到文件中。
            # 将变量 p (图像路径)转换为 Path 对象,这是 pathlib 模块中的一个类,用于处理文件系统路径。
            p = Path(p)  # to Path
            # 构建了保存推理后图像的路径 save_path 。它将 save_dir (保存目录)与 p.name (图像文件名)组合在一起,并转换为字符串。
            save_path = str(save_dir / p.name)  # im.jpg
            # 构建了保存推理结果文本文件的路径 txt_path 。它将 save_dir (保存目录)、 'labels' (标签子目录)与 p.stem (图像文件的基本名称,无扩展名)组合在一起,并根据 dataset.mode 是否为 'image' 来决定是否添加帧编号 frame 。
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # im.txt
            # 将图像的尺寸(高度和宽度)添加到字符串 s 中,用于打印或记录图像的尺寸信息。
            s += '%gx%g ' % im.shape[2:]  # print string
            # 创建了一个 PyTorch 张量 gn ,它包含了原始图像 im0 的宽度和高度,用于后续的归一化增益计算。
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            # 如果设置了 save_crop (保存裁剪的预测框),复制 im0 到 imc ,否则直接使用 im0 。复制图像是为了在保存裁剪的预测框时保留原始图像。
            imc = im0.copy() if save_crop else im0  # for save_crop
            # 创建了一个 Annotator 实例,用于在图像 im0 上绘制边界框和标签。 line_width 参数设置边界框线的宽度, example 参数设置用于标注的类别名称。
            # class Annotator:
            # -> 用于目标检测任务中绘制边界框和标签。
            # -> def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            # 条件判断检查预测结果 det 是否非空。
            if len(det):
                # Rescale boxes from img_size to im0 size    说明接下来的代码将检测框从推理时的图像尺寸 im.shape[2:] 重新缩放到原始图像尺寸 im0.shape 。
                # 对检测框进行重新缩放。 scale_boxes 函数接受推理图像的尺寸、检测框坐标和原始图像的尺寸,返回重新缩放后的检测框坐标。 .round() 方法将坐标值四舍五入到最接近的整数。
                # def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): -> 用于目标检测任务中,当图像尺寸改变时,需要对检测到的边界框进行相应的缩放。返回重新缩放后的边界框数组。 -> return boxes
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
            # 这段代码负责将模型的预测结果转换为适用于原始图像尺寸的检测框,并准备将结果保存到文件中。它还创建了一个 Annotator 实例,用于在图像上绘制边界框和标签,并将结果路径和尺寸信息添加到字符串 s 中。这些步骤是将推理结果可视化和持久化到文件中的关键环节。

                # 这段代码负责打印每张图像的检测结果,特别是按类别统计检测到的对象数量。
                # Print results
                # 这个 for 循环遍历 det 数组中第6列(索引为5,因为索引从0开始)的唯一值。这一列通常包含了检测到的对象的 类别ID 。
                for c in det[:, 5].unique():
                    # 对于每个唯一的类别ID c ,计算检测到该类别对象的数量。它通过比较 det 数组第6列的每个值是否等于 c 来实现,返回一个布尔数组,然后使用 sum() 函数计算 True 的数量,即检测到的对象数量。
                    n = (det[:, 5] == c).sum()  # detections per class
                    # 将检测结果添加到字符串 s 中。对于每个类别,它将检测到的 对象数量 n 、 类别名称 names[int(c)] (从类别ID转换为类别名称) 以及根据数量确定的复数后缀(如果 n 大于1,则添加's')组合成一个字符串,并追加到 s 。
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
                # 这段代码通过迭代检测结果中的唯一类别ID,计算每个类别的检测数量,并将这些信息格式化为一个字符串。这个字符串 s 可以用来打印或记录每张图像的检测结果,显示每个类别检测到的对象数量和类别名称。这样的信息对于理解模型的推理输出非常有用,尤其是在需要按类别分析检测性能时。

                # 这段代码负责将推理结果写入文件,并在图像上绘制边界框和标签。
                # Write results

                # reversed(seq)
                # reversed() 是 Python 内置的一个函数,用于返回一个反向迭代器。它可以对任何可迭代对象(如列表、元组、字符串等)进行反向迭代。
                # 参数 :
                # seq :这是要反转的可迭代对象。可以是列表、元组、字符串等。
                # 返回值 :
                # 返回一个反向迭代器,可以通过 for 循环或其他迭代方式访问。
                # 注意事项 :
                # reversed() 不会修改原始序列,而是返回一个新的迭代器。
                # 如果传入的对象不支持反向迭代(例如,整数),将会引发 TypeError 。
                # 总结 :
                # reversed() 是一个非常实用的函数,特别是在需要以相反的顺序处理可迭代对象时。它提供了一种简单而有效的方法来反转迭代顺序,而不需要手动创建新的列表或其他数据结构。
                # 示例 :
                # 反转列表
                # list_reversed = reversed([1, 2, 3, 4])
                # print(list(reversed))  # 输出: [4, 3, 2, 1]
                # # 反转字符串
                # str_reversed = reversed("hello")
                # print(''.join(list(str_reversed)))  # 输出: "olleh"
                # # 反转元组
                # tuple_reversed = reversed((1, 2, 3, 4))
                # print(tuple(tuple_reversed))  # 输出: (4, 3, 2, 1)

                #  这个 for 循环遍历 det 数组的逆序,每次迭代获取一个检测结果。 *xyxy 表示边界框的四个坐标值(x1, y1, x2, y2), conf 是置信度, cls 是类别ID。
                for *xyxy, conf, cls in reversed(det):
                    # 如果设置了 save_txt (保存文本结果),执行以下步骤将结果写入文本文件。
                    if save_txt:  # Write to file
                        # 将边界框的坐标从 xyxy 格式转换为 xywh 格式(x中心点,y中心点,宽度,高度),并除以 gn (归一化增益)进行归一化,最后转换为列表。
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        # 构建一个元组 line ,包含 类别ID 、 归一化的 xywh 坐标 和 置信度 (如果设置了 save_conf )。如果不保存置信度,则不包括 conf 。
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                        # 打开文本文件路径 txt_path 附加模式( 'a' ),用于写入结果。
                        with open(f'{txt_path}.txt', 'a') as f:
                            # 将格式化后的检测结果写入文件。 '%g ' * len(line) 创建一个格式化字符串, .rstrip() 去除尾部空格, % line 插入数据, \n 添加换行符。
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    # 如果设置了 save_img (保存图像结果)、 save_crop (保存裁剪的边界框)或 view_img (显示图像结果),执行以下步骤在图像上绘制边界框和标签。
                    if save_img or save_crop or view_img:  # Add bbox to image
                        # 将类别ID cls 转换为整数。
                        c = int(cls)  # integer class
                        # 根据是否隐藏标签 hide_labels 和置信度 hide_conf ,设置标签文本。如果不隐藏标签且不隐藏置信度,则显示标签和置信度。
                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
                        # 使用 Annotator 实例在图像上绘制边界框和标签, xyxy 是边界框坐标, label 是标签文本, colors(c, True) 是根据类别ID获取的颜色。
                        # def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)): -> 用于在图像上绘制边界框并添加文本标签。这个函数根据类是否使用 PIL 或 OpenCV 来绘制不同的图形。
                        annotator.box_label(xyxy, label, color=colors(c, True))
                    # 如果设置了 save_crop (保存裁剪的边界框),执行以下步骤。
                    if save_crop:
                        # 调用 save_one_box 函数保存边界框裁剪的图像。 xyxy 是边界框坐标, imc 是用于裁剪的图像, file 是保存路径, BGR=True 表示图像是BGR格式。
                        # def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True): -> 用于保存图像裁剪。返回裁剪后的图像。 -> return crop
                        save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
                    # 这段代码负责将检测结果写入文本文件,并在图像上绘制边界框和标签。它还处理了保存裁剪的边界框图像的逻辑。这些步骤是将推理结果可视化和持久化到文件中的关键环节,对于后续的结果分析和验证非常重要。

            # 这段代码负责将推理结果流式传输到屏幕显示,如果启用了 view_img 选项。
            # Stream results
            # 调用 Annotator 实例的 result 方法,获取绘制了边界框和标签的图像 im0 。
            # def result(self): -> 用于将经过注释的图像转换为 NumPy 数组并返回。将 Annotator 类的 self.im 属性转换为 NumPy 数组。 -> return np.asarray(self.im)
            im0 = annotator.result()
            # 条件判断检查是否设置了 view_img (显示图像结果)。
            if view_img:
                # 在Linux系统上,如果路径 p (图像路径)不在 windows 列表中,执行以下步骤。
                if platform.system() == 'Linux' and p not in windows:
                    # 将路径 p 添加到 windows 列表中,这个列表可能用于跟踪已经创建的窗口。
                    windows.append(p)

                    # cv2.namedWindow(winname, flags=8)
                    # cv2.namedWindow() 函数是 OpenCV 库中用于创建一个窗口的函数。这个函数允许你在屏幕上创建一个窗口,并且可以对这个窗口进行一些配置。
                    # 参数 :
                    # winname :窗口的名称,创建的窗口将以此名称标识。
                    # flags :窗口的属性,用于指定窗口的一些特性。 flags 参数是一个标志的组合,可以是以下值之一或几个值的组合 :
                    # cv2.WINDOW_NORMAL :创建一个可调整大小的窗口。如果没有设置此标志,窗口将是固定的。
                    # cv2.WINDOW_AUTOSIZE :创建一个自动调整大小的窗口,这是默认值。
                    # cv2.WINDOW_OPENGL :创建一个使用OpenGL的窗口。
                    # cv2.WND_PROP_FULLSCREEN :窗口属性标志,用于在窗口中设置全屏属性。
                    # cv2.WND_PROP_AUTOSIZE :窗口属性标志,用于在窗口中设置自动大小属性。
                    # cv2.WND_PROP_ASPECT_RATIO :窗口属性标志,用于在窗口中设置保持宽高比属性。
                    # cv2.WND_PROP_TOPMOST :窗口属性标志,用于在窗口中设置窗口始终处于最顶层的属性。
                    # 注意事项 :
                    # 在调用 cv2.namedWindow() 创建窗口后,通常需要使用 cv2.imshow() 函数将图像显示在窗口中。
                    # 在程序结束时,可以使用 cv2.destroyWindow() 或 cv2.destroyAllWindows() 来关闭窗口。
                    # cv2.namedWindow() 函数需要在调用 cv2.imshow() 之前调用,因为 cv2.imshow() 会使用 cv2.WINDOW_AUTOSIZE 标志创建一个窗口,除非你已经为该窗口名创建了一个窗口。
                    # 通过使用 cv2.namedWindow() ,可以在OpenCV应用程序中创建和管理多个窗口,这对于显示多个图像或结果非常有用。

                    # 使用 cv2.namedWindow 创建一个命名窗口,允许用户调整窗口大小,并保持窗口的宽高比。
                    cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)

                    # cv2.resizeWindow(winname, width, height)
                    # cv2.resizeWindow() 函数是 OpenCV 库中用于调整窗口大小的函数。这个函数允许在程序运行时动态地改变窗口的尺寸。
                    # 参数 :
                    # winname :窗口的名称,要调整大小的窗口以此名称标识。
                    # width :新的窗口宽度,以像素为单位。
                    # height :新的窗口高度,以像素为单位。
                    # 注意事项 :
                    # 确保在调用 cv2.resizeWindow() 之前已经使用 cv2.namedWindow() 创建了窗口,并且指定了窗口名称。
                    # cv2.resizeWindow() 函数只改变窗口的客户区域大小,不会影响窗口的外部框架或标题栏。
                    # 如果窗口是通过 cv2.WINDOW_AUTOSIZE 标志创建的,那么图像会根据窗口大小自动缩放。如果窗口是通过 cv2.WINDOW_NORMAL 标志创建的,那么图像不会自动缩放,需要手动调整图像大小以适应新窗口。
                    # 在某些操作系统和OpenCV版本中,频繁地调整窗口大小可能会导致性能问题或窗口闪烁。
                    # 附加信息 :
                    # 如果需要在窗口创建时就指定初始大小,可以在调用 cv2.namedWindow() 时使用 cv2.WINDOW_NORMAL 或 cv2.WINDOW_AUTOSIZE 标志,并在创建窗口后立即调用 cv2.resizeWindow() 。
                    # 在某些情况下,可能需要在调整窗口大小时保持窗口的宽高比,这时需要根据窗口的新尺寸和图像的宽高比来计算图像的新尺寸,并使用 cv2.resize() 函数调整图像大小。
                    # 通过使用 cv2.resizeWindow() ,可以为用户提供更灵活的界面,允许他们根据需要调整窗口大小,以便更好地查看图像或视频内容。

                    # 使用 cv2.resizeWindow 设置窗口的大小,宽度为 im0.shape[1] ,高度为 im0.shape[0] 。
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
                #  使用 cv2.imshow 将绘制了边界框和标签的图像 im0 显示在命名窗口中。
                cv2.imshow(str(p), im0)
                # 使用 cv2.waitKey 等待1毫秒,以便处理任何窗口事件,如键盘输入或窗口关闭事件。这个短暂的等待时间允许图像窗口更新显示,同时不会阻塞程序的执行。
                cv2.waitKey(1)  # 1 millisecond
            # 这段代码负责在启用 view_img 时,将推理结果实时显示在屏幕上。它特别处理了Linux系统上的窗口创建和大小调整,确保图像能够正确显示。通过 cv2.imshow 和 cv2.waitKey 的循环调用,代码实现了图像结果的流式传输和实时显示,这对于实时监控和验证推理结果非常有用。

            # 这段代码负责将包含检测结果的图像或视频帧保存到文件。
            # Save results (image with detections)
            # 条件判断检查是否设置了 save_img (保存图像结果)。
            if save_img:
                # 如果数据加载器 dataset 的模式是 'image' ,表示处理的是单张图像。
                if dataset.mode == 'image':
                    # 使用 cv2.imwrite 函数将绘制了边界框和标签的图像 im0 保存到 save_path 指定的路径。
                    cv2.imwrite(save_path, im0)
                # 如果数据加载器 dataset 的模式不是 'image' ,那么处理的是视频或流媒体。
                else:  # 'video' or 'stream'
                    # 条件判断检查当前索引 i 对应的视频路径 vid_path[i] 是否与新的保存路径 save_path 不同,如果是,则表示需要处理一个新的视频文件。
                    if vid_path[i] != save_path:  # new video
                        # 更新当前索引 i 对应的视频路径为新的保存路径。
                        vid_path[i] = save_path
                        # 条件判断检查 vid_writer[i] 是否是一个 cv2.VideoWriter 实例。
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            # 如果是,调用 release 方法释放之前的 VideoWriter 实例,以便创建新的实例。
                            vid_writer[i].release()  # release previous video writer
                        # 如果 vid_cap (视频捕获对象)存在,表示处理的是视频文件。
                        if vid_cap:  # video
                            # 获取视频的帧率。
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            # 获取视频的宽度。
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            # 获取视频的高度。
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        # 如果不是处理视频文件,而是流媒体。
                        else:  # stream
                            # 设置默认的帧率、宽度和高度。
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                        # 强制将结果视频的保存路径 save_path 的后缀设置为 .mp4 。
                        save_path = str(Path(save_path).with_suffix('.mp4'))  # force *.mp4 suffix on results videos
                        # 创建一个新的 cv2.VideoWriter 实例,用于将视频帧写入文件。
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                    # 将绘制了边界框和标签的图像 im0 写入视频文件。
                    vid_writer[i].write(im0)
            # 这段代码负责将包含检测结果的图像或视频帧保存到文件。对于图像,它直接使用 cv2.imwrite 函数保存。对于视频或流媒体,它检查是否需要创建新的 VideoWriter 实例,并使用 cv2.VideoWriter 将帧写入视频文件。这个过程确保了推理结果可以被持久化存储,便于后续的分析和验证。

        # 这段代码负责在推理过程中打印出处理每张图像所需的时间,以及检测结果的概要。
        # Print time (inference-only)
        # 使用 LOGGER 对象的 info 方法来记录信息级别的日志。
        # f 表示这是一个格式化字符串(f-string),用于构建最终的字符串。
        # s 是之前构建的包含检测结果信息的字符串。
        # len(det) 检查 det 数组(包含检测结果)的长度,即检测到的对象数量。
        # 如果 det 数组非空(即检测到了对象),则在字符串 s 后直接添加推理时间;如果 det 数组为空(即没有检测到对象),则在 s 后添加 '(no detections), ' 。
        # dt[1].dt 是 dt 元组中第二个 Profile 实例记录的时间,它对应于 推理步骤 的时间。
        # dt[1].dt * 1E3 将时间从秒转换为毫秒(乘以1000)。
        # .1f 指定浮点数的格式化方式,保留一位小数。
        LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
        # 这段代码在处理每张图像后打印出检测结果的概要和推理时间,这对于性能分析和监控模型的推理效率非常有用。如果检测到了对象,它会显示检测结果的详细信息和推理时间;如果没有检测到对象,它会显示"(no detections)"和推理时间。这样的日志信息有助于调试和优化模型的推理过程。

    # 这段代码负责在推理过程结束后打印性能统计和结果保存信息。
    # Print results
    # 计算每个阶段( 预处理 、 推理 、 NMS )的速度,单位是毫秒(ms)。 x.t 是 Profile 对象记录的时间, seen 是处理的图像数量。 列表推导式计算每个阶段的时间,并将其转换为毫秒。 tuple 将结果转换为元组。
    t = tuple(x.t / seen * 1E3 for x in dt)  # speeds per image
    # 使用 LOGGER 对象的 info 方法记录每个阶段的速度。 %t 是字符串格式化,用于插入计算出的速度。 (1, 3, *imgsz) 表示推理时的图像形状,其中 1 是批处理大小, 3 是颜色通道数, *imgsz 是图像的尺寸(宽度和高度)。
    LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)    # 速度:%.1fms 预处理,%.1fms 推理,形状为 {(1, 3, *imgsz)} 的每个图像 %.1fms NMS 。
    # 条件判断检查是否设置了保存文本结果或图像结果。
    if save_txt or save_img:
        # 如果设置了保存文本结果,计算并构建一个字符串 s ,显示 保存的 标签文件数量 和 保存路径 。 save_dir.glob('labels/*.txt') 获取 save_dir/labels 目录下所有的 .txt 文件。 len(list(...)) 计算这些文件的数量。 save_dir / 'labels' 构建保存路径。
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''    # {len(list(save_dir.glob('labels/*.txt')))} 标签保存至 {save_dir / 'labels' 。
        # 使用 LOGGER 对象的 info 方法记录结果保存的路径和额外信息 s 。 colorstr('bold', save_dir) 用来给保存路径 save_dir 添加粗体格式的函数。
        # def colorstr(*input): -> 构建并返回最终的字符串。它首先通过列表推导式和 join 函数将所有颜色和样式的 ANSI 代码连接起来,然后加上要着色的字符串 string ,最后加上 colors['end'] 来重置样式,确保之后的输出不会受到颜色代码的影响。 -> return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")    # 结果保存至 {colorstr('bold', save_dir)}{s} 。
    # 条件判断检查是否设置了更新模型。
    if update:
        # 如果设置了更新模型,调用 strip_optimizer 函数来移除模型权重文件中的优化器状态,这通常用于在模型权重文件被复制或移动后避免 PyTorch 抛出 SourceChangeWarning 警告。
        # def strip_optimizer(f='best.pt', s=''): -> 从训练好的模型文件中移除优化器(optimizer)和其他一些不需要的键,以便在不继续训练的情况下使用模型进行推理。这通常用于模型的部署阶段,因为优化器在推理时是不需要的。
        strip_optimizer(weights[0])  # update model (to fix SourceChangeWarning)
    # 这段代码在推理过程结束后提供了性能反馈和结果保存的确认。它显示了每个阶段的处理速度,并在有结果保存时提供了保存位置和数量的信息。如果设置了模型更新,它还会对模型权重文件进行处理,以避免潜在的警告。这些信息对于评估模型性能和确保结果被正确保存非常重要。
# 这个函数是一个完整的推理流程,包括模型加载、数据加载、推理、NMS、结果处理和保存。它支持多种输入源和输出格式,以及多种推理配置选项。通过这个函数,用户可以灵活地对图像或视频进行推理,并根据需要保存和显示结果。

3.def parse_opt():

python 复制代码
# 这段代码定义了一个名为 parse_opt 的函数,它使用 argparse 库来解析命令行参数。这个函数用于设置和获取运行模型推理所需的配置参数。
def parse_opt():
    # 创建一个新的参数解析器对象。
    parser = argparse.ArgumentParser()
    # parser.add_argument(...) 为解析器添加命令行参数。每个 add_argument 调用都定义了一个参数,包括它的类型、默认值、帮助信息等。
    # --weights :模型权重文件的路径或Triton URL。
    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model path or triton URL')
    # --source :输入源,可以是文件、目录、URL、glob模式、屏幕捕获或摄像头。
    parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
    # --data :数据集配置文件的路径(可选)。
    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
    # --imgsz :推理时的图像尺寸(高度和宽度)。
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
    # --conf-thres :置信度阈值。
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
    # --iou-thres :NMS的IOU阈值。
    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
    # --max-det :每张图像的最大检测数量。
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
    # --device :指定使用的设备,如GPU编号或CPU。
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    # --view-img :显示结果图像。
    parser.add_argument('--view-img', action='store_true', help='show results')
    # --save-txt :将结果保存到文本文件。
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    # --save-conf :在保存的文本文件中包含置信度。
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    # --save-crop :保存裁剪的预测框。
    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
    # --nosave :不保存图像或视频。
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    # --classes :按类别过滤检测结果。
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
    # --agnostic-nms :执行类别无关的NMS。
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    # --augment :进行增强推理。
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    # --visualize :可视化特征。
    parser.add_argument('--visualize', action='store_true', help='visualize features')
    # --update :更新所有模型。
    parser.add_argument('--update', action='store_true', help='update all models')
    # --project :保存结果的项目目录。
    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
    # --name :保存结果的名称。
    parser.add_argument('--name', default='exp', help='save results to project/name')
    # --exist-ok :如果项目/名称已存在,则不增加后缀。
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    # --line-thickness :边界框的线宽。
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
    # --hide-labels :隐藏标签。
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
    # --hide-conf :隐藏置信度。
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
    # --half :使用FP16半精度推理。
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
    # --dnn :使用OpenCV DNN进行ONNX推理。
    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
    # --vid-stride :视频帧率步长。
    parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
    # 解析命令行参数,并将解析结果存储在 opt 变量中。
    opt = parser.parse_args()
    # 如果 --imgsz 参数只提供了一个值,将其乘以2以扩展为高度和宽度(即 [640] 变为 [640, 640] )。
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
    # 打印所有解析后的参数及其值。
    print_args(vars(opt))
    # 返回包含所有参数的 opt 对象。
    return opt
# parse_opt 函数是一个参数解析器,它定义了运行模型推理所需的所有配置参数,并从命令行获取这些参数。这个函数使得用户可以通过命令行灵活地配置推理过程,并且可以轻松地获取和使用这些参数。

4.def main(opt):

python 复制代码
# 这段代码定义了一个名为 main 的函数,它负责执行模型推理流程的主要步骤。
# 定义一个名为 main 的函数,它接受一个参数。
# 1.opt :这个参数是一个包含所有命令行参数的对象。
def main(opt):
    # 调用 check_requirements 函数来检查是否安装了运行模型所需的依赖库。 exclude 参数指定了不需要检查的依赖库列表,这里排除了 tensorboard 和 thop (Tensor Layer的运算复杂度分析工具)。
    check_requirements(exclude=('tensorboard', 'thop'))
    # 调用 run 函数来执行模型推理。 vars(opt) 将 opt 对象转换为字典,这样可以通过关键字参数的方式将所有的命令行参数传递给 run 函数。
    run(**vars(opt))
# main 函数是程序的入口点,它首先检查必要的依赖库是否已经安装(除了指定排除的库),然后调用 run 函数并传递所有解析后的命令行参数。这样的设计使得程序的结构清晰,易于维护,并且可以灵活地处理不同的配置参数。

5.if name == "main":

python 复制代码
# 这段代码是Python脚本中常用的模式,用于判断当前脚本是否作为主程序运行。
# 这是一个特殊的块,当Python文件被直接运行时(而不是被导入到另一个文件中), __name__ 变量的值会被设置为 "__main__" 。
if __name__ == "__main__":
    # 调用 parse_opt 函数来解析命令行参数,并将返回的参数对象赋值给变量 opt 。
    opt = parse_opt()
    # 将 parse_opt 函数返回的参数对象 opt 传递给 main 函数,启动整个推理流程。
    main(opt)
# 这段代码确保了当脚本被直接执行时,会解析命令行参数并调用 main 函数来运行程序。如果脚本是被其他脚本导入的,则不会执行这些代码,这有助于避免不必要的代码执行和潜在的命名空间冲突。这是一种常见的实践,用于在Python中定义可执行的脚本和模块。
相关推荐
没学上了2 分钟前
逻辑回归机器学习
人工智能·深度学习·逻辑回归
!!!5252 分钟前
Sentinel 笔记
笔记·sentinel
CITY_OF_MO_GY20 分钟前
Spark-TTS:基于大模型的文本语音合成工具
人工智能·深度学习·语音识别
阿丢是丢心心26 分钟前
【从0到1搞懂大模型】神经网络的实现:数据策略、模型调优与评估体系(3)
人工智能·深度学习·神经网络
何大春34 分钟前
【对话推荐系统综述】Broadening the View: Demonstration-augmented Prompt Learning for CR
论文阅读·人工智能·深度学习·语言模型·prompt·论文笔记
CoovallyAIHub40 分钟前
一码难求的Manus,又对计算机视觉产生冲击?复刻开源版已在路上!
人工智能·深度学习·计算机视觉
是理不是里_41 分钟前
人工智能里的深度学习指的是什么?
人工智能·深度学习
银河小铁骑plus1 小时前
Go学习笔记:基础语法6
笔记·学习·golang
321Leo1231 小时前
Kaggle 经典比赛 Shopee - Price Match Guarantee(Shopee商品匹配大赛) 高分方案解析
人工智能·深度学习
机器懒得学习2 小时前
基于深度学习的恶意软件检测系统:设计与实现
人工智能·深度学习