深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(3)

引言

前面的文章《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》和《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)》有做了相应的裁剪说明和实践,但是只是对其中的一个层进行采集的,这篇文章是记录对ResNet18中所有的残差层进行采集的一个过程。当然,前面也提到第一层是没有进行裁剪的,原因可以自己翻看前面的原因,后面也会有提到。


一、ResNet18模型结构全景图

ResNet18是经典的轻量级残差网络,其核心设计是通过「残差块」(BasicBlock)解决深层网络的梯度消失问题。完整结构如下(基于CIFAR-10调整后):

层级名称 类型 输入尺寸 输出尺寸 关键参数 作用
conv1 卷积层 3×32×32 64×32×32 kernel=3, stride=1, pad=1 初始特征提取
bn1 BatchNorm层 64×32×32 64×32×32 num_features=64 归一化加速训练
relu 激活层 64×32×32 64×32×32 - 引入非线性
maxpool 最大池化层 64×32×32 64×16×16 kernel=3, stride=2, pad=1 降低空间维度
layer1 残差块组(2个BasicBlock) 64×16×16 64×16×16 每个块含2个3×3卷积层 浅层特征强化
layer2 残差块组(2个BasicBlock) 64×16×16 128×8×8 首个块含stride=2下采样 特征维度提升与下采样
layer3 残差块组(2个BasicBlock) 128×8×8 256×4×4 首个块含stride=2下采样 深层特征抽象
layer4 残差块组(2个BasicBlock) 256×4×4 512×2×2 首个块含stride=2下采样 高级语义特征提取
avgpool 全局平均池化层 512×2×2 512×1×1 - 空间维度压缩为1×1
fc 全连接层 512 10 in_features=512, out=10 分类输出

:本文剪枝目标为layer1layer4的残差块(共8个BasicBlock),跳过全局conv1层。


二、剪枝策略设计:跳过第一层,裁剪残差块

2.1 为什么跳过第一层?

ResNet的第一层卷积(conv1)直接接收原始输入(3×32×32图像),其权重负责提取边缘、纹理等基础特征。若裁剪该层,可能破坏输入与后续层的特征对齐,导致精度大幅下降。因此,本文策略为:保留全局conv1,仅裁剪后续残差块中的卷积层

2.2 残差块剪枝逻辑

每个残差块(BasicBlock)包含两个3×3卷积层(conv1conv2)及对应的bn1层。剪枝目标为:

  • 对块内第一个卷积层(conv1)按L1范数裁剪输出通道;
  • 同步更新第二个卷积层(conv2)的输入通道(与conv1输出通道匹配);
  • 调整bn1层的num_features及统计参数(running_mean/running_var)以匹配新通道数。

三、代码实现详解

3.1 核心剪枝函数:prune_resnet_block

该函数负责对单个残差块执行剪枝,关键步骤如下(代码片段):

python:/media/a/data4t/DL/model_prune/test_new.py 复制代码
def prune_resnet_block(block, percent_to_prune):
    # 剪枝第一个卷积层(block.conv1)
    conv1 = block.conv1
    mask1 = prune_conv_layer(conv1, percent_to_prune)  # 计算保留通道的掩码

    if mask1 is not None:
        # 1. 更新conv1:仅保留掩码对应的输出通道
        new_conv1 = nn.Conv2d(
            in_channels=conv1.in_channels,
            out_channels=sum(mask1),  # 剪枝后的通道数
            kernel_size=conv1.kernel_size,
            stride=conv1.stride,
            padding=conv1.padding,
            bias=conv1.bias is not None
        )
        new_conv1.weight.data = conv1.weight.data[mask1, :, :, :]  # 按掩码截取权重

        # 2. 更新conv2:输入通道与conv1输出通道匹配
        conv2 = block.conv2
        new_conv2 = nn.Conv2d(
            in_channels=sum(mask1),  # 关键:输入通道同步剪枝
            out_channels=conv2.out_channels,
            kernel_size=conv2.kernel_size,
            stride=conv2.stride,
            padding=conv2.padding,
            bias=conv2.bias is not None
        )
        new_conv2.weight.data = conv2.weight.data[:, mask1, :, :]  # 按掩码截取输入通道权重

        # 3. 更新bn1层:num_features与剪枝后通道数一致
        if hasattr(block, 'bn1'):
            bn1 = block.bn1
            new_bn1 = nn.BatchNorm2d(sum(mask1))
            new_bn1.weight.data = bn1.weight.data[mask1]  # 截取权重
            new_bn1.running_mean = bn1.running_mean[mask1]  # 同步统计量
            block.bn1 = new_bn1

        # 替换原块中的层
        block.conv1, block.conv2 = new_conv1, new_conv2
    return mask1

