深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)

一、背景:为什么需要模型剪枝?

随着深度学习的发展,模型参数量和计算量呈指数级增长。以ResNet18为例,其在ImageNet上的参数量约为1100万,虽然在服务器端运行流畅,但在移动端或嵌入式设备上部署时,内存和计算资源的限制使得直接使用大模型变得困难。模型剪枝(Model Pruning)作为模型压缩的核心技术之一,通过删除冗余的神经元或通道,在保持模型性能的前提下显著降低模型大小和计算量,是解决这一问题的关键手段。

在前面一篇文章我们也提到了模型压缩的一些基本定义和核心原理:《深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏》。

本文将基于PyTorch框架,以ResNet18在CIFAR-10数据集上的分类任务为例,详细讲解结构化通道剪枝的完整实现流程,包括模型训练、剪枝策略、剪枝后结构调整、微调及效果评估。

二、整体流程概览

本文代码的核心流程可总结为以下6步:

  1. 环境初始化与数据集加载
  2. 原始模型训练与评估
  3. 卷积层结构化剪枝(以conv1层为例)
  4. 剪枝后模型结构调整(BN层、残差下采样层等)
  5. 剪枝模型微调
  6. 剪枝前后模型效果对比
    特地说明:在这里选择conv1层作为例子,不是因为选择这个就会效果更好。

三、关键步骤代码解析

3.1 环境初始化与数据集准备

首先需要配置计算设备(GPU/CPU),并加载CIFAR-10数据集。CIFAR-10包含10类32x32的彩色图像,训练集5万张,测试集1万张。

python 复制代码
def setup_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1,1]
    ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
    return train_dataset, test_dataset

3.2 原始模型训练

使用预训练的ResNet18模型,修改全连接层输出为10类(匹配CIFAR-10的类别数),并进行5轮训练:

python 复制代码
def create_model(device):
    model = models.resnet18(pretrained=True)  # 加载ImageNet预训练权重
    model.fc = nn.Linear(512, 10)  # 修改输出层为10类
    return model.to(device)

def train_model(model, train_loader, criterion, optimizer, device, epochs=3):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")
    return model

3.3 结构化通道剪枝核心实现

本文重点是对卷积层进行结构化剪枝(按通道剪枝),具体步骤如下:

3.3.1 计算通道重要性

通过计算卷积核的L2范数评估通道重要性。假设卷积层权重维度为[out_channels, in_channels, kernel_h, kernel_w],将每个输出通道的权重展平为一维向量,计算其L2范数,范数越小表示该通道对模型性能贡献越低,越应被剪枝。

python 复制代码
layer = dict(model.named_modules())[layer_name]  # 获取目标卷积层
weight = layer.weight.data
channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)  # 计算每个输出通道的L2范数
3.3.2 生成剪枝掩码

根据剪枝比例(如20%),选择范数最小的通道生成掩码:

python 复制代码
num_channels = weight.shape[0]  # 原始输出通道数(如ResNet18的conv1层为64)
num_prune = int(num_channels * amount)  # 需剪枝的通道数(如64*0.2=12)
_, indices = torch.topk(channel_norm, k=num_prune, largest=False)  # 找到最不重要的12个通道

mask = torch.ones(num_channels, dtype=torch.bool)
mask[indices] = False  # 掩码:保留的通道标记为True(52个),剪枝的标记为False(12个)
3.3.3 替换卷积层

创建新的卷积层,仅保留掩码为True的通道:

python 复制代码
new_conv = nn.Conv2d(
    in_channels=layer.in_channels,
    out_channels=num_channels - num_prune,  # 剪枝后输出通道数(52)
    kernel_size=layer.kernel_size,
    stride=layer.stride,
    padding=layer.padding,
    bias=layer.bias is not None
).to(device)  # 移动到模型所在设备

new_conv.weight.data = layer.weight.data[mask]  # 保留掩码为True的通道权重
if layer.bias is not None:
    new_conv.bias.data = layer.bias.data[mask]  # 偏置同理
3.3.4 关键:剪枝后结构调整

直接剪枝会导致后续层(如BN层、残差连接中的下采样层)的输入/输出通道不匹配,必须同步调整:

(1) 调整BN层

卷积层后通常接BN层,BN的num_features需与卷积输出通道数一致:

python 复制代码
if 'conv1' in layer_name:
    bn1 = model.bn1
    new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)  # 新BN层通道数52
    with torch.no_grad():
        # 同步原始BN层的参数(仅保留未被剪枝的通道)
        new_bn1.weight.data = bn1.weight.data[mask].clone()
        new_bn1.bias.data = bn1.bias.data[mask].clone()
        new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()
        new_bn1.running_var.data = bn1.running_var.data[mask].clone()
    model.bn1 = new_bn1

(2) 调整残差下采样层

ResNet的残差块(如layer1.0)中,若主路径的通道数被剪枝,需要通过1x1卷积的下采样层(downsample)匹配 shortcut 的通道数:

