CV党福音:YOLOv8实现实例分割(三)之损失函数

在上一篇博文中,我们讲解了YOLOv8实例分割的训练过程,已将前向传播过程分析完毕,那么,接下来便是损失计算过程了。

文章目录

训练整体流程

获得预测结果与真值后,即可计算损失。

整体流程如下:

预测结果preds如下:

真值batch如下:

分割整体损失函数

v8SegmentationLoss的计算过程如下,从最终的结果来看,其计算了四个损失,分别是目标预测框损失、mask损失、类别损失以及DEL损失,博主已将每段代码的结果标注在对应的代码位置。同时,在损失计算过程中不可避免的需要使用其他方法,博主将一些较为重要的方法也罗列出来了。

python 复制代码
def __call__(self, preds, batch):
        """Calculate and return the loss for the YOLO model."""
        loss = torch.zeros(4, device=self.device) # box, cls, dfl,mask,共 4 个
        feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] #feature是目标检测头输出的三个特征图list类型,pred_masks为torch.Size([4, 32, 8400]) ,oroto为torch.Size([4, 32, 160, 160])
        batch_size, _, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width
        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
            (self.reg_max * 4, self.nc), 1
        )#拆分目标检测特征图结果,并融合在一起,得到(4,64,8400)与(4,80,8400)

        # B, grids, ..   维度转换
        pred_scores = pred_scores.permute(0, 2, 1).contiguous()#(4,8400,80)
        pred_distri = pred_distri.permute(0, 2, 1).contiguous()#(4,8400,64)
        pred_masks = pred_masks.permute(0, 2, 1).contiguous()#(4,8400,32)

        dtype = pred_scores.dtype #torch.float16
        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)

        # Targets
        try:
            batch_idx = batch["batch_idx"].view(-1, 1)
            targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)#torch.Size([22, 6]) 即batchid 类别id  x y w h 共6个数
            targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])#(batch,最大目标数量,5) 5=1+4
            gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy (4,7,1)(4,7,4)
            mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# 这行代码的作用是创建一个掩码张量 mask_gt,用于标识哪些目标是有效的(即有实际的边界框),哪些目标是无效的(即填充的零)。
        # sum(2, keepdim=True):沿着第 2 维(即每个边界框的坐标维度)进行求和,保持维度不变。
        # 这样会将每个边界框的 4 个坐标值相加,如果该目标是有效的边界框,和将大于 0;如果该目标是填充的零,和将等于 0。
        except RuntimeError as e

        # Pboxes
        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)

        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
            pred_scores.detach().sigmoid(),
            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
            anchor_points * stride_tensor,
            gt_labels,#(4,7,1)7指这个四个batch中
            gt_bboxes,#(4,7,4)
            mask_gt,
        )
"""
        self.assigner即根据分类与回归的分数加权的分数选择正样本 输入(共6个输入值)
        1.pred_scores:表示模型预测的每个锚点位置的分类分数 形状为:(batch_size, 8400, 80)
        2.(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype)
            pred_bboxes:解码后的边界框坐标,形状为 [batch_size, 8400, 4]。
            detach():从计算图中分离出预测的边界框,避免反向传播时更新它们。
            * stride_tensor:将边界框坐标乘以步幅张量 stride_tensor,恢复到原图尺度。
            type(gt_bboxes.dtype):将边界框的类型转换为与 gt_bboxes 相同的类型。
        3.anchor_points * stride_tensor
            anchor_points:锚点位置,形状为 [num_anchors, 2]
            将锚点位置乘以步幅张量 stride_tensor,恢复到原图尺度
        4.gt_labels:真实目标的类别标签
        5.gt_bboxes:真实目标的边界框坐标
        6.mask_gt:掩码张量,标识哪些目标是有效的
        """
        """
        输出(共5个返回值)
        target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签
        target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框
        target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。
        fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。
        fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。
            正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。
            负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标
        target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。
        """


        target_scores_sum = max(target_scores.sum(), 1)
        # Cls loss
        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
        loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE 即求分类损失

        if fg_mask.sum():
            # Bbox loss
            loss[0], loss[3] = self.bbox_loss(  #求Box损失和DEL损失
                pred_distri,
                pred_bboxes,
                anchor_points,
                target_bboxes / stride_tensor,
                target_scores,
                target_scores_sum,
                fg_mask,
            )
            # Masks loss
            masks = batch["masks"].to(self.device).float()
            if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
                masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]

            loss[1] = self.calculate_segmentation_loss(
                fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
            )
        # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
        else:
            loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
        loss[0] *= self.hyp.box  # box gain
        loss[1] *= self.hyp.box  # seg gain
        loss[2] *= self.hyp.cls  # cls gain
        loss[3] *= self.hyp.dfl  # dfl gain

        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)

