13.5. 多尺度目标检测

这里是对那一节代码的通俗注释,希望对各位学习有帮助。

值得注意的是,multibox_prior函数的宽高计算网络上有争议,此处我仍认为作者的写法是正确的,如果读者有想法,可以在评论区留言,我们进行讨论。

python 复制代码
import torch
from d2l import torch as d2l

torch.set_printoptions(2)  # 设置张量输出精度


# 定义一个函数,用于生成以每个像素为中心具有不同形状的锚框
def multibox_prior(data, sizes, ratios):
    in_height, in_width = data.shape[-2:]
    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
    boxes_per_pixel = (num_sizes + num_ratios - 1)  # 计算以每一个像素为中心要生成多少个锚框
    size_tensor = torch.tensor(sizes, device=device)  # 将这些锚框的大小(缩放比)转换为张量
    ratio_tensor = torch.tensor(ratios, device=device)  # 将这些锚框的宽高比转换为张量
    offset_h, offset_w = 0.5, 0.5  # 设置偏移量,将中心点移动到每一个像素的中心
    steps_h = 1.0 / in_height
    steps_w = 1.0 / in_width
    center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
    center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
    shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),  # 保持宽高比不变,遍历所有缩放比,下一行是保持缩放比不变,遍历所有宽高比
                   sizes[0] * torch.sqrt(ratio_tensor[1:]))) \
        * in_height / in_width  # 最终计算得到宽度
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                   sizes[0] / torch.sqrt(ratio_tensor[1:])))
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
        in_height * in_width, 1) / 2  # 重复这些锚框为wxh次,因为有这么多像素,除以2是因为将锚框的上下和左右度量均方,以放置中点上
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
                           dim=1).repeat_interleave(boxes_per_pixel, dim=0)  # 得到中心点位置
    output = out_grid + anchor_manipulations  # 两者相加,得到正确的锚框坐标
    return output.unsqueeze(0)


# 显示所有边界框
def show_bboxes(axes, bboxes, labels=None, colors=None):
    def _make_list(obj, default_values=None):
        if obj is None:
            obj = default_values
        elif not isinstance(obj, (list, tuple)):
            obj = [obj]
        return obj

    labels = _make_list(labels)
    colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])

    for i, bbox in enumerate(bboxes):
        color = colors[i % len(colors)]
        rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)
        axes.add_patch(rect)
        if labels and len(labels) > i:
            text_color = 'k' if color == 'w' else 'w'
            axes.text(rect.xy[0], rect.xy[1], labels[i],
                      va='center', ha='center', fontsize=9, color=text_color,
                      bbox=dict(facecolor=color, lw=0))


# 计算两个锚框或边界框列表中成对的交并比
def box_iou(boxes1, boxes2):
    box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *
                              (boxes[:, 3] - boxes[:, 1]))
    # 计算给定框的面积
    areas1 = box_area(boxes1)
    areas2 = box_area(boxes2)
    # 计算交集的左上角和右下角的坐标
    inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])
    inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
    # 计算交集的宽高以及面积
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)
    inter_areas = inters[:, :, 0] * inters[:, :, 1]
    # 计算并集的面积
    union_areas = areas1[:, None] + areas2 - inter_areas
    # 返回交并比
    return inter_areas / union_areas


# 将最接近的真实边界框分配给锚框
def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):
    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]  # 获取锚框的数量和真实边界框的数量
    jaccard = box_iou(anchors, ground_truth)  # 得到交并比矩阵
    anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,
                                  device=device)  # 创建真实边界框分配列表,初始用-1填充,表示不分配
    max_ious, indices = torch.max(jaccard, dim=1)  # 求得每一个锚框与所有真实边界框的最大交并比和其索引
    anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)  # 得到满足阈值要求交并比
    box_j = indices[max_ious >= iou_threshold]  # 得到满足阈值要求交并比的索引
    anchors_bbox_map[anc_i] = box_j  # 如果交并比满足阈值要求,将真实边界框索引分配到对应的锚框
    col_discard = torch.full((num_anchors,), -1)  # 列丢弃索引,用来标记交并比矩阵已经丢弃的列
    row_discard = torch.full((num_gt_boxes,), -1)  # 行丢弃索引,用来标记交并比矩阵已经丢弃的行
    for _ in range(num_gt_boxes):
        max_idx = torch.argmax(jaccard)  # 获取整个交并比矩阵中,值最大的索引(矩阵扁平化后的索引)
        box_idx = (max_idx % num_gt_boxes).long()  # 得到该交并比对应的真实边界框的索引
        anc_idx = (max_idx / num_gt_boxes).long()  # 得到该交并比对应的锚框的索引
        anchors_bbox_map[anc_idx] = box_idx  # 分配真实边界框
        jaccard[:, box_idx] = col_discard  # 丢弃对应的列
        jaccard[anc_idx, :] = row_discard  # 丢弃对应的行
    return anchors_bbox_map


