Yolov8 目标检测剪枝学习记录

最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能

这里主要参考了torch-pruning的基本使用v8模型剪枝Jetson nano部署剪枝YOLOv8
下面只是记录一个简单流程,用于后续使用在自己的任务和网络中,数据不作为参考

首先训练一个base模型用于参考

  • 环境:Ultralytics YOLOv8.2.18 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (NVIDIA H100 PCIe, 81008MiB)
  • 训练代码

参考网上或者自己写一个能训练即可,为了方便我将通用的记录下来,实测可用来自代码来源

dart 复制代码
from ultralytics import YOLO
import os

root = os.getcwd()
## 配置文件路径
name_yaml             = os.path.join(root, "ultralytics/datasets/VOC.yaml")
name_pretrain         = os.path.join(root, "yolov8s.pt")
## 原始训练路径
path_train            = os.path.join(root, "runs/detect/VOC")
name_train            = os.path.join(path_train, "weights/last.pt")
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
name_prune_before     = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after      = os.path.join(path_constraint_train, "weights/last_prune.pt")
## 微调路径
path_fineturn         = os.path.join(root, "runs/detect/VOC_finetune")

def else_api():
    path_data = ""
    path_result = ""
    model = YOLO(name_pretrain) 
    metrics = model.val()  # evaluate model performance on the validation set
    model.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)
    model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5)  # 这里的imgsz为高宽

def step1_train():
    model = YOLO(name_pretrain) 
    model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train)  # train the model

## 2024.3.4添加【amp=False】
def step2_Constraint_train():
    model = YOLO(name_train) 
    model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, amp=False, workers=16, save_period=1,name=path_constraint_train)  # train the model

def step3_pruning():
    from LL_pruning import do_pruning
    do_pruning(name_prune_before, name_prune_after)

def step4_finetune():
    model = YOLO(name_prune_after)     # load a pretrained model (recommended for training)
    model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn)  # train the model

step1_train()
# step2_Constraint_train()
# step3_pruning()
# step4_finetune()

第一步,step1_train()

  • 即训练一个base模型,用于最后性能好坏的重要参考

第二步,step2_Constraint_train()

训练之前在ultralytics\engine\trainer.py添加bn的L1正则,使得bn参数在训练时变得稀疏

  • 通过对参数的绝对值进行惩罚,使得一些不重要的权重变为零,从而实现模型的稀疏化和简化
dart 复制代码
     # Backward
     self.scaler.scale(self.loss).backward()
     ## add new code=============================duj
     ## add l1 regulation for step2_Constraint_train               
     l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
     for k, m in self.model.named_modules():
         if isinstance(m, nn.BatchNorm2d):
             m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
             m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

     # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
     if ni - last_opt_step >= self.accumulate:
         self.optimizer_step()
         last_opt_step = ni
  • 个人理解的稀疏化作用
    • 通过对 gamma 和 beta 添加 L1 正则化,可以促使某些通道的 BN 权重变得非常小,甚至为零。这意味着在剪枝时,可以将这些通道从模型中移除
    • 通过稀疏化 BN 层并剪除不重要的通道,剩下的通道会更有效地利用计算资源,减少无用计算。

第三步,step3_pruning()剪枝操作

LL_pruning.py

dart 复制代码
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
import os


class PRUNE():
    def __init__(self) -> None:
        self.threshold = None

    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
                print()
        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]

    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## a. 根据BN中的参数,获取需要保留的index================
        gamma = conv1.bn.weight.data.detach()
        beta  = conv1.bn.bias.data.detach()
        
        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        # scale = len(idxs) / n

        ## b. 利用index对BN进行剪枝============================
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data   = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n
        
        ## c. 利用index对conv1进行剪枝=========================
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

        ## d. 利用index对conv2进行剪枝=========================
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
     
    def prune(self, m1, m2):
        if isinstance(m1, C2f):      # C2f as a top conv
            m1 = m1.cv2
        if not isinstance(m2, list): # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C2f) or isinstance(item, SPPF):
                m2[i] = item.cv1
        self.prune_conv(m1, m2)
     
