在深度学习的世界里,我们习惯了"大力出奇迹":模型越大,参数越多,效果似乎就越好。但你是否想过,这些庞大的参数中,可能绝大多数都是在"陪跑"?
2019 年 ICLR 的一篇经典论文提出了著名的"彩票假设"(The Lottery Ticket Hypothesis),颠覆了我们对神经网络参数冗余的认知。今天我们就来深入解读这篇论文,看看如何从大模型中找到那张"中奖彩票",并动手实现一个最小 Demo。
1. 论文核心:什么是"彩票假设"?
我们在训练一个巨大的密集网络时,往往能得到很好的效果。但这篇论文告诉我们,这个密集网络内部包含了一些稀疏的子网络(sparse subnetworks) 。
彩票假设的核心定义是:
一个随机初始化的密集神经网络,包含一个子网络,如果将该子网络单独拿出来,并使用原始的初始化权重进行训练,它能够在相似的迭代次数内达到与原始网络相当的测试准确率 。
通俗的理解:
训练大网络就像买了一大把彩票。之所以大网络效果好,是因为参数够多,里面大概率包含了一张"中奖"的彩票(即那个特定的子网络结构 + 特定的初始权重)。其他的参数只是为了保证你买到了这张彩票而已。
2. 核心创新与关键技术
这篇论文不仅仅是一个理论猜想,它提供了一套具体的算法来挖掘这些子网络。
2.1 关键算法:迭代幅度剪枝 (Iterative Magnitude Pruning)
作者发现,简单的一步剪枝往往不够极致。为了找到参数量仅为原模型 10%-20% 的中奖彩票,论文提出了以下流程 :
- 随机初始化 :初始化一个网络,保存其初始权重 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ 0 \theta_0 </math>θ0。
- 训练 :正常训练网络直到收敛,得到参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ j \theta_j </math>θj。
- 剪枝 :移除幅度(绝对值)最小的 <math xmlns="http://www.w3.org/1998/Math/MathML"> p % p\% </math>p% 的参数,生成掩码(Mask) <math xmlns="http://www.w3.org/1998/Math/MathML"> m m </math>m。
- 重置(The Magic Step) :这是最关键的创新点。不要 继续微调,也不要 重新随机初始化,而是将剩余参数的值恢复到步骤 1 中的初始值 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ 0 \theta_0 </math>θ0(即应用 <math xmlns="http://www.w3.org/1998/Math/MathML"> m ⊙ θ 0 m \odot \theta_0 </math>m⊙θ0)。
- 循环:重复上述步骤,直到达到目标的稀疏度。
2.2 初始化的决定性作用
论文通过实验证明了一个反直觉的现象:如果你找到了这个稀疏架构,但给它重新随机初始化(Random Reinitialization),它的训练效果会大打折扣,收敛变慢且精度下降 。
这说明,"中奖彩票"之所以能中奖,不仅仅是因为它的结构 (长得好),更是因为它的初始权重(出身好)。
2.3 针对深层网络的优化
对于像 VGG-19 或 ResNet-18 这样的深层网络,简单的剪枝容易失败。作者引入了学习率预热(Learning Rate Warmup) ,成功在这些深层网络中找到了中奖彩票 。
3. 实际应用场景
虽然我们现在通常还是先训练大模型,但彩票假设为未来提供了巨大的想象空间:
-
端侧设备的高效推理:
找到中奖彩票后,我们可以获得参数量减少 90% 以上的模型。这对于手机、IoT 设备等资源受限的场景至关重要,能显著降低存储需求和推理能耗 。
-
训练加速(稀疏训练):
如果我们能开发出在训练早期就识别出"中奖彩票"的方法,就可以在训练过程中直接优化这个小网络,从而大幅节省昂贵的 GPU 算力 。
-
模型迁移与设计:
在一个任务上发现的"中奖"结构,可能隐含了处理该类数据的最佳归纳偏置(Inductive Bias),可以迁移到相似任务中,指导新的网络架构设计 9999。
4. 最小可运行 Demo (PyTorch)
下面是一个基于 PyTorch 的最小实现,展示了如何训练、剪枝、并回溯权重来寻找"中奖彩票"。
Python
python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import copy
# 1. 定义一个简单的全连接网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
x = torch.flatten(x, 1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# 初始化模型
model = SimpleNet()
# === 关键步骤 1: 保存原始初始化权重 (\theta_0) ===
# deepcopy 确保我们存下的是数值副本,而不是引用
initial_state_dict = copy.deepcopy(model.state_dict())
print(f"原始参数量 (fc1): {torch.count_nonzero(model.fc1.weight)}")
# === 关键步骤 2: 模拟训练过程 (\theta_j) ===
# 这里用随机更新模拟训练,实际场景中通过 Loss 反向传播更新
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for _ in range(10): # 模拟训练 10 步
dummy_input = torch.randn(1, 784)
dummy_target = torch.randn(1, 10)
output = model(dummy_input)
loss = nn.MSELoss()(output, dummy_target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("训练后权重已更新...")
# === 关键步骤 3: 全局幅度剪枝 (Pruning) ===
# 移除 20% 幅度最小的权重
parameters_to_prune = (
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2, # 剪掉 20%
)
# 打印剪枝后的非零参数量
print(f"剪枝后非零参数量 (fc1): {torch.count_nonzero(model.fc1.weight)}")
# 此时 model 中包含了 mask,权重虽然被 mask 了,但数值还是训练后的数值
# === 关键步骤 4: 重置回初始权重 (Reset to \theta_0) ===
# 这是彩票假设的灵魂:保持剪枝的 Mask 不变,但把权重值变回最初的样子
# 获取当前的 mask (PyTorch prune 会把 mask 存在 'weight_mask' 缓冲区中)
masks = {}
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
masks[name + '.weight'] = module.weight_mask
# 移除 PyTorch 的 prune hook,使权重变为普通的 Parameter
for module, name in parameters_to_prune:
prune.remove(module, name)
# 加载原始初始化权重
model.load_state_dict(initial_state_dict)
# 重新应用 Mask
with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 将原始权重乘以我们计算出的 Mask
# 这样我们就得到了:原始的初始化值 + 稀疏的结构
module.weight.data.mul_(masks[name + '.weight'])
print("已重置为初始权重,但在相同位置进行了稀疏化。")
print(f"重置后非零参数量 (fc1): {torch.count_nonzero(model.fc1.weight)}")
print("这就是一张'中奖彩票'!")
总结
"彩票假设"不仅揭示了深度学习中巨大的参数冗余,也为我们指明了方向:优秀的网络并不一定需要那么大,关键在于找到那个对的结构和对的起点。 随着硬件和算法的进步,也许未来我们真的能直接"打印"彩票,而不需要买下整个彩票店。
参考来源:arXiv:1803.03635v5 [cs.LG]