13.5. 多尺度目标检测

这里是对那一节代码的通俗注释,希望对各位学习有帮助。

值得注意的是,multibox_prior函数的宽高计算网络上有争议,此处我仍认为作者的写法是正确的,如果读者有想法,可以在评论区留言,我们进行讨论。

python 复制代码
import torch
from d2l import torch as d2l

torch.set_printoptions(2)  # 设置张量输出精度


# 定义一个函数,用于生成以每个像素为中心具有不同形状的锚框
def multibox_prior(data, sizes, ratios):
    in_height, in_width = data.shape[-2:]
    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
    boxes_per_pixel = (num_sizes + num_ratios - 1)  # 计算以每一个像素为中心要生成多少个锚框
    size_tensor = torch.tensor(sizes, device=device)  # 将这些锚框的大小(缩放比)转换为张量
    ratio_tensor = torch.tensor(ratios, device=device)  # 将这些锚框的宽高比转换为张量
    offset_h, offset_w = 0.5, 0.5  # 设置偏移量,将中心点移动到每一个像素的中心
    steps_h = 1.0 / in_height
    steps_w = 1.0 / in_width
    center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
    center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
    shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),  # 保持宽高比不变,遍历所有缩放比,下一行是保持缩放比不变,遍历所有宽高比
                   sizes[0] * torch.sqrt(ratio_tensor[1:]))) \
        * in_height / in_width  # 最终计算得到宽度
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                   sizes[0] / torch.sqrt(ratio_tensor[1:])))
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
        in_height * in_width, 1) / 2  # 重复这些锚框为wxh次,因为有这么多像素,除以2是因为将锚框的上下和左右度量均方,以放置中点上
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
                           dim=1).repeat_interleave(boxes_per_pixel, dim=0)  # 得到中心点位置
    output = out_grid + anchor_manipulations  # 两者相加,得到正确的锚框坐标
    return output.unsqueeze(0)


# 显示所有边界框
def show_bboxes(axes, bboxes, labels=None, colors=None):
    def _make_list(obj, default_values=None):
        if obj is None:
            obj = default_values
        elif not isinstance(obj, (list, tuple)):
            obj = [obj]
        return obj

    labels = _make_list(labels)
    colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])

    for i, bbox in enumerate(bboxes):
        color = colors[i % len(colors)]
        rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)
        axes.add_patch(rect)
        if labels and len(labels) > i:
            text_color = 'k' if color == 'w' else 'w'
            axes.text(rect.xy[0], rect.xy[1], labels[i],
                      va='center', ha='center', fontsize=9, color=text_color,
                      bbox=dict(facecolor=color, lw=0))


# 计算两个锚框或边界框列表中成对的交并比
def box_iou(boxes1, boxes2):
    box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *
                              (boxes[:, 3] - boxes[:, 1]))
    # 计算给定框的面积
    areas1 = box_area(boxes1)
    areas2 = box_area(boxes2)
    # 计算交集的左上角和右下角的坐标
    inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])
    inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
    # 计算交集的宽高以及面积
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)
    inter_areas = inters[:, :, 0] * inters[:, :, 1]
    # 计算并集的面积
    union_areas = areas1[:, None] + areas2 - inter_areas
    # 返回交并比
    return inter_areas / union_areas


# 将最接近的真实边界框分配给锚框
def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):
    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]  # 获取锚框的数量和真实边界框的数量
    jaccard = box_iou(anchors, ground_truth)  # 得到交并比矩阵
    anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,
                                  device=device)  # 创建真实边界框分配列表,初始用-1填充,表示不分配
    max_ious, indices = torch.max(jaccard, dim=1)  # 求得每一个锚框与所有真实边界框的最大交并比和其索引
    anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)  # 得到满足阈值要求交并比
    box_j = indices[max_ious >= iou_threshold]  # 得到满足阈值要求交并比的索引
    anchors_bbox_map[anc_i] = box_j  # 如果交并比满足阈值要求,将真实边界框索引分配到对应的锚框
    col_discard = torch.full((num_anchors,), -1)  # 列丢弃索引,用来标记交并比矩阵已经丢弃的列
    row_discard = torch.full((num_gt_boxes,), -1)  # 行丢弃索引,用来标记交并比矩阵已经丢弃的行
    for _ in range(num_gt_boxes):
        max_idx = torch.argmax(jaccard)  # 获取整个交并比矩阵中,值最大的索引(矩阵扁平化后的索引)
        box_idx = (max_idx % num_gt_boxes).long()  # 得到该交并比对应的真实边界框的索引
        anc_idx = (max_idx / num_gt_boxes).long()  # 得到该交并比对应的锚框的索引
        anchors_bbox_map[anc_idx] = box_idx  # 分配真实边界框
        jaccard[:, box_idx] = col_discard  # 丢弃对应的列
        jaccard[anc_idx, :] = row_discard  # 丢弃对应的行
    return anchors_bbox_map


