augment.py
ultralytics\data\augment.py
目录
[2.class BaseTransform:](#2.class BaseTransform:)
[3.class Compose:](#3.class Compose:)
[4.class BaseMixTransform:](#4.class BaseMixTransform:)
[5.class Mosaic(BaseMixTransform):](#5.class Mosaic(BaseMixTransform):)
[6.class MixUp(BaseMixTransform):](#6.class MixUp(BaseMixTransform):)
[7.class RandomPerspective:](#7.class RandomPerspective:)
[8.class RandomHSV:](#8.class RandomHSV:)
[9.class RandomFlip:](#9.class RandomFlip:)
[10.class LetterBox:](#10.class LetterBox:)
[11.class CopyPaste:](#11.class CopyPaste:)
[12.class Albumentations:](#12.class Albumentations:)
[13.class Format:](#13.class Format:)
[14.class RandomLoadText:](#14.class RandomLoadText:)
[15.def v8_transforms(dataset, imgsz, hyp, stretch=False):](#15.def v8_transforms(dataset, imgsz, hyp, stretch=False):)
[16.def classify_transforms(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, interpolation="BILINEAR", crop_fraction: float = DEFAULT_CROP_FRACTION,):](#16.def classify_transforms(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, interpolation="BILINEAR", crop_fraction: float = DEFAULT_CROP_FRACTION,):)
[17.def classify_augmentations(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, scale=None, ratio=None, hflip=0.5, vflip=0.0, auto_augment=None, hsv_h=0.015, hsv_s=0.4, hsv_v=0.4, force_color_jitter=False, erasing=0.0, interpolation="BILINEAR",):](#17.def classify_augmentations(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, scale=None, ratio=None, hflip=0.5, vflip=0.0, auto_augment=None, hsv_h=0.015, hsv_s=0.4, hsv_v=0.4, force_color_jitter=False, erasing=0.0, interpolation="BILINEAR",):)
[18.class ClassifyLetterBox:](#18.class ClassifyLetterBox:)
[19.class CenterCrop:](#19.class CenterCrop:)
[20.class ToTensor:](#20.class ToTensor:)
1.所需的库和模块
python
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
import random
from copy import deepcopy
from typing import Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image
from ultralytics.data.utils import polygons2masks, polygons2masks_overlap
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
DEFAULT_MEAN = (0.0, 0.0, 0.0)
DEFAULT_STD = (1.0, 1.0, 1.0)
DEFAULT_CROP_FRACTION = 1.0
# labels 都是 Instances 类的实例。 ultralytics\utils\instance.py
# dataset 都是 BaseDataset 类的实例。 ultralytics\data\base.py
2.class BaseTransform:
python
# 这段代码定义了一个名为 BaseTransform 的 Python 类,用于图像处理或数据增强的基类。
# 这个类提供了一个初始化方法 __init__ 和三个抽象的方法 apply_image 、 apply_instances 和 apply_semantic ,以及一个 __call__ 方法,它是一个特殊的方法,允许类的实例像函数一样被调用。
class BaseTransform:
# Ultralytics 库中图像转换的基类。
# 此类是实现各种图像处理操作的基础,旨在兼容分类和语义分割任务。
# 方法:
# apply_image :将图像转换应用于标签。
# apply_instances :将转换应用于标签中的对象实例。
# apply_semantic :将语义分割应用于图像。
# __call__ :将所有标签转换应用于图像、实例和语义掩码。
"""
Base class for image transformations in the Ultralytics library.
This class serves as a foundation for implementing various image processing operations, designed to be
compatible with both classification and semantic segmentation tasks.
Methods:
apply_image: Applies image transformations to labels.
apply_instances: Applies transformations to object instances in labels.
apply_semantic: Applies semantic segmentation to an image.
__call__: Applies all label transformations to an image, instances, and semantic masks.
Examples:
>>> transform = BaseTransform()
>>> labels = {"image": np.array(...), "instances": [...], "semantic": np.array(...)}
>>> transformed_labels = transform(labels)
"""
# 这是类的构造函数,用于初始化类的实例。在这个基类中,它什么也不做( pass 是一个空操作,表示什么也不执行)。
def __init__(self) -> None:
# 初始化 BaseTransform 对象。
# 此构造函数设置基础转换对象,该对象可以针对特定图像处理任务进行扩展。它旨在兼容分类和语义分割。
"""
Initializes the BaseTransform object.
This constructor sets up the base transformation object, which can be extended for specific image
processing tasks. It is designed to be compatible with both classification and semantic segmentation.
Examples:
>>> transform = BaseTransform()
"""
pass
# 这个方法应该在子类中被重写,用于对图像数据应用某种变换。
def apply_image(self, labels):
# 将图像转换应用于标签。
# 此方法旨在由子类重写以实现特定的图像转换逻辑。在其基本形式中,它返回未更改的输入标签。
# 参数:
# labels(任意):要转换的输入标签。标签的确切类型和结构可能因具体实现而异。
# 返回:
# (任意):转换后的标签。在基本实现中,这与输入相同。
"""
Applies image transformations to labels.
This method is intended to be overridden by subclasses to implement specific image transformation
logic. In its base form, it returns the input labels unchanged.
Args:
labels (Any): The input labels to be transformed. The exact type and structure of labels may
vary depending on the specific implementation.
Returns:
(Any): The transformed labels. In the base implementation, this is identical to the input.
Examples:
>>> transform = BaseTransform()
>>> original_labels = [1, 2, 3]
>>> transformed_labels = transform.apply_image(original_labels)
>>> print(transformed_labels)
[1, 2, 3]
"""
pass
# 这个方法也应该在子类中被重写,用于对实例数据(图像中的物体标注)应用变换。
def apply_instances(self, labels):
# 将转换应用于标签中的对象实例。
# 此方法负责将各种转换应用于给定标签内的对象实例。它旨在由子类重写以实现特定的实例转换逻辑。
# 参数:
# labels (Dict):包含标签信息的字典,包括对象实例。
# 返回:
# (Dict):带有转换后的对象实例的修改后的标签字典。
"""
Applies transformations to object instances in labels.
This method is responsible for applying various transformations to object instances within the given
labels. It is designed to be overridden by subclasses to implement specific instance transformation
logic.
Args:
labels (Dict): A dictionary containing label information, including object instances.
Returns:
(Dict): The modified labels dictionary with transformed object instances.
Examples:
>>> transform = BaseTransform()
>>> labels = {"instances": Instances(xyxy=torch.rand(5, 4), cls=torch.randint(0, 80, (5,)))}
>>> transformed_labels = transform.apply_instances(labels)
"""
pass
# 这个方法同样需要在子类中被重写,用于对语义数据(可能是图像的语义分割信息)应用变换。
def apply_semantic(self, labels):
# 将语义分割转换应用于图像。
# 此方法旨在由子类重写以实现特定的语义分割转换。 在其基本形式中,它不执行任何操作。
# 参数:
# 标签 (Any):要转换的输入标签或语义分割掩码。
# 返回:
# (Any):转换后的语义分割掩码或标签。
"""
Applies semantic segmentation transformations to an image.
This method is intended to be overridden by subclasses to implement specific semantic segmentation
transformations. In its base form, it does not perform any operations.
Args:
labels (Any): The input labels or semantic segmentation mask to be transformed.
Returns:
(Any): The transformed semantic segmentation mask or labels.
Examples:
>>> transform = BaseTransform()
>>> semantic_mask = np.zeros((100, 100), dtype=np.uint8)
>>> transformed_mask = transform.apply_semantic(semantic_mask)
"""
pass
# 这个方法使得类的实例可以像函数一样被调用。当实例被调用时,它会依次执行 apply_image 、 apply_instances 和 apply_semantic 方法。
def __call__(self, labels):
# 将所有标签转换应用于图像、实例和语义掩码。
# 此方法协调将 BaseTransform 类中定义的各种转换应用于输入标签。它依次调用 apply_image 和 apply_instances 方法分别处理图像和对象实例。
# 参数:
# labels (Dict):包含图像数据和注释的字典。预期键包括图像数据的"img"和对象实例的"instances"。
# 返回:
# (Dict):包含转换后的图像和实例的输入标签字典。
"""
Applies all label transformations to an image, instances, and semantic masks.
This method orchestrates the application of various transformations defined in the BaseTransform class
to the input labels. It sequentially calls the apply_image and apply_instances methods to process the
image and object instances, respectively.
Args:
labels (Dict): A dictionary containing image data and annotations. Expected keys include 'img' for
the image data, and 'instances' for object instances.
Returns:
(Dict): The input labels dictionary with transformed image and instances.
Examples:
>>> transform = BaseTransform()
>>> labels = {"img": np.random.rand(640, 640, 3), "instances": []}
>>> transformed_labels = transform(labels)
"""
self.apply_image(labels)
self.apply_instances(labels)
self.apply_semantic(labels)
# 这个基类的设计意图是让子类继承 BaseTransform 并实现具体的变换逻辑。这样,任何继承自 BaseTransform 的子类都可以确保在处理图像、实例和语义数据时遵循相同的调用顺序。
3.class Compose:
python
# 这段代码定义了一个名为 Compose 的 Python 类,它用于组合多个图像变换操作。这个类允许你将多个变换(如 BaseTransform 的子类实例)串联起来,以便按顺序对数据(如图像)应用这些变换。
class Compose:
# 用于组合多个图像转换的类。
# 属性:
# transforms (List[Callable]):要按顺序应用的转换函数列表。
# 方法:
# __call__:将一系列转换应用于输入数据。
# append:将新转换附加到现有转换列表。
# insert:在转换列表中的指定索引处插入新转换。
# __getitem__:使用索引检索特定转换或一组转换。
# __setitem__:使用索引设置特定转换或一组转换。
# tolist:将转换列表转换为标准 Python 列表。
"""
A class for composing multiple image transformations.
Attributes:
transforms (List[Callable]): A list of transformation functions to be applied sequentially.
Methods:
__call__: Applies a series of transformations to input data.
append: Appends a new transform to the existing list of transforms.
insert: Inserts a new transform at a specified index in the list of transforms.
__getitem__: Retrieves a specific transform or a set of transforms using indexing.
__setitem__: Sets a specific transform or a set of transforms using indexing.
tolist: Converts the list of transforms to a standard Python list.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(30)]
>>> compose = Compose(transforms)
>>> transformed_data = compose(data)
>>> compose.append(CenterCrop((224, 224)))
>>> compose.insert(0, RandomFlip())
"""
# 这段代码是一个类的构造函数 __init__ 的定义,它的作用是初始化类的实例。这个构造函数接收一个参数。
# 1.transforms :这个参数预期是一个变换的集合,可以是一个单独的变换对象或者一个包含多个变换对象的列表。
def __init__(self, transforms):
# 使用转换列表初始化 Compose 对象。
# 参数:
# transforms (List[Callable]):要按顺序应用的可调用转换对象列表。
"""
Initializes the Compose object with a list of transforms.
Args:
transforms (List[Callable]): A list of callable transform objects to be applied sequentially.
Examples:
>>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip
>>> transforms = [RandomHSV(), RandomFlip()]
>>> compose = Compose(transforms)
"""
# self.transforms = transforms :如果传入的 transforms 参数已经是一个列表,那么直接将其赋值给实例变量 self.transforms 。
# transforms if isinstance(transforms, list) else [transforms] :这是一个条件表达式(也称为三元操作符),它检查 transforms 是否是一个列表。
# 如果是列表,就返回 transforms 本身;如果不是列表,就将 transforms 包装在一个新列表中,确保 self.transforms 总是一个列表。
self.transforms = transforms if isinstance(transforms, list) else [transforms]
# 这样做的好处是,这个构造函数可以接受单个变换对象或者变换对象的列表作为输入,使得类的使用更加灵活。无论是单个变换还是多个变换,都可以被这个类所处理。
# 这段代码定义了 __call__ 方法,它是一个特殊方法,允许类的实例表现得像函数一样,即可以被"调用"。在这个上下文中, __call__ 方法被用来将一系列的变换(transformations)应用到输入数据 data 上。
def __call__(self, data):
# 对输入数据应用一系列转换。此方法按顺序将 Compose 对象的转换列表中的每个转换应用于输入数据。
# 参数:
# data (Any):要转换的输入数据。这可以是任何类型,具体取决于列表中的转换。
# 返回:
# (Any):按顺序应用所有转换后的转换数据。
"""
Applies a series of transformations to input data. This method sequentially applies each transformation in the
Compose object's list of transforms to the input data.
Args:
data (Any): The input data to be transformed. This can be of any type, depending on the
transformations in the list.
Returns:
(Any): The transformed data after applying all transformations in sequence.
Examples:
>>> transforms = [Transform1(), Transform2(), Transform3()]
>>> compose = Compose(transforms)
>>> transformed_data = compose(input_data)
"""
# 这个循环遍历存储在 self.transforms 中的所有变换。 self.transforms 是一个包含变换对象的列表,这些对象预期都有一个 __call__ 方法,使得它们可以被调用。
for t in self.transforms:
# 在循环中,每个变换对象 t 被调用,并将输入数据 data 作为参数传递。变换对象 t 应用其变换逻辑,并返回变换后的数据。这个返回值被赋值回 data 变量,这样下一个变换就可以在已经变换过的数据上工作。
data = t(data)
# 循环完成后,所有变换都被应用到了原始数据上,最终变换后的数据通过 return 语句返回。
return data
# 这种方法的优雅之处在于它的链式调用能力,允许你将多个变换组合起来,形成一个处理流水线。每个变换只需要关心它自己的工作,而 __call__ 方法处理整个流程的协调。
# 这段代码定义了 append 方法,它是 Compose 类的一个成员方法,用于向 Compose 实例的 transforms 列表中添加一个新的变换。
def append(self, transform):
# 将新变换附加到现有变换列表。
# 参数:
# transform (BaseTransform):要添加到合成中的变换。
"""
Appends a new transform to the existing list of transforms.
Args:
transform (BaseTransform): The transformation to be added to the composition.
Examples:
>>> compose = Compose([RandomFlip(), RandomPerspective()])
>>> compose.append(RandomHSV())
"""
# 这个方法调用 Python 列表的 append 方法,将传入的 transform 参数添加到 self.transforms 列表的末尾。 self.transforms 是一个包含变换对象的列表,这些对象通常是 BaseTransform 的子类实例或者其他具有 __call__ 方法的可调用对象。
self.transforms.append(transform)
# 这个方法允许用户动态地向 Compose 实例中添加变换,使得变换序列更加灵活和可扩展。
# 这段代码定义了 insert 方法,它是 Compose 类的一个成员方法,用于在 Compose 实例的 transforms 列表中的指定位置插入一个新的变换。
# 1.self :这是一个指向类实例本身的引用,它允许访问类的属性和方法。在类的任何方法中, self 都是第一个参数,并且是隐式传递的,不需要显式提供。
# 2.index :这是指定插入变换的索引位置的参数。索引是一个整数,表示在 self.transforms 列表中插入新变换 transform 的位置。如果提供的索引等于列表的长度, transform 将被添加到列表的末尾。
# 3.transform :这是要插入到 self.transforms 列表中的变换对象。这个对象应该是一个可以被调用的实例(即具有 __call__ 方法的对象),这样它才能在后续的变换过程中被正确执行。
def insert(self, index, transform):
# 在现有变换列表中的指定索引处插入新变换。
# 参数:
# index (int):插入新变换的索引。
# transform (BaseTransform):要插入的变换对象。
"""
Inserts a new transform at a specified index in the existing list of transforms.
Args:
index (int): The index at which to insert the new transform.
transform (BaseTransform): The transform object to be inserted.
Examples:
>>> compose = Compose([Transform1(), Transform2()])
>>> compose.insert(1, Transform3())
>>> len(compose.transforms)
3
"""
# list.insert(index, item)
# insert() 函数通常是指 Python 中列表(list)类型的一个方法,它用于在列表中的特定位置插入一个元素。这个方法的第一个参数是索引(index),指定了元素应该被插入的位置,第二个参数是要插入的元素本身。
# 参数 :
# index : 一个整数,表示元素应该被插入的位置。如果索引超出了列表的范围,将会抛出一个 IndexError 。
# item : 要插入到列表中的元素。
# 这个方法调用 Python 列表的 insert 方法,将传入的 transform 参数插入到 self.transforms 列表的指定位置 index 。如果 index 等于列表的长度, transform 将被添加到列表的末尾。
self.transforms.insert(index, transform)
# 这个方法允许用户在变换序列中的特定位置插入新的变换,提供了比 append 方法更灵活的控制。
# 这段代码定义了 Compose 类的 __getitem__ 方法,这个方法使得 Compose 类的实例可以通过索引或索引列表来获取子集,并且返回一个新的 Compose 实例,包含原始列表中指定的变换。
# 这个方法接受一个参数。
# 1.index :它可以是一个整数或一个整数列表。 Union 表示参数可以是这两种类型中的任意一种。
def __getitem__(self, index: Union[list, int]) -> "Compose":
# 使用索引检索特定变换或一组变换。
# 参数:
# index (int | List[int]):要检索的变换的索引或索引列表。
# 返回:
# (Compose):包含所选变换的新 Compose 对象。
# 引发:
# AssertionError:如果索引不是 int 或 list 类型。
"""
Retrieves a specific transform or a set of transforms using indexing.
Args:
index (int | List[int]): Index or list of indices of the transforms to retrieve.
Returns:
(Compose): A new Compose object containing the selected transform(s).
Raises:
AssertionError: If the index is not of type int or list.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(10), RandomHSV(0.5, 0.5, 0.5)]
>>> compose = Compose(transforms)
>>> single_transform = compose[1] # Returns a Compose object with only RandomPerspective
>>> multiple_transforms = compose[0:2] # Returns a Compose object with RandomFlip and RandomPerspective
"""
# 这是一个断言语句,用于确保传入的 index 参数是一个整数或列表。如果不是,将抛出一个 AssertionError ,并显示一条错误消息,说明期望的参数类型。
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" # 索引应该是列表或 int 类型,但得到了 {type(index)。
# 这行代码处理 index 参数,如果 index 是一个整数( int ),则将其转换为只包含该整数的列表,这样无论是单个整数还是整数列表,都可以统一处理。
index = [index] if isinstance(index, int) else index
# 这行代码创建并返回一个新的 Compose 实例。它通过列表推导式从原始的 self.transforms 列表中取出指定索引位置的变换,并传递给新 Compose 实例的构造函数。
return Compose([self.transforms[i] for i in index])
# 这个方法允许用户通过索引或索引列表来提取 Compose 实例中的变换子集。
# 这段代码定义了 Compose 类的 __setitem__ 方法,它允许你通过索引或索引列表来设置 Compose 实例中的变换。这个方法类似于 Python 列表的 __getitem__ 方法,但是用于赋值操作。
# 1.index :它可以是一个整数或一个整数列表。 Union 表示参数可以是这两种类型中的任意一种。
# 2.value :它可以是一个变换对象、一个变换对象列表,或者是与 index 类型相匹配的单个值。
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
# 使用索引设置组合中的一个或多个变换。
"""
Sets one or more transforms in the composition using indexing.
Args:
index (int | List[int]): Index or list of indices to set transforms at.
value (Any | List[Any]): Transform or list of transforms to set at the specified index(es).
Raises:
AssertionError: If index type is invalid, value type doesn't match index type, or index is out of range.
Examples:
>>> compose = Compose([Transform1(), Transform2(), Transform3()])
>>> compose[1] = NewTransform() # Replace second transform
>>> compose[0:2] = [NewTransform1(), NewTransform2()] # Replace first two transforms
"""
# 这是一个断言语句,用于确保传入的 index 参数是一个整数或列表。如果不是,将抛出一个 AssertionError ,并显示一条错误消息,说明期望的参数类型。
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
# 如果 index 是一个列表,那么这个方法会检查 value 是否也是一个列表。如果不是,将抛出一个 AssertionError 。
if isinstance(index, list):
assert isinstance(
value, list
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
# 如果 index 是一个整数,那么这个方法会将 index 和 value 都转换为列表,以便统一处理。
if isinstance(index, int):
index, value = [index], [value]
# 这个循环遍历 index 和 value (它们现在是相同的类型,并且长度相同),并使用 zip 函数将它们配对。
for i, v in zip(index, value):
# 这是一个断言语句,用于确保索引 i 在 self.transforms 列表的范围内。如果不是,将抛出一个 AssertionError ,并显示一条错误消息,说明索引超出范围。
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
# 这行代码将 self.transforms 列表中索引为 i 的元素设置为 v ,即更新变换列表中的元素。
self.transforms[i] = v
# 这个方法允许用户通过索引或索引列表来更新 Compose 实例中的变换。
# 这段代码定义了 Compose 类的 tolist 方法,其作用是将 Compose 实例中包含的变换列表返回给调用者。
# 这是方法的定义, tolist 是方法的名称, self 参数代表类的实例本身。
def tolist(self):
# 将变换列表转换为标准 Python 列表。
"""
Converts the list of transforms to a standard Python list.
Returns:
(List): A list containing all the transform objects in the Compose instance.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(10), CenterCrop()]
>>> compose = Compose(transforms)
>>> transform_list = compose.tolist()
>>> print(len(transform_list))
3
"""
# 这行代码返回 self.transforms ,它是存储在 Compose 类实例中的一个列表,包含了所有的变换对象。
return self.transforms
# 这个方法非常简单,它没有参数(除了 self ),也不进行任何复杂的操作,仅仅是返回存储变换的列表。
# 这段代码定义了 Compose 类的 __repr__ 方法,它是一个特殊方法,用于返回对象的官方字符串表示形式,通常用于调试和开发。 __repr__ 方法的目的是提供一个明确的、精确的方式来表示对象,以便开发者能够理解对象的状态。
# 这是方法的定义, __repr__ 是方法的名称, self 参数代表类的实例本身。
def __repr__(self):
# 返回 Compose 对象的字符串表示形式。
"""
Returns a string representation of the Compose object.
Returns:
(str): A string representation of the Compose object, including the list of transforms.
Examples:
>>> transforms = [RandomFlip(), RandomPerspective(degrees=10, translate=0.1, scale=0.1)]
>>> compose = Compose(transforms)
>>> print(compose)
Compose([
RandomFlip(),
RandomPerspective(degrees=10, translate=0.1, scale=0.1)
])
"""
# 这行代码构建并返回一个字符串,该字符串表示 Compose 实例的官方字符串表示形式。
# self.__class__.__name__ :这是获取当前类名称的方式,它将返回类的名字,例如 "Compose" 。
# ', '.join([f'{t}' for t in self.transforms]) :这是一个列表推导式,它遍历 self.transforms 列表中的每个变换对象 t ,并为每个对象创建一个字符串表示 f'{t}' 。然后, join 方法将这些字符串连接成一个单一的字符串,各个字符串之间用逗号和空格分隔。
# f"{...}" :这是一个格式化字符串(也称为 f-string),它允许你插入表达式的值来构建字符串。
return f"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})"
# 这个方法允许 Compose 实例在被打印或者在调试时显示一个清晰、易于理解的字符串表示。
4.class BaseMixTransform:
python
# 这段代码定义了一个名为 BaseMixTransform 的 Python 类,它是一个用于数据增强的基类,专门用于实现 Mosaic 和 MixUp 这两种技术。这个类提供了一个框架,子类需要实现具体的 _mix_transform 和 get_indexes 方法。
class BaseMixTransform:
# MixUp 和 Mosaic 等混合转换的基类。
# 此类为在数据集上实现混合转换提供了基础。它处理基于概率的转换应用并管理多个图像和标签的混合。
# 方法:
# __call__ :将混合转换应用于输入标签。
# _mix_transform :由子类为特定混合操作实现的抽象方法。
# get_indexes :获取要混合的图像索引的抽象方法。
# _update_label_text :更新混合图像的标签文本。
"""
Base class for mix transformations like MixUp and Mosaic.
This class provides a foundation for implementing mix transformations on datasets. It handles the
probability-based application of transforms and manages the mixing of multiple images and labels.
Attributes:
dataset (Any): The dataset object containing images and labels.
pre_transform (Callable | None): Optional transform to apply before mixing.
p (float): Probability of applying the mix transformation.
Methods:
__call__: Applies the mix transformation to the input labels.
_mix_transform: Abstract method to be implemented by subclasses for specific mix operations.
get_indexes: Abstract method to get indexes of images to be mixed.
_update_label_text: Updates label text for mixed images.
Examples:
>>> class CustomMixTransform(BaseMixTransform):
... def _mix_transform(self, labels):
... # Implement custom mix logic here
... return labels
...
... def get_indexes(self):
... return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
>>> dataset = YourDataset()
>>> transform = CustomMixTransform(dataset, p=0.5)
>>> mixed_labels = transform(original_labels)
"""
# 这段代码是 BaseMixTransform 类的构造函数 __init__ 的定义。这个构造函数用于初始化类的实例,设置实例的属性。
# 1.self :这是一个对当前对象实例的引用,它允许访问类的其他成员和方法。
# 2.dataset :它代表要进行变换的数据集。这个数据集对象应该提供了获取图像和标签的方法,例如 get_image_and_label(i) ,它用于根据索引 i 获取对应的图像和标签。
# 3.pre_transform :它是一个可选的预变换。预变换是一个在应用主要的混合变换(如 Mosaic 或 MixUp)之前执行的变换。如果提供了 pre_transform ,则它应该是一个可以被调用的对象(即具有 __call__ 方法的对象),这样它就可以在获取图像和标签后立即应用。
# 4.p :它是一个概率值,用于确定是否应用混合变换。如果随机数生成器产生的值大于 p ,则不应用混合变换,直接返回原始标签。
# -> None :这是 Python 3 函数注解的一部分,表示这个函数不返回任何值(即返回 None )。
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
# 初始化 BaseMixTransform 对象,用于 MixUp 和 Mosaic 等混合转换。
# 此类用作在图像处理管道中实现混合转换的基础。
"""
Initializes the BaseMixTransform object for mix transformations like MixUp and Mosaic.
This class serves as a base for implementing mix transformations in image processing pipelines.
Args:
dataset (Any): The dataset object containing images and labels for mixing.
pre_transform (Callable | None): Optional transform to apply before mixing.
p (float): Probability of applying the mix transformation. Should be in the range [0.0, 1.0].
Examples:
>>> dataset = YOLODataset("path/to/data")
>>> pre_transform = Compose([RandomFlip(), RandomPerspective()])
>>> mix_transform = BaseMixTransform(dataset, pre_transform, p=0.5)
"""
self.dataset = dataset
self.pre_transform = pre_transform
self.p = p
# 构造函数的主体部分简单地将传入的参数赋值给 self 的属性,这些属性将在类的其他方法中使用。这样,类的实例就可以访问数据集、预变换和概率值了。
# 这段代码定义了 BaseMixTransform 类的 __call__ 方法,它是一个特殊方法,允许类的实例像函数一样被调用。这个方法实现了一个条件数据增强流程,用于在一定概率 p 下对图像进行 Mosaic 或 MixUp 变换。
# 1.labels :它代表当前要处理的数据标签。这个参数预期是一个字典或类似的数据结构,包含了图像的标签信息,这些信息可能会在 Mosaic 或 MixUp 变换过程中被修改或扩展。
def __call__(self, labels):
# 将预处理转换和混合/马赛克转换应用于标签数据。
# 此方法根据概率因子确定是否应用混合转换。 如果应用,它将选择其他图像,应用预转换(如果指定),然后执行混合转换。
# 参数:
# 标签 (Dict):包含图像标签数据的字典。
"""
Applies pre-processing transforms and mixup/mosaic transforms to labels data.
This method determines whether to apply the mix transform based on a probability factor. If applied, it
selects additional images, applies pre-transforms if specified, and then performs the mix transform.
Args:
labels (Dict): A dictionary containing label data for an image.
Returns:
(Dict): The transformed labels dictionary, which may include mixed data from other images.
Examples:
>>> transform = BaseMixTransform(dataset, pre_transform=None, p=0.5)
>>> result = transform({"image": img, "bboxes": boxes, "cls": classes})
"""
# 这行代码生成一个0到1之间的随机数,并与概率 p 比较。如果随机数大于 p ,则不进行变换,直接返回原始的 labels 。
if random.uniform(0, 1) > self.p:
return labels
# Get index of one or three other images
# 调用 get_indexes 方法获取一个或多个其他图像的索引。这个方法需要在子类中实现。
indexes = self.get_indexes()
# 确保 indexes 是一个列表,无论 get_indexes 返回单个索引还是索引列表。
if isinstance(indexes, int):
indexes = [indexes]
# Get images information will be used for Mosaic or MixUp
# 使用数据集的 get_image_and_label 方法获取索引对应的图像和标签信息,并存储在 mix_labels 列表中。
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
# 如果提供了预变换 pre_transform ,则对每个 mix_labels 中的数据应用这个预变换。
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
# 将混合图像的信息添加到 labels 字典中,键为 "mix_labels" 。
labels["mix_labels"] = mix_labels
# Update cls and texts
# 调用 _update_label_text 方法更新 labels 中的文本标签和类别标签。
labels = self._update_label_text(labels)
# Mosaic or MixUp
# 调用 _mix_transform 方法执行实际的 Mosaic 或 MixUp 变换。这个方法需要在子类中实现。
labels = self._mix_transform(labels)
# 从 labels 中移除 "mix_labels" 键及其对应的值。
labels.pop("mix_labels", None)
# 返回变换后的 labels 。
return labels
# 这个方法的关键在于它提供了一个框架,使得子类可以轻松实现具体的 Mosaic 和 MixUp 变换逻辑,同时处理变换前后的标签更新和预变换应用。
# 这段代码定义了 BaseMixTransform 类中的 _mix_transform 方法,其目的是提供一个抽象方法的框架,要求任何继承自 BaseMixTransform 的子类都必须实现这个方法。
# 这是方法的定义, _mix_transform 是方法的名称, self 参数代表类的实例本身。
# 1.labels :是传递给方法的参数,代表当前要进行混合变换的标签数据。
def _mix_transform(self, labels):
# 将 MixUp 或 Mosaic 增强应用于标签字典。
# 此方法应由子类实现,以执行特定的混合转换,如 MixUp 或 Mosaic。它使用增强数据就地修改输入标签字典。
"""
Applies MixUp or Mosaic augmentation to the label dictionary.
This method should be implemented by subclasses to perform specific mix transformations like MixUp or
Mosaic. It modifies the input label dictionary in-place with the augmented data.
Args:
labels (Dict): A dictionary containing image and label data. Expected to have a 'mix_labels' key
with a list of additional image and label data for mixing.
Returns:
(Dict): The modified labels dictionary with augmented data after applying the mix transform.
Examples:
>>> transform = BaseMixTransform(dataset)
>>> labels = {"image": img, "bboxes": boxes, "mix_labels": [{"image": img2, "bboxes": boxes2}]}
>>> augmented_labels = transform._mix_transform(labels)
"""
# NotImplementedError
# NotImplementedError 是 Python 中的一个内置异常,用于指示一个方法或函数应该在子类中被实现,但在当前类中尚未提供具体的实现。这个异常通常在抽象基类(Abstract Base Classes,简称 ABCs)中使用,以确保子类覆盖了所有必要的方法。
# 这行代码引发了一个 NotImplementedError 异常。在 Python 中,这是一种常见的做法,用于指示一个方法应该在子类中被重写,但当前基类中并未提供具体的实现。
raise NotImplementedError
# _mix_transform 方法是一个抽象方法,它预期在子类中被实现,以提供具体的混合变换逻辑。这个变换可以是 Mosaic、MixUp 或其他任何类型的图像混合技术。由于 BaseMixTransform 类是一个基类,它不包含具体的变换实现,而是提供了一个框架和一些通用的逻辑,真正的变换逻辑需要在子类中根据具体的应用场景来实现。
# 这段代码定义了 BaseMixTransform 类中的 get_indexes 方法,其目的是提供一个抽象方法的框架,要求任何继承自 BaseMixTransform 的子类都必须实现这个方法。
# 这是方法的定义, get_indexes 是方法的名称, self 参数代表类的实例本身。
def get_indexes(self):
# 获取用于马赛克增强的混洗索引列表。
"""
Gets a list of shuffled indexes for mosaic augmentation.
Returns:
(List[int]): A list of shuffled indexes from the dataset.
Examples:
>>> transform = BaseMixTransform(dataset)
>>> indexes = transform.get_indexes()
>>> print(indexes) # [3, 18, 7, 2]
"""
# 这行代码引发了一个 NotImplementedError 异常。在 Python 中,这是一种常见的做法,用于指示一个方法应该在子类中被重写,但当前基类中并未提供具体的实现。
raise NotImplementedError
# get_indexes 方法是一个抽象方法,它预期在子类中被实现,以提供获取用于混合变换的图像索引的逻辑。这个方法对于实现 Mosaic 或 MixUp 等数据增强技术至关重要,因为它决定了哪些图像将被组合在一起。
# 这段代码定义了 BaseMixTransform 类中的 _update_label_text 方法,它用于更新混合变换后的标签文本。这个方法处理了标签中文本的合并和更新,确保混合后的标签与原始标签一致,并且更新了类别标签以匹配新的文本标签集合。
def _update_label_text(self, labels):
# 更新图像增强中混合标签的标签文本和类别 ID。
# 此方法处理输入标签字典和任何混合标签的"texts"和"cls"字段,创建一组统一的文本标签并相应地更新类别 ID。
"""
Updates label text and class IDs for mixed labels in image augmentation.
This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels,
creating a unified set of text labels and updating class IDs accordingly.
Args:
labels (Dict): A dictionary containing label information, including 'texts' and 'cls' fields,
and optionally a 'mix_labels' field with additional label dictionaries.
Returns:
(Dict): The updated labels dictionary with unified text labels and updated class IDs.
Examples:
>>> labels = {
... "texts": [["cat"], ["dog"]],
... "cls": torch.tensor([[0], [1]]),
... "mix_labels": [{"texts": [["bird"], ["fish"]], "cls": torch.tensor([[0], [1]])}],
... }
>>> updated_labels = self._update_label_text(labels)
>>> print(updated_labels["texts"])
[['cat'], ['dog'], ['bird'], ['fish']]
>>> print(updated_labels["cls"])
tensor([[0],
[1]])
>>> print(updated_labels["mix_labels"][0]["cls"])
tensor([[2],
[3]])
"""
# 检查 labels 字典中是否存在键 "texts" 。如果不存在,意味着没有文本标签需要更新,因此直接返回原始的 labels 。
if "texts" not in labels:
return labels
# 这行代码创建一个新的文本列表,包含原始标签中的文本和所有混合标签中的文本。 sum 函数将多个列表合并成一个列表。
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
# 在 Python 中,集合(set)是一个无序的、不包含重复元素的数据结构。你可以使用大括号 {} 或者 set() 函数来定义一个集合。
# 将合并后的文本列表转换为一个集合,以去除重复的文本项,然后将集合转换回列表。这里使用 tuple(x) 是因为集合只能包含不可变类型,而字符串是不可变的,所以每个字符串项被转换为元组以确保可哈希。
# {tuple(x) for x in mix_texts} :这是一个集合推导式,它遍历 mix_texts 列表中的每个元素 x 。由于集合(set)只能包含不可变(hashable)的元素,而列表是可变的,所以这里将每个文本条目 x 转换为元组(tuple),使其成为不可变类型,以便可以被添加到集合中。
# list(...) :将集合转换回列表。集合中的元素是唯一的,因为集合会自动去除重复的元素。将集合转换回列表是为了后续操作,因为列表提供了更多的灵活性和操作选项。
# mix_texts = list({tuple(x) for x in mix_texts}) :这行代码的最终结果是 mix_texts 变量被更新为一个新列表,其中包含去重后的文本条目。
# 这个步骤的目的是为了创建一个不包含重复项的文本标签列表,因为同一个文本标签可能在原始标签和混合标签中多次出现。通过将每个文本条目转换为元组并使用集合去重,然后再转换回列表,可以确保每个文本条目在最终的 mix_texts 列表中只出现一次。
mix_texts = list({tuple(x) for x in mix_texts})
# 创建一个映射,将每个唯一的文本标签映射到一个唯一的索引。
text2id = {text: i for i, text in enumerate(mix_texts)}
# 遍历原始标签和所有混合标签。
for label in [labels] + labels["mix_labels"]:
# 遍历每个标签中的类别索引。
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
# 根据类别索引找到对应的文本标签。
text = label["texts"][int(cls)]
# 更新类别索引,使其指向新的文本标签集合中的索引。
label["cls"][i] = text2id[tuple(text)]
# 更新标签中的文本列表,使其指向新的合并后的文本标签集合。
label["texts"] = mix_texts
# 返回更新后的 labels 。
return labels
# 这个方法确保了在混合变换后,所有标签的文本信息都是最新的,并且类别标签与新的文本集合保持一致。这对于保持数据的一致性和准确性至关重要,特别是在进行数据增强时。
5.class Mosaic(BaseMixTransform):
python
# 这段代码是一个Python类的定义,名为 Mosaic ,它继承自 BaseMixTransform 。这个类用来处理图像数据增强的一种方法,特别是在目标检测任务中常用的"mosaic"(马赛克)数据增强技术。
# 定义了一个名为 Mosaic 的新类,它继承自 BaseMixTransform 类。
class Mosaic(BaseMixTransform):
# 图像数据集的马赛克增强。
# 此类通过将多张(4 张或 9 张)图像组合成一张马赛克图像来执行马赛克增强。增强以给定的概率应用于数据集。
# 方法:
# get_indexes :返回来自数据集的随机索引列表。
# _mix_transform :将混合转换应用于输入图像和标签。
# _mosaic3 :创建 1x3 图像马赛克。
# _mosaic4 :创建 2x2 图像马赛克。
# _mosaic9 :创建 3x3 图像马赛克。
# _update_labels :使用填充更新标签。
# _cat_labels :连接标签并剪辑马赛克边框实例。
"""
Mosaic augmentation for image datasets.
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
The augmentation is applied to a dataset with a given probability.
Attributes:
dataset: The dataset on which the mosaic augmentation is applied.
imgsz (int): Image size (height and width) after mosaic pipeline of a single image.
p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.
n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).
border (Tuple[int, int]): Border size for width and height.
Methods:
get_indexes: Returns a list of random indexes from the dataset.
_mix_transform: Applies mixup transformation to the input image and labels.
_mosaic3: Creates a 1x3 image mosaic.
_mosaic4: Creates a 2x2 image mosaic.
_mosaic9: Creates a 3x3 image mosaic.
_update_labels: Updates labels with padding.
_cat_labels: Concatenates labels and clips mosaic border instances.
Examples:
>>> from ultralytics.data.augment import Mosaic
>>> dataset = YourDataset(...) # Your image dataset
>>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4)
>>> augmented_labels = mosaic_aug(original_labels)
"""
# 这是 Mosaic 类的构造函数,它接受四个参数。
# 1.dataset :数据集,可能是包含图像和标签的数据结构。
# 2.imgsz :图像的大小,默认为640。
# 3.p :进行mosaic变换的概率,默认为1.0,意味着总是进行变换。
# 4.n :马赛克中的图像数量,默认为4,可以是4或9。
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
# 初始化马赛克增强对象。
# 此类通过将多张(4 张或 9 张)图像组合成一张马赛克图像来执行马赛克增强。增强以给定的概率应用于数据集。
"""
Initializes the Mosaic augmentation object.
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
The augmentation is applied to a dataset with a given probability.
Args:
dataset (Any): The dataset on which the mosaic augmentation is applied.
imgsz (int): Image size (height and width) after mosaic pipeline of a single image.
p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1.
n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3).
Examples:
>>> from ultralytics.data.augment import Mosaic
>>> dataset = YourDataset(...)
>>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4)
"""
# 这是一个断言语句,确保 p 的值在0到1之间。
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." # 概率应该在 [0, 1] 范围内,但得到的是 {p}。
# 这是另一个断言语句,确保 n 的值只能是4或9,这对应于马赛克布局中的图像数量。
assert n in {4, 9}, "grid must be equal to 4 or 9." # 网格必须等于 4 或 9。
# 调用父类 BaseMixTransform 的构造函数,传递 dataset 和 p 参数。
super().__init__(dataset=dataset, p=p)
# 将传入的 imgsz 参数赋值给实例变量 self.imgsz ,用于存储图像的大小。
self.imgsz = imgsz
# 设置马赛克图像的边界,这里设置为图像大小的一半的负值,是为了在拼接图像时保持中心对齐。
self.border = (-imgsz // 2, -imgsz // 2) # width, height self.border = (-320,-320)
# 将传入的 n 参数赋值给实例变量 self.n ,用于存储马赛克布局中的图像数量。
self.n = n
# 这段代码是 Mosaic 类的一个方法定义,名为 get_indexes ,它用于从数据集中选择图像的索引,以便进行马赛克数据增强。这个方法接受一个参数 buffer ,它决定了选择图像的方式。
# 定义了一个名为 get_indexes 的方法,它接受两个参数。
# 1.self :指向类的实例。
# 2.buffer :一个布尔值,默认为True。
def get_indexes(self, buffer=True):
# 返回用于马赛克增强的数据集随机索引列表。
# 此方法根据"buffer"参数从缓冲区或整个数据集中选择随机图像索引。它用于选择用于创建马赛克增强的图像。
"""
Returns a list of random indexes from the dataset for mosaic augmentation.
This method selects random image indexes either from a buffer or from the entire dataset, depending on
the 'buffer' parameter. It is used to choose images for creating mosaic augmentations.
Args:
buffer (bool): If True, selects images from the dataset buffer. If False, selects from the entire
dataset.
Returns:
(List[int]): A list of random image indexes. The length of the list is n-1, where n is the number
of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9).
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
>>> indexes = mosaic.get_indexes()
>>> print(len(indexes)) # Output: 3
"""
# 这是一个条件判断,如果 buffer 为True,则从缓冲区中选择图像。
if buffer: # select images from buffer
# 如果 buffer 为True,使用 random.choices 函数从 self.dataset.buffer 中随机选择 self.n - 1 个图像的索引。 self.dataset.buffer 是一个包含图像索引的列表, k=self.n - 1 是因为马赛克布局中已经有一张图像(通常是主图像),所以只需要额外选择 self.n - 1 张图像。
return random.choices(list(self.dataset.buffer), k=self.n - 1)
# 如果 buffer 为False,则从整个数据集中随机选择图像。
else: # select any images
# random.randint(a, b)
# random.randint() 是 Python 中 random 模块提供的一个函数,用于生成一个指定范围内的随机整数。
# 参数 :
# a :范围的下限(包含),必须是一个整数。
# b :范围的上限(包含),必须是一个整数,且 b 必须大于或等于 a 。
# 返回值 :
# 函数返回一个随机整数 N ,满足 a <= N <= b
# 如果 buffer 为False,使用列表推导式生成一个包含 self.n - 1 个随机索引的列表。 random.randint(0, len(self.dataset) - 1) 用于生成一个随机索引,范围从0到数据集长度减1。
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
# 这个方法的作用是从数据集中随机选择图像的索引,以便进行马赛克数据增强。选择图像的方式取决于 buffer 参数的值,如果为True,则从缓冲区中选择;如果为False,则从整个数据集中选择。选择的图像数量取决于 self.n 的值,减1是因为马赛克布局中已经有一张图像。
# 这段代码是 Mosaic 类中的一个私有方法 _mix_transform 的定义,它用于执行马赛克数据增强的变换。这个方法接受一个参数 labels ,它包含了图像的标签信息。
# 定义了一个名为 _mix_transform 的方法,它接受两个参数。
# 1.self :指向类的实例。
# 2.labels :包含图像标签的字典。
def _mix_transform(self, labels):
# 将马赛克增强应用于输入图像和标签。
# 此方法根据"n"属性将多幅图像(3、4 或 9)组合成单个马赛克图像。它确保不存在矩形注释,并且有其他图像可用于马赛克增强。
"""
Applies mosaic augmentation to the input image and labels.
This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute.
It ensures that rectangular annotations are not present and that there are other images available for
mosaic augmentation.
Args:
labels (Dict): A dictionary containing image data and annotations. Expected keys include:
- 'rect_shape': Should be None as rect and mosaic are mutually exclusive.
- 'mix_labels': A list of dictionaries containing data for other images to be used in the mosaic.
Returns:
(Dict): A dictionary containing the mosaic-augmented image and updated annotations.
Raises:
AssertionError: If 'rect_shape' is not None or if 'mix_labels' is empty.
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
>>> augmented_data = mosaic._mix_transform(labels)
"""
# 这是一个断言语句,检查 labels 字典中是否包含键 "rect_shape" 。如果包含,则断言失败,并抛出错误信息,表明矩形变换和马赛克变换是互斥的,不能同时使用。
assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive." # rect 和 mosaic 是互斥的。
# 这是另一个断言语句,检查 labels 字典中 "mix_labels" 键对应的值(默认为空列表)是否非空。如果为空,则断言失败,并抛出错误信息,表明没有其他图像用于马赛克增强。
assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
# 这是一个条件表达式,根据 self.n 的值来决定调用哪个具体的马赛克变换方法 :
# 如果 self.n 等于3,调用 self._mosaic3(labels) 方法。
# 如果 self.n 等于4,调用 self._mosaic4(labels) 方法。
# 如果 self.n 等于9,调用 self._mosaic9(labels) 方法。
return (
self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
) # This code is modified for mosaic3 method.
# 这个方法的作用是根据马赛克布局中的图像数量( self.n ),选择相应的马赛克变换方法来执行数据增强。 _mosaic3 、 _mosaic4 和 _mosaic9 这三个方法,分别对应3、4和9张图像的马赛克变换逻辑。
# 这段代码是 Mosaic 类中的一个私有方法 _mosaic3 的定义,它用于执行3张图像的马赛克数据增强。这个方法接受一个参数 labels ,它包含了图像的标签信息。
# 定义了一个名为 _mosaic3 的方法,它接受两个参数。
# 1.self :指向类的实例。
# 2.labels :包含图像标签的字典。
def _mosaic3(self, labels):
# 通过组合三幅图像创建 1x3 图像马赛克。
# 此方法将三幅图像以水平布局排列,主图像位于中心,两幅附加图像位于两侧。它是对象检测中使用的马赛克增强技术的一部分。
"""
Creates a 1x3 image mosaic by combining three images.
This method arranges three images in a horizontal layout, with the main image in the center and two
additional images on either side. It's part of the Mosaic augmentation technique used in object detection.
Args:
labels (Dict): A dictionary containing image and label information for the main (center) image.
Must include 'img' key with the image array, and 'mix_labels' key with a list of two
dictionaries containing information for the side images.
Returns:
(Dict): A dictionary with the mosaic image and updated labels. Keys include:
- 'img' (np.ndarray): The mosaic image array with shape (H, W, C).
- Other keys from the input labels, updated to reflect the new image dimensions.
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=3)
>>> labels = {
... "img": np.random.rand(480, 640, 3),
... "mix_labels": [{"img": np.random.rand(480, 640, 3)} for _ in range(2)],
... }
>>> result = mosaic._mosaic3(labels)
>>> print(result["img"].shape)
(640, 640, 3)
"""
# 初始化一个空列表,用于存储每张图像处理后的标签。
mosaic_labels = []
# 获取图像的大小。
s = self.imgsz
# 循环3次,分别处理3张图像。
for i in range(3):
# 选择当前要处理的图像的标签,如果是第一张图像,则使用 labels ,否则从 labels["mix_labels"] 中选择。
# 如果 i 不等于0,即当前处理的是第二张或第三张图像,则 labels_patch 指向 labels["mix_labels"] 列表中对应的元素。 i - 1 是因为列表索引是从0开始的,所以需要减1来获取正确的索引。
# labels["mix_labels"] 是一个列表,其中包含了除主图像外,用于创建马赛克的其他图像的标签。
# 例如,如果 self.n 是3,那么 labels["mix_labels"] 将包含两个额外图像的标签,这些图像将被用来与主图像一起创建3张图像的马赛克效果。
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
# 加载当前图像。
img = labels_patch["img"]
# 获取当前图像的尺寸,并从 labels_patch 中移除这个键值对。
h, w = labels_patch.pop("resized_shape")
# Place img in img3
# 如果是第一张图像(中心图像),则创建一个基础图像 img3 ,并设置中心图像的坐标 c 。
if i == 0: # center
img3 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 3 tiles
h0, w0 = h, w
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
# 如果是第二张图像(右侧图像),则设置右侧图像的坐标 c 。
elif i == 1: # right
c = s + w0, s, s + w0 + w, s + h
# 如果是第三张图像(左侧图像),则设置左侧图像的坐标 c 。
elif i == 2: # left
c = s - w, s + h0 - h, s, s + h0
# 获取 c 中的前两个值,即 x 和 y 坐标的偏移量。
padw, padh = c[:2]
# 计算分配的坐标,确保它们不小于0。
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# 将当前图像放置到 img3 中对应的位置。
img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]
# hp, wp = h, w # height, width previous for next iteration
# Labels assuming imgsz*2 mosaic size
# 更新当前图像的标签,考虑边界偏移。
labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) # ()
# 将更新后的标签添加到 mosaic_labels 列表中。
mosaic_labels.append(labels_patch)
# 将所有图像的标签合并为最终的标签。
final_labels = self._cat_labels(mosaic_labels)
# 从 img3 中裁剪出最终的图像,去除边界。
# 这行代码的作用是从 img3 中裁剪出一个区域,并将这个区域的图像赋值给 final_labels 字典中的 "img" 键。这个裁剪操作是为了去除 img3 边缘的边界区域,通常这些边界区域是黑色或其他颜色的填充,不是我们感兴趣的图像内容。
# final_labels["img"] :这是要赋值的目标,即 final_labels 字典中的 "img" 键。
# img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] :这是用于裁剪 img3 的切片表达式。
# self.border[0] 和 self.border[1] 分别是水平和垂直方向上的边界宽度。由于 self.border 被定义为 (-imgsz // 2, -imgsz // 2) ,这意味着边界宽度是图像尺寸的一半的负值,实际上是从图像中心向边缘的偏移量。
# -self.border[0] 和 -self.border[1] 表示从 img3 的右边界和底部边界开始裁剪,裁剪掉边界区域。
# self.border[0] 和 self.border[1] 表示从 img3 的左边界和顶部边界开始裁剪,裁剪掉边界区域。
# 综合来看, img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] 这个切片表达式的作用是从 img3 的四个边缘裁剪掉宽度为 self.border[0] 和 self.border[1] 的区域,只保留中间的核心区域。
# 将裁剪后的图像区域赋值给 final_labels["img"] ,这样 final_labels 字典中的 "img" 键就包含了去除了边界的最终图像。
# 这个裁剪操作确保了最终的图像只包含有用的内容,去除了由于马赛克拼接而产生的不必要的边缘区域。
final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] # (1280,1280)
# 返回最终的标签。
return final_labels
# 这个方法的作用是将3张图像合并为一个3x3的马赛克图像,并更新标签信息。它首先创建一个基础图像,然后将其他两张图像放置在基础图像的两侧。最后,它更新标签信息,并返回包含最终图像和标签的字典。
# 这段代码是 Mosaic 类中的一个方法 _mosaic4 的定义,它用于执行4张图像的马赛克数据增强。这个方法接受一个参数 labels ,它包含了图像的标签信息。
# 定义了一个名为 _mosaic4 的方法,它接受两个参数。
# 1.self :类的实例。
# 2.labels :包含图像标签的字典。
def _mosaic4(self, labels):
# 从四幅输入图像创建 2x2 图像马赛克。
# 此方法将四幅图像放在 2x2 网格中,将它们组合成一幅马赛克图像。它还会更新马赛克中每幅图像的相应标签。
"""
Creates a 2x2 image mosaic from four input images.
This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also
updates the corresponding labels for each image in the mosaic.
Args:
labels (Dict): A dictionary containing image data and labels for the base image (index 0) and three
additional images (indices 1-3) in the 'mix_labels' key.
Returns:
(Dict): A dictionary containing the mosaic image and updated labels. The 'img' key contains the mosaic
image as a numpy array, and other keys contain the combined and adjusted labels for all four images.
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4)
>>> labels = {
... "img": np.random.rand(480, 640, 3),
... "mix_labels": [{"img": np.random.rand(480, 640, 3)} for _ in range(3)],
... }
>>> result = mosaic._mosaic4(labels)
>>> assert result["img"].shape == (1280, 1280, 3)
"""
# 初始化一个空列表,用于存储每张图像处理后的标签。
mosaic_labels = []
# 获取图像的大小。
s = self.imgsz
# 随机生成马赛克中心的 x 和 y 坐标。这里使用 random.uniform 函数在 [-border, 2*imgsz + border] 范围内均匀随机选择一个值。
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y [320,960]
# 循环4次,分别处理4张图像。
for i in range(4):
# 选择当前要处理的图像的标签,如果是第一张图像,则使用 labels ,否则从 labels["mix_labels"] 中选择。
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
# 加载当前图像。
img = labels_patch["img"]
# 获取当前图像的尺寸,并从 labels_patch 中移除这个键值对。
h, w = labels_patch.pop("resized_shape")
# Place img in img4
# 如果是第一张图像(左上角),则创建一个基础图像 img4 ,并设置左上角图像的坐标 x1a, y1a, x2a, y2a 和 x1b, y1b, x2b, y2b 。
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
# 如果是第二张图像(右上角),则设置右上角图像的坐标。
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
# 如果是第三张图像(左下角),则设置左下角图像的坐标。
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
# 如果是第四张图像(右下角),则设置右下角图像的坐标。
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
# 将当前图像放置到 img4 中对应的位置。
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
# 计算水平和垂直方向上的填充宽度。
padw = x1a - x1b
padh = y1a - y1b
# 更新当前图像的标签,考虑边界偏移。
labels_patch = self._update_labels(labels_patch, padw, padh)
# 将更新后的标签添加到 mosaic_labels 列表中。
mosaic_labels.append(labels_patch)
# 将所有图像的标签合并为最终的标签。
final_labels = self._cat_labels(mosaic_labels)
# img4 将最终的马赛克图像赋值给 final_labels 字典中的 "img" 键。
final_labels["img"] = img4
# 返回最终的标签。
return final_labels
# 这个方法的作用是将4张图像合并为一个2x2的马赛克图像,并更新标签信息。它首先创建一个基础图像,然后将其他三张图像放置在基础图像的周围。最后,它更新标签信息,并返回包含最终图像和标签的字典。
# 这段代码是 Mosaic 类中的一个方法 _mosaic9 的定义,它用于执行9张图像的马赛克数据增强。这个方法接受一个参数 labels ,它包含了图像的标签信息。
# 定义了一个名为 _mosaic9 的方法,它接受两个参数.
# 1.self :类的实例。
# 2.labels :包含图像标签的字典。
def _mosaic9(self, labels):
# 从输入图像和另外八幅图像创建 3x3 图像马赛克。
# 此方法将九幅图像组合成一幅马赛克图像。输入图像位于中心,来自数据集的另外八幅图像以 3x3 网格模式放置在其周围。
"""
Creates a 3x3 image mosaic from the input image and eight additional images.
This method combines nine images into a single mosaic image. The input image is placed at the center,
and eight additional images from the dataset are placed around it in a 3x3 grid pattern.
Args:
labels (Dict): A dictionary containing the input image and its associated labels. It should have
the following keys:
- 'img' (numpy.ndarray): The input image.
- 'resized_shape' (Tuple[int, int]): The shape of the resized image (height, width).
- 'mix_labels' (List[Dict]): A list of dictionaries containing information for the additional
eight images, each with the same structure as the input labels.
Returns:
(Dict): A dictionary containing the mosaic image and updated labels. It includes the following keys:
- 'img' (numpy.ndarray): The final mosaic image.
- Other keys from the input labels, updated to reflect the new mosaic arrangement.
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=9)
>>> input_labels = dataset[0]
>>> mosaic_result = mosaic._mosaic9(input_labels)
>>> mosaic_image = mosaic_result["img"]
"""
# 初始化一个空列表,用于存储每张图像处理后的标签。
mosaic_labels = []
# 获取图像的大小。
s = self.imgsz
# 初始化前一次迭代的高和宽为-1。
hp, wp = -1, -1 # height, width previous
# 循环9次,分别处理9张图像。
for i in range(9):
# 选择当前要处理的图像的标签,如果是第一张图像,则使用 labels ,否则从 labels["mix_labels"] 中选择。
labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
# Load image
# 加载当前图像。
img = labels_patch["img"]
# 获取当前图像的尺寸,并从 labels_patch 中移除这个键值对。
h, w = labels_patch.pop("resized_shape")
# Place img in img9
# 如果是第一张图像(中心图像),则创建一个基础图像 img9 ,并设置中心图像的坐标 c 。
if i == 0: # center
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
h0, w0 = h, w
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
# elif i in [1, 2, 3, 4, 5, 6, 7, 8] : 对于其他8张图像,根据它们在3x3网格中的位置设置坐标 c 。
elif i == 1: # top
c = s, s - h, s + w, s
elif i == 2: # top right
c = s + wp, s - h, s + wp + w, s
elif i == 3: # right
c = s + w0, s, s + w0 + w, s + h
elif i == 4: # bottom right
c = s + w0, s + hp, s + w0 + w, s + hp + h
elif i == 5: # bottom
c = s + w0 - w, s + h0, s + w0, s + h0 + h
elif i == 6: # bottom left
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
elif i == 7: # left
c = s - w, s + h0 - h, s, s + h0
elif i == 8: # top left
c = s - w, s + h0 - hp - h, s, s + h0 - hp
# 获取 c 中的前两个值,即 x 和 y 坐标的偏移量。
padw, padh = c[:2]
# 计算分配的坐标,确保它们不小于0。
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
# Image
# 将当前图像放置到 img9 中对应的位置。
img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]
# 更新前一次迭代的高和宽。
hp, wp = h, w # height, width previous for next iteration
# Labels assuming imgsz*2 mosaic size
# 更新当前图像的标签,考虑边界偏移。
labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
# 将更新后的标签添加到 mosaic_labels 列表中。
mosaic_labels.append(labels_patch)
# 将所有图像的标签合并为最终的标签。
final_labels = self._cat_labels(mosaic_labels)
# 从 img9 中裁剪出最终的图像,去除边界。
final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
# 返回最终的标签。
return final_labels
# 这个方法的作用是将9张图像合并为一个3x3的马赛克图像,并更新标签信息。它首先创建一个基础图像,然后将其他八张图像放置在基础图像的周围。最后,它更新标签信息,并返回包含最终图像和标签的字典。
# 这段代码定义了一个名为 _update_labels 的静态方法,它用于更新图像标签,考虑到图像的填充(padding)。这个方法不依赖于类的实例(因为它是静态的)。
# @staticmethod 这是一个装饰器,表示 _update_labels 是一个静态方法,它不会自动接收类实例( self )或类( cls )作为第一个参数。
@staticmethod
# 定义了这个方法,接受三个参数。
# 1.labels :包含图像标签信息的字典。
# 2.padw :水平方向上的填充宽度。
# 3.padh :垂直方向上的填充高度。
def _update_labels(labels, padw, padh):
# 使用填充值更新标签坐标。
# 此方法通过添加填充值来调整标签中对象实例的边界框坐标。如果坐标之前已标准化,它还会对其进行反标准化。
"""
Updates label coordinates with padding values.
This method adjusts the bounding box coordinates of object instances in the labels by adding padding
values. It also denormalizes the coordinates if they were previously normalized.
Args:
labels (Dict): A dictionary containing image and instance information.
padw (int): Padding width to be added to the x-coordinates.
padh (int): Padding height to be added to the y-coordinates.
Returns:
(Dict): Updated labels dictionary with adjusted instance coordinates.
Examples:
>>> labels = {"img": np.zeros((100, 100, 3)), "instances": Instances(...)}
>>> padw, padh = 50, 50
>>> updated_labels = Mosaic._update_labels(labels, padw, padh)
"""
# 从 labels 字典中获取图像的高度( nh )和宽度( nw )。
nh, nw = labels["img"].shape[:2]
# 将标签中的边界框(bounding boxes)转换为 xyxy 格式,即每个边界框由四个值表示:(x1, y1, x2, y2),分别代表边界框左上角和右下角的坐标。
labels["instances"].convert_bbox(format="xyxy")
# 将边界框的坐标从归一化值转换为实际的像素坐标。这里, nw 和 nh 分别代表图像的宽度和高度,用于将 归一化的坐标 转换回相对于图像尺寸的 实际坐标 。denormalize 方法需要宽度和高度的顺序。
labels["instances"].denormalize(nw, nh)
# 为边界框添加填充。这意味着边界框的坐标将根据图像边缘的填充进行调整,以确保边界框仍然准确地描述了图像中对象的位置。
labels["instances"].add_padding(padw, padh)
# 返回更新后的 labels 字典,其中包含了调整后的边界框坐标。
return labels
# 这个方法的作用是确保在图像进行填充操作后,边界框的坐标能够正确地反映对象在图像中的位置。这对于目标检测任务尤其重要,因为模型需要准确地定位和识别图像中的对象。通过这个方法,可以确保即使图像尺寸发生变化(例如,添加填充),边界框的坐标也能相应地进行调整。
# 这段代码定义了一个名为 _cat_labels 的方法,它用于合并多个图像的标签信息,以创建一个包含所有标签的单一字典。这个方法接受一个参数 mosaic_labels ,这是一个包含多个图像标签的列表。
# 定义了这个方法,接受两个参数。
# 1.self :类的实例。
# 2.mosaic_labels :包含多个图像标签的列表。
def _cat_labels(self, mosaic_labels):
# 连接并处理马赛克增强的标签。
# 此方法组合马赛克增强中使用的多幅图像的标签,将实例剪辑到马赛克边框,并删除零面积框。
"""
Concatenates and processes labels for mosaic augmentation.
This method combines labels from multiple images used in mosaic augmentation, clips instances to the
mosaic border, and removes zero-area boxes.
Args:
mosaic_labels (List[Dict]): A list of label dictionaries for each image in the mosaic.
Returns:
(Dict): A dictionary containing concatenated and processed labels for the mosaic image, including:
- im_file (str): File path of the first image in the mosaic.
- ori_shape (Tuple[int, int]): Original shape of the first image.
- resized_shape (Tuple[int, int]): Shape of the mosaic image (imgsz * 2, imgsz * 2).
- cls (np.ndarray): Concatenated class labels.
- instances (Instances): Concatenated instance annotations.
- mosaic_border (Tuple[int, int]): Mosaic border size.
- texts (List[str], optional): Text labels if present in the original labels.
Examples:
>>> mosaic = Mosaic(dataset, imgsz=640)
>>> mosaic_labels = [{"cls": np.array([0, 1]), "instances": Instances(...)} for _ in range(4)]
>>> result = mosaic._cat_labels(mosaic_labels)
>>> print(result.keys())
dict_keys(['im_file', 'ori_shape', 'resized_shape', 'cls', 'instances', 'mosaic_border'])
"""
# 如果 mosaic_labels 列表为空,则返回一个空字典。
if len(mosaic_labels) == 0:
return {}
# 初始化两个空列表,用于存储所有图像的 类别标签 和 实例标签 。
cls = []
instances = []
# 计算马赛克图像的新尺寸,这里假设马赛克图像的尺寸是单个图像尺寸的两倍。
imgsz = self.imgsz * 2 # mosaic imgsz
# 遍历 mosaic_labels 列表中的每个标签字典。
for labels in mosaic_labels:
# 将每个图像的 类别标签 和 实例标签 添加到相应的列表中。
cls.append(labels["cls"])
instances.append(labels["instances"])
# Final labels
# 创建一个新的字典 final_labels ,用于存储合并后的标签信息。
final_labels = {
# 将第一个图像的文件名作为最终的文件名。
"im_file": mosaic_labels[0]["im_file"],
# 将第一个图像的原始尺寸作为最终的原始尺寸。
"ori_shape": mosaic_labels[0]["ori_shape"],
# 设置最终的尺寸为马赛克图像的尺寸。
"resized_shape": (imgsz, imgsz),
# 使用 numpy 的 concatenate 函数将所有类别标签合并为一个数组。
"cls": np.concatenate(cls, 0),
# 使用 Instances 类的 concatenate 方法将所有实例标签合并为一个数组。
"instances": Instances.concatenate(instances, axis=0),
# 将马赛克的边界信息添加到最终标签中。
"mosaic_border": self.border,
}
# 将所有实例标签的边界框裁剪到马赛克图像的尺寸内。
final_labels["instances"].clip(imgsz, imgsz)
# 移除所有面积为零的边界框,并返回一个布尔数组,表示哪些边界框是有效的。
good = final_labels["instances"].remove_zero_area_boxes()
# 根据 good 数组,只保留有效的类别标签。
final_labels["cls"] = final_labels["cls"][good]
# 如果第一个图像的标签中包含文本信息,则将这些文本信息添加到最终标签中。
if "texts" in mosaic_labels[0]:
final_labels["texts"] = mosaic_labels[0]["texts"]
# 返回包含所有合并标签的 final_labels 字典。
return final_labels
# 这个方法的作用是将多个图像的标签信息合并为一个统一的标签字典,这对于处理马赛克图像特别有用,因为马赛克图像是由多个图像拼接而成的,需要将这些图像的标签合并以供后续处理。
6.class MixUp(BaseMixTransform):
python
# 这段代码定义了一个名为 MixUp 的类,它继承自 BaseMixTransform 。 MixUp 类用于实现一种数据增强技术,称为MixUp,它通过混合两张图像的像素来创建新的训练样本。
# 定义了一个名为 MixUp 的新类,它继承自 BaseMixTransform 类。
class MixUp(BaseMixTransform):
# 将 MixUp 增强应用于图像数据集。
# 此类实现了论文"mixup:超越经验风险最小化"(https://arxiv.org/abs/1710.09412) 中描述的 MixUp 增强技术。MixUp 使用随机权重组合两幅图像及其标签。
"""
Applies MixUp augmentation to image datasets.
This class implements the MixUp augmentation technique as described in the paper "mixup: Beyond Empirical Risk
Minimization" (https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight.
Attributes:
dataset (Any): The dataset to which MixUp augmentation will be applied.
pre_transform (Callable | None): Optional transform to apply before MixUp.
p (float): Probability of applying MixUp augmentation.
Methods:
get_indexes: Returns a random index from the dataset.
_mix_transform: Applies MixUp augmentation to the input labels.
Examples:
>>> from ultralytics.data.augment import MixUp
>>> dataset = YourDataset(...) # Your image dataset
>>> mixup = MixUp(dataset, p=0.5)
>>> augmented_labels = mixup(original_labels)
"""
# 这是 MixUp 类的构造函数,它接受三个参数。
# 1.dataset :数据集,可能是包含图像和标签的数据结构。
# 2.pre_transform :预处理变换,可以是None或一个变换函数。
# 3.p :应用MixUp变换的概率,默认为0.0。
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
"""
Initializes the MixUp augmentation object.
MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel
values and labels. This implementation is designed for use with the Ultralytics YOLO framework.
Args:
dataset (Any): The dataset to which MixUp augmentation will be applied.
pre_transform (Callable | None): Optional transform to apply to images before MixUp.
p (float): Probability of applying MixUp augmentation to an image. Must be in the range [0, 1].
Examples:
>>> from ultralytics.data.dataset import YOLODataset
>>> dataset = YOLODataset("path/to/data.yaml")
>>> mixup = MixUp(dataset, pre_transform=None, p=0.5)
"""
# 调用父类 BaseMixTransform 的构造函数,传递 dataset 、 pre_transform 和 p 参数。
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
# 定义了一个名为 get_indexes 的方法,用于随机选择一个索引用于MixUp。
def get_indexes(self):
"""
Get a random index from the dataset.
This method returns a single random index from the dataset, which is used to select an image for MixUp
augmentation.
Returns:
(int): A random integer index within the range of the dataset length.
Examples:
>>> mixup = MixUp(dataset)
>>> index = mixup.get_indexes()
>>> print(index)
42
"""
# 返回一个随机索引,用于从数据集中选择第二张图像。
return random.randint(0, len(self.dataset) - 1)
# 定义了一个名为 _mix_transform 的方法,用于执行MixUp变换。
def _mix_transform(self, labels):
"""
Applies MixUp augmentation to the input labels.
This method implements the MixUp augmentation technique as described in the paper
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412).
Args:
labels (Dict): A dictionary containing the original image and label information.
Returns:
(Dict): A dictionary containing the mixed-up image and combined label information.
Examples:
>>> mixer = MixUp(dataset)
>>> mixed_labels = mixer._mix_transform(labels)
"""
# 生成一个Beta分布的随机数作为MixUp比例,其中 alpha=beta=32.0 。
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
# 选择用于混合的第二张图像的标签。
labels2 = labels["mix_labels"][0]
# 将两张图像的像素按比例混合,并确保结果为无符号整型。
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
# 合并两张图像的实例标签。
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
# 合并两张图像的类别标签。
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
# 返回包含混合图像和合并标签的字典。
return labels
# MixUp 类的作用是通过对两张图像的像素进行加权平均来创建新的图像样本,同时合并这两张图像的标签,以增加数据集的多样性并提高模型的泛化能力。这种方法特别适用于深度学习中的目标检测和分类任务。
7.class RandomPerspective:
python
# 这段代码定义了一个名为 RandomPerspective 的类,它用于在图像上应用随机透视变换,这是一种数据增强技术,可以模拟从不同视角观察对象的效果。
# 定义了一个名为 RandomPerspective 的新类。
class RandomPerspective:
# 在图像和相应的注释上实现随机透视和仿射变换。
# 此类将随机旋转、平移、缩放、剪切和透视变换应用于图像及其相关的边界框、段和关键点。它可以用作对象检测和实例分割任务的增强管道的一部分。
# 方法:
# affine_transform :将仿射变换应用于输入图像。
# apply_bboxes :使用仿射矩阵变换边界框。
# apply_segments :变换段并生成新的边界框。
# apply_keypoints :使用仿射矩阵变换关键点。
# __call__ :将随机透视变换应用于图像和注释。
# box_candidates :根据大小和纵横比过滤变换后的边界框。
"""
Implements random perspective and affine transformations on images and corresponding annotations.
This class applies random rotations, translations, scaling, shearing, and perspective transformations
to images and their associated bounding boxes, segments, and keypoints. It can be used as part of an
augmentation pipeline for object detection and instance segmentation tasks.
Attributes:
degrees (float): Maximum absolute degree range for random rotations.
translate (float): Maximum translation as a fraction of the image size.
scale (float): Scaling factor range, e.g., scale=0.1 means 0.9-1.1.
shear (float): Maximum shear angle in degrees.
perspective (float): Perspective distortion factor.
border (Tuple[int, int]): Mosaic border size as (x, y).
pre_transform (Callable | None): Optional transform to apply before the random perspective.
Methods:
affine_transform: Applies affine transformations to the input image.
apply_bboxes: Transforms bounding boxes using the affine matrix.
apply_segments: Transforms segments and generates new bounding boxes.
apply_keypoints: Transforms keypoints using the affine matrix.
__call__: Applies the random perspective transformation to images and annotations.
box_candidates: Filters transformed bounding boxes based on size and aspect ratio.
Examples:
>>> transform = RandomPerspective(degrees=10, translate=0.1, scale=0.1, shear=10)
>>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
>>> labels = {"img": image, "cls": np.array([0, 1]), "instances": Instances(...)}
>>> result = transform(labels)
>>> transformed_image = result["img"]
>>> transformed_instances = result["instances"]
"""
# 这是 RandomPerspective 类的构造函数,它接受多个参数来配置透视变换的参数。
# 1.degrees :随机旋转的角度范围,默认为0.0,表示不旋转。
# 2.translate :随机平移的范围,以图像宽度和高度的比例表示,默认为0.1。
# 3.scale :随机缩放的范围,默认为0.5,表示图像可以被缩放到原始尺寸的50%。
# 4.shear :随机剪切变换的范围,默认为0.0,表示不剪切。
# 5.perspective :随机透视变换的范围,默认为0.0,表示不应用透视变换。
# 6.border :应用变换后图像的边界填充,默认为(0, 0),表示不填充。
# 7.pre_transform :预处理变换,可以是None或一个变换函数。
def __init__(
self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
):
# 使用变换参数初始化 RandomPerspective 对象。
# 此类对图像和相应的边界框、线段和关键点实现随机透视和仿射变换。变换包括旋转、平移、缩放和剪切。
"""
Initializes RandomPerspective object with transformation parameters.
This class implements random perspective and affine transformations on images and corresponding bounding boxes,
segments, and keypoints. Transformations include rotation, translation, scaling, and shearing.
Args:
degrees (float): Degree range for random rotations.
translate (float): Fraction of total width and height for random translation.
scale (float): Scaling factor interval, e.g., a scale factor of 0.5 allows a resize between 50%-150%.
shear (float): Shear intensity (angle in degrees).
perspective (float): Perspective distortion factor.
border (Tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right).
pre_transform (Callable | None): Function/transform to apply to the image before starting the random
transformation.
Examples:
>>> transform = RandomPerspective(degrees=10.0, translate=0.1, scale=0.5, shear=5.0)
>>> result = transform(labels) # Apply random perspective to labels
"""
# 将传入的 degrees 参数赋值给实例变量 self.degrees ,用于存储旋转角度范围。
self.degrees = degrees
# 将传入的 translate 参数赋值给实例变量 self.translate ,用于存储平移范围。
self.translate = translate
# 将传入的 scale 参数赋值给实例变量 self.scale ,用于存储缩放范围。
self.scale = scale
# 将传入的 shear 参数赋值给实例变量 self.shear ,用于存储剪切变换范围。
self.shear = shear
# 将传入的 perspective 参数赋值给实例变量 self.perspective ,用于存储透视变换范围。
self.perspective = perspective
# 将传入的 border 参数赋值给实例变量 self.border ,用于存储边界填充。
self.border = border # mosaic border
# 将传入的 pre_transform 参数赋值给实例变量 self.pre_transform ,用于存储预处理变换。
self.pre_transform = pre_transform
# RandomPerspective 类的作用是提供一种方法来随机应用透视变换,包括旋转、平移、缩放、剪切和透视变换,以增加图像数据集的多样性。这种数据增强技术可以帮助训练更鲁棒的机器学习模型,使模型能够更好地泛化到不同的视角和变换。
# 这段代码是 RandomPerspective 类中的一个方法 affine_transform 的定义,它用于对图像应用仿射变换和透视变换。
# 定义了一个名为 affine_transform 的方法,它接受两个参数。
# 1.self :类的实例。
# 2.img :要变换的图像。
# 3.border :变换后图像的边界。
def affine_transform(self, img, border):
# 应用以图像中心为中心的一系列仿射变换。
# 此函数对输入图像执行一系列几何变换,包括平移、透视变化、旋转、缩放和剪切。变换以特定顺序应用以保持一致性。
"""
Applies a sequence of affine transformations centered around the image center.
This function performs a series of geometric transformations on the input image, including
translation, perspective change, rotation, scaling, and shearing. The transformations are
applied in a specific order to maintain consistency.
Args:
img (np.ndarray): Input image to be transformed.
border (Tuple[int, int]): Border dimensions for the transformed image.
Returns:
(Tuple[np.ndarray, np.ndarray, float]): A tuple containing:
- np.ndarray: Transformed image.
- np.ndarray: 3x3 transformation matrix.
- float: Scale factor applied during the transformation.
Examples:
>>> import numpy as np
>>> img = np.random.rand(100, 100, 3)
>>> border = (10, 10)
>>> transformed_img, matrix, scale = affine_transform(img, border)
"""
# Center
# 创建一个3x3的单位矩阵 C ,用于存储中心点平移矩阵。
C = np.eye(3, dtype=np.float32)
# 将图像中心移动到原点。
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
# Perspective
# 创建一个3x3的单位矩阵 P ,用于存储透视变换矩阵。
P = np.eye(3, dtype=np.float32)
# 随机生成透视变换参数。
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
# Rotation and Scale
# 创建一个3x3的单位矩阵 R ,用于存储旋转和缩放矩阵。
R = np.eye(3, dtype=np.float32)
# 随机生成旋转角度。
a = random.uniform(-self.degrees, self.degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
# 随机生成缩放比例。
s = random.uniform(1 - self.scale, 1 + self.scale)
# s = 2 ** random.uniform(-scale, scale)
# cv2.getRotationMatrix2D(center, angle, scale)
# cv2.getRotationMatrix2D() 是 OpenCV 库中的一个函数,用于生成二维旋转矩阵。这个函数在图像处理和计算机视觉任务中非常有用,尤其是在需要旋转图像时。
# 参数 :
# center :旋转中心点,通常是一个 (x, y) 的二元组,表示图像中旋转轴的中心点。
# angle :旋转角度,单位是度。正数表示逆时针旋转,负数表示顺时针旋转。
# scale :缩放因子。当 scale 等于 1 时,表示没有缩放,图像大小不变;大于 1 表示放大;小于 1 表示缩小。
# 返回值 :
# 函数返回一个 2x3 的仿射变换矩阵,该矩阵可以用于 cv2.warpAffine() 或 cv2.warpPerspective() 函数来对图像进行旋转。
# 矩阵形式 :
# 返回的旋转矩阵 M 的形式如下 :
# [ cos(θ) -sin(θ) tx ]
# [ sin(θ) cos(θ) ty ]
# 其中, θ 是旋转角度, tx 和 ty 是平移量。
# 由于 cv2.getRotationMatrix2D() 生成的 tx 和 ty 实际上是为了确保图像围绕指定的中心点旋转而不超出边界,它们并不是固定的,而是根据旋转中心和图像尺寸动态计算的。
# 使用OpenCV的 getRotationMatrix2D 函数生成旋转矩阵。
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
# 创建一个3x3的单位矩阵 S ,用于存储剪切变换矩阵。
S = np.eye(3, dtype=np.float32)
# 随机生成剪切变换参数。
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
# Translation
# 创建一个3x3的单位矩阵 T ,用于存储平移矩阵。
T = np.eye(3, dtype=np.float32)
# 随机生成平移参数。
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
# Combined rotation matrix
# 将所有变换矩阵相乘,得到最终的变换矩阵 M 。
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
# Affine image
# 检查是否需要应用变换。
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
if self.perspective:
# 应用透视变换,并设置边界值为(114, 114, 114)。
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
else: # affine
# # 应用仿射变换,并设置边界值为(114, 114, 114)。
img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
# 返回变换后的图像、变换矩阵和缩放比例。
return img, M, s
# 这个方法的作用是对图像进行随机的仿射变换和透视变换,包括旋转、缩放、剪切和平移,以增加图像数据集的多样性。这种数据增强技术可以帮助训练更鲁棒的机器学习模型,使模型能够更好地泛化到不同的视角和变换。
# 仿射变换(Affine Transformation)和透视变换(Perspective Transformation)是两种不同的几何变换,它们在图像处理和计算机视觉中有着广泛的应用。以下是它们之间的主要区别 :
# 仿射变换 :
# 仿射变换是一种线性变换,它保持了图形的"直性"和"平行性"。这意味着,经过仿射变换后,直线仍然保持为直线,且平行线依然保持平行。仿射变换可以包括旋转、平移、缩放和剪切等操作。
# 线性 :仿射变换是线性变换,可以通过矩阵乘法来表示。
# 保持直线 :仿射变换下,直线映射为直线。
# 保持平行性 :平行线在仿射变换后仍然保持平行。
# 二维矩阵 :仿射变换可以用一个2x3的矩阵来表示。
# 透视变换 :
# 透视变换是一种更为复杂的变换,它模拟了三维空间中的透视效果。在透视变换中,平行线可能会相交于一点(消失点),这种变换可以模拟相机镜头的透视效果,使得远处的物体看起来更小。
# 非线性 :透视变换是一种非线性变换,涉及到除法操作。
# 不保持平行性 :平行线在透视变换后可能会相交。
# 三维矩阵 :透视变换需要一个3x3的矩阵来表示。
# 消失点 :在透视变换中,平行线可以相交于一点,这是由于视角的原因造成的。
# 应用场景 :
# 仿射变换 :常用于图像的校正、对齐、旋转和缩放等操作,因为它可以保持物体的形状和方向。
# 透视变换 :常用于模拟相机视角的变换,如在街景地图中校正图像,或者在计算机图形学中创建三维效果。
# 总结 :
# 仿射变换是一类特殊的线性变换,它保持了图形的线性结构,而透视变换则是一种更一般的变换,可以模拟真实的透视效果。在实际应用中,选择哪种变换取决于具体的需求和预期的效果。在图像处理中,仿射变换通常更容易实现,而透视变换则提供了更丰富的视觉效果。
# 这段代码定义了一个名为 apply_bboxes 的方法,它用于将仿射变换或透视变换应用到边界框(bounding boxes)上。
# 定义了一个名为 apply_bboxes 的方法,它接受两个参数。
# 1.self :类的实例。
# 2.bboxes :边界框的数组。
# 3. :变换矩阵。
def apply_bboxes(self, bboxes, M):
# 将仿射变换应用于边界框。
# 此函数使用提供的变换矩阵将仿射变换应用于一组边界框。
"""
Apply affine transformation to bounding boxes.
This function applies an affine transformation to a set of bounding boxes using the provided
transformation matrix.
Args:
bboxes (torch.Tensor): Bounding boxes in xyxy format with shape (N, 4), where N is the number
of bounding boxes.
M (torch.Tensor): Affine transformation matrix with shape (3, 3).
Returns:
(torch.Tensor): Transformed bounding boxes in xyxy format with shape (N, 4).
Examples:
>>> bboxes = torch.tensor([[10, 10, 20, 20], [30, 30, 40, 40]])
>>> M = torch.eye(3)
>>> transformed_bboxes = apply_bboxes(bboxes, M)
"""
# 获取边界框的数量。
n = len(bboxes)
# 如果没有边界框,则直接返回。
if n == 0:
return bboxes
# 初始化一个形状为 (n * 4, 3) 的数组,用于存储变换后的坐标点。每个边界框有4个角点,因此总共有 n * 4 个点。
# 在这段代码中, xy = np.ones((n * 4, 3), dtype=bboxes.dtype) 创建了一个形状为 (n * 4, 3) 的 NumPy 数组,其中 n 是边界框的数量。
# 这个数组的每一行代表一个点的坐标,而每个点有三个坐标值,分别是 x 、 y 和一个额外的值,通常用于表示齐次坐标(homogeneous coordinates)。
# 齐次坐标是一种在二维或三维空间中表示点的方法,它通过添加一个额外的维度来简化某些几何变换的计算。
# 在二维空间中,一个点 (x, y) 可以表示为齐次坐标 (x, y, 1) 。这样做的好处是可以将旋转、缩放、平移等变换统一到一个线性变换矩阵中,通过矩阵乘法来实现。
# 在这个上下文中, xy 数组的第三列被初始化为 1,用于表示齐次坐标中的最后一个分量。这样,当应用仿射变换或透视变换时,可以通过矩阵乘法直接对这些点进行变换,而不需要对每个点的 x 和 y 坐标分别进行缩放。
# 变换后的坐标点将保持在第三列的值(在这个初始化中是 1)与原始值成比例,这允许在透视变换中正确地处理点的深度。
xy = np.ones((n * 4, 3), dtype=bboxes.dtype)
# 将边界框的坐标复制到 xy 数组中。每个边界框的坐标顺序是 (x1, y1, x2, y2) ,然后重复一次以形成 (x1, y1, x2, y2, x1, y2, x2, y1) 。
# .reshape(n * 4, 2) :由于 bboxes 数组的形状是 (n, 4) ,其中 n 是边界框的数量,每个边界框有四个坐标值。通过上述操作,得到了一个形状为 (n, 8) 的数组。使用 reshape(n * 4, 2) 将其重塑为 (n * 4, 2) 的形状,其中 n * 4 表示所有边界框的角点总数, 2 表示每个角点有两个坐标值(x和y)。
# xy[:, :2] 表示取出 xy 数组中每个点的前两个坐标值(即x和y),并将它们替换为从 bboxes 中提取并重塑后的坐标值。
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
# 将变换矩阵 M 应用到坐标点上。 @ 是矩阵乘法运算符, M.T 是 M 的转置。
xy = xy @ M.T # transform
# 如果是透视变换,则对坐标点进行缩放;如果是仿射变换,则直接使用坐标点。然后,将结果重塑为 (n, 8) 的形状。
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
# Create new boxes
# 提取变换后的 x 和 y 坐标。
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
# 计算每个边界框的新坐标,即每个边界框的 最小 x 、 最小 y 、 最大 x 和 最大 y ,然后将它们连接起来形成新的边界框数组。
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
# 这个方法的作用是将给定的变换矩阵 M 应用到边界框上,返回变换后的新边界框数组。这个过程考虑了透视变换和仿射变换的不同,能够正确地处理边界框的变换。这对于图像处理和计算机视觉任务中的对象跟踪、目标检测等应用非常重要。
# 这段代码定义了一个名为 apply_segments 的方法,它用于将仿射变换或透视变换应用到图像的分割区域(segments)上。
# 定义了一个名为 apply_segments 的方法,它接受两个参数。
# 1.self :类的实例。
# 2.segments :分割区域的数组。
# 3.M :变换矩阵。
def apply_segments(self, segments, M):
"""
Apply affine transformations to segments and generate new bounding boxes.
This function applies affine transformations to input segments and generates new bounding boxes based on
the transformed segments. It clips the transformed segments to fit within the new bounding boxes.
Args:
segments (np.ndarray): Input segments with shape (N, M, 2), where N is the number of segments and M is the
number of points in each segment.
M (np.ndarray): Affine transformation matrix with shape (3, 3).
Returns:
(Tuple[np.ndarray, np.ndarray]): A tuple containing:
- New bounding boxes with shape (N, 4) in xyxy format.
- Transformed and clipped segments with shape (N, M, 2).
Examples:
>>> segments = np.random.rand(10, 500, 2) # 10 segments with 500 points each
>>> M = np.eye(3) # Identity transformation matrix
>>> new_bboxes, new_segments = apply_segments(segments, M)
"""
# 获取分割区域的数量 n 和每个分割区域中的点数 num 。
n, num = segments.shape[:2]
# 如果没有分割区域,则直接返回空列表和原始的 segments 。
if n == 0:
return [], segments
# 初始化一个形状为 (n * num, 3) 的数组,用于存储变换后的坐标点。每个点有三个坐标值,分别是 x 、 y 和一个额外的值,通常用于表示齐次坐标。
xy = np.ones((n * num, 3), dtype=segments.dtype)
# 将 segments 数组重塑,以便于处理。
segments = segments.reshape(-1, 2)
# segments 将分割区域的坐标复制到 xy 数组中。
xy[:, :2] = segments
# 将变换矩阵 M 应用到坐标点上。
xy = xy @ M.T # transform
# 如果是透视变换,则对坐标点进行缩放;如果是仿射变换,则直接使用坐标点。
xy = xy[:, :2] / xy[:, 2:3]
# 将变换后的坐标点重塑回原来的分割区域格式。
segments = xy.reshape(n, -1, 2)
# 对于每个变换后的分割区域,计算其边界框,并堆叠成一个数组。
# def segment2box(segment, width=640, height=640):
# -> 它将一个分割掩码(segment)转换为一个边界框(bounding box)。计算边界框的坐标,如果存在有效的 x 坐标(即 x 数组不为空),则返回一个包含 (x_min, y_min, x_max, y_max) 的数组;否则返回一个长度为 4 的零数组。
# -> return (np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(4, dtype=segment.dtype)) # xyxy
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
# 将变换后的分割区域的坐标限制在对应的边界框内。
segments[..., 0] = segments[..., 0].clip(bboxes[:, 0:1], bboxes[:, 2:3])
segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4])
# 返回变换后的边界框和分割区域。
return bboxes, segments
# 这个方法的作用是将给定的变换矩阵 M 应用到分割区域上,返回变换后的新边界框数组和新的分割区域数组。这个过程考虑了透视变换和仿射变换的不同,能够正确地处理分割区域的变换。这对于图像处理和计算机视觉任务中的图像分割、目标跟踪等应用非常重要。
# 这段代码定义了一个名为 apply_keypoints 的方法,它用于将仿射变换或透视变换应用到图像的关键点(keypoints)上。
# 定义了一个名为 apply_keypoints 的方法,它接受两个参数。
# 1.self :类的实例。
# 2.keypoints :关键点的数组。
# 3.M :变换矩阵。
def apply_keypoints(self, keypoints, M):
# 将仿射变换应用于关键点。
# 此方法使用提供的仿射变换矩阵变换输入关键点。它会在必要时处理透视重新缩放,并更新变换后超出图像边界的关键点的可见性。
"""
Applies affine transformation to keypoints.
This method transforms the input keypoints using the provided affine transformation matrix. It handles
perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image
boundaries after transformation.
Args:
keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances,
17 is the number of keypoints per instance, and 3 represents (x, y, visibility).
M (np.ndarray): 3x3 affine transformation matrix.
Returns:
(np.ndarray): Transformed keypoints array with the same shape as input (N, 17, 3).
Examples:
>>> random_perspective = RandomPerspective()
>>> keypoints = np.random.rand(5, 17, 3) # 5 instances, 17 keypoints each
>>> M = np.eye(3) # Identity transformation
>>> transformed_keypoints = random_perspective.apply_keypoints(keypoints, M)
"""
# 获取关键点的数量 n 和每个关键点的维度 nkpt 。
n, nkpt = keypoints.shape[:2]
# 如果没有关键点,则直接返回原始的 keypoints 。
if n == 0:
return keypoints
# 初始化一个形状为 (n * nkpt, 3) 的数组,用于存储变换后的坐标点。每个点有三个坐标值,分别是 x 、 y 和一个额外的值,通常用于表示齐次坐标。
xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)
# 提取关键点的可见性标志,并重塑为 (n * nkpt, 1) 的形状。
visible = keypoints[..., 2].reshape(n * nkpt, 1)
# 将关键点的坐标复制到 xy 数组中。
xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
# 将变换矩阵 M 应用到坐标点上。
xy = xy @ M.T # transform
# 如果是透视变换,则对坐标点进行缩放;如果是仿射变换,则直接使用坐标点。
xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine
# 创建一个掩码 out_mask ,用于标记那些移出图像边界的关键点。
out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])
# 将移出图像边界的关键点的可见性标志设置为0,表示它们不再可见。
visible[out_mask] = 0
# 将变换后的关键点坐标和可见性标志合并,并重塑回原来的格式。
return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
# 这个方法的作用是将给定的变换矩阵 M 应用到关键点上,返回变换后的新关键点数组。这个过程考虑了透视变换和仿射变换的不同,能够正确地处理关键点的变换,并且能够处理关键点移出图像边界的情况。这对于图像处理和计算机视觉任务中的人脸关键点检测、姿态估计等应用非常重要。
# 这段代码定义了一个类的 __call__ 方法,它通常用于实现类的实例作为函数调用时的行为。在这个特定的例子中, __call__ 方法被用来应用一系列的图像变换和更新操作到图像及其对应的标签上。
# 这个方法接受一个参数。
# 1.labels :它包含了图像的标签信息。
def __call__(self, labels):
# 将随机透视和仿射变换应用于图像及其相关标签。
# 此方法对输入图像执行一系列变换,包括旋转、平移、缩放、剪切和透视失真,并相应地调整相应的边界框、段和关键点。
"""
Applies random perspective and affine transformations to an image and its associated labels.
This method performs a series of transformations including rotation, translation, scaling, shearing,
and perspective distortion on the input image and adjusts the corresponding bounding boxes, segments,
and keypoints accordingly.
Args:
labels (Dict): A dictionary containing image data and annotations.
Must include:
'img' (ndarray): The input image.
'cls' (ndarray): Class labels.
'instances' (Instances): Object instances with bounding boxes, segments, and keypoints.
May include:
'mosaic_border' (Tuple[int, int]): Border size for mosaic augmentation.
Returns:
(Dict): Transformed labels dictionary containing:
- 'img' (np.ndarray): The transformed image.
- 'cls' (np.ndarray): Updated class labels.
- 'instances' (Instances): Updated object instances.
- 'resized_shape' (Tuple[int, int]): New image shape after transformation.
Examples:
>>> transform = RandomPerspective()
>>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
>>> labels = {
... "img": image,
... "cls": np.array([0, 1, 2]),
... "instances": Instances(bboxes=np.array([[10, 10, 50, 50], [100, 100, 150, 150]])),
... }
>>> result = transform(labels)
>>> assert result["img"].shape[:2] == result["resized_shape"]
"""
# 检查是否存在预处理变换,并且标签中没有 mosaic_border 时,应用预处理变换。
if self.pre_transform and "mosaic_border" not in labels:
# 应用预处理变换。
labels = self.pre_transform(labels)
# 从标签中移除 ratio_pad 键,如果它存在。
labels.pop("ratio_pad", None) # do not need ratio pad
# 获取图像数据。
img = labels["img"]
# 获取类别标签。
cls = labels["cls"]
# 获取实例标签,并从 labels 字典中移除。
instances = labels.pop("instances")
# Make sure the coord formats are right
# 确保边界框格式是 xyxy 。
instances.convert_bbox(format="xyxy")
# 将边界框坐标从归一化值转换为实际像素坐标。denormalize 方法需要宽度和高度的顺序。
instances.denormalize(*img.shape[:2][::-1])
# 获取马赛克边界,如果不存在则使用默认值。
border = labels.pop("mosaic_border", self.border)
# 计算变换后图像的新尺寸。
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
# M is affine matrix
# Scale for func:`box_candidates`
# 应用仿射变换,获取变换后的图像、变换矩阵和缩放比例。
img, M, scale = self.affine_transform(img, border)
# 应用变换到边界框上。
bboxes = self.apply_bboxes(instances.bboxes, M)
# 获取分割区域。
segments = instances.segments
# 获取关键点。
keypoints = instances.keypoints
# Update bboxes if there are segments.
# 如果存在分割区域,应用变换。
if len(segments):
# 应用变换到分割区域。
bboxes, segments = self.apply_segments(segments, M)
# 如果存在关键点,应用变换。
if keypoints is not None:
# 应用变换到关键点。
keypoints = self.apply_keypoints(keypoints, M)
# 创建新的实例对象。
# class Instances:
# -> def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
# Clip
# 裁剪新实例以适应图像尺寸。
new_instances.clip(*self.size)
# Filter instances
# 根据缩放比例调整旧实例的边界框。
# def scale(self, scale_w, scale_h, bbox_only=False): -> scale 的方法,它是 Instances 类的一个成员函数。这个方法用于对实例中的边界框、分割和关键点进行缩放操作。
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
# Make the bboxes have the same scale with new_bboxes
# 找到与新边界框重叠面积超过阈值的旧边界框。
i = self.box_candidates(
box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10
)
# 更新标签中的实例。
labels["instances"] = new_instances[i]
# 更新类别标签。
labels["cls"] = cls[i]
# 更新图像数据。
labels["img"] = img
# 更新调整后的图像尺寸。
labels["resized_shape"] = img.shape[:2]
# 返回更新后的标签。
return labels
# 这个方法的作用是将一系列变换应用到图像及其标签上,包括预处理、仿射变换、边界框、分割区域和关键点的更新。这些操作通常用于数据增强,以提高模型的泛化能力。
# 这段代码定义了一个名为 box_candidates 的方法,它用于筛选满足特定条件的边界框候选。
# 定义了一个名为 box_candidates 的方法,个方法接受四个参数。
# 1.box1 和 2.box2 :是两个边界框的坐标。
# 3.wh_thr :是宽度和高度的阈值。
# 4.ar_thr :是宽高比的阈值。
# 5.area_thr :是面积的阈值。
# 6.eps :是用于数值稳定性的小常数。
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
# 根据大小和纵横比标准计算候选框以供进一步处理。
# 此方法比较增强前后的框,以确定它们是否满足指定的宽度、高度、纵横比和面积阈值。它用于过滤掉在增强过程中过度扭曲或缩小的框。
"""
Compute candidate boxes for further processing based on size and aspect ratio criteria.
This method compares boxes before and after augmentation to determine if they meet specified
thresholds for width, height, aspect ratio, and area. It's used to filter out boxes that have
been overly distorted or reduced by the augmentation process.
Args:
box1 (numpy.ndarray): Original boxes before augmentation, shape (4, N) where n is the
number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates.
box2 (numpy.ndarray): Augmented boxes after transformation, shape (4, N). Format is
[x1, y1, x2, y2] in absolute coordinates.
wh_thr (float): Width and height threshold in pixels. Boxes smaller than this in either
dimension are rejected.
ar_thr (float): Aspect ratio threshold. Boxes with an aspect ratio greater than this
value are rejected.
area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than
this value are rejected.
eps (float): Small epsilon value to prevent division by zero.
Returns:
(numpy.ndarray): Boolean array of shape (n,) indicating which boxes are candidates.
True values correspond to boxes that meet all criteria.
Examples:
>>> random_perspective = RandomPerspective()
>>> box1 = np.array([[0, 0, 100, 100], [0, 0, 50, 50]]).T
>>> box2 = np.array([[10, 10, 90, 90], [5, 5, 45, 45]]).T
>>> candidates = random_perspective.box_candidates(box1, box2)
>>> print(candidates)
[True True]
"""
# 计算第一个边界框 box1 的宽度和高度。
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
# 计算第二个边界框 box2 的宽度和高度。
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
# 计算 box2 的宽高比(aspect ratio),使用 np.maximum 来确保即使在 w2 或 h2 接近零时也不会导致除以零的错误, eps 是一个很小的常数,用于防止除以零。
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
# 返回一个布尔掩码,用于筛选满足以下条件的 box2 :
# w2 > wh_thr 和 h2 > wh_thr : box2 的宽度和高度都必须大于 wh_thr 。
# w2 * h2 / (w1 * h1 + eps) > area_thr : box2 与 box1 的面积比必须大于 area_thr 。
# ar < ar_thr : box2 的宽高比必须小于 ar_thr 。
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
# 这个方法的作用是在一组边界框中筛选出满足特定尺寸、面积和宽高比条件的候选框。这在目标检测和图像分割任务中非常有用,特别是在非极大值抑制(NMS)或匹配预测边界框与真实边界框时。通过这些条件,可以过滤掉那些与参考框 box1 差异过大或过小的候选框 box2 。
8.class RandomHSV:
python
# 这段代码定义了一个名为 RandomHSV 的类,它用于对图像进行随机HSV颜色空间变换,这是一种数据增强技术,可以增加图像的色彩多样性。
# 定义了一个名为 RandomHSV 的新类。
class RandomHSV:
# 随机调整图像的色调、饱和度和值 (HSV) 通道。
# 此类将随机 HSV 增强应用于由 hgain、sgain 和 vgain 设置的预定义限制内的图像。
# 方法:
# __call__ :将随机 HSV 增强应用于图像。
"""
Randomly adjusts the Hue, Saturation, and Value (HSV) channels of an image.
This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain.
Attributes:
hgain (float): Maximum variation for hue. Range is typically [0, 1].
sgain (float): Maximum variation for saturation. Range is typically [0, 1].
vgain (float): Maximum variation for value. Range is typically [0, 1].
Methods:
__call__: Applies random HSV augmentation to an image.
Examples:
>>> import numpy as np
>>> from ultralytics.data.augment import RandomHSV
>>> augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)
>>> image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
>>> labels = {"img": image}
>>> augmented_labels = augmenter(labels)
>>> augmented_image = augmented_labels["img"]
"""
# 这是 RandomHSV 类的构造函数,它接受三个参数。
# 1.hgain :色调(Hue)变换的增益,默认为0.5。
# 2.sgain :饱和度(Saturation)变换的增益,默认为0.5。
# 3.vgain :亮度(Value)变换的增益,默认为0.5。
def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
# 初始化 RandomHSV 对象以进行随机 HSV(色调、饱和度、值)增强。
# 此类在指定限制内对图像的 HSV 通道应用随机调整。
"""
Initializes the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation.
This class applies random adjustments to the HSV channels of an image within specified limits.
Args:
hgain (float): Maximum variation for hue. Should be in the range [0, 1].
sgain (float): Maximum variation for saturation. Should be in the range [0, 1].
vgain (float): Maximum variation for value. Should be in the range [0, 1].
Examples:
>>> hsv_aug = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)
>>> augmented_image = hsv_aug(image)
"""
# 将传入的参数赋值给类的实例变量。
self.hgain = hgain
self.sgain = sgain
self.vgain = vgain
# 定义了一个 __call__ 方法,允许类的实例像函数一样被调用。
def __call__(self, labels):
# 在预定义的限制内对图像应用随机 HSV 增强。
# 此方法通过随机调整其色调、饱和度和值 (HSV) 通道来修改输入图像。调整是在初始化期间由 hgain、sgain 和 vgain 设置的限制内进行的。
"""
Applies random HSV augmentation to an image within predefined limits.
This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels.
The adjustments are made within the limits set by hgain, sgain, and vgain during initialization.
Args:
labels (Dict): A dictionary containing image data and metadata. Must include an 'img' key with
the image as a numpy array.
Returns:
(None): The function modifies the input 'labels' dictionary in-place, updating the 'img' key
with the HSV-augmented image.
Examples:
>>> hsv_augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5)
>>> labels = {"img": np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)}
>>> hsv_augmenter(labels)
>>> augmented_img = labels["img"]
"""
# 从 labels 字典中获取图像数据。
img = labels["img"]
# 检查是否有增益不为零,如果有,则进行HSV变换。
if self.hgain or self.sgain or self.vgain:
# 生成三个随机数,分别对应色调、饱和度和亮度的增益,并调整范围到[0, 2]。
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
# 将图像从BGR颜色空间转换到HSV颜色空间,并分离出色调、饱和度和亮度通道。
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
# 获取图像的数据类型。
dtype = img.dtype # uint8
# 创建一个0到255的数组,用于构建查找表(LUT)。
x = np.arange(0, 256, dtype=r.dtype)
# 创建三个查找表,分别对应 色调 、 饱和度 和 亮度 的变换。
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
# cv2.LUT(src, lut[, dst])
# cv2.LUT() 是 OpenCV 库中的一个函数,用于执行查找表(Look-Up Table,LUT)变换。这个函数将输入图像的每个像素值替换为查找表中对应的新像素值。
# 参数 :
# src :输入数组,必须是8位元素的数组,即数据类型为 np.uint8 。
# lut :查找表,包含256个元素。对于多通道输入数组,查找表可以是单通道的(在这种情况下,相同的表用于所有通道),或者与输入数组具有相同数量的通道。
# dst :输出数组,其大小和通道数与 src 相同,深度与 lut 相同。如果未指定 dst ,则函数会自动创建一个与 src 大小和通道数相同的输出数组。
# 工作原理 :
# cv2.LUT() 函数通过查找表中的值填充输出数组。
# 使用查找表应用色调、饱和度和亮度的变换。
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
# 将变换后的HSV图像转换回BGR颜色空间,直接修改 img 数据。
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
# 返回更新后的 labels 字典。
return labels
# 这个方法的作用是对图像进行随机的色调、饱和度和亮度调整,以增加图像的色彩多样性,这有助于模型在训练时对颜色变化的鲁棒性。通过调整HSV值,可以模拟不同光照和色彩条件下的图像,从而增强模型的泛化能力。
9.class RandomFlip:
python
# 这段代码定义了一个名为 RandomFlip 的类,它用于以一定的概率对图像进行随机翻转,可以是水平翻转或垂直翻转。这种数据增强技术有助于模型学习时对图像方向的不变性。
# 定义了一个名为 RandomFlip 的新类。
class RandomFlip:
# 以给定概率对图像应用随机水平或垂直翻转。
# 此类执行随机图像翻转并更新相应的实例注释,例如边界框和关键点。
# 方法:
# __call__ :将随机翻转变换应用于图像及其注释。
"""
Applies a random horizontal or vertical flip to an image with a given probability.
This class performs random image flipping and updates corresponding instance annotations such as
bounding boxes and keypoints.
Attributes:
p (float): Probability of applying the flip. Must be between 0 and 1.
direction (str): Direction of flip, either 'horizontal' or 'vertical'.
flip_idx (array-like): Index mapping for flipping keypoints, if applicable.
Methods:
__call__: Applies the random flip transformation to an image and its annotations.
Examples:
>>> transform = RandomFlip(p=0.5, direction="horizontal")
>>> result = transform({"img": image, "instances": instances})
>>> flipped_image = result["img"]
>>> flipped_instances = result["instances"]
"""
# 这是 RandomFlip 类的构造函数,它接受三个参数。
# 1.p :翻转的概率,默认为0.5。
# 2.direction :翻转的方向,可以是 "horizontal" 或 "vertical",默认为 "horizontal"。
# 3.flip_idx :用于关键点翻转时的索引映射,如果提供,则在水平翻转时使用。
def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
# 使用概率和方向初始化 RandomFlip 类。
# 此类以给定概率对图像应用随机水平或垂直翻转。它还会相应地更新任何实例(边界框、关键点等)。
"""
Initializes the RandomFlip class with probability and direction.
This class applies a random horizontal or vertical flip to an image with a given probability.
It also updates any instances (bounding boxes, keypoints, etc.) accordingly.
Args:
p (float): The probability of applying the flip. Must be between 0 and 1.
direction (str): The direction to apply the flip. Must be 'horizontal' or 'vertical'.
flip_idx (List[int] | None): Index mapping for flipping keypoints, if any.
Raises:
AssertionError: If direction is not 'horizontal' or 'vertical', or if p is not between 0 and 1.
Examples:
>>> flip = RandomFlip(p=0.5, direction="horizontal")
>>> flip = RandomFlip(p=0.7, direction="vertical", flip_idx=[1, 0, 3, 2, 5, 4])
"""
# 断言语句,确保 direction 参数是支持的方向,且 p 参数在 [0, 1] 范围内。
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}" # 支持方向 `水平` 或 `垂直`,得到{direction}。
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." # 概率应该在 [0, 1] 范围内,但得到的是 {p}。
# 将传入的参数赋值给类的实例变量。
self.p = p
self.direction = direction
self.flip_idx = flip_idx
# 定义了一个 __call__ 方法,允许类的实例像函数一样被调用。
def __call__(self, labels):
# 对图像应用随机翻转,并相应地更新任何实例(如边界框或关键点)。
# 此方法根据初始化的概率和方向随机水平或垂直翻转输入图像。它还会更新相应的实例(边界框、关键点)以匹配翻转的图像。
# 参数:
# labels (Dict):包含以下键的字典:
# 'img' (numpy.ndarray):要翻转的图像。
# 'instances' (ultralytics.utils.instance.Instances):包含边界框和可选关键点的对象。
"""
Applies random flip to an image and updates any instances like bounding boxes or keypoints accordingly.
This method randomly flips the input image either horizontally or vertically based on the initialized
probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to
match the flipped image.
Args:
labels (Dict): A dictionary containing the following keys:
'img' (numpy.ndarray): The image to be flipped.
'instances' (ultralytics.utils.instance.Instances): An object containing bounding boxes and
optionally keypoints.
Returns:
(Dict): The same dictionary with the flipped image and updated instances:
'img' (numpy.ndarray): The flipped image.
'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image.
Examples:
>>> labels = {"img": np.random.rand(640, 640, 3), "instances": Instances(...)}
>>> random_flip = RandomFlip(p=0.5, direction="horizontal")
>>> flipped_labels = random_flip(labels)
"""
# 从 labels 字典中获取图像数据。
img = labels["img"]
# 获取实例标签,并从 labels 字典中移除。
instances = labels.pop("instances")
# 确保边界框格式是 xywh (x坐标,y坐标,宽度,高度)。
instances.convert_bbox(format="xywh")
# 获取图像的高度和宽度。
h, w = img.shape[:2]
# 如果边界框坐标是归一化的,则将 h 和 w 设置为 1,否则使用图像的实际尺寸。
h = 1 if instances.normalized else h
w = 1 if instances.normalized else w
# Flip up-down
# 检查是否需要进行垂直翻转。
if self.direction == "vertical" and random.random() < self.p:
# 对图像进行垂直翻转。
img = np.flipud(img)
# 对边界框进行垂直翻转。
instances.flipud(h)
# 检查是否需要进行水平翻转。
if self.direction == "horizontal" and random.random() < self.p:
# 对图像进行水平翻转。
img = np.fliplr(img)
# 对边界框进行水平翻转。
instances.fliplr(w)
# For keypoints
# 如果提供了 flip_idx 并且存在关键点,则对关键点进行水平翻转。
if self.flip_idx is not None and instances.keypoints is not None:
# 根据 flip_idx 索引映射对关键点进行翻转。
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
# 更新 labels 字典中的图像数据和实例标签。
labels["img"] = np.ascontiguousarray(img)
labels["instances"] = instances
# 返回更新后的 labels 字典。
return labels
# 这个方法的作用是对图像进行随机的水平或垂直翻转,并相应地更新边界框和关键点的位置。这种数据增强技术有助于提高模型对图像方向变化的鲁棒性。
10.class LetterBox:
python
# 这段代码定义了一个名为 LetterBox 的类,它用于将图像调整到指定的尺寸,同时保持图像的长宽比不变,这个过程通常称为"Letterboxing"。这种调整常用于深度学习模型的输入预处理,以确保图像尺寸与模型期望的尺寸一致。
# 定义了一个名为 LetterBox 的新类。
class LetterBox:
# 调整图像大小和填充以进行检测、实例分割和姿势。
# 此类调整图像大小并填充图像以使其达到指定形状,同时保持宽高比。它还会更新相应的标签和边界框。
# 方法:
# __call__ :调整图像大小并填充图像,更新标签和边界框。
"""
Resize image and padding for detection, instance segmentation, pose.
This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates
corresponding labels and bounding boxes.
Attributes:
new_shape (tuple): Target shape (height, width) for resizing.
auto (bool): Whether to use minimum rectangle.
scaleFill (bool): Whether to stretch the image to new_shape.
scaleup (bool): Whether to allow scaling up. If False, only scale down.
stride (int): Stride for rounding padding.
center (bool): Whether to center the image or align to top-left.
Methods:
__call__: Resize and pad image, update labels and bounding boxes.
Examples:
>>> transform = LetterBox(new_shape=(640, 640))
>>> result = transform(labels)
>>> resized_img = result["img"]
>>> updated_instances = result["instances"]
"""
# 这是 LetterBox 类的构造函数,它接受多个参数来配置图像调整的行为。
# 1.new_shape :目标尺寸,是一个元组 (width, height) ,默认为 (640, 640) 。
# 2.auto :自动模式,如果为 True ,则自动调整图像尺寸以适应 new_shape 。
# 3.scaleFill :如果为 True ,则调整图像尺寸以完全填充 new_shape ,可能会改变长宽比。
# 4.scaleup :如果为 True ,则在需要时放大图像以适应 new_shape 。
# 5.center :如果为 True ,则将图像居中放置在 new_shape 中,否则放置在左上角。
# 6.stride :用于确保 new_shape 是 stride 的倍数,通常用于某些深度学习模型。
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
# 初始化 LetterBox 对象以调整图像大小和填充。
# 此类旨在调整图像大小和填充图像,以用于对象检测、实例分割和姿势估计任务。它支持各种调整大小模式,包括自动调整大小、缩放填充和边界化。
"""
Initialize LetterBox object for resizing and padding images.
This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation
tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing.
Args:
new_shape (Tuple[int, int]): Target size (height, width) for the resized image.
auto (bool): If True, use minimum rectangle to resize. If False, use new_shape directly.
scaleFill (bool): If True, stretch the image to new_shape without padding.
scaleup (bool): If True, allow scaling up. If False, only scale down.
center (bool): If True, center the placed image. If False, place image in top-left corner.
stride (int): Stride of the model (e.g., 32 for YOLOv5).
Attributes:
new_shape (Tuple[int, int]): Target size for the resized image.
auto (bool): Flag for using minimum rectangle resizing.
scaleFill (bool): Flag for stretching image without padding.
scaleup (bool): Flag for allowing upscaling.
stride (int): Stride value for ensuring image size is divisible by stride.
Examples:
>>> letterbox = LetterBox(new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32)
>>> resized_img = letterbox(original_img)
"""
# 将传入的 new_shape 参数赋值给实例变量 self.new_shape 。
self.new_shape = new_shape
# 将传入的 auto 参数赋值给实例变量 self.auto 。
self.auto = auto
# 将传入的 scaleFill 参数赋值给实例变量 self.scaleFill 。
self.scaleFill = scaleFill
# 将传入的 scaleup 参数赋值给实例变量 self.scaleup 。
self.scaleup = scaleup
# 将传入的 stride 参数赋值给实例变量 self.stride 。
self.stride = stride
# 将传入的 center 参数赋值给实例变量 self.center 。
self.center = center # Put the image in the middle or top-left
# LetterBox 类的作用是提供一种方法来调整图像尺寸,使其适应模型的输入要求,同时尽可能保持图像的原始长宽比。这种调整方式对于确保图像内容不会因为尺寸调整而失真非常重要,特别是在目标检测和图像分类任务中。
# 这段代码定义了 LetterBox 类的 __call__ 方法,它用于将图像调整到指定的尺寸,同时保持图像的长宽比不变。
# 定义了一个 __call__ 方法,允许类的实例像函数一样被调用。这个方法接受两个参数。
# 1.labels :是包含图像标签信息的字典。
# 2.image :是图像数据。
def __call__(self, labels=None, image=None):
# 调整图像大小并填充图像以用于对象检测、实例分割或姿势估计任务。
# 此方法将信箱化应用于输入图像,这涉及调整图像大小,同时保持其纵横比并添加填充以适应新形状。它还会相应地更新任何关联的标签。
"""
Resizes and pads an image for object detection, instance segmentation, or pose estimation tasks.
This method applies letterboxing to the input image, which involves resizing the image while maintaining its
aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly.
Args:
labels (Dict | None): A dictionary containing image data and associated labels, or empty dict if None.
image (np.ndarray | None): The input image as a numpy array. If None, the image is taken from 'labels'.
Returns:
(Dict | Tuple): If 'labels' is provided, returns an updated dictionary with the resized and padded image,
updated labels, and additional metadata. If 'labels' is empty, returns a tuple containing the resized
and padded image, and a tuple of (ratio, (left_pad, top_pad)).
Examples:
>>> letterbox = LetterBox(new_shape=(640, 640))
>>> result = letterbox(labels={"img": np.zeros((480, 640, 3)), "instances": Instances(...)})
>>> resized_img = result["img"]
>>> updated_instances = result["instances"]
"""
# 如果 labels 参数为 None ,则初始化为空字典。
if labels is None:
labels = {}
# 根据 image 参数的值选择图像数据源。
img = labels.get("img") if image is None else image
# 获取图像的当前高度和宽度。
shape = img.shape[:2] # current shape [height, width]
# 从 labels 字典中获取 目标尺寸 ,如果不存在则使用类的实例变量 self.new_shape 。
new_shape = labels.pop("rect_shape", self.new_shape)
# 如果 new_shape 是整数,则将其转换为元组。
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
# 计算缩放比例,确保图像不会超出目标尺寸。
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# 如果 self.scaleup 为 False ,则限制缩放比例不超过 1,即不放大图像。
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
# 计算宽度和高度的缩放比例。
ratio = r, r # width, height ratios
# 计算缩放后的图像尺寸。
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
# 计算需要添加的填充宽度和高度。
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
# 如果是自动模式,则调整填充以适应 self.stride 的倍数。
if self.auto: # minimum rectangle
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
# 如果是填充模式,则不添加额外填充,且调整缩放比例。
elif self.scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
# 如果需要居中,则将填充平分到两侧。
if self.center:
dw /= 2 # divide padding into 2 sides
dh /= 2
# 如果图像需要调整尺寸,则使用 cv2.resize 进行缩放。
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
# 计算上下填充。
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
# 计算左右填充。
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
# 添加边界,使用常数边界值 (114, 114, 114) 。
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
) # add border
# 如果 labels 中存在 ratio_pad ,则更新它。
if labels.get("ratio_pad"):
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
# 如果 labels 不为空,则更新 labels 中的标签信息,并设置调整后的图像尺寸。
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
# 更新 labels 中的图像数据和调整后的尺寸。
labels["img"] = img
labels["resized_shape"] = new_shape
# 返回更新后的 labels 字典。
return labels
else:
# # 返回更新后的图像数据。
return img
# 这个方法的作用是将图像调整到指定的尺寸,同时保持图像的长宽比不变,通过添加边界(Letterboxing)来适应目标尺寸。这种调整方式对于确保图像内容不会因为尺寸调整而失真非常重要,特别是在目标检测和图像分类任务中。
# 这段代码定义了一个名为 _update_labels 的方法,它用于更新图像标签,以适应图像尺寸的变化。
# 这个方法接受四个参数。
# 1.labels :包含图像标签信息的字典。
# 2.ratio :缩放比例。
# 3.padw :水平方向的填充宽度。
# 4.padh :垂直方向的填充高度。
def _update_labels(self, labels, ratio, padw, padh):
# 在将信箱效果应用于图像后更新标签。
# 此方法修改标签中实例的边界框坐标,以考虑信箱效果期间应用的调整大小和填充。
"""
Updates labels after applying letterboxing to an image.
This method modifies the bounding box coordinates of instances in the labels
to account for resizing and padding applied during letterboxing.
Args:
labels (Dict): A dictionary containing image labels and instances.
ratio (Tuple[float, float]): Scaling ratios (width, height) applied to the image.
padw (float): Padding width added to the image.
padh (float): Padding height added to the image.
Returns:
(Dict): Updated labels dictionary with modified instance coordinates.
Examples:
>>> letterbox = LetterBox(new_shape=(640, 640))
>>> labels = {"instances": Instances(...)}
>>> ratio = (0.5, 0.5)
>>> padw, padh = 10, 20
>>> updated_labels = letterbox._update_labels(labels, ratio, padw, padh)
"""
# 确保边界框格式是 xyxy (x1, y1, x2, y2)。
labels["instances"].convert_bbox(format="xyxy")
# 将边界框坐标从归一化值转换为实际像素坐标。这里使用 [::-1] 来交换图像的高度和宽度,因为 denormalize 方法需要宽度和高度的顺序。
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
# 根据缩放比例调整边界框的大小。
labels["instances"].scale(*ratio)
# 根据填充宽度和高度调整边界框的位置。
labels["instances"].add_padding(padw, padh)
# 返回更新后的 labels 字典。
return labels
# 这个方法的作用是更新图像标签,使其与图像尺寸的变化保持一致。这在图像经过缩放和填充后尤为重要,因为边界框的位置和大小需要相应地调整,以确保它们仍然准确地描述图像中对象的位置。这对于目标检测和图像分割任务中的准确标注是必要的。
11.class CopyPaste:
python
# 这段代码定义了一个名为 CopyPaste 的类,它用于实现复制粘贴数据增强技术。这种技术涉及从一张图像中随机选择一个区域,并将其粘贴到同一张图像或另一张图像的随机位置。
# 定义了一个名为 CopyPaste 的新类。
class CopyPaste:
# 实现复制粘贴增强,如 https://arxiv.org/abs/2012.07177 中所述。
# 此类将复制粘贴增强应用于图像及其相应的实例。
# 方法:
# __call__ :将复制粘贴增强应用于给定的图像和实例。
"""
Implements Copy-Paste augmentation as described in https://arxiv.org/abs/2012.07177.
This class applies Copy-Paste augmentation on images and their corresponding instances.
Attributes:
p (float): Probability of applying the Copy-Paste augmentation. Must be between 0 and 1.
Methods:
__call__: Applies Copy-Paste augmentation to given image and instances.
Examples:
>>> copypaste = CopyPaste(p=0.5)
>>> augmented_labels = copypaste(labels)
>>> augmented_image = augmented_labels["img"]
"""
# 这是 CopyPaste 类的构造函数,它接受一个参数。
# 1.p :表示执行复制粘贴操作的概率,默认值为 0.5,即有 50% 的概率应用这种变换。
def __init__(self, p=0.5) -> None:
# 初始化 CopyPaste 增强对象。
# 此类实现了论文"简单的复制粘贴是实例分割的强大数据增强方法"(https://arxiv.org/abs/2012.07177) 中描述的复制粘贴增强。它以给定的概率将复制粘贴增强应用于图像及其相应的实例。
"""
Initializes the CopyPaste augmentation object.
This class implements the Copy-Paste augmentation as described in the paper "Simple Copy-Paste is a Strong Data
Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It applies the Copy-Paste
augmentation on images and their corresponding instances with a given probability.
Args:
p (float): The probability of applying the Copy-Paste augmentation. Must be between 0 and 1.
Attributes:
p (float): Stores the probability of applying the augmentation.
Examples:
>>> augment = CopyPaste(p=0.7)
>>> augmented_data = augment(original_data)
"""
# 将传入的 p 参数赋值给类的实例变量 self.p ,用于存储执行复制粘贴操作的概率。
self.p = p
# 这段代码定义了 CopyPaste 类的 __call__ 方法,它实现了复制粘贴数据增强技术。这个方法接受一个参数 labels ,它包含了图像的标签信息。
# 定义了一个 __call__ 方法,允许类的实例像函数一样被调用。
def __call__(self, labels):
# 将复制粘贴增强应用于图像及其实例。
# 参数:
# labels (Dict) : 包含以下内容的字典:
# - 'img' (np.ndarray) : 要增强的图像。
# - 'cls' (np.ndarray) : 实例的类标签。
# - 'instances' (ultralytics.engine.results.Instances) : 包含边界框、片段等的对象。
"""
Applies Copy-Paste augmentation to an image and its instances.
Args:
labels (Dict): A dictionary containing:
- 'img' (np.ndarray): The image to augment.
- 'cls' (np.ndarray): Class labels for the instances.
- 'instances' (ultralytics.engine.results.Instances): Object containing bounding boxes, segments, etc.
Returns:
(Dict): Dictionary with augmented image and updated instances under 'img', 'cls', and 'instances' keys.
Examples:
>>> labels = {"img": np.random.rand(640, 640, 3), "cls": np.array([0, 1, 2]), "instances": Instances(...)}
>>> augmenter = CopyPaste(p=0.5)
>>> augmented_labels = augmenter(labels)
"""
# 从 labels 字典中获取图像数据。
im = labels["img"]
# 获取类别标签。
cls = labels["cls"]
# 获取图像的高度和宽度。
h, w = im.shape[:2]
# 获取实例标签,并从 labels 字典中移除。
instances = labels.pop("instances")
# 确保边界框格式是 xyxy 。
instances.convert_bbox(format="xyxy")
# 将边界框坐标从归一化值转换为实际像素坐标。 denormalize 方法需要宽度和高度的顺序。
instances.denormalize(w, h)
# 检查是否需要执行复制粘贴操作(基于概率 self.p )以及是否存在分割区域。
if self.p and len(instances.segments):
# 获取图像的通道数。
_, w, _ = im.shape # height, width, channels
# 创建一个新的图像数组,用于存储粘贴的分割区域。
im_new = np.zeros(im.shape, np.uint8)
# Calculate ioa first then select indexes randomly
# 复制实例标签。
ins_flip = deepcopy(instances)
# 将复制的实例标签翻转。
ins_flip.fliplr(w)
# 计算复制的边界框和原始边界框之间的交集面积比(Intersection over Area, IoA)。
# def bbox_ioa(box1, box2, iou=False, eps=1e-7): -> 这个函数可以计算两个边界框的交集区域,并将其与第二个边界框的面积进行比较,以得到IoA值。返回IoA或IoU值, eps 用于防止除以零。 -> return inter_area / (area + eps)
ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
# 选择 IoA 小于 0.30 的索引,即复制的区域与原始区域重叠较少的情况。
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
# 获取满足条件的索引数量。
n = len(indexes)
# 随机选择一些索引进行复制粘贴操作。
for j in random.sample(list(indexes), k=round(self.p * n)):
# 将复制的类别标签添加到原始类别标签中。
cls = np.concatenate((cls, cls[[j]]), axis=0)
# 将复制的实例标签添加到原始实例标签中。
instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
# cv2.drawContours(image, contours, contourIdx, color, thickness=None, lineType=8, hierarchy=None)
# cv2.drawContours 是 OpenCV 库中的一个函数,它用于在图像上绘制轮廓。这个函数可以绘制简单的轮廓或复杂的轮廓,并且可以对轮廓进行填充。
# image :要绘制轮廓的图像。
# contours :轮廓的列表,其中每个轮廓都是一个点集的数组。
# contourIdx :要绘制的轮廓的索引。如果为 -1 ,则绘制所有轮廓。
# color :轮廓的颜色,以 BGR 格式指定。
# thickness :轮廓的线条粗细。如果为负值,则轮廓将被填充。
# lineType :轮廓的线条类型,例如 8 表示 cv2.LINE_8 , cv2.CV_AA 表示抗锯齿线条。
# hierarchy :轮廓的层次结构,用于指定轮廓之间的关系。
# 返回值 :
# 该函数不返回任何值,它直接在输入图像 image 上进行绘制。
# 注意事项 :
# 当 thickness 为负值时,例如 -1 ,轮廓将被填充。
# hierarchy 参数可以用来指定轮廓之间的父子关系,这对于处理具有嵌套结构的轮廓非常有用。
# cv2.drawContours 函数可以一次绘制多个轮廓,只需将它们作为列表传递给 contours 参数。
# 在新图像上绘制复制的分割区域。
# im_new :这是目标图像,轮廓将被绘制在这个图像上。这个图像应该是一个三通道的BGR格式图像。
# instances.segments[[j]].astype(np.int32) :这是要绘制的轮廓点集。 instances.segments 是一个包含多个轮廓的数组,其中每个轮廓都是一个点集。 [j] 表示选择第 j 个轮廓。 astype(np.int32) 将轮廓点的类型转换为32位整数,这是 cv2.drawContours 函数所要求的。
# -1 :这个参数指定绘制哪个轮廓。传入 -1 表示绘制所有轮廓,但如果 instances.segments 只包含一个轮廓,也可以这样使用。
# (1, 1, 1) :这个参数指定绘制轮廓的颜色。在这里, (1, 1, 1) 表示白色,因为 OpenCV 中的颜色格式是BGR,所以这里的 (1, 1, 1) 对应于BGR值都是255的白色。
# cv2.FILLED :这个参数指定轮廓的填充方式。 cv2.FILLED 表示轮廓内部将被填充颜色。如果不希望填充轮廓内部,可以使用 cv2.LINES 或 cv2.LINE_8 。
# 综上所述,这行代码的作用是在 im_new 图像上绘制 instances 中第 j 个轮廓,并将轮廓内部填充为白色。这种操作通常用于可视化检测到的对象,或者在数据增强过程中将一个对象复制并粘贴到图像的不同位置。
cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
# 对原始图像进行翻转。
result = cv2.flip(im, 1) # augment segments (flip left-right)
# 对新图像进行翻转,并转换为布尔数组。
# .astype(bool) :这个操作将 cv2.flip 函数的结果转换为布尔类型。在这种情况下, im_new 图像的像素值通常是 uint8 类型,其中值为 0 的像素点在转换为布尔类型时将变为 False ,非 0 的像素点将变为 True 。
# 这通常用于创建一个掩码,该掩码指示 im_new 中哪些像素是 "开启" 或 "非零" 的,这可以用于后续的图像处理操作,比如将 im_new 中的某些区域复制到另一个图像上。
i = cv2.flip(im_new, 1).astype(bool)
# 将翻转后的新图像中的分割区域复制到翻转后的原始图像中。
im[i] = result[i] # 是不是应该修改为 : im[i] = im_new[i] 。
# 更新 labels 字典中的图像数据。
labels["img"] = im
# 更新类别标签。
labels["cls"] = cls
# 更新实例标签。
labels["instances"] = instances
# 返回更新后的 labels 字典。
return labels
# 这个方法的作用是通过复制粘贴分割区域来增强图像数据,这有助于模型学习时对图像内容变化的鲁棒性,特别是在目标检测和图像分割任务中。通过模拟目标在不同位置出现的情况,可以提高模型的泛化能力。
12.class Albumentations:
python
# 这段代码是一个Python类的定义,名为 Albumentations ,它用于图像增强,特别是与 albumentations 库一起使用。
# 定义了一个名为 Albumentations 的类。
class Albumentations:
# Albumentations 用于图像增强的转换。
# 此类使用 Albumentations 库应用各种图像转换。它包括模糊、中值模糊、转换为灰度、对比度限制自适应直方图均衡 (CLAHE)、亮度和对比度的随机变化、RandomGamma 以及通过压缩降低图像质量等操作。
# 方法:
# __call__ :将 Albumentations 转换应用于输入标签。
"""
Albumentations transformations for image augmentation.
This class applies various image transformations using the Albumentations library. It includes operations such as
Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes
in brightness and contrast, RandomGamma, and image quality reduction through compression.
Attributes:
p (float): Probability of applying the transformations.
transform (albumentations.Compose): Composed Albumentations transforms.
contains_spatial (bool): Indicates if the transforms include spatial operations.
Methods:
__call__: Applies the Albumentations transformations to the input labels.
Examples:
>>> transform = Albumentations(p=0.5)
>>> augmented_labels = transform(labels)
Notes:
- The Albumentations package must be installed to use this class.
- If the package is not installed or an error occurs during initialization, the transform will be set to None.
- Spatial transforms are handled differently and require special processing for bounding boxes.
"""
# 定义了类的构造函数,它接受一个参数。
# 1.p :默认值为1.0。这个参数可能用于控制变换的概率。
def __init__(self, p=1.0):
# 为 YOLO bbox 格式的参数初始化 Albumentations 变换对象。
# 此类使用 Albumentations 库应用各种图像增强,包括模糊、中值模糊、转换为灰度、对比度限制自适应直方图均衡、亮度和对比度的随机变化、RandomGamma 以及通过压缩降低图像质量。
# 注意事项:
# - 需要 Albumentations 1.0.3 或更高版本。
# - 空间变换的处理方式不同,以确保 bbox 兼容性。
# - 默认情况下,某些变换的应用概率非常低(0.01)。
"""
Initialize the Albumentations transform object for YOLO bbox formatted parameters.
This class applies various image augmentations using the Albumentations library, including Blur, Median Blur,
conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and
contrast, RandomGamma, and image quality reduction through compression.
Args:
p (float): Probability of applying the augmentations. Must be between 0 and 1.
Attributes:
p (float): Probability of applying the augmentations.
transform (albumentations.Compose): Composed Albumentations transforms.
contains_spatial (bool): Indicates if the transforms include spatial transformations.
Raises:
ImportError: If the Albumentations package is not installed.
Exception: For any other errors during initialization.
Examples:
>>> transform = Albumentations(p=0.5)
>>> augmented = transform(image=image, bboxes=bboxes, class_labels=classes)
>>> augmented_image = augmented["image"]
>>> augmented_bboxes = augmented["bboxes"]
Notes:
- Requires Albumentations version 1.0.3 or higher.
- Spatial transforms are handled differently to ensure bbox compatibility.
- Some transforms are applied with very low probability (0.01) by default.
"""
# 将传入的参数 p 赋值给类的实例变量 p 。
self.p = p
# 初始化一个名为 transform 的实例变量,并将其设置为 None 。
self.transform = None
# 创建一个前缀字符串,用于日志输出, colorstr 是一个自定义函数,用于给字符串添加颜色。
# def colorstr(*input):
# -> 函数通过遍历 args 中的每个元素(颜色或样式),从 colors 字典中获取对应的ANSI转义序列,并将其与传入的 string 字符串连接起来。最后,它还会添加一个 colors["end"] 序列,用于重置终端的颜色和样式到默认状态。
# -> return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
prefix = colorstr("albumentations: ")
# 开始一个 try 块,用于捕获可能发生的异常。
try:
# 尝试导入 albumentations 库,并将其别名设置为 A 。
import albumentations as A
# 检查 albumentations 库的版本是否为1.0.3,如果不是,则抛出异常。
check_version(A.__version__, "1.0.3", hard=True) # version requirement
# List of possible spatial transforms 可能的空间变换列表。
# 定义了一个名为 spatial_transforms 的集合,包含了所有支持的空间变换(从提供的网页链接中获取)。
spatial_transforms = {
"Affine",
"BBoxSafeRandomCrop",
"CenterCrop",
"CoarseDropout",
"Crop",
"CropAndPad",
"CropNonEmptyMaskIfExists",
"D4",
"ElasticTransform",
"Flip",
"GridDistortion",
"GridDropout",
"HorizontalFlip",
"Lambda",
"LongestMaxSize",
"MaskDropout",
"MixUp",
"Morphological",
"NoOp",
"OpticalDistortion",
"PadIfNeeded",
"Perspective",
"PiecewiseAffine",
"PixelDropout",
"RandomCrop",
"RandomCropFromBorders",
"RandomGridShuffle",
"RandomResizedCrop",
"RandomRotate90",
"RandomScale",
"RandomSizedBBoxSafeCrop",
"RandomSizedCrop",
"Resize",
"Rotate",
"SafeRotate",
"ShiftScaleRotate",
"SmallestMaxSize",
"Transpose",
"VerticalFlip",
"XYMasking",
} # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms
# Transforms
# 定义了一个名为 T 的列表,包含了一系列的图像变换,例如模糊、中值模糊、灰度转换等。
# 这段代码定义了一个名为 T 的列表,其中包含了几个不同的图像增强变换(transforms),这些变换都是 albumentations 库提供的。每个变换后面都有一个 p 参数,它表示应用该变换的概率。
T = [
# 应用一个模糊效果, p=0.01 表示有 0.01 的概率应用这个变换。
A.Blur(p=0.01),
# 应用一个中值模糊效果, p=0.01 表示有 0.01 的概率应用这个变换。
A.MedianBlur(p=0.01),
# 将图像转换为灰度图, p=0.01 表示有 0.01 的概率应用这个变换。
A.ToGray(p=0.01),
# 应用对比度受限的自适应直方图均衡化(CLAHE), p=0.01 表示有 0.01 的概率应用这个变换。
A.CLAHE(p=0.01),
# 随机调整图像的亮度和对比度, p=0.0 表示这个变换不会被应用。
A.RandomBrightnessContrast(p=0.0),
# 随机调整图像的伽马值, p=0.0 表示这个变换不会被应用。
A.RandomGamma(p=0.0),
# 模拟图像压缩的效果, quality_lower=75 表示压缩质量的下限为 75%, p=0.0 表示这个变换不会被应用。
A.ImageCompression(quality_lower=75, p=0.0),
]
# 这些变换被设计为可以随机地应用于图像,以增加数据的多样性,这在机器学习中是一种常见的数据增强技术。通过设置不同的概率,可以控制每个变换被应用的频率。
# 在这个特定的例子中,除了 A.RandomBrightnessContrast 、 A.RandomGamma 和 A.ImageCompression 变换的概率为 0,不会应用外,其他变换都有很小的概率(1%)被应用。
# Compose transforms
# 检查 T 列表中的变换是否包含空间变换,并根据结果决定是否使用 A.Compose 的 bbox_params 参数。
# 这行代码是 Python 类中的一个表达式,用于确定一个名为 T 的变换列表中是否包含任何空间级变换(spatial-level transforms)。如果列表 T 中至少有一个变换是空间级变换,那么 self.contains_spatial 将被设置为 True ,否则为 False 。
# for transform in T : 这是一个生成器表达式,用于遍历列表 T 中的每个元素(变换)。每个元素 transform 都是一个变换实例。
# transform.class.name : 这部分获取当前变换实例 transform 的类名。例如,如果 transform 是 A.Blur 类的一个实例,那么 transform.__class__.__name__ 将返回字符串 "Blur" 。
# in spatial_transforms : 这个表达式检查变换的类名是否在 spatial_transforms 集合中。 spatial_transforms 是一个包含所有空间级变换类名的集合。
# any(...) : any 函数接受一个可迭代对象(在这个例子中是一个生成器表达式),如果可迭代对象中的任何元素为 True ,则返回 True ;如果所有元素都为 False ,则返回 False 。
# 这个表达式的目的是快速检查 T 列表中是否至少有一个变换是空间级变换。这是通过检查每个变换的类名是否在 spatial_transforms 集合中来实现的。如果找到至少一个匹配项, any 函数将返回 True ,并将这个值赋给 self.contains_spatial 。
self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T)
# A.Compose(transforms, bbox_params=None, keypoint_params=None, mask_params=None, p=1.0)
# albumentations.Compose 是 Albumentations 库中的一个函数,它用于组合多个图像变换(transforms)成一个复合变换(composition)。这个复合变换可以同时应用于图像和与之相关的标注(如边界框、关键点、掩码等)。
# 参数说明 :
# transforms : 一个包含多个变换的列表。每个变换可以是一个单独的变换类实例,也可以是一个变换序列。
# bbox_params : 一个 BboxParams 对象,用于定义如何应用变换到边界框。如果为 None ,则不会对边界框应用任何变换。
# keypoint_params : 一个 KeypointParams 对象,用于定义如何应用变换到关键点。如果为 None ,则不会对关键点应用任何变换。
# mask_params : 一个 MaskParams 对象,用于定义如何应用变换到掩码。如果为 None ,则不会对掩码应用任何变换。
# p : 一个浮点数,表示应用整个复合变换的概率。值范围从 0 到 1,其中 1 表示总是应用变换。
# 返回值 :
# 返回一个复合变换对象,可以被用于图像增强流程中。
# A.BboxParams(format=None, label_fields=None, min_area=0.0, min_visibility=0.0, check_each_transform=True)
# A.BboxParams() 是 Albumentations 库中用于定义边界框参数的类,一般用在 A.Compose() 实例的初始化参数中。
# 参数说明 :
# format : 边界框的格式。可以是 "coco" 、 "pascal_voc" 、 "albumentations" 或 "yolo" 。
# "coco" 格式 : [x_min, y_min, width, height] 。
# "pascal_voc" 格式 : [x_min, y_min, x_max, y_max] 。
# "albumentations" 格式 : 与 "pascal_voc" 相同,但归一化到 [0, 1] 范围内。
# "yolo" 格式 : [x, y, width, height] ,其中 x, y 指归一化的边界框中心, width, height 指归一化的边界框宽高。
# label_fields : 与边界框相关联的标签字段列表。如果边界框和类别标签写在一起,则不需要此参数。
# min_area : 边界框的最小面积(以像素为单位)。所有可见面积小于这个值的边界框将被移除。默认为 0.0 。
# min_visibility : 边界框保持在列表中的最小可见面积比。默认为 0.0 。
# check_each_transform : 如果为 True ,在每个双重变换后都会检查边界框是否满足限制。默认为 True 。
# 如果包含空间变换,则创建一个 A.Compose 对象,否则只创建一个没有 bbox_params 的 A.Compose 对象。
self.transform = (
A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
if self.contains_spatial
else A.Compose(T)
)
# 使用日志记录器输出当前应用的变换。
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
# 如果 albumentations 库没有安装,则捕获 ImportError 异常。
except ImportError: # package not installed, skip
# 如果捕获到 ImportError ,则不执行任何操作。
pass
# 捕获其他所有异常,并使用日志记录器记录异常信息。
except Exception as e:
LOGGER.info(f"{prefix}{e}")
# 这个类的主要功能是创建一个图像增强的管道,可以应用于图像数据,包括可能的空间变换。它还处理了 albumentations 库的版本检查和异常处理。
# 在图像处理和计算机视觉中,特别是在数据增强的上下文中,空间级变换(spatial-level transforms)和非空间级变换(pixel-level transforms)是两种不同类型的图像变换 :
# 空间级变换(Spatial-level transforms)
# 空间级变换是指那些会改变图像中物体的空间位置或尺寸的变换。这些变换不仅影响图像的像素值,还会影响图像中的标注信息,如边界框(bounding boxes)、掩码(masks)和关键点(keypoints)。空间级变换包括但不限于 :
# 平移(Translation) :移动图像中的内容。
# 旋转(Rotation) :围绕图像中心或某个点旋转图像。
# 缩放(Scaling) :改变图像的大小。
# 剪切(Shear) :对图像进行倾斜变换。
# 透视变换(Perspective Transformation) :模拟相机视角变化对图像的影响。
# 翻转(Flip) :水平或垂直翻转图像。
# 裁剪(Cropping) :从图像中裁剪出一部分区域。
# 这些变换会改变图像中物体的位置和尺寸,因此需要相应地更新边界框、掩码等标注信息,以保持与图像内容的一致性。
# 非空间级变换(Pixel-level transforms)
# 非空间级变换是指那些只改变图像像素值而不改变物体空间位置或尺寸的变换。这些变换通常只影响图像的外观,而不会影响标注信息。非空间级变换包括但不限于 :
# 颜色抖动(Color Jitter) :随机改变图像的亮度、对比度、饱和度等。
# 模糊(Blur) :对图像应用模糊效果,如高斯模糊。
# 噪声添加(Noise Addition) :向图像中添加随机噪声。
# 直方图均衡化(Histogram Equalization) :增强图像的对比度。
# 像素值缩放(Pixel Value Scaling) :改变像素值的范围,如归一化到 [0, 1]。
# 这些变换不会改变图像中物体的空间位置,因此不需要更新边界框、掩码等标注信息。
# 总结 :
# 空间级变换和非空间级变换的主要区别在于它们对图像中物体空间位置的影响。空间级变换需要更新标注信息以反映物体的新位置,而非空间级变换则不需要。
# 在实际应用中,这两种变换通常结合使用,以生成多样化的训练数据,提高模型的泛化能力。
# 这段代码定义了一个 Python 类的 __call__ 方法,这个方法使得类的实例可以像函数一样被调用。这个方法通常用于实现类的"可调用"协议。在这个特定的例子中,这个方法被用来对包含图像和标签(例如类别标签和边界框)的数据进行变换。
# 这个方法接受 self (类的实例本身)和一个参数。
# 1.labels :是一个字典,包含了图像和相关的标签信息。
def __call__(self, labels):
# 将 Albumentations 转换应用于输入标签。
# 此方法使用 Albumentations 库应用一系列图像增强。它可以对输入图像及其对应的标签执行空间和非空间转换。
# 注意:
# - 该方法以概率 self.p 应用变换。
# - 空间变换更新边界框,而非空间变换仅修改图像。
# - 需要安装 Albumentations 库。
"""
Applies Albumentations transformations to input labels.
This method applies a series of image augmentations using the Albumentations library. It can perform both
spatial and non-spatial transformations on the input image and its corresponding labels.
Args:
labels (Dict): A dictionary containing image data and annotations. Expected keys are:
- 'img': numpy.ndarray representing the image
- 'cls': numpy.ndarray of class labels
- 'instances': object containing bounding boxes and other instance information
Returns:
(Dict): The input dictionary with augmented image and updated annotations.
Examples:
>>> transform = Albumentations(p=0.5)
>>> labels = {
... "img": np.random.rand(640, 640, 3),
... "cls": np.array([0, 1]),
... "instances": Instances(bboxes=np.array([[0, 0, 1, 1], [0.5, 0.5, 0.8, 0.8]])),
... }
>>> augmented = transform(labels)
>>> assert augmented["img"].shape == (640, 640, 3)
Notes:
- The method applies transformations with probability self.p.
- Spatial transforms update bounding boxes, while non-spatial transforms only modify the image.
- Requires the Albumentations library to be installed.
"""
# 检查变换是否为空或随机概率是否未超过设定值。如果 self.transform 是 None 或者生成的随机数大于实例变量 self.p ,则直接返回原始的 labels 。
if self.transform is None or random.random() > self.p:
return labels
# 处理空间变换。
# 如果类实例变量 self.contains_spatial 为 True ,表示需要处理空间变换(如平移、旋转等),这些变换会影响边界框的位置。
if self.contains_spatial:
# 提取类别标签和图像。
# 从 labels 字典中提取类别标签 cls 。
cls = labels["cls"]
if len(cls):
# 从 labels 字典中提取图像 im 。
im = labels["img"]
# 转换和归一化边界框。
# 将边界框格式转换为 xywh (中心点坐标加上宽高),然后根据图像尺寸进行归一化,并提取边界框坐标。
labels["instances"].convert_bbox("xywh")
labels["instances"].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
# 应用变换。使用 self.transform 对图像和边界框应用变换, self.transform 是一个由 albumentations 库提供的复合变换。
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
# 更新标签。如果变换后仍有类别标签存在,则更新图像和类别标签,并将边界框坐标转换为 float32 类型,然后更新到 labels 中。
# 这段代码是一个条件语句,用于检查经过变换后是否还有有效的边界框(bboxes),如果有,则更新图像和标签信息。
# 检查是否有有效的类别标签。
# 这行代码检查经过变换后返回的 new 字典中 "class_labels" 的长度是否大于0。如果大于0,说明变换后的图像中至少有一个有效的边界框,可以继续更新图像和标签信息。如果长度为0,说明变换后的图像中没有边界框,因此跳过更新。
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
# 更新图像。
# 如果变换后的图像中至少有一个有效的边界框,这行代码将 new 字典中的 "image" 键对应的值(变换后的图像)赋给 labels 字典中的 "img" 键。
labels["img"] = new["image"]
# 更新类别标签。
# 这行代码将 new 字典中的 "class_labels" 键对应的值(变换后的类别标签)转换为 NumPy 数组,并赋给 labels 字典中的 "cls" 键。
labels["cls"] = np.array(new["class_labels"])
# 更新边界框。
# 这行代码将 new 字典中的 "bboxes" 键对应的值(变换后的边界框)转换为 NumPy 数组,并指定数据类型为 np.float32 。
bboxes = np.array(new["bboxes"], dtype=np.float32)
# 更新实例信息。
# 这行代码调用 labels 字典中的 "instances" 键对应的对象的 update 方法,传入更新后的边界框 bboxes 。
labels["instances"].update(bboxes=bboxes)
# 这段代码的目的是确保只有在变换后的图像中存在有效的边界框时,才更新图像和标签信息。这样可以避免在变换后图像中没有边界框时,错误地更新图像和标签信息。
# 处理非空间变换。
else:
# 如果不包含空间变换,则只对图像应用变换,并更新 labels 中的图像。
labels["img"] = self.transform(image=labels["img"])["image"] # transformed
# 返回更新后的 labels ,其中包含了变换后的图像和标签。
return labels
# 这个方法的设计目的是为了灵活地对图像数据进行增强,同时确保相关的标签信息(如类别标签和边界框)与图像同步更新。
13.class Format:
python
# 这段代码定义了一个名为 Format 的 Python 类,它用于配置和标准化图像数据的格式,特别是在处理边界框、掩码、关键点等信息时。
# 类定义。定义了一个名为 Format 的类。
class Format:
# 用于为对象检测、实例分割和姿势估计任务格式化图像注释的类。
# 此类标准化 PyTorch DataLoader 中的 `collate_fn` 使用的图像和实例注释。
# 方法:
# __call__ :使用图像、类、边界框以及可选的掩码格式化标签字典和关键点。
# _format_img :将图像从 Numpy 数组转换为 PyTorch 张量。
# _format_segments :将多边形点转换为位图蒙版。
"""
A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks.
This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader.
Attributes:
bbox_format (str): Format for bounding boxes. Options are 'xywh' or 'xyxy'.
normalize (bool): Whether to normalize bounding boxes.
return_mask (bool): Whether to return instance masks for segmentation.
return_keypoint (bool): Whether to return keypoints for pose estimation.
return_obb (bool): Whether to return oriented bounding boxes.
mask_ratio (int): Downsample ratio for masks.
mask_overlap (bool): Whether to overlap masks.
batch_idx (bool): Whether to keep batch indexes.
bgr (float): The probability to return BGR images.
Methods:
__call__: Formats labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints.
_format_img: Converts image from Numpy array to PyTorch tensor.
_format_segments: Converts polygon points to bitmap masks.
Examples:
>>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True)
>>> formatted_labels = formatter(labels)
>>> img = formatted_labels["img"]
>>> bboxes = formatted_labels["bboxes"]
>>> masks = formatted_labels["masks"]
"""
# 构造函数。构造函数接受多个参数,用于初始化类的实例。
# 1.bbox_format ( str ) : 边界框的格式,默认为 "xywh" ,表示边界框的坐标以 (x, y, width, height) 的形式表示。其他可能的格式包括 "xyxy" (左上角和右下角的坐标)。
# 2.normalize ( bool ) : 是否对坐标进行归一化处理,默认为 True 。归一化通常是将坐标缩放到 [0, 1] 的范围内。
# 3.return_mask ( bool ): 是否返回掩码,默认为 False 。在只进行目标检测训练时,通常不需要掩码。
# 4.return_keypoint ( bool ): 是否返回关键点,默认为 False 。
# 5.return_obb ( bool ) : 是否返回方向边界框(Oriented Bounding Box),默认为 False 。
# 6.mask_ratio ( int ) : 掩码的比率,默认为 4。这个参数可能用于控制掩码的尺寸或精度。
# 7.mask_overlap ( bool ): 是否允许掩码重叠,默认为 True 。
# 8.batch_idx ( bool ): 是否保留批次索引,默认为 True 。
# 9.bgr ( float ) : 用于处理 BGR 颜色空间的参数,默认为 0.0。如果图像数据是 BGR 格式,这个参数可能用于指示或转换颜色空间。
def __init__(
self,
bbox_format="xywh",
normalize=True,
return_mask=False,
return_keypoint=False,
return_obb=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True,
bgr=0.0,
):
# 使用给定的参数初始化 Format 类,用于图像和实例注释格式化。
# 此类标准化图像和实例注释,用于对象检测、实例分割和姿势估计任务,为 PyTorch DataLoader 的 `collate_fn` 做准备。
"""
Initializes the Format class with given parameters for image and instance annotation formatting.
This class standardizes image and instance annotations for object detection, instance segmentation, and pose
estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`.
Args:
bbox_format (str): Format for bounding boxes. Options are 'xywh', 'xyxy', etc.
normalize (bool): Whether to normalize bounding boxes to [0,1].
return_mask (bool): If True, returns instance masks for segmentation tasks.
return_keypoint (bool): If True, returns keypoints for pose estimation tasks.
return_obb (bool): If True, returns oriented bounding boxes.
mask_ratio (int): Downsample ratio for masks.
mask_overlap (bool): If True, allows mask overlap.
batch_idx (bool): If True, keeps batch indexes.
bgr (float): Probability of returning BGR images instead of RGB.
Attributes:
bbox_format (str): Format for bounding boxes.
normalize (bool): Whether bounding boxes are normalized.
return_mask (bool): Whether to return instance masks.
return_keypoint (bool): Whether to return keypoints.
return_obb (bool): Whether to return oriented bounding boxes.
mask_ratio (int): Downsample ratio for masks.
mask_overlap (bool): Whether masks can overlap.
batch_idx (bool): Whether to keep batch indexes.
bgr (float): The probability to return BGR images.
Examples:
>>> format = Format(bbox_format="xyxy", return_mask=True, return_keypoint=False)
>>> print(format.bbox_format)
xyxy
"""
# 在构造函数中,每个参数都被赋值给同名的实例变量,这些变量可以在类的其他方法中使用。
self.bbox_format = bbox_format
self.normalize = normalize
self.return_mask = return_mask # set False when training detection only
self.return_keypoint = return_keypoint
self.return_obb = return_obb
self.mask_ratio = mask_ratio
self.mask_overlap = mask_overlap
self.batch_idx = batch_idx # keep the batch indexes
self.bgr = bgr
# Format 类用于图像预处理或数据加载阶段,以确保数据的格式一致性,便于后续的处理和模型训练。通过设置不同的参数,可以灵活地配置数据处理流程,以适应不同的任务需求,如目标检测、分割或关键点检测等。
# 这段代码定义了 Format 类的 __call__ 方法,它使得类的实例可以像函数一样被调用。这个方法的主要作用是将输入的标签数据( labels )格式化为统一的输出格式,包括图像、类别标签、边界框、掩码和关键点等。
# 方法定义。这个方法接受 self (类的实例本身)和一个参数。
# 1.labels :是一个字典,包含了图像和相关的标签信息。
def __call__(self, labels):
# 格式化图像注释以用于对象检测、实例分割和姿势估计任务。
# 此方法标准化 PyTorch DataLoader 中的 `collate_fn` 将使用的图像和实例注释。它处理输入标签字典,将注释转换为指定格式,并在需要时应用规范化。
"""
Formats image annotations for object detection, instance segmentation, and pose estimation tasks.
This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch
DataLoader. It processes the input labels dictionary, converting annotations to the specified format and
applying normalization if required.
Args:
labels (Dict): A dictionary containing image and annotation data with the following keys:
- 'img': The input image as a numpy array.
- 'cls': Class labels for instances.
- 'instances': An Instances object containing bounding boxes, segments, and keypoints.
Returns:
(Dict): A dictionary with formatted data, including:
- 'img': Formatted image tensor.
- 'cls': Class labels tensor.
- 'bboxes': Bounding boxes tensor in the specified format.
- 'masks': Instance masks tensor (if return_mask is True).
- 'keypoints': Keypoints tensor (if return_keypoint is True).
- 'batch_idx': Batch index tensor (if batch_idx is True).
Examples:
>>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True)
>>> labels = {"img": np.random.rand(640, 640, 3), "cls": np.array([0, 1]), "instances": Instances(...)}
>>> formatted_labels = formatter(labels)
>>> print(formatted_labels.keys())
"""
# 提取图像和相关标签信息。
# 从 labels 字典中提取 图像 img 及其 高度 h 和 宽度 w , 类别标签 cls 和 实例信息 instances 。
img = labels.pop("img")
h, w = img.shape[:2]
cls = labels.pop("cls")
instances = labels.pop("instances")
# 转换边界框格式并去归一化。
# 将 instances 中的边界框转换为指定的格式(由 self.bbox_format 指定),并对边界框坐标进行去归一化处理。
instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h)
# nl 代表 instances 的数量,即边界框的数量。
nl = len(instances)
# 处理掩码。
# 这段代码是 Format 类 __call__ 方法中处理掩码(masks)的部分。它根据 self.return_mask 的值决定是否需要处理和返回掩码。
# 条件判断。这个 if 语句检查 self.return_mask 是否为 True 。如果为 True ,则表示需要处理和返回掩码。
if self.return_mask:
# 检查实例数量。这里 nl 代表 instances 的数量,即边界框的数量。如果 nl 大于0,说明存在实例,需要进一步处理。
if nl:
# 格式化掩码。
# 如果存在实例,调用 self._format_segments 方法来格式化掩码。这个方法接受实例信息、类别标签、图像宽度和高度作为参数,并返回格式化后的掩码、更新后的实例信息和类别标签。
masks, instances, cls = self._format_segments(instances, cls, w, h)
# 将掩码从 NumPy 数组转换为 PyTorch 张量。
masks = torch.from_numpy(masks)
# 处理无实例的情况。
# 如果 nl 为0,即没有实例,根据 self.mask_overlap 的值决定是创建一个全零掩码(如果 self.mask_overlap 为 True )还是根据实例数量创建多个全零掩码(如果 self.mask_overlap 为 False )。
else:
# 这里的 img.shape[0] // self.mask_ratio 和 img.shape[1] // self.mask_ratio 分别计算掩码的 高度 和 宽度 ,根据 mask_ratio 参数进行调整。
masks = torch.zeros(
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
)
# 更新标签字典。最后,将处理好的掩码张量赋值给 labels 字典中的 "masks" 键。
labels["masks"] = masks
# 这段代码的目的是确保在需要掩码的情况下,能够正确地处理和返回掩码信息。通过这种方式, Format 类可以灵活地处理不同类型的图像标注信息,包括掩码,以适应不同的计算机视觉任务,如实例分割。
# 格式化图像。调用 _format_img 方法格式化图像,并将其结果存储在 labels 字典中。
labels["img"] = self._format_img(img)
# 格式化类别标签和边界框。
# 将类别标签 cls 和边界框 instances.bboxes 转换为 PyTorch 张量,并存储在 labels 字典中。
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
# 处理关键点。
# 如果需要返回关键点( self.return_keypoint 为 True ),则将 instances.keypoints 转换为 PyTorch 张量,并在需要时进行归一化处理。
if self.return_keypoint:
labels["keypoints"] = torch.from_numpy(instances.keypoints)
if self.normalize:
labels["keypoints"][..., 0] /= w
labels["keypoints"][..., 1] /= h
# 处理方向边界框(Oriented Bounding Box, OBB)。
# 如果需要返回 OBB( self.return_obb 为 True ),则将 instances.segments 转换为 OBB 格式,并存储在 labels 字典中。
if self.return_obb:
labels["bboxes"] = (
xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
)
# NOTE: need to normalize obb in xywhr format for width-height consistency
# 归一化边界框。
# 如果需要归一化( self.normalize 为 True ),则对边界框的宽度和高度进行归一化处理。
if self.normalize:
labels["bboxes"][:, [0, 2]] /= w
labels["bboxes"][:, [1, 3]] /= h
# Then we can use collate_fn
# 添加批次索引。
# 如果需要保留批次索引( self.batch_idx 为 True ),则在 labels 字典中添加 batch_idx 键。
if self.batch_idx:
# 这行代码的功能是为 labels 字典添加一个 "batch_idx" 键,其值是一个长度为 nl 的张量,所有元素都初始化为0。这个张量可以后续用于存储每个样本在批次中的实际索引,或者用于其他需要批次索引的场景。
# 应用场景 :
# 数据加载 :在数据加载阶段,可以使用批次索引来追踪每个样本的来源,特别是在使用数据增强或打乱数据顺序时。
# 模型训练 :在模型训练过程中,批次索引可以用来在批次内部或批次之间进行特定的操作,如计算批次损失、选择特定的样本进行反向传播等。
# 数据后处理 :在数据后处理阶段,批次索引可以用来将预测结果与原始样本关联起来,特别是在进行评估或分析时。
# 通过这种方式, labels 字典不仅包含了图像的标注信息,还包含了批次索引,使得数据的处理和追踪更加灵活和方便。
labels["batch_idx"] = torch.zeros(nl)
# 返回格式化后的标签。返回格式化后的 labels 字典。
return labels
# 这个方法的设计目的是为了将输入的标签数据格式化为统一的输出格式,以便于后续的处理和模型训练。通过设置不同的参数,可以灵活地配置数据处理流程,以适应不同的任务需求。
# 这段代码定义了 Format 类中的 _format_img 方法,该方法用于将输入的图像 img 格式化为适合深度学习模型处理的格式。
# 方法定义。这个方法接受 self (类的实例本身)和一个参数。
# 1.img :是一个 NumPy 数组,代表要格式化的图像。
def _format_img(self, img):
# 将 YOLO 的图像从 Numpy 数组格式化为 PyTorch 张量。
# 此函数执行以下操作:
# 1. 确保图像有 3 个维度(如果需要,添加一个通道维度)。
# 2. 将图像从 HWC 转置为 CHW 格式。
# 3. 可选地将颜色通道从 RGB 翻转为 BGR。
# 4. 将图像转换为连续数组。
# 5. 将 Numpy 数组转换为 PyTorch 张量。
"""
Formats an image for YOLO from a Numpy array to a PyTorch tensor.
This function performs the following operations:
1. Ensures the image has 3 dimensions (adds a channel dimension if needed).
2. Transposes the image from HWC to CHW format.
3. Optionally flips the color channels from RGB to BGR.
4. Converts the image to a contiguous array.
5. Converts the Numpy array to a PyTorch tensor.
Args:
img (np.ndarray): Input image as a Numpy array with shape (H, W, C) or (H, W).
Returns:
(torch.Tensor): Formatted image as a PyTorch tensor with shape (C, H, W).
Examples:
>>> import numpy as np
>>> img = np.random.rand(100, 100, 3)
>>> formatted_img = self._format_img(img)
>>> print(formatted_img.shape)
torch.Size([3, 100, 100])
"""
# 检查图像维度。
# 这行代码检查图像 img 的维度。如果维度小于3(即不是彩色图像),则在最后一个维度添加一个新的轴,使图像成为单通道的彩色图像。
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
# 转换图像维度顺序。
# 这行代码将图像的维度顺序从 HxWxC (高度x宽度x通道)转换为 CxHxW (通道x高度x宽度),这是因为 PyTorch 通常期望输入的图像通道在前。
img = img.transpose(2, 0, 1)
# 随机调整图像颜色顺序。
# 这行代码使用 random.uniform(0, 1) 生成一个0到1之间的随机数,如果这个数大于 self.bgr 值,则将图像通道顺序反转(例如,从RGB变为BGR或从BGR变为RGB)。 np.ascontiguousarray 确保结果是一个连续的NumPy数组。
img = np.ascontiguousarray(img[::-1] if random.uniform(0, 1) > self.bgr else img)
# 将 NumPy 数组转换为 PyTorch 张量。
# 这行代码将格式化后的 NumPy 数组 img 转换为 PyTorch 张量,以便可以在 PyTorch 模型中使用。
img = torch.from_numpy(img)
# 返回格式化后的图像。返回转换后的 PyTorch 张量。
return img
# _format_img 方法的主要作用是确保输入图像具有正确的维度和通道顺序,并且可以随机地调整颜色通道顺序。这对于确保图像数据与 PyTorch 模型的输入要求一致非常重要。通过这种方式,可以灵活地处理不同格式的图像数据,并将其标准化为模型可以接受的格式。
# 这段代码定义了 Format 类中的 _format_segments 方法,该方法用于将实例中的多边形(segments)转换成掩码(masks)。
# 方法定义。
# 1.instances :包含边界框和多边形信息的实例。
# 2.cls :类别标签。
# 3.w 、 4.h :图像的宽度、高度。
def _format_segments(self, instances, cls, w, h):
# 将多边形片段转换为位图掩码。
# 注意事项:
# - 如果 self.mask_overlap 为 True,则掩码会重叠并按面积排序。
# - 如果 self.mask_overlap 为 False,则每个掩码将单独表示。
# - 根据 self.mask_ratio 对掩码进行下采样。
"""
Converts polygon segments to bitmap masks.
Args:
instances (Instances): Object containing segment information.
cls (numpy.ndarray): Class labels for each instance.
w (int): Width of the image.
h (int): Height of the image.
Returns:
(tuple): Tuple containing:
masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
Notes:
- If self.mask_overlap is True, masks are overlapped and sorted by area.
- If self.mask_overlap is False, each mask is represented separately.
- Masks are downsampled according to self.mask_ratio.
"""
# 提取多边形信息。从 instances 中提取多边形信息,这些多边形定义了实例的轮廓。
segments = instances.segments
# 处理掩码重叠。
if self.mask_overlap:
# 如果 self.mask_overlap 为 True ,表示掩码之间可以重叠。使用 polygons2masks_overlap 函数将多边形转换为掩码,这个函数返回掩码和一个索引数组 sorted_idx ,用于重新排序 instances 和 cls 以匹配掩码的顺序。
# def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
# -> 用于将多个多边形(polygons)转换为重叠的掩码(masks)。返回最终的掩码和索引。返回最终的重叠掩码和排序后的索引。
# -> return masks, index
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
# masks[None] 将掩码的形状从 (h, w) 转换为 (1, h, w) ,以符合 PyTorch 的批次维度要求。
masks = masks[None] # (640, 640) -> (1, 640, 640)
# 重新排序 instances 。
# 这行代码的作用是根据 sorted_idx 中的索引对 instances 进行重新排序。这意味着 instances 中的元素将按照 sorted_idx 指定的顺序重新排列。这样做的目的是为了确保 instances 中的边界框(或其他信息)与之前生成的掩码 masks 的顺序相匹配,因为在处理掩码重叠时,掩码的生成顺序可能与原始 instances 的顺序不同。
instances = instances[sorted_idx]
# 重新排序 cls 。
# 这行代码的作用是根据 sorted_idx 中的索引对 cls 数组进行重新排序。这意味着 cls 数组中的元素将按照 sorted_idx 指定的顺序重新排列。这样做的目的是为了确保类别标签与重新排序后的 instances 和掩码 masks 的顺序相匹配。
cls = cls[sorted_idx]
# 处理不重叠的掩码。
else:
# 如果 self.mask_overlap 为 False ,表示掩码之间不能重叠。使用 polygons2masks 函数将多边形转换为掩码,这个函数不返回索引数组,因此不需要重新排序 instances 和 cls 。
# def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
# -> 用于将多个多边形(polygons)转换为对应的掩码(masks)。这个数组包含了所有多边形对应的掩码。
# -> return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
# 返回掩码和排序后的实例及类别标签。返回掩码、排序后的 instances 和 cls 。
return masks, instances, cls
# _format_segments 方法的主要作用是将实例中的多边形信息转换为掩码,这些掩码可以用于实例分割任务。根据 self.mask_overlap 的值,方法会选择不同的函数来处理掩码的生成,并且可能重新排序实例和类别标签以匹配掩码的顺序。这对于确保掩码与相应的实例和类别标签正确对应非常重要。
14.class RandomLoadText:
python
# 这段代码定义了一个名为 RandomLoadText 的 Python 类,用于处理文本数据的加载。
# 类定义。定义了一个名为 RandomLoadText 的类。
class RandomLoadText:
# 随机抽取正文本和负文本并相应地更新类索引。
# 此类负责从给定的一组类文本中抽取文本,包括正样本(存在于图像中)和负样本(不存在于图像中)。它更新类索引以反映抽样的文本,并可以选择将文本列表填充为固定长度。
# 方法:
# __call__ :处理输入标签并返回更新的类和文本。
"""
Randomly samples positive and negative texts and updates class indices accordingly.
This class is responsible for sampling texts from a given set of class texts, including both positive
(present in the image) and negative (not present in the image) samples. It updates the class indices
to reflect the sampled texts and can optionally pad the text list to a fixed length.
Attributes:
prompt_format (str): Format string for text prompts.
neg_samples (Tuple[int, int]): Range for randomly sampling negative texts.
max_samples (int): Maximum number of different text samples in one image.
padding (bool): Whether to pad texts to max_samples.
padding_value (str): The text used for padding when padding is True.
Methods:
__call__: Processes the input labels and returns updated classes and texts.
Examples:
>>> loader = RandomLoadText(prompt_format="Object: {}", neg_samples=(5, 10), max_samples=20)
>>> labels = {"cls": [0, 1, 2], "texts": [["cat"], ["dog"], ["bird"]], "instances": [...]}
>>> updated_labels = loader(labels)
>>> print(updated_labels["texts"])
['Object: cat', 'Object: dog', 'Object: bird', 'Object: elephant', 'Object: car']
"""
# 构造函数。构造函数接受多个参数,用于初始化类的实例。
# 1.prompt_format ( str ) : 用于格式化提示文本的字符串,默认为 "{}" 。这个格式字符串用于插入变量或动态文本。
# 2.neg_samples ( Tuple[int, int] ) : 一个包含两个整数的元组,表示负样本的数量范围,默认为 (80, 80) 。这与机器学习中的正负样本采样有关。
# 3.max_samples ( int ) : 最大样本数量,默认为 80 。这用于限制加载的样本总数。
# 4.padding ( bool ) : 是否进行填充操作,默认为 False 。在自然语言处理中,为了使所有样本具有相同的长度,通常会进行填充或截断操作。
# 5.padding_value ( str ) : 用于填充的值,默认为空字符串 "" 。
# 返回值 :构造函数的返回类型被标注为 -> None ,表示构造函数不返回任何值,这是 Python 类构造函数的常见做法。
def __init__(
self,
prompt_format: str = "{}",
neg_samples: Tuple[int, int] = (80, 80),
max_samples: int = 80,
padding: bool = False,
padding_value: str = "",
) -> None:
# 初始化 RandomLoadText 类,用于随机抽取正文本和负文本样本。
# 此类用于随机抽取正文本和负文本样本,并根据样本数量更新类索引。它可用于基于文本的对象检测任务。
"""
Initializes the RandomLoadText class for randomly sampling positive and negative texts.
This class is designed to randomly sample positive texts and negative texts, and update the class
indices accordingly to the number of samples. It can be used for text-based object detection tasks.
Args:
prompt_format (str): Format string for the prompt. Default is '{}'. The format string should
contain a single pair of curly braces {} where the text will be inserted.
neg_samples (Tuple[int, int]): A range to randomly sample negative texts. The first integer
specifies the minimum number of negative samples, and the second integer specifies the
maximum. Default is (80, 80).
max_samples (int): The maximum number of different text samples in one image. Default is 80.
padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always
be equal to max_samples. Default is False.
padding_value (str): The padding text to use when padding is True. Default is an empty string.
Attributes:
prompt_format (str): The format string for the prompt.
neg_samples (Tuple[int, int]): The range for sampling negative texts.
max_samples (int): The maximum number of text samples.
padding (bool): Whether padding is enabled.
padding_value (str): The value used for padding.
Examples:
>>> random_load_text = RandomLoadText(prompt_format="Object: {}", neg_samples=(50, 100), max_samples=120)
>>> random_load_text.prompt_format
'Object: {}'
>>> random_load_text.neg_samples
(50, 100)
>>> random_load_text.max_samples
120
"""
# 在构造函数中,每个参数都被赋值给同名的实例变量,这些变量可以在类的其他方法中使用。
self.prompt_format = prompt_format
self.neg_samples = neg_samples
self.max_samples = max_samples
self.padding = padding
self.padding_value = padding_value
# RandomLoadText 类用于在文本处理任务中随机加载和格式化文本数据。通过设置不同的参数,可以灵活地配置文本数据的加载和预处理流程,以适应不同的任务需求。
# 这段代码定义了 RandomLoadText 类的 __call__ 方法,该方法用于处理和随机选择文本标签。
# 这个方法接受 self (类的实例本身)和一个参数.
# 1.labels :是一个包含文本和其他标签信息的字典。
def __call__(self, labels: dict) -> dict:
# 随机抽样正文本和负文本并相应地更新类索引。
# 此方法根据 图像中现有的类标签 抽样正文本,并从剩余的类中随机选择负文本。然后,它更新类索引以匹配新的采样文本顺序。
"""
Randomly samples positive and negative texts and updates class indices accordingly.
This method samples positive texts based on the existing class labels in the image, and randomly
selects negative texts from the remaining classes. It then updates the class indices to match the
new sampled text order.
Args:
labels (Dict): A dictionary containing image labels and metadata. Must include 'texts' and 'cls' keys.
Returns:
(Dict): Updated labels dictionary with new 'cls' and 'texts' entries.
Examples:
>>> loader = RandomLoadText(prompt_format="A photo of {}", neg_samples=(5, 10), max_samples=20)
>>> labels = {"cls": np.array([[0], [1], [2]]), "texts": [["dog"], ["cat"], ["bird"]]}
>>> updated_labels = loader(labels)
"""
# 检查文本标签存在。确保 labels 字典中包含键 "texts" 。
assert "texts" in labels, "No texts found in labels." # 标签中未找到文本标签。
# 提取类别文本。
# 从 labels 中提取类别文本。
class_texts = labels["texts"]
# 计算类别数量。
num_classes = len(class_texts)
# 提取并转换类别标签。
# 从 labels 中提取类别标签,转换为 NumPy 数组。
cls = np.asarray(labels.pop("cls"), dtype=int)
# 找到所有唯一的正样本标签。
# np.unique(cls) :这是一个 NumPy 函数,用于找出 cls 数组中的唯一值。 cls 数组是一个包含类别标签的 NumPy 数组,可能包含重复的标签值。 np.unique 函数会返回一个包含所有唯一值的数组,并且这些值会按照升序排序。
# .tolist() :这是一个将 NumPy 数组转换为 Python 列表的方法。 np.unique(cls) 的结果是一个 NumPy 数组, tolist() 方法将这个数组转换成一个 Python 列表,这样得到的 pos_labels 列表就可以用于后续的 Python 操作,如随机抽样等。
pos_labels = np.unique(cls).tolist()
# 随机选择正样本标签。
# 如果正样本标签数量超过 self.max_samples ,则随机选择 self.max_samples 个标签。
if len(pos_labels) > self.max_samples:
# random.sample(population, k)
# random.sample 是 Python 标准库 random 模块中的一个函数,它用于从一个序列中随机选择指定数量的不重复元素,并返回一个新列表。
# population :一个序列,表示可供选择的元素集合。
# k :一个整数,表示需要随机选择的元素数量。
# 返回值 :
# 返回一个新列表,包含从 population 中随机选择的 k 个不重复元素。
# 功能 :
# random.sample 函数可以确保从 population 中选择的 k 个元素是唯一的,不会有重复。如果 population 中的元素数量小于 k ,则抛出 ValueError 。
# 注意事项 :
# population 必须是一个序列,例如列表、元组或字符串。
# k 的值不能大于 population 中元素的数量,否则会抛出 ValueError 。
# 每次调用 random.sample 都会生成一个新的随机选择的列表,因为随机数生成器的状态在每次调用时都会改变。
pos_labels = random.sample(pos_labels, k=self.max_samples)
# 计算负样本数量并选择负样本标签。
# 确定负样本数量。
# min(num_classes, self.max_samples) :这部分代码确定在正样本和最大样本限制之间较小的值。 num_classes 是类别总数, self.max_samples 是允许的最大样本数。
# len(pos_labels) :这是正样本标签的数量。
# random.randint(*self.neg_samples) :这部分代码生成一个介于 self.neg_samples 元组两个值之间的随机整数,表示负样本的可能数量范围。
# min(...) :最终确定的负样本数量是类别总数和最大样本数之差与随机生成的负样本数量之间的较小值。
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
# 确定负样本标签。
# 这行代码通过列表推导式创建一个包含所有非正样本标签的列表。 range(num_classes) 生成一个从 0 到 num_classes - 1 的整数序列,代表所有可能的类别索引。 if i not in pos_labels 确保只有不在 pos_labels 列表中的索引被包括在内,即选择负样本标签。
neg_labels = [i for i in range(num_classes) if i not in pos_labels]
# 随机选择负样本标签。
# 这行代码从 neg_labels 列表中随机选择 neg_samples 个样本。 k=neg_samples 指定了需要选择的样本数量。
neg_labels = random.sample(neg_labels, k=neg_samples)
# 合并并随机打乱样本标签。合并正负样本标签,并随机打乱顺序。
sampled_labels = pos_labels + neg_labels
# random.shuffle(x, random=None)
# random.shuffle() 是 Python 标准库 random 模块中的一个函数,用于随机打乱一个列表的元素顺序。
# 参数说明 :
# x :要被打乱的列表。 random.shuffle() 函数会直接修改这个列表,不返回任何值。
# random :一个可选的 random.Random 实例,用于提供随机数生成器。如果不提供,将使用 random 模块的默认随机数生成器。
# 返回值 :该函数不返回任何值,它直接在输入的列表 x 上进行操作,将列表中的元素随机打乱。
# random.shuffle() 是一个非常方便的函数,用于在需要随机化处理数据时,如在机器学习中的随机抽样或洗牌算法等场景。
random.shuffle(sampled_labels)
# 创建标签到索引的映射。创建一个从标签到索引的映射字典。
label2ids = {label: i for i, label in enumerate(sampled_labels)}
# 过滤和更新有效标签。
# 初始化有效索引数组。这行代码创建了一个布尔类型的 NumPy 数组 valid_idx ,其长度与 labels["instances"] 的长度相同。这个数组将用于标记哪些实例是有效的。
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
# 初始化新的类别标签列表。这行代码初始化了一个空列表 new_cls ,用于存储有效的类别标签。
new_cls = []
# 遍历类别标签。
# 这个 for 循环遍历 cls 数组中的每个类别标签。
# enumerate(cls.squeeze(-1).tolist()) :这会返回每个标签的索引 i 和值 label 。 cls.squeeze(-1) 用于去除数组末尾的单一维度, tolist() 将 NumPy 数组转换为 Python 列表。
for i, label in enumerate(cls.squeeze(-1).tolist()):
# 如果当前标签 label 不在 label2ids 映射中,表示它不是有效的标签,因此跳过当前循环。
if label not in label2ids:
continue
# 如果标签有效,将对应的 valid_idx 元素设置为 True ,表示对应的实例是有效的。
valid_idx[i] = True
# 将有效的标签索引添加到 new_cls 列表中。
new_cls.append([label2ids[label]])
# 更新 labels 字典。
# 使用布尔索引 valid_idx 来筛选 labels["instances"] ,只保留有效的实例。
labels["instances"] = labels["instances"][valid_idx]
# 将 new_cls 列表转换为 NumPy 数组,并更新 labels 字典中的 "cls" 键。
labels["cls"] = np.array(new_cls)
# Randomly select one prompt when there's more than one prompts
# 随机选择提示文本。
# 初始化文本列表。这行代码初始化了一个空列表 texts ,用于存储为每个样本标签选择的提示文本。
texts = []
# 遍历样本标签。这个 for 循环遍历 sampled_labels 列表中的每个标签, sampled_labels 是之前通过正负样本标签合并并随机打乱得到的列表。
for label in sampled_labels:
# 提取提示列表。对于每个标签 label ,从 class_texts 中提取对应的提示列表。 class_texts 是一个包含每个类别文本提示的数组或列表。
prompts = class_texts[label]
# 确保提示列表非空。这行代码使用 assert 语句确保 prompts 列表不为空。如果列表为空,将抛出 AssertionError 。
assert len(prompts) > 0
# 随机选择提示并格式化。
# random.randrange(stop)
# random.randrange(start, stop[, step])
# random.randrange() 是 Python 标准库 random 模块中的一个函数,它用于生成一个指定范围内的随机整数。
# 参数说明 :
# stop :一个整数,表示随机数生成的上限(不包括 stop ),即随机数的范围是从 0 到 stop-1 。
# start (可选) :一个整数,表示随机数生成的下限(包括 start ),默认值为 0。
# step (可选):一个整数,表示步长,即随机数的增量,默认值为 1。
# 返回值 :
# 该函数返回一个在指定范围内的随机整数。
# random.randrange() 是一个非常方便的函数,用于在需要随机选择一个特定范围内的整数时,如在循环中随机选择元素、生成随机索引或在游戏编程中模拟随机事件等场景。
# random.randrange(len(prompts)) :生成一个随机索引,用于从 prompts 列表中选择一个提示。
# prompts[random.randrange(len(prompts))] :使用随机索引从 prompts 列表中选择一个提示。
# self.prompt_format.format(...) :使用 self.prompt_format 字符串作为格式模板,并将随机选择的提示插入到模板中。 format 方法将替换格式字符串中的占位符 {} 为实际的提示文本。
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
# 将格式化后的提示添加到列表。将格式化后的提示 prompt 添加到 texts 列表中。
texts.append(prompt)
# 添加填充文本。
# 条件判断是否需要填充。这行代码检查 self.padding 属性是否为 True 。 self.padding 是在类 RandomLoadText 的构造函数中设置的一个参数,用于指定是否需要对文本数据进行填充。
if self.padding:
# 计算有效标签总数。如果需要填充,这行代码计算有效标签的总数,即正样本标签 pos_labels 和负样本标签 neg_labels 的数量之和。
valid_labels = len(pos_labels) + len(neg_labels)
# 计算需要填充的数量。这行代码计算需要填充的文本数量,即最大样本数 self.max_samples 减去有效标签的总数 valid_labels 。
num_padding = self.max_samples - valid_labels
# 判断是否需要进行填充。如果需要填充的文本数量 num_padding 大于0,表示实际的有效文本数量少于 self.max_samples 指定的最大样本数,需要进行填充。
if num_padding > 0:
# 执行填充操作。
# 这行代码将 self.padding_value 填充到 texts 列表中,填充的数量为 num_padding 。 self.padding_value 是在类 RandomLoadText 的构造函数中设置的一个参数,用于指定用于填充的值。
texts += [self.padding_value] * num_padding
# 更新文本标签。更新 labels 字典中的 "texts" 。
labels["texts"] = texts
# 返回更新后的标签字典。返回更新后的 labels 字典。
return labels
# __call__ 方法的主要作用是从 labels 字典中提取和处理文本标签,随机选择正负样本,格式化提示文本,并根据需要进行填充。这个方法使得文本数据的处理更加灵活和可控,适用于需要随机采样和格式化文本的场景。
15.def v8_transforms(dataset, imgsz, hyp, stretch=False):
python
# 这段代码定义了一个名为 v8_transforms 的函数,它用于构建一个图像增强的流程,通常用于计算机视觉任务,如目标检测或图像分类。这个函数使用了 Compose 类来组合多个图像变换,这些变换来自于 Albumentations 库,它是一个用于图像增强的库。
# 函数定义。
# 1.dataset :数据集对象,包含了数据集的信息。
# 2.imgsz :目标图像的尺寸。
# 3.hyp :超参数对象,包含了各种图像增强的概率和参数。
# 4.stretch :一个布尔值,表示是否应用拉伸变换。
def v8_transforms(dataset, imgsz, hyp, stretch=False):
# 应用一系列图像变换以进行 YOLOv8 训练。
# 此函数创建图像增强技术的组合,以准备用于 YOLOv8 训练的图像。它包括马赛克、复制粘贴、随机透视、混合和各种颜色调整等操作。
"""
Applies a series of image transformations for YOLOv8 training.
This function creates a composition of image augmentation techniques to prepare images for YOLOv8 training.
It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments.
Args:
dataset (Dataset): The dataset object containing image data and annotations.
imgsz (int): The target image size for resizing.
hyp (Dict): A dictionary of hyperparameters controlling various aspects of the transformations.
stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing.
Returns:
(Compose): A composition of image transformations to be applied to the dataset.
Examples:
>>> from ultralytics.data.dataset import YOLODataset
>>> dataset = YOLODataset(img_path="path/to/images", imgsz=640)
>>> hyp = {"mosaic": 1.0, "copy_paste": 0.5, "degrees": 10.0, "translate": 0.2, "scale": 0.9}
>>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp)
>>> augmented_data = transforms(dataset[0])
"""
# 预变换。
# A.Compose(transforms, bbox_params=None, keypoint_params=None, mask_params=None, p=1.0)
# albumentations.Compose 是 Albumentations 库中的一个函数,它用于组合多个图像变换(transforms)成一个复合变换(composition)。这个复合变换可以同时应用于图像和与之相关的标注(如边界框、关键点、掩码等)。
# 参数说明 :
# transforms : 一个包含多个变换的列表。每个变换可以是一个单独的变换类实例,也可以是一个变换序列。
# bbox_params : 一个 BboxParams 对象,用于定义如何应用变换到边界框。如果为 None ,则不会对边界框应用任何变换。
# keypoint_params : 一个 KeypointParams 对象,用于定义如何应用变换到关键点。如果为 None ,则不会对关键点应用任何变换。
# mask_params : 一个 MaskParams 对象,用于定义如何应用变换到掩码。如果为 None ,则不会对掩码应用任何变换。
# p : 一个浮点数,表示应用整个复合变换的概率。值范围从 0 到 1,其中 1 表示总是应用变换。
# 返回值 :
# 返回一个复合变换对象,可以被用于图像增强流程中。
# 这部分代码创建了一个预变换流程,包括马 赛克增强(Mosaic) 、 复制粘贴增强(CopyPaste) 和 随机透视变换(RandomPerspective) 。
# 这段代码定义了一个名为 pre_transform 的图像预处理流程,它使用 Compose 函数来组合多个图像增强变换。这些变换通常用于计算机视觉任务中,以增加数据集的多样性并提高模型的泛化能力。
# Compose 是 Albumentations 库中的一个函数,它允许你将多个图像变换组合成一个流程。这样,你可以按顺序应用一系列变换,而不是一次只应用一个。
# Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic) :
# Mosaic 是一种图像增强技术,它将四个不同的图像拼接成一个图像,通常用于目标检测任务中增加样本多样性。
# dataset 参数是数据集对象,包含了数据集的信息。 imgsz 是目标图像的尺寸。 p=hyp.mosaic 表示执行马赛克增强的概率,由超参数 hyp 控制。
# CopyPaste(p=hyp.copy_paste) :
# CopyPaste 是另一种图像增强技术,它将图像的一部分复制并粘贴到图像的另一个位置。
# p=hyp.copy_paste 表示执行复制粘贴增强的概率。
# RandomPerspective(...) :
# RandomPerspective 是一个随机透视变换,它可以模拟图像在三维空间中的视角变化。
# degrees 、 translate 、 scale 、 shear 和 perspective 参数控制透视变换的程度,这些参数都由超参数 hyp 控制。
# pre_transform 参数决定了在应用透视变换之前是否需要进行其他变换。如果 stretch 参数为 False ,则使用 LetterBox 来保持图像的长宽比;如果 stretch 为 True ,则不进行任何预变换。
# LetterBox 是一个变换,它将图像调整到指定的尺寸,同时保持图像的长宽比。如果 stretch 为 False ,则使用 LetterBox 将图像调整到 imgsz x imgsz 的尺寸。
# 这段代码定义了一个预处理流程,它包括马赛克增强、复制粘贴增强和随机透视变换。这些变换可以显著增加数据集的多样性,提高模型的泛化能力。通过使用 Compose 函数,这些变换可以按顺序应用到图像上,形成一个完整的预处理流程。
pre_transform = Compose(
[
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
# 如果 stretch 为 False ,则在透视变换之前使用 LetterBox 来保持图像的长宽比。
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
),
]
)
# 关键点增强。
# 获取翻转索引。这行代码尝试从 dataset.data 中获取一个名为 "flip_idx" 的数组,该数组包含了在水平翻转时需要特别处理的关键点索引。如果 "flip_idx" 不存在,则默认为一个空列表。
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
# 检查数据集是否使用关键点。
# 这行代码检查 dataset 是否使用关键点数据增强。如果 dataset.use_keypoints 为 True ,则表示数据集中包含了关键点信息,需要进行特殊处理。
if dataset.use_keypoints:
# 获取关键点形状。如果数据集使用关键点,这行代码尝试获取关键点的形状 kpt_shape ,它描述了关键点数据的结构。
kpt_shape = dataset.data.get("kpt_shape", None)
# 检查翻转索引和关键点形状。
# 这段代码检查 flip_idx 是否为空,并且水平翻转的概率 hyp.fliplr 是否大于 0。如果 flip_idx 为空,表示没有定义翻转索引,因此将 hyp.fliplr 设置为 0,并记录一条警告日志,表示由于没有定义翻转索引,水平翻转增强被禁用。
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
hyp.fliplr = 0.0
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") # 警告⚠️ data.yaml 中未定义"flip_idx"数组,设置增强"fliplr=0.0"。
# 如果 flip_idx 不为空,这段代码检查 flip_idx 的长度是否与 kpt_shape[0] 相等。如果不等,表示定义的翻转索引与关键点的数量不匹配,因此抛出一个 ValueError 异常。
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") # data.yaml flip_idx={flip_idx} 长度必须等于 kpt_shape[0]={kpt_shape[0]}。
# 这段代码的目的是确保在进行关键点数据增强时,翻转索引 flip_idx 被正确定义,并且与关键点的数量相匹配。这是必要的,因为水平翻转可能会影响关键点的位置,特别是在人体姿态估计等任务中,正确的关键点处理对于模型性能至关重要。通过这些检查,可以避免在数据增强过程中出现错误,确保数据的一致性和模型训练的有效性。
# 最终变换流程。
# 这部分代码创建了最终的变换流程,包括预变换、混合增强(MixUp)、随机HSV颜色变换、垂直翻转和水平翻转。水平翻转使用了 flip_idx 来确定哪些关键点需要翻转。
# 这段代码使用 Albumentations 库中的 Compose 函数来组合一系列的图像增强变换,创建一个完整的图像增强流程。
# 变换流程。
# pre_transform :
# pre_transform 是之前定义的预处理流程,包含了如马赛克、复制粘贴和随机透视变换等增强技术。
# MixUp :
# MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup) :MixUp 是一种图像增强技术,它通过混合两张图像来创建新的训练样本。
# dataset 参数是数据集对象。 pre_transform 是应用在每张图像上的预处理流程。 p=hyp.mixup 表示执行 MixUp 增强的概率,由超参数 hyp 控制。
# Albumentations :
# Albumentations(p=1.0) :这个变换应用了 Albumentations 库提供的一系列预定义变换,每个变换的概率都设置为 1.0,即一定会被应用。
# RandomHSV :
# RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v) :这个变换随机调整图像的色调(Hue)、饱和度(Saturation)和亮度(Value)。
# hgain 、 sgain 和 vgain 参数控制色调、饱和度和亮度的变化范围,这些参数都由超参数 hyp 控制。
# RandomFlip (垂直翻转) :
# RandomFlip(direction="vertical", p=hyp.flipud) :这个变换以一定的概率 p=hyp.flipud 垂直翻转图像。
# RandomFlip (水平翻转) :
# RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx) :这个变换以一定的概率 p=hyp.fliplr 水平翻转图像,并使用 flip_idx 来确定哪些关键点需要特别处理。
# 这段代码定义了一个完整的图像增强流程,它包括预处理、MixUp 增强、随机 HSV 调整、垂直翻转和水平翻转。这些变换可以显著增加数据集的多样性,提高模型的泛化能力。通过使用 Compose 函数,这些变换可以按顺序应用到图像上,形成一个完整的预处理和增强流程。
return Compose(
[
pre_transform,
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
]
) # transforms
# v8_transforms 函数构建了一个复杂的图像增强流程,包括多种增强技术,如马赛克、复制粘贴、透视变换、颜色变换和翻转。这些技术可以显著增加数据集的多样性,提高模型的泛化能力。这个函数的设计考虑到了关键点增强的特殊需求,并提供了灵活的配置选项。
16.def classify_transforms(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, interpolation="BILINEAR", crop_fraction: float = DEFAULT_CROP_FRACTION,):
python
# Classification augmentations -----------------------------------------------------------------------------------------
# 这段代码定义了一个名为 classify_transforms 的函数,它用于创建一组图像变换,这些变换通常用于图像分类任务。这个函数使用了 torchvision.transforms 模块来构建变换流程。
# 函数定义。
# 1.size :目标图像的尺寸,可以是整数或元组。
# 2.mean :图像数据的均值,用于归一化。
# 3.std :图像数据的标准差,用于归一化。
# 4.interpolation :插值方法,用于调整图像尺寸。
# 5.crop_fraction :裁剪比例,用于确定裁剪尺寸。
def classify_transforms(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
interpolation="BILINEAR",
crop_fraction: float = DEFAULT_CROP_FRACTION,
):
# 为分类任务创建图像变换组合。
# 此函数生成一系列 torchvision 变换,适用于在评估或推理期间对分类模型的图像进行预处理。变换包括调整大小、中心裁剪、转换为张量和规范化。
"""
Creates a composition of image transforms for classification tasks.
This function generates a sequence of torchvision transforms suitable for preprocessing images
for classification models during evaluation or inference. The transforms include resizing,
center cropping, conversion to tensor, and normalization.
Args:
size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a
tuple, it defines (height, width).
mean (tuple): Mean values for each RGB channel used in normalization.
std (tuple): Standard deviation values for each RGB channel used in normalization.
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
crop_fraction (float): Fraction of the image to be cropped.
Returns:
(torchvision.transforms.Compose): A composition of torchvision transforms.
Examples:
>>> transforms = classify_transforms(size=224)
>>> img = Image.open("path/to/image.jpg")
>>> transformed_img = transforms(img)
"""
# 导入模块。这行代码导入了 torchvision.transforms 模块,并将其别名设置为 T 。
import torchvision.transforms as T # scope for faster 'import ultralytics'
# 计算缩放尺寸。
# 这段代码根据 size 参数和 crop_fraction 参数计算缩放尺寸 scale_size 。
# 如果 size 是元组或列表,它会检查长度是否为2,并计算每个维度的缩放尺寸。
if isinstance(size, (tuple, list)):
assert len(size) == 2, f"'size' tuples must be length 2, not length {len(size)}" # size' 元组的长度必须是 2,而不是长度 {len(size)}。
scale_size = tuple(math.floor(x / crop_fraction) for x in size)
# 如果 size 是整数,它会计算缩放尺寸并确保结果是元组。
else:
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
# 构建变换流程。
# 这段代码根据 scale_size 的值决定如何调整图像尺寸。
# T.Resize(size, interpolation=InterpolationMode.BILINEAR)
# T.Resize() 是 PyTorch 的 torchvision.transforms 模块中的一个函数,它用于调整图像的大小。这个函数可以接收一个整数或一个元组作为参数,以指定输出图像的大小。
# 参数说明 :
# size :目标图像的大小,可以是一个整数或一个元组 (width, height) 。
# 如果是整数,表示将图像的最短边缩放到指定长度,同时保持长宽比。
# 如果是元组,表示将图像的宽度和高度分别调整为指定的尺寸。
# interpolation :插值方法,用于指定图像缩放时采用的插值算法。默认值为 InterpolationMode.BILINEAR ,即双线性插值。其他可选的插值方法包括 :
# InterpolationMode.NEAREST :最近邻插值。
# InterpolationMode.BILINEAR :双线性插值。
# InterpolationMode.BICUBIC :双三次插值。
# InterpolationMode.BOX :盒式插值。
# InterpolationMode.HAMMING :汉明窗插值。
# 返回值 :
# 该函数返回一个 Resize 变换对象,可以被用于 T.Compose() 之中,或者直接应用于图像。
# T.Resize() 是图像预处理中常用的操作之一,它可以帮助标准化图像输入到深度学习模型中,或者在数据增强过程中改变图像的尺寸。
# 如果 scale_size 的两个维度相等,它使用 T.Resize 调整图像尺寸,并指定插值方法。
if scale_size[0] == scale_size[1]:
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))]
# 如果不相等,它只调整最短边到目标尺寸。
else:
# Resize the shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
# 这段代码将其他变换添加到流程中,包括中心裁剪、转换为张量和归一化。
# 首先,通过中心裁剪来减少图像尺寸;然后,将裁剪后的图像转换为张量,使其可以被 PyTorch 模型处理;最后,通过归一化来调整图像的像素值,使其具有统一的分布,这有助于模型的训练和收敛。
tfl.extend(
[
# 中心裁剪 ( T.CenterCrop ) 。
# T.CenterCrop(size)
# T.CenterCrop 是一个变换,它从图像中裁剪出中心区域,裁剪后的图像尺寸由 size 参数指定。
# size 可以是一个整数或者一个元组 (width, height) 。如果 size 是整数,那么裁剪出的区域将是正方形;如果是元组,则裁剪出的区域将具有指定的宽度和高度。
T.CenterCrop(size),
# 转换为张量 ( T.ToTensor ) 。
# T.ToTensor()
# T.ToTensor 是一个变换,它将 PIL Image 或 NumPy ndarray 转换为 torch.Tensor 。
# 这个变换还会自动将图像的像素值从 [0, 255] 范围缩放到 [0.0, 1.0] 范围。
T.ToTensor(),
# 归一化 ( T.Normalize ) 。
# T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))
# T.Normalize 是一个变换,它对图像数据进行归一化,即减去均值( mean )并除以标准差( std )。
# mean 和 std 分别是归一化过程中使用的均值和标准差,它们可以是单个数值或者相同长度的数值列表(对于多通道图像)。
# torch.tensor(mean) 和 torch.tensor(std) 将均值和标准差转换为 PyTorch 张量,以便与图像张量进行广播操作。
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
]
)
# T.Compose(transforms)
# T.Compose() 是 PyTorch 的 torchvision.transforms 模块中的一个函数,它用于将多个图像变换组合成一个流程。这个函数允许你按顺序应用一系列变换,从而创建复杂的数据预处理或数据增强管道。
# 参数说明 :
# transforms :一个包含多个变换对象的列表或元组。这些变换对象可以是 torchvision.transforms 模块中定义的任何变换,如 Resize 、 CenterCrop 、 ToTensor 、 Normalize 等。
# 返回值 :
# 该函数返回一个组合变换对象,这个对象可以被用于图像数据的预处理或增强。
# T.Compose() 是构建深度学习模型训练和测试管道中不可或缺的工具,它使得数据预处理和增强步骤变得简洁和高效。通过组合不同的变换,你可以轻松地创建复杂的图像处理流程。
# 返回变换流程。这行代码返回一个由 T.Compose 创建的变换流程,它将按顺序应用所有定义的变换。
return T.Compose(tfl)
# classify_transforms 函数创建了一个图像变换流程,包括尺寸调整、中心裁剪、转换为张量和归一化。这个流程是图像分类任务中常用的预处理步骤,确保图像数据适合输入到深度学习模型中。通过调整 size 、 mean 、 std 、 interpolation 和 crop_fraction 参数,可以灵活地配置变换流程以适应不同的任务需求。
17.def classify_augmentations(size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, scale=None, ratio=None, hflip=0.5, vflip=0.0, auto_augment=None, hsv_h=0.015, hsv_s=0.4, hsv_v=0.4, force_color_jitter=False, erasing=0.0, interpolation="BILINEAR",):
python
# Classification training augmentations --------------------------------------------------------------------------------
# 这段代码定义了一个名为 classify_augmentations 的函数,它用于创建一组图像增强变换,这些变换通常用于图像分类任务。这个函数使用了 torchvision.transforms 模块来构建变换流程。
# 函数定义。
# 1.size :目标图像的尺寸。
# 2.mean 和 3.std :用于归一化的均值和标准差。
# 4.scale 和 5.ratio :用于随机裁剪的尺寸比例和宽高比范围。
# 6.hflip 和 7.vflip :水平和垂直翻转的概率。
# 8.auto_augment :自动增强策略的名称。
# 9.hsv_h , 10.hsv_s , 11.hsv_v :HSV颜色空间中的颜色抖动参数。
# 12.force_color_jitter :是否强制应用颜色抖动。
# 13.erasing :随机擦除的概率。
# 14.interpolation :插值方法。
def classify_augmentations(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.0,
auto_augment=None,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.0,
interpolation="BILINEAR",
):
# 为分类任务创建图像增强变换组合。
# 此函数生成一组适合训练分类模型的图像变换。它包括调整大小、翻转、颜色抖动、自动增强和随机擦除的选项。
"""
Creates a composition of image augmentation transforms for classification tasks.
This function generates a set of image transformations suitable for training classification models. It includes
options for resizing, flipping, color jittering, auto augmentation, and random erasing.
Args:
size (int): Target size for the image after transformations.
mean (tuple): Mean values for normalization, one per channel.
std (tuple): Standard deviation values for normalization, one per channel.
scale (tuple | None): Range of size of the origin size cropped.
ratio (tuple | None): Range of aspect ratio of the origin aspect ratio cropped.
hflip (float): Probability of horizontal flip.
vflip (float): Probability of vertical flip.
auto_augment (str | None): Auto augmentation policy. Can be 'randaugment', 'augmix', 'autoaugment' or None.
hsv_h (float): Image HSV-Hue augmentation factor.
hsv_s (float): Image HSV-Saturation augmentation factor.
hsv_v (float): Image HSV-Value augmentation factor.
force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled.
erasing (float): Probability of random erasing.
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
Returns:
(torchvision.transforms.Compose): A composition of image augmentation transforms.
Examples:
>>> transforms = classify_augmentations(size=224, auto_augment="randaugment")
>>> augmented_image = transforms(original_image)
"""
# Transforms to apply if Albumentations not installed
# 导入模块。这行代码导入了 torchvision.transforms 模块,并将其别名设置为 T 。
import torchvision.transforms as T # scope for faster 'import ultralytics'
# 构建主要变换流程。
# 这段代码是 classify_augmentations 函数的一部分,它负责设置图像分类任务中的主要数据增强步骤。
# 检查 size 类型。
# 这行代码检查 size 参数是否为整数类型。如果不是,将抛出一个 TypeError 异常,因为 size 应该是一个整数,表示目标图像的尺寸。
if not isinstance(size, int):
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") # classify_transforms() 大小 {size} 必须是整数,而不是 (列表,元组)。
# 设置默认缩放比例和宽高比。
# 这两行代码设置图像随机裁剪时的缩放比例 scale 和宽高比 ratio 。如果用户没有提供这些参数,将使用默认值。
# scale 的默认值是 (0.08, 1.0) ,表示图像尺寸可以在 8% 到 100% 之间随机缩放。
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
# ratio 的默认值是 (3.0 / 4.0, 4.0 / 3.0) ,表示宽高比可以在 0.75 到 1.33 之间随机变化。
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
# 设置插值方法。
# 这行代码根据 interpolation 参数的值获取对应的插值方法。 T.InterpolationMode 是 torchvision.transforms 中定义的插值方法枚举类。
interpolation = getattr(T.InterpolationMode, interpolation)
# 构建主要变换列表。
# T.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.33), interpolation=InterpolationMode.BILINEAR)
# T.RandomResizedCrop() 是 PyTorch 的 torchvision.transforms 模块中的一个函数,它用于随机裁剪图像到不同的大小和宽高比,然后调整裁剪的图像到指定的大小。这个变换特别有用,因为它可以帮助模型学习从图像的不同部分和不同的尺度中识别对象,从而提高模型的泛化能力。
# 参数说明 :
# size :目标图像的大小,可以是一个整数或一个元组 (height, width) 。如果是一个整数,那么裁剪后的图像将是正方形。
# scale :一个元组,表示随机裁剪区域面积与原图面积的比例范围。例如, (0.08, 1.0) 表示裁剪区域的面积将在原图面积的 8% 到 100% 之间。
# ratio :一个元组,表示随机裁剪区域的宽高比范围。例如, (0.75, 1.33) 表示裁剪区域的宽高比将在 3:4 到 4:3 之间。
# interpolation :插值方法,用于调整裁剪后的图像到指定大小。默认是 InterpolationMode.BILINEAR ,即双线性插值。其他可选的插值方法包括 InterpolationMode.NEAREST 、 InterpolationMode.BICUBIC 等。
# 返回值 :
# 该函数返回一个 RandomResizedCrop 变换对象,可以被用于 T.Compose() 之中,或者直接应用于图像。
# T.RandomResizedCrop() 是数据增强中常用的一个变换,它通过随机裁剪和调整图像大小,增加了数据集的多样性,有助于模型学习到更加鲁棒的特征。
# 这行代码创建了一个主要变换列表 primary_tfl ,首先添加了 RandomResizedCrop 变换,它将图像随机裁剪到指定的尺寸 size ,同时考虑缩放比例 scale 和宽高比 ratio 。
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
# 添加水平和垂直翻转。
# 这两行代码检查水平翻转 hflip 和垂直翻转 vflip 的概率。如果这些概率大于 0,将对应的翻转变换添加到 primary_tfl 列表中。这些变换将以指定的概率随机水平或垂直翻转图像。
if hflip > 0.0:
# T.RandomHorizontalFlip(p=0.5)
# T.RandomHorizontalFlip() 是 PyTorch 的 torchvision.transforms 模块中的一个函数,它用于以一定的概率对 PIL 图像进行水平翻转。这个变换可以用于数据增强,以提高模型对图像水平翻转变化的鲁棒性。
# 参数说明 :
# p :一个浮点数,表示图像被水平翻转的概率。默认值为 0.5 ,即有 50% 的概率进行水平翻转。
# 返回值 :
# 该函数返回一个 RandomHorizontalFlip 变换对象,可以被用于 T.Compose() 之中,或者直接应用于图像。
# T.RandomHorizontalFlip() 是数据增强中常用的一个变换,它通过随机水平翻转图像,增加了数据集的多样性,有助于模型学习到更加鲁棒的特征。通过调整参数 p ,可以灵活地控制翻转的概率。
primary_tfl.append(T.RandomHorizontalFlip(p=hflip))
if vflip > 0.0:
# T.RandomVerticalFlip(p=0.5)
# T.RandomVerticalFlip() 是 PyTorch 的 torchvision.transforms 模块中的一个函数,它用于以一定的概率对 PIL 图像进行垂直翻转。这个变换可以用于数据增强,以提高模型对图像垂直翻转变化的鲁棒性。
# 参数说明 :
# p :一个浮点数,表示图像被垂直翻转的概率。默认值为 0.5 ,即有 50% 的概率进行垂直翻转。
# 返回值 :
# 该函数返回一个 RandomVerticalFlip 变换对象,可以被用于 T.Compose() 之中,或者直接应用于图像。
# T.RandomVerticalFlip() 是数据增强中常用的一个变换,它通过随机垂直翻转图像,增加了数据集的多样性,有助于模型学习到更加鲁棒的特征。通过调整参数 p ,可以灵活地控制翻转的概率。
primary_tfl.append(T.RandomVerticalFlip(p=vflip))
# 这段代码负责设置图像分类任务中的主要数据增强步骤,包括随机裁剪、水平翻转和垂直翻转。这些步骤有助于增加数据集的多样性,提高模型的泛化能力。通过调整参数,可以灵活地配置变换流程以适应不同的任务需求。
# 构建次要变换流程。
# 这段代码是 classify_augmentations 函数的一部分,它负责根据 auto_augment 参数的值来决定是否应用特定的自动增强策略,并构建次要的变换列表 secondary_tfl 。
# 初始化次要变换列表。这行代码初始化了一个空列表 secondary_tfl ,用于存储次要的图像增强变换。
secondary_tfl = []
# 设置颜色抖动禁用标志。这行代码初始化了一个标志 disable_color_jitter ,默认为 False ,表示颜色抖动(ColorJitter)变换默认是启用的。
disable_color_jitter = False
# 检查自动增强策略。
# 这行代码检查 auto_augment 参数是否为字符串类型,如果不是,则抛出 AssertionError 异常。
if auto_augment:
assert isinstance(auto_augment, str), f"Provided argument should be string, but got type {type(auto_augment)}" # 提供的参数应该是字符串,但得到的类型是 {type(auto_augment)。
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not force_color_jitter
# 根据自动增强策略添加变换。
# 这段代码检查 auto_augment 参数是否为 "randaugment" ,如果是,则检查 TORCHVISION_0_11 标志(表示 torchvision 版本是否大于等于 0.11.0)。
if auto_augment == "randaugment":
# 如果是,则添加 RandAugment 变换到 secondary_tfl 列表中;如果不是,则记录一条警告日志,并禁用该增强策略。
if TORCHVISION_0_11:
secondary_tfl.append(T.RandAugment(interpolation=interpolation))
else:
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
# 类似地,这段代码检查 auto_augment 参数是否为 "augmix" ,如果是,则检查 TORCHVISION_0_13 标志(表示 torchvision 版本是否大于等于 0.13.0)。
elif auto_augment == "augmix":
# 如果是,则添加 AugMix 变换到 secondary_tfl 列表中;如果不是,则记录一条警告日志,并禁用该增强策略。
if TORCHVISION_0_13:
secondary_tfl.append(T.AugMix(interpolation=interpolation))
else:
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
# 这段代码检查 auto_augment 参数是否为 "autoaugment" ,如果是,则检查 TORCHVISION_0_10 标志(表示 torchvision 版本是否大于等于 0.10.0)。
elif auto_augment == "autoaugment":
# 如果是,则添加 AutoAugment 变换到 secondary_tfl 列表中;如果不是,则记录一条警告日志,并禁用该增强策略。
if TORCHVISION_0_10:
secondary_tfl.append(T.AutoAugment(interpolation=interpolation))
else:
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
# 处理无效的自动增强策略。
# 如果 auto_augment 参数不是上述任何一个有效值,则抛出 ValueError 异常。
else:
raise ValueError(
f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
f'"augmix", "autoaugment" or None'
)
# 这段代码负责根据 auto_augment 参数的值来决定是否应用特定的自动增强策略,并构建次要的变换列表 secondary_tfl 。这些自动增强策略包括 RandAugment、AugMix 和 AutoAugment,它们都是数据增强技术,可以提高模型的泛化能力。
# 通过调整 auto_augment 参数,可以灵活地配置变换流程以适应不同的任务需求。
# 应用颜色抖动。如果 disable_color_jitter 为 False ,则应用颜色抖动。
if not disable_color_jitter:
secondary_tfl.append(T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h))
# 构建最终变换流程。这段代码创建了最终的变换列表,包括转换为张量、归一化和随机擦除。
final_tfl = [
T.ToTensor(),
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
T.RandomErasing(p=erasing, inplace=True),
]
# 返回组合变换。这行代码返回一个由 T.Compose 创建的变换流程,它将按顺序应用所有定义的变换。
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
# classify_augmentations 函数创建了一个图像增强流程,包括随机裁剪、翻转、自动增强策略、颜色抖动、归一化和随机擦除。这些变换可以显著增加数据集的多样性,提高模型的泛化能力。通过调整参数,可以灵活地配置变换流程以适应不同的任务需求。
18.class ClassifyLetterBox:
python
# NOTE: keep this class for backward compatibility
# 这段代码定义了一个名为 ClassifyLetterBox 的类,它用于将图像进行 Letterbox 风格的填充,以便于进行目标检测或图像分类任务。Letterbox 填充是一种保持图像长宽比的同时,将图像缩放到目标尺寸的方法,通常会在图像周围添加填充以达到所需的尺寸。
# 类定义。定义了一个名为 ClassifyLetterBox 的类。
class ClassifyLetterBox:
# 用于调整图像大小和填充图像以进行分类任务的类。
# 此类旨在成为转换管道的一部分,例如 T.Compose([LetterBox(size), ToTensor()])。它将图像调整大小并填充到指定大小,同时保持原始纵横比。
# 方法:
# __call__ :将信箱转换应用于输入图像。
"""
A class for resizing and padding images for classification tasks.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
It resizes and pads images to a specified size while maintaining the original aspect ratio.
Attributes:
h (int): Target height of the image.
w (int): Target width of the image.
auto (bool): If True, automatically calculates the short side using stride.
stride (int): The stride value, used when 'auto' is True.
Methods:
__call__: Applies the letterbox transformation to an input image.
Examples:
>>> transform = ClassifyLetterBox(size=(640, 640), auto=False, stride=32)
>>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
>>> result = transform(img)
>>> print(result.shape)
(640, 640, 3)
"""
# 构造函数。构造函数接受以下参数。
# 1.size :目标图像的尺寸,可以是整数或元组,表示目标图像的宽度和高度。
# 2.auto :布尔值,表示是否自动计算短边尺寸以满足步长(stride)要求。
# 3.stride :整数,表示步长,用于确保输出图像的宽度和高度是步长的整数倍。
def __init__(self, size=(640, 640), auto=False, stride=32):
# 初始化 ClassifyLetterBox 对象以进行图像预处理。
# 此类旨在成为图像分类任务转换管道的一部分。它将图像调整大小并填充到指定大小,同时保持原始纵横比。
"""
Initializes the ClassifyLetterBox object for image preprocessing.
This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and
pads images to a specified size while maintaining the original aspect ratio.
Args:
size (int | Tuple[int, int]): Target size for the letterboxed image. If an int, a square image of
(size, size) is created. If a tuple, it should be (height, width).
auto (bool): If True, automatically calculates the short side based on stride. Default is False.
stride (int): The stride value, used when 'auto' is True. Default is 32.
Attributes:
h (int): Target height of the letterboxed image.
w (int): Target width of the letterboxed image.
auto (bool): Flag indicating whether to automatically calculate short side.
stride (int): Stride value for automatic short side calculation.
Examples:
>>> transform = ClassifyLetterBox(size=224)
>>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
>>> result = transform(img)
>>> print(result.shape)
(224, 224, 3)
"""
super().__init__()
# 这些行代码初始化实例变量 h 和 w 为 size 参数的值, auto 为 auto 参数的值, stride 为 stride 参数的值。
self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride
self.stride = stride # used with auto
# 这个方法使得类的实例可以像函数一样被调用,接受一个参数。
# 1.im :表示要进行 Letterbox 填充的图像。
def __call__(self, im):
# 使用 letterbox 方法调整图像大小并填充图像。
# 此方法调整输入图像的大小以适应指定的尺寸,同时保持其纵横比,然后填充调整大小后的图像以匹配目标大小。
"""
Resizes and pads an image using the letterbox method.
This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio,
then pads the resized image to match the target size.
Args:
im (numpy.ndarray): Input image as a numpy array with shape (H, W, C).
Returns:
(numpy.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are
the target height and width respectively.
Examples:
>>> letterbox = ClassifyLetterBox(size=(640, 640))
>>> image = np.random.randint(0, 255, (720, 1280, 3), dtype=np.uint8)
>>> resized_image = letterbox(image)
>>> print(resized_image.shape)
(640, 640, 3)
"""
# 计算图像尺寸和比例。
# 这些行代码计算原始图像的尺寸,计算缩放比例 r ,以及缩放后的图像尺寸 h 和 w 。
imh, imw = im.shape[:2]
r = min(self.h / imh, self.w / imw) # ratio of new/old dimensions
h, w = round(imh * r), round(imw * r) # resized image dimensions
# Calculate padding dimensions
# 计算填充尺寸。
# 这些行代码计算填充后的图像尺寸 hs 和 ws ,以及填充的顶部和左侧位置 top 和 left 。
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else (self.h, self.w)
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
# Create padded image
# 创建填充图像。
# 这些行代码创建一个新的填充图像 im_out ,其尺寸为 hs x ws ,填充值为 114(通常用于表示背景的均值)。
im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
# 将原始图像缩放到 w x h 的尺寸,并放置在填充图像的中心位置。
im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
# 返回填充后的图像。返回填充后的图像 im_out 。
return im_out
# ClassifyLetterBox 类提供了一种将图像进行 Letterbox 填充的方法,保持了图像的长宽比,同时满足了目标检测或图像分类任务中对图像尺寸的要求。通过调整 size 、 auto 和 stride 参数,可以灵活地配置填充流程以适应不同的任务需求。
19.class CenterCrop:
python
# NOTE: keep this class for backward compatibility
# 这段代码定义了一个名为 CenterCrop 的类,它用于对图像进行中心裁剪。
# 类定义。定义了一个名为 CenterCrop 的类。
class CenterCrop:
# 对图像应用中心裁剪以进行分类任务。
# 此类对输入图像执行中心裁剪,将其调整为指定大小,同时保持纵横比。它旨在成为转换管道的一部分,例如 T.Compose([CenterCrop(size), ToTensor()])。
# 方法:
# __call__ :将中心裁剪转换应用于输入图像。
"""
Applies center cropping to images for classification tasks.
This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect
ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
Attributes:
h (int): Target height of the cropped image.
w (int): Target width of the cropped image.
Methods:
__call__: Applies the center crop transformation to an input image.
Examples:
>>> transform = CenterCrop(640)
>>> image = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8)
>>> cropped_image = transform(image)
>>> print(cropped_image.shape)
(640, 640, 3)
"""
# 构造函数。构造函数接受一个参数。
# 1.size :表示裁剪后的图像尺寸。
def __init__(self, size=640):
# 初始化 CenterCrop 对象以进行图像预处理。
# 此类旨在成为转换管道的一部分,例如 T.Compose([CenterCrop(size), ToTensor()])。它对输入图像执行中心裁剪以达到指定大小。
"""
Initializes the CenterCrop object for image preprocessing.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]).
It performs a center crop on input images to a specified size.
Args:
size (int | Tuple[int, int]): The desired output size of the crop. If size is an int, a square crop
(size, size) is made. If size is a sequence like (h, w), it is used as the output size.
Returns:
(None): This method initializes the object and does not return anything.
Examples:
>>> transform = CenterCrop(224)
>>> img = np.random.rand(300, 300, 3)
>>> cropped_img = transform(img)
>>> print(cropped_img.shape)
(224, 224, 3)
"""
super().__init__()
# 实例变量初始化。
# 这行代码初始化实例变量 h 和 w 为 size 参数的值。如果 size 是整数,那么裁剪后的图像将是正方形;如果是元组,那么裁剪后的图像将具有指定的宽度和高度。
self.h, self.w = (size, size) if isinstance(size, int) else size
# 这个方法使得类的实例可以像函数一样被调用,接受一个参数。
# 1.im :表示要进行中心裁剪的图像。
def __call__(self, im):
# 对输入图像应用中心裁剪。
# 此方法使用信箱方法调整图像大小并裁剪图像中心。它在将原始图像调整到指定尺寸的同时保持原始图像的纵横比。
"""
Applies center cropping to an input image.
This method resizes and crops the center of the image using a letterbox method. It maintains the aspect
ratio of the original image while fitting it into the specified dimensions.
Args:
im (numpy.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a
PIL Image object.
Returns:
(numpy.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C).
Examples:
>>> transform = CenterCrop(size=224)
>>> image = np.random.randint(0, 255, (640, 480, 3), dtype=np.uint8)
>>> cropped_image = transform(image)
>>> assert cropped_image.shape == (224, 224, 3)
"""
# 转换图像格式。
# 如果输入的图像 im 是 PIL Image 对象,这行代码将其转换为 NumPy 数组。
if isinstance(im, Image.Image): # convert from PIL to numpy array if required
im = np.asarray(im)
# 获取图像尺寸。这行代码获取图像 im 的 高度 和 宽度 。
imh, imw = im.shape[:2]
# 计算最小维度。这行代码计算图像的最小维度,用于确定裁剪的大小。
m = min(imh, imw) # min dimension
# 裁剪图像。这行代码计算裁剪的起始位置,确保裁剪区域位于图像的中心。
top, left = (imh - m) // 2, (imw - m) // 2
# 裁剪图像。
# 这行代码裁剪图像的中心区域,并将其缩放到指定的尺寸 self.w x self.h 。
return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
# CenterCrop 类提供了一种将图像进行中心裁剪并缩放到指定尺寸的方法。这种变换通常用于图像预处理,以确保输入到模型的图像具有一致的尺寸。通过调整 size 参数,可以灵活地配置裁剪流程以适应不同的任务需求。
20.class ToTensor:
python
# NOTE: keep this class for backward compatibility 注意:保留此类是为了向后兼容。
# 这段代码定义了一个名为 ToTensor 的类,它用于将 NumPy 数组格式的图像转换为 PyTorch 张量(tensor),并进行一些预处理步骤。
# 类定义。定义了一个名为 ToTensor 的类。
class ToTensor:
# 将图像从 numpy 数组转换为 PyTorch 张量。
# 此类旨在成为转换管道的一部分,例如 T.Compose([LetterBox(size), ToTensor()])。
# 方法:
# __call__ :将张量转换应用于输入图像。
# 注释:
# 输入图像应为 BGR 格式,形状为 (H, W, C)。
# 输出张量将为 RGB 格式,形状为 (C, H, W),归一化为 [0, 1]。
"""
Converts an image from a numpy array to a PyTorch tensor.
This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]).
Attributes:
half (bool): If True, converts the image to half precision (float16).
Methods:
__call__: Applies the tensor conversion to an input image.
Examples:
>>> transform = ToTensor(half=True)
>>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
>>> tensor_img = transform(img)
>>> print(tensor_img.shape, tensor_img.dtype)
torch.Size([3, 640, 640]) torch.float16
Notes:
The input image is expected to be in BGR format with shape (H, W, C).
The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1].
"""
# 构造函数。构造函数接受一个参数。
# 1.half :它是一个布尔值,表示是否将张量转换为半精度浮点数(float16)。
def __init__(self, half=False):
# 初始化 ToTensor 对象以将图像转换为 PyTorch 张量。
# 此类旨在用作 Ultralytics YOLO 框架中图像预处理的转换管道的一部分。它将 numpy 数组或 PIL 图像转换为 PyTorch 张量,并提供半精度 (float16) 转换选项。
"""
Initializes the ToTensor object for converting images to PyTorch tensors.
This class is designed to be used as part of a transformation pipeline for image preprocessing in the
Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option
for half-precision (float16) conversion.
Args:
half (bool): If True, converts the tensor to half precision (float16). Default is False.
Examples:
>>> transform = ToTensor(half=True)
>>> img = np.random.rand(640, 640, 3)
>>> tensor_img = transform(img)
>>> print(tensor_img.dtype)
torch.float16
"""
super().__init__()
# 实例变量初始化。这行代码初始化实例变量 half 为传入的 half 参数值。
self.half = half
# 这个方法使得类的实例可以像函数一样被调用,接受一个参数。
# 1.im :表示要转换的图像。
def __call__(self, im):
# 将图像从 numpy 数组转换为 PyTorch 张量。
# 此方法将输入图像从 numpy 数组转换为 PyTorch 张量,应用可选的半精度转换和规范化。图像从 HWC 转置为 CHW 格式,颜色通道从 BGR 反转为 RGB。
"""
Transforms an image from a numpy array to a PyTorch tensor.
This method converts the input image from a numpy array to a PyTorch tensor, applying optional
half-precision conversion and normalization. The image is transposed from HWC to CHW format and
the color channels are reversed from BGR to RGB.
Args:
im (numpy.ndarray): Input image as a numpy array with shape (H, W, C) in BGR order.
Returns:
(torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized
to [0, 1] with shape (C, H, W) in RGB order.
Examples:
>>> transform = ToTensor(half=True)
>>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
>>> tensor_img = transform(img)
>>> print(tensor_img.shape, tensor_img.dtype)
torch.Size([3, 640, 640]) torch.float16
"""
# 转换图像维度和颜色通道顺序。
# 这行代码首先使用 transpose((2, 0, 1)) 将图像从高度x宽度x通道(HWC)格式转换为通道x高度x宽度(CHW)格式,然后将 BGR 颜色通道顺序反转为 RGB。
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
# 将 NumPy 数组转换为 PyTorch 张量。这行代码将 NumPy 数组转换为 PyTorch 张量。
im = torch.from_numpy(im) # to torch
# 转换数据类型。如果 self.half 为 True ,则将张量转换为半精度浮点数(float16),否则转换为单精度浮点数(float32)。
im = im.half() if self.half else im.float() # uint8 to fp16/32
# 像素值归一化是图像预处理中的一个重要步骤,它有以下几个作用:
# 缩放数据范围 :
# 归一化将像素值从整数范围 [0, 255] 转换为浮点数范围 [0.0, 1.0]。这种缩放有助于数值稳定性,尤其是在后续的计算中,如梯度下降等优化算法,可以减少计算过程中的数值问题。
# 加快收敛 :
# 归一化可以加快深度学习模型训练过程中的收敛速度。如果不同特征(像素值)的数据范围差异很大,那么模型可能需要更多的迭代次数来学习。通过归一化,可以使得不同特征的数据范围统一,从而加速模型的学习过程。
# 提高模型性能 :
# 归一化有助于提高模型的性能。当输入数据的分布接近标准正态分布时,许多机器学习算法的表现会更好。归一化可以使数据的分布更加接近标准正态分布,从而提高模型的性能。
# 避免梯度爆炸或消失 :
# 在深度学习中,如果输入数据的值范围很大,可能会导致梯度爆炸或消失的问题。归一化可以减少这种情况的发生,因为它限制了数据的值范围,使得梯度更加稳定。
# 提高数值计算的精度 :
# 对于浮点数计算,如果数值范围过大,可能会导致精度损失。归一化可以减少这种损失,因为它将数据范围限制在了一个较小的区间内。
# 使得不同特征具有可比性 :
# 在某些情况下,不同的特征可能有不同的量纲和数据范围。归一化可以使这些特征具有可比性,这对于某些算法(如基于距离的算法)来说是必要的。
# 数据预处理的标准化 :
# 归一化是数据预处理的一部分,它使得数据在不同的训练周期和不同的数据集之间保持一致性,这对于模型的泛化能力是有益的。
# 总之,像素值归一化是图像处理和机器学习中常用的技术,它有助于提高模型的训练效率和性能,同时减少数值计算中的问题。
# 归一化像素值。这行代码将像素值从 [0, 255] 范围归一化到 [0.0, 1.0] 范围。
im /= 255.0 # 0-255 to 0.0-1.0
# 返回转换后的张量。返回转换后的 PyTorch 张量。
return im
# ToTensor 类提供了一种将图像从 NumPy 数组格式转换为 PyTorch 张量的方法,包括颜色通道顺序的转换、数据类型的转换以及像素值的归一化。这种变换通常用于图像预处理,以确保输入到模型的图像具有一致的数据类型和值范围。通过调整 half 参数,可以灵活地配置数据类型的转换,以适应不同的计算需求和硬件限制。