1. 引言
模型蒸馏(Knowledge Distillation)是将大模型(Teacher)的知识转移到小模型(Student)的技术。通过蒸馏,我们可以用 1/10 的参数量达到大模型 90%+ 的性能,大幅降低部署成本。
压缩方法全景:
| 方法 | 压缩比 | 精度损失 | 实现难度 |
|---|---|---|---|
| 知识蒸馏 | 5-20x | 低 | 中 |
| 层剪枝 | 2-4x | 低 | 低 |
| 注意力头剪枝 | 1.5-3x | 低 | 低 |
| 量化 | 2-4x | 很低 | 低 |
| 低秩分解 | 2-5x | 中 | 中 |
2. 知识蒸馏基础
2.1 经典蒸馏框架
Teacher (大模型) → soft targets (logits with temperature)
↓
蒸馏损失 (KL散度)
↓
Student (小模型) → hard targets (真实标签)
↓
交叉熵损失
总损失 = α × 蒸馏损失 + (1-α) × 交叉熵损失
2.2 蒸馏实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""知识蒸馏损失"""
def __init__(self, temperature=4.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
self.kl_loss = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, labels):
# 软标签蒸馏损失
soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
distill_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)
# 硬标签交叉熵损失
hard_loss = self.ce_loss(student_logits, labels)
# 加权组合
total = self.alpha * distill_loss + (1 - self.alpha) * hard_loss
return total, distill_loss.item(), hard_loss.item()
2.3 完整训练流程
python
from transformers import AutoModelForCausalLM, AutoTokenizer
def distill_llm(teacher_id, student_id, dataset, epochs=3, lr=5e-5):
"""LLM 知识蒸馏"""
teacher = AutoModelForCausalLM.from_pretrained(
teacher_id, device_map="auto", torch_dtype=torch.float16
)
student = AutoModelForCausalLM.from_pretrained(
student_id, device_map="auto", torch_dtype=torch.float16
)
teacher.eval() # Teacher 不更新
tokenizer = AutoTokenizer.from_pretrained(teacher_id)
optimizer = torch.optim.AdamW(student.parameters(), lr=lr)
distill_criterion = DistillationLoss(temperature=4.0, alpha=0.7)
for epoch in range(epochs):
total_loss = 0
for batch in dataset:
input_ids = batch["input_ids"].to(student.device)
labels = batch["labels"].to(student.device)
# Teacher 前向(不计算梯度)
with torch.no_grad():
teacher_outputs = teacher(input_ids=input_ids)
teacher_logits = teacher_outputs.logits
# Student 前向
student_outputs = student(input_ids=input_ids)
student_logits = student_outputs.logits
# 计算蒸馏损失
loss, distill_val, hard_val = distill_criterion(
student_logits.view(-1, student_logits.size(-1)),
teacher_logits.view(-1, teacher_logits.size(-1)),
labels.view(-1),
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataset):.4f}")
return student
3. 特征蒸馏
3.1 中间层特征对齐
python
class FeatureDistillation(nn.Module):
"""中间层特征蒸馏"""
def __init__(self, teacher_dims, student_dims):
super().__init__()
# 投影层:将 student 特征映射到 teacher 维度
self.projectors = nn.ModuleList([
nn.Linear(s_dim, t_dim)
for s_dim, t_dim in zip(student_dims, teacher_dims)
])
def forward(self, teacher_features, student_features):
loss = 0
for i, (t_feat, s_feat, proj) in enumerate(
zip(teacher_features, student_features, self.projectors)
):
s_proj = proj(s_feat)
loss += F.mse_loss(s_proj, t_feat.detach())
return loss / len(teacher_features)
class DistillableModel(nn.Module):
"""可蒸馏的模型包装器"""
def __init__(self, model, layer_names):
super().__init__()
self.model = model
self.features = {}
self.layer_names = layer_names
# 注册 hook
for name in layer_names:
layer = dict(model.named_modules())[name]
layer.register_forward_hook(self._get_hook(name))
def _get_hook(self, name):
def hook(module, input, output):
if isinstance(output, tuple):
self.features[name] = output[0]
else:
self.features[name] = output
return hook
def forward(self, *args, **kwargs):
self.features = {}
return self.model(*args, **kwargs)
4. 层剪枝
4.1 基于重要性的层剪枝
python
def compute_layer_importance(model, dataloader):
"""计算每层的重要性分数"""
importance = {}
for name, param in model.named_parameters():
if "layer" in name and "weight" in name:
importance[name] = torch.zeros_like(param)
model.eval()
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
for name, param in model.named_parameters():
if name in importance and param.grad is not None:
importance[name] += param.grad.abs().mean(dim=-1)
# 归一化
for name in importance:
importance[name] = importance[name].cpu().numpy()
return importance
def prune_layers(model, prune_ratio=0.3):
"""剪枝指定比例的层"""
# 计算每层重要性
layer_importance = {}
for name, module in model.named_modules():
if hasattr(module, 'weight'):
importance = module.weight.data.abs().mean().item()
layer_importance[name] = importance
# 排序并选择要剪枝的层
sorted_layers = sorted(layer_importance.items(), key=lambda x: x[1])
num_prune = int(len(sorted_layers) * prune_ratio)
layers_to_prune = [name for name, _ in sorted_layers[:num_prune]]
print(f"剪枝 {num_prune} 层: {layers_to_prune}")
# 执行剪枝
for name in layers_to_prune:
module = dict(model.named_modules())[name]
# 将权重置零(相当于移除该层)
nn.init.zeros_(module.weight)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
return model, layers_to_prune
4.2 LLM 层剪枝
python
def prune_llm_layers(model, keep_layers=None, prune_ratio=0.25):
"""LLM 层剪枝"""
if keep_layers is None:
total = model.config.num_hidden_layers
keep = int(total * (1 - prune_ratio))
# 保留前几层和后几层(通常更重要)
keep_layers = list(range(keep // 2)) + list(range(total - keep // 2, total))
# 重建模型的层
original_layers = model.model.layers
pruned_layers = nn.ModuleList([
original_layers[i] for i in keep_layers
])
model.model.layers = pruned_layers
model.config.num_hidden_layers = len(keep_layers)
print(f"保留 {len(keep_layers)}/{total} 层: {keep_layers}")
return model
5. 注意力头剪枝
python
def prune_attention_heads(model, dataloader, prune_ratio=0.3):
"""基于重要性的注意力头剪枝"""
head_importance = {}
# 计算每个头的重要性
for batch in dataloader:
outputs = model(**batch, output_attentions=True)
attentions = outputs.attentions # 每层的注意力权重
for layer_idx, attn in enumerate(attentions):
if layer_idx not in head_importance:
head_importance[layer_idx] = torch.zeros(attn.size(1))
# 头重要性 = 注意力熵(低熵 = 高重要性)
entropy = -(attn.mean(0) * torch.log(attn.mean(0) + 1e-8)).sum(-1)
head_importance[layer_idx] += entropy.detach().cpu()
# 剪枝
for layer_idx in head_importance:
scores = head_importance[layer_idx]
num_heads = len(scores)
num_prune = int(num_heads * prune_ratio)
# 保留最重要的头
_, keep_indices = torch.topk(scores, num_heads - num_prune)
# 将不重要的头权重置零
layer = model.model.layers[layer_idx].self_attn
for head_idx in range(num_heads):
if head_idx not in keep_indices:
head_dim = model.config.hidden_size // num_heads
start = head_idx * head_dim
end = start + head_dim
layer.q_proj.weight.data[start:end] = 0
layer.k_proj.weight.data[start:end] = 0
layer.v_proj.weight.data[start:end] = 0
return model
6. 结构化压缩 Pipeline
python
class ModelCompressor:
"""完整的模型压缩 Pipeline"""
def __init__(self, teacher_model, student_model=None):
self.teacher = teacher_model
self.student = student_model or teacher_model
def compress(self, dataloader, config):
"""执行完整压缩流程"""
results = {}
# Step 1: 知识蒸馏
if config.get("distill", True):
print("=== 知识蒸馏 ===")
self.student = distill_llm(
self.teacher, self.student, dataloader,
epochs=config.get("distill_epochs", 3),
)
results["distill"] = True
# Step 2: 层剪枝
if config.get("prune_layers", 0) > 0:
print("=== 层剪枝 ===")
self.student, pruned = prune_llm_layers(
self.student, prune_ratio=config["prune_layers"]
)
results["pruned_layers"] = pruned
# Step 3: 注意力头剪枝
if config.get("prune_heads", 0) > 0:
print("=== 注意力头剪枝 ===")
self.student = prune_attention_heads(
self.student, dataloader, config["prune_heads"]
)
# Step 4: 量化
if config.get("quantize", False):
print("=== 量化 ===")
from transformers import BitsAndBytesConfig
# ... 量化代码
return self.student, results
# 使用
compressor = ModelCompressor(teacher_model, student_model)
compressed_model, results = compressor.compress(dataloader, {
"distill": True,
"distill_epochs": 3,
"prune_layers": 0.25,
"prune_heads": 0.3,
"quantize": True,
})
7. 蒸馏效果对比
| 配置 | 参数量 | 推理速度 | 精度保持 |
|---|---|---|---|
| Teacher (7B) | 7B | 1x | 100% |
| Student (1.3B) | 1.3B | 5.4x | 89% |
| + 蒸馏 | 1.3B | 5.4x | 93% |
| + 层剪枝 25% | 1.0B | 7.0x | 91% |
| + 头剪枝 30% | 0.8B | 8.8x | 88% |
| + INT8 量化 | 0.8B | 12x | 87% |
8. 总结
模型蒸馏与压缩的核心要点:
- 知识蒸馏是最有效的压缩方法,精度损失最小
- 层剪枝实现简单,效果显著(保留首尾层效果更好)
- 注意力头剪枝可以进一步压缩,但需要仔细评估
- 组合使用:蒸馏 → 剪枝 → 量化,逐步压缩
- 蒸馏后的模型仍可继续微调,恢复部分精度损失