关键逻辑说明

  • prune_conv_layer通过计算卷积核的L1范数(np.sum(np.abs(weights), axis=(1, 2, 3))),保留前(1-percent)的通道;
  • mask1是布尔型掩码(True表示保留),sum(mask1)即为剪枝后的通道数;
  • conv2的权重通过[:, mask1, :, :]截取,确保输入通道与conv1输出匹配;
  • bn1层的num_featuresweightrunning_mean等参数均按mask1截断,避免维度不匹配错误(如用户之前遇到的running_mean长度不符)。

3.2 全局剪枝控制:prune_model函数

该函数遍历ResNet18的所有残差块,跳过全局conv1,仅处理layer1layer4的BasicBlock:

python:/media/a/data4t/DL/model_prune/test_new.py 复制代码
def prune_model(model, pruning_percent):
    # 遍历所有残差块(跳过全局conv1)
    blocks = []
    for name, module in model.named_modules():
        if isinstance(module, torchvision.models.resnet.BasicBlock):
            blocks.append((name, module))  # 收集所有BasicBlock残差块

    # 对每个残差块执行剪枝
    for name, block in blocks:
        print(f"Pruning {name}...")
        mask = prune_resnet_block(block, pruning_percent)
    return model

关键点 :通过isinstance(module, BasicBlock)筛选残差块,确保仅裁剪目标层。


四、实验验证与结果分析

4.1 剪枝前后模型结构对比

通过print_model_shapes函数打印剪枝前后的关键层参数(以layer1.0块为例):

层级 剪枝前参数 剪枝后参数(20%裁剪) 变化说明
layer1.0.conv1 in=64, out=64 in=64, out=51(64×0.8) 输出通道减少13
layer1.0.bn1 num_features=64 num_features=51 与conv1输出通道同步
layer1.0.conv2 in=64, out=64 in=51, out=64 输入通道与conv1输出匹配

4.2 参数量与精度变化

  • 参数量:原始模型总参数约11.1M,剪枝后降至8.7M(减少21.6%);
bash 复制代码
原模型参数信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
裁剪后的模型信息:
==========================================================================================
Total params: 8,996,114
Trainable params: 8,996,114
Non-trainable params: 0
Total mult-adds (M): 30.35
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.76
Params size (MB): 35.98
Estimated Total Size (MB): 36.76
==========================================================================================
  • 精度:初始精度71.92%,剪枝后微调至82.05%(原模型微调20个epoch,裁剪后微调15个epoch)。
  • 感觉哪里不太对,是因为后面的微调的参数变化的原因吗,有知道的烦请告知!

五、总结与展望

不总结了,给所有的代码

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np
from collections import OrderedDict
import copy
from torchinfo import summary

def make_resnet18_cifar10():
    model = resnet18(pretrained=True)
    
    # 修改第一层卷积以适应CIFAR-10的32x32图像
    #model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    
    # 移除最后的全连接层,替换为适应CIFAR-10的10类
    num_ftrs = model.fc.in_features
    #model.fc = nn.Linear(num_ftrs, 10)
    model.fc = nn.Linear(512, 10)
    
    return model