def offset_boxes(anchors, assigned_bb, eps=1e-6):
    c_anc = d2l.box_corner_to_center(anchors)  # 获取所有锚框的中心坐标
    c_assigned_bb = d2l.box_corner_to_center(assigned_bb)  # 获取真实边界框的中心坐标
    offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]  # 计算锚框和真实边界框的中心坐标偏移量
    offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])  # 计算宽高缩放的偏移量
    offset = torch.cat([offset_xy, offset_wh], axis=1)  # 将两种偏移量进行连接,排成一行,然后返回
    return offset


def multibox_target(anchors, labels):  # labels的形状(batchsize,边界框数量,5),后面的5中,第一个元素是真实标签,后面是坐标信息
    batch_size, anchors = labels.shape[0], anchors.squeeze(0)
    batch_offset, batch_mask, batch_class_labels = [], [], []
    device, num_anchors = anchors.device, anchors.shape[0]
    for i in range(batch_size):
        label = labels[i, :, :]  # 获取每一个样本的所有真实边界框的信息(标签和坐标)
        anchors_bbox_map = assign_anchor_to_bbox(  # 获取真实标签对锚框的分配表
            label[:, 1:], anchors, device)
        bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(  # 生成偏移量掩码,为了屏蔽掉未分配的锚框的偏移量
            1, 4)
        class_labels = torch.zeros(num_anchors, dtype=torch.long,
                                   device=device)
        assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,
                                  device=device)
        indices_true = torch.nonzero(anchors_bbox_map >= 0)  # 获取已分配真实边界框的锚框的索引
        bb_idx = anchors_bbox_map[indices_true]  # 获取真实边界框的索引
        class_labels[indices_true] = label[bb_idx, 0].long() + 1  # 获取真实标签的同时,将标签索引改为从1开始
        assigned_bb[indices_true] = label[bb_idx, 1:]  # 获取真实边界框坐标
        offset = offset_boxes(anchors, assigned_bb) * bbox_mask  # 获取锚框与真实边界框的偏移量(已屏蔽未分配真实标签的锚框)
        batch_offset.append(offset.reshape(-1))  # 扁平化
        batch_mask.append(bbox_mask.reshape(-1))
        batch_class_labels.append(class_labels)
    bbox_offset = torch.stack(batch_offset)  # bbox_offset 的形状是 (batch_size, num_anchors * 4)
    bbox_mask = torch.stack(batch_mask)  # bbox_mask 的形状也是 (batch_size, num_anchors * 4)
    class_labels = torch.stack(batch_class_labels)  # class_labels 的形状是 (batch_size, num_anchors)
    return (bbox_offset, bbox_mask, class_labels)


def offset_inverse(anchors, offset_preds):
    """根据带有预测偏移量的锚框来预测边界框"""
    anc = d2l.box_corner_to_center(anchors)
    pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
    pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
    pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
    predicted_bbox = d2l.box_center_to_corner(pred_bbox)
    return predicted_bbox


def nms(boxes, scores, iou_threshold):
    """对预测边界框的置信度进行排序"""
    B = torch.argsort(scores, dim=-1, descending=True)
    keep = []  # 保留预测边界框的指标
    while B.numel() > 0:
        i = B[0]
        keep.append(i)
        if B.numel() == 1: break
        iou = box_iou(boxes[i, :].reshape(-1, 4),  # 将当前边界框与其他所有边界框进行IoU计算
                      boxes[B[1:], :].reshape(-1, 4)).reshape(-1)
        inds = torch.nonzero(iou <= iou_threshold).reshape(-1)  # 获取低于阈值的所有交并比索引
        B = B[inds + 1]  # 获取低于阈值的所有边界框,进行下一轮抑制
    return torch.tensor(keep, device=boxes.device)