python 复制代码
block = model.layer1[0]
if not hasattr(block, 'downsample') or block.downsample is None:
    # 原始无downsample,创建新的1x1卷积+BN
    downsample_conv = nn.Conv2d(
        in_channels=new_conv.out_channels,  # 52(剪枝后的conv1输出)
        out_channels=block.conv2.out_channels,  # 64(主路径conv2的输出)
        kernel_size=1,
        stride=1,
        bias=False
    ).to(device)
    torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')  # 初始化权重
    
    downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)
    block.downsample = nn.Sequential(downsample_conv, downsample_bn)  # 添加downsample层
else:
    # 原有downsample层,调整输入通道
    downsample_conv = block.downsample[0]
    downsample_conv.in_channels = new_conv.out_channels  # 输入通道改为52
    downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用掩码筛选

(3) 前向传播验证

调整后需验证模型能否正常前向传播,避免通道不匹配导致的错误:

python 复制代码
with torch.no_grad():
    test_input = torch.randn(1, 3, 32, 32).to(device)  # 测试输入(B, C, H, W)
    try:
        model(test_input)
        print("✅ 前向传播验证通过")
    except Exception as e:
        print(f"❌ 验证失败: {str(e)}")
        raise

3.3的总结,直接上代码

python 复制代码
def prune_conv_layer(model, layer_name, amount=0.2):
    # 获取模型当前所在设备
    device = next(model.parameters()).device  # 新增:获取设备
    
    layer = dict(model.named_modules())[layer_name]
    weight = layer.weight.data
    channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)
    
    num_channels = weight.shape[0]  # 原始通道数(如 64)
    num_prune = int(num_channels * amount)
    _, indices = torch.topk(channel_norm, k=num_prune, largest=False)
    
    mask = torch.ones(num_channels, dtype=torch.bool)
    mask[indices] = False  # 生成剪枝掩码(长度 64,52 个 True)
    
    new_conv = nn.Conv2d(
        in_channels=layer.in_channels,
        out_channels=num_channels - num_prune,  # 剪枝后通道数(如 52)
        kernel_size=layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding,
        bias=layer.bias is not None
    )
    new_conv = new_conv.to(device)  # 新增:移动到模型所在设备
    
    new_conv.weight.data = layer.weight.data[mask]  # 保留 mask 为 True 的通道
    if layer.bias is not None:
        new_conv.bias.data = layer.bias.data[mask]
    
    # 替换原始卷积层
    parent_name, sep, name = layer_name.rpartition('.')
    parent = model.get_submodule(parent_name)
    setattr(parent, name, new_conv)

    if 'conv1' in layer_name:
        # 1. 更新与 conv1 直接关联的 BN1 层
        bn1 = model.bn1
        new_bn1 = nn.BatchNorm2d(new_conv.out_channels)  # 新 BN 层通道数 52
        new_bn1 = new_bn1.to(device)  # 新增:移动到模型所在设备
        with torch.no_grad():
            new_bn1.weight.data = bn1.weight.data[mask].clone()
            new_bn1.bias.data = bn1.bias.data[mask].clone()
            new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()
            new_bn1.running_var.data = bn1.running_var.data[mask].clone()
        model.bn1 = new_bn1

        # 2. 处理残差连接中的 downsample(关键修正:添加缺失的 downsample)
        block = model.layer1[0]
        if not hasattr(block, 'downsample') or block.downsample is None:
            # 原始无 downsample,需创建新的 1x1 卷积+BN 来匹配通道
            downsample_conv = nn.Conv2d(
                in_channels=new_conv.out_channels,  # 52
                out_channels=block.conv2.out_channels,  # 64(主路径输出通道数)
                kernel_size=1,
                stride=1,
                bias=False
            )
            downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备
            # 初始化 1x1 卷积权重(这里简单复制原模型可能的统计量,实际可根据需求调整)
            torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')
            
            downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)
            downsample_bn = downsample_bn.to(device)  # 新增:移动到模型所在设备
            with torch.no_grad():
                # 初始化 BN 参数(可保持默认,或根据原模型统计量调整)
                downsample_bn.weight.fill_(1.0)
                downsample_bn.bias.zero_()
                downsample_bn.running_mean.zero_()
                downsample_bn.running_var.fill_(1.0)
            
            block.downsample = nn.Sequential(downsample_conv, downsample_bn)
            print("✅ 为 layer1.0 添加新的 downsample 层")
        else:
            # 原有 downsample 层,调整输入通道
            downsample_conv = block.downsample[0]
            downsample_conv.in_channels = new_conv.out_channels  # 输入通道调整为 52
            downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用 mask 筛选
            downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备
            
            downsample_bn = block.downsample[1]
            new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)
            new_downsample_bn = new_downsample_bn.to(device)  # 新增:移动到模型所在设备
            with torch.no_grad():
                new_downsample_bn.weight.data = downsample_bn.weight.data.clone()
                new_downsample_bn.bias.data = downsample_bn.bias.data.clone()
                new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()
                new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()
            block.downsample[1] = new_downsample_bn

        # 3. 同步 layer1.0.conv1 的输入通道(保持原有逻辑)
        next_convs = ['layer1.0.conv1']
        for conv_path in next_convs:
            try:
                conv = model.get_submodule(conv_path)
                if conv.in_channels != new_conv.out_channels:
                    print(f"同步输入通道: {conv.in_channels} → {new_conv.out_channels}")
                    conv.in_channels = new_conv.out_channels
                    conv.weight = nn.Parameter(conv.weight.data[:, mask, :, :].clone())
                    conv = conv.to(device)  # 新增:移动到模型所在设备
            except AttributeError as e:
                print(f"⚠️ 卷积层调整失败: {conv_path} ({str(e)})")

        # 验证前向传播
        with torch.no_grad():
            test_input = torch.randn(1, 3, 32, 32).to(device)  # 确保测试输入也在相同设备
            try:
                model(test_input)
                print("✅ 前向传播验证通过")
            except Exception as e:
                print(f"❌ 验证失败: {str(e)}")
                raise

    return model

