从零学习大模型(十一)-----Lottery Ticket Hypothesis剪枝

Lottery Ticket Hypothesis(LTH)是由 Frankle 和 Carbin 在 2019 年提出的一种剪枝方法,其核心思想是神经网络中存在可以单独训练的小型子网络(即"中奖票"),这些子网络可以在保持原始模型性能的情况下有效地训练。通过找到这些子网络,我们可以实现大模型的剪枝,从而减少模型的计算复杂度和存储需求。

实现过程

  1. 初始训练
    • 对于一个大型神经网络,首先对其进行完全训练,得到一个经过充分训练的基准模型。
    • 在此阶段,所有权重都将参与训练,并且模型逐渐逼近最优状态。
  2. 权重重要性评估与剪枝
    • 对训练后的模型,使用权重的重要性度量方法(如权重的绝对值大小)来评估每个权重在模型中的贡献。权重的重要性度量是基于这样一个假设:权重的绝对值越大,其对模型预测的贡献就越大。
    • 权重绝对值大小:在神经网络中,权重的绝对值大小可以用来衡量其对输出的影响程度。通常情况下,较大的权重对神经元的激活产生更显著的影响,因此对最终的预测结果也具有更大的贡献。反之,绝对值较小的权重对输出的影响较小,可以被认为是冗余的。
    • 具体步骤
      1. 计算权重的绝对值:对于每个神经网络层中的权重,计算其绝对值。
      2. 排序和选择:根据权重的绝对值大小进行排序,将绝对值较小的权重标记为不重要。
      3. 剪枝:剪去这些不重要的权重,使模型变得更加稀疏。
    • 剪枝后会得到一个稀疏子网络,这个子网络保留了大部分重要的连接,同时大大减少了参数数量。
  3. 重置权重和再训练
    • 将剪枝后的子网络的权重重置为它们在初始随机化时的值。这个步骤的目的是希望子网络能够独立训练,而不是依赖于剪枝前的已训练权重。
    • 通过将权重重置为初始状态,可以验证这些被称为"中奖票"的子网络是否具有足够的表达能力,能够单独训练达到与原始大模型相似的性能。
  4. 迭代剪枝
    • 对剪枝后的子网络进行再训练,直到它能够达到与原始模型相近的性能。
    • 如果目标是进一步减少模型大小,可以多次进行剪枝和再训练的过程,直到达到所需的压缩比例。
    • 每次迭代剪枝都会进一步减少不重要的权重,逐步形成一个稀疏、可训练的小型子网络。

代码实现Lottery Ticket Hypothesis的剪枝全过程

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
from tqdm import tqdm

# 检查是否可以使用GPU(针对MacBook M3芯片)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # 修改view以确保batch size匹配
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型
model = SimpleNet().to(device)
initial_state_dict = copy.deepcopy(model.state_dict())  # 保存初始权重

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 数据预处理和数据加载器
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='../datasets', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

