YOLOX源码之 Label Assignment

网络结构没什么好讲的,backbone、neck、head组成,backbone采用的cspdarknet,neck采用的pafpn,head是decoupled head结构。这里主要讲一下label assignment的具体实现,yolox中采用了simota,是ota的简化版本。

实现中标签分配以及计算损失部分是在yolo_head.py中,连带着head的网络层一起的,这里也顺带一起讲了。

首先forward函数的输入xin是neck的输出,当输入shape为(4,3,416,416)时,xin的shape为[(4,128,52,52),(4,256,26,26),(4,512,13,13)],对应8,16,32三种不同stride的输出特征图。

接下里的for循环是分别对三个特征图进行head部分网络层的forward,并计算对应的grids,grids具体是什么后面会讲。以stride=8对应的大小为(4,128,52,52)的特征图为例,self.stems[k]是一层1x1卷积,然后分类分支cls_conv和回归分支reg_conv都是2层3x3卷积,self.cls_preds[k]得到最终的分类输出shape为(b,num_classes,52,52),self.reg_preds[k]得到最终的回归输出shape为(b,4,52,52),self.obj_preds[k]得到最终的objectiveness输出shape为(b,1,52,52)。这里b=4,num_classes=16。

接下来将三个输出torch.cat得到输出shape为(4,21,52,52)。接下来函数self.get_output_and_grid()得到网格坐标grid和解码后的输出output。

代码如下

python 复制代码
def get_output_and_grid(self, output, k, stride, dtype):
    # (4,21,52,52)
    grid = self.grids[k]

    batch_size = output.shape[0]
    n_ch = 5 + self.num_classes
    hsize, wsize = output.shape[-2:]
    if grid.shape[2:4] != output.shape[2:4]:
        yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
        grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)  # (1,1,52,52,2), 先按行后按列,每个像素点的坐标
        self.grids[k] = grid

    output = output.view(batch_size, 1, n_ch, hsize, wsize)  # (4,1,21,52,52)
    output = output.permute(0, 1, 3, 4, 2).reshape(
        batch_size, hsize * wsize, -1
    )  # (4,2704,21)
    grid = grid.view(1, -1, 2)  # (1,2704,2)
    output[..., :2] = (output[..., :2] + grid) * stride
    output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
    return output, grid

self.grids是三个torch.Size([1])的列表,所以会进入到line8的if中。hsize和wsize分别是特征图的高和宽这里都是52,meshgrid返回的yv和xv分别是特征图每个像素点对应的y坐标和x坐标,如下所示

python 复制代码
tensor([[ 0,  0,  0,  ...,  0,  0,  0],                                                                                                                                
        [ 1,  1,  1,  ...,  1,  1,  1],                                                                                                                                
        [ 2,  2,  2,  ...,  2,  2,  2],                                                                                                                                
        ...,                                                                                                                                                           
        [49, 49, 49,  ..., 49, 49, 49],                                                                                                                                
        [50, 50, 50,  ..., 50, 50, 50],                                                                                                                                
        [51, 51, 51,  ..., 51, 51, 51]])                                                                                                                               
tensor([[ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        ...,                                                                                                                                                           
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51]])

然后将xy坐标stack得到每个点的xy坐标,shape为(1,1,52,52,2),按先行后列的顺序,如下

python 复制代码
tensor([[[[[ 0.,  0.],                                                                                                                                                 
           [ 1.,  0.],                                                                                                                                                 
           [ 2.,  0.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49.,  0.],                                                                                                                                                 
           [50.,  0.],                                                                                                                                                 
           [51.,  0.]],                                                                                                                                                
                                                                                                                                                                       
          [[ 0.,  1.],                                                                                                                                                 
           [ 1.,  1.],                                                                                                                                                 
           [ 2.,  1.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49.,  1.],                                                                                                                                                 
           [50.,  1.],                                                                                                                                                 
           [51.,  1.]],                                                                                                                                                                                                                                                                                            
                                                                                                                                                                       
          ...,                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                       
          [[ 0., 50.],                                                                                                                                                 
           [ 1., 50.],                                                                                                                                                 
           [ 2., 50.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49., 50.],                                                                                                                                                 
           [50., 50.],                                                                                                                                                 
           [51., 50.]],                                                                                                                                                
                                                                                                                                                                       
          [[ 0., 51.],                                                                                                                                                 
           [ 1., 51.],                                                                                                                                                 
           [ 2., 51.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49., 51.],                                                                                                                                                 
           [50., 51.],                                                                                                                                                 
           [51., 51.]]]]], device='cuda:0', dtype=torch.float16)

