yolov11剪枝

思路:yolov11中的C3k2与yolov8的c2f的不同,所以与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:

1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以获得不同的剪枝率;

2.改代码放在训练代码同一页面下即可;

3.在最后修改文件夹地址来获得剪枝后的模型;

python 复制代码
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os


# os.environ["CUDA_VISIBLE_DEVICES"] = "2"


class PRUNE():
    def __init__(self) -> None:
        self.threshold = None

    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
                print()
        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]

    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## a. 根据BN中的参数,获取需要保留的index================
        gamma = conv1.bn.weight.data.detach()
        beta = conv1.bn.bias.data.detach()

        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        # scale = len(idxs) / n

        ## b. 利用index对BN进行剪枝============================
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n

        if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
            proto = conv2.pop()
            proto.cv1.conv.in_channels = n
            proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
        ## c. 利用index对conv1进行剪枝=========================
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

        ## d. 利用index对conv2进行剪枝=========================
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            if isinstance(item, Sequential):
                conv1 = item[0]
                conv = item[1].conv
                conv1.conv.in_channels = n
                conv1.conv.out_channels = n
                conv1.conv.groups = n
                conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
                conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
                conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
                conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
                conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
                conv1.bn.num_features = n
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]

    def prune(self, m1, m2):
        if isinstance(m1, C3k2):  # C2f as a top conv
            m1 = m1.cv2
        if isinstance(m1, Sequential):
            m1 = m1[1]
        if not isinstance(m2, list):  # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C3k2) or isinstance(item, SPPF):
                m2[i] = item.cv1

        self.prune_conv(m1, m2)


def do_pruning(modelpath, savepath):
    pruning = PRUNE()

    ### 0. 加载模型
    yolo = YOLO(modelpath)  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.65)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。

    ### 1. 剪枝c2f 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)

    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3, 5, 7, 8]:
        pruning.prune(seq[i], seq[i + 1])

    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
    detect: Detect = seq[-1]
    proto = detect.proto
    last_inputs = [seq[16], seq[19], seq[22]]
    colasts = [seq[17], seq[20], None]
    for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):
        if idx == 0:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])
        else:
            pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])
        pruning.prune(cv4[0], cv4[1])
        pruning.prune(cv4[1], cv4[2])

    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True

    yolo.val(data='data.yaml', batch=2, device=0, workers=0)
    torch.save(yolo.ckpt, savepath)
    # yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))
    # yolo.export(format="onnx")
    #
    # ## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小
    # yolo = YOLO(modelpath)  # build a new model from scratch
    # yolo.export(format="onnx")


if __name__ == "__main__":
    modelpath = "runs/segment/Constraint/weights/best.pt"
    savepath = "runs/segment/Constraint/weights/last_prune.pt"
    do_pruning(modelpath, savepath)
相关推荐
likerhood1 小时前
java中`==`和`.equals()`区别
java·开发语言·python
qq_283720052 小时前
Python Celery + FastAPI + Vue 全栈异步任务实战
vue.js·python·fastapi
2401_885885042 小时前
营销推广短信接口集成:结合营销策略实现的API接口动态变量填充方案
前端·python
telllong3 小时前
Python异步编程从入门到不懵:asyncio实战踩坑7连发
开发语言·python
lulu12165440785 小时前
Claude Code Harness架构技术深度解析:生产级AI Agent工程化实践
java·人工智能·python·ai编程
碧海银沙音频科技研究院5 小时前
1-1杰理蓝牙SOC的UI配置开发方法
人工智能·深度学习·算法
7年前端辞职转AI6 小时前
Python 文件操作
python·编程语言
龙文浩_7 小时前
AI梯度下降与PyTorch张量操作技术指南
人工智能·pytorch·python·深度学习·神经网络·机器学习·自然语言处理
呱牛do it7 小时前
企业级绩效考核系统设计与实现:基于FastAPI + Vue3的全栈解决方案
python·fastapi
清空mega7 小时前
动手学深度学习——样式迁移
人工智能·深度学习