【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(下,结构化剪枝)

文章目录

      • [0. 前言](#0. 前言)
      • [1. 结构化剪枝 vs 非结构化剪枝](#1. 结构化剪枝 vs 非结构化剪枝)
        • [1.1 非结构化剪枝的特征](#1.1 非结构化剪枝的特征)
        • [1.2 结构化剪枝](#1.2 结构化剪枝)
        • [1.3 结构化剪枝的好处:](#1.3 结构化剪枝的好处:)
      • [2. `torch.nn.utils.prune`中的结构化剪枝方法](#2. torch.nn.utils.prune中的结构化剪枝方法)
      • [3. PyTorch实例](#3. PyTorch实例)
        • [3.1 `random_structured`](#3.1 random_structured)
        • [3.2 `prune.ln_structured`](#3.2 prune.ln_structured)
      • [4. 总结](#4. 总结)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在前文:【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上,非结构化剪枝)中介绍了PyTorch中的prune模型剪枝模块中的非结构化剪枝。本文将通过实例说明utils.prune中的结构化剪枝方法。

1. 结构化剪枝 vs 非结构化剪枝

1.1 非结构化剪枝的特征

非结构化剪枝是指在神经网络的权重矩阵中随机地移除一些权重,而不考虑这些权重在矩阵中的位置或它们是否构成某种结构(如一个完整的通道或过滤器)。这种剪枝方式通常会导致权重矩阵变得非常稀疏,但同时也破坏了权重的原有结构,这可能会对模型的并行计算效率产生负面影响,因为现代硬件(如GPU)在处理密集矩阵时更有效率。

1.2 结构化剪枝

相比之下,结构化剪枝则是在保持(部分)网络层的结构完整性的同时进行剪枝,即它会移除整个的神经元、通道、过滤器或其他结构单元,而不是某个完整结构中的单个权重。例如,在卷积层中,结构化剪枝可能涉及移除整个过滤器或输入/输出通道,而在全连接层中,则可能移除整行或整列的权重。

1.3 结构化剪枝的好处:
  1. **硬件友好性:**由于结构化剪枝保持了权重矩阵的结构,因此它更易于在现代硬件上实现高效的并行计算,不会像非结构化剪枝那样引入大量零元素,导致计算效率下降。
  2. **加速推理:**结构化剪枝通过移除整个的结构单元,可以直接减少模型的计算量和内存占用,从而显著加速推理过程。
  3. **易于部署:**结构化剪枝后的模型仍然保持原有的结构,这使得模型更容易被优化过的推理引擎(如TensorRT)所支持,便于在边缘设备或移动设备上部署。
  4. **更好的可解释性:**移除某些结构单元有时可以帮助理解哪些特征或信息对于模型的决策是不重要的,从而提高了模型的可解释性。

2. torch.nn.utils.prune中的结构化剪枝方法

本文将介绍2种结构化剪枝方法:

  • prune.random_structured: 随机结构化剪枝,按照给定维度移除随机通道。
  • prune.ln_structured: Ln范数结构化剪枝,沿着给定维度移除具有最低n范数的通道。

3. PyTorch实例

首先建立一个简单的模型:

python 复制代码
import torch
import torch.nn as nn
from torch.nn.utils import prune

torch.manual_seed(888)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        # 创建一个简单的卷积层
        self.conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)

model = SimpleModel()

通过print(model.conv.weight)可以打印出权重为:

python 复制代码
Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]],


        [[[ 0.0799,  0.1399, -0.0084],
          [ 0.2013, -0.0352, -0.1027],
          [-0.1724, -0.3094, -0.2382]]],


        [[[ 0.0419,  0.2224, -0.1558],
          [ 0.2084,  0.0543,  0.0647],
          [ 0.1493,  0.2011,  0.0310]]]], requires_grad=True)
3.1 random_structured

这个方法会在指定维度dim(默认为-1)上剪枝一个随机通道:

python 复制代码
prune.random_structured(model.conv, name="weight", amount=0.33)
print(model.conv.weight)

输出为:

python 复制代码
tensor([[[[-0.3017,  0.1290, -0.0000],
          [ 0.2107,  0.1799,  0.0000],
          [ 0.1887, -0.0527,  0.0000]]],


        [[[ 0.0799,  0.1399, -0.0000],
          [ 0.2013, -0.0352, -0.0000],
          [-0.1724, -0.3094, -0.0000]]],


        [[[ 0.0419,  0.2224, -0.0000],
          [ 0.2084,  0.0543,  0.0000],
          [ 0.1493,  0.2011,  0.0000]]]], grad_fn=<MulBackward0>)

由于权重的维度为[3, 1, 3, 3],我们也可以试试在其他维度(dim=0dim=2)上进行剪枝:

  • dim=0
python 复制代码
prune.random_structured(model.conv, name="weight", amount=0.33,dim=0)
print(model.conv.weight)

输出为:

python 复制代码
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]],


        [[[ 0.0799,  0.1399, -0.0084],
          [ 0.2013, -0.0352, -0.1027],
          [-0.1724, -0.3094, -0.2382]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
  • dim=2
python 复制代码
prune.random_structured(model.conv, name="weight", amount=0.33,dim=2)
print(model.conv.weight)

输出为:

python 复制代码
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[ 0.0799,  0.1399, -0.0084],
          [ 0.2013, -0.0352, -0.1027],
          [-0.0000, -0.0000, -0.0000]]],


        [[[ 0.0419,  0.2224, -0.1558],
          [ 0.2084,  0.0543,  0.0647],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
3.2 prune.ln_structured

这个方法会在指定维度dim(默认为-1)上按最小n范数剪枝一个通道:

python 复制代码
prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=-1)
print(model.conv.weight)

输出为:

python 复制代码
tensor([[[[-0.3017,  0.1290, -0.0000],
          [ 0.2107,  0.1799,  0.0000],
          [ 0.1887, -0.0527,  0.0000]]],


        [[[ 0.0799,  0.1399, -0.0000],
          [ 0.2013, -0.0352, -0.0000],
          [-0.1724, -0.3094, -0.0000]]],


        [[[ 0.0419,  0.2224, -0.0000],
          [ 0.2084,  0.0543,  0.0000],
          [ 0.1493,  0.2011,  0.0000]]]], grad_fn=<MulBackward0>)

更改dim也是同样的效果:

  • dim=0
python 复制代码
prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=0)
print(model.conv.weight)

输出为:

python 复制代码
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]],


        [[[ 0.0799,  0.1399, -0.0084],
          [ 0.2013, -0.0352, -0.1027],
          [-0.1724, -0.3094, -0.2382]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], grad_fn=<MulBackward0>)
  • dim=2
python 复制代码
prune.ln_structured(model.conv, name="weight",amount=0.33,n=1,dim=2)
print(model.conv.weight)

输出为:

python 复制代码
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.1887, -0.0527,  0.1403]]],


        [[[ 0.0799,  0.1399, -0.0084],
          [ 0.0000, -0.0000, -0.0000],
          [-0.1724, -0.3094, -0.2382]]],


        [[[ 0.0419,  0.2224, -0.1558],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.1493,  0.2011,  0.0310]]]], grad_fn=<MulBackward0>)

