模型蒸馏实战:知识蒸馏、层剪枝与结构化压缩

1. 引言

模型蒸馏(Knowledge Distillation)是将大模型(Teacher)的知识转移到小模型(Student)的技术。通过蒸馏,我们可以用 1/10 的参数量达到大模型 90%+ 的性能,大幅降低部署成本。

压缩方法全景:

方法 压缩比 精度损失 实现难度
知识蒸馏 5-20x
层剪枝 2-4x
注意力头剪枝 1.5-3x
量化 2-4x 很低
低秩分解 2-5x

2. 知识蒸馏基础

2.1 经典蒸馏框架

复制代码
Teacher (大模型) → soft targets (logits with temperature)
                         ↓
                    蒸馏损失 (KL散度)
                         ↓
Student (小模型) → hard targets (真实标签)
                         ↓
                    交叉熵损失

总损失 = α × 蒸馏损失 + (1-α) × 交叉熵损失

2.2 蒸馏实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    """知识蒸馏损失"""

    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, labels):
        # 软标签蒸馏损失
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        distill_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)

        # 硬标签交叉熵损失
        hard_loss = self.ce_loss(student_logits, labels)

        # 加权组合
        total = self.alpha * distill_loss + (1 - self.alpha) * hard_loss
        return total, distill_loss.item(), hard_loss.item()

2.3 完整训练流程

python 复制代码
from transformers import AutoModelForCausalLM, AutoTokenizer

def distill_llm(teacher_id, student_id, dataset, epochs=3, lr=5e-5):
    """LLM 知识蒸馏"""
    teacher = AutoModelForCausalLM.from_pretrained(
        teacher_id, device_map="auto", torch_dtype=torch.float16
    )
    student = AutoModelForCausalLM.from_pretrained(
        student_id, device_map="auto", torch_dtype=torch.float16
    )
    teacher.eval()  # Teacher 不更新

    tokenizer = AutoTokenizer.from_pretrained(teacher_id)
    optimizer = torch.optim.AdamW(student.parameters(), lr=lr)
    distill_criterion = DistillationLoss(temperature=4.0, alpha=0.7)

    for epoch in range(epochs):
        total_loss = 0
        for batch in dataset:
            input_ids = batch["input_ids"].to(student.device)
            labels = batch["labels"].to(student.device)

            # Teacher 前向(不计算梯度)
            with torch.no_grad():
                teacher_outputs = teacher(input_ids=input_ids)
                teacher_logits = teacher_outputs.logits

            # Student 前向
            student_outputs = student(input_ids=input_ids)
            student_logits = student_outputs.logits

            # 计算蒸馏损失
            loss, distill_val, hard_val = distill_criterion(
                student_logits.view(-1, student_logits.size(-1)),
                teacher_logits.view(-1, teacher_logits.size(-1)),
                labels.view(-1),
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataset):.4f}")

    return student

3. 特征蒸馏

3.1 中间层特征对齐

python 复制代码
class FeatureDistillation(nn.Module):
    """中间层特征蒸馏"""

    def __init__(self, teacher_dims, student_dims):
        super().__init__()
        # 投影层:将 student 特征映射到 teacher 维度
        self.projectors = nn.ModuleList([
            nn.Linear(s_dim, t_dim)
            for s_dim, t_dim in zip(student_dims, teacher_dims)
        ])

    def forward(self, teacher_features, student_features):
        loss = 0
        for i, (t_feat, s_feat, proj) in enumerate(
            zip(teacher_features, student_features, self.projectors)
        ):
            s_proj = proj(s_feat)
            loss += F.mse_loss(s_proj, t_feat.detach())
        return loss / len(teacher_features)