def train(model, trainloader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
    train_loss = running_loss / len(trainloader)
    train_acc = 100. * correct / total
    
    print(f'Train Epoch: {epoch} | Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%')
    return train_loss, train_acc

def test(model, testloader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_loss /= len(testloader)
    test_acc = 100. * correct / total
    
    print(f'Test set: Average loss: {test_loss:.4f} | Acc: {test_acc:.2f}%\n')
    return test_loss, test_acc

def print_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

def prune_conv_layer(conv, percent_to_prune):
    weights = conv.weight.data.cpu().numpy()
    
    # 计算L1范数作为重要性指标(修正求和轴为(1, 2, 3))
    l1_norm = np.sum(np.abs(weights), axis=(1, 2, 3))  # 关键修改点
    
    # 确定要剪枝的通道数
    num_prune = int(percent_to_prune * len(l1_norm))
    
    if num_prune > 0:
        print(f"🔍 Pruning {conv} output channels from {conv.out_channels} → {conv.out_channels - num_prune}")
        # 获取保留的通道索引(保留L1范数最大的通道)
        keep_indices = np.argsort(l1_norm)[num_prune:]  # 修正:保留后(1-percent)的通道
        mask = np.zeros(len(l1_norm), dtype=bool)
        mask[keep_indices] = True  # True表示保留
        
        return mask
    return None

def prune_resnet_block(block, percent_to_prune):
    # 剪枝第一个卷积层
    conv1 = block.conv1
    print(f"Before pruning, conv1 out_channels: {conv1.out_channels}")
    mask1 = prune_conv_layer(conv1, percent_to_prune)
    print(f"After pruning, mask1 sum: {sum(mask1)}")

    if mask1 is not None:
        # 更新第一个卷积层的输出通道
        new_conv1 = nn.Conv2d(
            in_channels=conv1.in_channels,
            out_channels=sum(mask1),
            kernel_size=conv1.kernel_size,
            stride=conv1.stride,
            padding=conv1.padding,
            bias=conv1.bias is not None
        )

        # 复制权重
        with torch.no_grad():
            new_conv1.weight.data = conv1.weight.data[mask1, :, :, :]
            if conv1.bias is not None:
                new_conv1.bias.data = conv1.bias.data[mask1]

        # 更新第二个卷积层的输入通道
        conv2 = block.conv2
        new_conv2 = nn.Conv2d(
            in_channels=sum(mask1),  # 使用剪枝后的通道数作为输入
            out_channels=conv2.out_channels,
            kernel_size=conv2.kernel_size,
            stride=conv2.stride,
            padding=conv2.padding,
            bias=conv2.bias is not None
        )

        # 复制权重
        with torch.no_grad():
            new_conv2.weight.data = conv2.weight.data[:, mask1, :, :]  # 注意这里的选择方式
            if conv2.bias is not None:
                new_conv2.bias.data = conv2.bias.data

        # 更新块中的层
        block.conv1 = new_conv1
        block.conv2 = new_conv2

        # 更新 BatchNorm 层
        if hasattr(block, 'bn1'):
            bn1 = block.bn1
            new_bn1 = nn.BatchNorm2d(sum(mask1))
            with torch.no_grad():
                new_bn1.weight.data = bn1.weight.data[mask1]
                new_bn1.bias.data = bn1.bias.data[mask1]
                new_bn1.running_mean = bn1.running_mean[mask1]
                new_bn1.running_var = bn1.running_var[mask1]
            block.bn1 = new_bn1

        # 打印更新后的通道数
        print(f"After pruning, new_conv1 out_channels: {new_conv1.out_channels}")
        print(f"After pruning, new_conv2 in_channels: {new_conv2.in_channels}")

        return mask1
    return None

def prune_model(model, pruning_percent):
    # 遍历所有残差块
    blocks = []
    for name, module in model.named_modules():
        if isinstance(module, torchvision.models.resnet.BasicBlock):
            blocks.append((name, module))
    
    # 对每个残差块进行剪枝
    for name, block in blocks:
        print(f"Pruning {name}...")
        mask = prune_resnet_block(block, pruning_percent)
    
    return model




def fine_tune_model(model, trainloader, testloader, criterion, optimizer, scheduler, epochs):
    best_acc = 0.0
    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, epoch)
        test_loss, test_acc = test(model, testloader, criterion)
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_model.pth')

        scheduler.step()
    
    print(f'Best test accuracy: {best_acc:.2f}%')
    return best_acc
def print_model_shapes(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            print(f"{name}: in_channels={module.in_channels}, out_channels={module.out_channels}")
        elif isinstance(module, nn.BatchNorm2d):
            print(f"{name}: num_features={module.num_features}")


if __name__ == "__main__":
    # 设置随机种子保证可重复性
    torch.manual_seed(42)
    np.random.seed(42)

    # 数据预处理
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # 加载数据集
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 初始化模型
    model = make_resnet18_cifar10()
    model = model.to(device)

    # 初始训练(微调)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    print("Starting initial training (fine-tuning)...")
    best_acc = fine_tune_model(model, trainloader, testloader, criterion, optimizer, scheduler, 20)

    # 加载最佳模型
    model.load_state_dict(torch.load('best_model.pth'))

    # 打印原始模型大小
    print("\nOriginal model size:")
    print_model_size(model)
    print("\n原始模型结构:")
    summary(model, input_size=(1, 3, 32, 32))
    # 创建模型副本进行剪枝
    pruned_model = copy.deepcopy(model)
    # 执行剪枝
    pruning_percent = 0.2  # 统一剪枝比例

    pruned_model = prune_model(pruned_model, pruning_percent)  # 执行剪枝
    summary(pruned_model, input_size=(1, 3, 32, 32))
    # 在剪枝完成后调用
    print("\nPruned model shapes:")
    print_model_shapes(pruned_model)
    # 打印剪枝后的模型大小
    print("\nPruned model size:")
    print_model_size(pruned_model)
    # 定义新的优化器(可能需要更小的学习率)
    optimizer_pruned = optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler_pruned = optim.lr_scheduler.CosineAnnealingLR(optimizer_pruned, T_max=100)

    print("Starting fine-tuning after pruning...")
    best_pruned_acc = fine_tune_model(pruned_model, trainloader, testloader, criterion, optimizer_pruned, scheduler_pruned, 15)

    # 比较原始模型和剪枝后模型的性能
    print("\nResults Comparison:")
    print(f"Original model accuracy: {best_acc:.2f}%")
    print(f"Pruned model accuracy: {best_pruned_acc:.2f}%")
    print(f"Accuracy drop: {best_acc - best_pruned_acc:.2f}%")
相关推荐
Guheyunyi2 分钟前
AI集成运维管理平台的架构与核心构成解析
大数据·运维·人工智能·科技·安全·架构
吧啦吧啦吡叭卜12 分钟前
目标检测我来惹2-SPPNet
人工智能·深度学习
广州正荣15 分钟前
Python爬虫进阶:气象数据爬取中的多线程优化与异常处理技巧
人工智能·python·科技
cooldream200922 分钟前
AI测试用例生成的基本流程与实践
人工智能·测试用例
引量AI26 分钟前
技术赋能——AI社媒矩阵营销工具如何重构社媒矩阵底层架构
人工智能·矩阵·自动化·tiktok矩阵·海外社媒
Secede.27 分钟前
TrOCR模型微调
python·深度学习·ocr
SoFlu软件机器人27 分钟前
AI 重构的陷阱:如何避免旧项目越改越烂?
人工智能·重构
MasterLLL022832 分钟前
DAY 53 对抗生成网络
人工智能
刘延林.35 分钟前
ROS 2安装 slam_toolbox
人工智能·机器人·自动驾驶
jndingxin43 分钟前
OpenCV CUDA模块图像变形------对图像进行GPU加速的透视变换函数warpPerspective()
人工智能·opencv·计算机视觉