从零学习大模型(十一)-----Lottery Ticket Hypothesis剪枝

Lottery Ticket Hypothesis(LTH)是由 Frankle 和 Carbin 在 2019 年提出的一种剪枝方法,其核心思想是神经网络中存在可以单独训练的小型子网络(即"中奖票"),这些子网络可以在保持原始模型性能的情况下有效地训练。通过找到这些子网络,我们可以实现大模型的剪枝,从而减少模型的计算复杂度和存储需求。

实现过程

  1. 初始训练
    • 对于一个大型神经网络,首先对其进行完全训练,得到一个经过充分训练的基准模型。
    • 在此阶段,所有权重都将参与训练,并且模型逐渐逼近最优状态。
  2. 权重重要性评估与剪枝
    • 对训练后的模型,使用权重的重要性度量方法(如权重的绝对值大小)来评估每个权重在模型中的贡献。权重的重要性度量是基于这样一个假设:权重的绝对值越大,其对模型预测的贡献就越大。
    • 权重绝对值大小:在神经网络中,权重的绝对值大小可以用来衡量其对输出的影响程度。通常情况下,较大的权重对神经元的激活产生更显著的影响,因此对最终的预测结果也具有更大的贡献。反之,绝对值较小的权重对输出的影响较小,可以被认为是冗余的。
    • 具体步骤
      1. 计算权重的绝对值:对于每个神经网络层中的权重,计算其绝对值。
      2. 排序和选择:根据权重的绝对值大小进行排序,将绝对值较小的权重标记为不重要。
      3. 剪枝:剪去这些不重要的权重,使模型变得更加稀疏。
    • 剪枝后会得到一个稀疏子网络,这个子网络保留了大部分重要的连接,同时大大减少了参数数量。
  3. 重置权重和再训练
    • 将剪枝后的子网络的权重重置为它们在初始随机化时的值。这个步骤的目的是希望子网络能够独立训练,而不是依赖于剪枝前的已训练权重。
    • 通过将权重重置为初始状态,可以验证这些被称为"中奖票"的子网络是否具有足够的表达能力,能够单独训练达到与原始大模型相似的性能。
  4. 迭代剪枝
    • 对剪枝后的子网络进行再训练,直到它能够达到与原始模型相近的性能。
    • 如果目标是进一步减少模型大小,可以多次进行剪枝和再训练的过程,直到达到所需的压缩比例。
    • 每次迭代剪枝都会进一步减少不重要的权重,逐步形成一个稀疏、可训练的小型子网络。

代码实现Lottery Ticket Hypothesis的剪枝全过程

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
from tqdm import tqdm

# 检查是否可以使用GPU(针对MacBook M3芯片)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # 修改view以确保batch size匹配
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型
model = SimpleNet().to(device)
initial_state_dict = copy.deepcopy(model.state_dict())  # 保存初始权重

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 数据预处理和数据加载器
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='../datasets', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

