Lottery Ticket Hypothesis(LTH)是由 Frankle 和 Carbin 在 2019 年提出的一种剪枝方法,其核心思想是神经网络中存在可以单独训练的小型子网络(即"中奖票"),这些子网络可以在保持原始模型性能的情况下有效地训练。通过找到这些子网络,我们可以实现大模型的剪枝,从而减少模型的计算复杂度和存储需求。
实现过程
- 初始训练
- 对于一个大型神经网络,首先对其进行完全训练,得到一个经过充分训练的基准模型。
- 在此阶段,所有权重都将参与训练,并且模型逐渐逼近最优状态。
- 权重重要性评估与剪枝
- 对训练后的模型,使用权重的重要性度量方法(如权重的绝对值大小)来评估每个权重在模型中的贡献。权重的重要性度量是基于这样一个假设:权重的绝对值越大,其对模型预测的贡献就越大。
- 权重绝对值大小:在神经网络中,权重的绝对值大小可以用来衡量其对输出的影响程度。通常情况下,较大的权重对神经元的激活产生更显著的影响,因此对最终的预测结果也具有更大的贡献。反之,绝对值较小的权重对输出的影响较小,可以被认为是冗余的。
- 具体步骤 :
- 计算权重的绝对值:对于每个神经网络层中的权重,计算其绝对值。
- 排序和选择:根据权重的绝对值大小进行排序,将绝对值较小的权重标记为不重要。
- 剪枝:剪去这些不重要的权重,使模型变得更加稀疏。
- 剪枝后会得到一个稀疏子网络,这个子网络保留了大部分重要的连接,同时大大减少了参数数量。
- 重置权重和再训练
- 将剪枝后的子网络的权重重置为它们在初始随机化时的值。这个步骤的目的是希望子网络能够独立训练,而不是依赖于剪枝前的已训练权重。
- 通过将权重重置为初始状态,可以验证这些被称为"中奖票"的子网络是否具有足够的表达能力,能够单独训练达到与原始大模型相似的性能。
- 迭代剪枝
- 对剪枝后的子网络进行再训练,直到它能够达到与原始模型相近的性能。
- 如果目标是进一步减少模型大小,可以多次进行剪枝和再训练的过程,直到达到所需的压缩比例。
- 每次迭代剪枝都会进一步减少不重要的权重,逐步形成一个稀疏、可训练的小型子网络。
代码实现Lottery Ticket Hypothesis的剪枝全过程
python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
from tqdm import tqdm
# 检查是否可以使用GPU(针对MacBook M3芯片)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 定义一个简单的神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 16 * 16, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(x.size(0), -1) # 修改view以确保batch size匹配
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 初始化模型
model = SimpleNet().to(device)
initial_state_dict = copy.deepcopy(model.state_dict()) # 保存初始权重
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 数据预处理和数据加载器
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
trainset = torchvision.datasets.CIFAR10(root='../datasets', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
# 模型初始训练
def train(model, optimizer, criterion, dataloader, epochs=5):
model.train()
for epoch in range(epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader, 0):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
# 在更新之前再次应用掩码,确保剪枝的权重不会被更新
apply_pruning_mask(model)
optimizer.step()
running_loss += loss.item()
if i % 100 == 99: # 每100个批次打印一次损失
print(f'Epoch [{epoch + 1}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
running_loss = 0.0
# 剪枝函数
def prune_by_magnitude(model, amount=0.2):
print("Starting prune_by_magnitude...")
# 计算每个参数的绝对值并排序
all_weights = []
for param in model.parameters():
if len(param.data.size()) != 1: # 忽略偏置项
all_weights.extend(param.cpu().data.abs().numpy().flatten())
threshold = torch.tensor(sorted(all_weights)[int(len(all_weights) * amount)])
# 根据阈值剪枝,并保存掩码
with torch.no_grad():
for name, param in tqdm(list(model.named_parameters()), desc="Applying pruning mask", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'):
if "weight" in name:
mask = (torch.abs(param) > threshold).float().to(device)
param.mul_(mask)
model.register_buffer(f"mask_{name.replace('.', '_')}", mask) # 保存掩码用于冻结权重
# 重置权重函数
def reset_weights(model, initial_state_dict):
state_dict = {k: v for k, v in initial_state_dict.items() if "mask" not in k}
model.load_state_dict(state_dict, strict=False)
# 修改优化器以冻结被剪枝的权重
def apply_pruning_mask(model):
with torch.no_grad():
for name, param in model.named_parameters():
if "weight" in name and hasattr(model, f"mask_{name.replace('.', '_')}"):
mask = getattr(model, f"mask_{name.replace('.', '_')}")
param.mul_(mask) # 确保被剪枝的权重保持为零
if __name__ == "__main__":
# 初始训练
train(model, optimizer, criterion, trainloader, epochs=5)
# 剪枝并重置权重
prune_by_magnitude(model, amount=0.2)
reset_weights(model, initial_state_dict)
# 再次训练前应用剪枝掩码,以确保被剪枝的权重保持为零
apply_pruning_mask(model)
# 再次训练
train(model, optimizer, criterion, trainloader, epochs=5)
# 迭代剪枝和再训练的过程可以继续进行,直到达到所需的压缩比例
优点
- 高效压缩:LTH 方法可以找到一个非常稀疏的子网络,使得模型的计算量和存储需求大幅降低,同时性能基本不受影响。
- 理论支持:LTH 提出了一个关于神经网络可训练性的理论假设,即在大型神经网络中存在一个子网络(中奖票),如果将其权重重置为初始值并独立训练,这个子网络可以达到与原始模型相近的性能。具体来说,LTH 假设在初始随机权重中已经存在一个可以有效训练的稀疏子网络,这个子网络在训练时具备足够的表示能力和学习能力。因此,通过找到这个子网络并重置其权重,可以在保持模型性能的前提下减少不必要的参数,从而实现模型的压缩。
- 适用于多种架构:这种方法可以应用于不同类型的神经网络架构,包括卷积神经网络(CNN)和 Transformer 等。
缺点
- 计算开销大:LTH 方法需要多次反复地训练、剪枝和重置权重,因此训练过程相对耗时且计算资源需求较高。
- 剪枝策略依赖于初始权重:剪枝后的模型性能与初始权重的选择关系密切,存在一定的随机性,这可能导致最终的剪枝效果不稳定。
应用场景
- 移动设备和嵌入式系统:LTH 可以用于在内存和计算能力有限的设备上部署深度学习模型,例如移动设备、边缘计算设备等,通过找到稀疏子网络来实现模型压缩和加速。
- 加速推理:对于需要实时推理的应用,剪枝后的稀疏子网络可以减少计算量,从而加速模型推理。
相关文献
- Frankle, J., & Carbin, M. (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. ICLR 2019.