CV党福音:YOLOv8实现实例分割(三)之损失函数

在上一篇博文中,我们讲解了YOLOv8实例分割的训练过程,已将前向传播过程分析完毕,那么,接下来便是损失计算过程了。

文章目录

训练整体流程

获得预测结果与真值后,即可计算损失。

整体流程如下:

预测结果preds如下:

真值batch如下:

分割整体损失函数

v8SegmentationLoss的计算过程如下,从最终的结果来看,其计算了四个损失,分别是目标预测框损失、mask损失、类别损失以及DEL损失,博主已将每段代码的结果标注在对应的代码位置。同时,在损失计算过程中不可避免的需要使用其他方法,博主将一些较为重要的方法也罗列出来了。

python 复制代码
def __call__(self, preds, batch):
        """Calculate and return the loss for the YOLO model."""
        loss = torch.zeros(4, device=self.device) # box, cls, dfl,mask,共 4 个
        feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] #feature是目标检测头输出的三个特征图list类型,pred_masks为torch.Size([4, 32, 8400]) ,oroto为torch.Size([4, 32, 160, 160])
        batch_size, _, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width
        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
            (self.reg_max * 4, self.nc), 1
        )#拆分目标检测特征图结果,并融合在一起,得到(4,64,8400)与(4,80,8400)

        # B, grids, ..   维度转换
        pred_scores = pred_scores.permute(0, 2, 1).contiguous()#(4,8400,80)
        pred_distri = pred_distri.permute(0, 2, 1).contiguous()#(4,8400,64)
        pred_masks = pred_masks.permute(0, 2, 1).contiguous()#(4,8400,32)

        dtype = pred_scores.dtype #torch.float16
        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

        # Targets
        try:
            batch_idx = batch["batch_idx"].view(-1, 1)
            targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)#torch.Size([22, 6]) 即batchid 类别id  x y w h 共6个数
            targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])#(batch,最大目标数量,5) 5=1+4
            gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy (4,7,1)(4,7,4)
            mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# 这行代码的作用是创建一个掩码张量 mask_gt,用于标识哪些目标是有效的(即有实际的边界框),哪些目标是无效的(即填充的零)。
        # sum(2, keepdim=True):沿着第 2 维(即每个边界框的坐标维度)进行求和,保持维度不变。
        # 这样会将每个边界框的 4 个坐标值相加,如果该目标是有效的边界框,和将大于 0;如果该目标是填充的零,和将等于 0。
        except RuntimeError as e

        # Pboxes
        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)

        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
            pred_scores.detach().sigmoid(),
            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
            anchor_points * stride_tensor,
            gt_labels,#(4,7,1)7指这个四个batch中
            gt_bboxes,#(4,7,4)
            mask_gt,
        )
"""
        self.assigner即根据分类与回归的分数加权的分数选择正样本 输入(共6个输入值)
        1.pred_scores:表示模型预测的每个锚点位置的分类分数 形状为:(batch_size, 8400, 80)
        2.(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype)
            pred_bboxes:解码后的边界框坐标,形状为 [batch_size, 8400, 4]。
            detach():从计算图中分离出预测的边界框,避免反向传播时更新它们。
            * stride_tensor:将边界框坐标乘以步幅张量 stride_tensor,恢复到原图尺度。
            type(gt_bboxes.dtype):将边界框的类型转换为与 gt_bboxes 相同的类型。
        3.anchor_points * stride_tensor
            anchor_points:锚点位置,形状为 [num_anchors, 2]
            将锚点位置乘以步幅张量 stride_tensor,恢复到原图尺度
        4.gt_labels:真实目标的类别标签
        5.gt_bboxes:真实目标的边界框坐标
        6.mask_gt:掩码张量,标识哪些目标是有效的
        """
        """
        输出(共5个返回值)
        target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签
        target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框
        target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。
        fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。
        fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。
            正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。
            负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标
        target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。
        """


        target_scores_sum = max(target_scores.sum(), 1)
        # Cls loss
        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
        loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE 即求分类损失

        if fg_mask.sum():
            # Bbox loss
            loss[0], loss[3] = self.bbox_loss(  #求Box损失和DEL损失
                pred_distri,
                pred_bboxes,
                anchor_points,
                target_bboxes / stride_tensor,
                target_scores,
                target_scores_sum,
                fg_mask,
            )
            # Masks loss
            masks = batch["masks"].to(self.device).float()
            if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
                masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]

            loss[1] = self.calculate_segmentation_loss(
                fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
            )
        # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
        else:
            loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
        loss[0] *= self.hyp.box  # box gain
        loss[1] *= self.hyp.box  # seg gain
        loss[2] *= self.hyp.cls  # cls gain
        loss[3] *= self.hyp.dfl  # dfl gain

        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)

