YOLOX源码之 Label Assignment

网络结构没什么好讲的,backbone、neck、head组成,backbone采用的cspdarknet,neck采用的pafpn,head是decoupled head结构。这里主要讲一下label assignment的具体实现,yolox中采用了simota,是ota的简化版本。

实现中标签分配以及计算损失部分是在yolo_head.py中,连带着head的网络层一起的,这里也顺带一起讲了。

首先forward函数的输入xin是neck的输出,当输入shape为(4,3,416,416)时,xin的shape为[(4,128,52,52),(4,256,26,26),(4,512,13,13)],对应8,16,32三种不同stride的输出特征图。

接下里的for循环是分别对三个特征图进行head部分网络层的forward,并计算对应的grids,grids具体是什么后面会讲。以stride=8对应的大小为(4,128,52,52)的特征图为例,self.stems[k]是一层1x1卷积,然后分类分支cls_conv和回归分支reg_conv都是2层3x3卷积,self.cls_preds[k]得到最终的分类输出shape为(b,num_classes,52,52),self.reg_preds[k]得到最终的回归输出shape为(b,4,52,52),self.obj_preds[k]得到最终的objectiveness输出shape为(b,1,52,52)。这里b=4,num_classes=16。

接下来将三个输出torch.cat得到输出shape为(4,21,52,52)。接下来函数self.get_output_and_grid()得到网格坐标grid和解码后的输出output。

代码如下

python 复制代码
def get_output_and_grid(self, output, k, stride, dtype):
    # (4,21,52,52)
    grid = self.grids[k]

    batch_size = output.shape[0]
    n_ch = 5 + self.num_classes
    hsize, wsize = output.shape[-2:]
    if grid.shape[2:4] != output.shape[2:4]:
        yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
        grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)  # (1,1,52,52,2), 先按行后按列,每个像素点的坐标
        self.grids[k] = grid

    output = output.view(batch_size, 1, n_ch, hsize, wsize)  # (4,1,21,52,52)
    output = output.permute(0, 1, 3, 4, 2).reshape(
        batch_size, hsize * wsize, -1
    )  # (4,2704,21)
    grid = grid.view(1, -1, 2)  # (1,2704,2)
    output[..., :2] = (output[..., :2] + grid) * stride
    output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
    return output, grid

self.grids是三个torch.Size([1])的列表,所以会进入到line8的if中。hsize和wsize分别是特征图的高和宽这里都是52,meshgrid返回的yv和xv分别是特征图每个像素点对应的y坐标和x坐标,如下所示

python 复制代码
tensor([[ 0,  0,  0,  ...,  0,  0,  0],                                                                                                                                
        [ 1,  1,  1,  ...,  1,  1,  1],                                                                                                                                
        [ 2,  2,  2,  ...,  2,  2,  2],                                                                                                                                
        ...,                                                                                                                                                           
        [49, 49, 49,  ..., 49, 49, 49],                                                                                                                                
        [50, 50, 50,  ..., 50, 50, 50],                                                                                                                                
        [51, 51, 51,  ..., 51, 51, 51]])                                                                                                                               
tensor([[ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        ...,                                                                                                                                                           
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51],                                                                                                                                
        [ 0,  1,  2,  ..., 49, 50, 51]])

然后将xy坐标stack得到每个点的xy坐标,shape为(1,1,52,52,2),按先行后列的顺序,如下

python 复制代码
tensor([[[[[ 0.,  0.],                                                                                                                                                 
           [ 1.,  0.],                                                                                                                                                 
           [ 2.,  0.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49.,  0.],                                                                                                                                                 
           [50.,  0.],                                                                                                                                                 
           [51.,  0.]],                                                                                                                                                
                                                                                                                                                                       
          [[ 0.,  1.],                                                                                                                                                 
           [ 1.,  1.],                                                                                                                                                 
           [ 2.,  1.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49.,  1.],                                                                                                                                                 
           [50.,  1.],                                                                                                                                                 
           [51.,  1.]],                                                                                                                                                                                                                                                                                            
                                                                                                                                                                       
          ...,                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                       
          [[ 0., 50.],                                                                                                                                                 
           [ 1., 50.],                                                                                                                                                 
           [ 2., 50.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49., 50.],                                                                                                                                                 
           [50., 50.],                                                                                                                                                 
           [51., 50.]],                                                                                                                                                
                                                                                                                                                                       
          [[ 0., 51.],                                                                                                                                                 
           [ 1., 51.],                                                                                                                                                 
           [ 2., 51.],                                                                                                                                                 
           ...,                                                                                                                                                        
           [49., 51.],                                                                                                                                                 
           [50., 51.],                                                                                                                                                 
           [51., 51.]]]]], device='cuda:0', dtype=torch.float16)