生成网格

make_anchor方法,其作用是利用特征图生成网格的形式来创建预测框,这种产方式自YOLOv3起便一直沿用。

这里生成网格使用的方法为:x,y=torch.meshgrid(a,b)

torch.meshgrid(a,b)的功能是生成网格,可以用于生成坐标。函数输入两个数据类型相同的一维张量,两个输出张量的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数,当两个输入张量数据类型不同或维度不是一维时会报错。

其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素,各列元素相同。
a tensor([1, 2, 3, 4, 5, 6])
b tensor([ 7, 8, 9, 10])
x tensor([[1, 1, 1, 1],

[2, 2, 2, 2],

[3, 3, 3, 3],

[4, 4, 4, 4],

[5, 5, 5, 5],

[6, 6, 6, 6]])
y tensor([[4, 5, 6],

[4, 5, 6],

[4, 5, 6],

[4, 5, 6]])

python 复制代码
def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)

我们以第一层特征图生成的网格为例,得到的sysx如下:

python 复制代码
sy: tensor([[ 0.5000,  0.5000,  0.5000,  ...,  0.5000,  0.5000,  0.5000],
        [ 1.5000,  1.5000,  1.5000,  ...,  1.5000,  1.5000,  1.5000],
        [ 2.5000,  2.5000,  2.5000,  ...,  2.5000,  2.5000,  2.5000],
        ...,
        [77.5000, 77.5000, 77.5000,  ..., 77.5000, 77.5000, 77.5000],
        [78.5000, 78.5000, 78.5000,  ..., 78.5000, 78.5000, 78.5000],
        [79.5000, 79.5000, 79.5000,  ..., 79.5000, 79.5000, 79.5000]], device='cuda:0', dtype=torch.float16)
        
sx: tensor([[ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        ...,
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000],
        [ 0.5000,  1.5000,  2.5000,  ..., 77.5000, 78.5000, 79.5000]], device='cuda:0', dtype=torch.float16)

生成的anchor-point即为两者合并的结果,维度为(6400,2)

最终,将三个尺度的特征图产生的anchor合并,得到:

python 复制代码
anchor_point  :  torch.Size([8400, 2])       
stride_tensor   torch.Size([8400, 1]),其结果如下:
tensor([[ 8.], 8,16,32代表放大比例
        [ 8.],
        [ 8.],
        ...,
        [32.],
        [32.],
        [32.]], device='cuda:0', dtype=torch.float16)

前处理过程(真值转换)

这个方法是将target进行处理,初始的target(22,6)转换为(batch,最大目标数量,5)5即类别id+xywh,当然最后还会将xywh转换为坐标形式

python 复制代码
def preprocess(self, targets, batch_size, scale_tensor):
        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
        nl, ne = targets.shape
        if nl == 0:
            out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
        else:
            i = targets[:, 0]  # image index
            _, counts = i.unique(return_counts=True)#统计每张图像中目标的个数
            counts = counts.to(dtype=torch.int32)
            out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)#根据最大数量最大的图像设置维度,即(batch,最大目标数量,5)5即类别id+xywh
            for j in range(batch_size):  # 遍历每个批次
                matches = i == j  # 匹配当前批次的目标
                n = matches.sum()  # 目标数量
                if n:
                    out[j, :n] = targets[matches, 1:]  # 填充目标
            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))  # 转换坐标
        return out

随后拆分得到clsbox

python 复制代码
cls  tensor([[[22.],
         [22.],
         [22.],
         [22.],
         [ 0.],
         [ 0.],
         [ 0.]],

        [[45.],
         [50.],
         [45.],
         [45.],
         [49.],
         [49.],
         [49.]],

        [[22.],
         [58.],
         [75.],
         [58.],
         [58.],
         [75.],
         [75.]],

        [[58.],
         [75.],
         [23.],
         [23.],
         [ 0.],
         [ 0.],
         [ 0.]]], device='cuda:0')
