模型压缩——训练后剪枝

1. 引言

前文基于粒度的剪枝中主要是基于一个权重矩阵来介绍不同粒度下的剪枝方法,本文会介绍如何对一个实际的神经网络模型来实施剪枝操作。

剪枝是利用稀疏性来压缩模型的,卷积神经网络(CNN)往往具有较高的参数冗余性,冗余参数被剪除后,往往不会显著影响整体性能。因此,模型剪枝在图像处理领域的应用较为广泛。

相比图像模型来说,语言模型的上下文依赖特性使得模型性能对参数的敏感度较高,某些参数的删除可能会影响到模型的整体表现,所以用稀疏化剪枝在语言模型中的应用相对较少。

本文我们将以一个经典的卷积神经网络LeNet来例,来介绍模型剪枝操作的具体使用。

2. 模型介绍

LeNet 是一种经典的卷积神经网络,由 Yann LeCun 等人于 1998 年提出,主要用于手写数字识别(Minst数据集)。它的网络结构如下:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))  # 14x14x6
        x = self.maxpool(F.relu(self.conv2(x)))  # 5x5x16
        x = x.view(x.size()[0], -1)              # 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

模型结构说明:

  • 卷积层1:输入层接收一个32x32的单通道图像,应用6个5x5的卷积核后,得到一个28x28x6的特征图(每个卷积核都会得到一个28x28的特征图),卷积层有助于捕捉图像中的边缘、纹理等局部特征;
  • 池化1:用ReLu函数对卷积层的输出结果进行激活后,应用一个2x2的最大池化操作,特征图变为14x14x6; 最大池化是一种在小窗口内选择最大值的操作,例如2x2池化就是在2x2的区域内选择一个值最大的元素来代替这个区域,这样就减少特征图的尺寸和参数数量,可以防止过拟合,同时保留最重要的特征。
  • 卷积层2:应用16个5x5的卷积核后,得到一个10x10x16的特征图,进一步提取更高层次的特征和更复杂的模式。
  • 池化2:再次使用ReLu激活,并应用2x2的池化操作,池化后特征图变为5x5x16;
  • 展平:使用view操作将池化后的特征图展平成一维向量,形状变为1x400;
  • 全连接层fc1、fc2、fc3:负责将卷积层提取的局部特征整合起来,形成全局的高级特征;
python 复制代码
import numpy as np
import random

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

def count_parameters(model: nn.Module):
    return sum([param.numel() for param in model.parameters()])