然后将output和grid分别view调整维度,output中每个点对应一个预测框,output[..., :2]是预测框中心点相对于每个点的偏移,因此line8加上每个点的坐标grid并乘以stride还原回原图上得到原图上真实预测框的中心点坐标。line9则是通过 \(e^{t}\) 并乘以stride得到原图上真实预测框的宽高。

在得到原图上预测框的坐标以及类别和objectiveness后,接下来就是进行label assignment并计算loss,具体实现都在self.get_losses()中。其中输入outputs是将坐标还原到原图中的三个特征图的输出并concat得到的,shape为(b, 3549, 21),3549=52x52+26x26+13x13,21=4+1+16。

在函数get_losses()中,调用self.get_assignments进行标签分配,这里使用的方法是simota。关于simota和ota的原理可参考OTA: Optimal Transport Assignment for Object Detection 原理与代码解读-CSDN博客https://blog.csdn.net/ooooocj/article/details/136569249。get_assignments()的完整实现如下

python 复制代码
@torch.no_grad()
def get_assignments(
    self,
    batch_idx,
    num_gt,
    gt_bboxes_per_image,  # (17,4)
    gt_classes,  # (17)
    bboxes_preds_per_image,  # (3549,4)
    expanded_strides,  # (1,3549)
    x_shifts,  # (1,3549)
    y_shifts,  # (1,3549)
    cls_preds,  # (4,3549,16)
    obj_preds,  # (4,3549,1)
    mode="gpu",
):

    if mode == "cpu":
        print("-----------Using CPU for the Current Batch-------------")
        gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
        bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
        gt_classes = gt_classes.cpu().float()
        expanded_strides = expanded_strides.cpu().float()
        x_shifts = x_shifts.cpu()
        y_shifts = y_shifts.cpu()

    fg_mask, geometry_relation = self.get_geometry_constraint(
        gt_bboxes_per_image,
        expanded_strides,
        x_shifts,
        y_shifts,
    )  # (3549), (17,357)
    # fg_mask中True位置的anchor point至少在一个gt box的center area内,后续会用来进行label assignment。而不在fg_mask中False位置的anchor point
    # 不在任意一个gt box的center area内。

    bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]  # (357,4)
    cls_preds_ = cls_preds[batch_idx][fg_mask]  # (357,16)
    obj_preds_ = obj_preds[batch_idx][fg_mask]  # (357,1)
    num_in_boxes_anchor = bboxes_preds_per_image.shape[0]  # 357

    if mode == "cpu":
        gt_bboxes_per_image = gt_bboxes_per_image.cpu()
        bboxes_preds_per_image = bboxes_preds_per_image.cpu()

    pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)  # (17,357)

    gt_cls_per_image = (
        F.one_hot(gt_classes.to(torch.int64), self.num_classes)
        .float()
    )  # (17,16)
    pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)  # (17,357)

    if mode == "cpu":
        cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()

    with torch.cuda.amp.autocast(enabled=False):
        cls_preds_ = (
            cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
        ).sqrt()
        pair_wise_cls_loss = F.binary_cross_entropy(
            cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),  # (357,16)->(1,357,16)->(17,357,16)
            gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),  # (17,16)->(17,1,16)->(17,357,16)
            reduction="none"
        ).sum(-1)  # (17,357), 共16个类别,每个类单独计算bce
    del cls_preds_

    cost = (
        pair_wise_cls_loss
        + 3.0 * pair_wise_ious_loss
        + float(1e6) * (~geometry_relation)  # center area之外的anchor point对应的cost加上一个很大的值来过滤
    )  # (17,357)

    (
        num_fg,  # 22
        gt_matched_classes,  # (22)
        pred_ious_this_matching,  # (22)
        matched_gt_inds,  # (22)
    ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

        if mode == "cpu":
            gt_matched_classes = gt_matched_classes.cuda()
            fg_mask = fg_mask.cuda()
            pred_ious_this_matching = pred_ious_this_matching.cuda()
            matched_gt_inds = matched_gt_inds.cuda()

        return (
            gt_matched_classes,
            fg_mask,  # (3549)
            pred_ious_this_matching,
            matched_gt_inds,
            num_fg,
        )

在ota中使用了center prior,即只有gt box中心有限区域内的anchor point作为正样本的candidate,而不是整个gt box内所有的anchor point都作为正样本的候选。函数get_geometry_constraint就是实现这个的

python 复制代码
def get_geometry_constraint(
    self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
):
    """
    Calculate whether the center of an object is located in a fixed range of
    an anchor. This is used to avert inappropriate matching. It can also reduce
    the number of candidate anchors so that the GPU memory is saved.
    """
    expanded_strides_per_image = expanded_strides[0]  # (3549)
    x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)  # (1,3549)
    y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)  # (1,3549)

    # in fixed center
    center_radius = 1.5  # 这里有可能center area区域比原目标还大
    center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius  # (1,3549)
    gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist  # (17,1) -> (17,3549)
    gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist  # (17,3549)
    gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist  # (17,3549)
    gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist  # (17,3549)

    c_l = x_centers_per_image - gt_bboxes_per_image_l  # (1,3549)-(17,3549) -> (17,3549)
    c_r = gt_bboxes_per_image_r - x_centers_per_image
    c_t = y_centers_per_image - gt_bboxes_per_image_t
    c_b = gt_bboxes_per_image_b - y_centers_per_image
    center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)  # (17,3549,4)
    is_in_centers = center_deltas.min(dim=-1).values > 0.0  # (17,3549)
    anchor_filter = is_in_centers.sum(dim=0) > 0  # (3549), 一共3549个anchor point, 对应位置为False, 说明这个anchor point不在任意一个gt box的center area内
    geometry_relation = is_in_centers[:, anchor_filter]  # (17,357), anchor_filter.sum()==357,表明某个anchor point至少在一个gt box的center area内

    return anchor_filter, geometry_relation

