anchor_generator.py
utils\tal\anchor_generator.py
目录
[2.def make_anchors(feats, strides, grid_cell_offset=0.5):](#2.def make_anchors(feats, strides, grid_cell_offset=0.5):)
[3.def dist2bbox(distance, anchor_points, xywh=True, dim=-1):](#3.def dist2bbox(distance, anchor_points, xywh=True, dim=-1):)
[4.def bbox2dist(anchor_points, bbox, reg_max):](#4.def bbox2dist(anchor_points, bbox, reg_max):)
1.所需的库和模块
python
import torch
from utils.general import check_version
# 使用 check_version 函数来比较当前安装的 PyTorch 版本和字符串 '1.10.0' 指定的版本。 TORCH_1_10 是一个布尔值,表示当前 PyTorch 版本是否大于等于 '1.10.0' 。
# def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
# -> 它用于检查当前安装的软件包版本是否满足指定的最低版本要求。函数返回 result ,即版本检查的结果。
# -> return result
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
2.def make_anchors(feats, strides, grid_cell_offset=0.5):
python
# 这段代码定义了一个名为 make_anchors 的函数,它用于生成YOLO模型中使用的锚点(anchor points)。
# 这是 make_anchors 函数的定义,它接受三个参数。
# 1.feats :一个包含特征图的列表。
# 2.strides :一个包含对应特征图的步长的列表。
# 3.grid_cell_offset :一个浮点数,表示网格单元的偏移量,默认值为0.5。
def make_anchors(feats, strides, grid_cell_offset=0.5):
# 根据特征生成锚点。
"""Generate anchors from features."""
# 初始化两个空列表, anchor_points 用于存储锚点坐标, stride_tensor 用于存储步长。
anchor_points, stride_tensor = [], []
# 断言语句,确保传入的特征图列表 feats 不是 None 。
assert feats is not None
# 获取 特征图的数据类型 和 设备 (CPU或GPU),这将用于创建新的张量。
dtype, device = feats[0].dtype, feats[0].device
# 遍历步长列表, i 是索引, stride 是当前步长。
for i, stride in enumerate(strides):
# 获取第 i 个特征图的形状, h 是高度, w 是宽度。
_, _, h, w = feats[i].shape
# 在创建 sx 和 sy 时加上偏移量 grid_cell_offset (通常设置为0.5)的原因是为了将锚点(anchor points)的中心对准每个网格单元的中心。以下是详细解释 :
# 网格单元的中心 :
# 在目标检测中,特征图通常被划分为一个个网格单元,每个网格单元对应于图像中的一个区域。
# 通过在网格单元的索引上加上偏移量,我们可以将锚点的中心从网格单元的角落移动到网格单元的中心。
# 提高定位精度 :
# 目标检测模型需要精确地定位目标的边界框。将锚点中心对准网格单元的中心可以提高模型预测边界框位置的精度。
# 避免边界效应 :
# 如果锚点位于网格单元的角落,那么在网格边缘的目标可能会被不良地预测,因为锚点距离目标的真实中心较远。
# 通过将锚点中心移动到网格单元的中心,可以减少这种边界效应,提高模型对于边缘目标的检测能力。
# 与YOLO方法一致 :
# YOLO模型直接预测相对于网格单元左上角的边界框坐标。加上偏移量是为了使得预测的坐标与YOLO方法保持一致。
# 简化模型学习 :
# 随机初始化模型会需要很长一段时间才能稳定产生可靠的偏移量(offsets)。通过直接在网格单元中心初始化锚点,可以简化模型的学习过程。
# 提高模型的泛化能力 :
# 将锚点中心对准网格单元的中心有助于模型更好地泛化到不同尺寸和位置的目标。
# 综上所述,加上偏移量 grid_cell_offset 是为了确保锚点位于网格单元的中心,这有助于提高目标检测模型的定位精度和泛化能力。
# 创建两个张量 sx 和 sy ,分别表示特征图上每个单元格的x和y坐标,并加上偏移量 grid_cell_offset 。
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
# 使用 torch.meshgrid 函数生成坐标网格。 indexing='ij' 参数指定了网格的索引方式, 'ij' 表示矩阵索引方式,这是PyTorch 1.10及以上版本的写法。对于旧版本的PyTorch,不需要指定 indexing 参数。
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
# 将 sx 和 sy 堆叠起来,并调整形状,形成锚点坐标,然后添加到 anchor_points 列表中。
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
# 创建一个填充了步长的张量,并添加到 stride_tensor 列表中。
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
# 使用 torch.cat 函数将 anchor_points 和 stride_tensor 列表中的所有张量连接起来,并返回结果。
return torch.cat(anchor_points), torch.cat(stride_tensor)
# 这个函数的作用是为每个特征图上的每个单元格生成锚点坐标,并为每个锚点分配相应的步长。这些锚点将用于目标检测模型中的边界框预测。
3.def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
python
# 这段代码定义了一个名为 dist2bbox 的函数,它用于将从锚点(anchor points)出发的边界距离(通常表示为左上角和右下角的坐标)转换为边界框(bounding box)的坐标。这个转换可以输出两种格式的边界框:中心点加宽高(xywh)格式和左上角与右下角坐标(xyxy)格式。
# 这是 dist2bbox 函数的定义,它接受四个参数。
# 1.distance :一个包含边界距离的张量,通常表示为左上角(left, top)和右下角(right, bottom)的坐标。
# 2.anchor_points :锚点的坐标,通常是边界框的中心点或者左上角点。
# 3.xywh :一个布尔值,指示输出的边界框格式,默认为True,表示输出中心点加宽高的格式(xywh)。
# 4.dim :指定在哪个维度上进行操作,默认为-1,即最后一个维度。
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
# 将距离(ltrb)转换为边界框(xywh或xyxy)。
"""Transform distance(ltrb) to box(xywh or xyxy)."""
# 使用 torch.split 函数将 distance 张量在指定维度 dim 上分割成两个部分,分别表示左上角(lt)和右下角(rb)的坐标。
lt, rb = torch.split(distance, 2, dim)
# 计算边界框的左上角(x1y1)和右下角(x2y2)坐标。 x1y1 锚点坐标减去左上角的距离。 x2y2 锚点坐标加上右下角的距离。
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
# 判断是否需要将边界框格式转换为 xywh 格式。
if xywh:
# 如果需要转换为xywh格式,则计算边界框的中心点(c_xy)和宽高(wh)。 c_xy 左上角和右下角坐标的平均值,表示边界框的中心点。
c_xy = (x1y1 + x2y2) / 2
# wh 右下角和左上角坐标的差值,表示边界框的宽度和高度。
wh = x2y2 - x1y1
# 将中心点和宽高在指定维度 dim 上拼接起来,返回xywh格式的边界框。
return torch.cat((c_xy, wh), dim) # xywh bbox
# 如果不转换为xywh格式,则直接将左上角和右下角坐标在指定维度 dim 上拼接起来,返回xyxy格式的边界框。
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
# 这个函数的作用是将从锚点出发的边界距离转换为边界框的坐标,这是目标检测模型中常见的操作,用于从模型的输出中恢复出边界框的位置。
4.def bbox2dist(anchor_points, bbox, reg_max):
python
# 这段代码定义了一个名为 bbox2dist 的函数,它将边界框从 xyxy (左上角和右下角的坐标)格式转换为 ltrb (左、上、右、下的偏移量)格式。
# 定义 bbox2dist 函数,接受三个参数。
# 1.anchor_points :锚点的坐标,通常是边界框的中心点。
# 2.bbox :边界框的坐标。
# 3.reg_max :回归任务中的最大值,用于限制距离的范围。
def bbox2dist(anchor_points, bbox, reg_max):
# 将 bbox(xyxy) 转换为 dist(ltrb)。
"""Transform bbox(xyxy) to dist(ltrb)."""
# torch.split(tensor, split_size_or_sections, dim=0)
# torch.split 函数在PyTorch中用于将一个张量(Tensor)分割成多个较小的张量,这些张量在指定的维度上具有相等或不同的大小。这个函数非常灵活,可以根据需要分割张量。
# tensor :要分割的输入张量。
# split_size_or_sections :一个整数或张量大小的序列。 如果是一个整数,表示每个分割块的大小(除了可能的最后一块)。 如果是一个序列,表示每个分割块的大小。
# dim :要沿哪个维度进行分割。默认是0。
# 返回值:
# 返回一个张量元组,包含分割后的各个张量。
# 使用 torch.split 函数将 bbox 张量沿着最后一个维度(即坐标维度)分割成两个部分,分别包含边界框的左上角坐标 x1y1 和右下角坐标 x2y2 。
x1y1, x2y2 = torch.split(bbox, 2, -1)
# torch.clamp(input, min=None, max=None)
# torch.clamp() 是 PyTorch 库中的一个函数,用于将张量中的元素限制在指定的范围内。如果元素超出了这个范围,它们将被设置为范围的上限或下限。
# 参数 :
# input :要进行裁剪的输入张量。
# min :元素的最小值。默认为 None ,表示不设置下界。
# max :元素的最大值。默认为 None ,表示不设置上界。
# 返回值 :
# 返回一个新的张量,其中的元素被限制在 [min, max] 范围内。
# 注意事项 :
# torch.clamp() 函数返回的是新张量,原始输入张量不会被修改。
# 如果需要在原地修改张量,可以使用 clamped_() 方法,例如 tensor.clamp_(0, 3) 。
# torch.clamp() 可以用于多维张量,并且可以指定不同的 min 和 max 值用于不同的维度。
# min 和 max 参数也可以是标量值,或者与输入张量形状相同的张量,用于对不同元素应用不同的限制。
# 计算边界框相对于锚点的距离。
# anchor_points - x1y1 :计算锚点与边界框左上角之间的距离,这给出了边界框左上角相对于锚点的偏移量。
# x2y2 - anchor_points :计算边界框右下角与锚点之间的距离,这给出了边界框右下角相对于锚点的偏移量。
# 使用 torch.cat 将这两个偏移量沿着最后一个维度(即坐标维度)拼接起来,形成 ltrb 格式的距离。然后使用 clamp 函数将距离限制在 [0, reg_max - 0.01] 的范围内,以避免超出回归任务的最大值。
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp(0, reg_max - 0.01) # dist (lt, rb)
# bbox2dist 函数将边界框的坐标从 xyxy 格式转换为相对于锚点的 ltrb 格式的距离,这种转换在目标检测模型中常用于回归任务,其中模型需要预测边界框相对于锚点的位置偏移量。通过限制距离的范围,可以防止模型预测超出合理的边界框位置。