目录
- 第一部分:剪枝与稀疏化基础理论
- 第一章:绪论------为什么需要模型剪枝
- 第二章:稀疏性的数学基础------范数、测度与信息论
- 第三章:最优脑损伤(OBD)与二阶方法
- 第二部分:非结构化剪枝方法
- 第四章:幅度剪枝与迭代剪枝
- 第五章:SNIP、GraSP 与一次射击剪枝
- 第六章:彩票假说(Lottery Ticket Hypothesis)
- 第三部分:结构化剪枝方法
- 第七章:通道剪枝与滤波器剪枝
- 第八章:注意力头剪枝与层剪枝
- 第九章:大语言模型的结构化剪枝
- 第四部分:稀疏格式与稀疏计算
- 第十章:稀疏矩阵存储格式------CSR、CSC、COO、BSR
- 第十一章:稀疏矩阵乘法与硬件加速
- 第十二章:2:4 结构化稀疏与 NVIDIA Ampere
- 第五部分:完整可运行代码实现
- 第十三章:从零实现非结构化剪枝系统
- 第十四章:从零实现结构化剪枝系统
- 第十五章:从零实现稀疏矩阵运算
- 第十六章:完整剪枝 Pipeline 与精度对比
- 附录
第一部分:剪枝与稀疏化基础理论
第一章:绪论------为什么需要模型剪枝
1.1 模型冗余性的理论基础
1.1.1 过参数化现象
现代深度学习模型普遍采用**过参数化(over-parameterization)**策略------参数量远大于训练样本数。
观察:
| 模型 | 参数量 | 训练样本(ImageNet) | 参数/样本比 |
|---|---|---|---|
| ResNet-50 | 25.6M | 1.28M | 20:1 |
| ViT-Large | 307M | 1.28M | 240:1 |
| GPT-3 | 175B | ~300B tokens | ~0.58:1 |
| LLaMA-70B | 70B | ~1T tokens | ~0.07:1 |
问题:为什么过参数化的模型仍然泛化良好?
定理 1.1(过参数化的隐式正则化) :在梯度下降训练中,过参数化模型倾向于收敛到平坦极小值(flat minima),其 Hessian 矩阵的特征值谱具有大量接近零的特征值。
证明思路:设损失函数为 L(θ)\mathcal{L}(\theta)L(θ),在极小值 θ∗\theta^*θ∗ 处,Hessian H=∇2L(θ∗)H = \nabla^2 \mathcal{L}(\theta^*)H=∇2L(θ∗) 的特征值分解为 H=UΛUTH = U \Lambda U^TH=UΛUT。过参数化意味着 HHH 是低秩的------存在大量特征值 λi≈0\lambda_i \approx 0λi≈0。
这些接近零的特征值对应的方向是冗余方向 ------沿着这些方向移动参数,损失几乎不变。这为剪枝提供了理论基础。□\square□
1.1.2 冗余性的量化
定义 1.1(有效参数数量):模型的有效参数数量定义为:
deff=∑i=1dλiλi+αd_{\text{eff}} = \sum_{i=1}^{d} \frac{\lambda_i}{\lambda_i + \alpha}deff=i=1∑dλi+αλi
其中 λi\lambda_iλi 是 Hessian 的特征值,α\alphaα 是正则化参数。
当 α→0\alpha \to 0α→0 时,deffd_{\text{eff}}deff 趋近于 Hessian 的秩------即真正"有用"的参数数量。
实证观察 :对于典型的大语言模型,deff/dd_{\text{eff}} / ddeff/d 通常在 10%-50% 之间------意味着 50%-90% 的参数是冗余的。
1.2 剪枝的核心动机
剪枝(Pruning) 是移除模型中冗余参数(将其设为零)的过程。
| 效益 | 说明 | 数学刻画 |
|---|---|---|
| 减少存储 | 稀疏矩阵只需要存储非零元素 | 存储 ∝\propto∝ 非零元素数量 |
| 减少计算 | 跳过零元素的乘法 | FLOPs ∝\propto∝ 非零元素数量 |
| 减少内存带宽 | 更少的数据需要从内存读取 | 带宽 ∝\propto∝ 非零元素数量 |
| 可能提高泛化 | 剪枝起到正则化的作用 | 经验观察 |
1.3 剪枝方法的分类
1.3.1 按粒度分类
| 粒度 | 说明 | 稀疏模式 | 硬件友好度 |
|---|---|---|---|
| 非结构化(Unstructured) | 单个权重 | 任意位置 | 低 |
| 模式块(Pattern/Block) | 固定模式(如 2:4) | 规则模式 | 高 |
| 通道(Channel) | 整个通道 | 整行/列 | 高 |
| 滤波器(Filter) | 整个卷积核 | 结构化 | 高 |
| 层(Layer) | 整个层 | 层级别 | 最高 |
1.3.2 按时机分类
| 时机 | 说明 | 代表方法 |
|---|---|---|
| 训练后剪枝(Post-Training) | 训练完成后剪枝 | 幅度剪枝 |
| 训练中剪枝(During Training) | 训练过程中逐步剪枝 | 迭代剪枝、GMP |
| 训练前剪枝(Before Training) | 初始化时就确定稀疏模式 | SNIP、GraSP |
| 一次射击剪枝(One-Shot) | 不需要重新训练 | SparseGPT、Wanda |
1.3.3 按评分标准分类
| 标准 | 说明 | 代表方法 |
|---|---|---|
| 幅度(Magnitude) | $ | w |
| 梯度(Gradient) | $ | \partial \mathcal{L}/\partial w |
| Hessian | 二阶信息 | OBS、OBD |
| 激活感知 | 考虑输入激活的影响 | Wanda |
| 信息论 | Fisher 信息、互信息 | Fisher 剪枝 |
第二章:稀疏性的数学基础------范数、测度与信息论
2.1 稀疏性的形式化
2.1.1 ℓ0\ell_0ℓ0 "范数"
定义 2.1(ℓ0\ell_0ℓ0 "范数") :向量 w∈Rn\mathbf{w} \in \mathbb{R}^nw∈Rn 的 ℓ0\ell_0ℓ0 "范数"定义为非零元素的数量:
∥w∥0=∣{i:wi≠0}∣=∑i=1n1wi≠0\|\mathbf{w}\|0 = |\{i : w_i \neq 0\}| = \sum{i=1}^n \mathbf{1}w_i \\neq 0∥w∥0=∣{i:wi=0}∣=i=1∑n1wi=0
注意 :∥⋅∥0\|\cdot\|_0∥⋅∥0 不是真正的范数(不满足齐次性:∥cw∥0=∥w∥0\|c\mathbf{w}\|_0 = \|\mathbf{w}\|_0∥cw∥0=∥w∥0 对 c≠0c \neq 0c=0)。
稀疏率 :s=1−∥w∥0/ns = 1 - \|\mathbf{w}\|_0 / ns=1−∥w∥0/n,即零元素的比例。
2.1.2 ℓ0\ell_0ℓ0 约束优化
剪枝问题可以形式化为 ℓ0\ell_0ℓ0 约束优化:
minwL(w)s.t.∥w∥0≤k\min_{\mathbf{w}} \mathcal{L}(\mathbf{w}) \quad \text{s.t.} \quad \|\mathbf{w}\|_0 \leq kwminL(w)s.t.∥w∥0≤k
或等价地,拉格朗日形式:
minwL(w)+λ∥w∥0\min_{\mathbf{w}} \mathcal{L}(\mathbf{w}) + \lambda \|\mathbf{w}\|_0wminL(w)+λ∥w∥0
问题 :ℓ0\ell_0ℓ0 范数是非凸、非连续的,直接优化是 NP-hard 的。
定理 2.1(ℓ0\ell_0ℓ0 优化的 NP-hardness) :求解 min∥w∥0≤k∥Aw−b∥2\min_{\|\mathbf{w}\|_0 \leq k} \|\mathbf{A}\mathbf{w} - \mathbf{b}\|^2min∥w∥0≤k∥Aw−b∥2 是 NP-hard 的(可以归约到稀疏回归问题)。
2.1.3 ℓ1\ell_1ℓ1 松弛
ℓ1\ell_1ℓ1 范数 是 ℓ0\ell_0ℓ0 的最紧凸松弛:
∥w∥1=∑i=1n∣wi∣\|\mathbf{w}\|1 = \sum{i=1}^n |w_i|∥w∥1=i=1∑n∣wi∣
定理 2.2(ℓ1\ell_1ℓ1 诱导稀疏性) :在满足受限等距性质(Restricted Isometry Property, RIP) 的条件下,ℓ1\ell_1ℓ1 最小化与 ℓ0\ell_0ℓ0 最小化等价------ℓ1\ell_1ℓ1 正则化自动产生稀疏解。
RIP 条件 :矩阵 A∈Rm×n\mathbf{A} \in \mathbb{R}^{m \times n}A∈Rm×n 满足 kkk-RIP,如果存在 δk∈(0,1)\delta_k \in (0, 1)δk∈(0,1) 使得对所有 kkk-稀疏向量 x\mathbf{x}x:
(1−δk)∥x∥2≤∥Ax∥2≤(1+δk)∥x∥2(1 - \delta_k)\|\mathbf{x}\|^2 \leq \|\mathbf{A}\mathbf{x}\|^2 \leq (1 + \delta_k)\|\mathbf{x}\|^2(1−δk)∥x∥2≤∥Ax∥2≤(1+δk)∥x∥2
定理 2.3(LASSO 的稀疏恢复) :如果 A\mathbf{A}A 满足 2k2k2k-RIP 且 δ2k<2−1\delta_{2k} < \sqrt{2} - 1δ2k<2 −1,则 LASSO 问题
minw12∥Aw−b∥2+λ∥w∥1\min_{\mathbf{w}} \frac{1}{2}\|\mathbf{A}\mathbf{w} - \mathbf{b}\|^2 + \lambda\|\mathbf{w}\|_1wmin21∥Aw−b∥2+λ∥w∥1
的解 w^\hat{\mathbf{w}}w^ 满足:
∥w^−w∗∥2≤C1σk(w∗)1k+C2ϵ\|\hat{\mathbf{w}} - \mathbf{w}^*\|_2 \leq C_1 \frac{\sigma_k(\mathbf{w}^*)_1}{\sqrt{k}} + C_2 \epsilon∥w^−w∗∥2≤C1k σk(w∗)1+C2ϵ
其中 σk(w)1=min∥v∥0≤k∥w−v∥1\sigma_k(\mathbf{w})1 = \min{\|\mathbf{v}\|_0 \leq k} \|\mathbf{w} - \mathbf{v}\|_1σk(w)1=min∥v∥0≤k∥w−v∥1 是最佳 kkk-项逼近误差。□\square□
2.2 剪枝的信息论视角
2.2.1 Fisher 信息与参数重要性
定义 2.2(Fisher 信息矩阵) :模型参数 θ\thetaθ 的 Fisher 信息矩阵定义为:
F(θ)=Ex∼pdata∇θlogp(y∣x;θ)⋅(∇θlogp(y∣x;θ))TF(\theta) = \mathbb{E}{x \sim p{\text{data}}} \left\\nabla_\\theta \\log p(y\|x; \\theta) \\cdot (\\nabla_\\theta \\log p(y\|x; \\theta))\^T\\rightF(θ)=Ex∼pdata∇θlogp(y∣x;θ)⋅(∇θlogp(y∣x;θ))T
Fisher 信息矩阵的对角元素:
Fii=E(∂logp(y∣x;θ)∂θi)2F_{ii} = \mathbb{E}\left\\left(\\frac{\\partial \\log p(y\|x; \\theta)}{\\partial \\theta_i}\\right)\^2\\rightFii=E(∂θi∂logp(y∣x;θ))2
衡量了参数 θi\theta_iθi 对模型输出的"重要性"。
定理 2.4(Fisher 信息与 KL 散度) :在参数 θ\thetaθ 的局部邻域内,模型输出分布的 KL 散度近似为:
DKL(pθ∥pθ+δθ)≈12δθTF(θ)δθD_{\text{KL}}(p_{\theta} \| p_{\theta + \delta\theta}) \approx \frac{1}{2} \delta\theta^T F(\theta) \delta\thetaDKL(pθ∥pθ+δθ)≈21δθTF(θ)δθ
推论 2.1 :移除参数 θi\theta_iθi(设为零)导致的 KL 散度增加约为:
ΔDKL≈12Fiiθi2\Delta D_{\text{KL}} \approx \frac{1}{2} F_{ii} \theta_i^2ΔDKL≈21Fiiθi2
因此,Fiiθi2F_{ii} \theta_i^2Fiiθi2 可以作为参数重要性的度量------同时考虑了参数值的大小(θi2\theta_i^2θi2)和参数对输出的影响(FiiF_{ii}Fii)。
2.2.2 移除参数的信息损失
定理 2.5(剪枝的信息论下界) :设原始模型参数为 θ∈Rn\theta \in \mathbb{R}^nθ∈Rn,剪枝后的参数为 θ^\hat{\theta}θ^(∥θ^∥0=k\|\hat{\theta}\|_0 = k∥θ^∥0=k),则模型输出分布的 KL 散度满足:
DKL(pθ∥pθ^)≥12∑i∈SFii(θi−θ^i)2D_{\text{KL}}(p_\theta \| p_{\hat{\theta}}) \geq \frac{1}{2} \sum_{i \in \mathcal{S}} F_{ii} (\theta_i - \hat{\theta}_i)^2DKL(pθ∥pθ^)≥21i∈S∑Fii(θi−θ^i)2
其中 S\mathcal{S}S 是被剪枝的参数集合。
最优剪枝策略 :选择使信息损失最小的 kkk 个参数保留:
S∗=argmin∣S∣=n−k∑i∈SFiiθi2\mathcal{S}^* = \arg\min_{|\mathcal{S}| = n-k} \sum_{i \in \mathcal{S}} F_{ii} \theta_i^2S∗=arg∣S∣=n−kmini∈S∑Fiiθi2
等价于保留 Fiiθi2F_{ii} \theta_i^2Fiiθi2 最大的 kkk 个参数。
2.3 剪枝与泛化
2.3.1 剪枝作为正则化
定理 2.6(剪枝的隐式正则化效应) :设 θ^pruned\hat{\theta}_{\text{pruned}}θ^pruned 是剪枝后的参数,θ∗\theta^*θ∗ 是未剪枝的最优参数。在一定条件下:
ELtest(θ\^pruned)≤ELtest(θ∗)+O(klognT)\mathbb{E}\\mathcal{L}_{\\text{test}}(\\hat{\\theta}_{\\text{pruned}}) \leq \mathbb{E}\\mathcal{L}_{\\text{test}}(\\theta\^\*) + O\left(\frac{k \log n}{T}\right)ELtest(θ\^pruned)≤ELtest(θ∗)+O(Tklogn)
其中 kkk 是保留的参数数量,TTT 是训练样本数。
直觉:剪枝减少了模型的有效复杂度,从而减少了泛化误差的上界。但过度剪枝会增加近似误差。
2.3.2 剪枝-精度的权衡曲线
定义 2.3(剪枝曲线) :剪枝曲线是稀疏率 sss 与模型精度 Acc(s)\text{Acc}(s)Acc(s) 之间的关系。
经验观察:剪枝曲线通常呈现三个阶段:
- 平坦区 (s<s1s < s_1s<s1):精度几乎不变,冗余参数被移除
- 缓慢下降区 (s1<s<s2s_1 < s < s_2s1<s<s2):精度缓慢下降
- 急剧下降区 (s>s2s > s_2s>s2):精度急剧下降,关键参数被移除
定理 2.7(剪枝曲线的理论形状) :对于线性模型 y^=wTx\hat{y} = \mathbf{w}^T \mathbf{x}y^=wTx,使用最优幅度剪枝,测试 MSE 为:
MSE(s)=MSE(0)+σ2n⋅s1−s⋅∑i∈Pwi2σi2\text{MSE}(s) = \text{MSE}(0) + \frac{\sigma^2}{n} \cdot \frac{s}{1-s} \cdot \sum_{i \in \mathcal{P}} \frac{w_i^2}{\sigma_i^2}MSE(s)=MSE(0)+nσ2⋅1−ss⋅i∈P∑σi2wi2
其中 P\mathcal{P}P 是被剪枝的参数集合,σi2\sigma_i^2σi2 是第 iii 个特征的方差。
第三章:最优脑损伤(OBD)与二阶方法
3.1 最优脑损伤(Optimal Brain Damage)
3.1.1 问题形式化
LeCun et al., 1989 提出的最优脑损伤(OBD) 是第一个系统性的剪枝方法,基于损失函数的二阶泰勒展开。
设模型参数为 θ∈Rn\theta \in \mathbb{R}^nθ∈Rn,损失函数为 L(θ)\mathcal{L}(\theta)L(θ)。在最优解 θ∗\theta^*θ∗ 处的二阶泰勒展开:
L(θ∗+δθ)≈L(θ∗)+12δθTHδθ\mathcal{L}(\theta^* + \delta\theta) \approx \mathcal{L}(\theta^*) + \frac{1}{2} \delta\theta^T H \delta\thetaL(θ∗+δθ)≈L(θ∗)+21δθTHδθ
其中 H=∇2L(θ∗)H = \nabla^2 \mathcal{L}(\theta^*)H=∇2L(θ∗) 是 Hessian 矩阵(一阶项为零,因为在最优解处梯度为零)。
3.1.2 单参数删除的误差
定理 3.1(OBD 的删除准则) :删除参数 θi\theta_iθi(设为零)导致的损失增加为:
δLi=12Hiiθi∗2\delta\mathcal{L}i = \frac{1}{2} H{ii} \theta_i^{*2}δLi=21Hiiθi∗2
证明 :设 δθ=−θi∗ei\delta\theta = -\theta_i^* \mathbf{e}_iδθ=−θi∗ei(即第 iii 个分量变为零),则:
δL=12(−θi∗ei)TH(−θi∗ei)=12θi∗2Hii\delta\mathcal{L} = \frac{1}{2} (-\theta_i^* \mathbf{e}_i)^T H (-\theta_i^* \mathbf{e}i) = \frac{1}{2} \theta_i^{*2} H{ii}δL=21(−θi∗ei)TH(−θi∗ei)=21θi∗2Hii
其中 ei\mathbf{e}_iei 是第 iii 个标准基向量。□\square□
OBD 的剪枝准则 :删除使 δLi=Hiiθi∗2\delta\mathcal{L}i = H{ii} \theta_i^{*2}δLi=Hiiθi∗2 最小的参数。
3.1.3 对角近似
问题 :计算完整 Hessian H∈Rn×nH \in \mathbb{R}^{n \times n}H∈Rn×n 的复杂度为 O(n2)O(n^2)O(n2)(存储)和 O(n3)O(n^3)O(n3)(求逆),对于大模型不可行。
OBD 的近似 :假设 Hessian 是对角 的------即 H≈diag(H11,...,Hnn)H \approx \text{diag}(H_{11}, \dots, H_{nn})H≈diag(H11,...,Hnn)。
这等价于假设各参数的二阶导数之间不相关:
∂2L∂θi∂θj≈0,i≠j\frac{\partial^2 \mathcal{L}}{\partial \theta_i \partial \theta_j} \approx 0, \quad i \neq j∂θi∂θj∂2L≈0,i=j
对角 Hessian 的计算:
Hii=∂2L∂θi2=E(∂logp(y∣x;θ)∂θi)2+E∂2logp(y∣x;θ)∂θi2H_{ii} = \frac{\partial^2 \mathcal{L}}{\partial \theta_i^2} = \mathbb{E}\left\\left(\\frac{\\partial \\log p(y\|x; \\theta)}{\\partial \\theta_i}\\right)\^2\\right + \mathbb{E}\left\\frac{\\partial\^2 \\log p(y\|x; \\theta)}{\\partial \\theta_i\^2}\\rightHii=∂θi2∂2L=E(∂θi∂logp(y∣x;θ))2+E∂θi2∂2logp(y∣x;θ)
第一项是 Fisher 信息的对角元素,第二项在最优解处通常很小(可以忽略)。
推论 3.1:在对角近似下,OBD 的剪枝准则近似为:
δLi≈12Fiiθi∗2\delta\mathcal{L}i \approx \frac{1}{2} F{ii} \theta_i^{*2}δLi≈21Fiiθi∗2
这与 Fisher 信息的参数重要性度量一致。
3.2 最优脑外科(Optimal Brain Surgeon)
3.2.1 OBS 的框架
Hassibi & Stork, 1993 提出的最优脑外科(OBS) 改进了 OBD------删除一个参数后,允许其他参数自由调整以补偿误差。
问题 :删除参数 θi\theta_iθi 后,找到最优的参数调整 δθ\delta\thetaδθ:
minδθδθTHδθs.t.eiT(θ∗+δθ)=0\min_{\delta\theta} \delta\theta^T H \delta\theta \quad \text{s.t.} \quad \mathbf{e}_i^T (\theta^* + \delta\theta) = 0δθminδθTHδθs.t.eiT(θ∗+δθ)=0
约束条件确保 θi+δθi=0\theta_i + \delta\theta_i = 0θi+δθi=0(第 iii 个参数被设为零)。
3.2.2 OBS 的最优解
定理 3.2(OBS 的最优调整):使用拉格朗日乘子法,最优调整为:
δθ∗=−θi∗H−1iiH−1ei\delta\theta^* = -\frac{\theta_i^*}{H\^{-1}_{ii}} H^{-1} \mathbf{e}_iδθ∗=−H−1iiθi∗H−1ei
对应的损失增加为:
δLiOBS=θi∗22H−1ii\delta\mathcal{L}i^{\text{OBS}} = \frac{\theta_i^{*2}}{2H\^{-1}{ii}}δLiOBS=2H−1iiθi∗2
证明:拉格朗日函数为:
L(δθ,λ)=12δθTHδθ+λ(eiTθ∗+eiTδθ)\mathcal{L}(\delta\theta, \lambda) = \frac{1}{2} \delta\theta^T H \delta\theta + \lambda (\mathbf{e}_i^T \theta^* + \mathbf{e}_i^T \delta\theta)L(δθ,λ)=21δθTHδθ+λ(eiTθ∗+eiTδθ)
对 δθ\delta\thetaδθ 求导并令其为零:
Hδθ+λei=0 ⟹ δθ=−λH−1eiH \delta\theta + \lambda \mathbf{e}_i = 0 \implies \delta\theta = -\lambda H^{-1} \mathbf{e}_iHδθ+λei=0⟹δθ=−λH−1ei
代入约束:
eiTθ∗−λeiTH−1ei=0 ⟹ λ=θi∗H−1ii\mathbf{e}_i^T \theta^* - \lambda \mathbf{e}_i^T H^{-1} \mathbf{e}i = 0 \implies \lambda = \frac{\theta_i^*}{H\^{-1}{ii}}eiTθ∗−λeiTH−1ei=0⟹λ=H−1iiθi∗
因此 δθ∗=−θi∗H−1iiH−1ei\delta\theta^* = -\frac{\theta_i^*}{H\^{-1}_{ii}} H^{-1} \mathbf{e}_iδθ∗=−H−1iiθi∗H−1ei。
损失增加:
δL=12(δθ∗)THδθ∗=12θi∗2H−1ii2eiTH−1HH−1ei=θi∗22H−1ii\delta\mathcal{L} = \frac{1}{2} (\delta\theta^*)^T H \delta\theta^* = \frac{1}{2} \frac{\theta_i^{*2}}{H\^{-1}_{ii}^2} \mathbf{e}_i^T H^{-1} H H^{-1} \mathbf{e}i = \frac{\theta_i^{*2}}{2H\^{-1}{ii}}δL=21(δθ∗)THδθ∗=21H−1ii2θi∗2eiTH−1HH−1ei=2H−1iiθi∗2
□\square□
3.2.3 OBS vs OBD
定理 3.3(OBS 优于 OBD) :对于任意参数 θi\theta_iθi:
δLiOBS=θi∗22H−1ii≤12Hiiθi∗2=δLiOBD\delta\mathcal{L}i^{\text{OBS}} = \frac{\theta_i^{*2}}{2H\^{-1}{ii}} \leq \frac{1}{2} H_{ii} \theta_i^{*2} = \delta\mathcal{L}_i^{\text{OBD}}δLiOBS=2H−1iiθi∗2≤21Hiiθi∗2=δLiOBD
等号成立当且仅当 HHH 是对角矩阵。
证明 :对于正定矩阵 HHH,H−1ii≥1/HiiH\^{-1}{ii} \geq 1/H{ii}H−1ii≥1/Hii(由 Cauchy-Schwarz 不等式)。□\square□
直觉:OBS 允许其他参数补偿被删除参数的功能,因此损失更小。OBD 假设其他参数不变,因此损失更大。
3.2.4 OBS 的高效实现
问题 :OBS 需要 H−1H^{-1}H−1,计算复杂度为 O(n3)O(n^3)O(n3)。
高效 OBS(EOBS):使用 Sherman-Morrison 公式进行增量更新。
设当前 Hessian 逆为 H−1H^{-1}H−1,删除参数 θi\theta_iθi 后,更新后的 Hessian(去掉第 iii 行和列)的逆可以通过以下公式计算:
H−1iˉ,iˉnew=H−1iˉ,iˉ−H−1iˉ,iH−1i,iˉH−1iiH\^{-1}{\bar{i},\bar{i}}^{\text{new}} = H\^{-1}{\bar{i},\bar{i}} - \frac{H\^{-1}{\bar{i},i} H\^{-1}{i,\bar{i}}}{H\^{-1}_{ii}}H−1iˉ,iˉnew=H−1iˉ,iˉ−H−1iiH−1iˉ,iH−1i,iˉ
其中 iˉ\bar{i}iˉ 表示除第 iii 个之外的所有索引。
第二部分:非结构化剪枝方法
第四章:幅度剪枝与迭代剪枝
4.1 单次幅度剪枝
4.1.1 基本定义
幅度剪枝(Magnitude Pruning) 是最简单的剪枝方法------删除绝对值最小的权重:
w^i={wiif ∣wi∣>τ0otherwise\hat{w}_i = \begin{cases} w_i & \text{if } |w_i| > \tau \\ 0 & \text{otherwise} \end{cases}w^i={wi0if ∣wi∣>τotherwise
其中阈值 τ\tauτ 由目标稀疏率 sss 决定:τ=Percentile(∣w∣,s×100)\tau = \text{Percentile}(|\mathbf{w}|, s \times 100)τ=Percentile(∣w∣,s×100)。
4.1.2 幅度剪枝的理论基础
定理 4.1(幅度剪枝的最优性条件) :对于线性模型 y^=wTx\hat{y} = \mathbf{w}^T \mathbf{x}y^=wTx,x∼N(0,I)\mathbf{x} \sim \mathcal{N}(0, I)x∼N(0,I),使用对角 Fisher 信息近似,幅度剪枝等价于最小化 Fisher 信息加权的参数重要性:
importance(wi)=Fiiwi2=Exi2⋅wi2=wi2\text{importance}(w_i) = F_{ii} w_i^2 = \mathbb{E}x_i\^2 \cdot w_i^2 = w_i^2importance(wi)=Fiiwi2=Exi2⋅wi2=wi2
因为 Fii=Exi2=1F_{ii} = \mathbb{E}x_i\^2 = 1Fii=Exi2=1(对于标准高斯输入)。
推论 4.1:当输入特征是独立同分布的标准高斯时,幅度剪枝是最优的(在对角 Fisher 近似下)。
但实际中:输入特征通常不是独立同分布的,幅度剪枝可能不是最优的。
4.1.3 幅度剪枝的局限
问题 1:权重大小不等于重要性
某些权重可能很小,但对模型输出有重要影响(例如,通过与其他大权重的交互)。
问题 2:不考虑权重之间的相关性
幅度剪枝独立评估每个权重,忽略了权重之间的协同效应。
问题 3:一次性剪枝精度损失大
高稀疏率下,单次剪枝的精度损失远大于迭代剪枝。
4.2 迭代剪枝(Iterative Pruning)
4.2.1 基本思想
迭代剪枝 (也称为 逐步剪枝 或 贪心剪枝):每次只剪枝一小部分权重,然后微调恢复精度,重复直到达到目标稀疏率。
算法 4.1(迭代幅度剪枝):
输入:预训练模型 θ_0,目标稀疏率 s,每次剪枝比例 p
输出:稀疏模型 θ_sparse
θ = θ_0
当前稀疏率 = 0
while 当前稀疏率 < s:
1. 剪枝:删除 |θ_i| 最小的 p 比例参数
2. 微调:在训练数据上微调剩余参数
3. 更新当前稀疏率
返回 θ
4.2.2 迭代剪枝的理论分析
定理 4.2(迭代剪枝的误差上界) :设每次剪枝比例为 ppp,微调后损失恢复率为 ρ\rhoρ(即微调后损失为剪枝前的 ρ\rhoρ 倍),LLL 次迭代后的总损失为:
LL≤L0⋅∏l=1L(1+(1−ρ)p)\mathcal{L}_L \leq \mathcal{L}0 \cdot \prod{l=1}^{L} (1 + (1-\rho) p)LL≤L0⋅l=1∏L(1+(1−ρ)p)
当 ppp 很小且 ρ\rhoρ 接近 1 时,LL≈L0⋅(1+(1−ρ)Lp)\mathcal{L}_L \approx \mathcal{L}_0 \cdot (1 + (1-\rho) L p)LL≈L0⋅(1+(1−ρ)Lp)------损失随迭代次数线性增长。
与单次剪枝的比较 :单次剪枝率 s=1−(1−p)Ls = 1 - (1-p)^Ls=1−(1−p)L 时,单次剪枝的损失远大于迭代剪枝。
4.2.3 剪枝调度(Pruning Schedule)
问题 :每次剪枝的比例 ppp 如何选择?
常见调度:
| 调度 | 公式 | 说明 |
|---|---|---|
| 线性 | pl=p0p_l = p_0pl=p0 | 每次固定比例 |
| 多项式 | pl=p0⋅(1−l/L)αp_l = p_0 \cdot (1 - l/L)^\alphapl=p0⋅(1−l/L)α | 逐渐减小比例 |
| 指数 | pl=p0⋅γlp_l = p_0 \cdot \gamma^lpl=p0⋅γl | 指数衰减 |
| 余弦 | pl=p0⋅(1+cos(πl/L))/2p_l = p_0 \cdot (1 + \cos(\pi l/L))/2pl=p0⋅(1+cos(πl/L))/2 | 余弦退火 |
定理 4.3(最优剪枝调度):在一定假设下,余弦调度在相同总剪枝率下达到最低的最终损失。
4.3 渐进式幅度剪枝(GMP)
4.3.1 渐进式剪枝
渐进式幅度剪枝(Gradual Magnitude Pruning, GMP) (Zhu & Gupta, 2017)在训练过程中连续地增加稀疏率。
稀疏率调度:
s(t)=sf+(si−sf)(1−t−titf−ti)3s(t) = s_f + (s_i - s_f) \left(1 - \frac{t - t_i}{t_f - t_i}\right)^3s(t)=sf+(si−sf)(1−tf−tit−ti)3
其中:
- sis_isi:初始稀疏率(通常为 0)
- sfs_fsf:最终稀疏率
- tit_iti:开始剪枝的训练步数
- tft_ftf:结束剪枝的训练步数
4.3.2 GMP 的理论基础
定理 4.4(GMP 的收敛性):在平滑性和有界梯度假设下,GMP 训练的模型收敛到损失函数的一个稳定点,其梯度范数满足:
∥∇L(θT)∥2≤O(1T+sfT)\|\nabla \mathcal{L}(\theta_T)\|^2 \leq O\left(\frac{1}{\sqrt{T}} + \frac{s_f}{\sqrt{T}}\right)∥∇L(θT)∥2≤O(T 1+T sf)
直觉:渐进式剪枝允许模型在训练过程中逐步适应稀疏约束,比一次性剪枝更温和。
第五章:SNIP、GraSP 与一次射击剪枝
5.1 SNIP:单次剪枝识别重要连接
5.1.1 核心思想
SNIP (Lee et al., 2019)在训练之前就确定稀疏模式,基于单次前向-反向传播的梯度信息。
连接灵敏度(Connection Sensitivity):
si=∣∂L∂ci⋅ci∣s_i = \left|\frac{\partial \mathcal{L}}{\partial c_i} \cdot c_i\right|si= ∂ci∂L⋅ci
其中 ci∈{0,1}c_i \in \{0, 1\}ci∈{0,1} 是连接掩码(1 表示保留,0 表示删除),L\mathcal{L}L 是损失函数。
直觉 :sis_isi 衡量了如果将连接 iii 从"存在"变为"不存在",损失会变化多少。
5.1.2 SNIP 的数学推导
定理 5.1(SNIP 的梯度-幅度准则):连接灵敏度可以简化为:
si=∣wi⋅gi∣s_i = |w_i \cdot g_i|si=∣wi⋅gi∣
其中 gi=∂L/∂wig_i = \partial \mathcal{L} / \partial w_igi=∂L/∂wi 是权重 wiw_iwi 的梯度。
证明 :设 wic=ci⋅wiw_i^c = c_i \cdot w_iwic=ci⋅wi,则:
∂L∂ci=∂L∂wic⋅∂wic∂ci=gi⋅wi\frac{\partial \mathcal{L}}{\partial c_i} = \frac{\partial \mathcal{L}}{\partial w_i^c} \cdot \frac{\partial w_i^c}{\partial c_i} = g_i \cdot w_i∂ci∂L=∂wic∂L⋅∂ci∂wic=gi⋅wi
因此 si=∣gi⋅wi∣s_i = |g_i \cdot w_i|si=∣gi⋅wi∣。□\square□
SNIP 算法:
1. 采样一批数据
2. 前向传播,计算损失
3. 反向传播,计算梯度 g
4. 计算灵敏度 s_i = |w_i * g_i|
5. 保留 s_i 最大的 k 个连接
6. 使用保留的连接进行正常训练
5.1.3 SNIP 的理论性质
定理 5.2(SNIP 与 Fisher 信息的关系) :连接灵敏度 si=∣wigi∣s_i = |w_i g_i|si=∣wigi∣ 是 Fisher 信息加权参数重要性 Fiiwi2F_{ii} w_i^2Fiiwi2 的单样本无偏估计:
Esi2=Ewi2gi2=wi2Egi2=wi2Fii\mathbb{E}s_i\^2 = \mathbb{E}w_i\^2 g_i\^2 = w_i^2 \mathbb{E}g_i\^2 = w_i^2 F_{ii}Esi2=Ewi2gi2=wi2Egi2=wi2Fii
推论:SNIP 可以看作是基于单样本 Fisher 信息的参数重要性评估。
5.2 GraSP:梯度信号剪枝
5.2.1 核心思想
GraSP (Wang et al., 2020)改进了 SNIP------不仅考虑梯度的大小,还考虑梯度的方向。
定义 5.1(GraSP 评分):
GraSPi=−gi⋅(Hg)i\text{GraSP}_i = -g_i \cdot (\mathbf{H}\mathbf{g})_iGraSPi=−gi⋅(Hg)i
其中 H\mathbf{H}H 是 Hessian 矩阵,g\mathbf{g}g 是梯度向量。
直觉 :GraSP 评分衡量了删除参数 iii 后,梯度信号的变化。如果 GraSPi>0\text{GraSP}_i > 0GraSPi>0,删除该参数会增强 梯度信号(有利于训练);如果 GraSPi<0\text{GraSP}_i < 0GraSPi<0,删除该参数会减弱梯度信号。
5.2.2 GraSP 的理论推导
定理 5.3(GraSP 的最优性):GraSP 评分等价于最小化删除参数后的梯度范数变化:
GraSPi=−∂∥∇L∥2∂ci∣c=1\text{GraSP}i = -\frac{\partial \|\nabla \mathcal{L}\|^2}{\partial c_i}\Bigg|{c=\mathbf{1}}GraSPi=−∂ci∂∥∇L∥2 c=1
证明 :设掩码为 ccc,参数为 wic=ciwiw_i^c = c_i w_iwic=ciwi。梯度范数为:
∥∇L∥2=∑igi2\|\nabla \mathcal{L}\|^2 = \sum_i g_i^2∥∇L∥2=i∑gi2
对 cic_ici 求导(使用链式法则和 Hessian):
∂∥∇L∥2∂ci=2∑jgj∂gj∂ci=2∑jgj∂2L∂wj∂wi⋅wi=2wi(Hg)i\frac{\partial \|\nabla \mathcal{L}\|^2}{\partial c_i} = 2 \sum_j g_j \frac{\partial g_j}{\partial c_i} = 2 \sum_j g_j \frac{\partial^2 \mathcal{L}}{\partial w_j \partial w_i} \cdot w_i = 2 w_i (\mathbf{H}\mathbf{g})_i∂ci∂∥∇L∥2=2j∑gj∂ci∂gj=2j∑gj∂wj∂wi∂2L⋅wi=2wi(Hg)i
因此 GraSPi=−gi(Hg)i\text{GraSP}_i = -g_i (\mathbf{H}\mathbf{g})_iGraSPi=−gi(Hg)i(忽略常数因子)。□\square□
5.2.3 GraSP 的高效计算
问题 :计算 Hg\mathbf{H}\mathbf{g}Hg 需要 Hessian-向量乘积,复杂度为 O(n2)O(n^2)O(n2)。
高效实现 :使用 Pearlmutter 技巧------Hessian-向量乘积可以在一次前向-反向传播中计算:
Hv=∇θ(gTv)\mathbf{H}\mathbf{v} = \nabla_\theta (\mathbf{g}^T \mathbf{v})Hv=∇θ(gTv)
其中 v\mathbf{v}v 是任意向量。
5.3 SparseGPT:大语言模型的一次射击剪枝
5.3.1 核心思想
SparseGPT(Frantar & Alistarh, 2023)将 GPTQ 的思想从量化推广到剪枝------通过 Hessian 信息进行最优权重删除和误差补偿。
问题形式化 :对于权重矩阵 W∈Rm×nW \in \mathbb{R}^{m \times n}W∈Rm×n 和校准输入 XXX:
minW^∥(W−W^)X∥F2s.t.∥W^∥0≤k\min_{\hat{W}} \|(W - \hat{W})X\|_F^2 \quad \text{s.t.} \quad \|\hat{W}\|_0 \leq kW^min∥(W−W^)X∥F2s.t.∥W^∥0≤k
5.3.2 SparseGPT 的算法
算法 5.1(SparseGPT):
输入:权重 W (m×n),校准输入 X,目标稀疏率 s
输出:稀疏权重 W_sparse
1. 计算 Hessian: H = 2/n * X^T X + λI
2. 计算 H^{-1}
3. 对每一行 i = 1, ..., m:
a. 初始化: w = W[i, :], δ = 0
b. 对每一列 j = 1, ..., n:
- 如果 |w[j]| < τ (阈值):
* 删除: w[j] = 0
* 补偿: w[j+1:n] -= (w[j] / H^{-1}[j,j]) * H^{-1}[j, j+1:n]
- 否则:
* 保留: W_sparse[i, j] = w[j]
4. 返回 W_sparse
5.3.3 SparseGPT 与 GPTQ 的统一
定理 5.4(SparseGPT-GPTQ 统一框架):SparseGPT 和 GPTQ 都基于相同的最优脑外科框架,区别在于:
- GPTQ:将权重量化到最近的量化级别,补偿误差
- SparseGPT:将权重设为零(或保留),补偿误差
两者可以统一为:对每个权重,选择一个"目标值"(量化级别或零),然后用 Hessian 信息补偿误差。
5.4 Wanda:权重与激活剪枝
5.4.1 核心思想
Wanda(Sun et al., 2024)提出了一种更简单但同样有效的方法:同时考虑权重的幅度和对应输入激活的幅度。
评分准则:
scoreij=∣Wij∣⋅∥Xj∥2\text{score}{ij} = |W{ij}| \cdot \|X_j\|_2scoreij=∣Wij∣⋅∥Xj∥2
其中 XjX_jXj 是输入激活的第 jjj 个通道(在校准数据上取平均)。
直觉 :权重 WijW_{ij}Wij 的重要性取决于:
- 权重本身的大小 ∣Wij∣|W_{ij}|∣Wij∣
- 对应输入通道的"活跃程度" ∥Xj∥2\|X_j\|_2∥Xj∥2
5.4.2 Wanda 与 OBS 的关系
定理 5.5(Wanda 作为 OBS 的近似):在对角 Hessian 近似下,OBS 的参数重要性为:
importanceijOBS=Hjj⋅Wij2\text{importance}{ij}^{\text{OBS}} = H{jj} \cdot W_{ij}^2importanceijOBS=Hjj⋅Wij2
其中 Hjj=2T∑tXtj2∝∥Xj∥22H_{jj} = \frac{2}{T}\sum_t X_{tj}^2 \propto \|X_j\|_2^2Hjj=T2∑tXtj2∝∥Xj∥22。
因此:
importanceijOBS∝∥Xj∥22⋅Wij2=(∥Xj∥2⋅∣Wij∣)2\text{importance}_{ij}^{\text{OBS}} \propto \|X_j\|2^2 \cdot W{ij}^2 = (\|X_j\|2 \cdot |W{ij}|)^2importanceijOBS∝∥Xj∥22⋅Wij2=(∥Xj∥2⋅∣Wij∣)2
Wanda 的评分 scoreij=∣Wij∣⋅∥Xj∥2\text{score}{ij} = |W{ij}| \cdot \|X_j\|_2scoreij=∣Wij∣⋅∥Xj∥2 正是 OBS 重要性的平方根------单调等价。□\square□
第六章:彩票假说(Lottery Ticket Hypothesis)
6.1 核心陈述
彩票假说(Lottery Ticket Hypothesis, LTH)(Frankle & Carlin, 2019):
猜想 6.1(彩票假说) :对于一个随机初始化的密集网络 f(θ0)f(\theta_0)f(θ0),存在一个子网络 f(θ0⊙m)f(\theta_0 \odot m)f(θ0⊙m)(其中 mmm 是二值掩码),该子网络从相同的初始化开始训练,可以在相同或更少的迭代次数内达到与原始密集网络相当的精度。
称这样的子网络为中奖彩票(winning ticket)。
6.1.1 形式化
设密集网络参数为 θ0∈Rn\theta_0 \in \mathbb{R}^nθ0∈Rn(随机初始化),训练后的参数为 θT\theta_TθT。掩码 m∈{0,1}nm \in \{0, 1\}^nm∈{0,1}n。
中奖彩票 :存在 mmm 使得:
L(train(m⊙θ0,T))≤L(θT)\mathcal{L}(\text{train}(m \odot \theta_0, T)) \leq \mathcal{L}(\theta_T)L(train(m⊙θ0,T))≤L(θT)
其中 train(θ0,T)\text{train}(\theta_0, T)train(θ0,T) 表示从 θ0\theta_0θ0 开始训练 TTT 步的结果。
6.2 迭代剪枝找到中奖彩票
6.2.1 迭代幅度剪枝(IMP)
算法 6.1(迭代幅度剪枝,IMP):
输入:随机初始化 θ_0,训练步数 T,剪枝比例 p,迭代次数 L
输出:中奖彩票 (θ_0 ⊙ m, m)
θ = θ_0
m = 1 (全 1 掩码)
for l = 1, ..., L:
1. 训练: θ_T = train(m ⊙ θ, T)
2. 剪枝: 删除 |θ_T| 最小的 p 比例参数,更新 m
3. 重置: θ = θ_0 (回到初始值)
返回 (m ⊙ θ_0, m)
关键步骤 :第 3 步将参数重置为初始值 θ0\theta_0θ0------这是彩票假说的核心要求。
6.2.2 为什么需要重置?
实验观察 :如果不重置为 θ0\theta_0θ0,而是继续从当前参数训练,精度会显著下降。
理论解释 :θ0\theta_0θ0 的某些特定初始化值恰好适合后续的训练动态。重置确保了这些"幸运"的初始化值被保留。
6.2.3 彩票假说的理论分析
定理 6.1(中奖彩票的存在性,简化版) :对于过参数化的线性网络 f(θ)=θTxf(\theta) = \theta^T xf(θ)=θTx,使用梯度下降训练,存在一个稀疏率为 sss 的子网络,其中:
s≤1−O(deffn)s \leq 1 - O\left(\frac{d_{\text{eff}}}{n}\right)s≤1−O(ndeff)
使得该子网络可以从 θ0\theta_0θ0 开始训练到与密集网络相同的精度。
证明思路 :过参数化意味着 Hessian 的有效秩 deff≪nd_{\text{eff}} \ll ndeff≪n。沿着 Hessian 的零空间方向,参数可以被剪枝而不影响训练动态。□\square□
6.3 彩票假说的局限与扩展
6.3.1 大规模模型的挑战
问题:IMP 需要多次从头训练,对于大模型(如 GPT-3)计算成本过高。
解决方案:
- 一次性剪枝(SparseGPT、Wanda):不需要重新训练
- 转移彩票:在小模型上找到彩票,迁移到大模型
- 动态稀疏训练(DST):在训练过程中动态调整稀疏模式
6.3.2 稀疏迁移学习
定理 6.2(彩票的可迁移性):在源任务上找到的中奖彩票,在目标任务上的表现优于随机稀疏模式,但不如在目标任务上直接找到的彩票。
经验观察:迁移的彩票提供了一个良好的初始化,但仍需要少量微调。
第三部分:结构化剪枝方法
第七章:通道剪枝与滤波器剪枝
7.1 通道剪枝的基本思想
7.1.1 为什么需要结构化剪枝?
非结构化剪枝的问题:
- 稀疏矩阵在通用硬件上难以加速
- 需要专门的稀疏计算库和硬件
- 实际加速比远低于理论值
结构化剪枝的优势:
- 剪枝后的模型仍然是密集的(只是更小)
- 可以直接使用标准的密集计算库
- 在通用硬件上实现真正的加速
7.1.2 通道剪枝的形式化
对于线性层 Y=XWY = XWY=XW,W∈Rm×nW \in \mathbb{R}^{m \times n}W∈Rm×n,通道剪枝移除 WWW 的整行或整列:
- 输入通道剪枝 :移除 WWW 的列,同时移除对应输入特征
- 输出通道剪枝 :移除 WWW 的行,同时移除对应输出特征
剪枝后的层 :Y′=X′W′Y' = X' W'Y′=X′W′,其中 W′∈Rm′×n′W' \in \mathbb{R}^{m' \times n'}W′∈Rm′×n′,m′≤mm' \leq mm′≤m,n′≤nn' \leq nn′≤n。
7.2 通道重要性度量
7.2.1 基于幅度的度量
ℓ1\ell_1ℓ1 范数:
importance(ci)=∥Wi,:∥1=∑j∣Wij∣\text{importance}(c_i) = \|W_{i,:}\|1 = \sum_j |W{ij}|importance(ci)=∥Wi,:∥1=j∑∣Wij∣
ℓ2\ell_2ℓ2 范数:
importance(ci)=∥Wi,:∥2=∑jWij2\text{importance}(c_i) = \|W_{i,:}\|2 = \sqrt{\sum_j W{ij}^2}importance(ci)=∥Wi,:∥2=j∑Wij2
定理 7.1(ℓ2\ell_2ℓ2 范数与输出方差) :对于随机输入 XXX,EX=0\mathbb{E}X = 0EX=0,Cov(X)=I\text{Cov}(X) = ICov(X)=I,第 iii 个输出的方差为:
Var(Yi)=∥Wi,:∥22\text{Var}(Y_i) = \|W_{i,:}\|_2^2Var(Yi)=∥Wi,:∥22
因此,ℓ2\ell_2ℓ2 范数小的通道对输出的影响也小。
7.2.2 基于梯度的度量
Taylor 展开 :移除通道 iii 的损失变化为:
δLi≈∣∂L∂Yi⋅Yi∣=∣gi⋅Yi∣\delta\mathcal{L}_i \approx \left|\frac{\partial \mathcal{L}}{\partial Y_i} \cdot Y_i\right| = |g_i \cdot Y_i|δLi≈ ∂Yi∂L⋅Yi =∣gi⋅Yi∣
其中 gi=∂L/∂Yig_i = \partial \mathcal{L} / \partial Y_igi=∂L/∂Yi 是输出的梯度。
批量平均:
importance(ci)=1T∑t=1T∣gi,t⋅Yi,t∣\text{importance}(c_i) = \frac{1}{T} \sum_{t=1}^{T} |g_{i,t} \cdot Y_{i,t}|importance(ci)=T1t=1∑T∣gi,t⋅Yi,t∣
7.2.3 基于 Hessian 的度量
OBS 风格:
importance(ci)=∥Wi,:∥22H−1ii\text{importance}(c_i) = \frac{\|W_{i,:}\|2^2}{H\^{-1}{ii}}importance(ci)=H−1ii∥Wi,:∥22
7.3 通道剪枝的算法
7.3.1 贪心剪枝
算法 7.1(贪心通道剪枝):
输入:模型,目标通道数 k
输出:剪枝后的模型
1. 计算每个通道的重要性
2. 按重要性排序
3. 删除最不重要的 m-k 个通道
4. 微调模型
7.3.2 迭代通道剪枝
算法 7.2(迭代通道剪枝):
输入:模型,目标稀疏率 s,每次剪枝比例 p
输出:剪枝后的模型
while 当前稀疏率 < s:
1. 计算通道重要性
2. 剪枝 p 比例的通道
3. 微调
返回模型
7.4 通道剪枝的误差分析
定理 7.2(通道剪枝的输出误差) :设原始层为 Y=XWY = XWY=XW,剪枝后为 Y′=X′W′Y' = X'W'Y′=X′W′,则:
∥Y−Y′∥F≤∥X∥F⋅∥W−W′∥F\|Y - Y'\|_F \leq \|X\|_F \cdot \|W - W'\|_F∥Y−Y′∥F≤∥X∥F⋅∥W−W′∥F
其中 W′W'W′ 是将被剪枝通道设为零后的权重。
推论 7.1 :如果被剪枝通道的权重范数之和为 ϵ\epsilonϵ,则输出误差上界为 ∥X∥F⋅ϵ\|X\|_F \cdot \epsilon∥X∥F⋅ϵ。
第八章:注意力头剪枝与层剪枝
8.1 注意力头剪枝
8.1.1 Transformer 中的冗余
观察(Michel et al., 2019):在训练好的 Transformer 中,大量的注意力头是冗余的------移除它们对模型性能影响很小。
定义 8.1(注意力头重要性) :注意力头 hhh 的重要性定义为:
Ih=∣∂L∂αh⋅αh∣I_h = \left|\frac{\partial \mathcal{L}}{\partial \alpha_h} \cdot \alpha_h\right|Ih= ∂αh∂L⋅αh
其中 αh\alpha_hαh 是该头的注意力权重。
8.1.2 头重要性的计算
定理 8.1(头重要性的简化) :在多头注意力中,头 hhh 的重要性可以近似为:
Ih≈1T∑t=1T∥∂L∂Ah(t)⊙Ah(t)∥FI_h \approx \frac{1}{T} \sum_{t=1}^{T} \left\|\frac{\partial \mathcal{L}}{\partial A_h^{(t)}} \odot A_h^{(t)}\right\|_FIh≈T1t=1∑T ∂Ah(t)∂L⊙Ah(t) F
其中 Ah(t)A_h^{(t)}Ah(t) 是第 ttt 个样本的注意力权重矩阵。
8.1.3 交叉头剪枝
问题:独立评估每个头的重要性可能不是最优的------某些头可能单独看不重要,但与其他头组合时很重要。
解决方案 :考虑头之间的交互------使用联合重要性:
Ih1,h2=Ih1+Ih2+interaction(h1,h2)I_{h_1, h_2} = I_{h_1} + I_{h_2} + \text{interaction}(h_1, h_2)Ih1,h2=Ih1+Ih2+interaction(h1,h2)
8.2 层剪枝
8.2.1 层重要性度量
方法 1:基于输出变化
importance(l)=E∥fl(x)−fl−1(x)∥2\text{importance}(l) = \mathbb{E}\left\\\|f_l(x) - f_{l-1}(x)\\\|\^2\\rightimportance(l)=E∥fl(x)−fl−1(x)∥2
即第 lll 层对表示的改变量。
方法 2:基于角度距离
importance(l)=1−Efl(x)Tfl−1(x)∥fl(x)∥⋅∥fl−1(x)∥\text{importance}(l) = 1 - \frac{\mathbb{E}f_l(x)\^T f_{l-1}(x)}{\|f_l(x)\| \cdot \|f_{l-1}(x)\|}importance(l)=1−∥fl(x)∥⋅∥fl−1(x)∥Efl(x)Tfl−1(x)
方法 3:基于逐层损失
importance(l)=L(f1,...,fl,id,...,id)−L(f1,...,fL)\text{importance}(l) = \mathcal{L}(f_1, \dots, f_l, \text{id}, \dots, \text{id}) - \mathcal{L}(f_1, \dots, f_L)importance(l)=L(f1,...,fl,id,...,id)−L(f1,...,fL)
其中 id\text{id}id 表示恒等映射(跳过该层)。
8.2.2 层剪枝的算法
算法 8.1(逐层剪枝):
输入:L 层模型,目标层数 L'
输出:剪枝后的模型
1. 计算每层的重要性
2. 按重要性排序
3. 删除最不重要的 L-L' 层
4. 微调
注意:层剪枝通常只能删除少量层(10%-30%),否则精度会急剧下降。
8.3 大语言模型的结构化剪枝
8.3.1 LLM-Pruner
LLM-Pruner(Ma et al., 2023)提出了一种针对大语言模型的结构化剪枝框架:
- 依赖图分析:自动识别 LLM 中的可剪枝结构
- 梯度重要性:使用一阶梯度和 Hessian 信息评估结构重要性
- 快速恢复:使用 LoRA 进行高效微调
8.3.2 ShortGPT
ShortGPT(Men et al., 2024)发现 LLM 中存在大量的层冗余:
观察 :在 LLM 中,相邻层的表示高度相似------可以用余弦相似度衡量:
sim(l)=cos(hl,hl−1)=hlThl−1∥hl∥⋅∥hl−1∥\text{sim}(l) = \cos(h_l, h_{l-1}) = \frac{h_l^T h_{l-1}}{\|h_l\| \cdot \|h_{l-1}\|}sim(l)=cos(hl,hl−1)=∥hl∥⋅∥hl−1∥hlThl−1
当 sim(l)≈1\text{sim}(l) \approx 1sim(l)≈1 时,第 $l} 层几乎不改变表示,可以安全删除。
第四部分:稀疏格式与稀疏计算
第九章:稀疏矩阵存储格式
9.1 CSR(Compressed Sparse Row)
9.1.1 格式定义
CSR(压缩稀疏行) 格式使用三个数组存储稀疏矩阵 A∈Rm×nA \in \mathbb{R}^{m \times n}A∈Rm×n:
- values :非零元素的值,长度 nnznnznnz
- col_indices :每个非零元素的列索引,长度 nnznnznnz
- row_ptr :每行的起始位置,长度 m+1m+1m+1
示例:
A=(100203004050)A = \begin{pmatrix} 1 & 0 & 0 & 2 \\ 0 & 3 & 0 & 0 \\ 4 & 0 & 5 & 0 \end{pmatrix}A= 104030005200
- values = 1, 2, 3, 4, 5
- col_indices = 0, 3, 1, 0, 2
- row_ptr = 0, 2, 3, 5
9.1.2 存储复杂度
定理 9.1(CSR 的存储):CSR 格式需要:
存储=nnz⋅(bval+bidx)+(m+1)⋅bidx\text{存储} = nnz \cdot (b_{\text{val}} + b_{\text{idx}}) + (m+1) \cdot b_{\text{idx}}存储=nnz⋅(bval+bidx)+(m+1)⋅bidx
其中 bvalb_{\text{val}}bval 是值的比特数,bidxb_{\text{idx}}bidx 是索引的比特数。
对于 FP32 值和 INT32 索引:存储=nnz⋅8+(m+1)⋅4\text{存储} = nnz \cdot 8 + (m+1) \cdot 4存储=nnz⋅8+(m+1)⋅4 字节。
盈亏平衡点 :当 nnz<m⋅n⋅bvalbval+bidx+bidx/mnnz < \frac{m \cdot n \cdot b_{\text{val}}}{b_{\text{val}} + b_{\text{idx}} + b_{\text{idx}}/m}nnz<bval+bidx+bidx/mm⋅n⋅bval 时,CSR 比密集存储更省空间。
对于 FP32、大矩阵:nnz<mn⋅48=mn/2nnz < \frac{mn \cdot 4}{8} = mn/2nnz<8mn⋅4=mn/2,即稀疏率 > 50%。
9.2 CSC(Compressed Sparse Column)
CSC 是 CSR 的列版本:
- values:非零元素的值
- row_indices:每个非零元素的行索引
- col_ptr:每列的起始位置
9.3 COO(Coordinate)
COO(坐标格式) 存储每个非零元素的 (行, 列, 值):
- row_indices:行索引
- col_indices:列索引
- values:值
存储 :nnz⋅(2bidx+bval)nnz \cdot (2 b_{\text{idx}} + b_{\text{val}})nnz⋅(2bidx+bval)
优点:构建简单,适合动态插入非零元素。
9.4 BSR(Block Sparse Row)
BSR(块稀疏行) 格式将矩阵分成 br×bcb_r \times b_cbr×bc 的块,只存储非零块。
优点:
- 更好的数据局部性
- 适合 GPU 的块操作
- 索引开销更小(每个块只有一个索引)
存储 :nnzblocks⋅br⋅bc⋅bval+nnzblocks⋅bidx+(m/br+1)⋅bidxnnz_{\text{blocks}} \cdot b_r \cdot b_c \cdot b_{\text{val}} + nnz_{\text{blocks}} \cdot b_{\text{idx}} + (m/b_r + 1) \cdot b_{\text{idx}}nnzblocks⋅br⋅bc⋅bval+nnzblocks⋅bidx+(m/br+1)⋅bidx
第十章:稀疏矩阵乘法与硬件加速
10.1 稀疏-密集矩阵乘法(SpMM)
10.1.1 CSR 格式的 SpMM
算法 10.1(CSR SpMM):
输入:CSR 矩阵 A (m×n), 密集矩阵 B (n×k)
输出:C = A @ B (m×k)
for i = 0, ..., m-1:
for j = row_ptr[i], ..., row_ptr[i+1]-1:
col = col_indices[j]
val = values[j]
for p = 0, ..., k-1:
C[i, p] += val * B[col, p]
复杂度 :O(nnz⋅k)O(nnz \cdot k)O(nnz⋅k)
10.1.2 SpMM 的并行化
GPU 并行策略:
- 行并行:每个线程处理一行
- 元素并行:每个线程处理一个非零元素
- 块并行:每个线程块处理一个块
10.2 稀疏-密集向量乘法(SpMV)
算法 10.2(CSR SpMV):
输入:CSR 矩阵 A (m×n), 密集向量 x (n)
输出:y = A @ x (m)
for i = 0, ..., m-1:
sum = 0
for j = row_ptr[i], ..., row_ptr[i+1]-1:
sum += values[j] * x[col_indices[j]]
y[i] = sum
复杂度 :O(nnz)O(nnz)O(nnz)
10.3 稀疏计算的实际加速
10.3.1 加速比分析
理论加速比:
Speedup理论=mnnnz=11−s\text{Speedup}_{\text{理论}} = \frac{mn}{nnz} = \frac{1}{1-s}Speedup理论=nnzmn=1−s1
其中 sss 是稀疏率。
实际加速比:由于内存访问模式不规则、并行度不足等原因,实际加速比通常远低于理论值。
| 稀疏率 | 理论加速 | 实际加速(CPU) | 实际加速(GPU) |
|---|---|---|---|
| 50% | 2x | 1.2x | 1.1x |
| 90% | 10x | 3x | 2x |
| 95% | 20x | 5x | 3x |
| 99% | 100x | 10x | 5x |
10.3.2 结构化稀疏的优势
问题:非结构化稀疏的实际加速比很低。
解决方案 :使用结构化稀疏------稀疏模式是规则的,可以利用硬件的块操作。
第十一章:2:4 结构化稀疏与 NVIDIA Ampere
11.1 2:4 稀疏模式
11.1.1 定义
2:4 稀疏 (也称为 N:M 稀疏 ,N=2,M=4N=2, M=4N=2,M=4):在每 4 个连续元素中,恰好有 2 个为零。
示例:
w=w1,0,w3,0,0,w6,w7,0,... \mathbf{w} = w_1, 0, w_3, 0, 0, w_6, w_7, 0, \\dotsw=w1,0,w3,0,0,w6,w7,0,...
每 4 个元素中恰好 2 个非零。
11.1.2 理论分析
定理 11.1(2:4 稀疏的表达能力):2:4 稀疏的等效参数量为原始的 50%,但通过选择最优的稀疏模式,可以保留大部分模型精度。
最优稀疏模式:在每 4 个元素中,保留绝对值最大的 2 个:
maski:i+4=top-2-mask(∣wi:i+4∣)\text{mask}{i:i+4} = \text{top-2-mask}(|w{i:i+4}|)maski:i+4=top-2-mask(∣wi:i+4∣)
11.1.3 硬件支持
NVIDIA A100/H100 的稀疏张量核心:
- 支持 2:4 结构化稀疏的矩阵乘法
- 理论吞吐量是密集矩阵的 2 倍
- 稀疏矩阵只需要存储非零元素和 2-bit 索引
工作原理:
1. 稀疏矩阵 A (2:4 格式):存储非零值 + 2-bit 索引
2. 密集矩阵 B:正常存储
3. 硬件使用索引从 B 中选择对应列,与 A 的非零值相乘
4. 结果 C = A @ B
11.2 2:4 稀疏的训练
11.2.1 稀疏感知训练
算法 11.1(2:4 稀疏感知训练):
1. 初始化:密集权重 W
2. 每个训练步:
a. 前向传播:使用 2:4 稀疏权重
- 应用 2:4 掩码:W_sparse = W * mask
- 计算输出:Y = X @ W_sparse^T
b. 反向传播:使用 STE
- 梯度通过掩码传递:grad_W = grad_Y @ X
c. 更新:W = W - lr * grad_W
d. 重新计算掩码(每隔 N 步)
11.2.2 掩码更新策略
静态掩码:训练开始时确定掩码,训练过程中不变。
动态掩码:每隔 N 步重新计算掩码(基于当前权重的幅度)。
渐进式掩码:从密集开始,逐渐增加稀疏率。
11.3 2:4 稀疏的精度
11.3.1 实验结果
| 模型 | 方法 | 密集精度 | 2:4 稀疏精度 | 精度下降 |
|---|---|---|---|---|
| ResNet-50 | 训练后剪枝 | 76.1% | 75.2% | 0.9% |
| ResNet-50 | 稀疏感知训练 | 76.1% | 76.0% | 0.1% |
| BERT-Large | 训练后剪枝 | 91.0% | 90.5% | 0.5% |
| BERT-Large | 稀疏感知训练 | 91.0% | 90.9% | 0.1% |
11.3.2 与量化的结合
2:4 稀疏 + INT8 量化:
Y=(XINT8⋅WINT8, sparse)⋅sX⋅sWY = (X_{\text{INT8}} \cdot W_{\text{INT8, sparse}}) \cdot s_X \cdot s_WY=(XINT8⋅WINT8, sparse)⋅sX⋅sW
理论加速:2x(稀疏)× 2x(INT8 vs FP16)= 4x
第五部分:完整可运行代码实现
第十二章:从零实现非结构化剪枝系统
python
"""
非结构化剪枝系统的完整实现。
包含:幅度剪枝、迭代剪枝、SNIP、Wanda。
"""
import numpy as np
from typing import Tuple, Optional
def magnitude_prune(weights: np.ndarray, sparsity: float) -> Tuple[np.ndarray, np.ndarray]:
"""幅度剪枝。
删除绝对值最小的权重。
Args:
weights: 权重矩阵
sparsity: 目标稀疏率 (0-1)
Returns:
pruned_weights: 剪枝后的权重
mask: 二值掩码 (1=保留, 0=删除)
"""
# 计算阈值
flat_weights = np.abs(weights.flatten())
k = int(len(flat_weights) * (1 - sparsity))
threshold = np.sort(flat_weights)[k - 1] if k > 0 else np.inf
# 创建掩码
mask = (np.abs(weights) >= threshold).astype(float)
# 应用掩码
pruned_weights = weights * mask
return pruned_weights, mask
def snip_prune(
weights: np.ndarray,
gradients: np.ndarray,
sparsity: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""SNIP 剪枝。
基于连接灵敏度 |w * grad| 进行剪枝。
Args:
weights: 权重矩阵
gradients: 梯度矩阵
sparsity: 目标稀疏率
Returns:
pruned_weights: 剪枝后的权重
mask: 二值掩码
"""
# 计算连接灵敏度
sensitivity = np.abs(weights * gradients)
# 计算阈值
k = int(weights.size * (1 - sparsity))
threshold = np.sort(sensitivity.flatten())[k - 1] if k > 0 else np.inf
# 创建掩码
mask = (sensitivity >= threshold).astype(float)
return weights * mask, mask
def wanda_prune(
weights: np.ndarray,
activations: np.ndarray,
sparsity: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Wanda 剪枝。
基于 |w| * ||activation|| 进行剪枝。
Args:
weights: 权重矩阵 (m, n)
activations: 输入激活 (T, n)
sparsity: 目标稀疏率
Returns:
pruned_weights: 剪枝后的权重
mask: 二值掩码
"""
# 计算激活范数
act_norms = np.linalg.norm(activations, axis=0) # (n,)
# 计算重要性分数
scores = np.abs(weights) * act_norms[np.newaxis, :] # (m, n)
# 计算阈值
k = int(weights.size * (1 - sparsity))
threshold = np.sort(scores.flatten())[k - 1] if k > 0 else np.inf
# 创建掩码
mask = (scores >= threshold).astype(float)
return weights * mask, mask
def iterative_magnitude_prune(
weights: np.ndarray,
X: np.ndarray,
Y: np.ndarray,
target_sparsity: float,
n_steps: int = 10,
n_finetune_steps: int = 50,
lr: float = 0.001,
) -> Tuple[np.ndarray, np.ndarray, list]:
"""迭代幅度剪枝。
Args:
weights: 初始权重
X: 训练输入
Y: 训练目标
target_sparsity: 目标稀疏率
n_steps: 剪枝步骤数
n_finetune_steps: 每步微调的迭代数
lr: 学习率
Returns:
final_weights: 最终权重
final_mask: 最终掩码
loss_history: 损失历史
"""
current_weights = weights.copy()
loss_history = []
sparsity_per_step = 1 - (1 - target_sparsity) ** (1 / n_steps)
for step in range(n_steps):
# 剪枝
_, mask = magnitude_prune(current_weights, sparsity_per_step)
current_weights = current_weights * mask
# 微调
for _ in range(n_finetune_steps):
# 前向传播
Y_pred = X @ current_weights.T
loss = np.mean((Y - Y_pred) ** 2)
loss_history.append(loss)
# 反向传播
grad = 2 / X.shape[0] * (Y_pred - Y).T @ X
# 更新(只更新非零位置)
current_weights = current_weights - lr * grad * mask
current_sparsity = 1 - np.mean(mask)
print(f" 步骤 {step + 1}/{n_steps}: 稀疏率 = {current_sparsity:.2%}, "
f"损失 = {loss_history[-1]:.6f}")
final_mask = (current_weights != 0).astype(float)
return current_weights, final_mask, loss_history
def demonstrate_pruning():
"""演示各种剪枝方法。"""
np.random.seed(42)
# 设置
m, n = 128, 256
T = 500
sparsity = 0.9 # 90% 稀疏率
print("=" * 70)
print("非结构化剪枝方法对比")
print("=" * 70)
print(f"权重: {m}x{n}, 稀疏率: {sparsity:.0%}")
print()
# 生成数据
W = np.random.randn(m, n) * 0.02
X = np.random.randn(T, n) * 0.5
Y = X @ W.T + np.random.randn(T, m) * 0.01
# 参考 MSE
Y_ref = X @ W.T
ref_mse = np.mean((Y - Y_ref) ** 2)
# 计算梯度(用于 SNIP)
Y_pred = X @ W.T
grad = 2 / T * (Y_pred - Y).T @ X # (m, n)
results = []
# 方法 1:幅度剪枝
W_mag, mask_mag = magnitude_prune(W, sparsity)
Y_mag = X @ W_mag.T
mse_mag = np.mean((Y_ref - Y_mag) ** 2)
results.append(("幅度剪枝", mse_mag, np.mean(mask_mag)))
# 方法 2:SNIP
W_snip, mask_snip = snip_prune(W, grad, sparsity)
Y_snip = X @ W_snip.T
mse_snip = np.mean((Y_ref - Y_snip) ** 2)
results.append(("SNIP", mse_snip, np.mean(mask_snip)))
# 方法 3:Wanda
W_wanda, mask_wanda = wanda_prune(W, X, sparsity)
Y_wanda = X @ W_wanda.T
mse_wanda = np.mean((Y_ref - Y_wanda) ** 2)
results.append(("Wanda", mse_wanda, np.mean(mask_wanda)))
# 方法 4:迭代幅度剪枝
W_imp, mask_imp, _ = iterative_magnitude_prune(
W, X, Y, target_sparsity=sparsity, n_steps=5, n_finetune_steps=20, lr=0.0001
)
Y_imp = X @ W_imp.T
mse_imp = np.mean((Y_ref - Y_imp) ** 2)
results.append(("迭代幅度剪枝", mse_imp, np.mean(mask_imp)))
# 打印结果
print(f"\n {'方法':>15} {'MSE':>15} {'非零比例':>10} {'相对误差':>12}")
print(f" {'-'*15} {'-'*15} {'-'*10} {'-'*12}")
for name, mse, nnz_ratio in results:
rel_err = np.sqrt(mse) / np.std(Y_ref) if mse > 0 else 0
print(f" {name:>15} {mse:>15.10f} {nnz_ratio:>10.2%} {rel_err:>12.6f}")
# 不同稀疏率的分析
print(f"\n 不同稀疏率下的幅度剪枝 MSE:")
print(f" {'稀疏率':>10} {'MSE':>15} {'相对误差':>12} {'非零元素':>10}")
print(f" {'-'*10} {'-'*15} {'-'*12} {'-'*10}")
for s in [0.5, 0.7, 0.8, 0.9, 0.95, 0.99]:
W_s, mask_s = magnitude_prune(W, s)
Y_s = X @ W_s.T
mse_s = np.mean((Y_ref - Y_s) ** 2)
rel_err = np.sqrt(mse_s) / np.std(Y_ref)
nnz = int(mask_s.sum())
print(f" {s:>10.0%} {mse_s:>15.10f} {rel_err:>12.6f} {nnz:>10}")
if __name__ == "__main__":
demonstrate_pruning()
第十三章:从零实现结构化剪枝系统
python
"""
结构化剪枝系统的完整实现。
包含:通道剪枝、注意力头剪枝、层剪枝。
"""
import numpy as np
from typing import List, Tuple
def channel_prune_l1(
W: np.ndarray,
keep_ratio: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""基于 L1 范数的通道剪枝。
Args:
W: 权重矩阵 (m, n),每行是一个通道
keep_ratio: 保留比例
Returns:
W_pruned: 剪枝后的权重
channel_mask: 通道掩码
"""
m = W.shape[0]
k = max(1, int(m * keep_ratio))
# 计算每个通道的 L1 范数
channel_norms = np.linalg.norm(W, ord=1, axis=1)
# 保留范数最大的 k 个通道
top_k_indices = np.argsort(channel_norms)[-k:]
channel_mask = np.zeros(m)
channel_mask[top_k_indices] = 1
W_pruned = W * channel_mask[:, np.newaxis]
return W_pruned, channel_mask
def channel_prune_l2(
W: np.ndarray,
keep_ratio: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""基于 L2 范数的通道剪枝。
Args:
W: 权重矩阵 (m, n)
keep_ratio: 保留比例
Returns:
W_pruned: 剪枝后的权重
channel_mask: 通道掩码
"""
m = W.shape[0]
k = max(1, int(m * keep_ratio))
# 计算每个通道的 L2 范数
channel_norms = np.linalg.norm(W, ord=2, axis=1)
top_k_indices = np.argsort(channel_norms)[-k:]
channel_mask = np.zeros(m)
channel_mask[top_k_indices] = 1
W_pruned = W * channel_mask[:, np.newaxis]
return W_pruned, channel_mask
def channel_prune_gradient(
W: np.ndarray,
gradients: np.ndarray,
activations: np.ndarray,
keep_ratio: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""基于梯度的通道剪枝。
重要性 = ||通道权重|| * ||通道梯度|| * ||通道激活||
Args:
W: 权重矩阵 (m, n)
gradients: 输出梯度 (T, m)
activations: 输入激活 (T, n)
keep_ratio: 保留比例
Returns:
W_pruned: 剪枝后的权重
channel_mask: 通道掩码
"""
m = W.shape[0]
k = max(1, int(m * keep_ratio))
# 计算通道重要性
weight_norms = np.linalg.norm(W, axis=1) # (m,)
grad_norms = np.linalg.norm(gradients, axis=0) # (m,)
act_norms = np.linalg.norm(activations, axis=0) # (n,)
# 重要性 = 权重范数 * 梯度范数
importance = weight_norms * grad_norms
top_k_indices = np.argsort(importance)[-k:]
channel_mask = np.zeros(m)
channel_mask[top_k_indices] = 1
W_pruned = W * channel_mask[:, np.newaxis]
return W_pruned, channel_mask
def attention_head_prune(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
O: np.ndarray,
X: np.ndarray,
keep_ratio: float,
n_heads: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""注意力头剪枝。
Args:
Q, K, V, O: 注意力层的权重矩阵
X: 输入激活
keep_ratio: 头保留比例
n_heads: 注意力头数量
Returns:
Q_pruned, K_pruned, V_pruned, O_pruned: 剪枝后的权重
head_mask: 头掩码
"""
d_model = Q.shape[0]
d_head = d_model // n_heads
k = max(1, int(n_heads * keep_ratio))
# 计算每个头的重要性
head_importance = np.zeros(n_heads)
for h in range(n_heads):
# 提取该头的权重
start = h * d_head
end = (h + 1) * d_head
Q_h = Q[start:end, :]
K_h = K[start:end, :]
V_h = V[start:end, :]
# 重要性 = 各权重矩阵的 Frobenius 范数之和
head_importance[h] = (
np.linalg.norm(Q_h, 'fro') +
np.linalg.norm(K_h, 'fro') +
np.linalg.norm(V_h, 'fro')
)
# 选择最重要的头
top_k_heads = np.argsort(head_importance)[-k:]
head_mask = np.zeros(n_heads)
head_mask[top_k_heads] = 1
# 创建通道掩码
channel_mask = np.repeat(head_mask, d_head)
# 应用掩码
Q_pruned = Q * channel_mask[:, np.newaxis]
K_pruned = K * channel_mask[:, np.newaxis]
V_pruned = V * channel_mask[:, np.newaxis]
O_pruned = O * channel_mask[np.newaxis, :]
return Q_pruned, K_pruned, V_pruned, O_pruned, head_mask
def layer_importance_cosine(
hidden_states: List[np.ndarray],
) -> np.ndarray:
"""基于余弦相似度的层重要性。
Args:
hidden_states: 各层的隐藏状态列表
Returns:
importance: 层重要性分数
"""
n_layers = len(hidden_states) - 1
importance = np.zeros(n_layers)
for l in range(n_layers):
h_l = hidden_states[l].flatten()
h_l1 = hidden_states[l + 1].flatten()
# 余弦相似度
cos_sim = np.dot(h_l, h_l1) / (np.linalg.norm(h_l) * np.linalg.norm(h_l1) + 1e-8)
# 重要性 = 1 - 余弦相似度(变化越大越重要)
importance[l] = 1 - cos_sim
return importance
def layer_prune(
weights_list: List[np.ndarray],
hidden_states: List[np.ndarray],
keep_ratio: float,
) -> Tuple[List[np.ndarray], np.ndarray]:
"""层剪枝。
Args:
weights_list: 各层的权重列表
hidden_states: 各层的隐藏状态
keep_ratio: 层保留比例
Returns:
pruned_weights: 剪枝后的权重列表
layer_mask: 层掩码
"""
n_layers = len(weights_list)
k = max(1, int(n_layers * keep_ratio))
# 计算层重要性
importance = layer_importance_cosine(hidden_states)
# 选择最重要的层
top_k_layers = np.argsort(importance)[-k:]
layer_mask = np.zeros(n_layers)
layer_mask[top_k_layers] = 1
pruned_weights = [w for w, m in zip(weights_list, layer_mask) if m > 0]
return pruned_weights, layer_mask
def demonstrate_structured_pruning():
"""演示结构化剪枝。"""
np.random.seed(42)
print("=" * 70)
print("结构化剪枝演示")
print("=" * 70)
# 通道剪枝
print("\n 1. 通道剪枝")
print(" " + "-" * 40)
m, n = 64, 128
W = np.random.randn(m, n) * 0.02
for keep_ratio in [0.5, 0.25, 0.1]:
# L1 剪枝
W_l1, mask_l1 = channel_prune_l1(W, keep_ratio)
nnz_l1 = int(mask_l1.sum())
# L2 剪枝
W_l2, mask_l2 = channel_prune_l2(W, keep_ratio)
nnz_l2 = int(mask_l2.sum())
print(f" 保留 {keep_ratio:.0%}: L1 保留 {nnz_l1} 通道, L2 保留 {nnz_l2} 通道")
# 注意力头剪枝
print("\n 2. 注意力头剪枝")
print(" " + "-" * 40)
d_model = 256
n_heads = 8
d_head = d_model // n_heads
Q = np.random.randn(d_model, d_model) * 0.02
K = np.random.randn(d_model, d_model) * 0.02
V = np.random.randn(d_model, d_model) * 0.02
O = np.random.randn(d_model, d_model) * 0.02
X = np.random.randn(32, d_model) * 0.5
for keep_ratio in [0.75, 0.5, 0.25]:
Q_p, K_p, V_p, O_p, head_mask = attention_head_prune(
Q, K, V, O, X, keep_ratio, n_heads
)
n_kept = int(head_mask.sum())
print(f" 保留 {keep_ratio:.0%}: {n_kept}/{n_heads} 头")
# 层重要性分析
print("\n 3. 层重要性分析")
print(" " + "-" * 40)
n_layers = 12
d = 256
hidden_states = [np.random.randn(32, d) * 0.5]
for l in range(n_layers):
# 模拟:层越深,表示变化越小
h_new = hidden_states[-1] + np.random.randn(32, d) * (0.5 / (l + 1))
hidden_states.append(h_new)
importance = layer_importance_cosine(hidden_states)
print(f" 层重要性(1 - 余弦相似度):")
for l, imp in enumerate(importance):
bar = "█" * int(imp * 200)
print(f" 层 {l:2d}: {imp:.4f} {bar}")
if __name__ == "__main__":
demonstrate_structured_pruning()
第十四章:从零实现稀疏矩阵运算
python
"""
稀疏矩阵运算的完整实现。
包含:CSR 格式、SpMV、SpMM、2:4 稀疏。
"""
import numpy as np
from typing import Tuple
class CSRMatrix:
"""CSR 格式的稀疏矩阵。"""
def __init__(self, values, col_indices, row_ptr, shape):
self.values = values
self.col_indices = col_indices
self.row_ptr = row_ptr
self.shape = shape
@classmethod
def from_dense(cls, A: np.ndarray, threshold: float = 0.0) -> 'CSRMatrix':
"""从密集矩阵创建 CSR 矩阵。"""
m, n = A.shape
values = []
col_indices = []
row_ptr = [0]
for i in range(m):
for j in range(n):
if abs(A[i, j]) > threshold:
values.append(A[i, j])
col_indices.append(j)
row_ptr.append(len(values))
return cls(
np.array(values),
np.array(col_indices, dtype=np.int32),
np.array(row_ptr, dtype=np.int32),
(m, n)
)
def to_dense(self) -> np.ndarray:
"""转换为密集矩阵。"""
m, n = self.shape
A = np.zeros((m, n))
for i in range(m):
for j in range(self.row_ptr[i], self.row_ptr[i + 1]):
A[i, self.col_indices[j]] = self.values[j]
return A
def spmv(self, x: np.ndarray) -> np.ndarray:
"""稀疏矩阵-密集向量乘法 (SpMV)。"""
m, n = self.shape
y = np.zeros(m)
for i in range(m):
s = 0.0
for j in range(self.row_ptr[i], self.row_ptr[i + 1]):
s += self.values[j] * x[self.col_indices[j]]
y[i] = s
return y
def spmm(self, B: np.ndarray) -> np.ndarray:
"""稀疏矩阵-密集矩阵乘法 (SpMM)。"""
m, n = self.shape
k = B.shape[1]
C = np.zeros((m, k))
for i in range(m):
for jj in range(self.row_ptr[i], self.row_ptr[i + 1]):
j = self.col_indices[jj]
val = self.values[jj]
for p in range(k):
C[i, p] += val * B[j, p]
return C
@property
def nnz(self) -> int:
return len(self.values)
@property
def sparsity(self) -> float:
return 1 - self.nnz / (self.shape[0] * self.shape[1])
def apply_2_4_sparsity(W: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""应用 2:4 结构化稀疏。
在每 4 个连续元素中保留绝对值最大的 2 个。
Args:
W: 权重矩阵
Returns:
W_sparse: 稀疏权重
mask: 2:4 掩码
"""
W_flat = W.flatten()
n = len(W_flat)
mask = np.zeros(n)
# 每 4 个元素处理一组
for i in range(0, n, 4):
group = W_flat[i:i + 4]
if len(group) < 4:
# 不足 4 个的组,全部保留
mask[i:i + len(group)] = 1
continue
# 选择绝对值最大的 2 个
top_2 = np.argsort(np.abs(group))[-2:]
for idx in top_2:
mask[i + idx] = 1
mask = mask.reshape(W.shape)
W_sparse = W * mask
return W_sparse, mask
def demonstrate_sparse_operations():
"""演示稀疏矩阵运算。"""
np.random.seed(42)
print("=" * 70)
print("稀疏矩阵运算演示")
print("=" * 70)
# 创建稀疏矩阵
m, n = 128, 256
sparsity = 0.9
# 生成稀疏矩阵
A = np.random.randn(m, n) * 0.02
mask = np.random.random((m, n)) > sparsity
A_sparse = A * mask
print(f"\n 矩阵大小: {m}x{n}")
print(f" 稀疏率: {sparsity:.0%}")
print(f" 非零元素: {int(mask.sum())}")
# CSR 格式
csr = CSRMatrix.from_dense(A_sparse, threshold=1e-10)
print(f"\n CSR 格式:")
print(f" NNZ: {csr.nnz}")
print(f" 稀疏率: {csr.sparsity:.2%}")
print(f" 存储 (CSR): {csr.nnz * 8 + (m + 1) * 4} 字节")
print(f" 存储 (密集): {m * n * 4} 字节")
print(f" 压缩比: {m * n * 4 / (csr.nnz * 8 + (m + 1) * 4):.2f}x")
# SpMV
print(f"\n SpMV (稀疏矩阵-向量乘法):")
x = np.random.randn(n)
y_dense = A_sparse @ x
y_sparse = csr.spmv(x)
spmv_error = np.max(np.abs(y_dense - y_sparse))
print(f" 最大误差: {spmv_error:.2e}")
# SpMM
print(f"\n SpMM (稀疏矩阵-矩阵乘法):")
k = 64
B = np.random.randn(n, k)
C_dense = A_sparse @ B
C_sparse = csr.spmm(B)
spmm_error = np.max(np.abs(C_dense - C_sparse))
print(f" 最大误差: {spmm_error:.2e}")
# 2:4 结构化稀疏
print(f"\n 2:4 结构化稀疏:")
W = np.random.randn(64, 64) * 0.02
W_24, mask_24 = apply_2_4_sparsity(W)
actual_sparsity = 1 - np.mean(mask_24)
nnz_per_group = np.mean([np.sum(mask_24.flatten()[i:i + 4]) for i in range(0, mask_24.size, 4)])
print(f" 矩阵大小: {W.shape}")
print(f" 目标稀疏率: 50%")
print(f" 实际稀疏率: {actual_sparsity:.2%}")
print(f" 每组平均非零数: {nnz_per_group:.1f}")
# 误差分析
reconstruction_error = np.linalg.norm(W - W_24, 'fro') / np.linalg.norm(W, 'fro')
print(f" 重建相对误差: {reconstruction_error:.6f}")
if __name__ == "__main__":
demonstrate_sparse_operations()
第十五章:完整剪枝 Pipeline 与精度对比
python
"""
完整的剪枝 Pipeline。
对比各种剪枝方法在不同稀疏率下的精度。
"""
import numpy as np
from typing import Dict, List, Tuple
def run_full_pruning_comparison():
"""运行完整的剪枝方法对比。"""
np.random.seed(42)
# 设置
m, n = 128, 256
T = 500
print("=" * 70)
print("剪枝方法综合对比")
print("=" * 70)
print(f"权重: {m}x{n}, 训练样本: {T}")
print()
# 生成数据
W = np.random.randn(m, n) * 0.02
X = np.random.randn(T, n) * 0.5
Y = X @ W.T + np.random.randn(T, m) * 0.001
# 参考
Y_ref = X @ W.T
ref_mse = np.mean((Y - Y_ref) ** 2)
ref_std = np.std(Y_ref)
print(f" 参考 MSE: {ref_mse:.10f}")
print()
# 测试不同的稀疏率
sparsities = [0.5, 0.7, 0.8, 0.9, 0.95, 0.99]
# 定义剪枝方法
def magnitude_prune(W, s):
flat = np.abs(W.flatten())
k = int(W.size * (1 - s))
if k == 0:
return np.zeros_like(W), np.zeros_like(W)
threshold = np.sort(flat)[k - 1]
mask = (np.abs(W) >= threshold).astype(float)
return W * mask, mask
def random_prune(W, s):
mask = (np.random.random(W.shape) >= s).astype(float)
return W * mask, mask
def column_prune(W, s):
n = W.shape[1]
k = max(1, int(n * (1 - s)))
col_norms = np.linalg.norm(W, axis=0)
top_k = np.argsort(col_norms)[-k:]
mask = np.zeros_like(W)
mask[:, top_k] = 1
return W * mask, mask
def row_prune(W, s):
m = W.shape[0]
k = max(1, int(m * (1 - s)))
row_norms = np.linalg.norm(W, axis=1)
top_k = np.argsort(row_norms)[-k:]
mask = np.zeros_like(W)
mask[top_k, :] = 1
return W * mask, mask
methods = {
"幅度剪枝": magnitude_prune,
"随机剪枝": random_prune,
"列剪枝 (结构化)": column_prune,
"行剪枝 (结构化)": row_prune,
}
# 运行对比
results = {}
for method_name, method_fn in methods.items():
results[method_name] = []
for s in sparsities:
W_pruned, mask = method_fn(W, s)
Y_pruned = X @ W_pruned.T
mse = np.mean((Y_ref - Y_pruned) ** 2)
rel_err = np.sqrt(mse) / ref_std if mse > 0 else 0
nnz = int(mask.sum())
results[method_name].append((s, mse, rel_err, nnz))
# 打印结果
print(f" {'方法':>18} | ", end="")
for s in sparsities:
print(f"{'s=' + str(int(s*100)) + '%':>12} | ", end="")
print()
print(f" {'-'*18} | " + " | ".join(["-" * 12] * len(sparsities)) + " |")
print(f"\n MSE:")
for method_name, method_results in results.items():
print(f" {method_name:>18} | ", end="")
for s, mse, rel_err, nnz in method_results:
print(f"{mse:>12.8f} | ", end="")
print()
print(f"\n 相对误差:")
for method_name, method_results in results.items():
print(f" {method_name:>18} | ", end="")
for s, mse, rel_err, nnz in method_results:
print(f"{rel_err:>12.6f} | ", end="")
print()
# 存储分析
print(f"\n 存储分析 (非结构化 vs 结构化):")
print(f" {'稀疏率':>8} {'非结构化(KB)':>14} {'结构化(KB)':>14} {'压缩比':>10}")
print(f" {'-'*8} {'-'*14} {'-'*14} {'-'*10}")
dense_size = m * n * 4 / 1024 # KB
for s in sparsities:
# 非结构化:需要 CSR 格式
nnz = int(m * n * (1 - s))
sparse_csr_kb = (nnz * 8 + (m + 1) * 4) / 1024
# 结构化(行剪枝):只需要存储保留的行
k_rows = max(1, int(m * (1 - s)))
struct_kb = k_rows * n * 4 / 1024
print(f" {s:>8.0%} {sparse_csr_kb:>14.2f} {struct_kb:>14.2f} {dense_size / struct_kb:>10.2f}x")
if __name__ == "__main__":
run_full_pruning_comparison()
附录:关键公式汇总
A.1 稀疏性基础
| 公式 | 表达式 |
|---|---|
| ℓ0\ell_0ℓ0 "范数" | $|\mathbf{w}|_0 = |
| ℓ1\ell_1ℓ1 范数 | $|\mathbf{w}|_1 = \sum_i |
| 稀疏率 | s=1−∣w∣0/ns = 1 - |\mathbf{w}|_0 / ns=1−∣w∣0/n |
A.2 参数重要性
| 方法 | 重要性公式 |
|---|---|
| 幅度 | $ |
| SNIP | $ |
| Fisher | Fii⋅wi2F_{ii} \cdot w_i^2Fii⋅wi2 |
| OBD | Hii⋅wi2H_{ii} \cdot w_i^2Hii⋅wi2 |
| OBS | wi2/H−1iiw_i^2 / H\^{-1}_{ii}wi2/H−1ii |
| Wanda | $ |
| GraSP | −gi⋅(Hg)i-g_i \cdot (\mathbf{H}\mathbf{g})_i−gi⋅(Hg)i |
A.3 剪枝理论
| 定理 | 表达式 |
|---|---|
| OBD 损失增加 | δLi=12Hiiθi2\delta\mathcal{L}i = \frac{1}{2} H{ii} \theta_i^2δLi=21Hiiθi2 |
| OBS 损失增加 | δLi=θi22H−1ii\delta\mathcal{L}i = \frac{\theta_i^2}{2H\^{-1}{ii}}δLi=2H−1iiθi2 |
| 6 dB/bit 规则 | SQNR=6.02b−4.77\text{SQNR} = 6.02b - 4.77SQNR=6.02b−4.77 dB |
A.4 稀疏矩阵格式
| 格式 | 存储 |
|---|---|
| CSR | nnz⋅(bval+bidx)+(m+1)⋅bidxnnz \cdot (b_{\text{val}} + b_{\text{idx}}) + (m+1) \cdot b_{\text{idx}}nnz⋅(bval+bidx)+(m+1)⋅bidx |
| CSC | 同 CSR(列版本) |
| COO | nnz⋅(2bidx+bval)nnz \cdot (2 b_{\text{idx}} + b_{\text{val}})nnz⋅(2bidx+bval) |
| BSR | nnzblocks⋅br⋅bc⋅bval+索引开销nnz_{\text{blocks}} \cdot b_r \cdot b_c \cdot b_{\text{val}} + \text{索引开销}nnzblocks⋅br⋅bc⋅bval+索引开销 |
参考文献
- LeCun, Y., Denker, J., & Solla, S. (1989). Optimal brain damage. NeurIPS.
- Hassibi, B., & Stork, D. (1993). Second order derivatives for network pruning: Optimal brain surgeon. NeurIPS.
- Frankle, J., & Carlin, M. (2019). The lottery ticket hypothesis: Finding sparse, trainable neural networks. ICLR.
- Lee, J., et al. (2019). SNIP: Single-shot network pruning based on connection sensitivity. ICLR.
- Wang, C., et al. (2020). Picking winning tickets before training by preserving gradient flow. ICLR.
- Frantar, E., & Alistarh, D. (2023). SparseGPT: Massive language models can be accurately pruned in one-shot. ICML.
- Sun, M., et al. (2024). A simple and effective pruning approach for large language models. ICLR.
- Ma, X., et al. (2023). LLM-Pruner: On the structural pruning of large language models. NeurIPS.
- Zhu, M., & Gupta, S. (2017). To prune, or not to prune: Exploring the efficacy of pruning for model compression. ICLR Workshop.
- Michel, P., Levy, O., & Neubig, G. (2019). Are sixteen heads really better than one? NeurIPS.