生成网格

make_anchor方法,其作用是利用特征图生成网格的形式来创建预测框,这种产方式自YOLOv3起便一直沿用。

这里生成网格使用的方法为:x,y=torch.meshgrid(a,b)

torch.meshgrid(a,b)的功能是生成网格,可以用于生成坐标。函数输入两个数据类型相同的一维张量,两个输出张量的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数,当两个输入张量数据类型不同或维度不是一维时会报错。

其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素,各列元素相同。
a tensor([1, 2, 3, 4, 5, 6])
b tensor([ 7, 8, 9, 10])
x tensor([[1, 1, 1, 1],

2, 2, 2, 2\], \[3, 3, 3, 3\], \[4, 4, 4, 4\], \[5, 5, 5, 5\], \[6, 6, 6, 6\]\]) `y` tensor(\[\[4, 5, 6\], \[4, 5, 6\], \[4, 5, 6\], \[4, 5, 6\]\])

python 复制代码
def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)

我们以第一层特征图生成的网格为例,得到的sysx如下:

python 复制代码
sy: tensor([[ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
        [ 1.5000,  1.5000,  1.5000,  ...,  1.5000,  1.5000,  1.5000],
        [ 2.5000,  2.5000,  2.5000,  ...,  2.5000,  2.5000,  2.5000],
        ...,
        [77.5000, 77.5000, 77.5000,  ..., 77.5000, 77.5000, 77.5000],
        [78.5000, 78.5000, 78.5000,  ..., 78.5000, 78.5000, 78.5000],
        [79.5000, 79.5000, 79.5000,  ..., 79.5000, 79.5000, 79.5000]], device='cuda:0', dtype=torch.float16)
        
sx: tensor([[ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        ...,
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000]], device='cuda:0', dtype=torch.float16)

生成的anchor-point即为两者合并的结果,维度为(6400,2)

最终,将三个尺度的特征图产生的anchor合并,得到:

python 复制代码
anchor_point  :  torch.Size([8400, 2])       
stride_tensor   torch.Size([8400, 1]),其结果如下:
tensor([[ 8.], 8,16,32代表放大比例
        [ 8.],
        [ 8.],
        ...,
        [32.],
        [32.],
        [32.]], device='cuda:0', dtype=torch.float16)

前处理过程(真值转换)

这个方法是将target进行处理,初始的target(22,6)转换为(batch,最大目标数量,5)5即类别id+xywh,当然最后还会将xywh转换为坐标形式

python 复制代码
def preprocess(self, targets, batch_size, scale_tensor):
        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
        nl, ne = targets.shape
        if nl == 0:
            out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
        else:
            i = targets[:, 0]  # image index
            _, counts = i.unique(return_counts=True)#统计每张图像中目标的个数
            counts = counts.to(dtype=torch.int32)
            out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)#根据最大数量最大的图像设置维度,即(batch,最大目标数量,5)5即类别id+xywh
            for j in range(batch_size):  # 遍历每个批次
                matches = i == j  # 匹配当前批次的目标
                n = matches.sum()  # 目标数量
                if n:
                    out[j, :n] = targets[matches, 1:]  # 填充目标
            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))  # 转换坐标
        return out

随后拆分得到clsbox

python 复制代码
cls  tensor([[[22.],
         [22.],
         [22.],
         [22.],
         [ 0.],
         [ 0.],
         [ 0.]],

        [[45.],
         [50.],
         [45.],
         [45.],
         [49.],
         [49.],
         [49.]],

        [[22.],
         [58.],
         [75.],
         [58.],
         [58.],
         [75.],
         [75.]],

        [[58.],
         [75.],
         [23.],
         [23.],
         [ 0.],
         [ 0.],
         [ 0.]]], device='cuda:0')
