基于 Top-K Logits 的 LLM 知识蒸馏实战

从 14B 到 1.5B:基于 Top-K Logits 的 LLM 知识蒸馏实战(Qwen2.5 全流程)

把 14B 大模型的能力"塞进"1.5B 小模型里,显存怎么压?速度怎么提?教师 logits 怎么存?本文用一个完整可跑的项目,带你打通 Qwen2.5 系列的 Logits 离线蒸馏全链路。

一、为什么要做 Logits 蒸馏

1.1 痛点

大模型推理成本高是业界共识。Qwen2.5-14B 能力强,但部署成本高;Qwen2.5-1.5B 轻量,但能力弱。知识蒸馏(Knowledge Distillation) 就是让学生模型在教师模型的"指导"下训练,逼近教师性能。

1.2 为什么选 Logits 蒸馏而非响应蒸馏

主流蒸馏分两派:

方法 教师输出 信息量 训练速度
Logits 蒸馏(本项目) 每个位置全词表的概率分布 极大 快(离线预计算)
响应蒸馏 教师生成的文本答案 有限 慢(需在线生成)

Logits 蒸馏的核心优势是 暗知识(Dark Knowledge)------教师模型对每个 token 的概率分布包含丰富的类间关系信息。比如做数学题时,"divide" 和 "multiply" 的概率对比,硬标签(标准答案)给不了,但 logits 能给。

而且教师 logits 可以离线预计算并存储,训练阶段无需再跑教师模型,大幅节省训练时间。

二、项目概览

  • 教师模型:Qwen2.5-14B-Instruct-AWQ(4-bit 量化)
  • 学生模型:Qwen2.5-1.5B-Instruct
  • 数据集:MetaMathQA(数学应用题)
  • 方法:Top-K Logits 离线蒸馏
  • 硬件:2× RTX 3090 (24GB)

为什么教师/学生选同系列

Qwen2.5 全系列(0.5B 到 72B)使用完全相同的 tokenizer(词表 152,064)。教师 logits 的 token 索引和学生 logits 的 token 索引天然对齐,零额外处理。

三、核心概念:温度、Top-K 与暗知识

3.1 温度(Temperature)的作用

复制代码
标准 Softmax (T=1):   p_i = exp(logit_i / 1) / Σ exp(logit_j / 1)
高温 Softmax (T=3):   p_i = exp(logit_i / 3) / Σ exp(logit_j / 3)

温度越高,概率分布越平滑 ,暴露出更多非顶部的概率信息。本项目用 T=3.0

3.2 Top-K 稀疏存储的"空间魔法"

Qwen2.5 词表 152,064,每个位置完整 logits 在 fp16 下占 304KB。对于 5000 条样本 × 2048 位置,全量存储需要约 2.9 TB

只保留概率最高的 K=50 个 token:

复制代码
全量: [N, 2048, 152064]  →  ~2.9 TB
Top-K: [N, 2048, 50]     →  ~1.4 GB
节省: 99.97%

而教师模型 Top-50 token 通常覆盖了 99%+ 的概率质量,信息几乎无损。

四、全流程详解

整个 pipeline 分四个阶段:

复制代码
MetaMathQA
    │
    ▼
[阶段0] prepare_data.py    分词 → input_ids.pt [N, 2048]
    │
    ▼
[阶段1] generate_logits.py  教师推理 + Top-K → top_indices.npy
    │
    ▼
[阶段2] train_student.py    学生蒸馏训练 → final_model/
    │
    ▼
[阶段3] eval_compare.py     推理对比

阶段 0:数据准备

python 复制代码
load_dataset("meta-math/MetaMathQA")
   → shuffle(seed=42).select(range(5000))
   → tokenizer(queries, max_length=2048, padding="max_length")
   → 保存为 input_ids.pt / attention_mask.pt

关键点 :把分词结果缓存成 .pt 文件,避免教师推理和学生训练两个阶段重复分词,同时保证输入完全对齐。

阶段 1:教师 Logits 生成

这是项目最有工程价值的一步:

