【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上)

文章目录

      • [0. 前言](#0. 前言)
      • [1. 剪枝`prune`主要功能分类](#1. 剪枝prune主要功能分类)
      • [2. `torch.nn.utils.prune`中的方法介绍](#2. torch.nn.utils.prune中的方法介绍)
      • [3. PyTorch实例](#3. PyTorch实例)
      • [4. 总结](#4. 总结)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

PyTorch中的torch.nn.utils.prune模块是一个专门用于神经网络模型剪枝的工具集。模型剪枝是一种减少神经网络参数数量的技术,其目标是在保持模型性能的同时减少计算成本内存占用。这对于部署模型到资源受限的设备(如移动设备或嵌入式系统)特别有用。

本文将通过实例介绍torch.nn.utils.prune模块中的各个方法,由于内容较多分为上、下两篇。

1. 剪枝prune主要功能分类

torch.nn.utils.prune模块提供了一系列的剪枝方法,包括但不限于:

  1. 无结构剪枝 :这种剪枝方法可以独立地移除网络中的权重,而不考虑权重之间的结构关系。例如,L1UnstructuredRandomUnstructured 就是两种无结构剪枝方法,它们分别根据权重的绝对值大小和随机选择的方式移除权重。

  2. 结构化剪枝 :与无结构剪枝相反,结构化剪枝会移除整个的结构单位(如整个神经元或通道),而不是单独的权重。RandomStructuredLnStructured 就是这样的例子,它们可以移除整个的通道。

  3. 自定义剪枝CustomFromMask 方法允许用户自定义剪枝策略,通过提供一个掩码来指定哪些权重应该被保留或移除。

  4. 剪枝管理 :除了剪枝方法本身,torch.nn.utils.prune还提供了工具来管理和应用剪枝,例如,prune.global_unstructuredprune.remove 方法。前者允许跨多个层执行全局剪枝,而后者则用于移除剪枝操作,恢复原始权重或应用剪枝掩码。

2. torch.nn.utils.prune中的方法介绍

下面是本文将介绍的torch.nn.utils.prune中的方法:

  • BasePruningMethod: 抽象基类,用于创建新的剪枝类。
  • PruningContainer:允许组合多种不同的剪枝策略,并按顺序应用这些策略。
  • Identity: 实现了一个不剪枝任何单元仅生成一个全为一的掩码的实用剪枝方法。
  • RandomUnstructured: 随机剪枝张量中的单元。
  • L1Unstructured: 根据L1范数(绝对值)剪枝张量中的单元。

3. PyTorch实例

为了介绍这些剪枝方法,我们将首先定义一个简单的模型,并使用torch.nn.utils.prune模块中的各种剪枝方法来处理这个模型的权重。我们将以一个简单的卷积层为例,然后应用上述提到的每种剪枝方法。

首先,让我们导入必要的库并定义一个包含单个卷积层的模型:

python 复制代码
import torch
import torch.nn as nn
from torch.nn.utils import prune

torch.manual_seed(888)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        # 创建一个简单的卷积层
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)

model = SimpleModel()

这里有一个值得注意的地方就是prune的导入:如果不写from torch.nn.utils import prune,而直接在代码中使用torch.nn.utils.prune.xxxx(),会报错↓

这个报错我不太能理解,不知道会不会在后续版本中更正。

接下来,我们将逐一介绍并应用每种剪枝方法:

3.1 BasePruningMethod

这是一个抽象类,可以理解为自定义剪枝的类。

python 复制代码
class BasePruningMethod(ABC):
    r"""Abstract base class for creation of new pruning techniques.

    Provides a skeleton for customization requiring the overriding of methods
    such as :meth:`compute_mask` and :meth:`apply`.
    """
3.2PruningContainer

一开始我觉得这个方法和nn.Sequential差不多,但是实际并不是!

PruningContainer通常不会直接由用户实例化,而是作为torch.nn.utils.prune中其他剪枝方法的基础。当调用如L1UnstructuredRandomUnstructuredLnStructured等剪枝方法时,内部会创建一个PruningContainer实例,并且将特定的剪枝方法添加到容器中。

3.3 Identity

这个方法不会剪枝(改变)任何权重,它只会生成一个全为1的掩码。

python 复制代码
print("Weight before Identity pruning:")
print(model.conv.weight)
prune.identity(model.conv, name="weight")
print("Weight after Identity pruning:")
print(model.conv.weight)
print("mask:")
print(model.conv.weight_mask)

输出为:

python 复制代码
Weight before Identity pruning:
Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]]], requires_grad=True)
Weight after Identity pruning:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]]], grad_fn=<MulBackward0>)
mask:
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
3.4RandomUnstructured

这个方法会随机选择权重值进行剪枝:

python 复制代码
prune.random_unstructured(model.conv, name="weight", amount=0.5) 
#amount参数指定的是要被剪枝(即置零)的权重比例。
print("Weight after RandomUnstructured pruning (50%):")
print(model.conv.weight)

输出为:

python 复制代码
Weight after RandomUnstructured pruning (50%):
tensor([[[[-0.0000,  0.0000, -0.0000],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.0000]]]], grad_fn=<MulBackward0>)

可以明显看出,对比3.3节的输出结果,有4个(50%)参数被剪枝(置零)了。

3.5L1Unstructured

这个方法会根据权重的L1范数选择要剪枝的权重。

python 复制代码
prune.l1_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after L1Unstructured pruning (50%):")
print(model.conv.weight)

输出为:

python 复制代码
Weight after L1Unstructured pruning (50%):
tensor([[[[-0.3017,  0.0000, -0.2468],
          [ 0.2107,  0.0000,  0.1923],
          [ 0.1887, -0.0000,  0.0000]]]], grad_fn=<MulBackward0>)

Process finished with exit code 0

对比3.3输出的结果,可以看出L1范数(绝对值)最小的4个(50%)参数被剪枝(置零)了。

4. 总结

本文介绍了PyTorch中的prune模型剪枝模块的作用及部分方法的用法,后续将把所有方法的用法补齐!

相关推荐
Watermelo617几秒前
通过MongoDB Atlas 实现语义搜索与 RAG——迈向AI的搜索机制
人工智能·深度学习·神经网络·mongodb·机器学习·自然语言处理·数据挖掘
敲代码不忘补水7 分钟前
生成式GPT商品推荐:精准满足用户需求
开发语言·python·gpt·产品运营·产品经理
孤客网络科技工作室12 分钟前
Python Plotly 库使用教程
python·信息可视化·plotly
悟解了12 分钟前
《数据可视化技术》上机报告
python·信息可视化·数据分析
AI算法-图哥12 分钟前
pytorch量化训练
人工智能·pytorch·深度学习·文生图·模型压缩·量化
机器学习之心16 分钟前
时序预测 | 改进图卷积+informer时间序列预测,pytorch架构
人工智能·pytorch·python·时间序列预测·informer·改进图卷积
糊涂君-Q37 分钟前
Python小白学习教程从入门到入坑------第三十一课 迭代器(语法进阶)
python·学习·程序人生·考研·职场和发展·学习方法·改行学it
天飓43 分钟前
基于OpenCV的自制Python访客识别程序
人工智能·python·opencv
取个名字真难呐1 小时前
矩阵乘法实现获取第i行,第j列值,矩阵大小不变
python·线性代数·矩阵·numpy
技术仔QAQ1 小时前
【tokenization分词】WordPiece, Byte-Pair Encoding(BPE), Byte-level BPE(BBPE)的原理和代码
人工智能·python·gpt·语言模型·自然语言处理·开源·nlp