# 模型初始训练
def train(model, optimizer, criterion, dataloader, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(dataloader, 0):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            # 在更新之前再次应用掩码,确保剪枝的权重不会被更新
            apply_pruning_mask(model)
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:  # 每100个批次打印一次损失
                print(f'Epoch [{epoch + 1}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0

# 剪枝函数
def prune_by_magnitude(model, amount=0.2):
    print("Starting prune_by_magnitude...")
    # 计算每个参数的绝对值并排序
    all_weights = []
    for param in model.parameters():
        if len(param.data.size()) != 1:  # 忽略偏置项
            all_weights.extend(param.cpu().data.abs().numpy().flatten())
    threshold = torch.tensor(sorted(all_weights)[int(len(all_weights) * amount)])

    # 根据阈值剪枝,并保存掩码
    with torch.no_grad():
        for name, param in tqdm(list(model.named_parameters()), desc="Applying pruning mask", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'):
            if "weight" in name:
                mask = (torch.abs(param) > threshold).float().to(device)
                param.mul_(mask)
                model.register_buffer(f"mask_{name.replace('.', '_')}", mask)  # 保存掩码用于冻结权重

# 重置权重函数
def reset_weights(model, initial_state_dict):
    state_dict = {k: v for k, v in initial_state_dict.items() if "mask" not in k}
    model.load_state_dict(state_dict, strict=False)

# 修改优化器以冻结被剪枝的权重
def apply_pruning_mask(model):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if "weight" in name and hasattr(model, f"mask_{name.replace('.', '_')}"):
                mask = getattr(model, f"mask_{name.replace('.', '_')}")
                param.mul_(mask)  # 确保被剪枝的权重保持为零

if __name__ == "__main__":
    # 初始训练
    train(model, optimizer, criterion, trainloader, epochs=5)

    # 剪枝并重置权重
    prune_by_magnitude(model, amount=0.2)
    reset_weights(model, initial_state_dict)

    # 再次训练前应用剪枝掩码,以确保被剪枝的权重保持为零
    apply_pruning_mask(model)

    # 再次训练
    train(model, optimizer, criterion, trainloader, epochs=5)

    # 迭代剪枝和再训练的过程可以继续进行,直到达到所需的压缩比例

优点

  1. 高效压缩:LTH 方法可以找到一个非常稀疏的子网络,使得模型的计算量和存储需求大幅降低,同时性能基本不受影响。
  2. 理论支持:LTH 提出了一个关于神经网络可训练性的理论假设,即在大型神经网络中存在一个子网络(中奖票),如果将其权重重置为初始值并独立训练,这个子网络可以达到与原始模型相近的性能。具体来说,LTH 假设在初始随机权重中已经存在一个可以有效训练的稀疏子网络,这个子网络在训练时具备足够的表示能力和学习能力。因此,通过找到这个子网络并重置其权重,可以在保持模型性能的前提下减少不必要的参数,从而实现模型的压缩。
  3. 适用于多种架构:这种方法可以应用于不同类型的神经网络架构,包括卷积神经网络(CNN)和 Transformer 等。

缺点

  1. 计算开销大:LTH 方法需要多次反复地训练、剪枝和重置权重,因此训练过程相对耗时且计算资源需求较高。
  2. 剪枝策略依赖于初始权重:剪枝后的模型性能与初始权重的选择关系密切,存在一定的随机性,这可能导致最终的剪枝效果不稳定。

应用场景

  • 移动设备和嵌入式系统:LTH 可以用于在内存和计算能力有限的设备上部署深度学习模型,例如移动设备、边缘计算设备等,通过找到稀疏子网络来实现模型压缩和加速。
  • 加速推理:对于需要实时推理的应用,剪枝后的稀疏子网络可以减少计算量,从而加速模型推理。

相关文献

  • Frankle, J., & Carbin, M. (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. ICLR 2019.
相关推荐
神奇夜光杯18 分钟前
Python酷库之旅-第三方库Pandas(181)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
光明中黑暗19 分钟前
Python 学习笔记
笔记·python·学习
聪明的墨菲特i24 分钟前
VUE组件学习 | 六、v-if, v-else-if, v-else组件
前端·vue.js·学习·前端框架·vue
徐小夕@趣谈前端29 分钟前
MaxKB: 一款基于大语言模型的知识库问答系统
人工智能·语言模型·自然语言处理
ya888g31 分钟前
信息学奥赛复赛复习19-CSP-J2023-02公路-贪心算法、向上取整、向下取整
c++·算法
夜雨翦春韭40 分钟前
【代码随想录Day58】图论Part09
java·开发语言·数据结构·算法·leetcode·图论
正义的彬彬侠40 分钟前
近似线性可分支持向量机的原理推导
人工智能·机器学习·支持向量机·svm·近似线性可分支持向量机
正义的彬彬侠42 分钟前
绘制近似线性可分支持向量机的分类边界和支持向量
人工智能·python·机器学习·支持向量机·分类·svm
爱喝白开水a42 分钟前
零基础入门AI:一键本地运行各种开源大语言模型 - Ollama
人工智能·程序人生·语言模型·开源·大语言模型·开源大模型·大模型入门
纪怽ぅ1 小时前
LSTM——长短期记忆神经网络
python·深度学习·神经网络·算法·机器学习·lstm