box  tensor([[[2.7307e+02, 2.7907e+02, 5.4232e+02, 5.1043e+02],
         [2.7307e+02, 1.9674e+01, 5.4232e+02, 2.5103e+02],
         [6.8230e-01, 1.9674e+01, 1.5169e+02, 2.4743e+02],
         [6.8230e-01, 2.7907e+02, 1.5169e+02, 5.0683e+02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

        [[1.0442e+00, 1.5163e+02, 6.3999e+02, 5.4066e+02],
         [1.8152e-01, 2.0824e+02, 3.4580e+02, 5.4177e+02],
         [9.4192e+01, 8.0139e-02, 6.3958e+02, 3.8079e+02],
         [3.3554e-01, 4.0424e-01, 1.9339e+02, 2.1323e+02],
         [4.6191e+01, 2.7344e-02, 1.6055e+02, 9.2354e+01],
         [1.2538e+02, 2.2156e-02, 1.6390e+02, 1.4374e+01],
         [1.2584e+01, 2.2949e-02, 5.1544e+01, 1.2688e+01]],

        [[1.5334e+02, 1.5542e+02, 4.2644e+02, 3.9008e+02],
         [2.7954e+02, 4.2535e+02, 4.3731e+02, 6.2589e+02],
         [2.9995e+02, 5.0253e+02, 4.0279e+02, 6.2340e+02],
         [2.1976e-01, 4.3853e+02, 4.1101e+01, 5.4790e+02],
         [2.1976e-01, 1.7356e+02, 4.1101e+01, 2.8293e+02],
         [4.0710e-02, 2.3757e+02, 6.5855e+00, 2.8233e+02],
         [4.0710e-02, 5.0253e+02, 6.5855e+00, 5.4730e+02]],

        [[2.2732e+02, 1.2739e+02, 5.0236e+02, 4.7698e+02],
         [2.8749e+02, 2.6194e+02, 4.6677e+02, 4.7264e+02],
         [5.2375e+02, 1.8790e+01, 6.3991e+02, 7.6174e+01],
         [1.2816e+02, 4.9530e-01, 2.5127e+02, 1.9529e+01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]], device='cuda:0')

预测框解码

bbox_decode函数的目的是从预测的物体边界框坐标分布(pred_dist)和参考点(anchor_points)解码出实际的边界框坐标xyxy。(此时pred_dist4,8400,64)进行解码,得到(4,8400,4)

matmul方法为矩阵乘法,pred_dist在matmul之前, shape为[b, a, 4, 16],

self.proj的shape为[16], 最终的pred_dist的shape为[b, a, 4] 如果不理解可以直接使用 a =

torch.ones((1, 3, 4, 16)), 与b=torch.rand(16)进行matmul

python 复制代码
def bbox_decode(self, anchor_points, pred_dist):
        """从锚点和分布预测中解码出预测的目标边界框坐标。
        参数:
            anchor_points (torch.Tensor): 锚点坐标,形状为 [num_anchors, 2]。
            pred_dist (torch.Tensor): 预测的边界框分布,形状为 [batch_size, num_anchors, num_channels]。
        返回:
            torch.Tensor: 解码后的边界框坐标,形状为 [batch_size, num_anchors, 4]。
        """
        if self.use_dfl:
            b, a, c = pred_dist.shape  # 获取 batch 大小、锚点数量和通道数
            # 将预测分布变形为 [batch_size, num_anchors, 4, num_channels // 4],并在通道维度上应用 softmax
            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))#self.proj的值为tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15.], device='cuda:0')
        return dist2bbox(pred_dist, anchor_points, xywh=False)

TaskAlignedAssigner方法

该方法用于分配正负样本,即哪些锚点负责预测那个目标,没有分配到目标的锚点则认为预测的是背景。

TaskAlignedAssigner 的匹配策略简单总结为:根据分类与回归的分数加权的分数选择正样本。

  1. 计算真实框和预测框的匹配程度(分类与回归的分数加权的分数)。

其中,s 是预测类别分值,u 是预测框和真实框的ciou值,αβ 为权重超参数,两者相乘就可以衡量匹配程度,当分类的分值越高且ciou越高时,align_metric的值就越接近于1,此时预测框就与真实框越匹配,就越符合正样本的标准。

  1. 对于每个真实框,直接对align_metric匹配程度排序,选取topK个预测框作为正样本。
  2. 对一个预测框与多个真实框匹配测情况进行处理,保留ciou值最大的真实框。

得到的结果如下:

  • target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签

  • target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框

  • target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。

  • fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。

    fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。

    正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。

    负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标

  • target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。

python 复制代码
target_gt_idx的值如下:其行为8400个
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 6,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

Box损失

Box损失中包含IOU损失和DEL损失

传入的参数如下:

这里的pred_bboxes是预测的检测框,pred_dist则是用于DEL损失计算,

python 复制代码
 def BoxLoss(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)#(209,1)这个209代表匹配上的锚点数量
        iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)#也是(209,1)只不过该值求了
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        if self.dfl_loss:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl

