yolov8涨点系列之损失函数替换

文章目录

损失函数对模型的重要性

1.模型训练的导向作用

损失函数是衡量模型预测结果与真实标签之间差异的指标。在 YOLOv8 的训练过程中,它就像是一个 "指南针",引导模型朝着正确的方向更新参数。例如,在目标检测任务中,模型需要同时预测目标的位置(边界框)和类别。损失函数会综合考虑位置预测误差和类别预测误差,促使模型学习到如何准确地定位目标和识别类别。

2.评估模型性能的关键指标

损失函数的值可以直观地反映模型的训练状态。随着训练的进行,如果损失函数的值不断下降,说明模型在不断地优化,逐渐减小与真实情况的差距。反之,如果损失函数的值不再下降甚至上升,可能意味着模型出现了过拟合或者训练过程中存在其他问题。例如,在 YOLOv8 的早期训练阶段,损失函数的值可能会比较高,因为模型还没有学习到足够的特征来进行准确的预测;而在训练后期,损失函数应该逐渐趋近于一个较小的值。

3.影响模型泛化能力

合适的损失函数有助于提高模型的泛化能力。泛化能力是指模型对未见过的数据进行准确预测的能力。通过损失函数的引导,模型能够学习到数据中的一般性规律,而不是仅仅记住训练数据的特征。在 YOLOv8 中,例如对于不同场景下的目标检测,合理的损失函数可以让模型更好地适应新的场景、目标大小和类别分布,从而在实际应用中发挥更好的作用。

合理损失函数的涨点优势

4.提高检测精度

合理的损失函数能够更精准地衡量预测误差。以 YOLOv8 为例,在目标的位置预测方面,一些先进的损失函数可以更好地处理边界框的回归问题。例如,采用 CIoU(Complete Intersection over Union)损失函数代替传统的 IoU(Intersection over Union)损失函数,CIoU 损失函数不仅考虑了边界框的重叠面积,还考虑了边界框中心点之间的距离和宽高比等因素。这使得模型在预测边界框位置时更加准确,从而提高了目标检测的精度,在实验中可能会带来几个百分点的精度提升。

5.加速收敛速度

合适的损失函数可以优化模型的训练过程,使模型更快地收敛到一个较好的状态。例如,一些自适应的损失函数能够根据不同的样本难度自动调整权重。在 YOLOv8 训练中,对于容易预测的样本,损失函数可以降低其权重,而对于难样本(例如目标较小、被遮挡的目标等),则增加其权重。这样可以让模型更快地聚焦于学习难样本的特征,从而加速收敛速度,减少训练时间。

6.增强模型鲁棒性

鲁棒性是指模型在面对数据噪声、数据分布变化等情况时的稳定性。合理的损失函数可以使模型在训练过程中更好地应对这些情况。例如,在 YOLOv8 的训练数据中,如果存在一定比例的标注错误或者模糊的图像,一个好的损失函数可以通过对不同类型的误差进行合理分配权重,降低标注错误等异常情况对模型的负面影响,使模型能够更稳定地学习到正确的特征,从而增强模型的鲁棒性。

损失函数替换步骤

(1)首先打开loss.py和tal.py这两个文件

(2)依次点击"BboxLoss"------>"forward"------>"bbox_iou";点击之后会定位到"metrics.py"------>"bbox_iou"

(3)将bbox_iou函数删除,添加以下函数:

以下包含了:

bash 复制代码
class WIoU_Scale:
    ''' monotonous: {
            None: origin v1
            True: monotonic FM v2
            False: non-monotonic FM v3
        }
        momentum: The momentum of running mean'''

    iou_mean = 1.
    monotonous = False
    _momentum = 1 - 0.5 ** (1 / 7000)
    _is_train = True

    def __init__(self, iou):
        self.iou = iou
        self._update(self)

    @classmethod
    def _update(cls, self):
        if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
                                         cls._momentum * self.iou.detach().mean().item()

    @classmethod
    def _scaled_loss(cls, self, gamma=1.9, delta=3):
        if isinstance(self.monotonous, bool):
            if self.monotonous:
                return (self.iou.detach() / self.iou_mean).sqrt()
            else:
                beta = self.iou.detach() / self.iou_mean
                alpha = delta * torch.pow(gamma, beta - delta)
                return beta / alpha
        return 1


def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False, Focal=False,
             alpha=1, gamma=0.5, scale=False, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps
    if scale:
        self = WIoU_Scale(1 - (inter / union))

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter / (union + eps), alpha)  # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU or WIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (
                        b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),
                                                                                                 gamma)  # Focal_CIoU
                else:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps),
                                                                                      gamma)  # Focal_EIou
                else:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)  # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1 = torch.abs(s_cw) / sigma
                sin_alpha_2 = torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x = (s_cw / cw) ** 2
                rho_y = (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(
                        inter / (union + eps), gamma)  # Focal_SIou
                else:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha)  # SIou
            elif WIoU:
                if Focal:
                    raise RuntimeError("WIoU do not support Focal.")
                elif scale:
                    return getattr(WIoU_Scale, '_scaled_loss')(self), (1 - iou) * torch.exp(
                        (rho2 / c2)), iou  # WIoU https://arxiv.org/abs/2301.10051
                else:
                    return iou, torch.exp((rho2 / c2))  # WIoU v1
            if Focal:
                return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma)  # Focal_DIoU
            else:
                return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps),
                                                                                      gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        else:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter / (union + eps), gamma)  # Focal_IoU
    else:
        return iou  # IoU

(4)除了要替换bbox_ios,还需要切换回loss.py文件,将红框所在行代码替换为:

bash 复制代码
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, WIoU=True, scale=True)
if type(iou) is tuple:
    if len(iou) == 2:
        loss_iou = ((1.0 - iou[0]) * iou[1].detach() * weight).sum() / target_scores_sum
    else:
        loss_iou = (iou[0] * iou[1] * weight).sum() / target_scores_sum
else:
    loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

(5)指明损失函数:例如使用SIoU

①loss.py中改成SIoU

②tal.py中改成SIoU

相关推荐
新新学长搞科研2 小时前
【自动识别相关会议】第五届机器视觉、自动识别与检测国际学术会议(MVAID 2026)
人工智能·目标检测·计算机视觉·自动化·视觉检测·能源·语音识别
JicasdC123asd18 小时前
密集残差瓶颈网络改进YOLOv26特征复用与梯度传播双重优化
网络·yolo·目标跟踪
JicasdC123asd21 小时前
密集连接瓶颈模块改进YOLOv26特征复用与梯度流动双重优化
人工智能·yolo·目标跟踪
duyinbi75171 天前
局部特征提取改进YOLOv26空间移位卷积与轻量化设计双重突破
人工智能·yolo·目标跟踪
张道宁1 天前
基于Spring Boot与Docker的YOLOv8检测服务实战
spring boot·yolo·docker
duyinbi75171 天前
大核瓶颈架构改进YOLOv26扩大感受野与多尺度特征提取双重突破
yolo·架构
孤狼warrior1 天前
YOLO技术架构发展详解(从v1到v8)近万字底层实现逻辑解析
yolo
Coovally AI模型快速验证1 天前
无人机 RGB+热红外融合检测建筑裂缝与渗漏,34 层高楼约 2 小时
目标检测·计算机视觉·无人机·智慧城市·裂缝检测·渗漏检测
张张123y1 天前
机器学习与深度学习:从基础概念到YOLOv8全解析
深度学习·yolo·机器学习
AI浩2 天前
CollabOD:用于无人机小目标检测的跨尺度视觉协作多骨干网络
人工智能·目标检测·无人机