然后将output和grid分别view调整维度,output中每个点对应一个预测框,output[..., :2]是预测框中心点相对于每个点的偏移,因此line8加上每个点的坐标grid并乘以stride还原回原图上得到原图上真实预测框的中心点坐标。line9则是通过 \(e^{t}\) 并乘以stride得到原图上真实预测框的宽高。

在得到原图上预测框的坐标以及类别和objectiveness后,接下来就是进行label assignment并计算loss,具体实现都在self.get_losses()中。其中输入outputs是将坐标还原到原图中的三个特征图的输出并concat得到的,shape为(b, 3549, 21),3549=52x52+26x26+13x13,21=4+1+16。

在函数get_losses()中,调用self.get_assignments进行标签分配,这里使用的方法是simota。关于simota和ota的原理可参考OTA: Optimal Transport Assignment for Object Detection 原理与代码解读-CSDN博客https://blog.csdn.net/ooooocj/article/details/136569249。get_assignments()的完整实现如下

python 复制代码
@torch.no_grad()
def get_assignments(
    self,
    batch_idx,
    num_gt,
    gt_bboxes_per_image,  # (17,4)
    gt_classes,  # (17)
    bboxes_preds_per_image,  # (3549,4)
    expanded_strides,  # (1,3549)
    x_shifts,  # (1,3549)
    y_shifts,  # (1,3549)
    cls_preds,  # (4,3549,16)
    obj_preds,  # (4,3549,1)
    mode="gpu",
):

    if mode == "cpu":
        print("-----------Using CPU for the Current Batch-------------")
        gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
        bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
        gt_classes = gt_classes.cpu().float()
        expanded_strides = expanded_strides.cpu().float()
        x_shifts = x_shifts.cpu()
        y_shifts = y_shifts.cpu()

    fg_mask, geometry_relation = self.get_geometry_constraint(
        gt_bboxes_per_image,
        expanded_strides,
        x_shifts,
        y_shifts,
    )  # (3549), (17,357)
    # fg_mask中True位置的anchor point至少在一个gt box的center area内,后续会用来进行label assignment。而不在fg_mask中False位置的anchor point
    # 不在任意一个gt box的center area内。

    bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]  # (357,4)
    cls_preds_ = cls_preds[batch_idx][fg_mask]  # (357,16)
    obj_preds_ = obj_preds[batch_idx][fg_mask]  # (357,1)
    num_in_boxes_anchor = bboxes_preds_per_image.shape[0]  # 357

    if mode == "cpu":
        gt_bboxes_per_image = gt_bboxes_per_image.cpu()
        bboxes_preds_per_image = bboxes_preds_per_image.cpu()

    pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)  # (17,357)

    gt_cls_per_image = (
        F.one_hot(gt_classes.to(torch.int64), self.num_classes)
        .float()
    )  # (17,16)
    pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)  # (17,357)

    if mode == "cpu":
        cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()

    with torch.cuda.amp.autocast(enabled=False):
        cls_preds_ = (
            cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
        ).sqrt()
        pair_wise_cls_loss = F.binary_cross_entropy(
            cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),  # (357,16)->(1,357,16)->(17,357,16)
            gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),  # (17,16)->(17,1,16)->(17,357,16)
            reduction="none"
        ).sum(-1)  # (17,357), 共16个类别,每个类单独计算bce
    del cls_preds_

    cost = (
        pair_wise_cls_loss
        + 3.0 * pair_wise_ious_loss
        + float(1e6) * (~geometry_relation)  # center area之外的anchor point对应的cost加上一个很大的值来过滤
    )  # (17,357)

    (
        num_fg,  # 22
        gt_matched_classes,  # (22)
        pred_ious_this_matching,  # (22)
        matched_gt_inds,  # (22)
    ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

        if mode == "cpu":
            gt_matched_classes = gt_matched_classes.cuda()
            fg_mask = fg_mask.cuda()
            pred_ious_this_matching = pred_ious_this_matching.cuda()
            matched_gt_inds = matched_gt_inds.cuda()

        return (
            gt_matched_classes,
            fg_mask,  # (3549)
            pred_ious_this_matching,
            matched_gt_inds,
            num_fg,
        )

在ota中使用了center prior,即只有gt box中心有限区域内的anchor point作为正样本的candidate,而不是整个gt box内所有的anchor point都作为正样本的候选。函数get_geometry_constraint就是实现这个的

python 复制代码
def get_geometry_constraint(
    self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
):
    """
    Calculate whether the center of an object is located in a fixed range of
    an anchor. This is used to avert inappropriate matching. It can also reduce
    the number of candidate anchors so that the GPU memory is saved.
    """
    expanded_strides_per_image = expanded_strides[0]  # (3549)
    x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)  # (1,3549)
    y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)  # (1,3549)

    # in fixed center
    center_radius = 1.5  # 这里有可能center area区域比原目标还大
    center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius  # (1,3549)
    gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist  # (17,1) -> (17,3549)
    gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist  # (17,3549)
    gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist  # (17,3549)
    gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist  # (17,3549)

    c_l = x_centers_per_image - gt_bboxes_per_image_l  # (1,3549)-(17,3549) -> (17,3549)
    c_r = gt_bboxes_per_image_r - x_centers_per_image
    c_t = y_centers_per_image - gt_bboxes_per_image_t
    c_b = gt_bboxes_per_image_b - y_centers_per_image
    center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)  # (17,3549,4)
    is_in_centers = center_deltas.min(dim=-1).values > 0.0  # (17,3549)
    anchor_filter = is_in_centers.sum(dim=0) > 0  # (3549), 一共3549个anchor point, 对应位置为False, 说明这个anchor point不在任意一个gt box的center area内
    geometry_relation = is_in_centers[:, anchor_filter]  # (17,357), anchor_filter.sum()==357,表明某个anchor point至少在一个gt box的center area内

    return anchor_filter, geometry_relation