bbox2dist方法

在求dfl损失时,有个bbox2dist方法,传入的参数如下:

python 复制代码
def bbox2dist(anchor_points, bbox, reg_max):
    """Transform bbox(xyxy) to dist(ltrb)."""
    x1y1, x2y2 = bbox.chunk(2, -1) #将box的坐标拆分,x1y1为左上角坐标,x2y2为右上角坐标维度均为(4,8400,2)
    return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)

其将中心点坐标anchor_points与左上、右下坐标做差并重写拼接,同时将值限制在0到15(准确的说是14.99)以内

python 复制代码
torch.clamp(input, min, max, out=None)

作用:限幅。将input的值限制在[min, max]之间,并返回结果。

举例如下:

python 复制代码
import torch
 
a = torch.arange(9).reshape(3, 3)   # 创建3*3的tensor
b = torch.clamp(a, 3, 6)     # 对a的值进行限幅,限制在[3, 6]
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)
 
 
'''   输出结果   '''
a: tensor([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
shape of a: torch.Size([3, 3])
 
b: tensor([[3, 3, 3],
           [3, 4, 5],
           [6, 6, 6]])
shape of b: torch.Size([3, 3])

那么,为何要这样做呢,我们知道将中心点坐标与左上右下的坐标(真值坐标)做差后得到的是宽高,将其限制在15以内就是说明其只负责30x30以内的目标,而我们的特征图此时最大的为80x80,这其实已将不小了,即认为图像中的目标应当都在这个范围内(大多数)毕竟其作用在DEL损失中。

经过该方法处理后得到target_ltrb,其维度仍为(4,8400,4)这个是真值的数据

随后便是计算DEL损失了,其定义如下:

DFL损失

DFL,全称Distribution Focal Loss(分布焦点损失),很多人一听到Focal Loss就立马想到分类,这没错,但DFL却是用在边框回归中。这个损失用于求中心点坐标到上下左右四条边的距离。

其中,y为真值坐标,yi与yi+1是预测出的距离值,S是其对应的概率

由于DEL模块输出的数据为(4,8400,64)这个64=4x16,4即4条边,16则可认为是距离,这是一个分布,通过softmax函数求出的是概率,即距离为1到16的概率。

具体过程如下:

1、 模型先生成一个reg_max(默认为16)的概率统计分布,其对应得是{ 0 , 1 , . . . , 7 , 8 , . . . , 14 , 15 } ,该值就是模型预测出来的anchor points到bbox边的距离。假设模型预测出的结果pred_dist{ 0.01 , 0.05 , . . . , 0.12 , 0.23 , . . . , 0.01 , 0.34 } 也就是anchor pointsbbox边的距离为0的概率是0.01,距离为15的概率为0.34。

2、 然后使用上面介绍的检测头代码中的self.dfl求出anchor pointsbbox的距离的期望y

就是模型预测的最终的anchor pointsbbox边的距离,这个期望最大是15,也就是说模型预测出的anchor pointsbbox边的距离最大是15

该距离的求法如下:

那么在预测过程中,这里是直接预测了,不需要再计算多个分布值了

而在训练时,可以按照如下理解:

如下图所示:

调用DELoss方法,传入的参数如下:

python 复制代码
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight

fg_mask代表在TaskAlignedAssigner方法中匹配上的正样本,其维度为(4,8400),其值为FalseTrue

因此其会提取pred_dit(4,8400,64)target_ltrb(4,8400,4)对应坐标内的数据,即在(4,8400)这个维度提取,由于fg_mask共有209个,因此取出的值有209个,同时通过viewpred_dist进行维度转换,即209x4=836

随后开始DFLoss计算:

python 复制代码
class DFLoss(nn.Module):
    """Criterion class for computing DFL losses during training."""
    def __init__(self, reg_max=16) -> None:
        """Initialize the DFL module."""
        super().__init__()
        self.reg_max = reg_max

    def __call__(self, pred_dist, target):
        target = target.clamp_(0, self.reg_max - 1 - 0.01)#将target限制在0到14.99之间(这个已经做过了,只是为保险起见)
        tl = target.long()  # target left转换为整数
        tr = tl + 1  # target right  就相当于y-yi了,因为已经将其整数化了
        wl = tr - target  # weight left
        wr = 1 - wl  # weight right
        return (
            F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
            + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
        ).mean(-1, keepdim=True)

掩码损失(分割损失)

将masks从batch中取出,并

python 复制代码
masks = batch["masks"].to(self.device).float()#(4,160,160)
if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample 但不执行
          masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]

          loss[1] = self.calculate_segmentation_loss(
                fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
            )

匹配上的锚点(正样本)

掩膜(真值)这个值是经过处理下采样后

每个batch中真值对应的类别

目标所属batch

真值预测框,这个是处理后的,为方便计算(转为了4,8400维)

预测的mask

single_mask_loss方法

该法用于求单张图片的分割损失。传入的参数如下:

这里的40指的是匹配上的锚点数量(四个batch的锚点数量为40,69,100,40)。可以看到,此时的pred_mask(预测的mask)与真值mask都已经转换为相同的维度,即(40,160,160)

这里求两者损失使用的是binary_cross_entropy_with_logits方法,即交叉熵损失函数,这里带着_with_logits的原因是不需要将数据传入前使用sigmoid/softmax映射到(0,1)之间了。

python 复制代码
    def single_mask_loss(
        gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
    ) -> torch.Tensor:
        pred_mask = torch.einsum("in,nhw->ihw", pred, proto)  # (n, 32) @ (32, 160, 160) -> (n, 160, 160)  n即数量
        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
        return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()

至此,完成了实例分割的损失计算

crop_mask方法,这个方法是为了让mask不要超界(不要超出box)

python 复制代码
def crop_mask(masks, boxes):
    """
    It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.

    Args:
        masks (torch.Tensor): [n, h, w] tensor of masks
        boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form

    Returns:
        (torch.Tensor): The masks are being cropped to the bounding box.
    """
    _, h, w = masks.shape
    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)

    return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))