box  tensor([[[2.7307e+02, 2.7907e+02, 5.4232e+02, 5.1043e+02],
         [2.7307e+02, 1.9674e+01, 5.4232e+02, 2.5103e+02],
         [6.8230e-01, 1.9674e+01, 1.5169e+02, 2.4743e+02],
         [6.8230e-01, 2.7907e+02, 1.5169e+02, 5.0683e+02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

        [[1.0442e+00, 1.5163e+02, 6.3999e+02, 5.4066e+02],
         [1.8152e-01, 2.0824e+02, 3.4580e+02, 5.4177e+02],
         [9.4192e+01, 8.0139e-02, 6.3958e+02, 3.8079e+02],
         [3.3554e-01, 4.0424e-01, 1.9339e+02, 2.1323e+02],
         [4.6191e+01, 2.7344e-02, 1.6055e+02, 9.2354e+01],
         [1.2538e+02, 2.2156e-02, 1.6390e+02, 1.4374e+01],
         [1.2584e+01, 2.2949e-02, 5.1544e+01, 1.2688e+01]],

        [[1.5334e+02, 1.5542e+02, 4.2644e+02, 3.9008e+02],
         [2.7954e+02, 4.2535e+02, 4.3731e+02, 6.2589e+02],
         [2.9995e+02, 5.0253e+02, 4.0279e+02, 6.2340e+02],
         [2.1976e-01, 4.3853e+02, 4.1101e+01, 5.4790e+02],
         [2.1976e-01, 1.7356e+02, 4.1101e+01, 2.8293e+02],
         [4.0710e-02, 2.3757e+02, 6.5855e+00, 2.8233e+02],
         [4.0710e-02, 5.0253e+02, 6.5855e+00, 5.4730e+02]],

        [[2.2732e+02, 1.2739e+02, 5.0236e+02, 4.7698e+02],
         [2.8749e+02, 2.6194e+02, 4.6677e+02, 4.7264e+02],
         [5.2375e+02, 1.8790e+01, 6.3991e+02, 7.6174e+01],
         [1.2816e+02, 4.9530e-01, 2.5127e+02, 1.9529e+01],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]], device='cuda:0')

预测框解码

bbox_decode函数的目的是从预测的物体边界框坐标分布(pred_dist)和参考点(anchor_points)解码出实际的边界框坐标xyxy。(此时pred_dist4,8400,64)进行解码,得到(4,8400,4)

matmul方法为矩阵乘法,pred_dist在matmul之前, shape为[b, a, 4, 16],

self.proj的shape为[16], 最终的pred_dist的shape为[b, a, 4] 如果不理解可以直接使用 a =

torch.ones((1, 3, 4, 16)), 与b=torch.rand(16)进行matmul

python 复制代码
def bbox_decode(self, anchor_points, pred_dist):
        """从锚点和分布预测中解码出预测的目标边界框坐标。
        参数:
            anchor_points (torch.Tensor): 锚点坐标,形状为 [num_anchors, 2]。
            pred_dist (torch.Tensor): 预测的边界框分布,形状为 [batch_size, num_anchors, num_channels]。
        返回:
            torch.Tensor: 解码后的边界框坐标,形状为 [batch_size, num_anchors, 4]。
        """
        if self.use_dfl:
            b, a, c = pred_dist.shape  # 获取 batch 大小、锚点数量和通道数
            # 将预测分布变形为 [batch_size, num_anchors, 4, num_channels // 4],并在通道维度上应用 softmax
            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))#self.proj的值为tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15.], device='cuda:0')
        return dist2bbox(pred_dist, anchor_points, xywh=False)

TaskAlignedAssigner方法

该方法用于分配正负样本,即哪些锚点负责预测那个目标,没有分配到目标的锚点则认为预测的是背景。

TaskAlignedAssigner 的匹配策略简单总结为:根据分类与回归的分数加权的分数选择正样本。

  1. 计算真实框和预测框的匹配程度(分类与回归的分数加权的分数)。

其中,s 是预测类别分值,u 是预测框和真实框的ciou值,αβ 为权重超参数,两者相乘就可以衡量匹配程度,当分类的分值越高且ciou越高时,align_metric的值就越接近于1,此时预测框就与真实框越匹配,就越符合正样本的标准。

  1. 对于每个真实框,直接对align_metric匹配程度排序,选取topK个预测框作为正样本。
  2. 对一个预测框与多个真实框匹配测情况进行处理,保留ciou值最大的真实框。

得到的结果如下:

  • target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签

  • target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框

  • target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。

  • fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。

    fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。

    正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。

    负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标

  • target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。

