yolov5及yolov7实战之剪枝

之前有讲过一次yolov5的剪枝:yolov5实战之模型剪枝_yolov5模型剪枝-CSDN博客

当时基于的是比较老的yolov5版本,剪枝对整个训练代码的改动也比较多。最近发现一个比较好用的剪枝库,可以在不怎么改动原有训练代码的情况下,实现剪枝的操作,这篇文章就简单介绍一下,剪枝的概念以及为什么要剪枝可以参看上一篇,这里就不赘述了。

Torch-Pruning

VainF/Torch-Pruning: [CVPR 2023] Towards Any Structural Pruning; LLMs / Diffusion / Transformers / YOLOv8 / CNNs (github.com)

今天我们要用到的就是这个剪枝库,这个库集成了很多剪枝的方法,毕竟使用比较简单。

用法

这个剪枝库既有low level的剪枝,也就是手动控制剪枝哪些层,也有high level的剪枝,就是使用预设的剪枝算法,自动选择剪枝的部分。对于我们来说,更适合使用high level剪枝。具体的这里使用和上一篇yolov5里面的剪枝一样的算法,在这个库里叫BNScalePruner。

安装

首先我们需要安装上面提到的库,有两种方式来安装:

pip install torch-pruning

或源码安装(当碰到bug发布版本没修复,源码修复的时候):

pip install git+https://github.com/VainF/Torch-Pruning.git

稀疏化训练

为了更好的剪枝,我们在训练剪枝前的网络时,推荐开启稀疏化训练,利用这个库,我们可以很方便的实现这个操作。

首先在我们的训练代码中定义好剪枝器, 这里的opt.prune是我自己加的来控制是否开启稀疏化训练的标志:

Python 复制代码
# prune
if opt.prune:
	examle_input = torch.randn(1, 3, imgsz, imgsz).to(device)
	imp = tp.importance.BNScaleImportance()
	pruner = tp.pruner.BNScalePruner(model, examle_input, imp,
									 reg=0.0001)

稀疏化训练主要需要设置reg参数,一般设置0.001~1e-6之间。

定义好剪枝器后,在训练代码的scaler.scale(loss).backward()之后,添加如下代码:

Python 复制代码
if opt.prune:
	pruner.regularize(model)

即可实现稀疏化训练。

剪枝

稀疏化训练后(也可以不做稀疏化训练),我们就可以进行剪枝操作了。这个库可以在训练中交互式进行多次剪枝,简单起见,我们这里分离剪枝和训练的代码,只进行剪枝操作。

Python 复制代码
import torch_pruning as tp
from models.experimental import attempt_load
import torch

weights = "yolov7.pt"
model = attempt_load(weights, map_location=torch.device('cuda:0'), fuse=False)
for p in model.parameters():
    p.requires_grad = True
ignored_layers = []
from models.yolo import Detect, IDetect
from models.common import ImplicitA, ImplicitM
for m in model.modules():
    if isinstance(m, (Detect,IDetect)):
        ignored_layers.append(m.m)
unwrapped_parameters = []
for name, m in model.named_parameters():
    if isinstance(m, (ImplicitA,ImplicitM,)):
        unwrapped_parameters.append((name,1)) # pruning 1st dimension of implicit matrix

print(ignored_layers)
example_inputs = torch.rand(1, 3, 416, 416, device='cuda:0')
imp = tp.importance.BNScaleImportance()
pruner = tp.pruner.BNScalePruner(model, example_inputs, imp,
                                   ignored_layers=ignored_layers,
                                   unwrapped_parameters=unwrapped_parameters,
                                   global_pruning=True,
                                   ch_sparsity=0.3,
                                   round_to=8,
                                   )

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
pruned_model = pruner.model
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
print(f"macs: {base_macs} -> {pruned_macs}")
print(f"nparams: {base_nparams} -> {pruned_nparams}")
macs_cutoff_ratio = (base_macs - pruned_macs) / base_macs
nparams_cutoff_ratio = (base_nparams - pruned_nparams) / base_nparams
print(f"macs cutoff ratio: {macs_cutoff_ratio}")
print(f"nparams cutoff ratio: {nparams_cutoff_ratio}")
save_path = weights.replace(".pt", "_pruned_bn_0.3.pt")

torch.save({"model": pruned_model.module if hasattr(pruned_model, 'module') else pruned_model}, save_path)

去掉一些计算剪枝比例的,保存代码等代码外,剪枝操作其实由pruner.step()这一步完成。这里我们主要需要设置的参数是:

  • ch_sparsity: 可以理解成剪枝的比例,越大剪得越多
  • global_pruning: True表示整个模型的权重按一个整体排序后剪枝,False表示按分组内部按比例剪枝
  • round_to: 剪枝后的通道保留为多少的倍数,一般在硬件上,保留8的倍数

微调

经过剪枝的网络,精度是下降比较明显的,需要再在数据上finetune一些epoch才能把精度拉回来。

yolov7默认是通过yaml文件创建模型结构,然后再载入权重进行训练的,而我们剪枝后的模型是没有模型结构文件的,因此需要对训练代码做一定的修改,具体而言,只是对模型的载入进行一点修改。其中opt.finetune是用来控制是否处于finetune模式的标志位。

Python 复制代码
if opt.finetune: # for model without cfg
	new = torch.load(weights, map_location=device)  # create
	model = new["model"]
	print("Finetune Mode...")
elif pretrained:
...

比较简单的改法是这样,从checkpoint中载入结构和权重,还有一种方式则是修改yolov7的Model类,这个在后面讲yolov7剪枝后蒸馏的时候再讲,暂时用上面这种方式就可以了。

评测

我在自己的任务上的效果是yolov7剪枝50%,微调后基本上能达到剪枝前的map,没记错的话这是和稀疏化训练的比,毕竟开启稀疏化训练本身也会掉点。大家可以在自己的任务上尝试一下,总体上精度还是可以的

结语

这篇文章简述了以下yolov7的剪枝,yolov5也可用,希望对大家有帮助。

相关推荐
夏末秋也凉2 分钟前
力扣-数组-704 二分查找
算法·leetcode
玛丽亚后3 分钟前
动态规划(路径问题)
算法·动态规划
qy发大财5 分钟前
平衡二叉树(力扣110)
数据结构·算法·leetcode·职场和发展
AI技术控18 分钟前
计算机视觉算法实战——无人机检测
算法·计算机视觉·无人机
励志去大厂的菜鸟32 分钟前
系统相关类——java.lang.Math (三)(案例详细拆解小白友好)
java·服务器·开发语言·深度学习·学习方法
日日行不惧千万里33 分钟前
如何用YOLOv8训练一个识别安全帽的模型?
python·yolo
siy23331 小时前
【c语言日寄】Vs调试——新手向
c语言·开发语言·学习·算法
liuhui2441 小时前
Pytorch深度学习指南 卷I --编程基础(A Beginner‘s Guide) 第1章 一个简单的回归
pytorch·深度学习·回归
睡不着还睡不醒2 小时前
【深度学习】神经网络实战分类与回归任务
深度学习·神经网络·分类
知识鱼丸2 小时前
machine learning knn算法之使用KNN对鸢尾花数据集进行分类
算法·机器学习·分类