最终的总损失如下:

python 复制代码
		loss[0] *= self.hyp.box  # box gain
        loss[1] *= self.hyp.box  # seg gain
        loss[2] *= self.hyp.cls  # cls gain
        loss[3] *= self.hyp.dfl  # dfl gain

模型参数更新

至此完成了损失计算过程。

随后,便是开始反向传播(这是个黑盒)

python 复制代码
self.scaler.scale(self.loss).backward()

完成后,边进行模型参数的更新即可:

python 复制代码
		self.trainer.train()#这里是训练完成
		# Update model and cfg after training
        if RANK in {-1, 0}:
            ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
            self.model, _ = attempt_load_one_weight(ckpt)
            self.overrides = self.model.args
            self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
        return self.metrics

总结

本章梳理了预测结果与真值的损失计算过程,可以加深我们对模型训练的理解。

相关推荐
智驱力人工智能9 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
工程师老罗9 小时前
YOLOv1 核心结构解析
yolo
Lun3866buzha9 小时前
YOLOv10-BiFPN融合:危险物体检测与识别的革新方案,从模型架构到实战部署全解析
yolo
Katecat9966310 小时前
YOLOv8-MambaOut在电子元器件缺陷检测中的应用与实践_1
yolo
工程师老罗11 小时前
YOLOv1 核心知识点笔记
笔记·yolo
工程师老罗16 小时前
基于Pytorch的YOLOv1 的网络结构代码
人工智能·pytorch·yolo
学习3人组19 小时前
YOLO模型集成到Label Studio的MODEL服务
yolo
孤狼warrior19 小时前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
水中加点糖21 小时前
小白都能看懂的——车牌检测与识别(最新版YOLO26快速入门)
人工智能·yolo·目标检测·计算机视觉·ai·车牌识别·lprnet
前端摸鱼匠1 天前
YOLOv8 环境配置全攻略:Python、PyTorch 与 CUDA 的和谐共生
人工智能·pytorch·python·yolo·目标检测