4. 总结

至此,prune中的非结构化剪枝和结构化剪枝介绍完毕!

相关推荐
仙人掌_lz几秒前
深度理解用于多智能体强化学习的单调价值函数分解QMIX算法:基于python从零实现
python·算法·强化学习·rl·价值函数
小白学大数据6 分钟前
Python+Selenium爬虫:豆瓣登录反反爬策略解析
分布式·爬虫·python·selenium
未来之窗软件服务8 分钟前
人体肢体渲染-一步几个脚印从头设计数字生命——仙盟创梦IDE
开发语言·ide·人工智能·python·pygame·仙盟创梦ide
戌崂石15 分钟前
最优化方法Python计算:有约束优化应用——线性不可分问题支持向量机
python·机器学习·支持向量机·最优化方法
Echo``15 分钟前
40:相机与镜头选型
开发语言·人工智能·深度学习·计算机视觉·视觉检测
玉笥寻珍19 分钟前
Web安全渗透测试基础知识之内存动态分配异常篇
网络·python·安全·web安全·网络安全
Christo323 分钟前
关于在深度聚类中Representation Collapse现象
人工智能·深度学习·算法·机器学习·数据挖掘·embedding·聚类
Channing Lewis25 分钟前
如何判断一个网站后端是用什么语言写的
前端·数据库·python
noravinsc32 分钟前
InforSuite RDS 与django结合
后端·python·django
依然易冷43 分钟前
Manus AI 原理深度解析第三篇:Tools
人工智能·深度学习·机器学习