downloads.py
ultralytics\utils\downloads.py
目录
[2.def is_url(url, check=False):](#2.def is_url(url, check=False):)
[3.def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):](#3.def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):)
[4.def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):](#4.def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):)
[5.def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):](#5.def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):)
[6.def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):](#6.def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):)
[7.def get_google_drive_file_info(link):](#7.def get_google_drive_file_info(link):)
[8.def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,):](#8.def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,):)
[9.def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):](#9.def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):)
[10.def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):](#10.def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):)
[11.def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):](#11.def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):)
1.所需的库和模块
python
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import re
import shutil
import subprocess
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from urllib import parse, request
import requests
import torch
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# 这段代码定义了与 Ultralytics GitHub 资产仓库 https://github.com/ultralytics/assets 相关的一些变量,这些资产包括预训练模型和其他资源。
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
# 定义了一个变量 GITHUB_ASSETS_REPO ,它存储了 Ultralytics 资产的 GitHub 仓库名称。
GITHUB_ASSETS_REPO = "ultralytics/assets"
# 定义了一个变量 GITHUB_ASSETS_NAMES ,它是一个元组,包含了 Ultralytics 仓库中所有资产的文件名。这个元组是通过多个列表推导式构建的,每个列表推导式生成一系列相关的文件名。
GITHUB_ASSETS_NAMES = (
# 这个列表推导式生成了一系列 YOLOv8 模型的文件名,包括不同大小(n、s、m、l、x)和后缀(无、-cls、-seg、-pose、-obb)的变体。
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
# 类似的列表推导式为 YOLOv5 、 YOLOv3 、 YOLOv8-world 模型、YOLOv8-worldv2 模型、 YOLOv9 、 YOLO NAS 、 SAM 、 FastSAM 和 RTDETR 模型生成文件名。
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
# 这个列表添加了移动版 SAM 模型的文件名。
+ ["mobile_sam.pt"]
# 这个列表添加了一个校准图像样本数据的文件名。
+ ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
)
# 定义了一个变量 GITHUB_ASSETS_STEMS ,它是一个列表,包含了 GITHUB_ASSETS_NAMES 中每个文件名的干(stem),即不带扩展名的文件名。这是通过列表推导式和 Path 对象的 stem 属性实现的。
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
# 这段代码的主要作用是定义 Ultralytics GitHub 资产仓库中的资产文件名和它们的干名称。这些信息可能用于识别和下载特定的资产,例如在设置项目环境或执行模型推理时。通过将这些名称存储在变量中,代码提高了可维护性和可读性。
2.def is_url(url, check=False):
python
# 这段代码定义了一个名为 is_url 的函数,用于检查给定的字符串是否是一个有效的 URL,并可选地检查该 URL 是否在线可访问。
# 定义了函数 is_url ,它接受两个参数。
# 1.url :要检查的 URL 字符串。
# 2.check :一个布尔值,默认为 False 。如果为 True ,则会检查 URL 是否在线可访问。
def is_url(url, check=False):
# 验证给定的字符串是否为 URL,并可选择检查 URL 是否在线存在。
"""
Validates if the given string is a URL and optionally checks if the URL exists online.
Args:
url (str): The string to be validated as a URL.
check (bool, optional): If True, performs an additional check to see if the URL exists online.
Defaults to True.
Returns:
(bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online.
Returns False otherwise.
Example:
```python
valid = is_url("https://www.example.com")
```
"""
# 使用 contextlib.suppress(Exception) 上下文管理器来抑制任何异常。这意味着在执行代码块中的代码时,即使发生异常,也不会中断程序的执行,而是会继续执行后续代码。这在某些情况下可以用于忽略不重要的错误,但需要注意可能会隐藏一些重要的问题。
with contextlib.suppress(Exception):
# 将输入的 url 转换为字符串类型。这一步确保即使输入的 url 是其他类型(如列表或字典中的元素),也可以正常处理。
url = str(url)
# result = urlparse(urlstring, scheme='', allow_fragments=True)
# urlparse() 函数是 Python 标准库 urllib.parse 模块中的一个函数,用于解析 URL(统一资源定位符)并将其分解为组件。这个函数在处理网络地址时非常有用,因为它可以将复杂的 URL 分解成易于管理的部分。
# 参数 :
# urlstring : 要解析的 URL 字符串。
# scheme : (可选)如果提供,将用于覆盖 URL 中的方案部分。
# allow_fragments : (可选)一个布尔值,指示是否允许解析 URL 的片段部分(即 # 后面的部分)。默认为 True 。
# 返回值 :
# urlparse() 函数返回一个 ParseResult 对象,该对象包含以下属性 :
# scheme : URL 的方案部分(例如 http 、 https )。
# netloc : 网络位置部分(例如域名和端口)。
# path : URL 的路径部分。
# params : URL 的参数部分( ? 后面的部分)。
# query : URL 的查询部分( ? 后面的部分,不包括 # )。
# fragment : URL 的片段部分( # 后面的部分)。
# urlparse() 函数是处理 URL 的基础工具,常用于网络编程、Web 开发和任何需要解析或构造 URL 的场景。
# 使用 urllib.parse.urlparse 函数解析 url 字符串。 urlparse 函数将 URL 分解为多个部分,如协议(scheme)、网络位置(netloc)、路径(path)等,并返回一个 ParseResult 对象,其中包含了这些部分的信息。
result = parse.urlparse(url)
# 使用 assert 语句检查解析后的 URL 是否包含有效的协议和网络位置。
# result.scheme :协议部分,如 http 、 https 等。如果为空,则表示 URL 没有指定协议。
# result.netloc :网络位置部分,如域名或 IP 地址。如果为空,则表示 URL 没有指定网络位置。
# all([result.scheme, result.netloc]) :确保协议和网络位置都不为空。如果为空,则 assert 语句会抛出 AssertionError 异常,表示 URL 无效。
assert all([result.scheme, result.netloc]) # check if is url
# 如果 check 参数为 True ,则执行以下操作来检查 URL 是否在线可访问。
if check:
# 使用 urllib.request.urlopen 函数打开 URL 并获取响应对象。 urlopen 函数会发送一个 HTTP 请求到指定的 URL,并返回一个响应对象,其中包含了请求的结果。
with request.urlopen(url) as response:
# 检查响应对象的状态码是否为 200。状态码 200 表示请求成功,URL 在线可访问。如果状态码为 200,则返回 True ,表示 URL 在线可访问;否则返回 False 。
return response.getcode() == 200 # check if exists online
# 如果 check 参数为 False 或者 URL 检查通过,则返回 True ,表示 URL 是有效的。
return True
# 如果在执行过程中发生异常(例如 URL 格式不正确或网络请求失败),则返回 False ,表示 URL 无效或无法访问。
return False
# 这个 is_url 函数的作用是检查给定的字符串是否是一个有效的 URL,并可选地检查该 URL 是否在线可访问。它首先解析 URL 字符串,检查是否包含有效的协议和网络位置。如果 check 参数为 True ,则发送 HTTP 请求检查 URL 是否在线可访问。函数使用了异常抑制来简化错误处理,确保在任何情况下都能返回一个布尔值表示 URL 的有效性。这个函数可以用于验证用户输入的 URL 或在程序中处理 URL 时确保其有效性。
3.def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
python
# 这段代码定义了一个名为 delete_dsstore 的函数,用于删除指定路径下特定文件。
# 定义函数 delete_dsstore ,它有两个参数。
# 1.path :要删除文件的目录路径。
# 2.files_to_delete :一个元组,默认包含两个元素 .DS_Store 和 __MACOSX ,这两个文件通常是Mac系统在文件传输等过程中生成的,一般不需要在其他系统或场景中保留。
def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
# 删除指定目录下的所有".DS_store"文件。
# 注意:
# ".DS_store"文件由 Apple 操作系统创建,包含有关文件夹和文件的元数据。它们是隐藏的系统文件,在不同操作系统之间传输文件时可能会导致问题。
"""
Deletes all ".DS_store" files under a specified directory.
Args:
path (str, optional): The directory path where the ".DS_store" files should be deleted.
files_to_delete (tuple): The files to be deleted.
Example:
```python
from ultralytics.utils.downloads import delete_dsstore
delete_dsstore('path/to/dir')
```
Note:
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
are hidden system files and can cause issues when transferring files between different operating systems.
"""
# 遍历 files_to_delete 元组中的每个文件名。
for file in files_to_delete:
# Path.rglob(pattern)
# rglob() 是 Python pathlib 模块中 Path 类的一个方法,用于递归地搜索与给定模式匹配的所有文件路径。这个方法会遍历给定路径下的所有子目录,寻找匹配指定模式的文件。
# 参数 :
# pattern :一个字符串,表示要匹配的文件名模式。这个模式遵循 Unix shell 的规则,其中 * 匹配任意数量的字符(除了路径分隔符),而 ** 用于表示任意深度的目录。
# 返回值 :
# 返回一个生成器(generator),生成所有匹配模式的 Path 对象。
# rglob() 方法是递归的,因此它会搜索所有子目录,而不仅仅是当前目录。这使得它非常适合于在大型项目中查找特定类型的文件。
# 使用 pathlib 模块中的 Path 类的 rglob 方法,递归查找 path 路径下所有匹配当前 file 名称的文件,将结果转换为列表存储在 matches 变量中。
matches = list(Path(path).rglob(file))
# 通过日志记录器 LOGGER ,记录正在删除的文件名及这些文件的路径列表,方便跟踪操作过程。
LOGGER.info(f"Deleting {file} files: {matches}") # 删除 {file} 个文件:{matches} 。
# 遍历 matches 列表中的每个文件路径对象。
for f in matches:
# Path.unlink(missing_ok=False)
# unlink() 函数是 pathlib 模块中 Path 类的一个方法,用于删除文件系统中的一个文件。
# Path : 这是 pathlib 模块中的 Path 类,用于表示文件系统路径。
# unlink() : 这是 Path 类的方法,用于删除路径所指向的文件。
# 参数 :
# missing_ok : 这是一个可选参数,默认值为 False 。如果设置为 True ,则在文件不存在时不会抛出异常,而是静默地忽略这个错误。
# 功能 :
# unlink() 方法用于删除文件系统中的一个文件。如果文件不存在,并且 missing_ok 参数为 False (默认值),则会引发一个 FileNotFoundError 。
# 注意事项 :
# 使用 unlink() 方法时要小心,因为一旦文件被删除,就无法恢复。
# 确保在删除文件之前有适当的错误处理和文件存在性检查,除非你确定文件存在,或者你不在乎文件是否实际存在。
# 在多线程或多进程环境中,文件可能会在不同的执行线程或进程中被访问或修改,因此在使用 unlink() 时要特别注意同步和竞态条件。
# 调用文件路径对象的 unlink 方法,删除该文件。
f.unlink()
# 这段代码通过定义 delete_dsstore 函数,实现了在指定路径下递归查找并删除 .DS_Store 和 __MACOSX 等特定文件的功能,同时利用日志记录了删除操作的详细信息,方便调试和追踪。
4.def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
python
# 这段代码定义了一个名为 zip_directory 的函数,用于将指定目录下的文件压缩成一个zip文件。
# 定义函数 zip_directory ,它有四个参数。
# 1.directory :要压缩的目录路径。
# 2.compress :一个布尔值,默认为 True ,表示是否对文件进行压缩。
# 3.exclude :一个元组,默认包含 .DS_Store 和 __MACOSX ,表示在压缩过程中需要排除的文件名。
# 4.progress :一个布尔值,默认为 True ,表示是否显示压缩进度条。
def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
# 压缩目录内容,排除包含排除列表中的字符串的文件。生成的 zip 文件以目录命名并放在旁边。
"""
Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
named after the directory and placed alongside it.
Args:
directory (str | Path): The path to the directory to be zipped.
compress (bool): Whether to compress the files while zipping. Default is True.
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
progress (bool, optional): Whether to display a progress bar. Defaults to True.
Returns:
(Path): The path to the resulting zip file.
Example:
```python
from ultralytics.utils.downloads import zip_directory
file = zip_directory('path/to/dir')
```
"""
# 从 zipfile 模块导入常量 ZIP_DEFLATED 和 ZIP_STORED 以及 ZipFile 类。 ZIP_DEFLATED 表示使用压缩算法, ZIP_STORED 表示不压缩直接存储。
from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
# 调用前面定义的 delete_dsstore 函数,删除 directory 目录下的 .DS_Store 和 __MACOSX 文件。
delete_dsstore(directory)
# 将 directory 参数转换为 pathlib 模块中的 Path 对象,方便后续操作。
directory = Path(directory)
# 判断 directory 是否是一个目录。
if not directory.is_dir():
# 如果 directory 不是目录,则抛出 FileNotFoundError 异常,提示目录不存在。
raise FileNotFoundError(f"Directory '{directory}' does not exist.") # 目录"{directory}"不存在。
# Unzip with progress bar
# 使用列表推导式,递归查找 directory 目录下所有文件( rglob("*") ),并且过滤掉文件名中包含 exclude 元组中任一元素的文件,得到需要压缩的文件列表。
files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
# 将 directory 路径的后缀改为 .zip ,得到压缩后的zip文件路径。
zip_file = directory.with_suffix(".zip")
# 根据 compress 参数的值,选择压缩算法。如果 compress 为 True ,则使用 ZIP_DEFLATED 压缩算法;否则使用 ZIP_STORED 不压缩。
compression = ZIP_DEFLATED if compress else ZIP_STORED
# ZipFile(file, mode='r', compression=ZIP_STORED, allowZip64=True, compresslevel=None, *, strict_timestamps=True)
# ZipFile 类是Python标准库 zipfile 模块中的一个类,用于读取和写入ZIP文件。
# 参数解释 :
# file :可以是一个文件名(字符串),也可以是一个类文件对象(如 io.BytesIO )。如果是一个文件名, ZipFile 将打开这个文件进行读写操作;如果是一个类文件对象, ZipFile 将使用这个对象进行操作。
# mode :指定打开文件的模式,默认为 'r' ,表示只读模式。其他常用模式包括 : 'w' 写入模式,如果文件已存在,将被覆盖。 'a' 追加模式,如果文件已存在,将在文件末尾追加内容。 'r+' 读写模式,文件必须已存在。
# compression :指定压缩算法,默认为 ZIP_STORED ,表示不压缩。常用的压缩算法包括 : ZIP_STORED 不压缩,直接存储文件。 ZIP_DEFLATED 使用DEFLATE算法压缩文件,这是最常用的压缩算法。 ZIP_BZIP2 使用BZIP2算法压缩文件。 ZIP_LZMA 使用LZMA算法压缩文件。
# allowZip64 :布尔值,默认为 True ,表示是否允许使用ZIP64扩展。ZIP64扩展允许创建和读取大于4GB的ZIP文件。
# compresslevel :指定压缩级别,仅在使用 ZIP_DEFLATED 、 ZIP_BZIP2 或 ZIP_LZMA 压缩算法时有效。取值范围通常为0(无压缩)到9(最高压缩)。
# strict_timestamps :布尔值,默认为 True ,表示是否严格处理文件时间戳。如果为 True ,将确保文件时间戳在ZIP文件中准确表示。
# 常用方法 :
# write(filename, arcname=None, compress_type=None, compresslevel=None) :将文件 filename 写入ZIP文件中。 arcname 可以指定文件在ZIP文件中的名称, compress_type 可以指定压缩算法, compresslevel 可以指定压缩级别。
# writestr(zinfo_or_arcname, data, compress_type=None, compresslevel=None) :将字符串 data 写入ZIP文件中。 zinfo_or_arcname 可以是一个 ZipInfo 对象或文件名, compress_type 和 compresslevel 的含义与 write 方法相同。
# extract(member, path=None, pwd=None) :从ZIP文件中提取一个成员文件到指定路径 path 。 member 可以是一个文件名或 ZipInfo 对象, pwd 可以指定解压密码。
# extractall(path=None, members=None, pwd=None) :从ZIP文件中提取所有成员文件到指定路径 path 。 members 可以是一个文件名列表, pwd 可以指定解压密码。
# close() :关闭ZIP文件,释放资源。
# ZipFile 类提供了丰富的功能来处理ZIP文件,包括创建、写入、读取和提取文件。通过合理使用其参数和方法,可以方便地进行文件的压缩和解压操作。
# 使用 with 语句创建一个 ZipFile 对象,以写入模式打开 zip_file 文件,并设置压缩算法。这样可以确保文件操作完成后自动关闭文件。
with ZipFile(zip_file, "w", compression) as f:
# 遍历 files_to_zip 列表中的每个文件。如果 progress 为 True ,则使用 tqdm 库的 TQDM 函数显示进度条,描述信息为正在将 directory 压缩到 zip_file ,单位为文件。
for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): # 正在将 {directory} 压缩至 {zip_file}...
# 将当前文件 file 写入zip文件中,并设置其在zip文件中的相对路径为 file 相对于 directory 的路径。
f.write(file, file.relative_to(directory))
# 返回压缩后的zip文件路径。
return zip_file # return path to zip file
# 这段代码通过定义 zip_directory 函数,实现了将指定目录下的文件(排除特定文件)压缩成一个zip文件的功能,并且可以根据需要选择是否压缩以及是否显示进度条。在压缩前,还会先删除目录下的 .DS_Store 和 __MACOSX 文件。
5.def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
python
# 这段代码定义了一个名为 unzip_file 的函数,用于解压ZIP文件到指定路径,并提供了多种选项来控制解压过程。
# 定义函数 unzip_file ,它有五个参数。
# 1.file :要解压的ZIP文件路径。
# 2.path :解压目标路径,默认为 None ,表示使用ZIP文件所在目录。
# 3.exclude :一个元组,包含需要排除的文件名,默认为 .DS_Store 和 __MACOSX 。
# 4.exist_ok :布尔值,默认为 False ,表示如果目标目录已存在且不为空,是否跳过解压。
# 5.progress :布尔值,默认为 True ,表示是否显示解压进度条。
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
# 将 *.zip 文件解压到指定路径,排除包含排除列表中的字符串的文件。
# 如果 zip 文件不包含单个顶级目录,则该函数将创建一个新的。
# 目录,其名称与 zip 文件相同(不带扩展名)以提取其内容。
# 如果未提供路径,则该函数将使用 zip 文件的父目录作为默认路径。
# 引发:
# BadZipFile:如果提供的文件不存在或不是有效的 zip 文件。
"""
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
If the zipfile does not contain a single top-level directory, the function will create a new
directory with the same name as the zipfile (without the extension) to extract its contents.
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
Args:
file (str): The path to the zipfile to be extracted.
path (str, optional): The path to extract the zipfile to. Defaults to None.
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
progress (bool, optional): Whether to display a progress bar. Defaults to True.
Raises:
BadZipFile: If the provided file does not exist or is not a valid zipfile.
Returns:
(Path): The path to the directory where the zipfile was extracted.
Example:
```python
from ultralytics.utils.downloads import unzip_file
dir = unzip_file('path/to/file.zip')
```
"""
# 从 zipfile 模块导入 BadZipFile 、 ZipFile 和 is_zipfile 函数。
from zipfile import BadZipFile, ZipFile, is_zipfile
# is_zip = zipfile.is_zipfile(filename)
# is_zipfile() 函数是 Python zipfile 模块中的一个函数,用于检查一个文件是否是有效的 ZIP 文件格式。
# 参数 :
# filename : 要检查的文件的路径,可以是字符串、文件对象或路径对象。
# 返回值 :
# is_zipfile() 函数返回一个布尔值。 如果文件是有效的 ZIP 文件,则返回 True 。 如果文件不是有效的 ZIP 文件或文件不存在,则返回 False 。
# is_zipfile() 函数的实现依赖于文件的"魔术数字"(文件开头的字节序列),这是许多文件格式用来标识自己的一种方式。ZIP 文件的魔术数字是 PK ( 0x50 0x4B ),这个序列出现在所有 ZIP 文件的开头。
# 如果一个文件以这个序列开头, is_zipfile() 函数就会返回 True ,表明该文件是一个 ZIP 文件。这个函数在处理文件上传、归档和解压缩任务时非常有用,因为它可以帮助程序确定如何处理特定的文件。
# 检查 file 是否存在且是一个有效的ZIP文件。
if not (Path(file).exists() and is_zipfile(file)):
# 如果文件不存在或不是一个有效的ZIP文件,抛出 BadZipFile 异常。
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") # 文件"{file}"不存在或是一个坏的 zip 文件。
# 如果 path 为 None ,则使用ZIP文件的父目录作为默认解压路径。
if path is None:
# 设置默认解压路径为ZIP文件的父目录。
path = Path(file).parent # default path
# Unzip the file contents
# 使用 with 语句打开ZIP文件,创建一个 ZipFile 对象 zipObj ,确保文件操作完成后自动关闭文件。
with ZipFile(file) as zipObj:
# 使用列表推导式,从ZIP文件中获取所有文件名,并排除 exclude 元组中指定的文件。
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
# 获取ZIP文件中所有文件的顶级目录名称,存储在集合 top_level_dirs 中。
top_level_dirs = {Path(f).parts[0] for f in files}
# 判断ZIP文件中是否有多个顶级目录或多个文件在顶级目录中。
if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
# Zip has multiple files at top level
# 如果有多个顶级目录,设置 解压路径 为 path 加上ZIP文件的名称(不包括扩展名)。
path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
# 如果只有一个顶级目录。
else:
# Zip has 1 top-level directory
# 设置解压路径为 path 。
extract_path = path # i.e. ../datasets
# 设置最终路径为 path 加上顶级目录名称。
path = Path(path) / list(top_level_dirs)[0] # i.e. ../datasets/coco8
# Check if destination directory already exists and contains files
# path.iterdir()
# path.iterdir() 是 Python pathlib 模块中 Path 类的一个方法,它用于遍历指定路径下的目录内容。这个方法返回一个迭代器,该迭代器产生路径下的所有文件和子目录的 Path 对象。
# 参数 :
# path :一个 Path 对象,表示你想要遍历的目录。
# 返回值 :
# 返回一个迭代器,产生路径下每个文件和子目录的 Path 对象。
# iterdir() 方法是处理文件系统时非常有用的工具,它提供了一种简洁的方式来访问目录内容,并且能够以面向对象的方式操作路径。
# 检查 目标目录 是否已存在且不为空,且 exist_ok 为 False 。
if path.exists() and any(path.iterdir()) and not exist_ok:
# If it exists and is not empty, return the path without unzipping
# 如果目标目录已存在且不为空,记录警告信息并返回路径,不进行解压。
LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") # 警告 ⚠️ 跳过 {file} 解压缩,因为目标目录 {path} 不为空。
# 返回目标目录路径。
return path
# 遍历需要解压的文件列表,显示解压进度条。
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): # 正在将 {file} 解压至 {Path(path).resolve()}...
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
# 检查文件路径中是否包含 .. ,以防止路径遍历安全漏洞。
if ".." in Path(f).parts:
# 如果路径不安全,记录警告信息并跳过解压。
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") # 可能不安全的文件路径:{f},跳过提取。
# 跳过当前文件的解压。
continue
# 将文件 f 从ZIP文件中提取到 extract_path 目录。
zipObj.extract(f, extract_path)
# 返回解压后的目录路径。
return path # return unzip dir
# 这段代码通过定义 unzip_file 函数,实现了将ZIP文件解压到指定路径的功能。它提供了多种选项,如排除特定文件、处理目标目录已存在的情况、显示解压进度条等。此外,代码还进行了安全检查,防止路径遍历攻击。
6.def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
python
# 这段代码定义了一个名为 check_disk_space 的函数,用于检查下载文件前磁盘空间是否充足。
# 定义了一个名为 check_disk_space 的函数,接受四个参数。
# 1.url :文件的URL。
# 2.path :检查磁盘空间的路径,默认为当前工作目录。
# 3.sf :安全因子,默认为1.5。
# 4.hard :布尔值,指示是否抛出异常,默认为 True 。
def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", path=Path.cwd(), sf=1.5, hard=True):
# 检查是否有足够的磁盘空间来下载和存储文件。
"""
Check if there is sufficient disk space to download and store a file.
Args:
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
path (str | Path, optional): The path or drive to check the available free space on.
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
Returns:
(bool): True if there is sufficient disk space, False otherwise.
"""
# 开始尝试执行块。
try:
# 使用 requests 库发送HEAD请求到提供的URL,以获取文件的大小。
r = requests.head(url) # response
# 断言响应状态码小于400,即请求成功。如果失败,抛出异常。
assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response {url} 的 URL 错误:{r.status_code} {r.reason}。
# 如果HEAD请求失败,捕获异常。
except Exception:
# 如果请求失败,函数返回 True ,表示不进行磁盘空间检查。
return True # requests issue, default to True
# Check file size
# 定义1 GiB的字节数。
gib = 1 << 30 # bytes per GiB
# 从响应头中获取文件大小(Content-Length),并转换为GB。
data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
# shutil.disk_usage(path)
# shutil.disk_usage(path) 是 Python 标准库 shutil 模块中的一个函数,用于获取指定路径的磁盘使用情况统计信息。
# 参数 :
# path :一个表示文件系统路径的字符串或路径对象。在 Windows 上,这个路径必须代表一个目录;在 Unix 系统上,它可以是文件或目录。
# 返回值 :
# 该函数返回一个命名元组(namedtuple),包含以下属性 :
# total :表示文件系统的总空间量,单位为字节。
# used :表示已使用的磁盘空间量,单位为字节。
# free :表示可用的磁盘空间量,单位为字节。
# 注意事项 :
# 在 Unix 文件系统中, path 必须指向一个已挂载文件系统分区中的路径。在这些平台上,CPython 不会尝试从未挂载的文件系统中获取磁盘使用信息。
# 从 Python 3.8 开始,在 Windows 上, path 可以是一个文件或目录。 shutil.disk_usage() 提供的信息可以帮助开发者监控磁盘空间使用情况,或者在需要时优化磁盘使用。
# 使用 shutil.disk_usage 函数获取指定路径的磁盘使用情况,并转换为GB。
total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes
# 检查文件大小乘以安全因子是否小于可用磁盘空间。
if data * sf < free:
# 如果磁盘空间充足,返回 True 。
return True # sufficient space
# Insufficient space
# 构造一个警告信息,说明磁盘空间不足。
text = (
f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
f"Please free {data * sf - free:.1f} GB additional disk space and try again." # 警告 ⚠️ 可用磁盘空间不足 {free:.1f} GB < {data * sf:.3f} GB required,请释放 {data * sf - free:.1f} GB 额外磁盘空间并重试。
)
# 如果 hard 为 True ,执行以下操作。
if hard:
# 抛出 MemoryError 异常,指示磁盘空间不足。
raise MemoryError(text)
# 记录磁盘空间不足的警告信息。
LOGGER.warning(text)
# 返回 False ,表示磁盘空间不足。
return False
# check_disk_space 函数用于在下载文件前检查磁盘空间是否足够。它首先通过HEAD请求获取文件大小,然后检查指定路径的可用磁盘空间是否满足文件大小乘以安全因子。如果磁盘空间不足,函数会根据 hard 参数决定是抛出异常还是记录警告信息。这个函数有助于避免因磁盘空间不足而导致的下载失败。
7.def get_google_drive_file_info(link):
python
# 这段代码定义了一个名为 get_google_drive_file_info 的函数,用于从Google Drive分享链接中提取文件信息,包括文件ID和文件名。
# 定义了一个名为 get_google_drive_file_info 的函数,它接受一个参数。
# 1.link :即Google Drive的分享链接。
def get_google_drive_file_info(link):
# 检索可共享 Google Drive 文件链接的直接下载链接和文件名。
"""
Retrieves the direct download link and filename for a shareable Google Drive file link.
Args:
link (str): The shareable link of the Google Drive file.
Returns:
(str): Direct download URL for the Google Drive file.
(str): Original filename of the Google Drive file. If filename extraction fails, returns None.
Example:
```python
from ultralytics.utils.downloads import get_google_drive_file_info
link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
url, filename = get_google_drive_file_info(link)
```
"""
# 从链接中提取文件ID。首先以 "/d/" 为分隔符分割链接,然后取第二部分;再以 "/view" 为分隔符分割,取第一部分,即文件ID。
file_id = link.split("/d/")[1].split("/view")[0]
# 构造用于下载文件的Google Drive URL,使用提取到的文件ID。
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
# 初始化 filename 变量,用于存储文件名。
filename = None
# Start session
# 使用 requests 库创建一个会话对象,用于发送HTTP请求。
with requests.Session() as session:
# 发送GET请求到构造的下载URL,并设置 stream=True 以流式传输数据。
response = session.get(drive_url, stream=True)
# 检查响应内容中是否包含"quota exceeded",即检查是否超出了下载配额。
if "quota exceeded" in str(response.content.lower()):
# 如果超出配额,抛出 ConnectionError 异常,并提示用户稍后再试或手动下载文件。
raise ConnectionError(
emojis(
f"❌ Google Drive file download quota exceeded. " # ❌ Google Drive 文件下载配额已超出。
f"Please try again later or download this file manually at {link}." # 请稍后重试或从 {link} 手动下载此文件。
)
)
# 遍历响应中的cookies。
for k, v in response.cookies.items():
# 检查cookie的键是否以"download_warning"开头。
if k.startswith("download_warning"):
# 如果找到对应的cookie,则将确认令牌添加到下载URL中。
drive_url += f"&confirm={v}" # v is token
# 从响应头中获取 content-disposition 字段,该字段包含文件名信息。
cd = response.headers.get("content-disposition")
# 如果 content-disposition 存在,执行以下操作。
if cd:
# 使用正则表达式从 content-disposition 中提取文件名。
filename = re.findall('filename="(.+)"', cd)[0]
# 返回更新后的 下载URL 和 文件名 。
return drive_url, filename
# get_google_drive_file_info 函数的作用是处理Google Drive的分享链接,提取文件ID和文件名。它首先构造一个用于下载文件的URL,然后通过发送GET请求并检查响应来获取文件名。如果响应中包含下载配额超出的信息,则抛出异常。最后,函数返回更新后的下载URL和文件名,这些信息可以用于后续的文件下载操作。这个函数对于处理Google Drive链接并准备下载文件非常有用。
8.def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,):
python
# 这段代码定义了一个名为 safe_download 的函数,用于安全地下载文件,并在需要时解压。
# 定义了一个名为 safe_download 的函数,接受多个参数用于控制下载行为。
# 1.url :需要下载的文件的URL。
# 2.file :下载文件的目标文件名。如果为 None ,则从URL中提取文件名。
# 3.dir :下载文件的目标目录。如果为 None ,则使用当前工作目录。
# 4.unzip :指示是否在下载后解压文件(如果文件是压缩格式)。
# 5.delete :指示是否在解压后删除原始压缩文件。
# 6.curl :指示是否使用 curl 命令行工具进行下载,而不是Python的 urllib 库。
# 7.retry :下载失败时重试的次数。
# 8.min_bytes :下载文件的最小字节数,用于检查下载是否成功。
# 9.exist_ok :如果目标文件已存在,设置为 True 则不会覆盖现有文件。
# 10.progress :指示是否显示下载进度条。
def safe_download(
url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1e0,
exist_ok=False,
progress=True,
):
# 从 URL 下载文件,带有重试、解压和删除已下载文件的选项。
"""
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Args:
url (str): The URL of the file to be downloaded.
file (str, optional): The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir (str, optional): The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
Example:
```python
from ultralytics.utils.downloads import safe_download
link = "https://ultralytics.com/assets/bus.jpg"
path = safe_download(link)
```
"""
# 检查URL是否是Google Drive链接。
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
# 如果URL是Google Drive链接,执行以下操作。
if gdrive:
# 调用 get_google_drive_file_info 函数来获取文件信息。
url, file = get_google_drive_file_info(url)
# 根据提供的 dir 和 file 参数或从URL转换得到的文件名,构造一个路径对象。
# def url2file(url): -> 用于从URL中提取文件名。调用 clean_url 函数,然后使用 Path 类(来自 pathlib 模块)获取清理后的URL的路径对象,最后返回该路径的基本文件名。 -> return Path(clean_url(url)).name
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
# 检查URL是否指向一个本地文件。
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
# 如果URL指向本地文件,更新 f 为该本地路径。
f = Path(url) # filename
# 如果URL和文件都不存在,执行以下操作。
elif not f.is_file(): # URL and file do not exist
# 构造一个描述下载行为的字符串。
# def clean_url(url): -> 用于清理和规范化URL。使用 split("?") 将URL分割为查询参数前的部分和查询参数部分,取第一部分,即不包含查询参数的URL。 -> return urllib.parse.unquote(url).split("?")[0]
desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'" # 正在将 {url if gdrive else clean_url(url)} 下载至"{f}"。
# 记录下载开始的日志信息。
LOGGER.info(f"{desc}...")
# 确保文件的父目录存在。
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
# 检查磁盘空间是否足够。
check_disk_space(url, path=f.parent)
# 进行最多 retry 次的下载尝试。
for i in range(retry + 1):
# 尝试下载文件。
try:
# 如果使用curl或已经是重试,执行以下操作。
if curl or i > 0: # curl download with retry, continue
# 根据 progress 参数设置curl的静默模式。
s = "sS" * (not progress) # silent
# 使用curl命令下载文件。
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
# 确保curl命令成功执行。
assert r == 0, f"Curl return value {r}" # Curl 返回值 {r}。
# 如果不使用curl,使用urllib下载。
else: # urllib download
# 设置下载方法为torch。
method = "torch"
# 如果使用torch方法,执行以下操作。
if method == "torch":
# 使用torch的函数下载文件。
torch.hub.download_url_to_file(url, f, progress=progress)
# 如果不使用torch方法,执行以下操作。
else:
# 使用urllib打开URL并创建一个进度条。
with request.urlopen(url) as response, TQDM(
total=int(response.getheader("Content-Length", 0)),
desc=desc,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
# 以二进制写模式打开文件。
with open(f, "wb") as f_opened:
# 遍历响应数据。
for data in response:
# 写入文件。
f_opened.write(data)
# 更新进度条。
pbar.update(len(data))
# 检查文件是否存在。
if f.exists():
# 检查文件大小是否大于最小字节数。
if f.stat().st_size > min_bytes:
# 如果成功,跳出循环。
break # success
# 如果下载不成功,删除部分下载的文件。
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not is_online():
# def emojis(string=""): -> 处理字符串中的 emoji 字符。 -> return string.encode().decode("ascii", "ignore") if WINDOWS else string
raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e # ❌ {url} 下载失败。环境不在线。
elif i >= retry:
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e # ❌ {url} 下载失败。已达到重试限制。
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...") # ⚠️ 下载失败,重试 {i + 1}/{retry} {url}...
# 如果需要解压且文件存在且后缀为特定值,执行以下操作。
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
# 导入is_zipfile函数。
from zipfile import is_zipfile
# 设置解压目录。
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
# is_zip = zipfile.is_zipfile(filename)
# is_zipfile() 函数是 Python zipfile 模块中的一个函数,用于检查一个文件是否是有效的 ZIP 文件格式。
# 参数 :
# filename : 要检查的文件的路径,可以是字符串、文件对象或路径对象。
# 返回值 :
# is_zipfile() 函数返回一个布尔值。 如果文件是有效的 ZIP 文件,则返回 True 。 如果文件不是有效的 ZIP 文件或文件不存在,则返回 False 。
# is_zipfile() 函数的实现依赖于文件的"魔术数字"(文件开头的字节序列),这是许多文件格式用来标识自己的一种方式。ZIP 文件的魔术数字是 PK ( 0x50 0x4B ),这个序列出现在所有 ZIP 文件的开头。
# 如果一个文件以这个序列开头, is_zipfile() 函数就会返回 True ,表明该文件是一个 ZIP 文件。这个函数在处理文件上传、归档和解压缩任务时非常有用,因为它可以帮助程序确定如何处理特定的文件。
# 如果文件是zip文件,执行以下操作。
if is_zipfile(f):
# 调用 unzip_file 函数解压文件。
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
# 如果文件是tar或gz文件,执行以下操作。
elif f.suffix in (".tar", ".gz"):
LOGGER.info(f"Unzipping {f} to {unzip_dir}...") # 正在将 {f} 解压缩至 {unzip_dir}...
# 使用tar命令解压文件。
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
# 如果需要删除压缩文件,执行以下操作。
if delete:
# 删除压缩文件。
f.unlink() # remove zip
# 返回解压目录。
return unzip_dir
# safe_download 函数用于安全地下载文件,并在需要时解压。它首先检查URL是否指向一个Google Drive链接,然后构造文件路径,检查文件是否存在,如果不存在则尝试下载。下载过程中,它会检查磁盘空间,最多重试 retry 次,并记录日志信息。如果下载成功,它会检查文件大小是否大于最小字节数。如果需要解压,它会根据文件后缀调用相应的解压函数。最后,如果需要,它会删除压缩文件,并返回解压目录。这个函数提供了一个健壮的下载和解压机制,适用于多种文件类型和下载场景。
# def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,):
# -> 用于安全地下载文件,并在需要时解压。返回解压目录。
# -> return unzip_dir
9.def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
python
# 这段代码定义了一个名为 get_github_assets 的函数,它用于从 GitHub API 获取指定仓库的资产信息。
# 这行代码定义了 get_github_assets 函数,它接受以下参数 :
# 1.repo : GitHub 仓库名称,默认为 "ultralytics/assets" 。
# 2.version : 版本标签,默认为 "latest" 。
# 3.retry : 是否重试请求,默认为 False 。
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
# 从 GitHub 存储库检索指定版本的标签和资产。如果未指定版本,该函数将获取最新版本的资产。
"""
Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
function fetches the latest release assets.
Args:
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
version (str, optional): The release version to fetch assets from. Defaults to 'latest'.
retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False.
Returns:
(tuple): A tuple containing the release tag and a list of asset names.
Example:
```python
tag, assets = get_github_assets(repo='ultralytics/assets', version='latest')
```
"""
# 检查版本是否不是 "latest" 。
if version != "latest":
# 如果版本不是 "latest" ,将版本转换为 GitHub API 所需的格式,例如 tags/v6.2 。
version = f"tags/{version}" # i.e. tags/v6.2
# 构建请求 URL。
url = f"https://api.github.com/repos/{repo}/releases/{version}"
# 使用 requests 库发送 GET 请求到 GitHub API。
r = requests.get(url) # github api
# 检查响应状态码是否不是 200,且原因不是 "rate limit exceeded" 并且 retry 参数为 True 。
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded 403 超出速率限制。
# 如果满足重试条件,再次发送请求。
r = requests.get(url) # try again
# 检查响应状态码是否不是 200。
if r.status_code != 200:
# 如果状态码不是 200,记录一条警告日志,提示 GitHub 资产检查失败。
LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") # ⚠️ GitHub 资产检查 {url} 失败:{r.status_code} {r.reason}。
# 如果请求失败,返回空字符串和空列表。
return "", []
# 如果请求成功,将响应内容解析为 JSON 数据。
data = r.json()
# 提取标签名称和资产列表,并返回它们。
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
# get_github_assets 函数用于从 GitHub API 获取指定仓库和版本的资产信息。它处理了版本格式的转换、请求发送、重试逻辑以及错误处理。如果请求成功,函数返回标签名称和资产列表;如果请求失败,函数返回空字符串和空列表,并记录警告日志。这个函数为获取 GitHub 资产信息提供了一个健壮的方法。
10.def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):
python
# 这段代码定义了一个名为 attempt_download_asset 的函数,其目的是尝试下载一个指定的资产文件(如模型权重文件)。这个函数处理从 GitHub releases 页面或其他 URL 下载文件的逻辑。
# 这行代码定义了 attempt_download_asset 函数,它接受以下参数 :
# 1.file : 要下载的文件名或 URL。
# 2.repo : GitHub 仓库名称,默认为 "ultralytics/assets" 。
# 3.release : GitHub 释放版本标签,默认为 "v8.1.0" 。
# 4.**kwargs : 其他关键字参数,用于传递给下载函数。
def attempt_download_asset(file, repo="ultralytics/assets", release="v8.1.0", **kwargs):
# 如果在本地找不到文件,则尝试从 GitHub 发布资产中下载文件。该函数首先在本地检查文件,然后尝试从指定的 GitHub 存储库发布中下载。
"""
Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
locally first, then tries to download it from the specified GitHub repository release.
Args:
file (str | Path): The filename or file path to be downloaded.
repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'.
release (str, optional): The specific release version to be downloaded. Defaults to 'v8.1.0'.
**kwargs (any): Additional keyword arguments for the download process.
Returns:
(str): The path to the downloaded file.
Example:
```python
file_path = attempt_download_asset('yolov5s.pt', repo='ultralytics/assets', release='latest')
```
"""
# 从 ultralytics.utils 模块导入 SETTINGS ,这是一个存储配置设置的模块。
from ultralytics.utils import SETTINGS # scoped for circular import
# YOLOv3/5u updates
# 确保 file 参数是字符串类型。
file = str(file)
# 调用 checks 模块中的 check_yolov5u_filename 函数来验证和标准化 YOLOv5u 的文件名。
# def check_yolov5u_filename(file: str, verbose: bool = True): -> 用于检查和修改包含YOLOv5模型文件名的字符串,确保文件名符合特定的命名约定。返回修改后的文件名。 -> return file
file = checks.check_yolov5u_filename(file)
# 去除文件名字符串两端的空白字符和单引号,然后转换为 Path 对象。
file = Path(file.strip().replace("'", ""))
# 检查本地是否存在该文件。
if file.exists():
# 如果文件存在,返回文件的路径字符串。
return str(file)
# SETTINGS -> 初始化一个 SettingsManager 实例,这个类是用来管理项目设置的。
# 检查在配置设置中指定的权重目录下是否存在该文件。
elif (SETTINGS["weights_dir"] / file).exists():
# 如果文件在权重目录下存在,返回文件的路径字符串。
return str(SETTINGS["weights_dir"] / file)
else:
# URL specified
# parse.unquote(s, encoding='utf-8', errors='replace')
# parse.unquote() 是 Python 标准库 urllib.parse 模块中的一个函数,用于对 URL 编码的字符串进行解码,将百分号编码(%XX)转换回普通字符。
# 参数 :
# s :要解码的 URL 编码字符串。
# encoding :(可选)用于解码的字符编码,默认为 'utf-8' 。
# errors :(可选)指定如何处理解码错误,默认为 'replace' ,意味着将无法解码的字符替换为一个替代字符(通常是 ? )。
# 返回值 :
# 返回解码后的字符串。
# 函数逻辑 :
# 解码百分号编码 :将字符串中的 %XX 序列转换为对应的字符。
# 字符编码转换 :将原始的百分比编码字符串(通常为 ASCII)转换为指定的编码。
# 在例子中, unquote 函数将 URL 编码的字符串 "Hello%2C%20World%21" 解码为普通字符串 "Hello, World!" 。
# 注意事项 :
# 当处理来自用户的 URL 编码数据时,使用 unquote 函数可以确保正确地解释这些数据。
# 如果 URL 包含非 ASCII 字符,确保指定正确的 encoding 参数,否则解码可能会失败或产生意外结果。
# 如果遇到无法解码的百分号序列, errors 参数决定了如何处理这些错误。常见的选项包括 'strict' (抛出 UnicodeDecodeError )、 'replace' (用替代字符替换无法解码的字符)和 'ignore' (忽略无法解码的字符)。
# 获取文件名,并解码 URL 编码的字符。
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
# 构建 GitHub releases 页面的下载 URL。
download_url = f"https://github.com/{repo}/releases/download"
# 检查文件路径是否以 http:/ 或 https:/ 开头。
if str(file).startswith(("http:/", "https:/")): # download
# 修复 URL 格式,将 :/ 替换为正确的 :// 。
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
# 调用 url2file 函数将 URL 转换为文件名。
# def url2file(url): -> 用于从URL中提取文件名。调用 clean_url 函数,然后使用 Path 类(来自 pathlib 模块)获取清理后的URL的路径对象,最后返回该路径的基本文件名。 -> return Path(clean_url(url)).name
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
# 检查转换后的文件名是否指向一个已存在的文件。
if Path(file).is_file():
# 如果文件已存在,记录一条信息日志。
# def clean_url(url): -> 用于清理和规范化URL。使用 split("?") 将URL分割为查询参数前的部分和查询参数部分,取第一部分,即不包含查询参数的URL。 -> return urllib.parse.unquote(url).split("?")[0]
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists 在本地的 {file} 处找到了 {clean_url(url)}。
# 如果文件不存在,执行下载操作。
else:
# 调用 safe_download 函数下载文件。
# def safe_download(url, file=None, dir=None, unzip=True, delete=False, curl=False, retry=3, min_bytes=1e0, exist_ok=False, progress=True,):
# -> 用于安全地下载文件,并在需要时解压。返回解压目录。
# -> return unzip_dir
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
# 检查仓库是否是特定的 GitHub 资产仓库,并且文件名是否在已知资产列表中。
# GITHUB_ASSETS_REPO -> 定义了一个变量 GITHUB_ASSETS_REPO ,它存储了 Ultralytics 资产的 GitHub 仓库名称。
# GITHUB_ASSETS_NAMES -> 定义了一个变量 GITHUB_ASSETS_NAMES ,它是一个元组,包含了 Ultralytics 仓库中所有资产的文件名。
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
# 如果是,使用 releases 页面的 URL 下载文件。
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
# 如果文件不是特定 GitHub 资产仓库的一部分,尝试获取最新的 GitHub 资产。
else:
# 调用 get_github_assets 函数获取仓库的资产和标签。
# def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
# -> 用于从 GitHub API 获取指定仓库的资产信息。如果状态码不是 200,记录一条警告日志,提示 GitHub 资产检查失败。如果请求成功,将响应内容解析为 JSON 数据。提取标签名称和资产列表,并返回它们。
# -> return "", [] / return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
tag, assets = get_github_assets(repo, release)
# 如果没有找到资产,尝试获取最新释放版本的资产。
if not assets:
# 调用 get_github_assets 函数获取最新释放版本的资产和标签。
tag, assets = get_github_assets(repo) # latest release
# 检查文件名是否在资产列表中。
if name in assets:
# 如果文件在资产列表中,使用对应的 URL 下载文件。
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
# 返回文件的路径字符串。
return str(file)
# attempt_download_asset 函数是一个下载助手,它尝试从本地路径、GitHub releases 页面或其他 URL 下载指定的资产文件。这个函数处理文件名的标准化、本地文件的检查、URL 的构建和下载逻辑。如果文件已经存在,它会返回文件的路径;如果文件不存在,它会尝试下载文件并返回文件的路径。这个函数通过处理不同的下载场景和错误情况,提高了代码的健壮性。
11.def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
python
# 这段代码定义了一个名为 download 的函数,用于从指定URL下载文件,并提供了多种选项来控制下载过程,包括解压、删除源文件、使用 curl 下载、多线程下载等。
# 定义函数 download ,它有八个参数。
# 1.url :要下载的文件的URL,可以是一个字符串或字符串列表。
# 2.dir :下载目标目录,默认为当前工作目录。
# 3.unzip :布尔值,默认为 True ,表示是否解压下载的文件。
# 4.delete :布尔值,默认为 False ,表示是否删除下载的源文件。
# 5.curl :布尔值,默认为 False ,表示是否使用 curl 命令下载文件。
# 6.threads :整数,默认为 1 ,表示下载时使用的线程数。
# 7.retry :整数,默认为 3 ,表示下载失败时的重试次数。
# 8.exist_ok :布尔值,默认为 False ,表示如果目标目录已存在且不为空,是否跳过下载。
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
# 从指定的 URL 下载文件到给定目录。如果指定多个线程,则支持并发下载。
"""
Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
specified.
Args:
url (str | list): The URL or list of URLs of the files to be downloaded.
dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
curl (bool, optional): Flag to use curl for downloading. Defaults to False.
threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
retry (int, optional): Number of retries in case of download failure. Defaults to 3.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
Example:
```python
download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
```
"""
# 将 dir 参数转换为 pathlib 模块中的 Path 对象,方便后续操作。
dir = Path(dir)
# 创建目标目录,如果目录已存在则不报错。
dir.mkdir(parents=True, exist_ok=True) # make directory
# 如果 threads 大于1,使用多线程下载。
if threads > 1:
# pool = ThreadPool(processes=5)
# 在Python中, multiprocessing.pool.ThreadPool 是 multiprocessing 模块中的 pool 子模块提供的类之一,它用于创建一个线程池,以便并行执行多个线程任务。这个类是 multiprocessing 库的一部分,该库提供了一种方式来并行化程序,利用多核处理器的能力。
# 参数 :
# processes :参数指定线程池中的线程数量。
# 返回 :
# pool :创建一个线程池实例。
# 方法 ThreadPool 提供了以下方法 :
# apply_async(func, args=(), kwds={}) :异步地将一个函数 func 应用到 args 和 kwds 参数上,并返回一个 AsyncResult 对象。
# map(func, iterable, chunksize=1) :将一个函数 func 映射到一个迭代器 iterable 的所有元素上,并返回一个迭代器。
# close() :关闭线程池,不再接受新的任务。
# join() :等待线程池中的所有任务完成。
# terminate() :立即终止线程池中的所有任务。
# ThreadPool 是 multiprocessing 库中用于并行执行 I/O 密集型任务的工具,它允许程序利用多核处理器的能力来提高性能。需要注意的是, multiprocessing 库中的进程是系统级的进程,与线程相比,它们有更大的开销,但也能提供更好的并行性能。
# 使用 concurrent.futures.ThreadPoolExecutor 创建一个线程池,线程数为 threads 。
with ThreadPool(threads) as pool:
# 使用 map 方法并行执行下载任务。
pool.map(
# 定义一个匿名函数,调用 safe_download 函数。
lambda x: safe_download(
# 传递URL。
url=x[0],
# 传递目标目录。
dir=x[1],
# 传递是否解压的参数。
unzip=unzip,
# 传递是否删除源文件的参数。
delete=delete,
# 传递是否使用 curl 的参数。
curl=curl,
# 传递重试次数的参数。
retry=retry,
# 传递是否跳过已存在目录的参数。
exist_ok=exist_ok,
# 传递是否显示进度条的参数,如果线程数为1则显示进度条。
progress=threads <= 1,
),
# itertools.repeat(object[, times])
# repeat() 函数是 Python 标准库 itertools 模块中的一个函数,它用于创建一个迭代器,该迭代器会无限次重复给定的值。
# 参数 :
# object :要重复的值。
# times :(可选)重复的次数。如果不提供或为 None ,则迭代器将无限重复给定的值。
# 返回值 :
# 返回一个迭代器,该迭代器重复给定的值。
# repeat() 函数常用于需要固定值的场景,例如在 zip() 函数中为每个元素对提供相同的参数,或者在其他需要重复值的迭代处理中。
# zip(iter1, iter2, ..., iterN)
# zip() 函数是 Python 内置的一个函数,它接受任意数量的可迭代对象(如列表、元组、字符串等)作为参数,并将这些可迭代对象中对应的元素打包成一个个元组,然后返回由这些元组组成的迭代器。当输入的可迭代对象中最短的一个耗尽时, zip() 函数就会停止工作。
# 参数 :
# iter1, iter2, ..., iterN :一个或多个可迭代对象。
# 返回值 :
# 一个迭代器,其元素是元组,每个元组包含来自输入可迭代对象的对应元素。
# zip() 函数在处理并行数据时非常有用,比如当你有两个列表,并且需要将它们的元素成对处理时。此外, zip() 还可以与 dict() 函数结合使用,快速创建字典。
# 将 url 和重复的目标目录 dir 组合成一个元组列表,用于多线程下载。
zip(url, repeat(dir)),
)
# 关闭线程池,不再接受新的任务。
pool.close()
# 等待所有线程完成。
pool.join()
# 如果 threads 为1,使用单线程下载。
else:
# 如果 url 是一个字符串或 Path 对象,将其转换为列表;如果 url 是一个列表,直接遍历。
for u in [url] if isinstance(url, (str, Path)) else url:
# 调用 safe_download 函数进行下载。
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)
# 这段代码通过定义 download 函数,实现了从指定URL下载文件的功能,并提供了多种选项来控制下载过程,包括解压、删除源文件、使用 curl 下载、多线程下载等。函数首先创建目标目录,然后根据 threads 参数决定是否使用多线程下载。多线程下载时,使用 ThreadPool 并行执行下载任务;单线程下载时,逐个调用 safe_download 函数进行下载。通过这些参数和逻辑, download 函数可以灵活地处理各种下载场景。