def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
                       pos_threshold=0.009999999):
    """使用非极大值抑制来预测边界框"""
    device, batch_size = cls_probs.device, cls_probs.shape[0]
    anchors = anchors.squeeze(0)  # 压缩后的形状(num_anchors,4)
    num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]  # 获得每个样本的类别数量和锚框数量
    out = []  # 存储预测结果
    for i in range(batch_size):
        cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)  # 每次取出一个样本
        conf, class_id = torch.max(cls_prob[1:], 0)  # 获取样本的锚框对于所有类别的置信度
        predicted_bb = offset_inverse(anchors, offset_pred)  # 逆转偏移量计算操作,得到预测边界框的真实坐标
        keep = nms(predicted_bb, conf, nms_threshold)  # 获取通过非最大值抑制操作后保留的预测框的索引

        # 找到所有的non_keep索引,并将类设置为背景
        all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
        combined = torch.cat((keep, all_idx))  # 该混合操作会使最终combined张量有重复元素,便于后边将非重复的设置为背景
        uniques, counts = combined.unique(return_counts=True)
        non_keep = uniques[counts == 1]  # 未重复的就是不保留的
        all_id_sorted = torch.cat((keep, non_keep))  # 将要保留的和不保留的连接在一起
        class_id[non_keep] = -1  # 将不保留预测框的类别索引设置为-1,表示没有
        class_id = class_id[all_id_sorted]  # 重新排列类别索引
        conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]  # 重新排列置信度和预测框
        # pos_threshold是一个用于非背景预测的阈值
        below_min_idx = (conf < pos_threshold)  # 获取置信度小于阈值的预测框的索引
        class_id[below_min_idx] = -1  # 将对应位置类别索引设置为-1
        conf[below_min_idx] = 1 - conf[below_min_idx]  # 将低于阈值的置信度,与背景置信度互换
        pred_info = torch.cat((class_id.unsqueeze(1),  # 重排成列,一行表示一个类别索引
                               conf.unsqueeze(1),  # 重拍成列,一行表示一个类别的置信度
                               predicted_bb), dim=1)
        out.append(pred_info)  # 完成一个样本的处理
    return torch.stack(out)  # 将分开处理样本合并为一个批量


# 读取图片
img = d2l.plt.imread('../img/catdog.jpg')
h, w = img.shape[:2]

# 生成锚框
X = torch.rand(size=(1, 3, h, w))
Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])

# 显示部分锚框及其对应的标签
boxes = Y.reshape(h, w, 5, 4)
d2l.set_figsize()
bbox_scale = torch.tensor((w, h, w, h))
fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,
            ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',
             's=0.75, r=0.5'])
d2l.plt.show()

ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],
                             [1, 0.55, 0.2, 0.9, 0.88]])
anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],
                        [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],
                        [0.57, 0.3, 0.92, 0.9]])

fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')
show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);
d2l.plt.show()

labels = multibox_target(anchors.unsqueeze(dim=0),
                         ground_truth.unsqueeze(dim=0))
print(labels[2])
print(labels[1])
print(labels[0])

anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],
                        [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])
offset_preds = torch.tensor([0] * anchors.numel())
cls_probs = torch.tensor([[0] * 4,  # 背景的预测概率
                          [0.9, 0.8, 0.7, 0.1],  # 狗的预测概率
                          [0.1, 0.2, 0.3, 0.9]])  # 猫的预测概率

fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, anchors * bbox_scale,
            ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])
d2l.plt.show()

output = multibox_detection(cls_probs.unsqueeze(dim=0),
                            offset_preds.unsqueeze(dim=0),
                            anchors.unsqueeze(dim=0),
                            nms_threshold=0.5)
print(output)

fig = d2l.plt.imshow(img)
for i in output[0].detach().numpy():
    if i[0] == -1:
        continue
    label = ('dog=', 'cat=')[int(i[0])] + str(i[1])
    show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)
d2l.plt.show()
相关推荐
ECT-OS-JiuHuaShan3 分钟前
整体是函数,部分是子函数——范畴论框架下的严格证明
人工智能
柯儿的天空4 分钟前
【OpenClaw 全面解析:从零到精通】第 004 篇:OpenClaw 在 Linux/Ubuntu 上的安装与部署实战
linux·人工智能·ubuntu·elasticsearch·知识图谱
xixixi777775 分钟前
从SQL注入到XSS:实战Web安全渗透测试
人工智能·安全·web安全·网络安全·卫星通信
代码探秘者6 分钟前
【算法篇】1.双指针
java·数据结构·人工智能·后端·python·算法
倦王6 分钟前
Dify的部署(详细步骤一步一步)
人工智能
一水鉴天13 分钟前
整体设计 设计文档修订与重构修改稿 (豆包助手)20260321
人工智能·重构
小马过河R16 分钟前
小白沉浸式本地Mac小龙虾OpenClaw部署安装教程
人工智能·macos·大模型·nlp·agent·openclaw·龙虾
hitgavin23 分钟前
Physical Intelligence RLT
人工智能
xwz小王子23 分钟前
Science Robotics 赋予机器人“类脑”触觉,低成本视觉-触觉预训练攻克灵巧手多任务操作
人工智能·算法·机器人
LONGZETECH28 分钟前
实测职业教育无人机仿真教学软件:架构、功能与落地全解析
人工智能·架构·无人机·无人机仿真教学软件·无人机教学软件·无人机仿真软件