最终返回的anchor_filter是一个shape为(3549, )的tensor,值全为True或False。前面说过三个特征图一共3549个anchor point,值为False对应的anchor point不在任意一个gt box的center area内,后续进行标签分配时只从值为True的anchor point中挑选。当我用自己的数据调试时,另一个输出geometry_relation的shape为(17, 357),17是图中gt的数量,357是anchor_filter中值为True的anchor point的数量,geometry_relation表示每个gt的中心区域内对应的anchor point。

然后用fg_mask也就是anchor_filter挑选出候选的正样本,然后计算ota的cost matrix,cost矩阵包括分类损失以及回归损失,注意分类的预测要取sigmoid后并与obj预测相乘再与gt计算交叉熵损失,最后加上float(1e6) * (~geometry_relation)是对每个gt中心区域外的anchor加上一个特别大的cost,从而过滤它们。

在得到cost矩阵后,就是通过simota进行标签分配的过程了,具体实现在函数simota_matching中

python 复制代码
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
    # (17,357),(17,357),(17),17,(3549)
    matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)  # (17,357)

    n_candidate_k = min(10, pair_wise_ious.size(1))  # 这里10就是文章中dynamic_k中的q
    topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)  # (17,10)
    dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)  # (17)
    # tensor([3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0', dtype=torch.int32)
    # 每个gt选择q个最大iou值,相加取整作为为该gt分配的anchor point的个数
    for gt_idx in range(num_gt):
        _, pos_idx = torch.topk(
            cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
        )  # 选择cost最小的dynamic k个anchor point作为分配的正样本,代替原始OTA中的sinkhorn算法
        matching_matrix[gt_idx][pos_idx] = 1

    del topk_ious, dynamic_ks, pos_idx

    anchor_matching_gt = matching_matrix.sum(0)  # (357)
    # deal with the case that one anchor matches multiple ground-truths
    if anchor_matching_gt.max() > 1:
        multiple_match_mask = anchor_matching_gt > 1
        _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)  # 当一个anchor point匹配多个gt时,选择cost最小的gt作为匹配的结果
        matching_matrix[:, multiple_match_mask] *= 0
        matching_matrix[cost_argmin, multiple_match_mask] = 1
    fg_mask_inboxes = anchor_matching_gt > 0  # (357), pos anchor point的index
    num_fg = fg_mask_inboxes.sum().item()
    # num_fg==22, anchor_matching_gt.sum()==tensor(22, device='cuda:0')
    # 当if anchor_matching_gt.max() > 1成立时,num_fg > matching_matrix.sum().item()

    # fg_mask.sum().item() == 357
    fg_mask[fg_mask.clone()] = fg_mask_inboxes
    # fg_mask.sum().item() == 22
    # 更新fg_mask,本来fg_mask中有357个anchor point初步过滤后再gt center area内,然后经过simota第二次匹配找到pos anchor point
    # 注意这里[]内fg_mask.clone()的作用,是找到那357个的值,然后用fg_mask_inboxes替换
    # 这里fg_mask更新后,不用return,外面的fg_mask也更新了
    # 这里的fg_mask就是所有3549个anchor中哪几个anchor是正样本,正样本处的值为1

    matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
    # matched_gt_inds == tensor([7, 7, 8, 2, 1, 9, 4, 4, 10, 10, 3, 13, 5, 14, 12, 6, 15, 16, 11, 0, 0, 0], device='cuda:0')
    # 每个pos anchor匹配到了第几个gt的index
    # print(gt_classes) == tensor([6, 5, 12, 12, 12, 5, 12, 12, 5, 12, 12, 5, 5, 5, 5, 12, 12], device='cuda:0', dtype=torch.float16)
    gt_matched_classes = gt_classes[matched_gt_inds]  # 每个pos anchor匹配到的gt的实际类别索引
    # print(gt_matched_classes) == tensor([12, 12, 5, 12, 5, 12, 12, 12, 12, 12, 12, 5, 5, 5, 5, 12, 12, 12, 5, 6, 6, 6.], device='cuda:0', dtype=torch.float16)

    pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
        fg_mask_inboxes
    ]
    # 这里sum(0)沿列求和,一列只有1个值大于0,因为上面处理完后,一个anchor只能匹配一个gt。但一行可以有多个大于0的值,即1个gt可以和多个anchor匹配
    return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

