模型剪枝与稀疏化:理论、算法与可运行实现

目录

  • 第一部分:剪枝与稀疏化基础理论
    • 第一章:绪论------为什么需要模型剪枝
    • 第二章:稀疏性的数学基础------范数、测度与信息论
    • 第三章:最优脑损伤(OBD)与二阶方法
  • 第二部分:非结构化剪枝方法
    • 第四章:幅度剪枝与迭代剪枝
    • 第五章:SNIP、GraSP 与一次射击剪枝
    • 第六章:彩票假说(Lottery Ticket Hypothesis)
  • 第三部分:结构化剪枝方法
    • 第七章:通道剪枝与滤波器剪枝
    • 第八章:注意力头剪枝与层剪枝
    • 第九章:大语言模型的结构化剪枝
  • 第四部分:稀疏格式与稀疏计算
    • 第十章:稀疏矩阵存储格式------CSR、CSC、COO、BSR
    • 第十一章:稀疏矩阵乘法与硬件加速
    • 第十二章:2:4 结构化稀疏与 NVIDIA Ampere
  • 第五部分:完整可运行代码实现
    • 第十三章:从零实现非结构化剪枝系统
    • 第十四章:从零实现结构化剪枝系统
    • 第十五章:从零实现稀疏矩阵运算
    • 第十六章:完整剪枝 Pipeline 与精度对比
  • 附录

第一部分:剪枝与稀疏化基础理论


第一章:绪论------为什么需要模型剪枝

1.1 模型冗余性的理论基础

1.1.1 过参数化现象

现代深度学习模型普遍采用**过参数化(over-parameterization)**策略------参数量远大于训练样本数。

观察

模型 参数量 训练样本(ImageNet) 参数/样本比
ResNet-50 25.6M 1.28M 20:1
ViT-Large 307M 1.28M 240:1
GPT-3 175B ~300B tokens ~0.58:1
LLaMA-70B 70B ~1T tokens ~0.07:1

问题:为什么过参数化的模型仍然泛化良好?

定理 1.1(过参数化的隐式正则化) :在梯度下降训练中,过参数化模型倾向于收敛到平坦极小值(flat minima),其 Hessian 矩阵的特征值谱具有大量接近零的特征值。

证明思路:设损失函数为 L(θ)\mathcal{L}(\theta)L(θ),在极小值 θ∗\theta^*θ∗ 处,Hessian H=∇2L(θ∗)H = \nabla^2 \mathcal{L}(\theta^*)H=∇2L(θ∗) 的特征值分解为 H=UΛUTH = U \Lambda U^TH=UΛUT。过参数化意味着 HHH 是低秩的------存在大量特征值 λi≈0\lambda_i \approx 0λi≈0。

这些接近零的特征值对应的方向是冗余方向 ------沿着这些方向移动参数,损失几乎不变。这为剪枝提供了理论基础。□\square□

1.1.2 冗余性的量化

定义 1.1(有效参数数量):模型的有效参数数量定义为:

deff=∑i=1dλiλi+αd_{\text{eff}} = \sum_{i=1}^{d} \frac{\lambda_i}{\lambda_i + \alpha}deff=i=1∑dλi+αλi

其中 λi\lambda_iλi 是 Hessian 的特征值,α\alphaα 是正则化参数。

当 α→0\alpha \to 0α→0 时,deffd_{\text{eff}}deff 趋近于 Hessian 的秩------即真正"有用"的参数数量。

实证观察 :对于典型的大语言模型,deff/dd_{\text{eff}} / ddeff/d 通常在 10%-50% 之间------意味着 50%-90% 的参数是冗余的。

1.2 剪枝的核心动机

剪枝(Pruning) 是移除模型中冗余参数(将其设为零)的过程。

效益 说明 数学刻画
减少存储 稀疏矩阵只需要存储非零元素 存储 ∝\propto∝ 非零元素数量
减少计算 跳过零元素的乘法 FLOPs ∝\propto∝ 非零元素数量
减少内存带宽 更少的数据需要从内存读取 带宽 ∝\propto∝ 非零元素数量
可能提高泛化 剪枝起到正则化的作用 经验观察

1.3 剪枝方法的分类

1.3.1 按粒度分类

粒度 说明 稀疏模式 硬件友好度
非结构化(Unstructured) 单个权重 任意位置
模式块(Pattern/Block) 固定模式(如 2:4) 规则模式
通道(Channel) 整个通道 整行/列
滤波器(Filter) 整个卷积核 结构化
层(Layer) 整个层 层级别 最高

1.3.2 按时机分类

时机 说明 代表方法
训练后剪枝(Post-Training) 训练完成后剪枝 幅度剪枝
训练中剪枝(During Training) 训练过程中逐步剪枝 迭代剪枝、GMP
训练前剪枝(Before Training) 初始化时就确定稀疏模式 SNIP、GraSP
一次射击剪枝(One-Shot) 不需要重新训练 SparseGPT、Wanda

1.3.3 按评分标准分类

标准 说明 代表方法
幅度(Magnitude) $ w
梯度(Gradient) $ \partial \mathcal{L}/\partial w
Hessian 二阶信息 OBS、OBD
激活感知 考虑输入激活的影响 Wanda
信息论 Fisher 信息、互信息 Fisher 剪枝

第二章:稀疏性的数学基础------范数、测度与信息论

2.1 稀疏性的形式化

2.1.1 ℓ0\ell_0ℓ0 "范数"

定义 2.1(ℓ0\ell_0ℓ0 "范数") :向量 w∈Rn\mathbf{w} \in \mathbb{R}^nw∈Rn 的 ℓ0\ell_0ℓ0 "范数"定义为非零元素的数量:

∥w∥0=∣{i:wi≠0}∣=∑i=1n1wi≠0\|\mathbf{w}\|0 = |\{i : w_i \neq 0\}| = \sum{i=1}^n \mathbf{1}w_i \\neq 0∥w∥0=∣{i:wi=0}∣=i=1∑n1wi=0

注意 :∥⋅∥0\|\cdot\|_0∥⋅∥0 不是真正的范数(不满足齐次性:∥cw∥0=∥w∥0\|c\mathbf{w}\|_0 = \|\mathbf{w}\|_0∥cw∥0=∥w∥0 对 c≠0c \neq 0c=0)。

稀疏率 :s=1−∥w∥0/ns = 1 - \|\mathbf{w}\|_0 / ns=1−∥w∥0/n,即零元素的比例。

2.1.2 ℓ0\ell_0ℓ0 约束优化

剪枝问题可以形式化为 ℓ0\ell_0ℓ0 约束优化:

min⁡wL(w)s.t.∥w∥0≤k\min_{\mathbf{w}} \mathcal{L}(\mathbf{w}) \quad \text{s.t.} \quad \|\mathbf{w}\|_0 \leq kwminL(w)s.t.∥w∥0≤k

或等价地,拉格朗日形式:

min⁡wL(w)+λ∥w∥0\min_{\mathbf{w}} \mathcal{L}(\mathbf{w}) + \lambda \|\mathbf{w}\|_0wminL(w)+λ∥w∥0

问题 :ℓ0\ell_0ℓ0 范数是非凸、非连续的,直接优化是 NP-hard 的。

定理 2.1(ℓ0\ell_0ℓ0 优化的 NP-hardness) :求解 min⁡∥w∥0≤k∥Aw−b∥2\min_{\|\mathbf{w}\|_0 \leq k} \|\mathbf{A}\mathbf{w} - \mathbf{b}\|^2min∥w∥0≤k∥Aw−b∥2 是 NP-hard 的(可以归约到稀疏回归问题)。

2.1.3 ℓ1\ell_1ℓ1 松弛

ℓ1\ell_1ℓ1 范数 是 ℓ0\ell_0ℓ0 的最紧凸松弛:

∥w∥1=∑i=1n∣wi∣\|\mathbf{w}\|1 = \sum{i=1}^n |w_i|∥w∥1=i=1∑n∣wi∣

定理 2.2(ℓ1\ell_1ℓ1 诱导稀疏性) :在满足受限等距性质(Restricted Isometry Property, RIP) 的条件下,ℓ1\ell_1ℓ1 最小化与 ℓ0\ell_0ℓ0 最小化等价------ℓ1\ell_1ℓ1 正则化自动产生稀疏解。

RIP 条件 :矩阵 A∈Rm×n\mathbf{A} \in \mathbb{R}^{m \times n}A∈Rm×n 满足 kkk-RIP,如果存在 δk∈(0,1)\delta_k \in (0, 1)δk∈(0,1) 使得对所有 kkk-稀疏向量 x\mathbf{x}x:

(1−δk)∥x∥2≤∥Ax∥2≤(1+δk)∥x∥2(1 - \delta_k)\|\mathbf{x}\|^2 \leq \|\mathbf{A}\mathbf{x}\|^2 \leq (1 + \delta_k)\|\mathbf{x}\|^2(1−δk)∥x∥2≤∥Ax∥2≤(1+δk)∥x∥2

定理 2.3(LASSO 的稀疏恢复) :如果 A\mathbf{A}A 满足 2k2k2k-RIP 且 δ2k<2−1\delta_{2k} < \sqrt{2} - 1δ2k<2 −1,则 LASSO 问题

min⁡w12∥Aw−b∥2+λ∥w∥1\min_{\mathbf{w}} \frac{1}{2}\|\mathbf{A}\mathbf{w} - \mathbf{b}\|^2 + \lambda\|\mathbf{w}\|_1wmin21∥Aw−b∥2+λ∥w∥1

的解 w^\hat{\mathbf{w}}w^ 满足:

∥w^−w∗∥2≤C1σk(w∗)1k+C2ϵ\|\hat{\mathbf{w}} - \mathbf{w}^*\|_2 \leq C_1 \frac{\sigma_k(\mathbf{w}^*)_1}{\sqrt{k}} + C_2 \epsilon∥w^−w∗∥2≤C1k σk(w∗)1+C2ϵ

其中 σk(w)1=min⁡∥v∥0≤k∥w−v∥1\sigma_k(\mathbf{w})1 = \min{\|\mathbf{v}\|_0 \leq k} \|\mathbf{w} - \mathbf{v}\|_1σk(w)1=min∥v∥0≤k∥w−v∥1 是最佳 kkk-项逼近误差。□\square□

2.2 剪枝的信息论视角

2.2.1 Fisher 信息与参数重要性

定义 2.2(Fisher 信息矩阵) :模型参数 θ\thetaθ 的 Fisher 信息矩阵定义为:

F(θ)=Ex∼pdata∇θlog⁡p(y∣x;θ)⋅(∇θlog⁡p(y∣x;θ))TF(\theta) = \mathbb{E}{x \sim p{\text{data}}} \left\\nabla_\\theta \\log p(y\|x; \\theta) \\cdot (\\nabla_\\theta \\log p(y\|x; \\theta))\^T\\rightF(θ)=Ex∼pdata∇θlogp(y∣x;θ)⋅(∇θlogp(y∣x;θ))T

Fisher 信息矩阵的对角元素:

Fii=E(∂log⁡p(y∣x;θ)∂θi)2F_{ii} = \mathbb{E}\left\\left(\\frac{\\partial \\log p(y\|x; \\theta)}{\\partial \\theta_i}\\right)\^2\\rightFii=E(∂θi∂logp(y∣x;θ))2

衡量了参数 θi\theta_iθi 对模型输出的"重要性"。

定理 2.4(Fisher 信息与 KL 散度) :在参数 θ\thetaθ 的局部邻域内,模型输出分布的 KL 散度近似为:

DKL(pθ∥pθ+δθ)≈12δθTF(θ)δθD_{\text{KL}}(p_{\theta} \| p_{\theta + \delta\theta}) \approx \frac{1}{2} \delta\theta^T F(\theta) \delta\thetaDKL(pθ∥pθ+δθ)≈21δθTF(θ)δθ

推论 2.1 :移除参数 θi\theta_iθi(设为零)导致的 KL 散度增加约为:

ΔDKL≈12Fiiθi2\Delta D_{\text{KL}} \approx \frac{1}{2} F_{ii} \theta_i^2ΔDKL≈21Fiiθi2

