autobackend.py
ultralytics\nn\autobackend.py
目录
[2.def check_class_names(names):](#2.def check_class_names(names):)
[3.def default_class_names(data=None):](#3.def default_class_names(data=None):)
[4.class AutoBackend(nn.Module):](#4.class AutoBackend(nn.Module):)
1.所需的库和模块
python
# Ultralytics YOLO 🚀, AGPL-3.0 license
import ast
import contextlib
import json
import platform
import zipfile
from collections import OrderedDict, namedtuple
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, ROOT, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml
from ultralytics.utils.downloads import attempt_download_asset, is_url
2.def check_class_names(names):
python
# 这段代码定义了一个名为 check_class_names 的函数,其主要目的是验证和处理类别名称。这个函数在机器学习模型训练和部署中非常有用,尤其是在处理数据集的类别标签时。
# 函数签名。
# 1.names :一个包含类别名称的列表或字典。
def check_class_names(names):
# 检查类名。
# 如果需要,将 imagenet 类代码映射到人类可读的名称。将列表转换为字典。
"""
Check class names.
Map imagenet class codes to human-readable names if required. Convert lists to dicts.
"""
# 检查是否为列表并转换为字典。
if isinstance(names, list): # names is a list
# 如果 names 是一个列表,将其转换为字典。这是通过 enumerate 函数实现的,它将列表中的每个元素与其索引配对,从而创建一个从索引到类别名称的映射。
names = dict(enumerate(names)) # convert to dict
# 检查是否为字典并处理键和值。
# 如果 names 是一个字典,执行以下操作。
if isinstance(names, dict):
# Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
# 将所有字符串键转换为整数键。将所有非字符串值转换为字符串。
names = {int(k): str(v) for k, v in names.items()}
# 检查字典键的范围。
# 获取字典中键的数量。
n = len(names)
# 检查字典中的最大键值是否大于或等于 n 。如果是,抛出 KeyError 异常,因为这意味着类别索引超出了预期的范围(0 到 n-1 )。
if max(names.keys()) >= n:
raise KeyError(
f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " # {n} 类数据集需要类索引 0-{n - 1},但您的类索引无效。
f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." # min(names.keys())}-{max(names.keys())} 在数据集 YAML 中定义。
)
# 处理 ImageNet 类别代码。
# 如果字典中的第一个值是一个以 "n0" 开头的字符串,这通常表示 ImageNet 数据集的类别代码。
if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764'
# 从配置文件中加载 ImageNet 的类别代码到人类可读名称的映射。
# def yaml_load(file="data.yaml", append_filename=False):
# -> 从 YAML 文件中加载数据,并根据需要将文件名附加到数据字典中。返函数返回一个字典,包含从 YAML 文件中加载的数据。如果 append_filename 为 True ,则字典中还包括一个键 "yaml_file" ,其值为 YAML 文件的路径。
# -> return data
names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names
# 使用映射将所有类别代码转换为人类可读的名称。
names = {k: names_map[v] for k, v in names.items()}
# 返回值。函数返回处理后的 names 字典。
return names
# 这个函数确保输入的类别名称是有效的,并且如果需要,将它们转换为人类可读的名称。这对于确保模型训练和评估时类别标签的一致性和可读性非常重要。通过处理不同的输入格式(列表或字典),并确保类别索引的正确性,这个函数提高了代码的健壮性和灵活性。
3.def default_class_names(data=None):
python
# 这段代码定义了一个名为 default_class_names 的函数,其目的是从一个输入的 YAML 文件中加载默认的类别名称,或者在加载过程中出现错误时返回一组默认的数值类别名称。
# 函数签名。
# 1.data :一个可选参数,可以是 YAML 文件的路径或内容。
def default_class_names(data=None):
# 将默认类名应用于输入 YAML 文件或返回数字类名。
"""Applies default class names to an input YAML file or returns numerical class names."""
# 检查输入数据。如果 data 参数不为空,即用户提供了输入数据。
if data:
# contextlib.suppress(*exceptions)
# contextlib.suppress 是 Python 标准库 contextlib 模块中的一个上下文管理器,它用于临时忽略指定的异常。
# 当在 with 语句中使用时 (例 : with contextlib.suppress(ValueError): ) ,如果在代码块中发生了指定的异常, contextlib.suppress 会捕获这些异常并阻止它们传播,允许程序继续执行。
# 参数 :
# *exceptions :一个或多个异常类型,这些类型的异常将被忽略。
# 等效的 try-except 用法 :
# 使用 contextlib.suppress 可以达到与使用 try-except 块并配合 pass 关键字相同的效果,但通常更加简洁和明确。
# contextlib.suppress 是在 Python 3.4 中引入的,它提供了一种更优雅和 Pythonic 的方式来忽略特定的异常。
# 尝试加载 YAML 文件。
# 使用 contextlib.suppress 上下文管理器来捕获并忽略加载 YAML 文件时可能发生的任何异常。
with contextlib.suppress(Exception):
# 在异常被忽略的上下文中,调用 check_yaml 函数处理输入数据,然后使用 yaml_load 函数加载 YAML 文件,并返回其中的 "names" 键对应的值。
return yaml_load(check_yaml(data))["names"]
# 返回默认类别名称。如果 data 参数为空或在加载 YAML 文件时出现异常,函数将返回一个默认的类别名称字典。这个字典包含从 0 到 998 的整数键,每个键对应一个格式化的字符串 "class{i}" ,其中 i 是类别的索引。
# 返回值。函数返回一个包含类别名称的字典。如果成功加载了 YAML 文件并找到了 "names" 键,则返回该键对应的值;否则,返回一个包含默认类别名称的字典。
return {i: f"class{i}" for i in range(999)} # return default if above errors
# 这个函数提供了一种灵活的方式来处理类别名称,允许用户通过 YAML 文件自定义类别名称,同时提供了一个默认的后备方案,以防 YAML 文件加载失败或未提供 YAML 文件。这种设计使得函数更加健壮,能够适应不同的使用场景。
4.class AutoBackend(nn.Module):
python
# 这段代码定义了一个名为 AutoBackend 的类,它是 torch.nn.Module 的子类。
# 类定义。
class AutoBackend(nn.Module):
# 处理使用 Ultralytics YOLO 模型运行推理的动态后端选择。
# AutoBackend 类旨在为各种推理引擎提供抽象层。它支持多种格式,每种格式都有特定的命名约定,如下所述:
# | PyTorch | *.pt |
# 该类提供基于输入模型格式的动态后端切换功能,从而更轻松地跨各种平台部署模型。
"""
Handles dynamic backend selection for running inference using Ultralytics YOLO models.
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
range of formats, each with specific naming conventions as outlined below:
Supported Formats and Naming Conventions:
| Format | File Suffix |
|-----------------------|------------------|
| PyTorch | *.pt |
| TorchScript | *.torchscript |
| ONNX Runtime | *.onnx |
| ONNX OpenCV DNN | *.onnx (dnn=True)|
| OpenVINO | *openvino_model/ |
| CoreML | *.mlpackage |
| TensorRT | *.engine |
| TensorFlow SavedModel | *_saved_model |
| TensorFlow GraphDef | *.pb |
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model |
| NCNN | *_ncnn_model |
This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy
models across various platforms.
"""
# AutoBackend 类的构造函数 __init__ 初始化一个自动选择最优计算后端的模块。
# 构造函数签名。
# @torch.no_grad()
# @torch.no_grad() 是 PyTorch 中的一个装饰器,用于包裹一个函数,使得在该函数内部执行的所有操作都不会跟踪梯度,也就是说,这些操作不会计算梯度,也不会消耗计算图。
# 这通常用于推理(inference)或者评估(evaluation)阶段,因为在这些阶段我们不需要进行反向传播,因此不需要计算梯度,这样可以减少内存消耗并提高计算效率。
# 说明 :
# 当你使用 @torch.no_grad() 装饰器时,它会暂时将 PyTorch 设置为评估模式,在这个模式下,所有的 Variable 对象都会禁用梯度计算。
# 装饰器的作用域仅限于它所装饰的函数内部。
# 如果你需要在代码的某个特定区域临时禁用梯度计算,而不是整个函数,可以使用 torch.no_grad() 作为一个上下文管理器: with torch.no_grad(): # 在这个代码块中,梯度计算被禁用 result = model(input_data)
# 在示例中, evaluate_model 函数被 @torch.no_grad() 装饰,这意味着在函数内部执行的所有操作都不会跟踪梯度。这在模型评估时非常有用,因为评估时通常不需要进行梯度更新。
@torch.no_grad()
# 1.weights :模型权重文件的路径或模型权重列表。
# 2.device :模型运行的设备,默认为 CPU。
# 3.dnn :是否使用 DNN(深度神经网络)加速。
# 4.data :数据配置,可能包含数据集信息。
# 5.fp16 :是否使用半精度(16位浮点数)计算。
# 6.batch :批处理大小,默认为 1。
# 7.fuse :是否融合模型中的某些层(Conv2D + BatchNorm 层)以提高效率。
# 8.verbose :是否打印详细信息。
def __init__(
self,
weights="yolov8n.pt",
device=torch.device("cpu"),
dnn=False,
data=None,
fp16=False,
batch=1,
fuse=True,
verbose=True,
):
# 初始化 AutoBackend 进行推理。
"""
Initialize the AutoBackend for inference.
Args:
weights (str): Path to the model weights file. Defaults to 'yolov8n.pt'.
device (torch.device): Device to run the model on. Defaults to CPU.
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
batch (int): Batch-size to assume for inference.
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
verbose (bool): Enable verbose logging. Defaults to True.
"""
# 初始化父类。调用父类 nn.Module 的构造函数。
super().__init__()
# 处理权重文件。如果 weights 是列表,则取第一个元素;如果不是列表,则直接使用 weights 。
w = str(weights[0] if isinstance(weights, list) else weights)
# 检查权重类型。检查 weights 是否是 PyTorch 模型模块。
nn_module = isinstance(weights, torch.nn.Module)
# 模型类型判断。调用 _model_type 方法来判断模型的类型,返回不同类型的布尔值。
# def _model_type(p="path/to/model.pt"): -> 确定给定模型路径 p 所指向的模型文件类型。方法返回一个布尔值列表和一个额外的布尔值,表示模型是否为特定的格式和是否为 Triton 服务格式。 -> return types + [triton]
(
pt,
jit,
onnx,
xml,
engine,
coreml,
saved_model,
pb,
tflite,
edgetpu,
tfjs,
paddle,
ncnn,
triton,
) = self._model_type(w)
# FP16 和 NHWC 支持。
# 如果模型是 PyTorch 权重、TorchScript、ONNX、OpenVINO XML、OpenVINO Engine、CoreML、SavedModel、TensorRT、TensorFlow Lite、EdgeTPU、TensorFlow.js、PaddlePaddle、NCNN 或 Triton 格式,则启用 FP16。
# &= 是一个复合赋值运算符,它用于按位与(AND)操作。这个运算符会将左侧变量与右侧表达式的结果进行按位与操作,并将结果赋值回左侧变量。按位与操作意味着两个比特位进行逻辑与操作,只有当两个位都是 1 时,结果才为 1,否则为 0。
# &= 运算符是原地操作,这意味着它直接修改左侧变量的值,而不是创建一个新的变量。
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
# 如果模型是 CoreML、SavedModel、Protocol Buffers、TensorFlow Lite 或 EdgeTPU 格式,则使用 NHWC 数据格式(Batch Height Width Channel)。
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
# 默认 stride。设置默认的 stride(步长)。
stride = 32 # default stride
# 模型和元数据初始化。初始化模型和元数据变量。
model, metadata = None, None
# AutoBackend 类旨在根据提供的权重文件和其他参数自动选择最优的计算后端。它处理不同的模型格式,并根据模型类型和设备能力设置 FP16 和 NHWC 支持。这个类的设计目的是为了提高模型运行的灵活性和效率,特别是在多设备和多框架环境中。
# 这段代码继续处理 AutoBackend 类的构造函数中的一部分逻辑,用于设置计算设备和下载非本地模型。
# Set device 设置设备。
# 这行代码检查是否有可用的 CUDA 设备(GPU),并且设备类型不是 CPU。如果这两个条件都满足, cuda 变量将被设置为 True ,表示可以使用 GPU 进行计算。
cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA
# 检查 GPU 数据加载格式。
# 如果 cuda 为 True 但模型不是 PyTorch 模型( nn_module )、PyTorch 权重( pt )、TorchScript( jit )、OpenVINO 引擎( engine )或 ONNX 模型( onnx ),则将设备设置为 CPU,并将 cuda 设置为 False 。+
# 这意味着如果模型不支持 GPU 加速,那么即使有可用的 GPU,也会强制使用 CPU。
if cuda and not any([nn_module, pt, jit, engine, onnx]): # GPU dataloader formats
device = torch.device("cpu")
cuda = False
# Download if not local 如果非本地则下载。
# 如果模型不是 PyTorch 权重、Triton 模型或 PyTorch 模型模块,则意味着模型可能位于远程位置。
if not (pt or triton or nn_module):
# 在这种情况下,调用 attempt_download_asset 函数尝试下载模型权重文件。 w 是模型权重的路径或 URL, attempt_download_asset 函数会尝试从该路径或 URL 下载文件,并返回下载后的本地路径。
# def attempt_download_asset(file, repo="ultralytics/assets", release="v8.2.0", **kwargs):
# -> 其目的是尝试下载一个资产(如模型权重文件),如果该资产在本地不存在的话。这个函数处理了从 GitHub 发布页面下载文件的逻辑,并且能够处理不同的文件路径和 URL。返回下载后的文件路径。
# -> return str(file)
w = attempt_download_asset(w)
# 这段代码的目的是确保模型运行在合适的设备上,并且如果模型不在本地,能够自动下载模型。这样的设计使得模型的加载和运行更加灵活和方便。
# 这段代码是 AutoBackend 类构造函数的一部分,它处理两种情况:当模型已经是一个 PyTorch 模型模块( nn_module )时,以及当模型是 PyTorch 权重文件( .pt 文件)时。
# In-memory PyTorch model 内存中的 PyTorch 模型。
# 如果 weights 是一个 PyTorch 模型模块( nn_module ),则将其移动到指定的 device 。
if nn_module:
model = weights.to(device)
# 如果设置了 fuse ,则调用模型的 fuse 方法来融合某些层,以优化模型结构。
if fuse:
model = model.fuse(verbose=verbose)
# 如果模型有 kpt_shape 属性,表示这是一个姿态估计模型,记录关键点的形状。
if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
# 设置模型的 stride,取模型最大 stride 和 32 的最大值。
stride = max(int(model.stride.max()), 32) # model stride
# 获取模型的类别名称,如果模型有 module 属性,则从 module 获取 names ,否则直接从模型获取。
names = model.module.names if hasattr(model, "module") else model.names # get class names
# 如果设置了 fp16 ,则将模型转换为半精度( half() ),否则保持全精度( float() )。
model.half() if fp16 else model.float()
# 将模型显式分配给 self.model ,以便可以调用 to() , cpu() , cuda() , half() 等方法。
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
# 将 pt 设置为 True ,表示模型是 PyTorch 权重格式。
pt = True
# PyTorch
# PyTorch 权重文件。
# 如果 weights 是 PyTorch 权重文件( pt ),则导入 attempt_load_weights 函数来加载权重。
elif pt:
# def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# -> 加载一个模型集合(ensemble)的权重,或者单个模型的权重。这个函数处理了权重文件的加载、模型的兼容性更新、模型的融合以及模型集合的创建。函数返回创建的模型集合对象。
# -> return ensemble
from ultralytics.nn.tasks import attempt_load_weights
# 使用 attempt_load_weights 函数加载权重,并将加载的模型移动到指定的 device 。
model = attempt_load_weights(
weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
)
# 如果模型有 kpt_shape 属性,表示这是一个姿态估计模型,记录关键点的形状。
if hasattr(model, "kpt_shape"):
kpt_shape = model.kpt_shape # pose-only
# 设置模型的 stride,取模型最大 stride 和 32 的最大值。
stride = max(int(model.stride.max()), 32) # model stride
# 获取模型的类别名称,如果模型有 module 属性,则从 module 获取 names ,否则直接从模型获取。
names = model.module.names if hasattr(model, "module") else model.names # get class names
# 如果设置了 fp16 ,则将模型转换为半精度( half() ),否则保持全精度( float() )。
model.half() if fp16 else model.float()
# 将模型显式分配给 self.model ,以便可以调用 to() , cpu() , cuda() , half() 等方法。
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
# 这段代码的目的是确保模型无论是以 PyTorch 模型模块的形式还是以权重文件的形式提供,都能被正确地加载、移动到指定设备、设置精度,并最终分配给 self.model 。这样的设计使得模型的加载和运行更加灵活和方便。
# 可忽略---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------↓
# TorchScript
elif jit:
LOGGER.info(f"Loading {w} for TorchScript inference...")
extra_files = {"config.txt": ""} # model metadata
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files["config.txt"]: # load metadata dict
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items()))
# ONNX OpenCV DNN
elif dnn:
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...")
check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w)
# ONNX Runtime
elif onnx:
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
if IS_RASPBERRYPI or IS_JETSON:
# Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson
check_requirements("numpy==1.23.5")
import onnxruntime
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map
# OpenVINO
elif xml:
LOGGER.info(f"Loading {w} for OpenVINO inference...")
check_requirements("openvino>=2024.0.0")
import openvino as ov
core = ov.Core()
w = Path(w)
if not w.is_file(): # if not *.xml
w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
if ov_model.get_parameters()[0].get_layout().empty:
ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))
# OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT'
inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY"
LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...")
ov_compiled_model = core.compile_model(
ov_model,
device_name="AUTO", # AUTO selects best available device, do not modify
config={"PERFORMANCE_HINT": inference_mode},
)
input_name = ov_compiled_model.input().get_any_name()
metadata = w.parent / "metadata.yaml"
# TensorRT
elif engine:
LOGGER.info(f"Loading {w} for TensorRT inference...")
try:
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
except ImportError:
if LINUX:
check_requirements("tensorrt>7.0.0,<=10.1.0")
import tensorrt as trt # noqa
check_version(trt.__version__, ">=7.0.0", hard=True)
check_version(trt.__version__, "<=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
if device.type == "cpu":
device = torch.device("cuda:0")
Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
logger = trt.Logger(trt.Logger.INFO)
# Read file
with open(w, "rb") as f, trt.Runtime(logger) as runtime:
try:
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
except UnicodeDecodeError:
f.seek(0) # engine file may lack embedded Ultralytics metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
# Model context
try:
context = model.create_execution_context()
except Exception as e: # model is None
LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n")
raise e
bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
dynamic = False
is_trt10 = not hasattr(model, "num_bindings")
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
for i in num:
if is_trt10:
name = model.get_tensor_name(i)
dtype = trt.nptype(model.get_tensor_dtype(name))
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
if is_input:
if -1 in tuple(model.get_tensor_shape(name)):
dynamic = True
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_tensor_shape(name))
else: # TensorRT < 10.0
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
is_input = model.binding_is_input(i)
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1]))
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size
# CoreML
elif coreml:
LOGGER.info(f"Loading {w} for CoreML inference...")
import coremltools as ct
model = ct.models.MLModel(w)
metadata = dict(model.user_defined_metadata)
# TF SavedModel
elif saved_model:
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
import tensorflow as tf
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
metadata = Path(w) / "metadata.yaml"
# TF GraphDef
elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
import tensorflow as tf
from ultralytics.engine.exporter import gd_outputs
def wrap_frozen_graph(gd, inputs, outputs):
"""Wrap frozen graphs for deployment."""
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
platform.system()
]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# Load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, "r") as model:
meta_file = model.namelist()[0]
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8"))
# TF.js
elif tfjs:
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.")
# PaddlePaddle
elif paddle:
LOGGER.info(f"Loading {w} for PaddlePaddle inference...")
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle")
import paddle.inference as pdi # noqa
w = Path(w)
if not w.is_file(): # if not *.pdmodel
w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir
config = pdi.Config(str(w), str(w.with_suffix(".pdiparams")))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
metadata = w.parents[1] / "metadata.yaml"
# NCNN
elif ncnn:
LOGGER.info(f"Loading {w} for NCNN inference...")
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN
import ncnn as pyncnn
net = pyncnn.Net()
net.opt.use_vulkan_compute = cuda
w = Path(w)
if not w.is_file(): # if not *.param
w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir
net.load_param(str(w))
net.load_model(str(w.with_suffix(".bin")))
metadata = w.parent / "metadata.yaml"
# NVIDIA Triton Inference Server
elif triton:
check_requirements("tritonclient[all]")
from ultralytics.utils.triton import TritonRemoteModel
model = TritonRemoteModel(w)
# Any other format (unsupported)
else:
from ultralytics.engine.exporter import export_formats
raise TypeError(
f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n"
f"See https://docs.ultralytics.com/modes/predict for help."
)
# 可忽略---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------↑
# Load external metadata YAML 加载外部元数据 YAML。
# 这段代码是用于加载和处理外部元数据的Python代码片段。它使用了 yaml_load 函数,并且处理了元数据中的特定字段。
# 检查 metadata 是否是 字符串 或 Path 对象,并且对应的文件或路径是否存在。
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
# 如果 metadata 是一个有效的文件路径,那么使用 yaml_load 函数加载YAML文件内容到 metadata 变量中。
# def yaml_load(file="data.yaml", append_filename=False):
# -> 从 YAML 文件中加载数据,并根据需要将文件名附加到数据字典中。返函数返回一个字典,包含从 YAML 文件中加载的数据。如果 append_filename 为 True ,则字典中还包括一个键 "yaml_file" ,其值为 YAML 文件的路径。
# -> return data
metadata = yaml_load(metadata)
# 检查 metadata 是否存在且是否是一个字典。
if metadata and isinstance(metadata, dict):
# 遍历 metadata 字典中的所有键值对。
for k, v in metadata.items():
# 如果键 k 是 "stride" 或 "batch" ,则将对应的值 v 转换为整数。
if k in {"stride", "batch"}:
metadata[k] = int(v)
# 如果键 k 是 "imgsz" 、 "names" 或 "kpt_shape" ,并且值 v 是一个字符串,则使用 eval 函数来计算字符串表示的值。
elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
metadata[k] = eval(v)
# 从 metadata 字典中获取 "stride" 的值,并赋值给变量 stride 。
stride = metadata["stride"]
# 从 metadata 字典中获取 "task" 的值,并赋值给变量 task 。
task = metadata["task"]
# 从 metadata 字典中获取 "batch" 的值,并赋值给变量 batch 。
batch = metadata["batch"]
# 从 metadata 字典中获取 "imgsz" 的值,并赋值给变量 imgsz 。
imgsz = metadata["imgsz"]
# 从 metadata 字典中获取 "names" 的值,并赋值给变量 names 。
names = metadata["names"]
# value = dict.get(key, default=None)
# dict.get() 是 Python 字典( dict )类型提供的一个方法,用于从字典中获取指定键(key)对应的值(value)。如果键不存在于字典中,它将返回一个默认值,这个默认值可以是调用方法时指定的,如果没有指定,则默认为 None 。
# 参数说明 :
# key :要检索的键。
# default :如果键不在字典中,返回的默认值。如果未提供此参数,且键不存在时,默认返回 None 。
# 返回值 :
# 返回字典中键 key 对应的值,如果键不存在,则返回 default 指定的值或 None 。
# 注意事项 :
# 使用 dict.get() 方法可以避免在访问字典键时出现 KeyError 异常,当键不存在时,它提供了一种更安全的访问方式。
# 如果需要在键不存在时执行某些操作,可以在 get() 方法中设置一个特定的默认值,或者根据返回的 None 值来决定后续操作。
# 从 metadata 字典中安全地获取 "kpt_shape" 的值(如果存在),并赋值给变量 kpt_shape 。
kpt_shape = metadata.get("kpt_shape")
# 如果 pt 、 triton 或 nn_module 都不为真,则执行以下代码。
elif not (pt or triton or nn_module):
# 使用日志记录器 LOGGER 记录一条警告信息,提示没有找到模型的元数据。
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") # 警告⚠️未找到"model={weights}"的元数据。
# 这段代码的目的是确保元数据被正确加载和处理,以便后续可以使用这些元数据来配置模型或执行其他任务。如果元数据文件不存在或格式不正确,代码将记录一条警告信息。
# Check names
# 这段代码是用于检查和验证 names 变量的Python代码片段。
# 检查当前局部变量中是否没有 names 这个变量。 locals() 函数返回当前局部符号表,它是一个字典,包含了当前局部作用域中的所有变量。
if "names" not in locals(): # names missing
# 如果 names 变量缺失,那么调用 default_class_names 函数,并传入 data 参数,将返回的值赋给 names 变量。这个函数可能是用来生成默认的类别名称列表。
# def default_class_names(data=None):
# -> 从一个输入的 YAML 文件中加载默认的类别名称,或者在加载过程中出现错误时返回一组默认的数值类别名称。函数返回一个包含类别名称的字典。如果成功加载了 YAML 文件并找到了 "names" 键,则返回该键对应的值;否则,返回一个包含默认类别名称的字典。
# -> return {i: f"class{i}" for i in range(999)} # return default if above errors
names = default_class_names(data)
# 无论 names 是否已经定义,都会调用 check_class_names 函数来验证 names 变量,并将其返回值重新赋给 names 变量。这个函数可能是用来确保 names 变量包含了有效的类别名称,并且可能还会进行一些格式上的校验或转换。
# def check_class_names(names): -> 验证和处理类别名称。函数返回处理后的 names 字典。 -> return names
names = check_class_names(names)
# 这段代码的目的是确保 names 变量存在,并且包含有效的类别名称。如果 names 没有在之前的代码中被定义,它会使用 default_class_names 函数生成一个默认的类别名称列表。然后,无论 names 是否是默认生成的,都会通过 check_class_names 函数来确保其有效性。这样的设计可以在不同的上下文中灵活地处理类别名称,确保程序的健壮性。
# Disable gradients 禁止梯度。
# 这是一个条件语句,检查变量 pt 是否为真(即非零、非空或非False)。如果 pt 为真,那么执行下面的代码块。
if pt:
# 遍历 model 对象的所有参数。 model.parameters() 是一个迭代器,它返回模型中所有参数的迭代器。
for p in model.parameters():
# 对于每个参数 p ,将其 requires_grad 属性设置为 False 。在PyTorch中, requires_grad 属性决定了是否需要计算该参数的梯度。将其设置为 False 意味着在反向传播过程中,这个参数的梯度不会被计算,这通常用于冻结模型的某些层,使得这些层在训练过程中不会被更新。
p.requires_grad = False
# 这段代码的目的是冻结模型的所有参数,使得在后续的训练过程中,这些参数的值不会被改变。这在迁移学习或者微调模型时非常有用,当你想要保持模型的某些层不变,只训练其他层时,就可以使用这种方法。如果 pt 为假,那么这个代码块将不会被执行,模型的参数将保持原有的 requires_grad 状态。
# 用于将当前局部变量的值赋给对象的属性。
# self.__dict__ : 这是访问对象属性字典的一种方式。在Python中,每个对象都有一个 __dict__ 属性,它是一个字典,包含了对象的所有属性和对应的值。
# update(locals()) : locals() 函数返回当前局部作用域中的所有变量的字典。 update 方法用于将一个字典的键值对更新到另一个字典中。
# self.__dict__.update(locals()) : 这行代码将当前局部作用域中的所有变量和它们的值添加到 self 对象的属性字典中。这意味着,任何在当前函数或方法中定义的局部变量都会成为 self 对象的属性。
# 这种技术通常用在类的构造函数( __init__ 方法)中,以动态地将传入的参数转换为对象的属性。这样做的好处是可以让对象的属性与构造函数的参数保持一致,而不需要显式地为每个参数编写属性赋值语句。
self.__dict__.update(locals()) # assign all variables to self
# 这段代码定义了一个名为 forward 的方法,这个方法可能是一个深度学习模型类的一部分,用于执行模型的前向传播。
# 定义一个名为 forward 的方法,它接受四个参数。
# 1.self :指向类实例的引用。
# 2.im :输入图像。
# 3.augment :是否进行数据增强,默认为 False 。
# 4.visualize :是否进行可视化,默认为 False 。
# 6.embed :嵌入信息,默认为 None 。
def forward(self, im, augment=False, visualize=False, embed=None):
# 在 YOLOv8 MultiBackend 模型上运行推理。
"""
Runs inference on the YOLOv8 MultiBackend model.
Args:
im (torch.Tensor): The image tensor to perform inference on.
augment (bool): whether to perform data augmentation during inference, defaults to False
visualize (bool): whether to visualize the output predictions, defaults to False
embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
"""
# 从输入图像 im 中提取 批次大小( b ) 、 通道数( ch ) 、 高度( h ) 和 宽度( w )。
b, ch, h, w = im.shape # batch, channel, height, width
# 检查模型是否配置为使用半精度浮点数( self.fp16 为 True ),并且输入图像 im 的数据类型不是半精度浮点数( torch.float16 )。
if self.fp16 and im.dtype != torch.float16:
# 如果上述条件满足,将输入图像 im 转换为半精度浮点数(FP16),这通常用于减少模型的内存使用和加速计算。
im = im.half() # to FP16
# 检查模型是否配置为使用非标准的数据格式( self.nhwc 为 True ),其中数据格式为 批大小 、 高度 、 宽度 、 通道数 (NHWC) 。
if self.nhwc:
# 如果模型使用 NHWC 格式,将输入图像 im 的维度从 PyTorch 的标准格式(批次、通道、高度、宽度,即 BCHW)转换为 NHWC 格式。
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
# PyTorch
# 检查模型是否是基于 PyTorch 的模型( self.pt 为 True )或者是一个神经网络模块( self.nn_module 为 True )。
if self.pt or self.nn_module:
# 如果模型是基于 PyTorch 的模型或神经网络模块,调用模型的 self.model 方法进行前向传播,并传入输入图像 im 以及其他参数。
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
# 这个方法的主要作用是处理输入图像,根据模型的配置调整图像的数据类型和格式,然后执行模型的前向传播。 augment 、 visualize 和 embed 参数提供了额外的控制,以便在前向传播过程中进行数据增强、可视化或嵌入额外信息。
# 可忽略---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------↓
# TorchScript
elif self.jit:
y = self.model(im)
# ONNX OpenCV DNN
elif self.dnn:
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
y = self.net.forward()
# ONNX Runtime
elif self.onnx:
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
# OpenVINO
elif self.xml:
im = im.cpu().numpy() # FP32
if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes
n = im.shape[0] # number of images in batch
results = [None] * n # preallocate list with None to match the number of images
def callback(request, userdata):
"""Places result in preallocated list using userdata index."""
results[userdata] = request.results
# Create AsyncInferQueue, set the callback and start asynchronous inference for each input image
async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model)
async_queue.set_callback(callback)
for i in range(n):
# Start async inference with userdata=i to specify the position in results list
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW
async_queue.wait_all() # wait for all inference requests to complete
y = np.concatenate([list(r.values())[0] for r in results])
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1
y = list(self.ov_compiled_model(im).values())
# TensorRT
elif self.engine:
if self.dynamic or im.shape != self.bindings["images"].shape:
if self.is_trt10:
self.context.set_input_shape("images", im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
else:
i = self.model.get_binding_index("images")
self.context.set_binding_shape(i, im.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
s = self.bindings["images"].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs["images"] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
# CoreML
elif self.coreml:
im = im[0].cpu().numpy()
im_pil = Image.fromarray((im * 255).astype("uint8"))
# im = im.resize((192, 320), Image.BILINEAR)
y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
if "confidence" in y:
raise TypeError(
"Ultralytics only supports inference of non-pipelined CoreML models exported with "
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
)
# TODO: CoreML NMS inference handling
# from ultralytics.utils.ops import xywh2xyxy
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
# conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
# y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif len(y) == 1: # classification model
y = list(y.values())
elif len(y) == 2: # segmentation model
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
# PaddlePaddle
elif self.paddle:
im = im.cpu().numpy().astype(np.float32)
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
# NCNN
elif self.ncnn:
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
with self.net.create_extractor() as ex:
ex.input(self.net.input_names()[0], mat_in)
# WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130
y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())]
# NVIDIA Triton Inference Server
elif self.triton:
im = im.cpu().numpy() # torch to numpy
y = self.model(im)
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
else:
im = im.cpu().numpy()
if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im)
if not isinstance(y, list):
y = [y]
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im))
else: # Lite or Edge TPU
details = self.input_details[0]
is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
if is_int:
scale, zero_point = details["quantization"]
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
self.interpreter.set_tensor(details["index"], im)
self.interpreter.invoke()
y = []
for output in self.output_details:
x = self.interpreter.get_tensor(output["index"])
if is_int:
scale, zero_point = output["quantization"]
x = (x.astype(np.float32) - zero_point) * scale # re-scale
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
if x.shape[-1] == 6: # end-to-end model
x[:, :, [0, 2]] *= w
x[:, :, [1, 3]] *= h
else:
x[:, [0, 2]] *= w
x[:, [1, 3]] *= h
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
if y[1].shape[-1] == 6: # end-to-end model
y = [y[1]]
else:
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
# 可忽略---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------↑
# 它主要负责检查输出结果 y 的类型,并根据输出结果的类型和任务类型来处理和转换数据。
# 这是一个被注释掉的调试代码块,用于打印输出结果中每个元素的类型和长度或形状。
# for x in y:
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
# 检查 y 是否是一个列表或元组。
if isinstance(y, (list, tuple)):
# 如果 self.names 的长度为999(通常表示未定义或默认值),并且任务类型为 "segment" 或者 y 的长度为2(通常表示输出包含两个元素,例如分割掩码和边界框),则执行以下代码。
if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined
# 根据 y[0] 的形状维度来确定 protos 和 boxes 的索引。如果 y[0] 是四维的,则假设第一个元素是 protos ,第二个元素是 boxes ;否则相反。
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
# 计算类别数量 nc ,这是通过从 boxes 的第二个维度减去 protos 的第三个维度和4来得到的。
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
# 为每个类别创建一个名称,并将这些名称存储在 self.names 字典中。
self.names = {i: f"class{i}" for i in range(nc)}
# 如果 y 只包含一个元素,那么调用 self.from_numpy 方法将这个元素从NumPy数组转换为所需的格式,并返回结果。如果 y 包含多个元素,那么对每个元素都进行转换,并将结果作为一个列表返回。
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
# 如果 y 不是一个列表或元组。
else:
# 那么直接调用 self.from_numpy 方法将 y 从NumPy数组转换为所需的格式,并返回结果。
return self.from_numpy(y)
# 这段代码的目的是处理模型的输出,根据任务类型和输出结果的类型来动态地处理和转换数据。特别是在处理分割任务时,它能够根据输出结果的形状来推断类别数量,并为每个类别分配一个默认名称。最后,它将输出结果从NumPy数组转换为所需的格式。
# 这段代码定义了一个名为 from_numpy 的方法,它的作用是将 NumPy 数组转换为 PyTorch 张量,并将其移动到指定的设备上。
# 定义一个名为 from_numpy 的方法,接受两个参数。
# 1.self :指向类实例的引用。
# 2.x :要转换的数据。
def from_numpy(self, x):
# 将 numpy 数组转换为张量。
"""
Convert a numpy array to a tensor.
Args:
x (np.ndarray): The array to be converted.
Returns:
(torch.Tensor): The converted tensor
"""
# 这是一个条件表达式,它检查 x 是否是一个 NumPy 数组(通过 isinstance(x, np.ndarray) )。
# 如果 x 是 NumPy 数组,那么执行以下操作 :
# torch.tensor(x) :将 NumPy 数组 x 转换为 PyTorch 张量。
# .to(self.device) :将转换得到的 PyTorch 张量移动到 self.device 指定的设备上,这个设备可以是 CPU 或 GPU。
# 如果 x 不是 NumPy 数组,那么直接返回 x 。
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
# 这个方法的主要作用是确保输入的数据是 PyTorch 张量,并且位于正确的设备上,这对于在 PyTorch 中进行计算是必要的。如果输入数据已经是 PyTorch 张量或其他类型,那么它将不会被转换,直接返回。这样的设计使得方法更加灵活,能够处理不同类型的输入数据。
# 这段代码定义了一个名为 warmup 的方法,它用于在实际推理之前对模型进行预热。预热的目的是减少模型在实际使用时的延迟,特别是在使用特定的模型优化技术时。
# 定义一个名为 warmup 的方法,接受两个参数。
# 1.self :指向类实例的引用。
# 2.imgsz :输入图像的尺寸,默认为 (1, 3, 640, 640) ,即批次大小为1,通道数为3,高度和宽度均为640。
def warmup(self, imgsz=(1, 3, 640, 640)):
# 通过使用虚拟输入运行一次前向传递来预热模型。
"""
Warm up the model by running one forward pass with a dummy input.
Args:
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
"""
# 导入 torchvision 库, # noqa 注释用于忽略某些代码风格检查工具的警告,因为在这里导入 torchvision 是为了确保导入时间不被记录在后续处理时间中。
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
# 创建一个元组 warmup_types ,包含模型的各种类型标识,例如是否是 PyTorch 模型 ( self.pt ),是否是 JIT 编译的 ( self.jit ),是否是 ONNX 模型 ( self.onnx ) 等。
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
# 检查 warmup_types 中是否有任何为真的值,并且设备不是 CPU 或者是 Triton 模型。如果条件为真,则执行预热。
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
# 创建一个空的 PyTorch 张量 im ,其尺寸由 imgsz 指定,数据类型根据 self.fp16 的值决定(如果是半精度则为 torch.half ,否则为 torch.float ),并且位于 self.device 指定的设备上。
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
# 如果模型是 JIT 编译的,则预热两次,否则预热一次。
for _ in range(2 if self.jit else 1):
# 调用 self.forward 方法,传入预热用的输入张量 im ,进行前向传播,以预热模型。
self.forward(im) # warmup
# 这个方法的目的是确保模型在实际推理之前已经准备好,特别是在使用 GPU 或其他加速器时,预热可以帮助减少推理延迟。通过在实际输入之前运行模型几次,可以预热 GPU 缓存,优化 JIT 编译等,从而提高模型的推理效率。
# 这段代码定义了一个名为 _model_type 的静态方法,其作用是确定给定模型路径 p 所指向的模型文件类型。这个方法返回一个布尔值列表,表示模型是否为特定的格式,以及一个额外的布尔值表示模型是否为 Triton 服务格式。
# 方法签名。
# @staticmethod
#在 Python 中, @staticmethod 是一个装饰器,用于将一个方法声明为静态方法。静态方法与类的其他方法不同,它不会接收隐式的 self 参数(即不会自动接收类的实例或类本身作为第一个参数)。
# 这意味着静态方法不依赖于类的实例或类本身的状态,它们可以被看作是仅仅属于类的函数,而不是与特定的实例相关联。
# 静态方法的特点 :
# 静态方法可以通过类直接调用,也可以通过类的实例调用,但它们不会影响类的状态或实例的状态。
# 静态方法通常用于实现与类相关的功能,但不需要访问类的属性或方法。
# 静态方法的应用场景 :
# 当函数逻辑与类相关,但不需要访问类或实例的任何属性时。
# 当函数需要在多个类之间共享时。
# 当函数需要在类的方法中被调用,但不需要类的方法传递 self 参数时。
# 使用 @staticmethod 可以使代码更加清晰和模块化,同时避免了不必要的实例化开销。
@staticmethod
# 1.p :模型文件的路径或URL,默认为 "path/to/model.pt" 。
def _model_type(p="path/to/model.pt"):
# 获取模型文件的路径并返回模型类型。可能的类型有 pt、jit、onnx、xml、engine、coreml、saved_model、pb、tflite、edgetpu、tfjs、ncnn 或 paddle。
"""
Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml,
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
Args:
p: path to the model file. Defaults to path/to/model.pt
Examples:
>>> model = AutoBackend(weights="path/to/model.onnx")
>>> model_type = model._model_type() # returns "onnx"
"""
# 导入导出格式。从 ultralytics.engine.exporter 模块导入 export_formats 函数,该函数返回一个包含不同模型导出格式后缀的字典。
from ultralytics.engine.exporter import export_formats
# 获取导出后缀。从 export_formats 函数获取模型导出格式的后缀列表。
sf = export_formats()["Suffix"] # export suffixes
# 检查模型后缀。
# 如果 p 不是 URL 也不是字符串,则检查 p 的后缀。
if not is_url(p) and not isinstance(p, str):
# 调用 check_suffix 函数来检查模型文件 p 的后缀是否有效。
check_suffix(p, sf) # checks
# 获取模型文件名。使用 pathlib.Path 获取模型文件的文件名。
name = Path(p).name
# 检查文件类型。检查文件名 name 是否包含 sf 列表中的后缀,并返回一个布尔值列表。
types = [s in name for s in sf]
# 特殊格式处理。
# 对于 Apple CoreML 模型,保留对旧格式 .mlmodel 的支持。
types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats
# 确保 TensorFlow Lite 模型和 Edge TPU 模型不会同时为真。
types[8] &= not types[9] # tflite &= not edgetpu
# 检查 Triton 服务格式。
# Triton 服务格式是指 NVIDIA Triton Inference Server(之前称为 TensorRT Inference Server)支持的模型部署和推理服务的格式。Triton Inference Server 是一个开源软件解决方案,旨在简化生产环境中大规模部署人工智能模型的过程
# 如果模型是已知的格式,则 triton 设置为 False 。
if any(types):
triton = False
# 如果模型不是已知的格式,则进一步检查是否为 Triton 服务格式。
else:
# 导入 urlsplit 函数用于解析 URL。
from urllib.parse import urlsplit
# 解析 p 为一个 URL 对象。
url = urlsplit(p)
# 如果 URL 有网络位置( netloc )和路径( path ),并且 scheme 是 http 或 grpc ,则认为是一个 Triton 服务。
triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"}
# 返回值。方法返回一个布尔值列表和一个额外的布尔值,表示模型是否为特定的格式和是否为 Triton 服务格式。
return types + [triton]
# _model_type 方法用于自动检测模型文件的类型,这对于确定如何加载和处理模型非常重要。这个方法考虑了多种模型格式,并能够识别 Triton 服务,使得模型加载过程更加灵活和健壮。