最终返回的anchor_filter是一个shape为(3549, )的tensor,值全为True或False。前面说过三个特征图一共3549个anchor point,值为False对应的anchor point不在任意一个gt box的center area内,后续进行标签分配时只从值为True的anchor point中挑选。当我用自己的数据调试时,另一个输出geometry_relation的shape为(17, 357),17是图中gt的数量,357是anchor_filter中值为True的anchor point的数量,geometry_relation表示每个gt的中心区域内对应的anchor point。

然后用fg_mask也就是anchor_filter挑选出候选的正样本,然后计算ota的cost matrix,cost矩阵包括分类损失以及回归损失,注意分类的预测要取sigmoid后并与obj预测相乘再与gt计算交叉熵损失,最后加上float(1e6) * (~geometry_relation)是对每个gt中心区域外的anchor加上一个特别大的cost,从而过滤它们。

在得到cost矩阵后,就是通过simota进行标签分配的过程了,具体实现在函数simota_matching中

python 复制代码
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
    # (17,357),(17,357),(17),17,(3549)
    matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)  # (17,357)

    n_candidate_k = min(10, pair_wise_ious.size(1))  # 这里10就是文章中dynamic_k中的q
    topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)  # (17,10)
    dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)  # (17)
    # tensor([3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0', dtype=torch.int32)
    # 每个gt选择q个最大iou值,相加取整作为为该gt分配的anchor point的个数
    for gt_idx in range(num_gt):
        _, pos_idx = torch.topk(
            cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
        )  # 选择cost最小的dynamic k个anchor point作为分配的正样本,代替原始OTA中的sinkhorn算法
        matching_matrix[gt_idx][pos_idx] = 1

    del topk_ious, dynamic_ks, pos_idx

    anchor_matching_gt = matching_matrix.sum(0)  # (357)
    # deal with the case that one anchor matches multiple ground-truths
    if anchor_matching_gt.max() > 1:
        multiple_match_mask = anchor_matching_gt > 1
        _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)  # 当一个anchor point匹配多个gt时,选择cost最小的gt作为匹配的结果
        matching_matrix[:, multiple_match_mask] *= 0
        matching_matrix[cost_argmin, multiple_match_mask] = 1
    fg_mask_inboxes = anchor_matching_gt > 0  # (357), pos anchor point的index
    num_fg = fg_mask_inboxes.sum().item()
    # num_fg==22, anchor_matching_gt.sum()==tensor(22, device='cuda:0')
    # 当if anchor_matching_gt.max() > 1成立时,num_fg > matching_matrix.sum().item()

    # fg_mask.sum().item() == 357
    fg_mask[fg_mask.clone()] = fg_mask_inboxes
    # fg_mask.sum().item() == 22
    # 更新fg_mask,本来fg_mask中有357个anchor point初步过滤后再gt center area内,然后经过simota第二次匹配找到pos anchor point
    # 注意这里[]内fg_mask.clone()的作用,是找到那357个的值,然后用fg_mask_inboxes替换
    # 这里fg_mask更新后,不用return,外面的fg_mask也更新了
    # 这里的fg_mask就是所有3549个anchor中哪几个anchor是正样本,正样本处的值为1

    matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
    # matched_gt_inds == tensor([7, 7, 8, 2, 1, 9, 4, 4, 10, 10, 3, 13, 5, 14, 12, 6, 15, 16, 11, 0, 0, 0], device='cuda:0')
    # 每个pos anchor匹配到了第几个gt的index
    # print(gt_classes) == tensor([6, 5, 12, 12, 12, 5, 12, 12, 5, 12, 12, 5, 5, 5, 5, 12, 12], device='cuda:0', dtype=torch.float16)
    gt_matched_classes = gt_classes[matched_gt_inds]  # 每个pos anchor匹配到的gt的实际类别索引
    # print(gt_matched_classes) == tensor([12, 12, 5, 12, 5, 12, 12, 12, 12, 12, 12, 5, 5, 5, 5, 12, 12, 12, 5, 6, 6, 6.], device='cuda:0', dtype=torch.float16)

    pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
        fg_mask_inboxes
    ]
    # 这里sum(0)沿列求和,一列只有1个值大于0,因为上面处理完后,一个anchor只能匹配一个gt。但一行可以有多个大于0的值,即1个gt可以和多个anchor匹配
    return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