def do_pruning(modelpath, savepath):
    pruning = PRUNE()

    ### 0. 加载模型
    yolo = YOLO(modelpath)                  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.8)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。

    ### 1. 剪枝c2f 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)

    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3,5,7,8]: 
        pruning.prune(seq[i], seq[i+1])

    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1] 
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2] 
    detect:Detect = seq[-1]
    last_inputs   = [seq[15], seq[18], seq[21]]
    colasts       = [seq[16], seq[19], None]
    for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
        pruning.prune(last_input, [colast, cv2[0], cv3[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])

    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True
     
    yolo.val()
    torch.save(yolo.ckpt, savepath)
    yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))
    yolo.export(format="onnx")

    ## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小
    yolo = YOLO(modelpath)  # build a new model from scratch
    yolo.export(format="onnx")


if __name__ == "__main__":

    modelpath = "runs/detect1/14_Constraint/weights/last.pt"
    savepath  = "runs/detect1/14_Constraint/weights/last_prune.pt"
    do_pruning(modelpath, savepath)
  • 如下图可用看到剪枝前后还是有区别的,参数量减少很多,网络性能将不可用,需要微调恢复精度
  • 查看剪枝前后模型大小 du -sh ./runs/detect/VOC_Constraint/weights/last*yolov8n模型

微调

该部分内容我也存在一些疑问,例如很多博主让ultralytics\engine\trainer.py添加加载模型代码,经过我8.2版本测试代码添加是完全失效的,因为setup_model在执行if isinstance(self.model, torch.nn.Module)就已经return了。

dart 复制代码
 def setup_model(self):
        """Load/create/download model for any task."""
        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return
  • 例如ultralytics\engine\trainer.py
  • v8...x添加代码:548行 参考这里
dart 复制代码
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)
# duj add code to finetune
self.model = weights
return ckpt
  • 如果是v8.0.x 参考这里

在看到这篇中的修改1启发

  • v8.2.x上面我不确定是哪个版本需要添加的,但是我实测都不起作用
  • 我尝试在ultralytics\engine\model.py添加如下代码加载模型成功
dart 复制代码
 self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
        if not args.get("resume"):  # manually set model only if not resuming
            # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
            # self.model = self.trainer.model
            # dujiang edit 
            self.trainer.model = self.model.train()

            if SETTINGS["hub"] is True and not self.session:
  • 这里就是确保自己加载的是剪枝后的模型,但是不同版本好像不同,后续在探究原因。。。
  • 这里有个小插曲,我在使用自己模型稀疏训练后剪枝发现(步骤2)发现BN层全没了,这里后面我将别人的稀疏训练的v8s模型拿来进行剪枝就没问题
  • 可能是v8n的问题,也可能是我训练的问题,这里先不做深究继续查看剪枝是否成功且微调加载成功后能否恢复精度
  • 此时多次尝试我基本确定微调加载的是我剪枝后的模型,接下来就是等待训练结果是否参数量正确了。

总结

总的来说跑通整个流程了,接下来尝试在自己的任务和数据上面进行剪枝,看看更换了模型结构又会有哪些坑等着我

相关推荐
Lunar*2 小时前
华为 Ascend 平台 YOLOv5 目标检测推理教程
yolo·目标检测·华为
AI街潜水的八角12 小时前
PyTorch框架——基于深度学习YOLOv11神经网络路面坑洞检测系统
pytorch·深度学习·yolo
量子-Alex19 小时前
【遥感目标检测】【数据集】DOTA:用于航空图像中目标检测的大规模数据集
人工智能·目标检测·目标跟踪
云空1 天前
《探秘火焰目标检测开源模型:智能防火的科技利刃》
科技·目标检测·开源
小李学AI1 天前
基于YOLOv8的卫星图像中船只检测系统
人工智能·深度学习·神经网络·yolo·目标检测·机器学习·计算机视觉
刘争Stanley2 天前
量子计算:从薛定谔的猫到你的生活
人工智能·yolo·搜索引擎·生活·scikit-learn·量子计算·dall·e 2
sagima_sdu2 天前
YOLOv11 OBB 任务介绍与数据集构建要求及训练脚本使用指南
yolo
NiNg_1_2343 天前
YOLOv5训练长方形图像详解
人工智能·yolo·目标跟踪
robin_suli3 天前
穷举vs暴搜vs深搜vs回溯vs剪枝系列一>优美的排列
算法·剪枝·深度优先遍历·回溯