python 复制代码
target_gt_idx的值如下:其行为8400个
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 6,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

Box损失

Box损失中包含IOU损失和DEL损失

传入的参数如下:

这里的pred_bboxes是预测的检测框,pred_dist则是用于DEL损失计算,

python 复制代码
 def BoxLoss(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)#(209,1)这个209代表匹配上的锚点数量
        iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)#也是(209,1)只不过该值求了
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        if self.dfl_loss:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
            loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl

bbox2dist方法

在求dfl损失时,有个bbox2dist方法,传入的参数如下:

python 复制代码
def bbox2dist(anchor_points, bbox, reg_max):
    """Transform bbox(xyxy) to dist(ltrb)."""
    x1y1, x2y2 = bbox.chunk(2, -1) #将box的坐标拆分,x1y1为左上角坐标,x2y2为右上角坐标维度均为(4,8400,2)
    return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)

其将中心点坐标anchor_points与左上、右下坐标做差并重写拼接,同时将值限制在0到15(准确的说是14.99)以内

python 复制代码
torch.clamp(input, min, max, out=None)

作用:限幅。将input的值限制在[min, max]之间,并返回结果。

举例如下:

python 复制代码
import torch
 
a = torch.arange(9).reshape(3, 3)   # 创建3*3的tensor
b = torch.clamp(a, 3, 6)     # 对a的值进行限幅,限制在[3, 6]
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)
 
 
'''   输出结果   '''
a: tensor([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
shape of a: torch.Size([3, 3])
 
b: tensor([[3, 3, 3],
           [3, 4, 5],
           [6, 6, 6]])
shape of b: torch.Size([3, 3])

那么,为何要这样做呢,我们知道将中心点坐标与左上右下的坐标(真值坐标)做差后得到的是宽高,将其限制在15以内就是说明其只负责30x30以内的目标,而我们的特征图此时最大的为80x80,这其实已将不小了,即认为图像中的目标应当都在这个范围内(大多数)毕竟其作用在DEL损失中。

经过该方法处理后得到target_ltrb,其维度仍为(4,8400,4)这个是真值的数据

随后便是计算DEL损失了,其定义如下:

DFL损失

DFL,全称Distribution Focal Loss(分布焦点损失),很多人一听到Focal Loss就立马想到分类,这没错,但DFL却是用在边框回归中。这个损失用于求中心点坐标到上下左右四条边的距离。

其中,y为真值坐标,yi与yi+1是预测出的距离值,S是其对应的概率

由于DEL模块输出的数据为(4,8400,64)这个64=4x16,4即4条边,16则可认为是距离,这是一个分布,通过softmax函数求出的是概率,即距离为1到16的概率。

具体过程如下:

1、 模型先生成一个reg_max(默认为16)的概率统计分布,其对应得是{ 0 , 1 , . . . , 7 , 8 , . . . , 14 , 15 } ,该值就是模型预测出来的anchor points到bbox边的距离。假设模型预测出的结果pred_dist{ 0.01 , 0.05 , . . . , 0.12 , 0.23 , . . . , 0.01 , 0.34 } 也就是anchor pointsbbox边的距离为0的概率是0.01,距离为15的概率为0.34。

2、 然后使用上面介绍的检测头代码中的self.dfl求出anchor pointsbbox的距离的期望y

就是模型预测的最终的anchor pointsbbox边的距离,这个期望最大是15,也就是说模型预测出的anchor pointsbbox边的距离最大是15

该距离的求法如下:

那么在预测过程中,这里是直接预测了,不需要再计算多个分布值了

而在训练时,可以按照如下理解:

如下图所示:

调用DELoss方法,传入的参数如下:

python 复制代码
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight

fg_mask代表在TaskAlignedAssigner方法中匹配上的正样本,其维度为(4,8400),其值为FalseTrue

因此其会提取pred_dit(4,8400,64)target_ltrb(4,8400,4)对应坐标内的数据,即在(4,8400)这个维度提取,由于fg_mask共有209个,因此取出的值有209个,同时通过viewpred_dist进行维度转换,即209x4=836

随后开始DFLoss计算:

python 复制代码
class DFLoss(nn.Module):
    """Criterion class for computing DFL losses during training."""
    def __init__(self, reg_max=16) -> None:
        """Initialize the DFL module."""
        super().__init__()
        self.reg_max = reg_max

    def __call__(self, pred_dist, target):
        target = target.clamp_(0, self.reg_max - 1 - 0.01)#将target限制在0到14.99之间(这个已经做过了,只是为保险起见)
        tl = target.long()  # target left转换为整数
        tr = tl + 1  # target right  就相当于y-yi了,因为已经将其整数化了
        wl = tr - target  # weight left
        wr = 1 - wl  # weight right
        return (
            F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
            + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
        ).mean(-1, keepdim=True)

掩码损失(分割损失)

将masks从batch中取出,并

python 复制代码
masks = batch["masks"].to(self.device).float()#(4,160,160)
if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample 但不执行
          masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]

          loss[1] = self.calculate_segmentation_loss(
                fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
            )

匹配上的锚点(正样本)

掩膜(真值)这个值是经过处理下采样后

每个batch中真值对应的类别

目标所属batch

真值预测框,这个是处理后的,为方便计算(转为了4,8400维)

预测的mask

single_mask_loss方法

该法用于求单张图片的分割损失。传入的参数如下:

这里的40指的是匹配上的锚点数量(四个batch的锚点数量为40,69,100,40)。可以看到,此时的pred_mask(预测的mask)与真值mask都已经转换为相同的维度,即(40,160,160)

这里求两者损失使用的是binary_cross_entropy_with_logits方法,即交叉熵损失函数,这里带着_with_logits的原因是不需要将数据传入前使用sigmoid/softmax映射到(0,1)之间了。

python 复制代码
    def single_mask_loss(
        gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
    ) -> torch.Tensor:
        pred_mask = torch.einsum("in,nhw->ihw", pred, proto)  # (n, 32) @ (32, 160, 160) -> (n, 160, 160)  n即数量
        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
        return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()

至此,完成了实例分割的损失计算

crop_mask方法,这个方法是为了让mask不要超界(不要超出box)

python 复制代码
def crop_mask(masks, boxes):
    """
    It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.

    Args:
        masks (torch.Tensor): [n, h, w] tensor of masks
        boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form

    Returns:
        (torch.Tensor): The masks are being cropped to the bounding box.
    """
    _, h, w = masks.shape
    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)

    return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))