因此,Fiiθi2F_{ii} \theta_i^2Fiiθi2 可以作为参数重要性的度量------同时考虑了参数值的大小(θi2\theta_i^2θi2)和参数对输出的影响(FiiF_{ii}Fii)。

2.2.2 移除参数的信息损失

定理 2.5(剪枝的信息论下界) :设原始模型参数为 θ∈Rn\theta \in \mathbb{R}^nθ∈Rn,剪枝后的参数为 θ^\hat{\theta}θ^(∥θ^∥0=k\|\hat{\theta}\|_0 = k∥θ^∥0=k),则模型输出分布的 KL 散度满足:

DKL(pθ∥pθ^)≥12∑i∈SFii(θi−θ^i)2D_{\text{KL}}(p_\theta \| p_{\hat{\theta}}) \geq \frac{1}{2} \sum_{i \in \mathcal{S}} F_{ii} (\theta_i - \hat{\theta}_i)^2DKL(pθ∥pθ^)≥21i∈S∑Fii(θi−θ^i)2

其中 S\mathcal{S}S 是被剪枝的参数集合。

最优剪枝策略 :选择使信息损失最小的 kkk 个参数保留:

S∗=arg⁡min⁡∣S∣=n−k∑i∈SFiiθi2\mathcal{S}^* = \arg\min_{|\mathcal{S}| = n-k} \sum_{i \in \mathcal{S}} F_{ii} \theta_i^2S∗=arg∣S∣=n−kmini∈S∑Fiiθi2

等价于保留 Fiiθi2F_{ii} \theta_i^2Fiiθi2 最大的 kkk 个参数。

2.3 剪枝与泛化

2.3.1 剪枝作为正则化

定理 2.6(剪枝的隐式正则化效应) :设 θ^pruned\hat{\theta}_{\text{pruned}}θ^pruned 是剪枝后的参数,θ∗\theta^*θ∗ 是未剪枝的最优参数。在一定条件下:

ELtest(θ\^pruned)≤ELtest(θ∗)+O(klog⁡nT)\mathbb{E}\\mathcal{L}_{\\text{test}}(\\hat{\\theta}_{\\text{pruned}}) \leq \mathbb{E}\\mathcal{L}_{\\text{test}}(\\theta\^\*) + O\left(\frac{k \log n}{T}\right)ELtest(θ\^pruned)≤ELtest(θ∗)+O(Tklogn)

其中 kkk 是保留的参数数量,TTT 是训练样本数。

直觉:剪枝减少了模型的有效复杂度,从而减少了泛化误差的上界。但过度剪枝会增加近似误差。

2.3.2 剪枝-精度的权衡曲线

定义 2.3(剪枝曲线) :剪枝曲线是稀疏率 sss 与模型精度 Acc(s)\text{Acc}(s)Acc(s) 之间的关系。

经验观察:剪枝曲线通常呈现三个阶段:

  1. 平坦区 (s<s1s < s_1s<s1):精度几乎不变,冗余参数被移除
  2. 缓慢下降区 (s1<s<s2s_1 < s < s_2s1<s<s2):精度缓慢下降
  3. 急剧下降区 (s>s2s > s_2s>s2):精度急剧下降,关键参数被移除

定理 2.7(剪枝曲线的理论形状) :对于线性模型 y^=wTx\hat{y} = \mathbf{w}^T \mathbf{x}y^=wTx,使用最优幅度剪枝,测试 MSE 为:

MSE(s)=MSE(0)+σ2n⋅s1−s⋅∑i∈Pwi2σi2\text{MSE}(s) = \text{MSE}(0) + \frac{\sigma^2}{n} \cdot \frac{s}{1-s} \cdot \sum_{i \in \mathcal{P}} \frac{w_i^2}{\sigma_i^2}MSE(s)=MSE(0)+nσ2⋅1−ss⋅i∈P∑σi2wi2

其中 P\mathcal{P}P 是被剪枝的参数集合,σi2\sigma_i^2σi2 是第 iii 个特征的方差。


第三章:最优脑损伤(OBD)与二阶方法

3.1 最优脑损伤(Optimal Brain Damage)

3.1.1 问题形式化

LeCun et al., 1989 提出的最优脑损伤(OBD) 是第一个系统性的剪枝方法,基于损失函数的二阶泰勒展开。

设模型参数为 θ∈Rn\theta \in \mathbb{R}^nθ∈Rn,损失函数为 L(θ)\mathcal{L}(\theta)L(θ)。在最优解 θ∗\theta^*θ∗ 处的二阶泰勒展开:

L(θ∗+δθ)≈L(θ∗)+12δθTHδθ\mathcal{L}(\theta^* + \delta\theta) \approx \mathcal{L}(\theta^*) + \frac{1}{2} \delta\theta^T H \delta\thetaL(θ∗+δθ)≈L(θ∗)+21δθTHδθ

其中 H=∇2L(θ∗)H = \nabla^2 \mathcal{L}(\theta^*)H=∇2L(θ∗) 是 Hessian 矩阵(一阶项为零,因为在最优解处梯度为零)。

3.1.2 单参数删除的误差

定理 3.1(OBD 的删除准则) :删除参数 θi\theta_iθi(设为零)导致的损失增加为:

δLi=12Hiiθi∗2\delta\mathcal{L}i = \frac{1}{2} H{ii} \theta_i^{*2}δLi=21Hiiθi∗2

证明 :设 δθ=−θi∗ei\delta\theta = -\theta_i^* \mathbf{e}_iδθ=−θi∗ei(即第 iii 个分量变为零),则:

δL=12(−θi∗ei)TH(−θi∗ei)=12θi∗2Hii\delta\mathcal{L} = \frac{1}{2} (-\theta_i^* \mathbf{e}_i)^T H (-\theta_i^* \mathbf{e}i) = \frac{1}{2} \theta_i^{*2} H{ii}δL=21(−θi∗ei)TH(−θi∗ei)=21θi∗2Hii

其中 ei\mathbf{e}_iei 是第 iii 个标准基向量。□\square□

OBD 的剪枝准则 :删除使 δLi=Hiiθi∗2\delta\mathcal{L}i = H{ii} \theta_i^{*2}δLi=Hiiθi∗2 最小的参数。

3.1.3 对角近似

问题 :计算完整 Hessian H∈Rn×nH \in \mathbb{R}^{n \times n}H∈Rn×n 的复杂度为 O(n2)O(n^2)O(n2)(存储)和 O(n3)O(n^3)O(n3)(求逆),对于大模型不可行。

OBD 的近似 :假设 Hessian 是对角 的------即 H≈diag(H11,...,Hnn)H \approx \text{diag}(H_{11}, \dots, H_{nn})H≈diag(H11,...,Hnn)。

这等价于假设各参数的二阶导数之间不相关:

∂2L∂θi∂θj≈0,i≠j\frac{\partial^2 \mathcal{L}}{\partial \theta_i \partial \theta_j} \approx 0, \quad i \neq j∂θi∂θj∂2L≈0,i=j

对角 Hessian 的计算

Hii=∂2L∂θi2=E(∂log⁡p(y∣x;θ)∂θi)2+E∂2log⁡p(y∣x;θ)∂θi2H_{ii} = \frac{\partial^2 \mathcal{L}}{\partial \theta_i^2} = \mathbb{E}\left\\left(\\frac{\\partial \\log p(y\|x; \\theta)}{\\partial \\theta_i}\\right)\^2\\right + \mathbb{E}\left\\frac{\\partial\^2 \\log p(y\|x; \\theta)}{\\partial \\theta_i\^2}\\rightHii=∂θi2∂2L=E(∂θi∂logp(y∣x;θ))2+E∂θi2∂2logp(y∣x;θ)

第一项是 Fisher 信息的对角元素,第二项在最优解处通常很小(可以忽略)。

推论 3.1:在对角近似下,OBD 的剪枝准则近似为:

δLi≈12Fiiθi∗2\delta\mathcal{L}i \approx \frac{1}{2} F{ii} \theta_i^{*2}δLi≈21Fiiθi∗2

这与 Fisher 信息的参数重要性度量一致。

3.2 最优脑外科(Optimal Brain Surgeon)

3.2.1 OBS 的框架

Hassibi & Stork, 1993 提出的最优脑外科(OBS) 改进了 OBD------删除一个参数后,允许其他参数自由调整以补偿误差。

问题 :删除参数 θi\theta_iθi 后,找到最优的参数调整 δθ\delta\thetaδθ:

min⁡δθδθTHδθs.t.eiT(θ∗+δθ)=0\min_{\delta\theta} \delta\theta^T H \delta\theta \quad \text{s.t.} \quad \mathbf{e}_i^T (\theta^* + \delta\theta) = 0δθminδθTHδθs.t.eiT(θ∗+δθ)=0

约束条件确保 θi+δθi=0\theta_i + \delta\theta_i = 0θi+δθi=0(第 iii 个参数被设为零)。

3.2.2 OBS 的最优解

定理 3.2(OBS 的最优调整):使用拉格朗日乘子法,最优调整为:

δθ∗=−θi∗H−1iiH−1ei\delta\theta^* = -\frac{\theta_i^*}{H\^{-1}_{ii}} H^{-1} \mathbf{e}_iδθ∗=−H−1iiθi∗H−1ei

对应的损失增加为:

δLiOBS=θi∗22H−1ii\delta\mathcal{L}i^{\text{OBS}} = \frac{\theta_i^{*2}}{2H\^{-1}{ii}}δLiOBS=2H−1iiθi∗2

证明:拉格朗日函数为:

L(δθ,λ)=12δθTHδθ+λ(eiTθ∗+eiTδθ)\mathcal{L}(\delta\theta, \lambda) = \frac{1}{2} \delta\theta^T H \delta\theta + \lambda (\mathbf{e}_i^T \theta^* + \mathbf{e}_i^T \delta\theta)L(δθ,λ)=21δθTHδθ+λ(eiTθ∗+eiTδθ)

对 δθ\delta\thetaδθ 求导并令其为零:

Hδθ+λei=0  ⟹  δθ=−λH−1eiH \delta\theta + \lambda \mathbf{e}_i = 0 \implies \delta\theta = -\lambda H^{-1} \mathbf{e}_iHδθ+λei=0⟹δθ=−λH−1ei

代入约束:

eiTθ∗−λeiTH−1ei=0  ⟹  λ=θi∗H−1ii\mathbf{e}_i^T \theta^* - \lambda \mathbf{e}_i^T H^{-1} \mathbf{e}i = 0 \implies \lambda = \frac{\theta_i^*}{H\^{-1}{ii}}eiTθ∗−λeiTH−1ei=0⟹λ=H−1iiθi∗

因此 δθ∗=−θi∗H−1iiH−1ei\delta\theta^* = -\frac{\theta_i^*}{H\^{-1}_{ii}} H^{-1} \mathbf{e}_iδθ∗=−H−1iiθi∗H−1ei。

损失增加:

δL=12(δθ∗)THδθ∗=12θi∗2H−1ii2eiTH−1HH−1ei=θi∗22H−1ii\delta\mathcal{L} = \frac{1}{2} (\delta\theta^*)^T H \delta\theta^* = \frac{1}{2} \frac{\theta_i^{*2}}{H\^{-1}_{ii}^2} \mathbf{e}_i^T H^{-1} H H^{-1} \mathbf{e}i = \frac{\theta_i^{*2}}{2H\^{-1}{ii}}δL=21(δθ∗)THδθ∗=21H−1ii2θi∗2eiTH−1HH−1ei=2H−1iiθi∗2

□\square□

3.2.3 OBS vs OBD

定理 3.3(OBS 优于 OBD) :对于任意参数 θi\theta_iθi:

δLiOBS=θi∗22H−1ii≤12Hiiθi∗2=δLiOBD\delta\mathcal{L}i^{\text{OBS}} = \frac{\theta_i^{*2}}{2H\^{-1}{ii}} \leq \frac{1}{2} H_{ii} \theta_i^{*2} = \delta\mathcal{L}_i^{\text{OBD}}δLiOBS=2H−1iiθi∗2≤21Hiiθi∗2=δLiOBD