def offset_boxes(anchors, assigned_bb, eps=1e-6):
    c_anc = d2l.box_corner_to_center(anchors)  # 获取所有锚框的中心坐标
    c_assigned_bb = d2l.box_corner_to_center(assigned_bb)  # 获取真实边界框的中心坐标
    offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]  # 计算锚框和真实边界框的中心坐标偏移量
    offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])  # 计算宽高缩放的偏移量
    offset = torch.cat([offset_xy, offset_wh], axis=1)  # 将两种偏移量进行连接,排成一行,然后返回
    return offset


def multibox_target(anchors, labels):  # labels的形状(batchsize,边界框数量,5),后面的5中,第一个元素是真实标签,后面是坐标信息
    batch_size, anchors = labels.shape[0], anchors.squeeze(0)
    batch_offset, batch_mask, batch_class_labels = [], [], []
    device, num_anchors = anchors.device, anchors.shape[0]
    for i in range(batch_size):
        label = labels[i, :, :]  # 获取每一个样本的所有真实边界框的信息(标签和坐标)
        anchors_bbox_map = assign_anchor_to_bbox(  # 获取真实标签对锚框的分配表
            label[:, 1:], anchors, device)
        bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(  # 生成偏移量掩码,为了屏蔽掉未分配的锚框的偏移量
            1, 4)
        class_labels = torch.zeros(num_anchors, dtype=torch.long,
                                   device=device)
        assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,
                                  device=device)
        indices_true = torch.nonzero(anchors_bbox_map >= 0)  # 获取已分配真实边界框的锚框的索引
        bb_idx = anchors_bbox_map[indices_true]  # 获取真实边界框的索引
        class_labels[indices_true] = label[bb_idx, 0].long() + 1  # 获取真实标签的同时,将标签索引改为从1开始
        assigned_bb[indices_true] = label[bb_idx, 1:]  # 获取真实边界框坐标
        offset = offset_boxes(anchors, assigned_bb) * bbox_mask  # 获取锚框与真实边界框的偏移量(已屏蔽未分配真实标签的锚框)
        batch_offset.append(offset.reshape(-1))  # 扁平化
        batch_mask.append(bbox_mask.reshape(-1))
        batch_class_labels.append(class_labels)
    bbox_offset = torch.stack(batch_offset)  # bbox_offset 的形状是 (batch_size, num_anchors * 4)
    bbox_mask = torch.stack(batch_mask)  # bbox_mask 的形状也是 (batch_size, num_anchors * 4)
    class_labels = torch.stack(batch_class_labels)  # class_labels 的形状是 (batch_size, num_anchors)
    return (bbox_offset, bbox_mask, class_labels)


def offset_inverse(anchors, offset_preds):
    """根据带有预测偏移量的锚框来预测边界框"""
    anc = d2l.box_corner_to_center(anchors)
    pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
    pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
    pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
    predicted_bbox = d2l.box_center_to_corner(pred_bbox)
    return predicted_bbox


def nms(boxes, scores, iou_threshold):
    """对预测边界框的置信度进行排序"""
    B = torch.argsort(scores, dim=-1, descending=True)
    keep = []  # 保留预测边界框的指标
    while B.numel() > 0:
        i = B[0]
        keep.append(i)
        if B.numel() == 1: break
        iou = box_iou(boxes[i, :].reshape(-1, 4),  # 将当前边界框与其他所有边界框进行IoU计算
                      boxes[B[1:], :].reshape(-1, 4)).reshape(-1)
        inds = torch.nonzero(iou <= iou_threshold).reshape(-1)  # 获取低于阈值的所有交并比索引
        B = B[inds + 1]  # 获取低于阈值的所有边界框,进行下一轮抑制
    return torch.tensor(keep, device=boxes.device)