simota和原本的ota的区别是,在得到cost矩阵后,ota通过sinkhorn算法进行匹配,而simota则直接选择topk个cost最小的anchor作为正样本,和最早的faster rcnn中的topk相似,只不过那里是选择iou最小,这里是选择cost最小,这里的cost不仅考虑了iou还考虑了分类损失和center prior。另外这里的k不是认为设置的固定值,而是dynamic k,具体是根据先选择q个iou最大的anchor(这里q仍然是人工设定的代码中取10),然后这10个iou求和取整得到k值。

python 复制代码
n_candidate_k = min(10, pair_wise_ious.size(1))  # 这里10就是文章中dynamic_k中的q
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)  # (17,10)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)  # (17)

一个gt可以匹配多个anchor,但一个anchor只能匹配一个gt,根据上面的规则选择cost最小的k个anchor后如果存在一个anchor匹配多个gt的情况,选择cost最小对应的gt作为匹配结果。

样本分配完后,就是计算损失了,这里没什么好讲的,回归损失采用的iou loss,分类损失和obj损失都是bce loss。yolox中作者在最后15个epoch关闭了mosaic数据增强,并添加了额外的L1 loss来增加回归的精度,这里L1 loss就是在特征图上计算的,预测就是特征图的原始输出,没有像iou loss一样加上grid并乘以stride映射会原图,这里target是将label反向映射到特征图上。

python 复制代码
def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
    l1_target[:, 0] = gt[:, 0] / stride - x_shifts
    l1_target[:, 1] = gt[:, 1] / stride - y_shifts
    l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
    l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
    return l1_target
相关推荐
星期天要睡觉8 小时前
计算机视觉(opencv)实战十八——图像透视转换
人工智能·opencv·计算机视觉
Morning的呀9 小时前
Class48 GRU
人工智能·深度学习·gru
拾零吖11 小时前
李宏毅 Deep Learning
人工智能·深度学习·机器学习
华芯邦11 小时前
广东充电芯片助力新能源汽车车载系统升级
人工智能·科技·车载系统·汽车·制造
时空无限12 小时前
说说transformer 中的掩码矩阵以及为什么能掩盖住词语
人工智能·矩阵·transformer
查里王12 小时前
AI 3D 生成工具知识库:当前产品格局与测评总结
人工智能·3d
武子康12 小时前
AI-调查研究-76-具身智能 当机器人走进生活:具身智能对就业与社会结构的深远影响
人工智能·程序人生·ai·职场和发展·机器人·生活·具身智能
小鹿清扫日记12 小时前
从蛮力清扫到 “会看路”:室外清洁机器人的文明进阶
人工智能·ai·机器人·扫地机器人·具身智能·连合直租·有鹿巡扫机器人
技术小黑12 小时前
Transformer系列 | Pytorch复现Transformer
pytorch·深度学习·transformer
fanstuck13 小时前
Prompt提示工程上手指南(六):AI避免“幻觉”(Hallucination)策略下的Prompt
人工智能·语言模型·自然语言处理·nlp·prompt