从 14B 到 1.5B:基于 Top-K Logits 的 LLM 知识蒸馏实战(Qwen2.5 全流程)
把 14B 大模型的能力"塞进"1.5B 小模型里,显存怎么压?速度怎么提?教师 logits 怎么存?本文用一个完整可跑的项目,带你打通 Qwen2.5 系列的 Logits 离线蒸馏全链路。
一、为什么要做 Logits 蒸馏
1.1 痛点
大模型推理成本高是业界共识。Qwen2.5-14B 能力强,但部署成本高;Qwen2.5-1.5B 轻量,但能力弱。知识蒸馏(Knowledge Distillation) 就是让学生模型在教师模型的"指导"下训练,逼近教师性能。
1.2 为什么选 Logits 蒸馏而非响应蒸馏
主流蒸馏分两派:
| 方法 | 教师输出 | 信息量 | 训练速度 |
|---|---|---|---|
| Logits 蒸馏(本项目) | 每个位置全词表的概率分布 | 极大 | 快(离线预计算) |
| 响应蒸馏 | 教师生成的文本答案 | 有限 | 慢(需在线生成) |
Logits 蒸馏的核心优势是 暗知识(Dark Knowledge)------教师模型对每个 token 的概率分布包含丰富的类间关系信息。比如做数学题时,"divide" 和 "multiply" 的概率对比,硬标签(标准答案)给不了,但 logits 能给。
而且教师 logits 可以离线预计算并存储,训练阶段无需再跑教师模型,大幅节省训练时间。
二、项目概览
- 教师模型:Qwen2.5-14B-Instruct-AWQ(4-bit 量化)
- 学生模型:Qwen2.5-1.5B-Instruct
- 数据集:MetaMathQA(数学应用题)
- 方法:Top-K Logits 离线蒸馏
- 硬件:2× RTX 3090 (24GB)
为什么教师/学生选同系列
Qwen2.5 全系列(0.5B 到 72B)使用完全相同的 tokenizer(词表 152,064)。教师 logits 的 token 索引和学生 logits 的 token 索引天然对齐,零额外处理。
三、核心概念:温度、Top-K 与暗知识
3.1 温度(Temperature)的作用
标准 Softmax (T=1): p_i = exp(logit_i / 1) / Σ exp(logit_j / 1)
高温 Softmax (T=3): p_i = exp(logit_i / 3) / Σ exp(logit_j / 3)
温度越高,概率分布越平滑 ,暴露出更多非顶部的概率信息。本项目用 T=3.0。
3.2 Top-K 稀疏存储的"空间魔法"
Qwen2.5 词表 152,064,每个位置完整 logits 在 fp16 下占 304KB。对于 5000 条样本 × 2048 位置,全量存储需要约 2.9 TB。
只保留概率最高的 K=50 个 token:
全量: [N, 2048, 152064] → ~2.9 TB
Top-K: [N, 2048, 50] → ~1.4 GB
节省: 99.97%
而教师模型 Top-50 token 通常覆盖了 99%+ 的概率质量,信息几乎无损。
四、全流程详解
整个 pipeline 分四个阶段:
MetaMathQA
│
▼
[阶段0] prepare_data.py 分词 → input_ids.pt [N, 2048]
│
▼
[阶段1] generate_logits.py 教师推理 + Top-K → top_indices.npy
│
▼
[阶段2] train_student.py 学生蒸馏训练 → final_model/
│
▼
[阶段3] eval_compare.py 推理对比
阶段 0:数据准备
python
load_dataset("meta-math/MetaMathQA")
→ shuffle(seed=42).select(range(5000))
→ tokenizer(queries, max_length=2048, padding="max_length")
→ 保存为 input_ids.pt / attention_mask.pt
关键点 :把分词结果缓存成 .pt 文件,避免教师推理和学生训练两个阶段重复分词,同时保证输入完全对齐。
阶段 1:教师 Logits 生成
这是项目最有工程价值的一步:
python
for batch_start in range(0, N, batch_size):
input_ids_batch = input_ids[batch_start:batch_end].cuda()
outputs = model(input_ids_batch, attention_mask=attention_mask_batch)
logits = outputs.logits # [bs, 2048, 152064]
# 提取 Top-K
values, indices = torch.topk(logits, k=50) # [bs, 2048, 50]
# 立即释放显存
del outputs, logits
torch.cuda.empty_cache()
# 写入 mmap 文件
top_indices[batch_start:batch_end] = indices.cpu().numpy().astype(np.int32)
top_values[batch_start:batch_end] = values.cpu().numpy().astype(np.float16)
# 更新进度
save_progress(batch_end)
三个工程亮点:
- numpy mmap 预分配 :用
mode='w+'预分配全量空间,每个样本在文件中有唯一物理偏移量,支持精确断点续传。 - 断点续传 :
progress.json记录最后完成索引,重启后从断点继续,不破坏已有数据。 - 即时显存释放 :每个 batch 后
del+empty_cache,把显存占用压到最低。
阶段 2:学生蒸馏训练
蒸馏损失公式:
Loss = α · T² · KL(p_teacher || p_student) + (1 - α) · CE(p_student, y_true)
└─────────────────────────────────┘ └────────────────────────────┘
蒸馏项(软标签) 硬标签项(标准训练)
参数:α = 0.7, T = 3.0。
为什么乘 T² ?温度 T 平滑了 logits 分布。根据链式法则,KL 损失对 student logits 的梯度会多一个 1/T 因子,高温下梯度量级急剧减小。乘 T² 补偿这个衰减,使 KL 损失梯度与标准 CE 在同一数量级。
Top-K 上的高效 KL 计算:
python
# 提取学生在教师 Top-K 位置对应的 logits(避免创建全词表张量)
student_topk = torch.gather(student_logits, dim=-1, index=teacher_top_indices)
p_teacher = softmax(teacher_top_values / T, dim=-1)
log_p_student = log_softmax(student_topk / T, dim=-1)
kl_loss = (p_teacher * (p_teacher.log() - log_p_student)).sum(dim=-1).mean()
kl_loss = kl_loss * T * T # 梯度补偿
用 torch.gather 只在 Top-K 维度计算,避免创建 B, S, 152064 的全张量,计算和显存都大幅减少。
五、踩坑实录(重点推荐阅读)
项目过程踩了两个非常典型的坑,写出来给大家避雷。
5.1 cuBLAS fp16 GEMM Bug
现象 :教师模型前向推理在 lm_head 层崩溃:
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling
`cublasGemmEx(... a, CUDA_R_16F, ... b, CUDA_R_16F, ...)`
根因分析:
PyTorch 2.10.0+cu128
└── 捆绑 nvidia-cublas-cu12==12.8.4.1(有 bug)
└── cublasGemmEx
├── CUDA_R_16F (fp16) ❌
├── CUDA_R_16BF (bf16) ❌
└── CUDA_R_32F (fp32) ✅
注意 AWQ 的 Marlin 量化内核不受影响 (它有自有 int4→fp16 反量化路径,不走 cuBLAS)。所以教师模型的量化层正常,但 lm_head(标准 nn.Linear)会触发 bug。
临时方案 :把教师 lm_head 包装为 fp32 计算:
python
class _Float32Linear(torch.nn.Module):
def forward(self, x):
return F.linear(x.float(), self.weight)
显存多占 ~1.5GB,RTX 3090 承受得住。但学生模型全是标准 Linear,无法逐个包装。
根治方案:升级 cuBLAS 到修复版本:
bash
pip install nvidia-cublas-cu12==12.9.1.4 --force-reinstall
pip 会警告
torch 2.10.0+cu128 requires nvidia-cublas-cu12==12.8.4.1,这只是依赖声明警告,实际运行完全正常。
5.2 单卡 OOM 优化
学生模型训练时,原始配置 batch_size=4, max_seq_len=2048 导致显存爆炸:
Logits 大小 = 4 × 2048 × 152064 × 2 bytes ≈ 2.5 GB
反向传播时这个大张量必须缓存,24GB 显存瞬间爆掉。
解法:Micro-batch + 梯度累积
python
# config.py
per_device_batch_size = 1 # 从 4 降到 1
gradient_accumulation_steps = 8 # 从 2 升到 8
# 等效全局 batch size = 1 × 8 = 8(保持不变)
工作机制:
[ 样本 1 ] → 前向+反向 → 梯度累加暂存
[ 样本 2 ] → 前向+反向 → 梯度累加暂存
...
[ 样本 8 ] → 前向+反向 → 梯度累加 → 【optimizer.step() 更新参数】
显存占用降到原来的 25%,等效 batch size 不变,模型收敛效果无影响。
5.3 AWQ 加载方式变了
transformers 4.48+ 弃用了 autoawq,改用 gptqmodel 作为 AWQ 加载后端:
bash
pip install gptqmodel
不要再装 autoawq------它已官方废弃,且会强制降级 transformers 到 4.47.1,与新版本冲突。
六、关键工程设计总结
| 决策 | 选择 | 原因 |
|---|---|---|
| 教师/学生同系列 | Qwen2.5 全家桶 | 词表 100% 一致,零额外处理 |
| 存储格式 | numpy mmap | 训练时随机读取高效,不需全部加载内存 |
| Top-K=50 | 稀疏存储 | 节省 ~99.97% 存储空间 |
| KL 计算 | torch.gather |
只在 Top-K 维度计算,避免 B,S,152K 全张量 |
| 断点续传 | progress.json |
logits 生成耗时长,支持中断恢复 |
| mmap 懒加载 | __getitem__ 内 _setup_mmap() |
解决多 worker DataLoader 的 pickle 问题 |
| 梯度累积 | micro-batch=1, accum=8 | 单卡跑得动大词表模型 |
mmap 懒加载的小技巧
np.memmap 对象包含操作系统文件描述符,不可被 pickle 序列化 。DataLoader num_workers > 0 时会把 Dataset 通过 pickle 分发给子进程,直接初始化会报 PicklingError。
解决方案:构造函数中设 _top_indices = None,在 __getitem__ 首次调用时才打开 mmap------每个 worker 进程独立打开自己的文件描述符。
七、评估效果
蒸馏成功的标志:
- ✅ 蒸馏模型回答更结构化 (有步骤、有公式、有
\boxed{}答案) - ✅ 减少幻觉(不会凭空编造不合理假设)
- ✅ 回答风格更接近 14B 教师模型
蒸馏不能解决的:
- ❌ 算术计算能力(1.5B 参数量的固有限制)
- ❌ 教师模型本身不具备的知识
八、一键运行
bash
# 环境
conda create -n logits-distill python=3.10
conda activate logits-distill
pip install torch==2.10.0 torchvision==0.25.0 \
--index-url https://download.pytorch.org/whl/cu128
pip install -r requirements.txt
# 国内镜像(必要)
export HF_ENDPOINT=https://hf-mirror.com
# 全流程
bash run_all.sh
# 或分步
python prepare_data.py
python generate_logits.py
accelerate launch --num_processes=2 train_student.py
python eval_compare.py
写在最后
Logits 蒸馏看似原理简单(KL 散度 + 温度),但工程实现里的细节决定了能不能跑通:
- 存储压缩:Top-K 稀疏存储让 TB 级数据降到 GB 级
- 显存优化:mmap + 梯度累积让单卡 3090 也能跑大词表模型
- 环境陷阱:cuBLAS bug、AWQ 后端变更等坑要提前规避
希望这个项目能帮到你。如有问题欢迎评论区交流。