def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
                       pos_threshold=0.009999999):
    """使用非极大值抑制来预测边界框"""
    device, batch_size = cls_probs.device, cls_probs.shape[0]
    anchors = anchors.squeeze(0)  # 压缩后的形状(num_anchors,4)
    num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]  # 获得每个样本的类别数量和锚框数量
    out = []  # 存储预测结果
    for i in range(batch_size):
        cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)  # 每次取出一个样本
        conf, class_id = torch.max(cls_prob[1:], 0)  # 获取样本的锚框对于所有类别的置信度
        predicted_bb = offset_inverse(anchors, offset_pred)  # 逆转偏移量计算操作,得到预测边界框的真实坐标
        keep = nms(predicted_bb, conf, nms_threshold)  # 获取通过非最大值抑制操作后保留的预测框的索引

        # 找到所有的non_keep索引,并将类设置为背景
        all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
        combined = torch.cat((keep, all_idx))  # 该混合操作会使最终combined张量有重复元素,便于后边将非重复的设置为背景
        uniques, counts = combined.unique(return_counts=True)
        non_keep = uniques[counts == 1]  # 未重复的就是不保留的
        all_id_sorted = torch.cat((keep, non_keep))  # 将要保留的和不保留的连接在一起
        class_id[non_keep] = -1  # 将不保留预测框的类别索引设置为-1,表示没有
        class_id = class_id[all_id_sorted]  # 重新排列类别索引
        conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]  # 重新排列置信度和预测框
        # pos_threshold是一个用于非背景预测的阈值
        below_min_idx = (conf < pos_threshold)  # 获取置信度小于阈值的预测框的索引
        class_id[below_min_idx] = -1  # 将对应位置类别索引设置为-1
        conf[below_min_idx] = 1 - conf[below_min_idx]  # 将低于阈值的置信度,与背景置信度互换
        pred_info = torch.cat((class_id.unsqueeze(1),  # 重排成列,一行表示一个类别索引
                               conf.unsqueeze(1),  # 重拍成列,一行表示一个类别的置信度
                               predicted_bb), dim=1)
        out.append(pred_info)  # 完成一个样本的处理
    return torch.stack(out)  # 将分开处理样本合并为一个批量


# 读取图片
img = d2l.plt.imread('../img/catdog.jpg')
h, w = img.shape[:2]

# 生成锚框
X = torch.rand(size=(1, 3, h, w))
Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])

# 显示部分锚框及其对应的标签
boxes = Y.reshape(h, w, 5, 4)
d2l.set_figsize()
bbox_scale = torch.tensor((w, h, w, h))
fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,
            ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',
             's=0.75, r=0.5'])
d2l.plt.show()

ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],
                             [1, 0.55, 0.2, 0.9, 0.88]])
anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],
                        [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],
                        [0.57, 0.3, 0.92, 0.9]])

fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')
show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);
d2l.plt.show()

labels = multibox_target(anchors.unsqueeze(dim=0),
                         ground_truth.unsqueeze(dim=0))
print(labels[2])
print(labels[1])
print(labels[0])

anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],
                        [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])
offset_preds = torch.tensor([0] * anchors.numel())
cls_probs = torch.tensor([[0] * 4,  # 背景的预测概率
                          [0.9, 0.8, 0.7, 0.1],  # 狗的预测概率
                          [0.1, 0.2, 0.3, 0.9]])  # 猫的预测概率

fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, anchors * bbox_scale,
            ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])
d2l.plt.show()

output = multibox_detection(cls_probs.unsqueeze(dim=0),
                            offset_preds.unsqueeze(dim=0),
                            anchors.unsqueeze(dim=0),
                            nms_threshold=0.5)
print(output)

fig = d2l.plt.imshow(img)
for i in output[0].detach().numpy():
    if i[0] == -1:
        continue
    label = ('dog=', 'cat=')[int(i[0])] + str(i[1])
    show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)
d2l.plt.show()
相关推荐
18号房客5 分钟前
计算机视觉-人工智能(AI)入门教程一
人工智能·深度学习·opencv·机器学习·计算机视觉·数据挖掘·语音识别
百家方案7 分钟前
「下载」智慧产业园区-数字孪生建设解决方案:重构产业全景图,打造虚实结合的园区数字化底座
大数据·人工智能·智慧园区·数智化园区
云起无垠13 分钟前
“AI+Security”系列第4期(一)之“洞” 见未来:AI 驱动的漏洞挖掘新范式
人工智能
QQ_77813297432 分钟前
基于深度学习的图像超分辨率重建
人工智能·机器学习·超分辨率重建
清 晨44 分钟前
Web3 生态全景:创新与发展之路
人工智能·web3·去中心化·智能合约
公众号Codewar原创作者1 小时前
R数据分析:工具变量回归的做法和解释,实例解析
开发语言·人工智能·python
IT古董1 小时前
【漫话机器学习系列】020.正则化强度的倒数C(Inverse of regularization strength)
人工智能·机器学习
进击的小小学生1 小时前
机器学习连载
人工智能·机器学习
Trouvaille ~2 小时前
【机器学习】从流动到恒常,无穷中归一:积分的数学诗意
人工智能·python·机器学习·ai·数据分析·matplotlib·微积分
dundunmm2 小时前
论文阅读:Deep Fusion Clustering Network With Reliable Structure Preservation
论文阅读·人工智能·数据挖掘·聚类·深度聚类·图聚类