model = LeNet()
print("parameters num:", count_parameters(model))
print("model structure:", model)
parameters num: 44426
model structure:
LeNet(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (fc1): Linear(in_features=256, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
  )

我们的主要目的是演示剪枝操作,所以就不再做模型的训练操作,直接加载训练好的模型参数。

python 复制代码
model.load_state_dict(torch.load("./checkpoint/model.pt"))

查看此时模型的稀疏度:

python 复制代码
def get_model_sparsity(model: nn.Module) -> float:
    num_nonzeros, num_params = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_params += param.numel()
    return 1 - float(num_nonzeros) / num_params

get_model_sparsity(model)
    0.0

此时还没有剪枝,所以模型的稀疏度为0,表示模型所有参数都是有效的。

3. 加载数据集

Minst数据集包含70000个手写数字(0-9)图像,其中有60000个训练样本和10000个测试样本,每张图像都是一个灰度图像,分辨率为 28x28 像素。

先使用datasets.MNIST类来加载数据集,其中:

  • download=True 参数会检查root指定的目录下是否已经存在 MNIST 数据集文件,如果不存在,它会自动从互联网上下载并解压,如果已经存在,则会直接使用现有的数据集文件。
  • torchvision.transforms 模块用于对数据进行预处理操作,transforms.ToTensor() 用于将 PIL 图片或 NumPy 数组转换为 PyTorch 张量(Tensor)。transforms.Normalize的作用是将张量从 [0, 255]的像素自动归一化到 均值为0.1307,标准差为0.3081的正态分布上。
python 复制代码
from torchvision.transforms import *
from torchvision import datasets

# 设置归一化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# 获取数据集,train=True训练集,=False测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) 
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


    100%|██████████| 9912422/9912422 [00:11<00:00, 860380.13it/s] 


    Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

    
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz

    100%|██████████| 28881/28881 [00:00<00:00, 124555.87it/s]

    Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
    
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


    100%|██████████| 1648877/1648877 [00:03<00:00, 484979.62it/s]


    Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
    
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
    Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


    100%|██████████| 4542/4542 [00:00<00:00, 1517365.89it/s]

    Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

那这个数据集究竟长什么样呢?我们可以用matplotlib库将这些图像的原始内容显示出来,以便直观察看。

python 复制代码
import matplotlib.pyplot as plt

def show_image(dataset, rows=1):
    # 设置可视化的画布大小  
    plt.figure(figsize=(10, 10))  

    # 从数据集中提取rows * 10个样本  
    num_samples = rows * 10
    indices = np.random.choice(len(dataset), num_samples, replace=False)  
    sample_dataset = [dataset[i] for i in indices]
    
    # 显示随机样本  
    for i, (image, label) in enumerate(sample_dataset):    
        # 转换为 numpy 数组并去除通道信息  
        image = image.numpy().squeeze()  # (C, H, W) -> (H, W)  
        plt.subplot(rows, 10, i + 1)  
        plt.imshow(image, cmap='gray')  
        plt.title(f'Label: {label}', fontsize=12)  
        plt.axis('off')  
    
    plt.tight_layout()  
    plt.show() 


show_image(train_dataset, 1)

可以看到,每个图像都与一个标签label(0 到 9 的数字)相关联,表示图像中显示的数字。

4. 剪枝前评估

为了在剪枝前对模型的性能预先有一个了解,我们会测试数据集对AlexNet模型进行一个准确率测试。

先实现一个evaluate评估方法,逻辑大概如下。

  • 用模型model对输入的批量数据inputs作分类预测,得到所有分类可能性的数值logits,并用argmax取可能性最大的值作为预测分类结果outputs。
  • 将预测分类结果outputs和目标分类targets进行比对,统计预测正确的数量num_correct。
  • 最后,计算正确预测数量num_correct与总数量num_samples的比值,得到准确率。
python 复制代码
from tqdm.auto import tqdm

def evaluate(model, dataloader):
    model.eval()
    num_samples, num_correct = 0, 0
    for inputs, targets in tqdm(dataloader, desc="eval"):
        logits = model(inputs)
        outputs = logits.argmax(dim=1)
        num_samples += targets.size(0)
        num_correct += (outputs == targets).sum()

    return (num_correct / num_samples * 100).item()        

将数据集封装为小批量数据加载器,批量大小batch_size设为64,shuffle=True表示对数据集进行随机性打乱顺序。

python 复制代码
from torch.utils.data import DataLoader

# 设置DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

对测试数据集上运行评估方法,得到准确率。

python 复制代码
origin_accuracy = evaluate(model, test_loader)
origin_accuracy
    97.98999786376953

5. 剪枝

这一部分我们将实现一个基于稀疏度的剪枝,例如:稀疏度0.8表示剪掉张量中80%的参数。

在剪枝前,首先需要计算一个掩码,它决定了哪些权重需要被剪枝,哪些权重需要被保留。

python 复制代码
def calc_mask(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        return torch.zeros_like(tensor)
    elif sparsity == 0.0:
        return torch.ones_like(tensor)
    # 计算张量中的元素总数和需要置零的个数
    num_elements = tensor.numel()
    num_zeros = round(num_elements * sparsity)
    # 计算每个元素的绝对值,作为重要性度量
    importance = tensor.abs()
    # 根据需要置零的元素数量找到相应分位阀值
    threshold = importance.view(-1).kthvalue(num_zeros).values
    # 计算掩码:将大于阀值的置为1,小于阀值的置为0
    mask = torch.gt(importance, threshold)
    return mask

创建一个剪枝器类来管理剪枝操作,其中:

  • 在构造方法中,完成了模型每层掩码的计算,并缓存在masks变量中;
  • 在prune方法中,在一个新克隆的模型实例上,对每一层权重应用剪枝掩码,得到剪枝后的张量。
python 复制代码
import copy

class SparsePruner:
    def __init__(self, model, sparsity_dict):
        masks = dict()
        for name, param in model.named_parameters():
            if param.dim() > 1:
                masks[name] = calc_mask(param, sparsity_dict[name])
        self.masks = masks

    @torch.no_grad()
    def prune(self, model):
        # to_prune_model = copy.deepcopy(self.origin_model)
        for name, param in model.named_parameters():
            if name in self.masks:
                param.mul_(self.masks[name])
        return model

在sparsity_dict字典中定义每一层的稀疏度,并调用剪枝方法。

python 复制代码
sparsity_dict = {
    'conv1.weight': 0.85,
    'conv2.weight': 0.8,
    'fc1.weight': 0.75,
    'fc2.weight': 0.7,
    'fc3.weight': 0.8,
}

pruner = SparsePruner(model, sparsity_dict)
pruned_model = pruner.prune(copy.deepcopy(model))

print(f"model sparsity after prune: {get_model_sparsity(pruned_model):.4f}")
    model sparsity after prune: 0.7387

可以看到,剪枝后模型的稀疏度为0.7387, 这表示模型中有73.87%的参数都被置为了0,相应的模型大小已经变成原来的26.13%。

下面评估下剪枝操作给模型的准确率带来有多大程度的影响。

python 复制代码
accuracy_after_prune = evaluate(pruned_model, test_loader)
accuracy_after_prune
    66.45999908447266

剪枝后,模型的准确率从97.99% 下降到了 66.46%。

6. 微调

这部分我们将对上面剪枝后的模型进行微调,目的是尽可能将模型性能恢复到接近剪枝前的水平。

首先,写一个训练函数train,功能是在指定的数据集上完成一轮训练。

python 复制代码
def train(model, dataloader, loss_fn, optimizer, pruner):
    model.train()
    for inputs, targets in tqdm(dataloader, desc="train"):
        optimizer.zero_grad()  
        
        logits = model(inputs)
        loss = loss_fn(logits, targets)
        loss.backward()
        optimizer.step()
        
        pruner.prune(model)
  • optimizer.zero_grad():用于在每个小批量迭代前清零梯度;
  • model(inputs):前向传播完成分类预测;
  • loss_fn:计算预测结果outputs与目标结果targets之间的损失;
  • loss.backward(): 损失反向传播计算每层的梯度;
  • optimizer.step(): 根据梯度来更新权重参数值;
  • pruner.prune(model): 对模型参数进行剪枝,始终保证训练期间模型参数的稀疏度;
python 复制代码
num_epochs = 5

optimizer = torch.optim.SGD(pruned_model.parameters(),  lr=0.01, momentum=0.5)
loss_fn = nn.CrossEntropyLoss()  

best_pruned_model_checkpoint = None
best_accuracy = 0
for i in range(num_epochs):
    train(pruned_model, train_loader, loss_fn, optimizer, pruner)
    accuracy = evaluate(pruned_model, test_loader)
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_pruned_model_checkpoint = copy.deepcopy(pruned_model.state_dict())

    print(f"epoch: {i+1}, accuracy: {accuracy:.2f}%, best_accuracy: {best_accuracy:.2f}%")
    

通过微调,我们将剪枝后模型的准确率从66.46%恢复到了98.31%,比剪枝前97.99%还要略高一点,达到了预期目标。

小结:本文以一个比较经典的LeNet卷积神经网络作为开始,介绍了一种训练后剪枝的实施过程,通过先剪枝并评估性能损失,再通过微调来恢复模型性能。这个网络比较简单,所以只进行了一轮剪枝-微调步骤,实际场景中对于参数量大的模型,可能需要用迭代剪枝的方法循环多次剪枝-微调步骤,以便让剪枝的影响和结果更为可控。

参考阅读

相关推荐
盼小辉丶2 小时前
TensorFlow深度学习实战——情感分析模型
深度学习·神经网络·tensorflow
好评笔记2 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
算家云2 小时前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
AI街潜水的八角3 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
叫我:松哥4 小时前
基于Python django的音乐用户偏好分析及可视化系统设计与实现
人工智能·后端·python·mysql·数据分析·django
熊文豪5 小时前
深入解析人工智能中的协同过滤算法及其在推荐系统中的应用与优化
人工智能·算法
Vol火山5 小时前
AI引领工业制造智能化革命:机器视觉与时序数据预测的双重驱动
人工智能·制造
tuan_zhang6 小时前
第17章 安全培训筑牢梦想根基
人工智能·安全·工业软件·太空探索·战略欺骗·算法攻坚
Antonio9156 小时前
【opencv】第10章 角点检测
人工智能·opencv·计算机视觉
互联网资讯6 小时前
详解共享WiFi小程序怎么弄!
大数据·运维·网络·人工智能·小程序·生活