predictor.py
ultralytics\engine\predictor.py
目录
[2.class BasePredictor:](#2.class BasePredictor:)
1.所需的库和模块
python
# Ultralytics YOLO 🚀, AGPL-3.0 license
# 对图像、视频、目录、glob、YouTube、网络摄像头、流等运行预测。
# 使用 - 格式:
# $ yolo mode=predict model=yolov8n.pt # PyTorch
"""
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources:
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/LNwODJXcvt4' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
Usage - formats:
$ yolo mode=predict model=yolov8n.pt # PyTorch
yolov8n.torchscript # TorchScript
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolov8n_openvino_model # OpenVINO
yolov8n.engine # TensorRT
yolov8n.mlpackage # CoreML (macOS-only)
yolov8n_saved_model # TensorFlow SavedModel
yolov8n.pb # TensorFlow GraphDef
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
yolov8n_ncnn_model # NCNN
"""
import platform
import re
import threading
from pathlib import Path
import cv2
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data import load_inference_source
from ultralytics.data.augment import LetterBox, classify_transforms
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.files import increment_path
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
# 警告 ⚠️ 除非传递了 `stream=True`,否则推理结果将累积在 RAM 中,这可能会导致大型源或长时间运行的流和视频出现内存不足错误。请参阅 https://docs.ultralytics.com/modes/predict/ 获取帮助。
# 示例:
# results = model(source=..., stream=True) # Results 对象的生成器
# for r in results:
# boxes = r.boxes # 用于 bbox 输出的 Boxes 对象
# mask = r.masks # 用于段掩码输出的 Masks 对象
# probs = r.probs # 用于分类输出的类概率
STREAM_WARNING = """
WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
Example:
results = model(source=..., stream=True) # generator of Results objects
for r in results:
boxes = r.boxes # Boxes object for bbox outputs
masks = r.masks # Masks object for segment masks outputs
probs = r.probs # Class probabilities for classification outputs
"""
2.class BasePredictor:
python
# 这段代码定义了一个名为 BasePredictor 的类,用于实现图像或视频的推理(inference)功能。它提供了一个完整的框架,用于加载模型、预处理输入数据、执行推理、后处理结果,并将结果保存或显示出来。
# 定义了一个名为 BasePredictor 的类,它是推理功能的核心类。
class BasePredictor:
# 用于创建预测器的基类。
"""
BasePredictor.
A base class for creating predictors.
Attributes:
args (SimpleNamespace): Configuration for the predictor.
save_dir (Path): Directory to save results.
done_warmup (bool): Whether the predictor has finished setup.
model (nn.Module): Model used for prediction.
data (dict): Data configuration.
device (torch.device): Device used for prediction.
dataset (Dataset): Dataset used for prediction.
vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
"""
# 这段代码定义了 BasePredictor 类的初始化方法 __init__ ,用于设置类的基本属性和初始化状态。
# 定义了 BasePredictor 类的初始化方法。它接受三个参数。
# 1.cfg :配置文件路径,默认值为 DEFAULT_CFG 。
# 2.overrides :配置覆盖参数,用于覆盖默认配置。
# 3._callbacks :回调函数集合,用于扩展功能。
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
# 初始化 BasePredictor 类。
"""
Initializes the BasePredictor class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
# 通过 get_cfg 函数加载配置文件,并应用 overrides 中的覆盖参数。结果存储在 self.args 中, self.args 是一个 包含所有配置参数的对象 。
# def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
# -> 用于处理和返回一个配置对象。它支持从多种输入类型(如字符串、路径、字典或 SimpleNamespace )加载配置,并允许用户通过覆盖( overrides )来修改默认配置。将最终的配置字典转换为 IterableSimpleNamespace 对象并返回。
# -> return IterableSimpleNamespace(**cfg)
self.args = get_cfg(cfg, overrides)
# 调用 get_save_dir 函数,根据配置 生成保存结果的目录路径 ,并存储在 self.save_dir 中。
# def get_save_dir(args, name=None): -> 根据训练、验证或预测的参数返回一个保存目录( save_dir )。这个函数处理了多种情况,确保保存目录的路径是唯一且有效的。返回生成的保存目录路径,确保返回值是一个 Path 对象。 -> return Path(save_dir)
self.save_dir = get_save_dir(self.args)
# 如果配置中未指定置信度阈值( conf )。
if self.args.conf is None:
# 则默认设置为 0.25。置信度阈值用于过滤低置信度的预测结果。
self.args.conf = 0.25 # default conf=0.25
# 初始化一个标志变量 done_warmup ,用于标记模型是否已完成预热(warmup)。预热是指在正式推理之前对模型进行一次空运行,以确保模型加载到 GPU 并准备好。
self.done_warmup = False
# 如果配置中启用了显示结果的功能( self.args.show )。
if self.args.show:
# 则调用 check_imshow 函数检查是否支持显示(例如是否安装了必要的库)。如果环境不支持显示,则发出警告。
# def check_imshow(warn=False): -> 用于检查当前环境是否支持使用 cv2.imshow() 显示图像。如果上述操作成功执行,则返回 True ,表示当前环境支持 cv2.imshow() 。返回 False ,表示当前环境不支持 cv2.imshow() 。 -> return True / return False
self.args.show = check_imshow(warn=True)
# Usable if setup is done
# 初始化 模型对象 为 None ,模型将在后续通过 setup_model 方法加载。
self.model = None
# 将配置中的 数据字典 ( data )存储在 self.data 中。这个字典通常包含数据集的路径、类别信息等。
self.data = self.args.data # data_dict
# 初始化类的各种属性。
# 图像大小。
self.imgsz = None
# 推理设备(如 GPU 或 CPU)。
self.device = None
# 数据集对象。
self.dataset = None
# 一个字典,用于存储视频写入器( cv2.VideoWriter )。
self.vid_writer = {} # dict of {save_path: video_writer, ...}
# 绘制了预测结果的图像。
self.plotted_img = None
# 数据源类型(如图像、视频或流)。
self.source_type = None
# 已处理的图像数量。
self.seen = 0
# 用于显示图像的窗口列表。
self.windows = []
# 当前批次的数据。
self.batch = None
# 推理结果。
self.results = None
# 数据预处理转换。
self.transforms = None
# 初始化回调函数集合。如果传入了 _callbacks ,则使用传入的回调;否则,调用 callbacks.get_default_callbacks() 获取默认回调函数。
self.callbacks = _callbacks or callbacks.get_default_callbacks()
# 初始化保存预测结果的文本路径为 None 。
self.txt_path = None
# lock = threading.Lock()
# threading.Lock() 是 Python 标准库 threading 模块中提供的一个同步原语,用于在多线程环境中实现线程安全的互斥锁(Mutex)。它确保在任何时刻只有一个线程可以访问共享资源,从而避免多线程并发访问导致的竞争条件(race condition)。
# 通过调用 threading.Lock() 创建一个锁对象。锁对象初始状态为 未锁定(unlocked)。
# 主要方法互斥锁对象提供了以下主要方法,用于控制锁的状态 :
# acquire(blocking=True, timeout=-1) :尝试获取锁。如果锁已被其他线程占用,则当前线程会阻塞,直到锁被释放。 如果成功获取锁,返回 True 。 如果未获取锁(如超时或非阻塞模式下锁已被占用),返回 False 。
# release() :释放锁。如果当前线程未持有锁,则会引发 RuntimeError 。 将锁的状态从 锁定(locked)变为 未锁定(unlocked)。 如果有其他线程正在等待锁,其中一个线程会获取锁并继续执行。
# locked() :检查锁是否已被占用。 如果锁已被占用,返回 True 。 如果锁未被占用,返回 False 。
# threading.Lock() 是一个简单的互斥锁机制,用于保护共享资源,确保多线程环境下的线程安全。它提供了 acquire 、 release 和 locked 方法,分别用于获取锁、释放锁和检查锁状态。通过使用锁,可以有效避免多线程并发访问导致的竞争条件和数据不一致问题。
# 初始化一个线程锁( threading.Lock ),用于确保推理过程的线程安全性,特别是在多线程环境中。
self._lock = threading.Lock() # for automatic thread-safe inference
# 调用 callbacks.add_integration_callbacks 方法,将一些集成回调函数添加到当前实例中。
callbacks.add_integration_callbacks(self)
# 这段代码的核心功能是初始化一个用于推理的预测器实例。它加载配置文件,设置保存路径,检查显示功能是否可用,并初始化模型、数据集、设备等关键属性。此外,它还设置了线程锁以确保推理过程的线程安全性,并注册了默认的回调函数。这些初始化操作为后续的推理流程(如预处理、模型加载、推理执行、结果保存等)奠定了基础。
# 这段代码定义了 BasePredictor 类中的 preprocess 方法,用于对输入图像进行预处理,以便将其转换为适合模型推理的格式。
# 定义了 preprocess 方法,它接受一个参数。
# 1.im :即待处理的图像数据。
def preprocess(self, im):
# 在推理之前准备输入图像。
"""
Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
"""
# 检查输入 im 是否为 PyTorch 张量。如果不是张量,则 not_tensor 为 True ,表示需要对输入进行格式转换。
not_tensor = not isinstance(im, torch.Tensor)
# 如果输入不是张量。
if not_tensor:
# 调用 self.pre_transform(im) 方法对输入图像进行预处理(例如调整大小、填充等)。使用 np.stack 将处理后的图像列表堆叠为一个 NumPy 数组。
im = np.stack(self.pre_transform(im))
# 对 NumPy 数组进行以下转换。
# m[..., ::-1] :将图像从 BGR 格式转换为 RGB 格式(假设输入是 BGR 格式)。
# .transpose((0, 3, 1, 2)) :将图像从 BHWC(批量、高度、宽度、通道)格式转换为 BCHW(批量、通道、高度、宽度)格式,以适配 PyTorch 的输入格式。
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
# 确保 NumPy 数组是连续的(contiguous)。这一步是为了优化内存布局,避免在后续转换为 PyTorch 张量时出现性能问题。
im = np.ascontiguousarray(im) # contiguous
# 将 NumPy 数组转换为 PyTorch 张量。
im = torch.from_numpy(im)
# 将张量移动到指定的设备(如 GPU 或 CPU),设备由 self.device 指定。
im = im.to(self.device)
# 根据模型是否支持半精度(FP16)推理,将张量的数据类型从 uint8 转换为 float16 或 float32 。如果模型支持 FP16,则调用 im.half() 。 否则,调用 im.float() 。
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
# 如果输入不是张量(即输入是原始图像数据),将像素值从 [0, 255] 范围归一化到 [0.0, 1.0]。
if not_tensor:
im /= 255 # 0 - 255 to 0.0 - 1.0
# 返回预处理后的图像张量。
return im
# 这段代码的核心功能是对输入图像进行预处理,使其符合模型推理所需的格式。格式转换:将输入从 NumPy 数组转换为 PyTorch 张量。颜色空间转换:将图像从 BGR 格式转换为 RGB 格式。内存布局调整:将图像从 BHWC 格式转换为 BCHW 格式。数据类型转换:根据模型是否支持 FP16 推理,将数据类型转换为 float16 或 float32 。归一化:将像素值从 [0, 255] 范围归一化到 [0.0, 1.0]。这些步骤确保了输入图像能够被模型正确处理,同时优化了推理性能。
# 这段代码定义了 BasePredictor 类中的 inference 方法,用于执行模型的推理过程。
# 定义了 inference 方法,输入参数包括 :
# 1.im :预处理后的输入图像张量。
# 2.*args 和 3.**kwargs :可选的额外参数,用于传递给模型的推理方法。
def inference(self, im, *args, **kwargs):
# 使用指定的模型和参数对给定的图像运行推理。
"""Runs inference on a given image using the specified model and arguments."""
# 这一行代码的作用是根据配置决定是否启用可视化,并设置可视化路径。如果满足条件, visualize 被设置为 生成的路径 。 否则, visualize 被设置为 False ,表示不启用可视化。
visualize = (
# 路径生成。
# 如果满足上述条件,则调用 increment_path 函数生成可视化路径。路径基于以下内容构建 :
# self.save_dir :保存结果的根目录。
# Path(self.batch[0][0]).stem :当前批次中第一个图像的文件名(不包含扩展名)。
# mkdir=True :确保目标目录存在,如果不存在则创建。
# def increment_path(path, exist_ok=False, sep="", mkdir=False): -> 用于处理路径的增量创建,即当路径已存在时,自动创建一个增量的路径版本。返回处理后的路径 path 。 -> return path
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
# 条件判断。
# self.args.visualize :检查是否启用了可视化功能。
# not self.source_type.tensor :确保数据源不是张量(即不是直接从内存传递的张量数据)。
if self.args.visualize and (not self.source_type.tensor)
else False
)
# 调用模型的推理方法,传入以下参数。
# im :预处理后的输入图像张量。
# augment=self.args.augment :是否启用数据增强(如多尺度推理)。
# visualize=visualize :是否启用可视化,以及可视化的路径(如果 visualize 为 False ,则不启用可视化)。
# embed=self.args.embed :是否启用嵌入(embedding)功能。
# *args 和 **kwargs :传递给模型的其他参数。
# 最后,返回模型的推理结果。
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
# 这段代码的核心功能是执行模型的推理过程,并根据配置动态调整推理行为。可视化路径生成:根据配置和数据源类型决定是否启用可视化。如果启用可视化,则生成保存可视化结果的路径。模型推理调用:将输入图像和相关配置参数传递给模型的推理方法。支持数据增强、可视化和嵌入等可选功能。返回模型的推理结果。通过这种方式, inference 方法能够灵活地适应不同的推理场景和配置需求,同时确保推理过程的高效性和可扩展性。
# 在深度学习和计算机视觉中,嵌入(embedding)功能通常指的是将输入数据(如图像、文本或其他结构化数据)映射到一个低维向量空间的过程。这个低维向量空间通常被称为嵌入空间(embedding space)。嵌入功能的核心目的是将复杂的输入数据转换为一个紧凑的、具有语义意义的向量表示,以便后续任务(如分类、聚类、检索等)能够更高效地进行。
# 嵌入功能的作用 :
# 嵌入功能在不同的应用场景中有不同的具体含义,但其核心目标是提取数据的特征表示,同时保留数据的重要语义信息。以下是嵌入功能在一些常见任务中的具体应用 :
# 图像检索(Image Retrieval) :
# 在图像检索任务中,嵌入功能用于将图像映射到一个低维特征空间。通过计算嵌入向量之间的相似度(如欧氏距离、余弦相似度等),可以快速找到与查询图像相似的图像。例如, CLIP 模型可以将图像和文本映射到同一个嵌入空间,从而实现图像和文本的跨模态检索。
# 分类任务(Classification) :
# 在分类任务中,嵌入功能可以用于提取图像或数据的特征表示,然后将这些特征输入到分类器(如全连接层、支持向量机等)中进行分类。嵌入向量通常能够更好地捕捉数据的内在结构,从而提高分类性能。
# 聚类(Clustering) :
# 嵌入功能可以将数据映射到一个低维空间,在这个空间中,相似的数据点会更接近。通过聚类算法(如 K-Means、DBSCAN 等),可以对嵌入向量进行聚类,从而发现数据的内在结构或类别。
# 无监督学习(Unsupervised Learning) :
# 在无监督学习中,嵌入功能用于将数据映射到一个低维空间,以便后续任务(如降维、可视化等)能够更高效地进行。例如, t-SNE 和 UMAP 是两种流行的嵌入方法,用于将高维数据嵌入到二维或三维空间中,以便可视化。
# 特征提取(Feature Extraction) :
# 嵌入功能可以看作是一种特征提取方法,它将原始数据转换为更紧凑、更有信息量的特征表示。这些特征可以用于后续的各种任务,如分类、回归、聚类等。
# 嵌入功能在代码中的体现 :
# 在提到的代码中, embed=self.args.embed 是传递给模型的一个参数,用于控制是否启用嵌入功能。具体来说 :
# 如果 self.args.embed 为 True ,模型可能会输出一个嵌入向量(而不是直接的分类结果或其他输出)。
# 这个嵌入向量可以用于后续的相似性计算、聚类或其他任务。
# 例如,如果模型是一个用于图像检索的网络,启用嵌入功能后,模型会输出每个图像的嵌入向量,而不是直接的分类标签。
# 总结 :嵌入功能是一种将输入数据映射到低维向量空间的技术,其目的是提取数据的特征表示,同时保留重要的语义信息。嵌入功能在图像检索、分类、聚类、无监督学习等任务中都有广泛的应用。在代码中,嵌入功能通常通过模型的特定参数(如 embed )启用,并输出嵌入向量供后续任务使用。
# 这段代码定义了 BasePredictor 类中的 pre_transform 方法,用于对输入图像进行预处理,确保它们符合模型推理所需的格式。预处理的核心是使用 LetterBox 类对图像进行调整,以满足模型对输入尺寸的要求。
# 定义了 pre_transform 方法,它接受一个参数。
# 1.im :它是一个包含多个图像的列表(每个图像都是一个 NumPy 数组)。
def pre_transform(self, im):
# 在推理之前对输入图像进行预转换。
# 参数:
# im (List(np.ndarray)):张量为 (N, 3, h, w),列表为 [(h, w, 3) x N]。
"""
Pre-transform input image before inference.
Args:
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
Returns:
(list): A list of transformed images.
"""
# 检查输入图像是否具有相同的形状。使用集合推导式 {x.shape for x in im} 获取所有图像的形状。 如果所有图像的形状相同,则集合的长度为 1, same_shapes 为 True ;否则为 False 。
same_shapes = len({x.shape for x in im}) == 1
# 创建一个 LetterBox 对象,用于对图像进行调整。
# self.imgsz :目标图像大小(通常是模型的输入尺寸)。
# auto :是否自动调整图像大小。如果所有图像形状相同且模型支持 PyTorch( self.model.pt 为 True ),则设置为 True 。
# stride :模型的步幅(stride),用于确保图像尺寸与模型的步幅对齐。
# class LetterBox:
# -> 用于将图像调整为指定的尺寸,同时保持图像的纵横比。这种变换常用于数据预处理,特别是在图像分类和目标检测任务中。
# -> def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
# 对输入图像列表中的每个图像应用 LetterBox 转换. letterbox(image=x) 将每个图像调整为目标尺寸,同时保持宽高比。 返回一个包含转换后图像的列表。
return [letterbox(image=x) for x in im]
# 这段代码的核心功能是对输入图像进行预处理,确保它们符合模型推理所需的格式。检查图像形状是否一致:通过集合推导式判断所有输入图像是否具有相同的形状。创建 LetterBox 对象:根据目标尺寸、是否自动调整和模型步幅初始化 LetterBox 。应用 LetterBox 转换:对每个输入图像进行调整,保持宽高比并添加填充,最终返回调整后的图像列表。这种预处理方式特别适用于计算机视觉任务中,尤其是当模型对输入尺寸有严格要求时(如 YOLO、YOLOv5 等目标检测模型)。
# 这段代码定义了 BasePredictor 类中的 postprocess 方法,用于对模型的推理结果进行后处理。然而,当前实现非常简单,只是直接返回了模型的预测结果 preds ,而没有进行任何实际的后处理操作。
# 定义了 postprocess 方法,它接收以下参数 :
# 1.preds :模型的预测结果。
# 2.img :经过预处理的输入图像张量(通常用于推理)。
# 3.orig_imgs :原始输入图像(未经过预处理的图像列表)。
def postprocess(self, preds, img, orig_imgs):
# 对图像进行后处理预测并返回它们。
"""Post-processes predictions for an image and returns them."""
# 直接返回模型的预测结果 preds ,没有任何后处理操作。
return preds
# 后处理的潜在用途 :
# 尽管当前实现没有进行任何操作,但后处理通常是一个非常重要的步骤,尤其是在实际应用中。后处理可以包括以下内容 :
# 解码预测结果 :将模型输出的原始预测(如边界框坐标、类别置信度等)转换为更易理解的格式。 例如,在目标检测任务中,模型可能输出的是归一化的边界框坐标,后处理可以将其转换为原始图像坐标。
# 应用置信度阈值 :过滤掉置信度低于某个阈值的预测结果。例如,只保留置信度大于 0.5 的预测。
# 非极大值抑制(NMS) :在目标检测任务中,模型可能会输出多个重叠的边界框。通过非极大值抑制,可以去除冗余的边界框,只保留最可靠的预测。
# 调整预测结果 :根据原始图像的尺寸和预处理步骤(如填充、缩放等),调整预测结果的坐标,使其与原始图像对齐。
# 格式化输出 :将预测结果格式化为特定的格式,以便后续任务(如可视化、保存结果等)能够更方便地使用。
# 当前的 postprocess 方法只是一个占位符实现,直接返回了模型的预测结果 preds ,而没有进行任何后处理操作。在实际应用中,后处理是一个非常重要的步骤,通常需要根据具体任务和模型的输出格式进行定制化实现。例如,在目标检测任务中,后处理可能包括解码边界框、应用置信度阈值和非极大值抑制等操作。
# 这段代码定义了 BasePredictor 类的 __call__ 方法,它允许用户通过直接调用类实例的方式来启动推理流程。这个方法的核心功能是根据是否启用流式推理( stream 参数)来决定推理的执行方式。
# 定义了 __call__ 方法,允许类实例像函数一样被调用。它接收以下参数 :
# 1.source :输入源,可以是图像路径、视频路径或摄像头流。
# 2.model :模型路径或模型对象,用于推理。
# 3.stream :布尔值,表示是否启用流式推理模式。
# 4.*args 和 5.**kwargs :额外的参数,用于传递给推理方法。
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
# 对图像或流执行推理。
"""Performs inference on an image or stream."""
# 将传入的 stream 参数值赋给实例变量 self.stream ,用于 记录当前是否为流式推理模式 。
self.stream = stream
# 如果启用了流式推理模式( stream=True )。
if stream:
# 则直接调用 self.stream_inference 方法,并将输入源、模型和其他参数传递给该方法。 stream_inference 方法通常用于实时处理视频流或摄像头数据,并逐帧返回推理结果。
return self.stream_inference(source, model, *args, **kwargs)
# 如果未启用流式推理模式( stream=False )。
else:
# 则调用 self.stream_inference 方法,但将其返回的结果(通常是生成器)转换为一个列表。这样可以将所有推理结果合并为一个列表并一次性返回,而不是逐帧返回。
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
# 这段代码的核心功能是提供一个统一的接口,用于启动推理流程。它根据是否启用流式推理模式来决定推理的执行方式:流式推理模式( stream=True ):逐帧处理输入源(如视频流或摄像头),并逐帧返回推理结果。适用于实时处理场景,例如视频监控或实时目标检测。非流式推理模式( stream=False ):将所有推理结果合并为一个列表并一次性返回。适用于处理静态图像或需要一次性获取所有结果的场景。通过这种方式, __call__ 方法提供了一个灵活的接口,能够适应不同的推理需求。
# 这段代码定义了 BasePredictor 类中的 predict_cli 方法,用于在命令行界面(CLI)模式下执行推理任务。它的主要功能是运行推理流程,但不保存或积累任何输出结果。这通常用于场景,比如只需要执行推理而不关心具体输出内容,或者输出内容已经在其他地方被处理。
# 定义了 predict_cli 方法,接收以下参数 :
# 1.source :输入源,可以是图像路径、视频路径或摄像头流。
# 2.model :模型路径或模型对象,用于推理。
def predict_cli(self, source=None, model=None):
# 用于 CLI 预测的方法。
# 它始终使用生成器作为输出,因为 CLI 模式不需要。
"""
Method used for CLI prediction.
It uses always generator as outputs as not required by CLI mode.
"""
# 调用 self.stream_inference 方法,传入 source 和 model 参数,返回一个生成器 gen 。这个 生成器逐帧生成推理结果 。
gen = self.stream_inference(source, model)
# 通过一个简单的 for 循环遍历生成器 gen ,逐帧执行推理任务。由于循环体中没有任何操作( pass ),这意味着推理结果不会被保存或进一步处理。 noqa 注释:这是一个常见的注释,用于告诉代码检查工具(如 flake8 )忽略当前行的警告或错误。在这里,它可能是为了避免"未使用的变量"等警告。
for _ in gen: # noqa, running CLI inference without accumulating any outputs (do not modify)
pass
# 作用和意义 :
# predict_cli 方法的主要目的是在命令行模式下运行推理任务,但不保存任何输出结果。这种设计可能适用于以下场景 :
# 性能测试 :仅测试模型的推理速度,而不关心具体的推理结果。
# 调试 :在开发过程中快速验证模型是否能够正常运行,而不必处理输出结果。
# 实时处理 :在某些实时应用场景中,推理结果可能已经在推理过程中被实时处理(如显示在屏幕上),因此不需要额外保存。
# predict_cli 方法是一个轻量级的推理接口,专门用于命令行模式下的推理任务。它的核心功能是通过 stream_inference 方法逐帧执行推理,但不保存或处理任何输出结果。这种设计使得该方法非常适合用于快速验证模型、性能测试或实时处理场景。
# 这段代码定义了 BasePredictor 类中的 setup_source 方法,用于设置输入源(如图像、视频或流)并初始化与之相关的属性。这个方法的核心功能是根据输入源的类型和配置参数,加载数据集并准备推理所需的环境。
# 定义了 setup_source 方法,它接受一个参数。
# 1.source :表示输入源的路径或来源(如图像路径、视频路径、摄像头流等)。
def setup_source(self, source):
# 设置源和推理模式。
"""Sets up source and inference mode."""
# 调用 check_imgsz 函数,检查并设置模型输入图像的尺寸。
# self.args.imgsz :从配置中获取目标图像尺寸。
# stride=self.model.stride :模型的步幅(stride),用于确保图像尺寸与模型的步幅对齐。
# min_dim=2 :设置图像的最小维度。
# 返回值存储在 self.imgsz 中,表示 最终确定的图像尺寸 。
# def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): -> 用于检查和调整图像尺寸( imgsz ),以确保其符合特定的约束条件。返回调整后的图像尺寸 sz 。 -> return sz
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
# 根据任务类型设置 数据转换 ( transforms )。
self.transforms = (
# 首先尝试从模型对象中获取 transforms 属性( getattr(self.model.model, "transforms", ...) )。如果模型对象中没有定义 transforms ,则使用默认的分类转换函数 classify_transforms ,传入图像尺寸和裁剪比例( crop_fraction )。
getattr(
self.model.model,
"transforms",
# def classify_transforms(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, crop_fraction: float = DEFAULT_CROP_FTACTION,):
# -> 用于生成图像分类任务的数据预处理流程。使用 T.Compose 将所有预处理步骤组合成一个完整的预处理流程,并返回该流程。
# -> return T.Compose(tfl)
classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
)
# 如果任务是分类( self.args.task == "classify" )
if self.args.task == "classify"
# 如果任务不是分类,则将 self.transforms 设置为 None 。
else None
)
# 调用 load_inference_source 函数加载输入源。返回值是一个 数据集对象 ,存储在 self.dataset 中。
# def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): -> 根据给定的输入源 source 加载相应的数据集,并返回一个数据加载器。返回构建好的数据集对象 dataset 。 -> return dataset
self.dataset = load_inference_source(
# 输入源路径或来源。
source=source,
# 批量大小。
batch=self.args.batch,
# 视频帧的步幅(用于视频或流)。
vid_stride=self.args.vid_stride,
# 流式数据的缓冲区大小。
buffer=self.args.stream_buffer,
)
# 从数据集对象中获取输入源的类型(如图像、视频、流等),并存储在 self.source_type 中。
self.source_type = self.dataset.source_type
# 检查是否需要发出警告。
# 如果当前实例没有启用流式推理模式( self.stream 为 False ),并且满足以下任一条件。
if not getattr(self, "stream", True) and (
# 输入源是流式数据( self.source_type.stream )。
self.source_type.stream
# 输入源是截图( self.source_type.screenshot )。
or self.source_type.screenshot
# 数据集包含大量图像( len(self.dataset) > 1000 )。
or len(self.dataset) > 1000 # many images
# 数据集中包含视频( any(getattr(self.dataset, "video_flag", [False])) )。
or any(getattr(self.dataset, "video_flag", [False]))
): # videos
# 如果满足上述条件,则通过日志记录器( LOGGER )发出警告,内容为 STREAM_WARNING 。
LOGGER.warning(STREAM_WARNING)
# 初始化一个空字典 self.vid_writer ,用于 存储视频写入器 ( cv2.VideoWriter )对象。这通常用于保存推理结果为视频文件。
self.vid_writer = {}
# setup_source 方法的核心功能是根据输入源的类型和配置参数,加载数据集并初始化推理所需的环境。主要步骤包括:检查图像尺寸:通过 check_imgsz 确保图像尺寸符合模型要求。设置数据转换:根据任务类型(如分类)加载或定义数据转换。加载输入源:通过 load_inference_source 加载数据集,并获取输入源的类型。检查并发出警告:如果输入源不适合当前的推理模式(如非流式推理但输入源是流或视频),则发出警告。初始化视频写入器:准备用于保存推理结果的视频写入器。这个方法为后续的推理流程(如预处理、推理、后处理和结果保存)奠定了基础,确保输入数据能够被正确处理并适配模型的要求。
# 这段代码定义了 BasePredictor 类中的 stream_inference 方法,用于执行流式推理(streaming inference)。它支持实时处理输入源(如视频流、摄像头或图像序列),并逐帧生成推理结果。这个方法还集成了多种功能,包括模型预热、性能分析、结果保存和回调机制。
# @smart_inference_mode() 是一个装饰器,用于优化推理模式(如自动选择推理后端)。
# def smart_inference_mode():
# -> 根据PyTorch的版本号来选择性地应用 torch.inference_mode() 装饰器或 torch.no_grad() 装饰器,以优化模型的推理模式。
# -> 如果上述条件判断为真,说明当前已经处于推理模式,无需再应用装饰器,因此直接返回原函数 fn ,相当于一个通过操作。根据 TORCH_1_9 变量的值来选择装饰器。如果 TORCH_1_9 为真,则选择 torch.inference_mode 装饰器;否则选择 torch.no_grad 装饰器。然后立即调用所选装饰器,并将 fn 作为参数传递给装饰器,最终返回装饰后的函数。
# -> 返回内部定义的 decorate 函数。这样,当 smart_inference_mode 被用作装饰器时,实际上返回的是 decorate 函数, decorate 函数会根据PyTorch的版本号来决定应用哪个装饰器。
# -> return decorate
@smart_inference_mode()
# 定义了 stream_inference 方法,接收以下参数 :
# 1.source :输入源路径或来源(如图像路径、视频路径、摄像头流等)。
# 2.model :模型路径或模型对象。如果为 None ,则使用默认模型。
# 3.*args 和 4.**kwargs :额外的参数,用于传递给推理方法。
def stream_inference(self, source=None, model=None, *args, **kwargs):
# 在摄像头馈送上进行实时推理并将结果保存到文件中。
"""Streams real-time inference on camera feed and saves results to file."""
# 如果启用了详细模式( self.args.verbose ),则在日志中输出一个空行,用于分隔日志信息。
if self.args.verbose:
LOGGER.info("")
# Setup model
# 如果当前实例尚未加载模型,则调用 self.setup_model 方法加载模型。这一步确保模型准备好进行推理。
if not self.model:
self.setup_model(model)
# 这段代码是 stream_inference 方法的核心部分,主要用于设置推理环境、初始化必要的资源,并开始逐批次处理数据集。
# 使用线程锁 self._lock 确保推理过程的线程安全性。这在多线程环境中非常关键,避免多个线程同时访问模型导致冲突。
with self._lock: # for thread-safe inference
# Setup source every time predict is called
# 调用 self.setup_source 方法设置输入源。如果 source 参数不为 None ,则使用传入的 source 。 否则,使用默认输入源( self.args.source )。 确保每次调用推理时,输入源都被正确设置。
self.setup_source(source if source is not None else self.args.source)
# Check if save_dir/ label file exists
# 检查并创建保存目录。
# 如果启用了保存图像( self.args.save )或保存文本文件( self.args.save_txt )。
if self.args.save or self.args.save_txt:
# 如果保存文本文件,则创建 self.save_dir / "labels" 目录。 否则,创建 self.save_dir 目录。 parents=True :如果需要,创建所有父目录。 exist_ok=True :如果目录已存在,不会抛出错误。
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
# 如果模型尚未预热( self.done_warmup 为 False )。
if not self.done_warmup:
# 则调用 self.model.warmup 方法进行预热。预热时使用的输入尺寸由 self.imgsz 决定。 如果模型是 PyTorch 模型或使用 Triton 后端,则批量大小为 1;否则,使用数据集的批量大小( self.dataset.bs )。
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
# 预热完成后,将 self.done_warmup 设置为 True 。
self.done_warmup = True
# 初始化一些变量。
# self.seen :记录 已处理的图像数量 ,初始值为 0 。
# self.windows :记录 显示窗口的列表 ,初始为空列表。
# self.batch : 当前批次 的数据,初始值为 None 。
self.seen, self.windows, self.batch = 0, [], None
# 创建三个性能分析器( ops.Profile ),用于记录 预处理 、 推理 和 后处理 的耗时。每个分析器绑定到当前设备( self.device )。
profilers = (
# class Profile(contextlib.ContextDecorator):
# -> 用于测量代码块的执行时间,支持CPU和CUDA设备。
# -> def __init__(self, t=0.0, device: torch.device = None):
ops.Profile(device=self.device),
ops.Profile(device=self.device),
ops.Profile(device=self.device),
)
# 调用 self.run_callbacks 方法,触发 on_predict_start 事件的回调函数。这允许用户在推理开始时插入自定义逻辑。
self.run_callbacks("on_predict_start")
# 遍历数据集中的 每个批次 ( self.dataset 是一个可迭代对象,通常是一个生成器)。
for self.batch in self.dataset:
# 在每个批次开始时,触发 on_predict_batch_start 事件的回调函数。
self.run_callbacks("on_predict_batch_start")
# 然后解包当前批次的数据。 paths :图像路径列表。 im0s :原始图像列表。 s :附加信息(如帧信息)。
paths, im0s, s = self.batch
# 这段代码的核心功能是初始化推理环境,并开始逐批次处理数据集。设置输入源:通过 self.setup_source 方法设置输入源。检查并创建保存目录:根据配置选项创建保存图像或文本文件的目录。模型预热:如果模型尚未预热,则调用 self.model.warmup 方法进行预热。初始化变量:初始化已处理图像数量、显示窗口列表和当前批次数据。创建性能分析器:创建三个性能分析器,分别用于记录预处理、推理和后处理的耗时。触发回调函数:在推理开始时和每个批次开始时,触发相应的回调函数。逐批次处理数据集:遍历数据集中的每个批次,并解包批次数据。这些步骤为后续的预处理、推理和后处理奠定了基础,确保推理流程能够高效、安全地运行。
# 这段代码是 stream_inference 方法的核心部分,负责执行推理流程中的预处理、推理和后处理步骤。它还通过性能分析器记录每个阶段的耗时,并根据配置决定是否直接返回嵌入向量。
# Preprocess
# 预处理阶段。
# 使用第一个性能分析器 profilers[0] 记录预处理阶段的耗时。
with profilers[0]:
# 调用 self.preprocess 方法对原始图像列表 im0s 进行预处理,返回 预处理后的图像张量 im 。预处理通常包括调整图像大小、归一化像素值、转换数据格式等操作。
im = self.preprocess(im0s)
# Inference
# 推理阶段。
# 使用第二个性能分析器 profilers[1] 记录推理阶段的耗时。
with profilers[1]:
# 调用 self.inference 方法,将预处理后的图像张量 im 传递给模型进行推理,返回 预测结果 preds 。额外的参数 *args 和 **kwargs 传递给推理方法,用于支持不同的推理配置。
preds = self.inference(im, *args, **kwargs)
# 嵌入模式。
# 如果启用了嵌入模式( self.args.embed 为 True ),则直接返回预测结果 preds 。
if self.args.embed:
# 如果 preds 是一个 torch.Tensor ,则将其包装为列表 [preds] 并生成。
# 如果 preds 已经是一个可迭代对象(如列表或生成器),则直接生成。
# 使用 yield from 逐个生成嵌入向量。
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
# 使用 continue 跳过后续的后处理步骤,直接进入下一个批次的处理。
continue
# Postprocess
# 后处理阶段。
# 使用第三个性能分析器 profilers[2] 记录后处理阶段的耗时。
with profilers[2]:
# 调用 self.postprocess 方法,对预测结果 preds 进行后处理。 preds :模型的原始预测结果。 im :预处理后的图像张量。 im0s :原始图像列表。后处理通常包括解码预测结果、过滤低置信度结果、应用非极大值抑制(NMS)等操作。
self.results = self.postprocess(preds, im, im0s)
# 回调机制。在后处理结束后,触发 on_predict_postprocess_end 事件的回调函数。 这允许用户在后处理完成后插入自定义逻辑,例如记录日志、保存中间结果等。
self.run_callbacks("on_predict_postprocess_end")
# 这段代码的核心功能是执行推理流程中的预处理、推理和后处理步骤,并通过性能分析器记录每个阶段的耗时。主要逻辑包括。预处理:将原始图像转换为模型所需的格式。推理:将预处理后的图像传递给模型,获取预测结果。嵌入模式:如果启用嵌入模式,则直接返回预测结果,跳过后处理。后处理:对预测结果进行进一步处理,以适配实际应用需求。回调机制:在后处理结束后触发回调函数,提供扩展点。这种设计使得推理流程既高效又灵活,能够适应不同的应用场景和需求。
# 这段代码是 stream_inference 方法的一部分,负责处理每个批次的推理结果,包括可视化、保存、写入结果,并记录日志。
# Visualize, save, write results
# 获取当前批次中原始图像的数量 n , im0s 是原始图像列表。
n = len(im0s)
# 遍历当前批次中的每张图像。
for i in range(n):
# 记录已处理的图像总数。
self.seen += 1
# 为每张图像的推理结果记录每个阶段的平均耗时(单位为毫秒)。每个阶段的耗时除以图像数量 n ,得到每张图像的平均耗时。
self.results[i].speed = {
# 预处理阶段的耗时。
"preprocess": profilers[0].dt * 1e3 / n,
# 推理阶段的耗时。
"inference": profilers[1].dt * 1e3 / n,
# 后处理阶段的耗时。
"postprocess": profilers[2].dt * 1e3 / n,
}
# 如果启用了详细模式( self.args.verbose )、保存图像( self.args.save )、保存文本文件( self.args.save_txt )或显示图像( self.args.show )
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
# 则调用 self.write_results 方法。
# i :当前图像的索引。
# Path(paths[i]) :当前图像的路径。
# im :预处理后的图像张量。
# s :附加信息(如帧信息)。
# self.write_results 方法会处理和保存推理结果,并返回一个 包含详细信息的字符串 ,追加到 s[i] 中。
s[i] += self.write_results(i, Path(paths[i]), im, s)
# Print batch results
# 如果启用了详细模式( self.args.verbose )。
if self.args.verbose:
# 则将当前批次的结果记录到日志中。使用 LOGGER.info 输出每张图像的详细信息( s 是一个字符串列表,每项对应一张图像)。
LOGGER.info("\n".join(s))
# 在当前批次结束后,触发 on_predict_batch_end 事件的回调函数。这允许用户在批次结束时插入自定义逻辑,例如记录日志、保存中间结果等。
self.run_callbacks("on_predict_batch_end")
# 生成当前批次的推理结果。 yield from 逐个返回 self.results 中的每个结果,供外部使用(例如保存或显示)。
yield from self.results
# 这段代码的核心功能是处理每个批次的推理结果,包括以下步骤。记录已处理图像数量:通过 self.seen 记录已处理的图像总数。记录每个阶段的耗时:为每张图像记录预处理、推理和后处理的平均耗时。保存和可视化结果:如果启用了相关选项,则调用 self.write_results 方法处理和保存结果(如保存图像、保存文本文件、显示图像)。记录日志:如果启用了详细模式,则将每张图像的详细信息记录到日志中。触发回调函数:在批次结束时触发回调,允许用户插入自定义逻辑。生成结果:逐个返回当前批次的推理结果,供外部使用。这种设计使得推理流程能够灵活地处理和保存结果,同时支持详细的日志记录和用户自定义的回调逻辑。
# 这段代码是 stream_inference 方法的最后部分,负责在推理流程结束后释放资源、记录最终结果,并触发结束时的回调函数。
# Release assets
# 释放视频写入器资源。
# 遍历 self.vid_writer 字典中的所有值。
for v in self.vid_writer.values():
# 如果某个值是 cv2.VideoWriter 对象(用于保存视频)。
if isinstance(v, cv2.VideoWriter):
# 则调用 v.release() 释放该资源。确保在推理结束后正确关闭所有视频文件,避免资源泄漏。
v.release()
# Print final results
# 如果启用了详细模式( self.args.verbose )并且至少处理了一张图像( self.seen > 0 ),则记录最终的性能统计信息。
if self.args.verbose and self.seen:
# 计算每个阶段(预处理、推理、后处理)的平均耗时(单位为毫秒)。 x.t :每个性能分析器记录的总耗时。 self.seen :已处理的图像总数。 * 1e3 :将秒转换为毫秒。 结果存储在元组 t 中。
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
# 记录每个阶段的平均耗时和输入图像的形状。使用 LOGGER.info 输出日志信息。
LOGGER.info(
# 格式化字符串中包含。
# 预处理、推理和后处理的平均耗时(单位为毫秒)。
# 输入图像的形状(批量大小、通道数、高度、宽度)。
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " # 速度:%.1fms 预处理,%.1fms 推理,%.1fms 后处理,每幅图像形状。
# min(self.args.batch, self.seen) :确保批量大小不超过实际处理的图像数量。
f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t
)
# 如果启用了保存图像、保存文本文件或裁剪保存目标区域,则记录保存结果的信息。
if self.args.save or self.args.save_txt or self.args.save_crop:
# 统计保存的标签文件数量。使用 self.save_dir.glob("labels/*.txt") 获取 labels 文件夹中的所有 .txt 文件。 使用 len 计算文件数量。
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
# 根据是否保存文本文件,构造保存结果的字符串。如果保存了文本文件,则记录保存的标签文件数量和路径。 如果没有保存文本文件,则字符串为空。
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" # {nl} 标签{'s' * (nl > 1)} 已保存至 {self.save_dir / 'labels'}。
# 录保存结果的路径和数量。使用 LOGGER.info 输出日志信息。 使用 colorstr('bold', self.save_dir) 将保存路径以加粗形式显示。
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") # 结果保存至 {colorstr('bold', self.save_dir)}{s}。
# 在推理结束后,触发 on_predict_end 事件的回调函数。这允许用户在推理结束时插入自定义逻辑,例如清理资源、记录最终结果等。
self.run_callbacks("on_predict_end")
# 这段代码的核心功能是。释放资源:确保所有视频写入器资源被正确释放。记录最终结果:如果启用了详细模式,则记录每个阶段的平均耗时和输入图像的形状。如果启用了保存功能,则记录保存结果的路径和数量。触发回调函数:在推理结束时触发回调,允许用户插入自定义逻辑。这种设计确保了推理流程的完整性和灵活性,同时提供了详细的性能统计信息和用户自定义的扩展点。
# stream_inference 方法是 BasePredictor 类的核心功能,用于执行流式推理任务。它支持实时处理输入源(如视频流、图像序列或摄像头数据),并逐帧生成推理结果。该方法通过线程锁确保线程安全,支持模型预热、性能分析、结果保存和可视化,并通过回调机制提供扩展点,允许用户在推理的不同阶段插入自定义逻辑。它还能够根据配置动态调整行为,例如保存预测结果为文本文件、裁剪并保存检测目标、显示图像或保存绘制了预测结果的图像。最终,它释放所有资源,并记录详细的推理性能统计信息,确保推理流程高效、灵活且易于扩展。
# 这段代码定义了 BasePredictor 类中的 setup_model 方法,用于初始化和配置模型,使其准备好进行推理。这个方法的核心功能是加载模型权重、设置设备(CPU/GPU)、配置模型的运行模式(如半精度推理),并将其切换到评估模式。
# 定义了 setup_model 方法,接收以下参数 :
# 1.model :模型权重的路径或模型对象。如果传入 None ,则使用 self.args.model 中的默认值。
# 2.verbose :布尔值,控制是否打印详细信息(如设备选择信息)。默认为 True 。
def setup_model(self, model, verbose=True):
# 使用给定的参数初始化 YOLO 模型并将其设置为评估模式。
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
# 通过 AutoBackend 类初始化模型对象,传入以下参数。
# class AutoBackend(nn.Module):
# -> 是用于运行 Ultralytics YOLO 模型推理的动态后端选择器。这个类提供了一个抽象层,用于支持多种推理引擎,并根据输入模型的格式支持动态后端切换,使得在不同平台上部署模型变得更加容易。
# -> def __init__(self, weights="yolov8n.pt", device=torch.device("cpu"), dnn=False, data=None, fp16=False, batch=1, fuse=True, verbose=True,):
self.model = AutoBackend(
# 模型权重路径。如果 model 参数为 None ,则使用 self.args.model 中的默认值。
weights=model or self.args.model,
# 通过 select_device 函数选择设备(CPU 或 GPU)。 self.args.device 指定了设备偏好(如 "cpu" 或 "cuda" ), verbose 参数控制是否打印设备选择信息。
# def select_device(device="", batch=0, newline=False, verbose=True): -> 用于选择和配置设备(如 CPU、GPU 或 MPS),以便在 PyTorch 框架中运行模型。返回一个 torch.device 对象,表示最终选择的设备。 -> return torch.device(arg)
device=select_device(self.args.device, verbose=verbose),
# 是否启用深度神经网络(DNN)相关的优化,从 self.args.dnn 获取。
dnn=self.args.dnn,
# 数据集配置,从 self.args.data 获取。
data=self.args.data,
# 是否启用半精度(FP16)推理,从 self.args.half 获取。
fp16=self.args.half,
# 推理的批量大小,从 self.args.batch 获取。
batch=self.args.batch,
# 是否启用模型融合(如卷积层和批量归一化层的融合),这里固定为 True 。
fuse=True,
# 控制是否打印详细信息。
verbose=verbose,
)
# 从模型对象中获取实际使用的设备(如 CPU 或 GPU),并更新 self.device 属性。这确保了后续操作(如数据移动)使用正确的设备。
self.device = self.model.device # update device
# 从模型对象中获取实际使用的推理精度(FP16 或 FP32),并更新 self.args.half 属性。这确保了后续操作(如数据类型转换)使用正确的精度。
self.args.half = self.model.fp16 # update half
# 将模型切换到评估模式( eval )。这会关闭某些训练时特有的操作(如 Dropout 和 BatchNorm 的训练模式),确保模型在推理时的行为是确定性的。
self.model.eval()
# setup_model 方法的核心功能是初始化和配置模型,使其准备好进行推理。加载模型权重:通过 AutoBackend 初始化模型,传入权重路径、设备选择、DNN 配置、数据集配置、推理精度等参数。更新设备和精度:从模型对象中获取实际使用的设备和推理精度,并更新相关属性。切换到评估模式:通过调用 self.model.eval() ,确保模型在推理时的行为是确定性的。这个方法为后续的推理流程(如预处理、推理和后处理)提供了必要的模型支持,确保模型能够高效、准确地运行。
# 这段代码定义了 BasePredictor 类中的 write_results 方法,用于处理和保存推理结果。它将预测结果写入图像、保存为文本文件、裁剪保存特定区域,或者显示在屏幕上。
# 定义了 write_results 方法,接收以下参数 :
# 1.i :当前处理的索引(例如帧编号或图像编号)。
# 2.p :当前处理的路径( Path 对象)。
# 3.im :当前处理的图像张量。
# 4.s :包含附加信息的字符串列表(例如帧信息)。
def write_results(self, i, p, im, s):
# 将推理结果写入文件或目录。
"""Write inference results to a file or directory."""
# 这段代码是 write_results 方法的一部分,主要用于初始化一些变量,并根据输入数据的类型(如流、图像或张量)和当前处理的索引,提取或生成帧编号( frame )。
# 初始化一个空字符串 string ,用于记录当前处理的详细信息。这个字符串后续可能会被打印到日志或控制台中。
string = "" # print string
# 检查输入图像 im 的形状。
if len(im.shape) == 3:
# 如果 im 是三维张量(形状为 [C, H, W] ,表示单张图像),则通过 im[None] 增加一个批量维度,使其形状变为 [1, C, H, W] 。 为了确保图像张量的形状与模型的输入要求一致(模型通常期望批量维度)。
im = im[None] # expand for batch dim
# 判断当前数据源的类型。如果数据源是流( self.source_type.stream )、图像文件( self.source_type.from_img )或张量( self.source_type.tensor ),则假设批量大小至少为 1。
if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
# 如果数据源是流、图像文件或张量。
# 在 string 中添加 当前索引 i ,格式为 "i: " 。
string += f"{i}: "
# 获取 当前帧编号 frame ,从 self.dataset.count 中获取。 self.dataset.count 通常是一个计数器,记录当前处理的帧或图像编号。
frame = self.dataset.count
# 如果数据源不是流、图像文件或张量,则尝试从字符串 s[i] 中提取帧编号。
else:
# 使用正则表达式 re.search(r"frame (\d+)/", s[i]) 匹配形如 "frame 123/" 的模式。
match = re.search(r"frame (\d+)/", s[i])
# 如果匹配成功,则提取帧编号( match.group(1) )并将其转换为整数。
# 如果匹配失败,则将 frame 设置为 None (表示帧编号无法确定)。
frame = int(match.group(1)) if match else None # 0 if frame undetermined
# 这段代码的核心功能是初始化一些变量,并根据输入数据的类型和当前处理的索引,提取或生成帧编号。检查输入图像的形状:如果输入是单张图像(三维张量),则增加一个批量维度。根据数据源类型处理帧编号:如果数据源是流、图像文件或张量,直接从数据集计数器中获取帧编号。如果数据源是其他类型,尝试从字符串中提取帧编号。这种设计使得 write_results 方法能够灵活处理不同类型的输入数据,并为后续的保存和显示操作提供必要的信息(如帧编号)。
# 这段代码是 write_results 方法的一部分,用于构造保存预测结果的文本路径、记录图像尺寸信息、获取推理结果并更新保存目录。
# 构造保存预测结果的文本路径。
# self.save_dir / "labels" :指定保存目录为 self.save_dir 下的 labels 文件夹。
# p.stem :获取当前处理文件的名称(不包含扩展名)。
# ("" if self.dataset.mode == "image" else f"_{frame}") : 如果数据集模式是 "image" (处理单张图像),则不添加帧编号。 否则(处理视频或流),在文件名后添加帧编号(如 _123 )。
# 最终路径 :构造的路径格式为 self.save_dir/labels/<filename>[_frame].txt 。
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
# 将图像的 宽度 和 高度 信息添加到 string 中。
# im.shape[2:] :获取图像的宽度和高度(假设 im 是 [batch, channel, height, width] 格式)。
# "%gx%g " :格式化为字符串,例如 "1280x720 " 。
# string += ... :将格式化的字符串追加到 string 中。
string += "%gx%g " % im.shape[2:]
# 从 self.results 列表中获取当前索引 i 对应的推理结果 。
result = self.results[i]
# 将保存目录路径( self.save_dir )转换为字符串,并赋值给 result.save_dir 。这可能 用于在其他地方引用保存路径 。
result.save_dir = self.save_dir.__str__() # used in other locations
# 将 推理结果的详细信息 和 推理速度 添加到 string 中。
# result.verbose() :调用 result 的 verbose 方法,获取推理结果的详细信息(如检测到的类别、置信度等)。
# result.speed['inference'] :获取推理速度(单位为毫秒)。
# f"{...:.1f}ms" :将推理速度格式化为保留一位小数的字符串,例如 "12.3ms" 。
# string += ... :将详细信息和推理速度追加到 string 中。
string += result.verbose() + f"{result.speed['inference']:.1f}ms"
# 这段代码的核心功能是构造保存预测结果的路径,并记录当前处理的详细信息(包括图像尺寸、推理结果和推理速度)。构造文本路径:根据数据集模式(图像或视频/流)动态生成保存路径。记录图像尺寸:将图像的宽度和高度格式化为字符串并追加到记录信息中。获取推理结果:从 self.results 中提取当前索引的推理结果,并更新保存目录。记录推理结果和速度:将推理结果的详细信息和推理速度格式化为字符串并追加到记录信息中。这些操作为后续保存预测结果(如保存到文本文件或打印到日志)提供了必要的信息。
# 这段代码是 write_results 方法的一部分,用于将预测结果(如检测框、置信度、类别标签等)绘制到图像上。这一步通常用于可视化推理结果,以便后续保存或显示图像。
# Add predictions to image
# 判断是否需要将预测结果绘制到图像上。如果配置中启用了保存图像( self.args.save )或显示图像( self.args.show ),则执行绘制操作。
if self.args.save or self.args.show:
# 调用 result.plot() 方法,将预测结果绘制到图像上,并将绘制后的图像存储在 self.plotted_img 中。
self.plotted_img = result.plot(
# 设置绘制边界框的线宽,值从配置参数 self.args.line_width 中获取。
line_width=self.args.line_width,
# 设置是否显示边界框,值从配置参数 self.args.show_boxes 中获取。 如果为 True ,则绘制边界框。 如果为 False ,则不绘制边界框。
boxes=self.args.show_boxes,
# 设置是否显示置信度,值从配置参数 self.args.show_conf 中获取。 如果为 True ,则在边界框旁边显示置信度。 如果为 False ,则不显示置信度。
conf=self.args.show_conf,
# 设置是否显示类别标签,值从配置参数 self.args.show_labels 中获取。 如果为 True ,则在边界框旁边显示类别标签。 如果为 False ,则不显示类别标签。
labels=self.args.show_labels,
# 设置绘制时使用的图像张量,如果启用了 RetinaMasks ( self.args.retina_masks 为 True ),则传递 None ,表示使用默认图像。 否则,传递当前图像张量 im[i] ,用于绘制预测结果。
im_gpu=None if self.args.retina_masks else im[i],
)
# 这段代码的核心功能是将预测结果绘制到图像上,以便后续保存或显示。判断是否需要绘制:根据配置参数( self.args.save 或 self.args.show )决定是否执行绘制操作。调用绘制方法:通过 result.plot() 方法将预测结果(如边界框、置信度、类别标签等)绘制到图像上。配置绘制参数:设置绘制边界框的线宽。控制是否显示边界框、置信度和类别标签。根据是否启用 RetinaMasks 选择合适的图像张量。这一步是可视化推理结果的关键环节,使得用户能够直观地看到模型的预测效果。
# 这段代码是 write_results 方法的一部分,用于根据配置选项保存推理结果。它支持将预测结果保存为文本文件、裁剪并保存检测到的目标区域、显示图像,以及保存绘制了预测结果的图像。
# Save results
# 如果配置中启用了保存文本文件( self.args.save_txt )。
if self.args.save_txt:
# 则调用 result.save_txt() 方法将预测结果保存为文本文件。文件路径 :使用之前构造的 self.txt_path ,并添加 .txt 扩展名。 保存置信度 :通过 save_conf=self.args.save_conf 参数决定是否保存预测的置信度。
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
# 如果配置中启用了裁剪并保存检测到的目标区域( self.args.save_crop )。
if self.args.save_crop:
# 则调用 result.save_crop() 方法。保存目录 : self.save_dir / "crops" ,表示裁剪后的图像将保存到 crops 文件夹中。 文件名 :使用 self.txt_path.stem (即不包含扩展名的文件名)作为裁剪图像的文件名。
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
# 如果配置中启用了显示图像( self.args.show )。
if self.args.show:
# 则调用 self.show() 方法显示绘制了预测结果的图像。窗口标题 使用 str(p) ,即当前处理的路径( p )作为窗口标题。
self.show(str(p))
# 如果配置中启用了保存绘制了预测结果的图像( self.args.save )。
if self.args.save:
# 则调用 self.save_predicted_images() 方法。保存路径 : str(self.save_dir / p.name) ,表示保存路径为 self.save_dir 下的原始文件名。 帧编号 : frame ,用于视频或流模式,指定当前帧编号。
self.save_predicted_images(str(self.save_dir / p.name), frame)
# 这段代码的核心功能是根据配置选项保存推理结果。它支持以下操作。保存为文本文件:将预测结果(如边界框、置信度、类别标签等)保存为 .txt 文件。裁剪并保存目标区域:将检测到的目标区域裁剪出来并保存为单独的图像。显示图像:在屏幕上显示绘制了预测结果的图像。保存绘制了预测结果的图像:将绘制了预测结果的图像保存为文件(支持图像、视频或流模式)。这种设计使得用户可以根据需要选择保存和显示的方式,提供了高度的灵活性。
# 返回记录的详细信息字符串 string 。
return string
# write_results 方法的核心功能是处理和保存推理结果。它根据配置选项执行以下操作。记录详细信息:生成包含图像尺寸、推理速度和预测结果的字符串。绘制预测结果:将预测结果(如边界框、标签)绘制到图像上。保存结果:保存为文本文件( save_txt )。裁剪并保存感兴趣区域( save_crop )。显示绘制了预测结果的图像( show )。保存绘制了预测结果的图像到文件( save_predicted_images )。这个方法通过灵活的配置选项,支持多种保存和显示方式,适用于不同的应用场景。
# 这段代码定义了 BasePredictor 类中的 save_predicted_images 方法,用于保存推理后的图像或视频结果。它根据数据集的模式(图像、视频或流)决定保存方式。
# 定义了 save_predicted_images 方法,接收以下参数 :
# 1.save_path :保存结果的路径。
# 2.frame :当前帧编号,用于视频或流模式。
def save_predicted_images(self, save_path="", frame=0):
# 将视频预测保存为 mp4 到指定路径。
"""Save video predictions as mp4 at specified path."""
# 获取绘制了预测结果的图像( self.plotted_img )。这通常是经过标注(如绘制边界框、类别标签等)后的图像。
im = self.plotted_img
# Save videos and streams
# 如果数据集模式是视频或流( self.dataset.mode 为 "stream" 或 "video" ),则按视频方式保存结果。
if self.dataset.mode in {"stream", "video"}:
# 根据数据集模式获取帧率( fps )。如果是视频模式,帧率从 self.dataset.fps 获取。 如果是流模式,默认帧率为 30 。
fps = self.dataset.fps if self.dataset.mode == "video" else 30
# 定义保存帧图像的路径(如果启用了保存帧的功能)。路径格式为 <save_path_without_extension>_frames/ 。
frames_path = f'{save_path.split(".", 1)[0]}_frames/'
# 检查当前保存路径是否已经创建了视频写入器( cv2.VideoWriter )。如果没有,则初始化一个新的视频写入器。
if save_path not in self.vid_writer: # new video
# 如果启用了保存帧的功能( self.args.save_frames 为 True ),则创建保存帧图像的目录。
if self.args.save_frames:
Path(frames_path).mkdir(parents=True, exist_ok=True)
# 根据操作系统选择视频文件的扩展名和编解码器。
# macOS :使用 .mp4 和 avc1 编解码器。
# Windows :使用 .avi 和 WMV2 编解码器。
# 其他系统 :使用 .avi 和 MJPG 编解码器。
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
# 初始化 cv2.VideoWriter 对象,用于保存视频。
self.vid_writer[save_path] = cv2.VideoWriter(
# 保存路径,带有适当的扩展名。
filename=str(Path(save_path).with_suffix(suffix)),
# 编解码器。
fourcc=cv2.VideoWriter_fourcc(*fourcc),
# 帧率。
fps=fps, # integer required, floats produce error in MP4 codec
# 图像的宽度和高度。
frameSize=(im.shape[1], im.shape[0]), # (width, height)
)
# Save video
# 将当前帧写入视频文件。
self.vid_writer[save_path].write(im)
# 如果启用了保存帧的功能。
if self.args.save_frames:
# 则将当前帧保存为单独的图像文件。
cv2.imwrite(f"{frames_path}{frame}.jpg", im)
# Save images
# 如果数据集模式不是视频或流(即处理单张图像)。
else:
# 则直接将图像保存到指定路径。
cv2.imwrite(save_path, im)
# save_predicted_images 方法的核心功能是根据数据集的模式保存推理后的图像或视频结果。视频或流模式:初始化视频写入器(如果尚未初始化)。将当前帧写入视频文件。如果启用了保存帧的功能,将当前帧保存为单独的图像文件。图像模式:直接将图像保存到指定路径。这个方法确保了推理结果能够以适当的方式保存,适用于不同的输入源(图像、视频或流)。
# 这段代码定义了 BasePredictor 类中的 show 方法,用于在屏幕上显示绘制了预测结果的图像。它主要用于实时可视化,例如在开发阶段调试模型或在实时视频流中显示检测结果。
# 定义了 show 方法,接收一个参数。
# 1.p :表示显示窗口的标题。默认值为空字符串。
def show(self, p=""):
# 使用 OpenCV imshow() 在窗口中显示图像。
"""Display an image in a window using OpenCV imshow()."""
# 获取绘制了预测结果的图像( self.plotted_img )。这通常是经过标注(如绘制边界框、类别标签等)后的图像。
im = self.plotted_img
# 检查当前操作系统是否为 Linux,并且窗口标题 p 是否尚未创建过。 platform.system() :获取当前操作系统的名称。 p not in self.windows :检查窗口标题是否已经存在于 self.windows 列表中。
if platform.system() == "Linux" and p not in self.windows:
# 如果窗口尚未创建,则将 窗口标题 p 添加到 self.windows 列表中,以避免重复创建窗口。
self.windows.append(p)
# 创建一个 OpenCV 窗口,窗口标题为 p ,并设置窗口属性。 cv2.WINDOW_NORMAL :允许用户调整窗口大小。 cv2.WINDOW_KEEPRATIO :保持图像的宽高比。
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
# 将窗口大小调整为图像的宽度和高度,以确保图像完整显示。
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
# 在窗口 p 中显示图像 im 。
cv2.imshow(p, im)
# 调用 cv2.waitKey 等待用户输入。如果数据集模式为 "image" (处理单张图像),则等待 300 毫秒。 否则(处理视频或流),等待 1 毫秒。 这一步是为了确保图像能够正确显示,并且用户可以按任意键退出显示。
cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
# show 方法的核心功能是在屏幕上显示绘制了预测结果的图像。检查操作系统和窗口状态:如果在 Linux 系统下且窗口尚未创建,则初始化窗口并设置属性。显示图像:使用 cv2.imshow 在指定窗口中显示图像。等待用户输入:根据数据集模式(图像或视频/流),设置合适的等待时间。这个方法适用于实时可视化场景,例如在开发阶段调试模型或在实时视频流中显示检测结果。它通过 OpenCV 的窗口功能实现了简单而高效的图像显示。
# 这段代码定义了 BasePredictor 类中的 run_callbacks 方法,用于在特定事件发生时运行注册的回调函数。回调机制是一种常见的编程模式,用于在特定事件发生时执行用户定义的代码。
# 定义了 run_callbacks 方法,接收一个参数。
# 1.event :表示当前发生的事件名称(如 "on_predict_start" 或 "on_predict_end" )。事件名称通常是一个字符串,用于标识特定的回调触发点。
def run_callbacks(self, event: str):
# 针对特定事件运行所有已注册的回调。
"""Runs all registered callbacks for a specific event."""
# 从 self.callbacks 字典中获取与当前事件名称 event 对应的回调函数列表。 self.callbacks 是一个字典,键为事件名称,值为回调函数列表。 如果字典中没有找到对应的事件名称,则返回一个空列表 [] ,以避免抛出 KeyError 。 遍历回调函数列表,逐个执行回调函数。
for callback in self.callbacks.get(event, []):
# 调用当前回调函数,并将当前实例( self )作为参数传递给回调函数。这允许回调函数访问 BasePredictor 实例的属性和方法。
callback(self)
# 回调机制的作用 :
# 回调机制在许多框架和工具中被广泛使用,特别是在需要扩展或自定义行为的场景中。在 BasePredictor 类中,回调机制可以用于以下场景 :
# 扩展功能 :用户可以通过注册回调函数,在特定事件发生时插入自定义逻辑,而无需修改核心代码。 例如,在推理开始时记录日志、在推理结束时保存结果等。
# 监控和调试 :在推理的各个阶段(如预处理、推理、后处理)插入监控代码,用于性能分析或调试。
# 与其他系统集成 :通过回调函数,可以将推理结果传递给其他系统(如消息队列、数据库等)。
# run_callbacks 方法的核心功能是在特定事件发生时运行注册的回调函数。从 self.callbacks 字典中获取与事件名称对应的回调函数列表。遍历回调函数列表,逐个调用回调函数,并将当前实例( self )作为参数传递。回调机制为 BasePredictor 类提供了高度的灵活性和可扩展性,允许用户在不修改核心代码的情况下,插入自定义逻辑、监控推理过程或与其他系统集成。
# 这段代码定义了 BasePredictor 类中的 add_callback 方法,用于向指定事件添加回调函数。这个方法是回调机制的核心部分,允许用户动态地注册回调函数,以便在特定事件发生时执行自定义逻辑。
# 定义了 add_callback 方法,接收以下参数 :
# 1.event ( str ) :事件名称,标识回调函数触发的时机(如 "on_predict_start" 、 "on_predict_end" 等)。
# 2.func :回调函数,用户定义的函数,将在指定事件发生时被调用。
def add_callback(self, event: str, func):
# 添加回调。
"""Add callback."""
# 将回调函数 func 添加到 self.callbacks 字典中对应事件的列表中。 self.callbacks 是一个字典,键为事件名称,值为回调函数列表。如果事件名称 event 已经存在于字典中,则将 func 添加到对应的列表中。如果事件名称尚未存在,Python 字典会自动为该键创建一个空列表,并将 func 添加到其中。
self.callbacks[event].append(func)
# 作用和意义 :
# add_callback 方法的核心功能是允许用户动态注册回调函数,以便在特定事件发生时执行自定义逻辑。通过这种方式,用户可以在不修改核心代码的情况下,扩展或自定义 BasePredictor 的行为。这使得类更加灵活和可扩展。
# add_callback 方法的核心功能是向指定事件添加回调函数。将回调函数 func 添加到 self.callbacks 字典中对应事件的列表中。如果事件名称尚未存在,Python 字典会自动创建一个空列表,并将回调函数添加到其中。通过 add_callback 方法,用户可以动态地注册回调函数,从而在特定事件发生时执行自定义逻辑。这为 BasePredictor 类提供了高度的灵活性和可扩展性,允许用户在不修改核心代码的情况下,插入自定义逻辑、监控推理过程或与其他系统集成。
# BasePredictor 类是一个灵活且功能丰富的推理框架,设计用于高效执行图像或视频的实时推理任务。它通过集成预处理、模型推理、后处理和结果保存等功能,提供了一个完整的推理流程。类中支持多种输入源(如图像、视频流、摄像头)和多种输出方式(如保存图像、文本文件或显示结果),并利用回调机制和线程安全设计,确保了高度的可扩展性和灵活性。此外, BasePredictor 还提供了详细的性能分析和日志记录功能,使得用户能够轻松监控和优化推理过程,适用于从开发调试到生产部署的各种场景。