3.4 剪枝模型微调

剪枝后模型的部分参数被删除,需要通过微调恢复性能。一开始,我们只是在微调时冻结了除 fc 层外的所有参数,但是效果并不好,当然分析原因,除了动了conv1的原因(conv1 是模型的第一个卷积层,负责提取最基础的图像特征(如边缘、纹理、颜色等)。这些底层特征对后续所有层的特征提取至关重要。 ),最重要的是裁剪后,需要对裁剪的层进行微调,确保参数适应新的特征维度。

微调时冻结了除 fc 层外的所有参数的代码和结果:

python 复制代码
for name, param in pruned_model.named_parameters():
        if 'fc' not in name:
            param.requires_grad = False
    optimizer = optim.Adam(pruned_model.fc.parameters(), lr=0.001)
    print("微调剪枝后的模型")
    pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device,epochs=5)
bash 复制代码
原始模型准确率: 80.07%
剪枝后模型准确率: 37.80%

可以看到这个相差很大

本文选择解冻被剪枝的层(如conv1bn1)及相关层(如layer1.0.conv1downsample)进行参数更新:

python 复制代码
print("开始微调剪枝后的模型")
for name, param in pruned_model.named_parameters():
    # 仅解冻与剪枝相关的层
    if 'conv1' in name or 'bn1' in name or 'layer1.0.conv1' in name or 'layer1.0.downsample' in name or 'fc' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.001)
pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=5)
bash 复制代码
原始模型准确率: 78.94%
剪枝后模型准确率:  81.30%

重新微调了裁剪后的层后,结果有了很大改变。

四、实验结果与分析

通过代码中的evaluate_model函数评估剪枝前后的模型准确率:

python 复制代码
def evaluate_model(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    return acc

假设原始模型准确率为88.5%,剪枝20%通道后(模型大小降低约20%),通过微调可恢复至87.2%,验证了剪枝策略的有效性。

五、总结与改进方向

本文实现了基于通道L2范数的结构化剪枝,重点解决了剪枝后模型结构不一致的问题(如BN层、残差下采样层的调整),并通过微调恢复了模型性能。

在这个例子中,仅裁剪 conv1 层的影响

仅裁剪 conv1 层对模型的影响极大,原因如下:

  • 底层特征的重要性 : conv1 输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。
  • 结构连锁反应 : conv1 的输出通道减少会触发 bn1 、 layer1.0.conv1 、 downsample 等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
    实际应用中可从以下方向改进:

模型裁剪通常优先选择 中间层(如ResNet的 layer2 、 layer3 ) ,而非底层或顶层,原因如下:

  • 底层(如 conv1 ) :负责基础特征提取,裁剪后特征损失大,对性能影响显著。
  • 中间层(如 layer2 、 layer3 ) :特征具有一定抽象性但冗余度高(同一层的多个通道可能提取相似特征),裁剪后对性能影响较小。
  • 顶层(如 fc 层) :负责分类决策,参数密度高但冗余度低,裁剪易导致分类能力下降。
相关推荐
dudly13 分钟前
大语言模型评测体系全解析(下篇):工具链、学术前沿与实战策略
人工智能·语言模型
zzlyx9921 分钟前
AI大数据模型如何与thingsboard物联网结合
人工智能·物联网
说私域1 小时前
定制开发开源AI智能名片驱动下的海报工厂S2B2C商城小程序运营策略——基于社群口碑传播与子市场细分的实证研究
人工智能·小程序·开源·零售
HillVue1 小时前
AI,如何重构理解、匹配与决策?
人工智能·重构
skywalk81631 小时前
市面上哪款AI开源软件做ppt最好?
人工智能·powerpoint
小九九的爸爸1 小时前
我是如何让AI帮我还原设计稿的
前端·人工智能·ai编程
小wanga2 小时前
【递归、搜索与回溯】专题三 穷举vs暴搜vs回溯vs剪枝
c++·算法·机器学习·剪枝
hanniuniu132 小时前
网络安全厂商F5推出AI Gateway,化解大模型应用风险
人工智能·web安全·gateway
Iamccc13_2 小时前
智能仓储的未来:自动化、AI与数据分析如何重塑物流中心
人工智能·数据分析·自动化
蹦蹦跳跳真可爱5893 小时前
Python----目标检测(使用YOLO 模型进行线程安全推理和流媒体源)
人工智能·python·yolo·目标检测·目标跟踪