等号成立当且仅当 HHH 是对角矩阵。

证明 :对于正定矩阵 HHH,H−1ii≥1/HiiH\^{-1}{ii} \geq 1/H{ii}H−1ii≥1/Hii(由 Cauchy-Schwarz 不等式)。□\square□

直觉:OBS 允许其他参数补偿被删除参数的功能,因此损失更小。OBD 假设其他参数不变,因此损失更大。

3.2.4 OBS 的高效实现

问题 :OBS 需要 H−1H^{-1}H−1,计算复杂度为 O(n3)O(n^3)O(n3)。

高效 OBS(EOBS):使用 Sherman-Morrison 公式进行增量更新。

设当前 Hessian 逆为 H−1H^{-1}H−1,删除参数 θi\theta_iθi 后,更新后的 Hessian(去掉第 iii 行和列)的逆可以通过以下公式计算:

H−1iˉ,iˉnew=H−1iˉ,iˉ−H−1iˉ,iH−1i,iˉH−1iiH\^{-1}{\bar{i},\bar{i}}^{\text{new}} = H\^{-1}{\bar{i},\bar{i}} - \frac{H\^{-1}{\bar{i},i} H\^{-1}{i,\bar{i}}}{H\^{-1}_{ii}}H−1iˉ,iˉnew=H−1iˉ,iˉ−H−1iiH−1iˉ,iH−1i,iˉ

其中 iˉ\bar{i}iˉ 表示除第 iii 个之外的所有索引。


第二部分:非结构化剪枝方法


第四章:幅度剪枝与迭代剪枝

4.1 单次幅度剪枝

4.1.1 基本定义

幅度剪枝(Magnitude Pruning) 是最简单的剪枝方法------删除绝对值最小的权重:

w^i={wiif ∣wi∣>τ0otherwise\hat{w}_i = \begin{cases} w_i & \text{if } |w_i| > \tau \\ 0 & \text{otherwise} \end{cases}w^i={wi0if ∣wi∣>τotherwise

其中阈值 τ\tauτ 由目标稀疏率 sss 决定:τ=Percentile(∣w∣,s×100)\tau = \text{Percentile}(|\mathbf{w}|, s \times 100)τ=Percentile(∣w∣,s×100)。

4.1.2 幅度剪枝的理论基础

定理 4.1(幅度剪枝的最优性条件) :对于线性模型 y^=wTx\hat{y} = \mathbf{w}^T \mathbf{x}y^=wTx,x∼N(0,I)\mathbf{x} \sim \mathcal{N}(0, I)x∼N(0,I),使用对角 Fisher 信息近似,幅度剪枝等价于最小化 Fisher 信息加权的参数重要性:

importance(wi)=Fiiwi2=Exi2⋅wi2=wi2\text{importance}(w_i) = F_{ii} w_i^2 = \mathbb{E}x_i\^2 \cdot w_i^2 = w_i^2importance(wi)=Fiiwi2=Exi2⋅wi2=wi2

因为 Fii=Exi2=1F_{ii} = \mathbb{E}x_i\^2 = 1Fii=Exi2=1(对于标准高斯输入)。

推论 4.1:当输入特征是独立同分布的标准高斯时,幅度剪枝是最优的(在对角 Fisher 近似下)。

但实际中:输入特征通常不是独立同分布的,幅度剪枝可能不是最优的。

4.1.3 幅度剪枝的局限

问题 1:权重大小不等于重要性

某些权重可能很小,但对模型输出有重要影响(例如,通过与其他大权重的交互)。

问题 2:不考虑权重之间的相关性

幅度剪枝独立评估每个权重,忽略了权重之间的协同效应。

问题 3:一次性剪枝精度损失大

高稀疏率下,单次剪枝的精度损失远大于迭代剪枝。

4.2 迭代剪枝(Iterative Pruning)

4.2.1 基本思想

迭代剪枝 (也称为 逐步剪枝贪心剪枝):每次只剪枝一小部分权重,然后微调恢复精度,重复直到达到目标稀疏率。

算法 4.1(迭代幅度剪枝)

复制代码
输入:预训练模型 θ_0,目标稀疏率 s,每次剪枝比例 p
输出:稀疏模型 θ_sparse

θ = θ_0
当前稀疏率 = 0
while 当前稀疏率 < s:
    1. 剪枝:删除 |θ_i| 最小的 p 比例参数
    2. 微调:在训练数据上微调剩余参数
    3. 更新当前稀疏率
返回 θ

4.2.2 迭代剪枝的理论分析

定理 4.2(迭代剪枝的误差上界) :设每次剪枝比例为 ppp,微调后损失恢复率为 ρ\rhoρ(即微调后损失为剪枝前的 ρ\rhoρ 倍),LLL 次迭代后的总损失为:

LL≤L0⋅∏l=1L(1+(1−ρ)p)\mathcal{L}_L \leq \mathcal{L}0 \cdot \prod{l=1}^{L} (1 + (1-\rho) p)LL≤L0⋅l=1∏L(1+(1−ρ)p)

当 ppp 很小且 ρ\rhoρ 接近 1 时,LL≈L0⋅(1+(1−ρ)Lp)\mathcal{L}_L \approx \mathcal{L}_0 \cdot (1 + (1-\rho) L p)LL≈L0⋅(1+(1−ρ)Lp)------损失随迭代次数线性增长。

与单次剪枝的比较 :单次剪枝率 s=1−(1−p)Ls = 1 - (1-p)^Ls=1−(1−p)L 时,单次剪枝的损失远大于迭代剪枝。

4.2.3 剪枝调度(Pruning Schedule)

问题 :每次剪枝的比例 ppp 如何选择?

常见调度

调度 公式 说明
线性 pl=p0p_l = p_0pl=p0 每次固定比例
多项式 pl=p0⋅(1−l/L)αp_l = p_0 \cdot (1 - l/L)^\alphapl=p0⋅(1−l/L)α 逐渐减小比例
指数 pl=p0⋅γlp_l = p_0 \cdot \gamma^lpl=p0⋅γl 指数衰减
余弦 pl=p0⋅(1+cos⁡(πl/L))/2p_l = p_0 \cdot (1 + \cos(\pi l/L))/2pl=p0⋅(1+cos(πl/L))/2 余弦退火

定理 4.3(最优剪枝调度):在一定假设下,余弦调度在相同总剪枝率下达到最低的最终损失。

4.3 渐进式幅度剪枝(GMP)

4.3.1 渐进式剪枝

渐进式幅度剪枝(Gradual Magnitude Pruning, GMP) (Zhu & Gupta, 2017)在训练过程中连续地增加稀疏率。

稀疏率调度

s(t)=sf+(si−sf)(1−t−titf−ti)3s(t) = s_f + (s_i - s_f) \left(1 - \frac{t - t_i}{t_f - t_i}\right)^3s(t)=sf+(si−sf)(1−tf−tit−ti)3

其中:

  • sis_isi:初始稀疏率(通常为 0)
  • sfs_fsf:最终稀疏率
  • tit_iti:开始剪枝的训练步数
  • tft_ftf:结束剪枝的训练步数

4.3.2 GMP 的理论基础

定理 4.4(GMP 的收敛性):在平滑性和有界梯度假设下,GMP 训练的模型收敛到损失函数的一个稳定点,其梯度范数满足:

∥∇L(θT)∥2≤O(1T+sfT)\|\nabla \mathcal{L}(\theta_T)\|^2 \leq O\left(\frac{1}{\sqrt{T}} + \frac{s_f}{\sqrt{T}}\right)∥∇L(θT)∥2≤O(T 1+T sf)

直觉:渐进式剪枝允许模型在训练过程中逐步适应稀疏约束,比一次性剪枝更温和。


第五章:SNIP、GraSP 与一次射击剪枝

5.1 SNIP:单次剪枝识别重要连接

5.1.1 核心思想

SNIP (Lee et al., 2019)在训练之前就确定稀疏模式,基于单次前向-反向传播的梯度信息。

连接灵敏度(Connection Sensitivity)

si=∣∂L∂ci⋅ci∣s_i = \left|\frac{\partial \mathcal{L}}{\partial c_i} \cdot c_i\right|si= ∂ci∂L⋅ci

其中 ci∈{0,1}c_i \in \{0, 1\}ci∈{0,1} 是连接掩码(1 表示保留,0 表示删除),L\mathcal{L}L 是损失函数。

直觉 :sis_isi 衡量了如果将连接 iii 从"存在"变为"不存在",损失会变化多少。

5.1.2 SNIP 的数学推导

定理 5.1(SNIP 的梯度-幅度准则):连接灵敏度可以简化为:

si=∣wi⋅gi∣s_i = |w_i \cdot g_i|si=∣wi⋅gi∣

其中 gi=∂L/∂wig_i = \partial \mathcal{L} / \partial w_igi=∂L/∂wi 是权重 wiw_iwi 的梯度。

证明 :设 wic=ci⋅wiw_i^c = c_i \cdot w_iwic=ci⋅wi,则:

∂L∂ci=∂L∂wic⋅∂wic∂ci=gi⋅wi\frac{\partial \mathcal{L}}{\partial c_i} = \frac{\partial \mathcal{L}}{\partial w_i^c} \cdot \frac{\partial w_i^c}{\partial c_i} = g_i \cdot w_i∂ci∂L=∂wic∂L⋅∂ci∂wic=gi⋅wi

因此 si=∣gi⋅wi∣s_i = |g_i \cdot w_i|si=∣gi⋅wi∣。□\square□

SNIP 算法

复制代码
1. 采样一批数据
2. 前向传播,计算损失
3. 反向传播,计算梯度 g
4. 计算灵敏度 s_i = |w_i * g_i|
5. 保留 s_i 最大的 k 个连接
6. 使用保留的连接进行正常训练

5.1.3 SNIP 的理论性质

定理 5.2(SNIP 与 Fisher 信息的关系) :连接灵敏度 si=∣wigi∣s_i = |w_i g_i|si=∣wigi∣ 是 Fisher 信息加权参数重要性 Fiiwi2F_{ii} w_i^2Fiiwi2 的单样本无偏估计:

Esi2=Ewi2gi2=wi2Egi2=wi2Fii\mathbb{E}s_i\^2 = \mathbb{E}w_i\^2 g_i\^2 = w_i^2 \mathbb{E}g_i\^2 = w_i^2 F_{ii}Esi2=Ewi2gi2=wi2Egi2=wi2Fii

推论:SNIP 可以看作是基于单样本 Fisher 信息的参数重要性评估。

5.2 GraSP:梯度信号剪枝

5.2.1 核心思想

GraSP (Wang et al., 2020)改进了 SNIP------不仅考虑梯度的大小,还考虑梯度的方向

定义 5.1(GraSP 评分)

GraSPi=−gi⋅(Hg)i\text{GraSP}_i = -g_i \cdot (\mathbf{H}\mathbf{g})_iGraSPi=−gi⋅(Hg)i

其中 H\mathbf{H}H 是 Hessian 矩阵,g\mathbf{g}g 是梯度向量。

直觉 :GraSP 评分衡量了删除参数 iii 后,梯度信号的变化。如果 GraSPi>0\text{GraSP}_i > 0GraSPi>0,删除该参数会增强 梯度信号(有利于训练);如果 GraSPi<0\text{GraSP}_i < 0GraSPi<0,删除该参数会减弱梯度信号。

5.2.2 GraSP 的理论推导

定理 5.3(GraSP 的最优性):GraSP 评分等价于最小化删除参数后的梯度范数变化:

GraSPi=−∂∥∇L∥2∂ci∣c=1\text{GraSP}i = -\frac{\partial \|\nabla \mathcal{L}\|^2}{\partial c_i}\Bigg|{c=\mathbf{1}}GraSPi=−∂ci∂∥∇L∥2 c=1

证明 :设掩码为 ccc,参数为 wic=ciwiw_i^c = c_i w_iwic=ciwi。梯度范数为:

∥∇L∥2=∑igi2\|\nabla \mathcal{L}\|^2 = \sum_i g_i^2∥∇L∥2=i∑gi2

对 cic_ici 求导(使用链式法则和 Hessian):

∂∥∇L∥2∂ci=2∑jgj∂gj∂ci=2∑jgj∂2L∂wj∂wi⋅wi=2wi(Hg)i\frac{\partial \|\nabla \mathcal{L}\|^2}{\partial c_i} = 2 \sum_j g_j \frac{\partial g_j}{\partial c_i} = 2 \sum_j g_j \frac{\partial^2 \mathcal{L}}{\partial w_j \partial w_i} \cdot w_i = 2 w_i (\mathbf{H}\mathbf{g})_i∂ci∂∥∇L∥2=2j∑gj∂ci∂gj=2j∑gj∂wj∂wi∂2L⋅wi=2wi(Hg)i

因此 GraSPi=−gi(Hg)i\text{GraSP}_i = -g_i (\mathbf{H}\mathbf{g})_iGraSPi=−gi(Hg)i(忽略常数因子)。□\square□

5.2.3 GraSP 的高效计算

问题 :计算 Hg\mathbf{H}\mathbf{g}Hg 需要 Hessian-向量乘积,复杂度为 O(n2)O(n^2)O(n2)。

高效实现 :使用 Pearlmutter 技巧------Hessian-向量乘积可以在一次前向-反向传播中计算:

Hv=∇θ(gTv)\mathbf{H}\mathbf{v} = \nabla_\theta (\mathbf{g}^T \mathbf{v})Hv=∇θ(gTv)

其中 v\mathbf{v}v 是任意向量。

5.3 SparseGPT:大语言模型的一次射击剪枝

5.3.1 核心思想

SparseGPT(Frantar & Alistarh, 2023)将 GPTQ 的思想从量化推广到剪枝------通过 Hessian 信息进行最优权重删除和误差补偿。

问题形式化 :对于权重矩阵 W∈Rm×nW \in \mathbb{R}^{m \times n}W∈Rm×n 和校准输入 XXX:

min⁡W^∥(W−W^)X∥F2s.t.∥W^∥0≤k\min_{\hat{W}} \|(W - \hat{W})X\|_F^2 \quad \text{s.t.} \quad \|\hat{W}\|_0 \leq kW^min∥(W−W^)X∥F2s.t.∥W^∥0≤k

5.3.2 SparseGPT 的算法

算法 5.1(SparseGPT)

复制代码
输入:权重 W (m×n),校准输入 X,目标稀疏率 s
输出:稀疏权重 W_sparse

1. 计算 Hessian: H = 2/n * X^T X + λI
2. 计算 H^{-1}
3. 对每一行 i = 1, ..., m:
   a. 初始化: w = W[i, :], δ = 0
   b. 对每一列 j = 1, ..., n:
      - 如果 |w[j]| < τ (阈值):
        * 删除: w[j] = 0
        * 补偿: w[j+1:n] -= (w[j] / H^{-1}[j,j]) * H^{-1}[j, j+1:n]
      - 否则:
        * 保留: W_sparse[i, j] = w[j]
4. 返回 W_sparse

5.3.3 SparseGPT 与 GPTQ 的统一

定理 5.4(SparseGPT-GPTQ 统一框架):SparseGPT 和 GPTQ 都基于相同的最优脑外科框架,区别在于:

  • GPTQ:将权重量化到最近的量化级别,补偿误差
  • SparseGPT:将权重设为零(或保留),补偿误差

两者可以统一为:对每个权重,选择一个"目标值"(量化级别或零),然后用 Hessian 信息补偿误差。

5.4 Wanda:权重与激活剪枝

5.4.1 核心思想

Wanda(Sun et al., 2024)提出了一种更简单但同样有效的方法:同时考虑权重的幅度和对应输入激活的幅度。

评分准则

scoreij=∣Wij∣⋅∥Xj∥2\text{score}{ij} = |W{ij}| \cdot \|X_j\|_2scoreij=∣Wij∣⋅∥Xj∥2

其中 XjX_jXj 是输入激活的第 jjj 个通道(在校准数据上取平均)。

直觉 :权重 WijW_{ij}Wij 的重要性取决于:

  • 权重本身的大小 ∣Wij∣|W_{ij}|∣Wij∣
  • 对应输入通道的"活跃程度" ∥Xj∥2\|X_j\|_2∥Xj∥2

5.4.2 Wanda 与 OBS 的关系

定理 5.5(Wanda 作为 OBS 的近似):在对角 Hessian 近似下,OBS 的参数重要性为:

importanceijOBS=Hjj⋅Wij2\text{importance}{ij}^{\text{OBS}} = H{jj} \cdot W_{ij}^2importanceijOBS=Hjj⋅Wij2

其中 Hjj=2T∑tXtj2∝∥Xj∥22H_{jj} = \frac{2}{T}\sum_t X_{tj}^2 \propto \|X_j\|_2^2Hjj=T2∑tXtj2∝∥Xj∥22。

因此:

importanceijOBS∝∥Xj∥22⋅Wij2=(∥Xj∥2⋅∣Wij∣)2\text{importance}_{ij}^{\text{OBS}} \propto \|X_j\|2^2 \cdot W{ij}^2 = (\|X_j\|2 \cdot |W{ij}|)^2importanceijOBS∝∥Xj∥22⋅Wij2=(∥Xj∥2⋅∣Wij∣)2

Wanda 的评分 scoreij=∣Wij∣⋅∥Xj∥2\text{score}{ij} = |W{ij}| \cdot \|X_j\|_2scoreij=∣Wij∣⋅∥Xj∥2 正是 OBS 重要性的平方根------单调等价。□\square□


第六章:彩票假说(Lottery Ticket Hypothesis)

6.1 核心陈述

彩票假说(Lottery Ticket Hypothesis, LTH)(Frankle & Carlin, 2019):

猜想 6.1(彩票假说) :对于一个随机初始化的密集网络 f(θ0)f(\theta_0)f(θ0),存在一个子网络 f(θ0⊙m)f(\theta_0 \odot m)f(θ0⊙m)(其中 mmm 是二值掩码),该子网络从相同的初始化开始训练,可以在相同或更少的迭代次数内达到与原始密集网络相当的精度。

称这样的子网络为中奖彩票(winning ticket)

6.1.1 形式化

设密集网络参数为 θ0∈Rn\theta_0 \in \mathbb{R}^nθ0∈Rn(随机初始化),训练后的参数为 θT\theta_TθT。掩码 m∈{0,1}nm \in \{0, 1\}^nm∈{0,1}n。

中奖彩票 :存在 mmm 使得:

L(train(m⊙θ0,T))≤L(θT)\mathcal{L}(\text{train}(m \odot \theta_0, T)) \leq \mathcal{L}(\theta_T)L(train(m⊙θ0,T))≤L(θT)

其中 train(θ0,T)\text{train}(\theta_0, T)train(θ0,T) 表示从 θ0\theta_0θ0 开始训练 TTT 步的结果。

6.2 迭代剪枝找到中奖彩票

6.2.1 迭代幅度剪枝(IMP)

算法 6.1(迭代幅度剪枝,IMP)

复制代码
输入:随机初始化 θ_0,训练步数 T,剪枝比例 p,迭代次数 L
输出:中奖彩票 (θ_0 ⊙ m, m)

θ = θ_0
m = 1  (全 1 掩码)
for l = 1, ..., L:
    1. 训练: θ_T = train(m ⊙ θ, T)
    2. 剪枝: 删除 |θ_T| 最小的 p 比例参数,更新 m
    3. 重置: θ = θ_0  (回到初始值)
返回 (m ⊙ θ_0, m)

关键步骤 :第 3 步将参数重置为初始值 θ0\theta_0θ0------这是彩票假说的核心要求。

6.2.2 为什么需要重置?

实验观察 :如果不重置为 θ0\theta_0θ0,而是继续从当前参数训练,精度会显著下降。

理论解释 :θ0\theta_0θ0 的某些特定初始化值恰好适合后续的训练动态。重置确保了这些"幸运"的初始化值被保留。

6.2.3 彩票假说的理论分析

定理 6.1(中奖彩票的存在性,简化版) :对于过参数化的线性网络 f(θ)=θTxf(\theta) = \theta^T xf(θ)=θTx,使用梯度下降训练,存在一个稀疏率为 sss 的子网络,其中:

s≤1−O(deffn)s \leq 1 - O\left(\frac{d_{\text{eff}}}{n}\right)s≤1−O(ndeff)

使得该子网络可以从 θ0\theta_0θ0 开始训练到与密集网络相同的精度。

证明思路 :过参数化意味着 Hessian 的有效秩 deff≪nd_{\text{eff}} \ll ndeff≪n。沿着 Hessian 的零空间方向,参数可以被剪枝而不影响训练动态。□\square□

6.3 彩票假说的局限与扩展

6.3.1 大规模模型的挑战

问题:IMP 需要多次从头训练,对于大模型(如 GPT-3)计算成本过高。

解决方案

  • 一次性剪枝(SparseGPT、Wanda):不需要重新训练
  • 转移彩票:在小模型上找到彩票,迁移到大模型
  • 动态稀疏训练(DST):在训练过程中动态调整稀疏模式

6.3.2 稀疏迁移学习

定理 6.2(彩票的可迁移性):在源任务上找到的中奖彩票,在目标任务上的表现优于随机稀疏模式,但不如在目标任务上直接找到的彩票。

经验观察:迁移的彩票提供了一个良好的初始化,但仍需要少量微调。


第三部分:结构化剪枝方法


第七章:通道剪枝与滤波器剪枝

7.1 通道剪枝的基本思想

7.1.1 为什么需要结构化剪枝?

非结构化剪枝的问题

  • 稀疏矩阵在通用硬件上难以加速
  • 需要专门的稀疏计算库和硬件
  • 实际加速比远低于理论值

结构化剪枝的优势

  • 剪枝后的模型仍然是密集的(只是更小)
  • 可以直接使用标准的密集计算库
  • 在通用硬件上实现真正的加速

7.1.2 通道剪枝的形式化

对于线性层 Y=XWY = XWY=XW,W∈Rm×nW \in \mathbb{R}^{m \times n}W∈Rm×n,通道剪枝移除 WWW 的整行或整列:

  • 输入通道剪枝 :移除 WWW 的列,同时移除对应输入特征
  • 输出通道剪枝 :移除 WWW 的行,同时移除对应输出特征

剪枝后的层 :Y′=X′W′Y' = X' W'Y′=X′W′,其中 W′∈Rm′×n′W' \in \mathbb{R}^{m' \times n'}W′∈Rm′×n′,m′≤mm' \leq mm′≤m,n′≤nn' \leq nn′≤n。

7.2 通道重要性度量

7.2.1 基于幅度的度量

ℓ1\ell_1ℓ1 范数

importance(ci)=∥Wi,:∥1=∑j∣Wij∣\text{importance}(c_i) = \|W_{i,:}\|1 = \sum_j |W{ij}|importance(ci)=∥Wi,:∥1=j∑∣Wij∣

ℓ2\ell_2ℓ2 范数

importance(ci)=∥Wi,:∥2=∑jWij2\text{importance}(c_i) = \|W_{i,:}\|2 = \sqrt{\sum_j W{ij}^2}importance(ci)=∥Wi,:∥2=j∑Wij2

定理 7.1(ℓ2\ell_2ℓ2 范数与输出方差) :对于随机输入 XXX,EX=0\mathbb{E}X = 0EX=0,Cov(X)=I\text{Cov}(X) = ICov(X)=I,第 iii 个输出的方差为:

Var(Yi)=∥Wi,:∥22\text{Var}(Y_i) = \|W_{i,:}\|_2^2Var(Yi)=∥Wi,:∥22

因此,ℓ2\ell_2ℓ2 范数小的通道对输出的影响也小。

7.2.2 基于梯度的度量

Taylor 展开 :移除通道 iii 的损失变化为:

δLi≈∣∂L∂Yi⋅Yi∣=∣gi⋅Yi∣\delta\mathcal{L}_i \approx \left|\frac{\partial \mathcal{L}}{\partial Y_i} \cdot Y_i\right| = |g_i \cdot Y_i|δLi≈ ∂Yi∂L⋅Yi =∣gi⋅Yi∣

其中 gi=∂L/∂Yig_i = \partial \mathcal{L} / \partial Y_igi=∂L/∂Yi 是输出的梯度。

批量平均

importance(ci)=1T∑t=1T∣gi,t⋅Yi,t∣\text{importance}(c_i) = \frac{1}{T} \sum_{t=1}^{T} |g_{i,t} \cdot Y_{i,t}|importance(ci)=T1t=1∑T∣gi,t⋅Yi,t∣

7.2.3 基于 Hessian 的度量

OBS 风格

importance(ci)=∥Wi,:∥22H−1ii\text{importance}(c_i) = \frac{\|W_{i,:}\|2^2}{H\^{-1}{ii}}importance(ci)=H−1ii∥Wi,:∥22

7.3 通道剪枝的算法

7.3.1 贪心剪枝

算法 7.1(贪心通道剪枝)

复制代码
输入:模型,目标通道数 k
输出:剪枝后的模型

1. 计算每个通道的重要性
2. 按重要性排序
3. 删除最不重要的 m-k 个通道
4. 微调模型

7.3.2 迭代通道剪枝

算法 7.2(迭代通道剪枝)

复制代码
输入:模型,目标稀疏率 s,每次剪枝比例 p
输出:剪枝后的模型

while 当前稀疏率 < s:
    1. 计算通道重要性
    2. 剪枝 p 比例的通道
    3. 微调
返回模型

7.4 通道剪枝的误差分析

定理 7.2(通道剪枝的输出误差) :设原始层为 Y=XWY = XWY=XW,剪枝后为 Y′=X′W′Y' = X'W'Y′=X′W′,则:

∥Y−Y′∥F≤∥X∥F⋅∥W−W′∥F\|Y - Y'\|_F \leq \|X\|_F \cdot \|W - W'\|_F∥Y−Y′∥F≤∥X∥F⋅∥W−W′∥F

其中 W′W'W′ 是将被剪枝通道设为零后的权重。

推论 7.1 :如果被剪枝通道的权重范数之和为 ϵ\epsilonϵ,则输出误差上界为 ∥X∥F⋅ϵ\|X\|_F \cdot \epsilon∥X∥F⋅ϵ。


第八章:注意力头剪枝与层剪枝

8.1 注意力头剪枝

8.1.1 Transformer 中的冗余

观察(Michel et al., 2019):在训练好的 Transformer 中,大量的注意力头是冗余的------移除它们对模型性能影响很小。

定义 8.1(注意力头重要性) :注意力头 hhh 的重要性定义为:

Ih=∣∂L∂αh⋅αh∣I_h = \left|\frac{\partial \mathcal{L}}{\partial \alpha_h} \cdot \alpha_h\right|Ih= ∂αh∂L⋅αh

其中 αh\alpha_hαh 是该头的注意力权重。

8.1.2 头重要性的计算

定理 8.1(头重要性的简化) :在多头注意力中,头 hhh 的重要性可以近似为:

Ih≈1T∑t=1T∥∂L∂Ah(t)⊙Ah(t)∥FI_h \approx \frac{1}{T} \sum_{t=1}^{T} \left\|\frac{\partial \mathcal{L}}{\partial A_h^{(t)}} \odot A_h^{(t)}\right\|_FIh≈T1t=1∑T ∂Ah(t)∂L⊙Ah(t) F

其中 Ah(t)A_h^{(t)}Ah(t) 是第 ttt 个样本的注意力权重矩阵。

8.1.3 交叉头剪枝

问题:独立评估每个头的重要性可能不是最优的------某些头可能单独看不重要,但与其他头组合时很重要。

解决方案 :考虑头之间的交互------使用联合重要性

Ih1,h2=Ih1+Ih2+interaction(h1,h2)I_{h_1, h_2} = I_{h_1} + I_{h_2} + \text{interaction}(h_1, h_2)Ih1,h2=Ih1+Ih2+interaction(h1,h2)

8.2 层剪枝

8.2.1 层重要性度量

方法 1:基于输出变化

importance(l)=E∥fl(x)−fl−1(x)∥2\text{importance}(l) = \mathbb{E}\left\\\|f_l(x) - f_{l-1}(x)\\\|\^2\\rightimportance(l)=E∥fl(x)−fl−1(x)∥2

即第 lll 层对表示的改变量。

方法 2:基于角度距离

importance(l)=1−Efl(x)Tfl−1(x)∥fl(x)∥⋅∥fl−1(x)∥\text{importance}(l) = 1 - \frac{\mathbb{E}f_l(x)\^T f_{l-1}(x)}{\|f_l(x)\| \cdot \|f_{l-1}(x)\|}importance(l)=1−∥fl(x)∥⋅∥fl−1(x)∥Efl(x)Tfl−1(x)

方法 3:基于逐层损失

importance(l)=L(f1,...,fl,id,...,id)−L(f1,...,fL)\text{importance}(l) = \mathcal{L}(f_1, \dots, f_l, \text{id}, \dots, \text{id}) - \mathcal{L}(f_1, \dots, f_L)importance(l)=L(f1,...,fl,id,...,id)−L(f1,...,fL)

其中 id\text{id}id 表示恒等映射(跳过该层)。

8.2.2 层剪枝的算法

算法 8.1(逐层剪枝)

复制代码
输入:L 层模型,目标层数 L'
输出:剪枝后的模型

1. 计算每层的重要性
2. 按重要性排序
3. 删除最不重要的 L-L' 层
4. 微调

注意:层剪枝通常只能删除少量层(10%-30%),否则精度会急剧下降。

8.3 大语言模型的结构化剪枝

8.3.1 LLM-Pruner

LLM-Pruner(Ma et al., 2023)提出了一种针对大语言模型的结构化剪枝框架:

  1. 依赖图分析:自动识别 LLM 中的可剪枝结构
  2. 梯度重要性:使用一阶梯度和 Hessian 信息评估结构重要性
  3. 快速恢复:使用 LoRA 进行高效微调

8.3.2 ShortGPT

ShortGPT(Men et al., 2024)发现 LLM 中存在大量的层冗余:

观察 :在 LLM 中,相邻层的表示高度相似------可以用余弦相似度衡量:

sim(l)=cos⁡(hl,hl−1)=hlThl−1∥hl∥⋅∥hl−1∥\text{sim}(l) = \cos(h_l, h_{l-1}) = \frac{h_l^T h_{l-1}}{\|h_l\| \cdot \|h_{l-1}\|}sim(l)=cos(hl,hl−1)=∥hl∥⋅∥hl−1∥hlThl−1

当 sim(l)≈1\text{sim}(l) \approx 1sim(l)≈1 时,第 $l} 层几乎不改变表示,可以安全删除。


第四部分:稀疏格式与稀疏计算


第九章:稀疏矩阵存储格式

9.1 CSR(Compressed Sparse Row)

9.1.1 格式定义

CSR(压缩稀疏行) 格式使用三个数组存储稀疏矩阵 A∈Rm×nA \in \mathbb{R}^{m \times n}A∈Rm×n:

  • values :非零元素的值,长度 nnznnznnz
  • col_indices :每个非零元素的列索引,长度 nnznnznnz
  • row_ptr :每行的起始位置,长度 m+1m+1m+1

示例

A=(100203004050)A = \begin{pmatrix} 1 & 0 & 0 & 2 \\ 0 & 3 & 0 & 0 \\ 4 & 0 & 5 & 0 \end{pmatrix}A= 104030005200

  • values = 1, 2, 3, 4, 5
  • col_indices = 0, 3, 1, 0, 2
  • row_ptr = 0, 2, 3, 5

9.1.2 存储复杂度

定理 9.1(CSR 的存储):CSR 格式需要:

存储=nnz⋅(bval+bidx)+(m+1)⋅bidx\text{存储} = nnz \cdot (b_{\text{val}} + b_{\text{idx}}) + (m+1) \cdot b_{\text{idx}}存储=nnz⋅(bval+bidx)+(m+1)⋅bidx

其中 bvalb_{\text{val}}bval 是值的比特数,bidxb_{\text{idx}}bidx 是索引的比特数。

对于 FP32 值和 INT32 索引:存储=nnz⋅8+(m+1)⋅4\text{存储} = nnz \cdot 8 + (m+1) \cdot 4存储=nnz⋅8+(m+1)⋅4 字节。

盈亏平衡点 :当 nnz<m⋅n⋅bvalbval+bidx+bidx/mnnz < \frac{m \cdot n \cdot b_{\text{val}}}{b_{\text{val}} + b_{\text{idx}} + b_{\text{idx}}/m}nnz<bval+bidx+bidx/mm⋅n⋅bval 时,CSR 比密集存储更省空间。

对于 FP32、大矩阵:nnz<mn⋅48=mn/2nnz < \frac{mn \cdot 4}{8} = mn/2nnz<8mn⋅4=mn/2,即稀疏率 > 50%。

9.2 CSC(Compressed Sparse Column)

CSC 是 CSR 的列版本:

  • values:非零元素的值
  • row_indices:每个非零元素的行索引
  • col_ptr:每列的起始位置

9.3 COO(Coordinate)

COO(坐标格式) 存储每个非零元素的 (行, 列, 值):

  • row_indices:行索引
  • col_indices:列索引
  • values:值

存储 :nnz⋅(2bidx+bval)nnz \cdot (2 b_{\text{idx}} + b_{\text{val}})nnz⋅(2bidx+bval)

优点:构建简单,适合动态插入非零元素。

9.4 BSR(Block Sparse Row)

BSR(块稀疏行) 格式将矩阵分成 br×bcb_r \times b_cbr×bc 的块,只存储非零块。

优点

  • 更好的数据局部性
  • 适合 GPU 的块操作
  • 索引开销更小(每个块只有一个索引)

存储 :nnzblocks⋅br⋅bc⋅bval+nnzblocks⋅bidx+(m/br+1)⋅bidxnnz_{\text{blocks}} \cdot b_r \cdot b_c \cdot b_{\text{val}} + nnz_{\text{blocks}} \cdot b_{\text{idx}} + (m/b_r + 1) \cdot b_{\text{idx}}nnzblocks⋅br⋅bc⋅bval+nnzblocks⋅bidx+(m/br+1)⋅bidx


第十章:稀疏矩阵乘法与硬件加速

10.1 稀疏-密集矩阵乘法(SpMM)

10.1.1 CSR 格式的 SpMM

算法 10.1(CSR SpMM)

复制代码
输入:CSR 矩阵 A (m×n), 密集矩阵 B (n×k)
输出:C = A @ B (m×k)

for i = 0, ..., m-1:
    for j = row_ptr[i], ..., row_ptr[i+1]-1:
        col = col_indices[j]
        val = values[j]
        for p = 0, ..., k-1:
            C[i, p] += val * B[col, p]

复杂度 :O(nnz⋅k)O(nnz \cdot k)O(nnz⋅k)

10.1.2 SpMM 的并行化

GPU 并行策略

  • 行并行:每个线程处理一行
  • 元素并行:每个线程处理一个非零元素
  • 块并行:每个线程块处理一个块

10.2 稀疏-密集向量乘法(SpMV)

算法 10.2(CSR SpMV)

复制代码
输入:CSR 矩阵 A (m×n), 密集向量 x (n)
输出:y = A @ x (m)

for i = 0, ..., m-1:
    sum = 0
    for j = row_ptr[i], ..., row_ptr[i+1]-1:
        sum += values[j] * x[col_indices[j]]
    y[i] = sum

复杂度 :O(nnz)O(nnz)O(nnz)

10.3 稀疏计算的实际加速

10.3.1 加速比分析

理论加速比

Speedup理论=mnnnz=11−s\text{Speedup}_{\text{理论}} = \frac{mn}{nnz} = \frac{1}{1-s}Speedup理论=nnzmn=1−s1

其中 sss 是稀疏率。

实际加速比:由于内存访问模式不规则、并行度不足等原因,实际加速比通常远低于理论值。

稀疏率 理论加速 实际加速(CPU) 实际加速(GPU)
50% 2x 1.2x 1.1x
90% 10x 3x 2x
95% 20x 5x 3x
99% 100x 10x 5x

10.3.2 结构化稀疏的优势

问题:非结构化稀疏的实际加速比很低。

解决方案 :使用结构化稀疏------稀疏模式是规则的,可以利用硬件的块操作。


第十一章:2:4 结构化稀疏与 NVIDIA Ampere

11.1 2:4 稀疏模式

11.1.1 定义

2:4 稀疏 (也称为 N:M 稀疏 ,N=2,M=4N=2, M=4N=2,M=4):在每 4 个连续元素中,恰好有 2 个为零。

示例

w=w1,0,w3,0,0,w6,w7,0,... \mathbf{w} = w_1, 0, w_3, 0, 0, w_6, w_7, 0, \\dotsw=w1,0,w3,0,0,w6,w7,0,...

每 4 个元素中恰好 2 个非零。

11.1.2 理论分析

定理 11.1(2:4 稀疏的表达能力):2:4 稀疏的等效参数量为原始的 50%,但通过选择最优的稀疏模式,可以保留大部分模型精度。

最优稀疏模式:在每 4 个元素中,保留绝对值最大的 2 个:

maski:i+4=top-2-mask(∣wi:i+4∣)\text{mask}{i:i+4} = \text{top-2-mask}(|w{i:i+4}|)maski:i+4=top-2-mask(∣wi:i+4∣)

11.1.3 硬件支持

NVIDIA A100/H100 的稀疏张量核心

  • 支持 2:4 结构化稀疏的矩阵乘法
  • 理论吞吐量是密集矩阵的 2 倍
  • 稀疏矩阵只需要存储非零元素和 2-bit 索引

工作原理

复制代码
1. 稀疏矩阵 A (2:4 格式):存储非零值 + 2-bit 索引
2. 密集矩阵 B:正常存储
3. 硬件使用索引从 B 中选择对应列,与 A 的非零值相乘
4. 结果 C = A @ B

11.2 2:4 稀疏的训练

11.2.1 稀疏感知训练

算法 11.1(2:4 稀疏感知训练)

复制代码
1. 初始化:密集权重 W
2. 每个训练步:
   a. 前向传播:使用 2:4 稀疏权重
      - 应用 2:4 掩码:W_sparse = W * mask
      - 计算输出:Y = X @ W_sparse^T
   b. 反向传播:使用 STE
      - 梯度通过掩码传递:grad_W = grad_Y @ X
   c. 更新:W = W - lr * grad_W
   d. 重新计算掩码(每隔 N 步)

11.2.2 掩码更新策略

静态掩码:训练开始时确定掩码,训练过程中不变。

动态掩码:每隔 N 步重新计算掩码(基于当前权重的幅度)。

渐进式掩码:从密集开始,逐渐增加稀疏率。

11.3 2:4 稀疏的精度

11.3.1 实验结果

模型 方法 密集精度 2:4 稀疏精度 精度下降
ResNet-50 训练后剪枝 76.1% 75.2% 0.9%
ResNet-50 稀疏感知训练 76.1% 76.0% 0.1%
BERT-Large 训练后剪枝 91.0% 90.5% 0.5%
BERT-Large 稀疏感知训练 91.0% 90.9% 0.1%

11.3.2 与量化的结合

2:4 稀疏 + INT8 量化

Y=(XINT8⋅WINT8, sparse)⋅sX⋅sWY = (X_{\text{INT8}} \cdot W_{\text{INT8, sparse}}) \cdot s_X \cdot s_WY=(XINT8⋅WINT8, sparse)⋅sX⋅sW

理论加速:2x(稀疏)× 2x(INT8 vs FP16)= 4x


第五部分:完整可运行代码实现


第十二章:从零实现非结构化剪枝系统

python 复制代码
"""
非结构化剪枝系统的完整实现。
包含:幅度剪枝、迭代剪枝、SNIP、Wanda。
"""

import numpy as np
from typing import Tuple, Optional


def magnitude_prune(weights: np.ndarray, sparsity: float) -> Tuple[np.ndarray, np.ndarray]:
    """幅度剪枝。

    删除绝对值最小的权重。

    Args:
        weights: 权重矩阵
        sparsity: 目标稀疏率 (0-1)

    Returns:
        pruned_weights: 剪枝后的权重
        mask: 二值掩码 (1=保留, 0=删除)
    """
    # 计算阈值
    flat_weights = np.abs(weights.flatten())
    k = int(len(flat_weights) * (1 - sparsity))
    threshold = np.sort(flat_weights)[k - 1] if k > 0 else np.inf

    # 创建掩码
    mask = (np.abs(weights) >= threshold).astype(float)

    # 应用掩码
    pruned_weights = weights * mask

    return pruned_weights, mask


def snip_prune(
    weights: np.ndarray,
    gradients: np.ndarray,
    sparsity: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """SNIP 剪枝。

    基于连接灵敏度 |w * grad| 进行剪枝。

    Args:
        weights: 权重矩阵
        gradients: 梯度矩阵
        sparsity: 目标稀疏率

    Returns:
        pruned_weights: 剪枝后的权重
        mask: 二值掩码
    """
    # 计算连接灵敏度
    sensitivity = np.abs(weights * gradients)

    # 计算阈值
    k = int(weights.size * (1 - sparsity))
    threshold = np.sort(sensitivity.flatten())[k - 1] if k > 0 else np.inf

    # 创建掩码
    mask = (sensitivity >= threshold).astype(float)

    return weights * mask, mask


def wanda_prune(
    weights: np.ndarray,
    activations: np.ndarray,
    sparsity: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """Wanda 剪枝。

    基于 |w| * ||activation|| 进行剪枝。

    Args:
        weights: 权重矩阵 (m, n)
        activations: 输入激活 (T, n)
        sparsity: 目标稀疏率

    Returns:
        pruned_weights: 剪枝后的权重
        mask: 二值掩码
    """
    # 计算激活范数
    act_norms = np.linalg.norm(activations, axis=0)  # (n,)

    # 计算重要性分数
    scores = np.abs(weights) * act_norms[np.newaxis, :]  # (m, n)

    # 计算阈值
    k = int(weights.size * (1 - sparsity))
    threshold = np.sort(scores.flatten())[k - 1] if k > 0 else np.inf

    # 创建掩码
    mask = (scores >= threshold).astype(float)

    return weights * mask, mask


def iterative_magnitude_prune(
    weights: np.ndarray,
    X: np.ndarray,
    Y: np.ndarray,
    target_sparsity: float,
    n_steps: int = 10,
    n_finetune_steps: int = 50,
    lr: float = 0.001,
) -> Tuple[np.ndarray, np.ndarray, list]:
    """迭代幅度剪枝。

    Args:
        weights: 初始权重
        X: 训练输入
        Y: 训练目标
        target_sparsity: 目标稀疏率
        n_steps: 剪枝步骤数
        n_finetune_steps: 每步微调的迭代数
        lr: 学习率

    Returns:
        final_weights: 最终权重
        final_mask: 最终掩码
        loss_history: 损失历史
    """
    current_weights = weights.copy()
    loss_history = []

    sparsity_per_step = 1 - (1 - target_sparsity) ** (1 / n_steps)

    for step in range(n_steps):
        # 剪枝
        _, mask = magnitude_prune(current_weights, sparsity_per_step)
        current_weights = current_weights * mask

        # 微调
        for _ in range(n_finetune_steps):
            # 前向传播
            Y_pred = X @ current_weights.T
            loss = np.mean((Y - Y_pred) ** 2)
            loss_history.append(loss)

            # 反向传播
            grad = 2 / X.shape[0] * (Y_pred - Y).T @ X

            # 更新(只更新非零位置)
            current_weights = current_weights - lr * grad * mask

        current_sparsity = 1 - np.mean(mask)
        print(f"  步骤 {step + 1}/{n_steps}: 稀疏率 = {current_sparsity:.2%}, "
              f"损失 = {loss_history[-1]:.6f}")

    final_mask = (current_weights != 0).astype(float)
    return current_weights, final_mask, loss_history


def demonstrate_pruning():
    """演示各种剪枝方法。"""
    np.random.seed(42)

    # 设置
    m, n = 128, 256
    T = 500
    sparsity = 0.9  # 90% 稀疏率

    print("=" * 70)
    print("非结构化剪枝方法对比")
    print("=" * 70)
    print(f"权重: {m}x{n}, 稀疏率: {sparsity:.0%}")
    print()

    # 生成数据
    W = np.random.randn(m, n) * 0.02
    X = np.random.randn(T, n) * 0.5
    Y = X @ W.T + np.random.randn(T, m) * 0.01

    # 参考 MSE
    Y_ref = X @ W.T
    ref_mse = np.mean((Y - Y_ref) ** 2)

    # 计算梯度(用于 SNIP)
    Y_pred = X @ W.T
    grad = 2 / T * (Y_pred - Y).T @ X  # (m, n)

    results = []

    # 方法 1:幅度剪枝
    W_mag, mask_mag = magnitude_prune(W, sparsity)
    Y_mag = X @ W_mag.T
    mse_mag = np.mean((Y_ref - Y_mag) ** 2)
    results.append(("幅度剪枝", mse_mag, np.mean(mask_mag)))

    # 方法 2:SNIP
    W_snip, mask_snip = snip_prune(W, grad, sparsity)
    Y_snip = X @ W_snip.T
    mse_snip = np.mean((Y_ref - Y_snip) ** 2)
    results.append(("SNIP", mse_snip, np.mean(mask_snip)))

    # 方法 3:Wanda
    W_wanda, mask_wanda = wanda_prune(W, X, sparsity)
    Y_wanda = X @ W_wanda.T
    mse_wanda = np.mean((Y_ref - Y_wanda) ** 2)
    results.append(("Wanda", mse_wanda, np.mean(mask_wanda)))

    # 方法 4:迭代幅度剪枝
    W_imp, mask_imp, _ = iterative_magnitude_prune(
        W, X, Y, target_sparsity=sparsity, n_steps=5, n_finetune_steps=20, lr=0.0001
    )
    Y_imp = X @ W_imp.T
    mse_imp = np.mean((Y_ref - Y_imp) ** 2)
    results.append(("迭代幅度剪枝", mse_imp, np.mean(mask_imp)))

    # 打印结果
    print(f"\n  {'方法':>15} {'MSE':>15} {'非零比例':>10} {'相对误差':>12}")
    print(f"  {'-'*15} {'-'*15} {'-'*10} {'-'*12}")

    for name, mse, nnz_ratio in results:
        rel_err = np.sqrt(mse) / np.std(Y_ref) if mse > 0 else 0
        print(f"  {name:>15} {mse:>15.10f} {nnz_ratio:>10.2%} {rel_err:>12.6f}")

    # 不同稀疏率的分析
    print(f"\n  不同稀疏率下的幅度剪枝 MSE:")
    print(f"  {'稀疏率':>10} {'MSE':>15} {'相对误差':>12} {'非零元素':>10}")
    print(f"  {'-'*10} {'-'*15} {'-'*12} {'-'*10}")

    for s in [0.5, 0.7, 0.8, 0.9, 0.95, 0.99]:
        W_s, mask_s = magnitude_prune(W, s)
        Y_s = X @ W_s.T
        mse_s = np.mean((Y_ref - Y_s) ** 2)
        rel_err = np.sqrt(mse_s) / np.std(Y_ref)
        nnz = int(mask_s.sum())
        print(f"  {s:>10.0%} {mse_s:>15.10f} {rel_err:>12.6f} {nnz:>10}")


if __name__ == "__main__":
    demonstrate_pruning()

第十三章:从零实现结构化剪枝系统

python 复制代码
"""
结构化剪枝系统的完整实现。
包含:通道剪枝、注意力头剪枝、层剪枝。
"""

import numpy as np
from typing import List, Tuple


def channel_prune_l1(
    W: np.ndarray,
    keep_ratio: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """基于 L1 范数的通道剪枝。

    Args:
        W: 权重矩阵 (m, n),每行是一个通道
        keep_ratio: 保留比例

    Returns:
        W_pruned: 剪枝后的权重
        channel_mask: 通道掩码
    """
    m = W.shape[0]
    k = max(1, int(m * keep_ratio))

    # 计算每个通道的 L1 范数
    channel_norms = np.linalg.norm(W, ord=1, axis=1)

    # 保留范数最大的 k 个通道
    top_k_indices = np.argsort(channel_norms)[-k:]
    channel_mask = np.zeros(m)
    channel_mask[top_k_indices] = 1

    W_pruned = W * channel_mask[:, np.newaxis]

    return W_pruned, channel_mask


def channel_prune_l2(
    W: np.ndarray,
    keep_ratio: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """基于 L2 范数的通道剪枝。

    Args:
        W: 权重矩阵 (m, n)
        keep_ratio: 保留比例

    Returns:
        W_pruned: 剪枝后的权重
        channel_mask: 通道掩码
    """
    m = W.shape[0]
    k = max(1, int(m * keep_ratio))

    # 计算每个通道的 L2 范数
    channel_norms = np.linalg.norm(W, ord=2, axis=1)

    top_k_indices = np.argsort(channel_norms)[-k:]
    channel_mask = np.zeros(m)
    channel_mask[top_k_indices] = 1

    W_pruned = W * channel_mask[:, np.newaxis]

    return W_pruned, channel_mask


def channel_prune_gradient(
    W: np.ndarray,
    gradients: np.ndarray,
    activations: np.ndarray,
    keep_ratio: float,
) -> Tuple[np.ndarray, np.ndarray]:
    """基于梯度的通道剪枝。

    重要性 = ||通道权重|| * ||通道梯度|| * ||通道激活||

    Args:
        W: 权重矩阵 (m, n)
        gradients: 输出梯度 (T, m)
        activations: 输入激活 (T, n)
        keep_ratio: 保留比例

    Returns:
        W_pruned: 剪枝后的权重
        channel_mask: 通道掩码
    """
    m = W.shape[0]
    k = max(1, int(m * keep_ratio))

    # 计算通道重要性
    weight_norms = np.linalg.norm(W, axis=1)  # (m,)
    grad_norms = np.linalg.norm(gradients, axis=0)  # (m,)
    act_norms = np.linalg.norm(activations, axis=0)  # (n,)

    # 重要性 = 权重范数 * 梯度范数
    importance = weight_norms * grad_norms

    top_k_indices = np.argsort(importance)[-k:]
    channel_mask = np.zeros(m)
    channel_mask[top_k_indices] = 1

    W_pruned = W * channel_mask[:, np.newaxis]

    return W_pruned, channel_mask


def attention_head_prune(
    Q: np.ndarray,
    K: np.ndarray,
    V: np.ndarray,
    O: np.ndarray,
    X: np.ndarray,
    keep_ratio: float,
    n_heads: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """注意力头剪枝。

    Args:
        Q, K, V, O: 注意力层的权重矩阵
        X: 输入激活
        keep_ratio: 头保留比例
        n_heads: 注意力头数量

    Returns:
        Q_pruned, K_pruned, V_pruned, O_pruned: 剪枝后的权重
        head_mask: 头掩码
    """
    d_model = Q.shape[0]
    d_head = d_model // n_heads
    k = max(1, int(n_heads * keep_ratio))

    # 计算每个头的重要性
    head_importance = np.zeros(n_heads)

    for h in range(n_heads):
        # 提取该头的权重
        start = h * d_head
        end = (h + 1) * d_head

        Q_h = Q[start:end, :]
        K_h = K[start:end, :]
        V_h = V[start:end, :]

        # 重要性 = 各权重矩阵的 Frobenius 范数之和
        head_importance[h] = (
            np.linalg.norm(Q_h, 'fro') +
            np.linalg.norm(K_h, 'fro') +
            np.linalg.norm(V_h, 'fro')
        )

    # 选择最重要的头
    top_k_heads = np.argsort(head_importance)[-k:]
    head_mask = np.zeros(n_heads)
    head_mask[top_k_heads] = 1

    # 创建通道掩码
    channel_mask = np.repeat(head_mask, d_head)

    # 应用掩码
    Q_pruned = Q * channel_mask[:, np.newaxis]
    K_pruned = K * channel_mask[:, np.newaxis]
    V_pruned = V * channel_mask[:, np.newaxis]
    O_pruned = O * channel_mask[np.newaxis, :]

    return Q_pruned, K_pruned, V_pruned, O_pruned, head_mask


def layer_importance_cosine(
    hidden_states: List[np.ndarray],
) -> np.ndarray:
    """基于余弦相似度的层重要性。

    Args:
        hidden_states: 各层的隐藏状态列表

    Returns:
        importance: 层重要性分数
    """
    n_layers = len(hidden_states) - 1
    importance = np.zeros(n_layers)

    for l in range(n_layers):
        h_l = hidden_states[l].flatten()
        h_l1 = hidden_states[l + 1].flatten()

        # 余弦相似度
        cos_sim = np.dot(h_l, h_l1) / (np.linalg.norm(h_l) * np.linalg.norm(h_l1) + 1e-8)

        # 重要性 = 1 - 余弦相似度(变化越大越重要)
        importance[l] = 1 - cos_sim

    return importance


def layer_prune(
    weights_list: List[np.ndarray],
    hidden_states: List[np.ndarray],
    keep_ratio: float,
) -> Tuple[List[np.ndarray], np.ndarray]:
    """层剪枝。

    Args:
        weights_list: 各层的权重列表
        hidden_states: 各层的隐藏状态
        keep_ratio: 层保留比例

    Returns:
        pruned_weights: 剪枝后的权重列表
        layer_mask: 层掩码
    """
    n_layers = len(weights_list)
    k = max(1, int(n_layers * keep_ratio))

    # 计算层重要性
    importance = layer_importance_cosine(hidden_states)

    # 选择最重要的层
    top_k_layers = np.argsort(importance)[-k:]
    layer_mask = np.zeros(n_layers)
    layer_mask[top_k_layers] = 1

    pruned_weights = [w for w, m in zip(weights_list, layer_mask) if m > 0]

    return pruned_weights, layer_mask


def demonstrate_structured_pruning():
    """演示结构化剪枝。"""
    np.random.seed(42)

    print("=" * 70)
    print("结构化剪枝演示")
    print("=" * 70)

    # 通道剪枝
    print("\n  1. 通道剪枝")
    print("  " + "-" * 40)

    m, n = 64, 128
    W = np.random.randn(m, n) * 0.02

    for keep_ratio in [0.5, 0.25, 0.1]:
        # L1 剪枝
        W_l1, mask_l1 = channel_prune_l1(W, keep_ratio)
        nnz_l1 = int(mask_l1.sum())

        # L2 剪枝
        W_l2, mask_l2 = channel_prune_l2(W, keep_ratio)
        nnz_l2 = int(mask_l2.sum())

        print(f"    保留 {keep_ratio:.0%}: L1 保留 {nnz_l1} 通道, L2 保留 {nnz_l2} 通道")

    # 注意力头剪枝
    print("\n  2. 注意力头剪枝")
    print("  " + "-" * 40)

    d_model = 256
    n_heads = 8
    d_head = d_model // n_heads

    Q = np.random.randn(d_model, d_model) * 0.02
    K = np.random.randn(d_model, d_model) * 0.02
    V = np.random.randn(d_model, d_model) * 0.02
    O = np.random.randn(d_model, d_model) * 0.02
    X = np.random.randn(32, d_model) * 0.5

    for keep_ratio in [0.75, 0.5, 0.25]:
        Q_p, K_p, V_p, O_p, head_mask = attention_head_prune(
            Q, K, V, O, X, keep_ratio, n_heads
        )
        n_kept = int(head_mask.sum())
        print(f"    保留 {keep_ratio:.0%}: {n_kept}/{n_heads} 头")

    # 层重要性分析
    print("\n  3. 层重要性分析")
    print("  " + "-" * 40)

    n_layers = 12
    d = 256
    hidden_states = [np.random.randn(32, d) * 0.5]

    for l in range(n_layers):
        # 模拟:层越深,表示变化越小
        h_new = hidden_states[-1] + np.random.randn(32, d) * (0.5 / (l + 1))
        hidden_states.append(h_new)

    importance = layer_importance_cosine(hidden_states)

    print(f"    层重要性(1 - 余弦相似度):")
    for l, imp in enumerate(importance):
        bar = "█" * int(imp * 200)
        print(f"      层 {l:2d}: {imp:.4f} {bar}")


if __name__ == "__main__":
    demonstrate_structured_pruning()

第十四章:从零实现稀疏矩阵运算

python 复制代码
"""
稀疏矩阵运算的完整实现。
包含:CSR 格式、SpMV、SpMM、2:4 稀疏。
"""

import numpy as np
from typing import Tuple


class CSRMatrix:
    """CSR 格式的稀疏矩阵。"""

    def __init__(self, values, col_indices, row_ptr, shape):
        self.values = values
        self.col_indices = col_indices
        self.row_ptr = row_ptr
        self.shape = shape

    @classmethod
    def from_dense(cls, A: np.ndarray, threshold: float = 0.0) -> 'CSRMatrix':
        """从密集矩阵创建 CSR 矩阵。"""
        m, n = A.shape
        values = []
        col_indices = []
        row_ptr = [0]

        for i in range(m):
            for j in range(n):
                if abs(A[i, j]) > threshold:
                    values.append(A[i, j])
                    col_indices.append(j)
            row_ptr.append(len(values))

        return cls(
            np.array(values),
            np.array(col_indices, dtype=np.int32),
            np.array(row_ptr, dtype=np.int32),
            (m, n)
        )

    def to_dense(self) -> np.ndarray:
        """转换为密集矩阵。"""
        m, n = self.shape
        A = np.zeros((m, n))
        for i in range(m):
            for j in range(self.row_ptr[i], self.row_ptr[i + 1]):
                A[i, self.col_indices[j]] = self.values[j]
        return A

    def spmv(self, x: np.ndarray) -> np.ndarray:
        """稀疏矩阵-密集向量乘法 (SpMV)。"""
        m, n = self.shape
        y = np.zeros(m)

        for i in range(m):
            s = 0.0
            for j in range(self.row_ptr[i], self.row_ptr[i + 1]):
                s += self.values[j] * x[self.col_indices[j]]
            y[i] = s

        return y

    def spmm(self, B: np.ndarray) -> np.ndarray:
        """稀疏矩阵-密集矩阵乘法 (SpMM)。"""
        m, n = self.shape
        k = B.shape[1]
        C = np.zeros((m, k))

        for i in range(m):
            for jj in range(self.row_ptr[i], self.row_ptr[i + 1]):
                j = self.col_indices[jj]
                val = self.values[jj]
                for p in range(k):
                    C[i, p] += val * B[j, p]

        return C

    @property
    def nnz(self) -> int:
        return len(self.values)

    @property
    def sparsity(self) -> float:
        return 1 - self.nnz / (self.shape[0] * self.shape[1])


def apply_2_4_sparsity(W: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """应用 2:4 结构化稀疏。

    在每 4 个连续元素中保留绝对值最大的 2 个。

    Args:
        W: 权重矩阵

    Returns:
        W_sparse: 稀疏权重
        mask: 2:4 掩码
    """
    W_flat = W.flatten()
    n = len(W_flat)
    mask = np.zeros(n)

    # 每 4 个元素处理一组
    for i in range(0, n, 4):
        group = W_flat[i:i + 4]
        if len(group) < 4:
            # 不足 4 个的组,全部保留
            mask[i:i + len(group)] = 1
            continue

        # 选择绝对值最大的 2 个
        top_2 = np.argsort(np.abs(group))[-2:]
        for idx in top_2:
            mask[i + idx] = 1

    mask = mask.reshape(W.shape)
    W_sparse = W * mask

    return W_sparse, mask


def demonstrate_sparse_operations():
    """演示稀疏矩阵运算。"""
    np.random.seed(42)

    print("=" * 70)
    print("稀疏矩阵运算演示")
    print("=" * 70)

    # 创建稀疏矩阵
    m, n = 128, 256
    sparsity = 0.9

    # 生成稀疏矩阵
    A = np.random.randn(m, n) * 0.02
    mask = np.random.random((m, n)) > sparsity
    A_sparse = A * mask

    print(f"\n  矩阵大小: {m}x{n}")
    print(f"  稀疏率: {sparsity:.0%}")
    print(f"  非零元素: {int(mask.sum())}")

    # CSR 格式
    csr = CSRMatrix.from_dense(A_sparse, threshold=1e-10)
    print(f"\n  CSR 格式:")
    print(f"    NNZ: {csr.nnz}")
    print(f"    稀疏率: {csr.sparsity:.2%}")
    print(f"    存储 (CSR): {csr.nnz * 8 + (m + 1) * 4} 字节")
    print(f"    存储 (密集): {m * n * 4} 字节")
    print(f"    压缩比: {m * n * 4 / (csr.nnz * 8 + (m + 1) * 4):.2f}x")

    # SpMV
    print(f"\n  SpMV (稀疏矩阵-向量乘法):")
    x = np.random.randn(n)

    y_dense = A_sparse @ x
    y_sparse = csr.spmv(x)
    spmv_error = np.max(np.abs(y_dense - y_sparse))

    print(f"    最大误差: {spmv_error:.2e}")

    # SpMM
    print(f"\n  SpMM (稀疏矩阵-矩阵乘法):")
    k = 64
    B = np.random.randn(n, k)

    C_dense = A_sparse @ B
    C_sparse = csr.spmm(B)
    spmm_error = np.max(np.abs(C_dense - C_sparse))

    print(f"    最大误差: {spmm_error:.2e}")

    # 2:4 结构化稀疏
    print(f"\n  2:4 结构化稀疏:")
    W = np.random.randn(64, 64) * 0.02

    W_24, mask_24 = apply_2_4_sparsity(W)
    actual_sparsity = 1 - np.mean(mask_24)
    nnz_per_group = np.mean([np.sum(mask_24.flatten()[i:i + 4]) for i in range(0, mask_24.size, 4)])

    print(f"    矩阵大小: {W.shape}")
    print(f"    目标稀疏率: 50%")
    print(f"    实际稀疏率: {actual_sparsity:.2%}")
    print(f"    每组平均非零数: {nnz_per_group:.1f}")

    # 误差分析
    reconstruction_error = np.linalg.norm(W - W_24, 'fro') / np.linalg.norm(W, 'fro')
    print(f"    重建相对误差: {reconstruction_error:.6f}")


if __name__ == "__main__":
    demonstrate_sparse_operations()

第十五章:完整剪枝 Pipeline 与精度对比

python 复制代码
"""
完整的剪枝 Pipeline。
对比各种剪枝方法在不同稀疏率下的精度。
"""

import numpy as np
from typing import Dict, List, Tuple


def run_full_pruning_comparison():
    """运行完整的剪枝方法对比。"""
    np.random.seed(42)

    # 设置
    m, n = 128, 256
    T = 500

    print("=" * 70)
    print("剪枝方法综合对比")
    print("=" * 70)
    print(f"权重: {m}x{n}, 训练样本: {T}")
    print()

    # 生成数据
    W = np.random.randn(m, n) * 0.02
    X = np.random.randn(T, n) * 0.5
    Y = X @ W.T + np.random.randn(T, m) * 0.001

    # 参考
    Y_ref = X @ W.T
    ref_mse = np.mean((Y - Y_ref) ** 2)
    ref_std = np.std(Y_ref)

    print(f"  参考 MSE: {ref_mse:.10f}")
    print()

    # 测试不同的稀疏率
    sparsities = [0.5, 0.7, 0.8, 0.9, 0.95, 0.99]

    # 定义剪枝方法
    def magnitude_prune(W, s):
        flat = np.abs(W.flatten())
        k = int(W.size * (1 - s))
        if k == 0:
            return np.zeros_like(W), np.zeros_like(W)
        threshold = np.sort(flat)[k - 1]
        mask = (np.abs(W) >= threshold).astype(float)
        return W * mask, mask

    def random_prune(W, s):
        mask = (np.random.random(W.shape) >= s).astype(float)
        return W * mask, mask

    def column_prune(W, s):
        n = W.shape[1]
        k = max(1, int(n * (1 - s)))
        col_norms = np.linalg.norm(W, axis=0)
        top_k = np.argsort(col_norms)[-k:]
        mask = np.zeros_like(W)
        mask[:, top_k] = 1
        return W * mask, mask

    def row_prune(W, s):
        m = W.shape[0]
        k = max(1, int(m * (1 - s)))
        row_norms = np.linalg.norm(W, axis=1)
        top_k = np.argsort(row_norms)[-k:]
        mask = np.zeros_like(W)
        mask[top_k, :] = 1
        return W * mask, mask

    methods = {
        "幅度剪枝": magnitude_prune,
        "随机剪枝": random_prune,
        "列剪枝 (结构化)": column_prune,
        "行剪枝 (结构化)": row_prune,
    }

    # 运行对比
    results = {}

    for method_name, method_fn in methods.items():
        results[method_name] = []

        for s in sparsities:
            W_pruned, mask = method_fn(W, s)
            Y_pruned = X @ W_pruned.T
            mse = np.mean((Y_ref - Y_pruned) ** 2)
            rel_err = np.sqrt(mse) / ref_std if mse > 0 else 0
            nnz = int(mask.sum())
            results[method_name].append((s, mse, rel_err, nnz))

    # 打印结果
    print(f"  {'方法':>18} | ", end="")
    for s in sparsities:
        print(f"{'s=' + str(int(s*100)) + '%':>12} | ", end="")
    print()
    print(f"  {'-'*18} | " + " | ".join(["-" * 12] * len(sparsities)) + " |")

    print(f"\n  MSE:")
    for method_name, method_results in results.items():
        print(f"  {method_name:>18} | ", end="")
        for s, mse, rel_err, nnz in method_results:
            print(f"{mse:>12.8f} | ", end="")
        print()

    print(f"\n  相对误差:")
    for method_name, method_results in results.items():
        print(f"  {method_name:>18} | ", end="")
        for s, mse, rel_err, nnz in method_results:
            print(f"{rel_err:>12.6f} | ", end="")
        print()

    # 存储分析
    print(f"\n  存储分析 (非结构化 vs 结构化):")
    print(f"  {'稀疏率':>8} {'非结构化(KB)':>14} {'结构化(KB)':>14} {'压缩比':>10}")
    print(f"  {'-'*8} {'-'*14} {'-'*14} {'-'*10}")

    dense_size = m * n * 4 / 1024  # KB

    for s in sparsities:
        # 非结构化:需要 CSR 格式
        nnz = int(m * n * (1 - s))
        sparse_csr_kb = (nnz * 8 + (m + 1) * 4) / 1024

        # 结构化(行剪枝):只需要存储保留的行
        k_rows = max(1, int(m * (1 - s)))
        struct_kb = k_rows * n * 4 / 1024

        print(f"  {s:>8.0%} {sparse_csr_kb:>14.2f} {struct_kb:>14.2f} {dense_size / struct_kb:>10.2f}x")


if __name__ == "__main__":
    run_full_pruning_comparison()

附录:关键公式汇总

A.1 稀疏性基础

公式 表达式
ℓ0\ell_0ℓ0 "范数" $|\mathbf{w}|_0 =
ℓ1\ell_1ℓ1 范数 $|\mathbf{w}|_1 = \sum_i
稀疏率 s=1−∣w∣0/ns = 1 - |\mathbf{w}|_0 / ns=1−∣w∣0/n

A.2 参数重要性

方法 重要性公式
幅度 $
SNIP $
Fisher Fii⋅wi2F_{ii} \cdot w_i^2Fii⋅wi2
OBD Hii⋅wi2H_{ii} \cdot w_i^2Hii⋅wi2
OBS wi2/H−1iiw_i^2 / H\^{-1}_{ii}wi2/H−1ii
Wanda $
GraSP −gi⋅(Hg)i-g_i \cdot (\mathbf{H}\mathbf{g})_i−gi⋅(Hg)i

A.3 剪枝理论

定理 表达式
OBD 损失增加 δLi=12Hiiθi2\delta\mathcal{L}i = \frac{1}{2} H{ii} \theta_i^2δLi=21Hiiθi2
OBS 损失增加 δLi=θi22H−1ii\delta\mathcal{L}i = \frac{\theta_i^2}{2H\^{-1}{ii}}δLi=2H−1iiθi2
6 dB/bit 规则 SQNR=6.02b−4.77\text{SQNR} = 6.02b - 4.77SQNR=6.02b−4.77 dB

A.4 稀疏矩阵格式

格式 存储
CSR nnz⋅(bval+bidx)+(m+1)⋅bidxnnz \cdot (b_{\text{val}} + b_{\text{idx}}) + (m+1) \cdot b_{\text{idx}}nnz⋅(bval+bidx)+(m+1)⋅bidx
CSC 同 CSR(列版本)
COO nnz⋅(2bidx+bval)nnz \cdot (2 b_{\text{idx}} + b_{\text{val}})nnz⋅(2bidx+bval)
BSR nnzblocks⋅br⋅bc⋅bval+索引开销nnz_{\text{blocks}} \cdot b_r \cdot b_c \cdot b_{\text{val}} + \text{索引开销}nnzblocks⋅br⋅bc⋅bval+索引开销

参考文献

  1. LeCun, Y., Denker, J., & Solla, S. (1989). Optimal brain damage. NeurIPS.
  2. Hassibi, B., & Stork, D. (1993). Second order derivatives for network pruning: Optimal brain surgeon. NeurIPS.
  3. Frankle, J., & Carlin, M. (2019). The lottery ticket hypothesis: Finding sparse, trainable neural networks. ICLR.
  4. Lee, J., et al. (2019). SNIP: Single-shot network pruning based on connection sensitivity. ICLR.
  5. Wang, C., et al. (2020). Picking winning tickets before training by preserving gradient flow. ICLR.
  6. Frantar, E., & Alistarh, D. (2023). SparseGPT: Massive language models can be accurately pruned in one-shot. ICML.
  7. Sun, M., et al. (2024). A simple and effective pruning approach for large language models. ICLR.
  8. Ma, X., et al. (2023). LLM-Pruner: On the structural pruning of large language models. NeurIPS.
  9. Zhu, M., & Gupta, S. (2017). To prune, or not to prune: Exploring the efficacy of pruning for model compression. ICLR Workshop.
  10. Michel, P., Levy, O., & Neubig, G. (2019). Are sixteen heads really better than one? NeurIPS.
相关推荐
赴山海bi1 小时前
家居类亚马逊Listing优化:DeepBI驱动的增长秘诀
人工智能
weixin_468466851 小时前
纳米 AI 搜索新手极速上手指南
人工智能·python·深度学习·搜索引擎·ai·语言模型·自然语言处理
逻辑君1 小时前
Foresight研究报告【20260011】
人工智能·线性代数·算法·矩阵
珊瑚里的鱼1 小时前
【动态规划】不同路径Ⅱ
算法·动态规划
YueJoy.AI1 小时前
AI应用的API安全:从认证到授权的完整指南
人工智能·ai·语言模型
码农小旋风1 小时前
使用 ChatGPT 聚合站前,先看安全和隐私判断清单
人工智能·安全·自然语言处理·chatgpt·claude
周易宅2 小时前
CLAUDE.md 与 MEMORY.md:AI 编程助手配置的两条平行铁轨
人工智能·ai·agent·claude
明志数科2 小时前
灵犀X2学跳舞技术解析:机器人动作学习的数据方案
人工智能·计算机视觉
不懂的浪漫2 小时前
Role Agent 方法论:如何把一个标准工作流 Agent 化
人工智能·ai·agent