simota和原本的ota的区别是,在得到cost矩阵后,ota通过sinkhorn算法进行匹配,而simota则直接选择topk个cost最小的anchor作为正样本,和最早的faster rcnn中的topk相似,只不过那里是选择iou最小,这里是选择cost最小,这里的cost不仅考虑了iou还考虑了分类损失和center prior。另外这里的k不是认为设置的固定值,而是dynamic k,具体是根据先选择q个iou最大的anchor(这里q仍然是人工设定的代码中取10),然后这10个iou求和取整得到k值。

python 复制代码
n_candidate_k = min(10, pair_wise_ious.size(1))  # 这里10就是文章中dynamic_k中的q
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)  # (17,10)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)  # (17)

一个gt可以匹配多个anchor,但一个anchor只能匹配一个gt,根据上面的规则选择cost最小的k个anchor后如果存在一个anchor匹配多个gt的情况,选择cost最小对应的gt作为匹配结果。

样本分配完后,就是计算损失了,这里没什么好讲的,回归损失采用的iou loss,分类损失和obj损失都是bce loss。yolox中作者在最后15个epoch关闭了mosaic数据增强,并添加了额外的L1 loss来增加回归的精度,这里L1 loss就是在特征图上计算的,预测就是特征图的原始输出,没有像iou loss一样加上grid并乘以stride映射会原图,这里target是将label反向映射到特征图上。

python 复制代码
def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
    l1_target[:, 0] = gt[:, 0] / stride - x_shifts
    l1_target[:, 1] = gt[:, 1] / stride - y_shifts
    l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
    l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
    return l1_target
相关推荐
四口鲸鱼爱吃盐14 分钟前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗15 分钟前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
bastgia1 小时前
Tokenformer: 下一代Transformer架构
人工智能·机器学习·llm
菜狗woc1 小时前
opencv-python的简单练习
人工智能·python·opencv
15年网络推广青哥1 小时前
国际抖音TikTok矩阵运营的关键要素有哪些?
大数据·人工智能·矩阵
weixin_387545641 小时前
探索 AnythingLLM:借助开源 AI 打造私有化智能知识库
人工智能
engchina2 小时前
如何在 Python 中忽略烦人的警告?
开发语言·人工智能·python
paixiaoxin3 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
OpenCSG3 小时前
CSGHub开源版本v1.2.0更新
人工智能
weixin_515202493 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习