python 复制代码
for batch_start in range(0, N, batch_size):
    input_ids_batch = input_ids[batch_start:batch_end].cuda()
    outputs = model(input_ids_batch, attention_mask=attention_mask_batch)
    logits = outputs.logits                    # [bs, 2048, 152064]

    # 提取 Top-K
    values, indices = torch.topk(logits, k=50) # [bs, 2048, 50]

    # 立即释放显存
    del outputs, logits
    torch.cuda.empty_cache()

    # 写入 mmap 文件
    top_indices[batch_start:batch_end] = indices.cpu().numpy().astype(np.int32)
    top_values[batch_start:batch_end]  = values.cpu().numpy().astype(np.float16)

    # 更新进度
    save_progress(batch_end)

三个工程亮点

  1. numpy mmap 预分配 :用 mode='w+' 预分配全量空间,每个样本在文件中有唯一物理偏移量,支持精确断点续传。
  2. 断点续传progress.json 记录最后完成索引,重启后从断点继续,不破坏已有数据。
  3. 即时显存释放 :每个 batch 后 del + empty_cache,把显存占用压到最低。

阶段 2:学生蒸馏训练

蒸馏损失公式

复制代码
Loss = α · T² · KL(p_teacher || p_student) + (1 - α) · CE(p_student, y_true)
       └─────────────────────────────────┘   └────────────────────────────┘
                   蒸馏项(软标签)                硬标签项(标准训练)

参数:α = 0.7, T = 3.0

为什么乘 T² ?温度 T 平滑了 logits 分布。根据链式法则,KL 损失对 student logits 的梯度会多一个 1/T 因子,高温下梯度量级急剧减小。乘 T² 补偿这个衰减,使 KL 损失梯度与标准 CE 在同一数量级。

Top-K 上的高效 KL 计算

python 复制代码
# 提取学生在教师 Top-K 位置对应的 logits(避免创建全词表张量)
student_topk = torch.gather(student_logits, dim=-1, index=teacher_top_indices)

p_teacher      = softmax(teacher_top_values / T, dim=-1)
log_p_student  = log_softmax(student_topk / T, dim=-1)

kl_loss = (p_teacher * (p_teacher.log() - log_p_student)).sum(dim=-1).mean()
kl_loss = kl_loss * T * T   # 梯度补偿

torch.gather 只在 Top-K 维度计算,避免创建 B, S, 152064 的全张量,计算和显存都大幅减少。

五、踩坑实录(重点推荐阅读)

项目过程踩了两个非常典型的坑,写出来给大家避雷。

5.1 cuBLAS fp16 GEMM Bug

现象 :教师模型前向推理在 lm_head 层崩溃:

复制代码
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling
`cublasGemmEx(... a, CUDA_R_16F, ... b, CUDA_R_16F, ...)`

根因分析

复制代码
PyTorch 2.10.0+cu128
   └── 捆绑 nvidia-cublas-cu12==12.8.4.1(有 bug)
          └── cublasGemmEx
                 ├── CUDA_R_16F  (fp16)  ❌
                 ├── CUDA_R_16BF (bf16)  ❌
                 └── CUDA_R_32F  (fp32)  ✅

注意 AWQ 的 Marlin 量化内核不受影响 (它有自有 int4→fp16 反量化路径,不走 cuBLAS)。所以教师模型的量化层正常,但 lm_head(标准 nn.Linear)会触发 bug。

临时方案 :把教师 lm_head 包装为 fp32 计算:

python 复制代码
class _Float32Linear(torch.nn.Module):
    def forward(self, x):
        return F.linear(x.float(), self.weight)

显存多占 ~1.5GB,RTX 3090 承受得住。但学生模型全是标准 Linear,无法逐个包装。

根治方案:升级 cuBLAS 到修复版本:

bash 复制代码
pip install nvidia-cublas-cu12==12.9.1.4 --force-reinstall

pip 会警告 torch 2.10.0+cu128 requires nvidia-cublas-cu12==12.8.4.1,这只是依赖声明警告,实际运行完全正常。

5.2 单卡 OOM 优化

学生模型训练时,原始配置 batch_size=4, max_seq_len=2048 导致显存爆炸:

复制代码
Logits 大小 = 4 × 2048 × 152064 × 2 bytes ≈ 2.5 GB

反向传播时这个大张量必须缓存,24GB 显存瞬间爆掉。

解法:Micro-batch + 梯度累积

python 复制代码
# config.py
per_device_batch_size = 1          # 从 4 降到 1
gradient_accumulation_steps = 8    # 从 2 升到 8
# 等效全局 batch size = 1 × 8 = 8(保持不变)