最终的总损失如下:

python 复制代码
		loss[0] *= self.hyp.box  # box gain
        loss[1] *= self.hyp.box  # seg gain
        loss[2] *= self.hyp.cls  # cls gain
        loss[3] *= self.hyp.dfl  # dfl gain

模型参数更新

至此完成了损失计算过程。

随后,便是开始反向传播(这是个黑盒)

python 复制代码
self.scaler.scale(self.loss).backward()

完成后,边进行模型参数的更新即可:

python 复制代码
		self.trainer.train()#这里是训练完成
		# Update model and cfg after training
        if RANK in {-1, 0}:
            ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
            self.model, _ = attempt_load_one_weight(ckpt)
            self.overrides = self.model.args
            self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
        return self.metrics

总结

本章梳理了预测结果与真值的损失计算过程,可以加深我们对模型训练的理解。

相关推荐
Suyuoa6 小时前
附录2-pytorch yolov5目标检测
python·深度学习·yolo
红色的山茶花1 天前
YOLOv8-ultralytics-8.2.103部分代码阅读笔记-block.py
笔记·深度学习·yolo
unix2linux1 天前
YOLO v5 Series - Image & Video Storage ( Openresty + Lua)
yolo·lua·openresty
菠菠萝宝1 天前
【YOLOv8】安卓端部署-1-项目介绍
android·java·c++·yolo·目标检测·目标跟踪·kotlin
ZZZZ_Y_1 天前
YOLOv5指定标签框背景颜色和标签字
yolo
红色的山茶花2 天前
YOLOv8-ultralytics-8.2.103部分代码阅读笔记-conv.py
笔记·yolo
Eric.Lee20212 天前
数据集-目标检测系列- 花卉 鸡蛋花 检测数据集 frangipani >> DataBall
人工智能·python·yolo·目标检测·计算机视觉·鸡蛋花检查
阿_旭2 天前
【模型级联】YOLO-World与SAM2通过文本实现指定目标的零样本分割
yolo·yolo-world·sam2
CSBLOG2 天前
OpenCV、YOLO、VOC、COCO之间的关系和区别
人工智能·opencv·yolo
2zcode2 天前
基于YOLOv8深度学习的医学影像骨折检测诊断系统研究与实现(PyQt5界面+数据集+训练代码)
人工智能·深度学习·yolo