# 模型初始训练
def train(model, optimizer, criterion, dataloader, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(dataloader, 0):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            # 在更新之前再次应用掩码,确保剪枝的权重不会被更新
            apply_pruning_mask(model)
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:  # 每100个批次打印一次损失
                print(f'Epoch [{epoch + 1}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

# 剪枝函数
def prune_by_magnitude(model, amount=0.2):
    print("Starting prune_by_magnitude...")
    # 计算每个参数的绝对值并排序
    all_weights = []
    for param in model.parameters():
        if len(param.data.size()) != 1:  # 忽略偏置项
            all_weights.extend(param.cpu().data.abs().numpy().flatten())
    threshold = torch.tensor(sorted(all_weights)[int(len(all_weights) * amount)])

    # 根据阈值剪枝,并保存掩码
    with torch.no_grad():
        for name, param in tqdm(list(model.named_parameters()), desc="Applying pruning mask", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'):
            if "weight" in name:
                mask = (torch.abs(param) > threshold).float().to(device)
                param.mul_(mask)
                model.register_buffer(f"mask_{name.replace('.', '_')}", mask)  # 保存掩码用于冻结权重

# 重置权重函数
def reset_weights(model, initial_state_dict):
    state_dict = {k: v for k, v in initial_state_dict.items() if "mask" not in k}
    model.load_state_dict(state_dict, strict=False)

# 修改优化器以冻结被剪枝的权重
def apply_pruning_mask(model):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if "weight" in name and hasattr(model, f"mask_{name.replace('.', '_')}"):
                mask = getattr(model, f"mask_{name.replace('.', '_')}")
                param.mul_(mask)  # 确保被剪枝的权重保持为零

if __name__ == "__main__":
    # 初始训练
    train(model, optimizer, criterion, trainloader, epochs=5)

    # 剪枝并重置权重
    prune_by_magnitude(model, amount=0.2)
    reset_weights(model, initial_state_dict)

    # 再次训练前应用剪枝掩码,以确保被剪枝的权重保持为零
    apply_pruning_mask(model)

    # 再次训练
    train(model, optimizer, criterion, trainloader, epochs=5)

    # 迭代剪枝和再训练的过程可以继续进行,直到达到所需的压缩比例

优点

  1. 高效压缩:LTH 方法可以找到一个非常稀疏的子网络,使得模型的计算量和存储需求大幅降低,同时性能基本不受影响。
  2. 理论支持:LTH 提出了一个关于神经网络可训练性的理论假设,即在大型神经网络中存在一个子网络(中奖票),如果将其权重重置为初始值并独立训练,这个子网络可以达到与原始模型相近的性能。具体来说,LTH 假设在初始随机权重中已经存在一个可以有效训练的稀疏子网络,这个子网络在训练时具备足够的表示能力和学习能力。因此,通过找到这个子网络并重置其权重,可以在保持模型性能的前提下减少不必要的参数,从而实现模型的压缩。
  3. 适用于多种架构:这种方法可以应用于不同类型的神经网络架构,包括卷积神经网络(CNN)和 Transformer 等。

缺点

  1. 计算开销大:LTH 方法需要多次反复地训练、剪枝和重置权重,因此训练过程相对耗时且计算资源需求较高。
  2. 剪枝策略依赖于初始权重:剪枝后的模型性能与初始权重的选择关系密切,存在一定的随机性,这可能导致最终的剪枝效果不稳定。

应用场景

  • 移动设备和嵌入式系统:LTH 可以用于在内存和计算能力有限的设备上部署深度学习模型,例如移动设备、边缘计算设备等,通过找到稀疏子网络来实现模型压缩和加速。
  • 加速推理:对于需要实时推理的应用,剪枝后的稀疏子网络可以减少计算量,从而加速模型推理。

相关文献

  • Frankle, J., & Carbin, M. (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. ICLR 2019.
相关推荐
芥末虾1 分钟前
【优选算法】KMP模式匹配算法 {算法介绍;算法原理:核心原理,如何求next数组;代码实现}
c语言·c++·算法·kmp·字符串模式匹配
曾几何时`6 分钟前
对撞双指针(七)三数之和
数据结构·算法·leetcode
薔薇十字7 分钟前
【代码随想录day36】【C++复健】1049. 最后一块石头的重量 II ; 494. 目标和 ;474.一和零
c++·算法·leetcode
一只小透明啊啊啊啊9 分钟前
【代码随想录】哈希
算法·哈希算法
m0_694938019 分钟前
Leetcode打卡:最小区间
算法·leetcode·职场和发展
学习前端的小z14 分钟前
【GPTs】Front-end Expert:助力前端开发的智能工具
人工智能·gpt·chatgpt·aigc
陈奕迅本讯17 分钟前
人力资源项目学习
java·学习
2401_8784673219 分钟前
大连环保公益管理系统|Java|SSM|Vue| 前后端分离
java·开发语言·学习·tomcat·maven
青椒大仙KI1122 分钟前
24/11/24 视觉笔记 滤镜
笔记·深度学习·计算机视觉
使者大牙27 分钟前
【LLM学习笔记】第四篇:模型压缩方法——量化、剪枝、蒸馏、分解
人工智能·深度学习·算法·机器学习