工作机制

复制代码
[ 样本 1 ] → 前向+反向 → 梯度累加暂存
[ 样本 2 ] → 前向+反向 → 梯度累加暂存
...
[ 样本 8 ] → 前向+反向 → 梯度累加 → 【optimizer.step() 更新参数】

显存占用降到原来的 25%,等效 batch size 不变,模型收敛效果无影响。

5.3 AWQ 加载方式变了

transformers 4.48+ 弃用了 autoawq,改用 gptqmodel 作为 AWQ 加载后端:

bash 复制代码
pip install gptqmodel

不要再装 autoawq------它已官方废弃,且会强制降级 transformers 到 4.47.1,与新版本冲突。

六、关键工程设计总结

决策 选择 原因
教师/学生同系列 Qwen2.5 全家桶 词表 100% 一致,零额外处理
存储格式 numpy mmap 训练时随机读取高效,不需全部加载内存
Top-K=50 稀疏存储 节省 ~99.97% 存储空间
KL 计算 torch.gather 只在 Top-K 维度计算,避免 B,S,152K 全张量
断点续传 progress.json logits 生成耗时长,支持中断恢复
mmap 懒加载 __getitem___setup_mmap() 解决多 worker DataLoader 的 pickle 问题
梯度累积 micro-batch=1, accum=8 单卡跑得动大词表模型

mmap 懒加载的小技巧

np.memmap 对象包含操作系统文件描述符,不可被 pickle 序列化 。DataLoader num_workers > 0 时会把 Dataset 通过 pickle 分发给子进程,直接初始化会报 PicklingError

解决方案:构造函数中设 _top_indices = None,在 __getitem__ 首次调用时才打开 mmap------每个 worker 进程独立打开自己的文件描述符。

七、评估效果

蒸馏成功的标志:

  • ✅ 蒸馏模型回答更结构化 (有步骤、有公式、有 \boxed{} 答案)
  • 减少幻觉(不会凭空编造不合理假设)
  • ✅ 回答风格更接近 14B 教师模型

蒸馏不能解决的:

  • ❌ 算术计算能力(1.5B 参数量的固有限制)
  • ❌ 教师模型本身不具备的知识

八、一键运行

bash 复制代码
# 环境
conda create -n logits-distill python=3.10
conda activate logits-distill
pip install torch==2.10.0 torchvision==0.25.0 \
    --index-url https://download.pytorch.org/whl/cu128
pip install -r requirements.txt

# 国内镜像(必要)
export HF_ENDPOINT=https://hf-mirror.com

# 全流程
bash run_all.sh

# 或分步
python prepare_data.py
python generate_logits.py
accelerate launch --num_processes=2 train_student.py
python eval_compare.py

写在最后

Logits 蒸馏看似原理简单(KL 散度 + 温度),但工程实现里的细节决定了能不能跑通:

  1. 存储压缩:Top-K 稀疏存储让 TB 级数据降到 GB 级
  2. 显存优化:mmap + 梯度累积让单卡 3090 也能跑大词表模型
  3. 环境陷阱:cuBLAS bug、AWQ 后端变更等坑要提前规避

希望这个项目能帮到你。如有问题欢迎评论区交流。

相关推荐
lkshop1 小时前
自研 GEO 系统实战:从架构设计到“一键投喂”多平台 AI 大模型
人工智能·geo
维基框架1 小时前
Claude Mythos Preview 发布后严重漏洞激增:安全还是营销?
人工智能·安全
Csvn1 小时前
AI Prompt 炼金术:让 AI 写代码 一次过
人工智能
HjhIron1 小时前
从 RAG 乱象到统一标准:MCP 凭什么成为 Agentic AI 的底座?
ai编程·mcp
Csvn1 小时前
AI 编程提效核心技巧(直接复制套用,大幅减少手写代码时间)
人工智能
delishcomcn1 小时前
预见性切割:机器学习如何提前预警碳带分切机的报废风险
人工智能·机器学习
拧AI螺丝2 小时前
你往 AI 里装的那些 skill,打开看过一眼吗?
人工智能·agent
学究天人2 小时前
数学星球:等价性(第1-4章)
人工智能
星释2 小时前
鸿蒙智能体开发实战:4.A2A 模式创建智能体
ai·harmonyos·鸿蒙·智能体