class DistillableModel(nn.Module):
    """可蒸馏的模型包装器"""

    def __init__(self, model, layer_names):
        super().__init__()
        self.model = model
        self.features = {}
        self.layer_names = layer_names

        # 注册 hook
        for name in layer_names:
            layer = dict(model.named_modules())[name]
            layer.register_forward_hook(self._get_hook(name))

    def _get_hook(self, name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                self.features[name] = output[0]
            else:
                self.features[name] = output
        return hook

    def forward(self, *args, **kwargs):
        self.features = {}
        return self.model(*args, **kwargs)

4. 层剪枝

4.1 基于重要性的层剪枝

python 复制代码
def compute_layer_importance(model, dataloader):
    """计算每层的重要性分数"""
    importance = {}

    for name, param in model.named_parameters():
        if "layer" in name and "weight" in name:
            importance[name] = torch.zeros_like(param)

    model.eval()
    for batch in dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        for name, param in model.named_parameters():
            if name in importance and param.grad is not None:
                importance[name] += param.grad.abs().mean(dim=-1)

    # 归一化
    for name in importance:
        importance[name] = importance[name].cpu().numpy()

    return importance


def prune_layers(model, prune_ratio=0.3):
    """剪枝指定比例的层"""
    # 计算每层重要性
    layer_importance = {}
    for name, module in model.named_modules():
        if hasattr(module, 'weight'):
            importance = module.weight.data.abs().mean().item()
            layer_importance[name] = importance

    # 排序并选择要剪枝的层
    sorted_layers = sorted(layer_importance.items(), key=lambda x: x[1])
    num_prune = int(len(sorted_layers) * prune_ratio)
    layers_to_prune = [name for name, _ in sorted_layers[:num_prune]]

    print(f"剪枝 {num_prune} 层: {layers_to_prune}")

    # 执行剪枝
    for name in layers_to_prune:
        module = dict(model.named_modules())[name]
        # 将权重置零(相当于移除该层)
        nn.init.zeros_(module.weight)
        if hasattr(module, 'bias') and module.bias is not None:
            nn.init.zeros_(module.bias)

    return model, layers_to_prune

4.2 LLM 层剪枝

python 复制代码
def prune_llm_layers(model, keep_layers=None, prune_ratio=0.25):
    """LLM 层剪枝"""
    if keep_layers is None:
        total = model.config.num_hidden_layers
        keep = int(total * (1 - prune_ratio))
        # 保留前几层和后几层(通常更重要)
        keep_layers = list(range(keep // 2)) + list(range(total - keep // 2, total))

    # 重建模型的层
    original_layers = model.model.layers
    pruned_layers = nn.ModuleList([
        original_layers[i] for i in keep_layers
    ])

    model.model.layers = pruned_layers
    model.config.num_hidden_layers = len(keep_layers)

    print(f"保留 {len(keep_layers)}/{total} 层: {keep_layers}")
    return model

5. 注意力头剪枝

python 复制代码
def prune_attention_heads(model, dataloader, prune_ratio=0.3):
    """基于重要性的注意力头剪枝"""
    head_importance = {}

    # 计算每个头的重要性
    for batch in dataloader:
        outputs = model(**batch, output_attentions=True)
        attentions = outputs.attentions  # 每层的注意力权重

        for layer_idx, attn in enumerate(attentions):
            if layer_idx not in head_importance:
                head_importance[layer_idx] = torch.zeros(attn.size(1))

            # 头重要性 = 注意力熵(低熵 = 高重要性)
            entropy = -(attn.mean(0) * torch.log(attn.mean(0) + 1e-8)).sum(-1)
            head_importance[layer_idx] += entropy.detach().cpu()

    # 剪枝
    for layer_idx in head_importance:
        scores = head_importance[layer_idx]
        num_heads = len(scores)
        num_prune = int(num_heads * prune_ratio)

        # 保留最重要的头
        _, keep_indices = torch.topk(scores, num_heads - num_prune)

        # 将不重要的头权重置零
        layer = model.model.layers[layer_idx].self_attn
        for head_idx in range(num_heads):
            if head_idx not in keep_indices:
                head_dim = model.config.hidden_size // num_heads
                start = head_idx * head_dim
                end = start + head_dim
                layer.q_proj.weight.data[start:end] = 0
                layer.k_proj.weight.data[start:end] = 0
                layer.v_proj.weight.data[start:end] = 0

    return model

6. 结构化压缩 Pipeline

python 复制代码
class ModelCompressor:
    """完整的模型压缩 Pipeline"""

    def __init__(self, teacher_model, student_model=None):
        self.teacher = teacher_model
        self.student = student_model or teacher_model

    def compress(self, dataloader, config):
        """执行完整压缩流程"""
        results = {}

        # Step 1: 知识蒸馏
        if config.get("distill", True):
            print("=== 知识蒸馏 ===")
            self.student = distill_llm(
                self.teacher, self.student, dataloader,
                epochs=config.get("distill_epochs", 3),
            )
            results["distill"] = True

        # Step 2: 层剪枝
        if config.get("prune_layers", 0) > 0:
            print("=== 层剪枝 ===")
            self.student, pruned = prune_llm_layers(
                self.student, prune_ratio=config["prune_layers"]
            )
            results["pruned_layers"] = pruned

        # Step 3: 注意力头剪枝
        if config.get("prune_heads", 0) > 0:
            print("=== 注意力头剪枝 ===")
            self.student = prune_attention_heads(
                self.student, dataloader, config["prune_heads"]
            )

        # Step 4: 量化
        if config.get("quantize", False):
            print("=== 量化 ===")
            from transformers import BitsAndBytesConfig
            # ... 量化代码

        return self.student, results

# 使用
compressor = ModelCompressor(teacher_model, student_model)
compressed_model, results = compressor.compress(dataloader, {
    "distill": True,
    "distill_epochs": 3,
    "prune_layers": 0.25,
    "prune_heads": 0.3,
    "quantize": True,
})

7. 蒸馏效果对比

配置 参数量 推理速度 精度保持
Teacher (7B) 7B 1x 100%
Student (1.3B) 1.3B 5.4x 89%
+ 蒸馏 1.3B 5.4x 93%
+ 层剪枝 25% 1.0B 7.0x 91%
+ 头剪枝 30% 0.8B 8.8x 88%
+ INT8 量化 0.8B 12x 87%

8. 总结

模型蒸馏与压缩的核心要点:

  1. 知识蒸馏是最有效的压缩方法,精度损失最小
  2. 层剪枝实现简单,效果显著(保留首尾层效果更好)
  3. 注意力头剪枝可以进一步压缩,但需要仔细评估
  4. 组合使用:蒸馏 → 剪枝 → 量化,逐步压缩
  5. 蒸馏后的模型仍可继续微调,恢复部分精度损失