人工智能|大模型——模型——大模型蒸馏详解(定义/原理/关键技术/落地)

摘要

大模型蒸馏(Model Distillation),即知识蒸馏(Knowledge Distillation),是一种将大型教师模型(如BERT、GPT-4o、DeepSeek-R1)的"隐含知识"高效迁移至轻量级学生模型(如DistilBERT、Qwen-1.5B、LSTM+Attention)的关键压缩技术。本文基于掘金、CSDN及行业实测文档三源权威材料,系统梳理其四大核心:① 定义与动因 ------直面2017--2024年参数量从5×10⁶暴涨至>10¹²的算力焦虑;② 四步闭环原理 ------教师训练→软标签生成(含温度T调控)→学生联合监督训练→评估优化;③ 关键技术要素 ------KL散度损失、温度参数T的物理意义、α加权总损失、输出层+中间层双路径知识转移;④ 真实落地证据------从DistilBERT"性能几乎相同、推理快2.5倍",到DeepSeek-R1-Distill系列在MATH-500达94.5%、编程能力媲美GPT-4o的硬核评测。全文无虚构,所有模型名、公式、参数、分数、流程均严格锚定原文事实。

关键词 / 标签建议

#大模型蒸馏 #知识蒸馏 #模型压缩 #DistilBERT #DeepSeek #KL散度 #温度参数 #软标签 #HuggingFace #边缘AI


大模型蒸馏全解析:从原理、数学本质到工业实践(含DistilBERT/DeepSeek-R1实测数据)

导读:当GPT-4参数突破万亿、DeepSeek-R1在数学推理上逼近人类专家,一个尖锐的工程现实浮出水面------如此庞大的模型,如何部署在手机、车载终端、医疗边缘设备甚至离线场景中?答案不是等待硬件跃进,而是主动"萃取":用知识蒸馏将大模型的智慧凝练为轻量、快速、低耗的小模型。本文严格依据掘金(juejin.cn)、CSDN(blog.csdn.net)及行业公开技术文档三源事实,不添加任何未验证内容,完整呈现大模型蒸馏的定义、原理、数学机制、全流程实现、工业案例与前沿挑战,助你真正掌握这一AI落地的核心杠杆。


一、为什么需要蒸馏?------参数爆炸时代的工程破局点

1.1 宏观背景

大语言模型的演进史,本质上是一部参数爆炸史 。从2017年Transformer初现的500万参数,到2023年GPT-4"突破万亿",再到2024年业界最常用的大模型稳定在约700亿参数量级 (7×10¹⁰),模型能力飞升的同时,也筑起了四道高墙(切实的部署困境):

年份 代表性模型 参数量 部署瓶颈
2017 Transformer 5×10⁶ 训练可行,但泛化弱
2018 GPT 1.1×10⁷ ---
BERT 3.4×10⁷ GPU显存需求陡增
2020 GPT-2 1.5×10⁹ 单卡无法加载,需模型并行
Megatron-LM 8.3×10⁹ ---
2021 GPT-3 1.75×10¹¹ 需数千A100集群,推理延迟秒级
T-NLG 1.7×10¹⁰ ---
2022 MT-NLG 5.3×10¹¹ ---
2023 GPT-4 >10¹² 仅限云服务,移动端完全不可行
2024 最大且最常用模型 ~7×10¹⁰ "最大且最常用",却仍难落地边缘

✅ 所有年份、模型名、参数量均直接引自CSDN博文(2025-02-25),为蒸馏必要性提供硬核量化支撑。

  • 硬件依赖:需A100/H100等高端GPU,单次GPT-4推理功耗相当于运行一台笔记本数分钟;
  • 延迟敏感场景失能:聊天机器人响应超500ms即影响体验;医疗设备需毫秒级本地分析,无法依赖云端往返;
  • 隐私与合规风险:金融、医疗数据上传至公有云存在泄露隐患;
  • 成本不可持续:云环境推理成本高昂,中小企业难以负担。

1.2 什么是大模型蒸馏?

大模型蒸馏 (Model Distillation),正是在此背景下诞生的核心破局技术 。其本质是:训练一个较小模型(student model)来逼近一个大模型(teacher model)的输出(大模型权重压缩至小模型权重),在保持性能前提下降低计算资源消耗与部署难度 。它适用于GPT、BERT等大模型,根本动因正是大模型带来的计算资源高消耗模型部署困难

🌟 形象类比:如同将饱和盐水蒸馏为纯净水------质量(参数量)减少,效果(性能)变化不大

1.3 关键概念

✅ 学生模型 & 教师模型

所谓的学生模型 ,就是模型蒸馏的目标模型。一般比被蒸馏的模型也就是教师模型要少一些参数。教师模型 通常是一个大型的、复杂的并且经过充分训练的模型,它在目标任务上表现非常出色,在绝大多数的情况下,教师模型决定了学生模型的性能上限,并且学生模型的性能不会超过教师模型。所以教师模型一般会是一个性能出色的模型。教师模型就像是一位知识渊博的专家,它对知识的掌握全面而深入,其丰富的参数和复杂的结构使其能够学习到数据中深层次的信息。而学生模型则相对小巧简单,参数数量较少,结构也更为精简。学生模型的目标是通过学习教师模型的知识,尽可能地逼近教师模型的性能。尽管学生模型自身的容量有限,但通过模型蒸馏这一过程,它能够从教师模型那里获取关键知识,从而在保持较小规模的同时,实现较好的任务表现,就像学生在老师的教导下不断成长,逐渐掌握解决问题的能力。

✅ 软标签 & 硬标签

在机器学习和深度学习中,硬标签(Hard Labels)和软标签(Soft Labels) 是两种不同的标签表示方式。硬标签(Hard Labels)在模型蒸馏中,指真实数据的标注,通常是独热编码(One-Hot Encoding) 的形式,如[0, 1, 0],表示每个样本的真实类别,用于监督学习。特点 是离散性和明确性:

  • 离散性:硬标签是离散的,每个样本只有一个类 别被标记为1,其余类别为0。
  • 明确性:硬标签直接反映了数据的真实类别,没有概率信息。

软标签(Soft Labels) 在模型蒸馏中,指教师模型的输出概率分布。教师模型对每个样本进行预测,输出每个类别的概率值(温度参数控制概率分布的平滑程度,后续内容会详细介绍),这些概率值构成了软标签,如[0.2, 0.7, 0.1]。一句话总结就是:Logits 是模型在最后一层的原始输出,表示模型对每个类别的预测置信度。它们通常需要经过 Softmax 函数处理,以转换为概率分布,用于分类决策和损失计算。在模型蒸馏中,Logits 通过温度参数调整,生成软标签,帮助学生模型学习教师模型的隐含知识。

> 为什么使用软标签?

软标签在模型蒸馏中的主要作用是传递教师模型的隐含知识,帮助学生模型学习到更复杂的决策逻辑。具体来说:

  • 类别间关系:软标签反映了类别之间的相似性。例如,如果教师模型预测某个样本属于类别A的概率为0.7,属于类别B的概率为0.2,这意味着类别A和类别B之间有一定的相似性。
  • 平滑输出:通过调整温度参数 T,软标签可以变得更加平滑,帮助学生模型学习到更广泛的特征。
  • 总结 :硬标签是真实数据的标注,用于监督学习。软标签是教师模型的输出概率分布,用于传递教师模型的隐含知识。学生模型通过结合硬标签损失和软标签损失进行训练,既能学习到正确的分类规则,又能继承教师模型的复杂决策逻辑。

二、蒸馏如何工作?------四步闭环与双路径知识传递

蒸馏绝非简单剪枝或量化,而是一套严谨的教师-学生协同训练范式。其标准流程可概括为四个闭环步骤,并包含两种互补的知识迁移路径。

2.1 四步标准流程(掘金 & CSDN 共同确认)

步骤 关键动作 技术要点 目的
① 训练/选定教师模型 使用海量数据与算力训练一个SOTA大模型(如BERT、GPT-4o、DeepSeek-R1) 模型必须在目标任务上表现卓越;其性能决定学生模型的理论上限;学生模型性能不会超过教师模型(CSDN明确断言) 提供高质量知识源
② 生成软标签(Soft Targets) 全部训练数据 输入冻结的教师模型,获取其输出的概率分布(非独热编码) 引入温度参数 T 控制分布平滑度: Pi=exp⁡(zi/T)∑jexp⁡(zj/T)P_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}Pi=∑jexp(zj/T)exp(zi/T) (zᵢ 为logits,T>1使分布更均匀) 提取教师的"决策信心"与"类别关系"
③ 构建并训练学生模型 设计轻量架构(如DistilBERT、3层LSTM+Attention、Qwen-1.5B),以软标签+硬标签联合监督训练 学生模型仅更新自身参数 ,教师模型全程冻结;输入数据与教师模型完全一致(CSDN强调) 实现知识内化
④ 评估与优化 在验证集上对比学生 vs 教师的准确率、推理速度、内存占用 调整温度T 、损失权重α 、学习率等超参;必须四维评估:①准确率 ②推理延迟(ms) ③模型体积(MB) ④GPU显存占用(GB)(正文补充) 确保性能-效率帕累托最优

✅ 此流程框架是综合多方内容总结而出(核心高度一致),也有人认为是7步,后续内容也会进一步对7步骤进行介绍。

2.2 双路径知识传递:不止于输出层

蒸馏的威力,源于它超越了传统监督学习的"结果导向",转而捕捉教师模型的过程智慧。这通过两条正交路径实现:

✅ 路径一:输出层转移(Output-layer Transfer)
  • 机制 :直接以教师模型的软标签概率分布作为监督信号。
  • 数学核心 :KL散度损失(Kullback-Leibler Divergence),衡量两个概率分布之间的差异
    Lsoft=1T2∑i=1Cqilog⁡(qipi) \mathcal{L}{\text{soft}} = \frac{1}{T^2} \sum{i=1}^{C} q_i \log \left(\frac{q_i}{p_i}\right) Lsoft=T21i=1∑Cqilog(piqi)
    其中 qᵢ 是教师软标签,pᵢ 是学生预测,C 为类别数,T 为温度。
  • 价值 :教会学生理解"为什么是A而不是B"。例如,教师输出 [0.7, 0.2, 0.1] 不仅告诉学生答案是A,更暗示A与B语义相近(0.2的置信度远高于C的0.1),这是硬标签 [1,0,0] 完全无法提供的信息。
  • 特别说明:出了软标签的损失函数计算外,输出成转移还可以通过交叉熵计算硬标签损失(交叉熵损失用于衡量两个概率分布之间的差异。在模型蒸馏中,它用于衡量学生模型预测的概率分布与真实标签之间的差异)。
✅ 路径二:中间层转移(Intermediate-layer Transfer)
  • 机制 :强制学生模型某隐藏层的激活值(activations) 与教师对应层的激活值对齐(常用MSE或PKD损失)。
  • 技术形式:匹配特定层(如Transformer的第6层FFN输出)的特征向量。
  • 价值 :捕获教师模型的深层特征表达能力。例如,在文本分类中,教师可能在中间层已形成"金融术语→高风险"的强关联表征,中间层蒸馏能将此结构化知识直接注入学生,大幅提升其泛化鲁棒性。

🔑 关键结论:"小模型不仅学软标签,还通过中间层转移学习大模型的隐藏层状态,目的是捕捉其深层特征表达能力与泛化潜力。"


三、关键技术要素详解:温度、损失函数与混合监督

蒸馏的成败,往往系于几个关键超参与损失设计。它们不是黑箱调参,而是有明确物理意义的工程杠杆。

3.1 蒸馏成功的理论前提

教师与学生模型必须满足函数空间兼容性(function space compatibility)。若学生架构无法表达教师决策边界(如用线性模型蒸馏深度神经网络)KL散度损失将无法收敛至理想下界。此限制在后续案例的LSTM+Attention蒸馏finBERT案例中实际存在------前者缺乏位置感知能力,导致长距离依赖建模失效。

3.1 温度参数 T:软标签的"调焦旋钮"

温度 T 是蒸馏中最具巧思的设计(跟大模型的温度系数不是一个参数,其作用本质是对logits空间进行仿射变幻以调控softmax输出的熵值),它决定了教师知识的"颗粒度"。

T 分布形态 对学生学习的影响 工程实践(CSDN实例)
T → 0 极尖锐(趋近one-hot) 学生只学"正确答案",丢失所有隐含知识 ❌ 不适用
T = 1 标准Softmax 保留部分置信度信息,但类别区分过强 基线参考
T = 5 平滑分布(CSDN明确采用) 类别间差异变小(如0.7/0.2/0.1),利于学习相似性 ✅ 常用静态值
T: 5 → 1 动态衰减(CSDN明确策略) 初期鼓励探索(高T),后期聚焦精确(低T) ✅ 进阶优化手段

💡 示例:教师Logits=[2.0, 1.0, 0.1],当 T=2 时,软标签≈[0.65, 0.25, 0.10];若 T=1 ,则≈[0.70, 0.26, 0.04]------微小的温度变化,显著改变了对学生"学习重点"的引导。

3.2 损失函数:软硬结合的黄金配比

单一使用软标签易导致学生过度拟合教师的"风格",忽略数据本身的真实性。因此,混合损失成为工业界标配:

Ltotal=α⋅Lsoft+(1−α)⋅Lhard \mathcal{L}{\text{total}} = \alpha \cdot \mathcal{L}{\text{soft}} + (1 - \alpha) \cdot \mathcal{L}_{\text{hard}} Ltotal=α⋅Lsoft+(1−α)⋅Lhard

  • 软损失 Lsoft\mathcal{L}_{\text{soft}}Lsoft:KL散度,驱动学生模仿教师的"思考过程";
  • 硬损失 Lhard\mathcal{L}_{\text{hard}}Lhard:交叉熵(Cross-Entropy),约束学生忠于"事实真相"(真实标签);
  • 权重 α\alphaα :平衡两者重要性。CSDN金融文本案例中,α=0.7\alpha = 0.7α=0.7(70%软监督 + 30%硬监督),这一比例被证实能兼顾知识迁移与基础准确性。

⚖️ 为何必须混合?

若 α=1\alpha = 1α=1(纯软标签):学生可能学会教师的偏见或错误模式(如教师将"苹果"误判为"水果"概率0.9,学生盲目跟随);

若 α=0\alpha = 0α=0(纯硬标签):退化为普通训练,完全丧失蒸馏价值。

3.3 蒸馏技术全景图:不止于KD

现代大模型蒸馏已发展为一套工具箱,针对不同需求选择组合:

技术名称 核心机制 解决痛点 文档出处
知识蒸馏(KD) 软目标+硬目标联合训练 提升精度与鲁棒性 掘金、CSDN、正文补充
数据增强(Teacher-generated Data) 教师模型生成合成样本(如问答对、摘要)扩充训练集 缓解小模型数据饥渴,提升泛化 正文补充
中间层蒸馏 对齐隐藏层激活/注意力矩阵 捕获结构化知识,超越输出层 掘金、正文补充
多教师蒸馏 学生同时向GPT-4o、Claude、DeepSeek-R1等多模型学习 融合多元视角,增强鲁棒性 正文补充
渐进蒸馏(Progressive Distillation) 分阶段蒸馏(如先蒸馏7B→14B,再14B→32B) 提升最终性能稳定性 正文补充

🌐 前沿方向:跨模态蒸馏 (文本→图像知识迁移)、自蒸馏 (学生用自身预测迭代优化)、对抗性蒸馏(注入对抗样本提升安全性)。


四、经典与前沿案例:从DistilBERT到DeepSeek-R1-Distill系列

理论终需实践检验。以下案例均来自原文实证,数据真实可查,展现蒸馏从NLP基石到大模型前沿的完整演进。

4.1 开山之作:BERT → DistilBERT(掘金详述)

  • 教师 :标准BERT(Bidirectional Encoder Representations from Transformers),NLP任务性能卓越,但模型体积大、计算开销高
  • 学生:DistilBERT(参数量显著减小)
  • 蒸馏方式 :用BERT对训练数据做推理,生成词级别(per-token)概率分布作为软标签;以该软标签监督DistilBERT训练(含温度调节)
  • 效果
    • 性能:"几乎相同"于BERT(DistilBERT论文表明,在GLUE基准上,其平均分达BERT-base的97%,但在SQuAD v1.1上F1仅低1.5个百分点。然而,在长文档推理(如NarrativeQA)或低资源语言任务中,DistillBERT性能衰减可达8-12%)
    • 效率显著减少计算资源需求 ,尤其在推理速度与内存占用方面(掘金原文结论)

💡 意义:首次证明,无需牺牲精度即可实现模型瘦身,奠定蒸馏工业应用基石。

4.2 垂直领域攻坚:FinBERT → LSTM+Attention(CSDN金融案例)

  • 教师finBERT(专用于金融文本的BERT变体)
  • 学生3层LSTM + Attention(轻量、适合边缘部署)
  • 任务:金融新闻情感分类(正面/中性/负面)
  • 关键配置T=3 , α=0.7, 输入完全对齐
  • 效果:学生模型在F1分数上达到教师模型的95%,但推理延迟从800ms降至120ms,可在树莓派等设备实时运行。

4.3 大模型时代标杆:DeepSeek-R1-Distill系列(正文补充实测)

这是当前最硬核的蒸馏成果展示,所有数据均来自原文评测基准(MATH-500, GPQA Diamond等):

模型 参数量级 MATH-500(数学) GPQA Diamond(事实) LiveCodeBench(编程) CodeForces(编程评分) 对标对象
DeepSeek-R1-Distill-Qwen-1.5B 1.5B 83.9% --- --- --- ---
DeepSeek-R1-Distill-Qwen-7B 7B 92.8% 49.1% 37.6% 1189 ---
DeepSeek-R1-Distill-Qwen-14B 14B 93.9% 59.1% 53.1% 1481 ---
DeepSeek-R1-Distill-Qwen-32B 32B 94.3% 62.1% 57.2% 1691 ---
DeepSeek-R1-Distill-Llama-70B 70B 94.5% --- 57.5% 1633 ≈ GPT-4o / o1-mini

🔍 关键洞察:

  • 数学与事实推理 :随参数量增加,性能稳步提升,Llama-70B94.5% MATH-500 成为当前蒸馏模型最高分,且在AIME 2024(高级数学竞赛)达86.7%,被作者称为"先进数学推理的优选模型"。
  • 编程能力 :虽整体弱于数学,但Qwen-32B(1691)与Llama-70B(1633)已媲美GPT-4o级别,证实蒸馏可有效迁移复杂推理能力。
  • 架构差异:同规模下,Llama蒸馏模型在数学上略优(94.5% vs 94.3%),Qwen在编程上略优(1691 vs 1633),体现基础架构特性。
    📌 注:所有分数均来自原文表格,未作任何插值或推测。GPQA DiamondLiveCodeBench为百分比准确率,CodeForces为原始评分(数值越高越好)。

五、落地实践指南:框架、步骤与避坑清单

从论文走向产品,需一套稳健的工程方法论。以下是整合多源文档的实战清单。

5.1 主流支持框架

框架 核心能力 适用场景 链接/备注
Hugging Face transformers 提供Distiller类,内置KD、中间层对齐等模块 快速原型,NLP任务首选 pip install transformers
PyTorch distiller 专为模型压缩设计,支持蒸馏全流程管理 需深度定制蒸馏策略 GitHub开源库
TensorFlow Model Optimization 集成蒸馏、剪枝、量化工具链 TensorFlow生态用户 TF官方工具包
DeepSpeed 微软大模型优化库,内置高效蒸馏模块 训练超大教师模型后直接蒸馏 支持ZeRO优化

5.2 七步实施流程(上述大纲,已融入前文四步,此处强化工程细节)

  1. 选教师 :务必选用在目标任务上SOTA的预训练模型 (如金融选finBERT,数学选DeepSeek-R1)。教师质量是天花板。
  2. 定学生 :根据部署环境反推------移动端选<1B参数,边缘设备选1-7B,服务器可选14-32B。架构可简化(如减少层数、头数)。
  3. 产软标 :用教师模型对整个训练集 做一次推理,保存logits(便于后续不同T 实验)。切忌只蒸馏子集!
  4. 设输入 :确保学生与教师输入tokenization完全一致(同一分词器、padding策略)。
  5. 配损失 :起始用α=0.7, T=3;若学生过拟合软标,可降α ;若泛化差,可升T
  6. 训学生 :仅更新学生参数;教师模型requires_grad=False;推荐使用torch.no_grad()包裹教师推理。
  7. 评模型必须四维评估:① 准确率(vs 教师/基线) ② 推理延迟(ms) ③ 模型体积(MB) ④ GPU显存占用(GB)。任一维度不达标,即未完成蒸馏。

5.3 常见陷阱与应对(基于原文挑战总结)

陷阱 表现 应对策略 来源依据
知识损失严重 学生在长推理链、专业领域任务上性能断崖下跌 ✅ 引入中间层蒸馏 ✅ 使用多教师蒸馏融合知识 正文补充、掘金
超参敏感 微调Tα导致性能剧烈波动 ✅ 采用温度动态衰减 (5→1) ✅ 使用网格搜索+贝叶斯优化自动调参 CSDN、正文补充
蒸馏后退化 学生精度低于同等规模直接训练模型 必须混合硬标签损失α<1 ) ✅ 检查教师模型在该任务上是否真SOTA CSDN原理对比
部署不兼容 蒸馏后模型无法在目标硬件(如Jetson)运行 ✅ 训练时即用目标硬件的ONNX Runtime或Triton 导出 ✅ 结合量化(INT8)+ 蒸馏双管齐下 正文补充"边缘部署"

5.4 Python实战代码

特别提醒,我以下编写的代码由于训练数据比较少,所以学生模型在测试数据集上会出现一定程度的过拟合问题,实际应用的时候大家可以自行调参或增加训练数据。

python 复制代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
知识蒸馏:在线大模型(教师文本/可选 logprobs) -> 本地 Qwen2.5 小模型。

说明(与「博文」类流程一致的核心点):
- 硬标签:对教师生成序列做标准下一词预测 CE(仅对 answer 段计权)。
- 软标签:KL( student_logits/T || teacher_dist );教师分布优先来自 API logprobs,
  若无则可用可选本地教师前向 logits;再不行则退化为带温度的标签平滑近似。
- 隐层:优先 MSE( proj(h_s^l), h_t^l )(需本地教师);否则用「中间层对齐下一词嵌入」
  作为可训练 surrogate(不依赖教师隐层)。

温度 T:支持按 epoch 或按「全局 step」从 T_max 线性降至 T_min(默认 5 -> 1)。

命令行示例:
  python distill_qwen.py --temp-schedule step --epochs 5 --output-dir ./out
  python distill_qwen.py --collect-only --refresh-cache
  python distill_qwen.py --resume --output-dir ./out
  python distill_qwen.py --eval-only --eval-student ./student_distilled
  python distill_qwen.py --epochs 10 --early-stop 3 --eval-baseline ./base_student

环境变量(OpenAI 兼容,适用于多数 Qwen 代理 / 部分平台):
  OPENAI_API_KEY
  OPENAI_BASE_URL   例如 https://dashscope.aliyuncs.com/compatible-mode/v1
  TEACHER_MODEL     例如 qwen-plus / qwen-turbo

或使用 DashScope 原生(需 pip install dashscope):
  DASHSCOPE_API_KEY
  USE_DASHSCOPE=1
  TEACHER_MODEL     例如 qwen-plus

学生模型默认:Qwen/Qwen2.5-0.5B-Instruct(HF 常见小模型;若你有 0.8B 版可改 STUDENT_MODEL_ID)。
"""

from __future__ import annotations

import argparse
import json
import math
import os
import random
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import nullcontext
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

try:
    from transformers import BitsAndBytesConfig
except ImportError:
    BitsAndBytesConfig = None  # type: ignore

# ---------------------------------------------------------------------------
# 自拟:Python 编程向指令(可扩充)
# ---------------------------------------------------------------------------

PYTHON_INSTRUCT_PROMPTS: List[str] = [
    "用 Python 3 写函数:输入整数 n,返回前 n 个斐波那契数(列表)。请只输出代码与简短注释。",
    "解释 Python 中 `list` 与 `tuple` 的区别,并各给一个使用场景。用中文回答,条理清晰。",
    "写一段 Python:读取文本文件,统计行数、单词数(按空白分词),打印结果。处理文件不存在异常。",
    "什么是 Python 的 GIL?它对多线程 CPU 密集任务有什么影响?如何避免?中文简述。",
    "用 `dataclasses.dataclass` 定义一个 `Point(x,y)`,并实现到原点的距离方法。只输出代码。",
    "解释 `*args` 与 `**kwargs` 的语义,并给一个合并为字典的示例函数。",
    "写 Python 单元测试(unittest):测试一个将字符串转为整数的函数在非法输入时抛出 ValueError。",
    "什么是迭代器与生成器?用生成器实现无限斐波那契序列的前 10 项打印示例。",
    "简述 Python 虚拟环境 venv 的作用,并给出创建与激活命令(Windows 与 Linux 各一行)。",
    "用 `typing` 为函数 `def merge(a: list[int], b: list[int]) -> list[int]` 写完整类型注解与 docstring。",
    "解释浅拷贝与深拷贝;对嵌套列表给出 `copy.copy` 与 `copy.deepcopy` 的差异示例。",
    "写异步 aiohttp 风格伪代码不可行则改用 `requests`:GET URL,超时 5s,失败重试 3 次。只输出 Python。",
    "什么是装饰器?写一个计时装饰器,打印函数耗时毫秒。",
    "解释 `if __name__ == '__main__':` 的用途,并给一个可导入也可直接运行的模块示例。",
    "用 `pathlib.Path` 遍历目录下所有 `.py` 文件并打印相对路径。",
]

# 验证集:与训练提示不重复,用于蒸馏效果评价
VALIDATION_PYTHON_PROMPTS: List[str] = [
    "写一个 Python 函数 `is_prime(n)` 判断 n 是否为质数,处理 n<2 边界。只输出代码。",
    "说明 Python 字典与集合的时间复杂度(平均意义下查找/插入),各一句话。",
    "用 `contextlib.contextmanager` 实现一个临时切换工作目录的上下文管理器。只输出代码。",
    "解释 `__getitem__` 与 `__iter__` 在自定义容器中的角色差异。中文简述。",
    "写 Python:用 `subprocess` 运行 `python --version` 并捕获 stdout。只输出代码。",
    "什么是闭包?给一个工厂函数返回计数器的示例(nonlocal)。",
    "用 `functools.lru_cache` 为斐波那契递归做记忆化,并说明缓存参数。只输出代码。",
    "简述列表推导式与生成器表达式在内存上的差异,并各给一例。",
]


@dataclass
class DistillConfig:
    student_model_id: str = os.environ.get(
        "STUDENT_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct"
    )
    # 可选:本地教师(用于完整 logits + 隐层;4bit 省显存)。为空则仅用 API + surrogate。
    local_teacher_model_id: str = os.environ.get("LOCAL_TEACHER_MODEL_ID", "")
    teacher_model: str = os.environ.get("TEACHER_MODEL", "qwen-plus")
    max_new_tokens: int = 512
    temperature_api: float = 0.7
    num_epochs: int = 3
    batch_size: int = 1
    lr: float = 2e-5
    weight_decay: float = 0.01
    max_length: int = 1024
    # 损失权重
    lambda_soft: float = 0.2
    lambda_hard: float = 1.5
    lambda_hidden: float = 0.002
    label_smoothing_eps: float = 0.1  # API 无 logprobs 且无本地教师时的软目标平滑
    hidden_student_layer: int = -2  # 学生用于隐层蒸馏的层索引(-2 表示倒数第二层)
    # 温度调度:epoch = 每个 epoch 内 T 固定;step = 每个样本步更新 T
    temperature_schedule_mode: str = os.environ.get("TEMP_SCHEDULE", "epoch")
    T_max: float = 5.0
    T_min: float = 1.0
    # 数据
    cache_path: str = "teacher_cache.jsonl"
    refresh_cache: bool = False
    collect_only: bool = False
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    use_4bit_local_teacher: bool = os.environ.get("USE_4BIT_TEACHER", "1") == "1"
    grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", "4"))
    # 训练加速与检查点
    use_amp: bool = True
    save_every_epoch: bool = True
    resume_from: str = ""
    # 验证 / 评价
    val_cache_path: str = os.environ.get("VAL_CACHE_PATH", "val_teacher_cache.jsonl")
    eval_after_train: bool = os.environ.get("EVAL_AFTER_TRAIN", "1") != "0"
    eval_only: bool = False
    eval_student_path: str = ""  # eval-only 时优先使用;否则用 output_dir
    eval_baseline_path: str = os.environ.get("EVAL_BASELINE_PATH", "")
    refresh_val_cache: bool = False
    eval_max_new_tokens: int = 256
    eval_latency_runs: int = 5
    eval_warmup_runs: int = 2
    early_stop_patience: int = int(os.environ.get("EARLY_STOP_PATIENCE", "0"))
    best_metric: str = "token_acc"  # 用于早停:token_acc 越高越好


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ---------------------------------------------------------------------------
# 教师:在线 API
# ---------------------------------------------------------------------------


def _build_user_message(instruction: str) -> str:
    return (
        "你是一位严谨的 Python 助教。请根据用户问题给出准确、可运行的代码或清晰解释。\n\n"
        f"用户问题:\n{instruction.strip()}"
    )


def call_teacher_openai_compatible(
    instruction: str, cfg: DistillConfig
) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
    """
    返回 (assistant_text, logprobs_payload)。
    logprobs_payload: 若 API 支持,可为每步 top_logprobs 列表;否则 None。
    """
    from openai import OpenAI

    client = OpenAI(
        api_key=os.environ.get("OPENAI_API_KEY", "xxxxxx"),
        base_url=os.environ.get("OPENAI_BASE_URL", "xxxxxx"),
    )
    messages = [
        {
            "role": "system",
            "content": "You are a helpful Python programming tutor. Reply concisely.",
        },
        {"role": "user", "content": _build_user_message(instruction)},
    ]
    kwargs: Dict[str, Any] = {
        "model": cfg.teacher_model,
        "messages": messages,
        "max_tokens": cfg.max_new_tokens,
        "temperature": cfg.temperature_api,
    }
    # 部分兼容接口支持 logprobs
    try:
        completion = client.chat.completions.create(
            **kwargs,
            logprobs=True,
            top_logprobs=5,
        )
    except TypeError:
        completion = client.chat.completions.create(**kwargs)

    choice = completion.choices[0]
    text = (choice.message.content or "").strip()
    logprobs = None
    if getattr(choice, "logprobs", None) and choice.logprobs.content:
        logprobs = [
            {
                "token": t.token,
                "top": [
                    {"tok": x.token, "logp": x.logprob}
                    for x in (t.top_logprobs or [])
                ],
            }
            for t in choice.logprobs.content
        ]
    return text, logprobs


def call_teacher_dashscope(instruction: str, cfg: DistillConfig) -> Tuple[str, None]:
    import dashscope
    from dashscope import Generation

    dashscope.api_key = os.environ.get("DASHSCOPE_API_KEY", "")
    messages = [
        {
            "role": "system",
            "content": "你是一位 Python 编程助教,回答要准确、可执行。",
        },
        {"role": "user", "content": _build_user_message(instruction)},
    ]
    resp = Generation.call(
        cfg.teacher_model,
        messages=messages,
        result_format="message",
        max_tokens=cfg.max_new_tokens,
        temperature=cfg.temperature_api,
    )
    if resp.status_code != 200:
        raise RuntimeError(f"DashScope error: {resp}")
    text = resp.output.choices[0].message.content.strip()
    return text, None


def call_teacher(instruction: str, cfg: DistillConfig) -> Tuple[str, Optional[List[Dict[str, Any]]]]:
    if os.environ.get("USE_DASHSCOPE", "").lower() in ("1", "true", "yes"):
        return call_teacher_dashscope(instruction, cfg)
    return call_teacher_openai_compatible(instruction, cfg)


def load_or_build_dataset(cfg: DistillConfig) -> List[Dict[str, Any]]:
    if cfg.refresh_cache and os.path.isfile(cfg.cache_path):
        os.remove(cfg.cache_path)

    if os.path.isfile(cfg.cache_path):
        rows: List[Dict[str, Any]] = []
        with open(cfg.cache_path, "r", encoding="utf-8") as f:
            for line in f:
                rows.append(json.loads(line))
        return rows

    rows = []
    for ins in tqdm(PYTHON_INSTRUCT_PROMPTS, desc="Calling teacher API"):
        text, lp = call_teacher(ins, cfg)
        rows.append({"instruction": ins, "answer": text, "logprobs": lp})
        with open(cfg.cache_path, "a", encoding="utf-8") as f:
            f.write(json.dumps({"instruction": ins, "answer": text, "logprobs": lp}, ensure_ascii=False) + "\n")
    return rows


# ---------------------------------------------------------------------------
# 将 API logprobs 粗映射到张量分布(与子词 tokenizer 对齐的近似)
# ---------------------------------------------------------------------------


def _logprobs_to_sparse_dist(
    tokenizer: Any,
    logprobs_step: Dict[str, Any],
    device: torch.device,
    vocab_size: int,
) -> torch.Tensor:
    """单步:由 top_logprobs 构造近似分布(其余质量压到极小)。"""
    dist = torch.full((vocab_size,), -1e9, device=device)
    for item in logprobs_step.get("top", []):
        tok_str = item["tok"]
        lp = float(item["logp"])
        tid = tokenizer.convert_tokens_to_ids(
            tokenizer.tokenize(tok_str, add_special_tokens=False) or [tok_str]
        )
        if isinstance(tid, int):
            dist[tid] = max(dist[tid].item(), lp)
        else:
            for t in tid:
                dist[t] = max(dist[t].item(), lp)
    return dist


# ---------------------------------------------------------------------------
# 模型加载
# ---------------------------------------------------------------------------


def _student_load_path(cfg: DistillConfig) -> str:
    if cfg.resume_from:
        p = os.path.join(cfg.resume_from, "student")
        if os.path.isdir(p):
            return p
    return cfg.student_model_id


def load_student(cfg: DistillConfig) -> Tuple[Any, Any]:
    load_id = _student_load_path(cfg)
    tok = AutoTokenizer.from_pretrained(load_id, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        load_id,
        trust_remote_code=True,
        dtype=torch.bfloat16 if cfg.device == "cuda" else torch.float32,
        device_map="auto" if cfg.device == "cuda" else None,
    )
    if cfg.device != "cuda":
        model = model.to(cfg.device)
    model.train()
    return model, tok


def load_local_teacher_optional(cfg: DistillConfig) -> Optional[Tuple[Any, Any]]:
    if not cfg.local_teacher_model_id.strip():
        return None
    quant = None
    if (
        cfg.use_4bit_local_teacher
        and cfg.device == "cuda"
        and BitsAndBytesConfig is not None
    ):
        quant = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    tok = AutoTokenizer.from_pretrained(cfg.local_teacher_model_id, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        cfg.local_teacher_model_id,
        trust_remote_code=True,
        dtype=torch.bfloat16 if cfg.device == "cuda" else torch.float32,
        device_map="auto" if cfg.device == "cuda" else None,
        quantization_config=quant,
    )
    if cfg.device != "cuda":
        model = model.to(cfg.device)
    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)
    return model, tok


# ---------------------------------------------------------------------------
# 损失
# ---------------------------------------------------------------------------


def temperature_schedule(epoch: int, total_epochs: int, cfg: DistillConfig) -> float:
    if total_epochs <= 1:
        return cfg.T_min
    alpha = epoch / max(total_epochs - 1, 1)
    return cfg.T_max + (cfg.T_min - cfg.T_max) * alpha


def temperature_schedule_step(step: int, total_steps: int, cfg: DistillConfig) -> float:
    """全局 step 从 0 到 total_steps-1,T 从 T_max 线性降至 T_min。"""
    if total_steps <= 1:
        return cfg.T_min
    alpha = step / max(total_steps - 1, 1)
    return cfg.T_max + (cfg.T_min - cfg.T_max) * alpha


def masked_ce(
    logits: torch.Tensor,
    labels: torch.Tensor,
    mask: torch.Tensor,
) -> torch.Tensor:
    """logits: [B, L, V], labels: [B, L], mask: [B, L] bool"""
    ce = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        labels.view(-1),
        reduction="none",
    )
    m = mask.view(-1).float()
    denom = m.sum().clamp_min(1.0)
    return (ce * m).sum() / denom


def kl_soft(
    student_logits: torch.Tensor,
    teacher_probs: torch.Tensor,
    mask: torch.Tensor,
    T: float,
) -> torch.Tensor:
    """teacher_probs: [B, L, V] 已 softmax"""
    s_log = F.log_softmax(student_logits / T, dim=-1)
    t = teacher_probs.clamp_min(1e-8)
    kl = (t * (t.log() - s_log)).sum(dim=-1)
    m = mask.float()
    denom = m.sum().clamp_min(1.0)
    return (kl * m).sum() / denom * (T * T)


def kl_to_smoothed_one_hot(
    student_logits: torch.Tensor,
    labels: torch.Tensor,
    mask: torch.Tensor,
    T: float,
    eps: float,
    vocab_size: int,
) -> torch.Tensor:
    """无教师分布时:软目标 = (1-eps)*one_hot + eps/V"""
    oh = F.one_hot(labels.clamp_min(0), num_classes=vocab_size).float()
    smooth = (1.0 - eps) * oh + eps / float(vocab_size)
    s_log = F.log_softmax(student_logits / T, dim=-1)
    kl = (smooth * (smooth.clamp_min(1e-8).log() - s_log)).sum(dim=-1)
    m = mask.float()
    denom = m.sum().clamp_min(1.0)
    return (kl * m).sum() / denom * (T * T)


def hidden_loss_with_local_teacher(
    h_s: torch.Tensor,
    h_t: torch.Tensor,
    proj: nn.Linear,
    mask: torch.Tensor,
) -> torch.Tensor:
    """h_s, h_t: [B, L, H];最后一维 MSE 取 mean,再对 mask 位置平均(与隐层维度无关)。"""
    pred = proj(h_s)
    mse = (pred - h_t.detach()).pow(2).mean(dim=-1)
    m = mask.float()
    return (mse * m).sum() / m.sum().clamp_min(1.0)


def hidden_loss_embedding_align(
    h_s: torch.Tensor,
    embed_weight: torch.Tensor,
    next_ids: torch.Tensor,
    proj: nn.Linear,
    mask: torch.Tensor,
) -> torch.Tensor:
    """中间层对齐下一词嵌入;最后一维 MSE 取 mean,再 mask 平均。"""
    emb = F.embedding(next_ids.clamp_min(0), embed_weight)
    pred = proj(h_s)
    mse = (pred - emb.detach()).pow(2).mean(dim=-1)
    m = mask.float()
    return (mse * m).sum() / m.sum().clamp_min(1.0)


# ---------------------------------------------------------------------------
# 单条样本张量化
# ---------------------------------------------------------------------------


def build_chat_inputs(
    tokenizer: Any,
    instruction: str,
    answer: str,
    max_length: int,
    device: torch.device,
) -> Dict[str, torch.Tensor]:
    """
    使用 chat 模板拼接 user/assistant,返回 input_ids 与 labels(prompt 段为 -100)。
    """
    messages = [
        {"role": "user", "content": _build_user_message(instruction)},
        {"role": "assistant", "content": answer},
    ]
    if hasattr(tokenizer, "apply_chat_template"):
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        )
    else:
        text = f"User:\n{instruction}\n\nAssistant:\n{answer}"

    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
    )
    input_ids = enc["input_ids"].to(device)
    attn = enc["attention_mask"].to(device)

    # 近似:仅对后半段 assistant 计 loss ------ 用 assistant 片段在文本中的起始偏移
    assistant_marker = "assistant"
    low_text = text.lower()
    idx = low_text.rfind(assistant_marker)
    if idx >= 0:
        prefix = text[: idx + len(assistant_marker)]
        pref_ids = tokenizer(prefix, truncation=True, max_length=max_length, return_tensors="pt")[
            "input_ids"
        ].to(device)
        cut = min(pref_ids.size(1), input_ids.size(1))
    else:
        cut = input_ids.size(1) // 2

    labels = input_ids.clone()
    labels[:, :cut] = -100
    loss_mask = labels.ne(-100) & attn.bool()

    return {
        "input_ids": input_ids,
        "attention_mask": attn,
        "labels": labels,
        "loss_mask": loss_mask,
        "cut": cut,
    }


def build_teacher_probs_from_logprobs(
    tokenizer: Any,
    logprobs: Optional[Sequence[Dict[str, Any]]],
    answer_ids: torch.Tensor,
    vocab_size: int,
    device: torch.device,
) -> Optional[torch.Tensor]:
    """
    将 API 返回的 logprobs(按生成 token 顺序)扩展到与 answer_ids 对齐的 [1, L, V]。
    若长度对不齐则返回 None。
    """
    if not logprobs:
        return None
    L = answer_ids.size(0)
    if len(logprobs) != L:
        return None
    probs = []
    for i, step in enumerate(logprobs):
        dist = _logprobs_to_sparse_dist(tokenizer, step, device, vocab_size)
        dist = dist - dist.logsumexp(dim=-1)
        probs.append(dist.softmax(dim=-1))
    return torch.stack(probs, dim=0).unsqueeze(0)


# ---------------------------------------------------------------------------
# 验证集(教师 API 生成金标准,与学生/本地教师对比)
# ---------------------------------------------------------------------------


def load_or_build_val_dataset(cfg: DistillConfig) -> List[Dict[str, Any]]:
    if cfg.refresh_val_cache and os.path.isfile(cfg.val_cache_path):
        os.remove(cfg.val_cache_path)

    if os.path.isfile(cfg.val_cache_path):
        rows: List[Dict[str, Any]] = []
        with open(cfg.val_cache_path, "r", encoding="utf-8") as f:
            for line in f:
                rows.append(json.loads(line))
        return rows

    rows = []
    for ins in tqdm(VALIDATION_PYTHON_PROMPTS, desc="Building val cache (teacher API)"):
        text, lp = call_teacher(ins, cfg)
        rows.append({"instruction": ins, "answer": text, "logprobs": lp})
        with open(cfg.val_cache_path, "a", encoding="utf-8") as f:
            f.write(
                json.dumps({"instruction": ins, "answer": text, "logprobs": lp}, ensure_ascii=False)
                + "\n"
            )
    return rows


def dir_size_mb(root: str) -> float:
    if not root or not os.path.isdir(root):
        return 0.0
    total = 0
    for dp, _, files in os.walk(root):
        for fn in files:
            fp = os.path.join(dp, fn)
            if os.path.isfile(fp):
                total += os.path.getsize(fp)
    return total / (1024 * 1024)


def cuda_reset_peak() -> None:
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()


def cuda_peak_allocated_gb() -> float:
    if not torch.cuda.is_available():
        return float("nan")
    return torch.cuda.max_memory_allocated() / (1024.0**3)


def _primary_device(model: nn.Module) -> torch.device:
    return next(model.parameters()).device


@torch.no_grad()
def token_accuracy_teacher_forcing(
    model: Any,
    tokenizer: Any,
    instruction: str,
    reference_answer: str,
    max_length: int,
    device: torch.device,
) -> float:
    """
    在「金标准回答」上做下一词预测:准确率 = 预测 token 与参考一致的比例(仅 assistant 段)。
    """
    model.eval()
    batch = build_chat_inputs(
        tokenizer, instruction, reference_answer, max_length, device
    )
    input_ids = batch["input_ids"]
    attn = batch["attention_mask"]
    loss_mask = batch["loss_mask"]
    out = model(
        input_ids=input_ids,
        attention_mask=attn,
        use_cache=False,
    )
    logits = out.logits[:, :-1]
    labels = input_ids[:, 1:]
    mask = loss_mask[:, 1:]
    pred = logits.argmax(dim=-1)
    correct = (pred == labels) & mask
    denom = mask.sum().clamp_min(1)
    return float(correct.sum() / denom)


def mean_token_accuracy_on_val(
    model: Any,
    tokenizer: Any,
    val_rows: List[Dict[str, Any]],
    max_length: int,
    device: torch.device,
) -> float:
    if not val_rows:
        return 0.0
    accs = []
    for row in val_rows:
        accs.append(
            token_accuracy_teacher_forcing(
                model,
                tokenizer,
                row["instruction"],
                row["answer"],
                max_length,
                device,
            )
        )
    return float(sum(accs) / len(accs))


def _build_prompt_only_ids(tokenizer: Any, instruction: str, device: torch.device) -> torch.Tensor:
    messages = [
        {"role": "user", "content": _build_user_message(instruction)},
    ]
    if hasattr(tokenizer, "apply_chat_template"):
        text = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    else:
        text = f"User:\n{instruction}\n\nAssistant:\n"
    enc = tokenizer(text, return_tensors="pt")
    return enc["input_ids"].to(device)


def measure_generate_latency_ms(
    model: Any,
    tokenizer: Any,
    instruction: str,
    device: torch.device,
    max_new_tokens: int,
    warmup: int,
    runs: int,
) -> Dict[str, float]:
    """单次请求:从 prompt 到生成 `max_new_tokens`(早停受 eos 影响)的平均 wall 时间。"""
    model.eval()
    input_ids = _build_prompt_only_ids(tokenizer, instruction, device)
    dev = _primary_device(model)

    def _sync() -> None:
        if dev.type == "cuda":
            torch.cuda.synchronize(dev)

    latencies: List[float] = []
    with torch.no_grad():
        for _ in range(max(warmup, 0)):
            _sync()
            t0 = time.perf_counter()
            model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=getattr(tokenizer, "pad_token_id", None)
                or tokenizer.eos_token_id,
                use_cache=True,
            )
            _sync()
            latencies.append((time.perf_counter() - t0) * 1000.0)

        bench: List[float] = []
        for _ in range(max(runs, 1)):
            _sync()
            t0 = time.perf_counter()
            model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=getattr(tokenizer, "pad_token_id", None)
                or tokenizer.eos_token_id,
                use_cache=True,
            )
            _sync()
            bench.append((time.perf_counter() - t0) * 1000.0)

    return {
        "latency_mean_ms": float(sum(bench) / len(bench)),
        "latency_std_ms": float(
            (sum((x - sum(bench) / len(bench)) ** 2 for x in bench) / max(len(bench) - 1, 1)) ** 0.5
        )
        if len(bench) > 1
        else 0.0,
    }


def measure_forward_peak_memory_gb(
    model: Any,
    tokenizer: Any,
    instruction: str,
    reference_answer: str,
    max_length: int,
    device: torch.device,
) -> float:
    """一次完整序列前向(与训练同形状)的 CUDA 峰值显存;非 CUDA 返回 nan。"""
    if device.type != "cuda":
        return float("nan")
    model.eval()
    cuda_reset_peak()
    batch = build_chat_inputs(
        tokenizer, instruction, reference_answer, max_length, device
    )
    with torch.no_grad():
        model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            use_cache=False,
        )
    return cuda_peak_allocated_gb()


def evaluate_local_model_metrics(
    name: str,
    model: Any,
    tokenizer: Any,
    val_rows: List[Dict[str, Any]],
    cfg: DistillConfig,
    weights_dir: Optional[str],
    bench_instruction: str,
) -> Dict[str, Any]:
    device = torch.device(cfg.device)
    tok_acc = mean_token_accuracy_on_val(
        model, tokenizer, val_rows, cfg.max_length, device
    )
    lat = measure_generate_latency_ms(
        model,
        tokenizer,
        bench_instruction,
        device,
        cfg.eval_max_new_tokens,
        cfg.eval_warmup_runs,
        cfg.eval_latency_runs,
    )
    if weights_dir and os.path.isdir(weights_dir):
        size_mb = dir_size_mb(weights_dir)
    else:
        size_mb = float("nan")
    peak_gb = measure_forward_peak_memory_gb(
        model,
        tokenizer,
        val_rows[0]["instruction"],
        val_rows[0]["answer"],
        cfg.max_length,
        device,
    )
    return {
        "name": name,
        "token_acc_val": tok_acc,
        "latency_mean_ms": lat["latency_mean_ms"],
        "latency_std_ms": lat["latency_std_ms"],
        "disk_size_mb": size_mb,
        "peak_mem_gb": peak_gb,
    }


def evaluate_teacher_api_metrics(
    val_rows: List[Dict[str, Any]], cfg: DistillConfig
) -> Dict[str, Any]:
    """在线教师:测 API 往返延迟;准确率/显存/本地体积在服务端,记为 N/A。"""
    times: List[float] = []
    for row in val_rows[: min(len(val_rows), 8)]:
        t0 = time.perf_counter()
        call_teacher(row["instruction"], cfg)
        times.append((time.perf_counter() - t0) * 1000.0)
    mean_ms = float(sum(times) / len(times)) if times else float("nan")
    std_ms = (
        float(
            (sum((x - mean_ms) ** 2 for x in times) / max(len(times) - 1, 1)) ** 0.5
        )
        if len(times) > 1
        else 0.0
    )
    return {
        "name": "teacher_api",
        "token_acc_val": float("nan"),
        "latency_mean_ms": mean_ms,
        "latency_std_ms": std_ms,
        "disk_size_mb": float("nan"),
        "peak_mem_gb": float("nan"),
    }


def print_eval_table(rows: List[Dict[str, Any]]) -> None:
    headers = ("角色", "验证集Token准确率", "推理延迟均值(ms)", "延迟std(ms)", "磁盘体积(MB)", "峰值显存(GB)")
    print("\n" + "=" * 100)
    print("蒸馏效果评价(验证集金标准为在线教师生成;本地指标在单机 GPU/CPU 上测得)")
    print("=" * 100)
    print("".join(f"{h:<22}" for h in headers))
    print("-" * 100)

    def _fmt(x: Any, is_int: bool = False) -> str:
        if isinstance(x, float) and math.isnan(x):
            return "N/A".ljust(20)
        if x is None:
            return "N/A".ljust(20)
        if is_int:
            return f"{int(x):<20}"
        if isinstance(x, float):
            return f"{x:<20.4f}"
        return f"{str(x):<20}"

    for r in rows:
        print(
            f"{r.get('name',''):<22}"
            f"{_fmt(r.get('token_acc_val', float('nan')))}"
            f"{_fmt(r.get('latency_mean_ms', float('nan')))}"
            f"{_fmt(r.get('latency_std_ms', float('nan')))}"
            f"{_fmt(r.get('disk_size_mb', float('nan')))}"
            f"{_fmt(r.get('peak_mem_gb', float('nan')))}"
        )
    print("=" * 100)
    print(
        "说明:Token 准确率为教师强制下一词预测;在线教师准确率/显存/体积为云端不可见故标 N/A。"
        "延迟:学生/本地教师为本地 generate;在线教师为 API 往返(含网络)。"
    )


def print_optimization_hints(table: List[Dict[str, Any]]) -> None:
    """根据指标给出简要优化方向(启发式)。"""
    by_name = {r.get("name"): r for r in table}
    st = by_name.get("student(distilled)")
    api = by_name.get("teacher_api")
    if not st:
        return
    hints: List[str] = []
    acc = st.get("token_acc_val")
    if isinstance(acc, float) and not math.isnan(acc):
        if acc < 0.25:
            hints.append("验证集准确率偏低:可增加训练轮数、调大 lambda_soft 或扩充教师缓存数据。")
        elif acc < 0.4:
            hints.append("准确率有提升空间:尝试更小学习率收尾、或加入本地教师 logits 蒸馏。")
    lat_s = st.get("latency_mean_ms")
    if api and isinstance(lat_s, float) and not math.isnan(lat_s):
        lat_a = api.get("latency_mean_ms")
        if isinstance(lat_a, float) and not math.isnan(lat_a) and lat_a > 0:
            speedup = lat_a / max(lat_s, 1e-6)
            hints.append(
                f"相对在线 API,本地学生推理约 {speedup:.1f}x 更快(延迟对比,含网络开销)。"
            )
    mem = st.get("peak_mem_gb")
    if isinstance(mem, float) and not math.isnan(mem) and mem > 8:
        hints.append("峰值显存较高:可尝试梯度检查点、更小 batch、4bit 推理或更小学生模型。")
    if hints:
        print("\n优化提示:")
        for h in hints:
            print(" - ", h)


def run_full_evaluation(
    cfg: DistillConfig,
    student_model: Any,
    student_tok: Any,
    student_weights_dir: str,
    teacher_local: Optional[Tuple[Any, Any]],
    val_rows: List[Dict[str, Any]],
    baseline_model: Optional[Any] = None,
    baseline_tok: Optional[Any] = None,
    baseline_weights_dir: Optional[str] = None,
) -> None:
    if not val_rows:
        print("验证集为空,跳过评价。")
        return

    bench_ins = val_rows[0]["instruction"]
    table: List[Dict[str, Any]] = []

    table.append(
        evaluate_local_model_metrics(
            "student(distilled)",
            student_model,
            student_tok,
            val_rows,
            cfg,
            student_weights_dir,
            bench_ins,
        )
    )

    if baseline_model is not None and baseline_tok is not None:
        table.append(
            evaluate_local_model_metrics(
                "student(baseline)",
                baseline_model,
                baseline_tok,
                val_rows,
                cfg,
                baseline_weights_dir or "",
                bench_ins,
            )
        )

    if teacher_local:
        tm, ttok = teacher_local
        lt_id = (cfg.local_teacher_model_id or "").strip()
        lt_dir = lt_id if os.path.isdir(lt_id) else None
        table.append(
            evaluate_local_model_metrics(
                "teacher(local)",
                tm,
                ttok,
                val_rows,
                cfg,
                lt_dir,
                bench_ins,
            )
        )

    need_api = os.environ.get("OPENAI_API_KEY") or os.environ.get("DASHSCOPE_API_KEY")
    if need_api:
        table.append(evaluate_teacher_api_metrics(val_rows, cfg))

    print_eval_table(table)
    print_optimization_hints(table)

    out_json = os.path.join(
        os.environ.get("OUTPUT_DIR", "student_distilled"), "eval_report.json"
    )
    try:
        os.makedirs(os.path.dirname(out_json) or ".", exist_ok=True)
        with open(out_json, "w", encoding="utf-8") as f:
            json.dump(table, f, ensure_ascii=False, indent=2)
        print(f"评价结果已写入 {out_json}")
    except OSError as e:
        print(f"写入 eval_report.json 失败: {e}")


def load_model_eval(model_path: str, cfg: DistillConfig) -> Tuple[Any, Any]:
    tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        dtype=torch.bfloat16 if cfg.device == "cuda" else torch.float32,
        device_map="auto" if cfg.device == "cuda" else None,
    )
    if cfg.device != "cuda":
        model = model.to(cfg.device)
    model.eval()
    return model, tok


def run_evaluation_only(cfg: DistillConfig) -> None:
    set_seed(cfg.seed)
    path = cfg.eval_student_path or os.environ.get("OUTPUT_DIR", "student_distilled")
    if not os.path.isdir(path):
        raise SystemExit(f"未找到学生模型目录: {path}")

    val_rows = load_or_build_val_dataset(cfg)
    student, tok = load_model_eval(path, cfg)

    teacher_local = load_local_teacher_optional(cfg)
    baseline_model = None
    baseline_tok = None
    bdir = (cfg.eval_baseline_path or "").strip()
    if bdir and os.path.isdir(bdir):
        baseline_model, baseline_tok = load_model_eval(bdir, cfg)

    run_full_evaluation(
        cfg,
        student,
        tok,
        path,
        teacher_local,
        val_rows,
        baseline_model=baseline_model,
        baseline_tok=baseline_tok,
        baseline_weights_dir=bdir or None,
    )


# ---------------------------------------------------------------------------
# 检查点
# ---------------------------------------------------------------------------


def save_training_checkpoint(
    base_dir: str,
    student: Any,
    tok: Any,
    proj: nn.Module,
    opt: torch.optim.Optimizer,
    epoch: int,
    global_step: int,
) -> None:
    ckpt = os.path.join(base_dir, "checkpoint_last")
    os.makedirs(ckpt, exist_ok=True)
    student_dir = os.path.join(ckpt, "student")
    student.save_pretrained(student_dir)
    tok.save_pretrained(student_dir)
    torch.save(proj.state_dict(), os.path.join(ckpt, "hidden_proj.pt"))
    torch.save(opt.state_dict(), os.path.join(ckpt, "optimizer.pt"))
    with open(os.path.join(ckpt, "meta.json"), "w", encoding="utf-8") as f:
        json.dump({"epoch": epoch, "global_step": global_step}, f)


def _torch_load(path: str, map_location: Any) -> Any:
    try:
        return torch.load(path, map_location=map_location, weights_only=True)
    except TypeError:
        return torch.load(path, map_location=map_location)


def load_resume_state(cfg: DistillConfig, proj: nn.Module, opt: torch.optim.Optimizer, device: torch.device):
    """返回 (start_epoch, global_step)。无恢复文件时 (0, 0)。"""
    if not cfg.resume_from or not os.path.isdir(cfg.resume_from):
        return 0, 0
    proj_path = os.path.join(cfg.resume_from, "hidden_proj.pt")
    if os.path.isfile(proj_path):
        proj.load_state_dict(_torch_load(proj_path, map_location=device))
    opt_path = os.path.join(cfg.resume_from, "optimizer.pt")
    if os.path.isfile(opt_path):
        opt.load_state_dict(_torch_load(opt_path, map_location="cpu"))
    meta_path = os.path.join(cfg.resume_from, "meta.json")
    if not os.path.isfile(meta_path):
        return 0, 0
    with open(meta_path, encoding="utf-8") as f:
        meta = json.load(f)
    last_epoch = int(meta.get("epoch", -1))
    gstep = int(meta.get("global_step", 0))
    # 已完成 last_epoch,从下一段开始
    return last_epoch + 1, gstep


# ---------------------------------------------------------------------------
# 训练
# ---------------------------------------------------------------------------


def train(cfg: DistillConfig) -> None:
    set_seed(cfg.seed)
    device = torch.device(cfg.device)
    data = load_or_build_dataset(cfg)
    if cfg.collect_only:
        print(f"已写入/加载缓存 {cfg.cache_path},共 {len(data)} 条。使用 --collect-only 已跳过训练。")
        return

    val_data: List[Dict[str, Any]] = []
    try:
        val_data = load_or_build_val_dataset(cfg)
    except Exception as e:
        print(f"警告:验证集构建失败,将跳过验证与早停:{e}")

    student, tok_s = load_student(cfg)
    teacher_local = load_local_teacher_optional(cfg)
    teacher_model = teacher_local[0] if teacher_local else None

    V = student.config.vocab_size
    hidden_size = student.config.hidden_size
    mid_layer_idx = cfg.hidden_student_layer
    h_tgt_dim = (
        teacher_model.config.hidden_size
        if teacher_model is not None
        else student.get_input_embeddings().weight.size(1)
    )
    proj = nn.Linear(hidden_size, h_tgt_dim, bias=False).to(device)

    opt = torch.optim.AdamW(
        list(student.parameters()) + list(proj.parameters()),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
    )

    start_epoch, global_step = load_resume_state(cfg, proj, opt, device)
    if start_epoch >= cfg.num_epochs:
        print(f"resume 起始 epoch {start_epoch} 已不小于 num_epochs={cfg.num_epochs},无需训练。")
        return

    total_steps = max(len(data) * cfg.num_epochs, 1)
    use_cuda_amp = cfg.device == "cuda" and cfg.use_amp

    out_dir = os.environ.get("OUTPUT_DIR", "student_distilled")
    os.makedirs(out_dir, exist_ok=True)

    best_val_acc = -1.0
    patience_ctr = 0

    for epoch in range(start_epoch, cfg.num_epochs):
        T_epoch = temperature_schedule(epoch, cfg.num_epochs, cfg)
        random.shuffle(data)
        student.train()
        pbar = tqdm(data, desc=f"epoch {epoch+1}/{cfg.num_epochs}")
        accum = 0
        total_loss = 0.0

        for row in pbar:
            if cfg.temperature_schedule_mode.lower() == "step":
                T = temperature_schedule_step(global_step, total_steps, cfg)
            else:
                T = T_epoch

            ins = row["instruction"]
            ans = row["answer"]
            logprobs = row.get("logprobs")

            batch = build_chat_inputs(tok_s, ins, ans, cfg.max_length, device)
            input_ids = batch["input_ids"]
            attn = batch["attention_mask"]
            labels = batch["labels"]
            loss_mask = batch["loss_mask"]

            autocast_cm = (
                torch.autocast(device_type="cuda", dtype=torch.bfloat16)
                if use_cuda_amp
                else nullcontext()
            )

            with autocast_cm:
                out_s = student(
                    input_ids=input_ids,
                    attention_mask=attn,
                    output_hidden_states=True,
                    use_cache=False,
                )
                logits_s = out_s.logits
                shift_logits = logits_s[:, :-1, :].contiguous()
                shift_labels = input_ids[:, 1:].contiguous()
                shift_mask = loss_mask[:, 1:].contiguous()

                out_t = None
                if teacher_model is not None:
                    with torch.no_grad():
                        out_t = teacher_model(
                            input_ids=input_ids,
                            attention_mask=attn,
                            output_hidden_states=True,
                            use_cache=False,
                        )

                l_hard = masked_ce(shift_logits, shift_labels, shift_mask)

                teacher_probs = None
                if logprobs is not None:
                    ans_only_ids = shift_labels[0][shift_mask[0]]
                    teacher_probs = build_teacher_probs_from_logprobs(
                        tok_s, logprobs, ans_only_ids, V, device
                    )

                if teacher_probs is not None and teacher_probs.size(1) == shift_logits.size(1):
                    l_soft = kl_soft(shift_logits, teacher_probs, shift_mask, T)
                elif out_t is not None:
                    l_soft = kl_soft(
                        shift_logits,
                        F.softmax(out_t.logits[:, :-1, :] / T, dim=-1),
                        shift_mask,
                        T,
                    )
                else:
                    l_soft = kl_to_smoothed_one_hot(
                        shift_logits,
                        shift_labels,
                        shift_mask,
                        T,
                        cfg.label_smoothing_eps,
                        V,
                    )

                h_mid = out_s.hidden_states[mid_layer_idx][:, :-1, :]
                if out_t is not None:
                    h_t_mid = out_t.hidden_states[mid_layer_idx][:, :-1, :]
                    l_h = hidden_loss_with_local_teacher(h_mid, h_t_mid, proj, shift_mask)
                else:
                    emb_w = student.get_input_embeddings().weight
                    l_h = hidden_loss_embedding_align(
                        h_mid, emb_w, shift_labels, proj, shift_mask
                    )

                loss = (
                    cfg.lambda_hard * l_hard
                    + cfg.lambda_soft * l_soft
                    + cfg.lambda_hidden * l_h
                ) / cfg.grad_accum_steps

            loss.backward()
            accum += 1
            global_step += 1
            total_loss += float(loss.detach()) * cfg.grad_accum_steps

            if accum >= cfg.grad_accum_steps:
                torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(proj.parameters(), 1.0)
                opt.step()
                opt.zero_grad(set_to_none=True)
                accum = 0

            pbar.set_postfix(
                hard=float(l_hard.detach()),
                soft=float(l_soft.detach()),
                hid=float(l_h.detach()),
                T=float(T),
                step=global_step,
            )

        if accum > 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        print(
            f"epoch {epoch+1} mean approx batch loss: {total_loss / max(len(data), 1):.4f}"
        )

        if val_data:
            student.eval()
            v_acc = mean_token_accuracy_on_val(
                student, tok_s, val_data, cfg.max_length, device
            )
            student.train()
            print(f"epoch {epoch+1} val token_acc: {v_acc:.4f}")
            if cfg.early_stop_patience > 0:
                if v_acc > best_val_acc:
                    best_val_acc = v_acc
                    patience_ctr = 0
                    bdir_best = os.path.join(out_dir, "best_student")
                    student.save_pretrained(bdir_best)
                    tok_s.save_pretrained(bdir_best)
                    print(f"  -> 保存最佳验证集模型至 {bdir_best}")
                else:
                    patience_ctr += 1
                    if patience_ctr >= cfg.early_stop_patience:
                        print(
                            f"早停:验证集 {cfg.best_metric} 连续 {cfg.early_stop_patience} 轮未提升。"
                        )
                        break

        if cfg.save_every_epoch:
            save_training_checkpoint(
                out_dir, student, tok_s, proj, opt, epoch, global_step
            )

    student.save_pretrained(out_dir)
    tok_s.save_pretrained(out_dir)
    torch.save(proj.state_dict(), os.path.join(out_dir, "hidden_proj.pt"))
    print(f"Saved student to {out_dir}")

    if cfg.eval_after_train and val_data:
        print("\n训练结束,运行验证集评价(学生 vs 教师)...")
        student.eval()
        teacher_tuple = load_local_teacher_optional(cfg)
        baseline_m = baseline_t = None
        bp = (cfg.eval_baseline_path or "").strip()
        if bp and os.path.isdir(bp):
            baseline_m, baseline_t = load_model_eval(bp, cfg)
        run_full_evaluation(
            cfg,
            student,
            tok_s,
            out_dir,
            teacher_tuple,
            val_data,
            baseline_model=baseline_m,
            baseline_tok=baseline_t,
            baseline_weights_dir=bp or None,
        )
        if os.path.isdir(os.path.join(out_dir, "best_student")):
            print(
                "\n提示:若启用早停,可对比 `best_student` 与最终 `out_dir`,"
                "或运行: python distill_qwen.py --eval-only --eval-student <路径>"
            )


def parse_args() -> DistillConfig:
    p = argparse.ArgumentParser(description="Qwen 知识蒸馏(在线教师 -> 本地学生)")
    p.add_argument("--student", default=None, help="学生 HF 模型 ID 或路径")
    p.add_argument("--local-teacher", default=None, dest="local_teacher", help="本地教师模型(logits/隐层)")
    p.add_argument("--teacher-model", default=None, help="在线 API 模型名,如 qwen-plus")
    p.add_argument("--epochs", type=int, default=None)
    p.add_argument("--lr", type=float, default=None)
    p.add_argument("--cache", default=None, help="教师缓存 jsonl 路径")
    p.add_argument("--refresh-cache", action="store_true", help="删除旧缓存并重新调用 API")
    p.add_argument("--output-dir", default=None, help="最终模型与 checkpoint 目录")
    p.add_argument("--collect-only", action="store_true", help="只采集教师数据不写模型")
    p.add_argument(
        "--temp-schedule",
        choices=("epoch", "step"),
        default=None,
        help="温度衰减:按 epoch 固定或按全局 step 从 5→1",
    )
    p.add_argument("--no-amp", action="store_true", help="关闭 CUDA bf16 autocast")
    p.add_argument(
        "--resume",
        nargs="?",
        const="__CHECKPOINT_LAST__",
        default=None,
        help="恢复训练:仅 --resume 则使用 OUTPUT_DIR/checkpoint_last;或指定 checkpoint 目录",
    )
    p.add_argument("--grad-accum", type=int, default=None, dest="grad_accum")
    p.add_argument("--seed", type=int, default=None)
    p.add_argument("--eval-only", action="store_true", help="仅运行验证集评价,不训练")
    p.add_argument(
        "--no-eval-after-train",
        action="store_true",
        help="训练结束后不自动评价",
    )
    p.add_argument("--val-cache", default=None, help="验证集教师缓存 jsonl")
    p.add_argument(
        "--refresh-val-cache",
        action="store_true",
        help="删除验证集缓存并重新请求教师 API",
    )
    p.add_argument(
        "--early-stop",
        type=int,
        default=None,
        help="验证集 token 准确率早停耐心(轮数),0 关闭",
    )
    p.add_argument(
        "--eval-student",
        default=None,
        help="eval-only:学生模型目录(默认 OUTPUT_DIR)",
    )
    p.add_argument(
        "--eval-baseline",
        default=None,
        help="评价时额外加载「蒸馏前」学生权重目录做对比",
    )
    args = p.parse_args()

    cfg = DistillConfig()
    if args.student:
        cfg.student_model_id = args.student
    if args.local_teacher is not None:
        cfg.local_teacher_model_id = args.local_teacher
    if args.teacher_model:
        cfg.teacher_model = args.teacher_model
    if args.epochs is not None:
        cfg.num_epochs = args.epochs
    if args.lr is not None:
        cfg.lr = args.lr
    if args.cache:
        cfg.cache_path = args.cache
    if args.refresh_cache:
        cfg.refresh_cache = True
    if args.output_dir:
        os.environ["OUTPUT_DIR"] = args.output_dir
    if args.collect_only:
        cfg.collect_only = True
    if args.temp_schedule:
        cfg.temperature_schedule_mode = args.temp_schedule
    if args.no_amp:
        cfg.use_amp = False
    if args.resume is not None:
        if args.resume == "__CHECKPOINT_LAST__":
            out = os.environ.get("OUTPUT_DIR", "student_distilled")
            cfg.resume_from = os.path.abspath(os.path.join(out, "checkpoint_last"))
        else:
            cfg.resume_from = os.path.abspath(args.resume)
    if args.grad_accum is not None:
        cfg.grad_accum_steps = args.grad_accum
    if args.seed is not None:
        cfg.seed = args.seed
    if args.eval_only:
        cfg.eval_only = True
    if args.no_eval_after_train:
        cfg.eval_after_train = False
    if args.val_cache:
        cfg.val_cache_path = args.val_cache
    if args.refresh_val_cache:
        cfg.refresh_val_cache = True
    if args.early_stop is not None:
        cfg.early_stop_patience = args.early_stop
    if args.eval_student:
        cfg.eval_student_path = args.eval_student
    if args.eval_baseline:
        cfg.eval_baseline_path = args.eval_baseline
    return cfg


def main() -> None:
    cfg = parse_args()

    if cfg.eval_only:
        if (
            not os.path.isfile(cfg.val_cache_path)
            and not cfg.refresh_val_cache
            and not os.environ.get("OPENAI_API_KEY")
            and not os.environ.get("DASHSCOPE_API_KEY")
        ):
            print(
                "评价需要验证集缓存文件,或设置 OPENAI_API_KEY / DASHSCOPE_API_KEY 以自动生成。"
            )
        run_evaluation_only(cfg)
        return

    need_key = not os.path.isfile(cfg.cache_path) and not cfg.refresh_cache
    if need_key and not os.environ.get("OPENAI_API_KEY") and not os.environ.get("DASHSCOPE_API_KEY"):
        print(
            "请设置 OPENAI_API_KEY(及 OPENAI_BASE_URL)或 DASHSCOPE_API_KEY + USE_DASHSCOPE=1;"
            "或先准备已有缓存文件。"
        )
    train(cfg)


if __name__ == "__main__":
    main()

六、现实思考

✅ 问题1:模型蒸馏和将教师模型的训练数据直接训练学生模型有什么区别?

模型蒸馏:学生模型能够通过蒸馏获得更高的性能,尤其是在资源受限的环境中(如移动设备或边缘计算)。这是因为蒸馏过程允许学生模型继承教师模型的部分知识。蒸馏过程还可以通过调整温度参数和损失函数的组合,进一步优化学生模型的学习效果
直接训练学生模型:直接训练学生模型通常会导致性能较低,因为学生模型缺乏教师模型的复杂结构和深层次知识。在有限的数据和计算资源下,直接训练学生模型可能无法达到教师模型的性能水平。模型蒸馏通过软标签的知识传递,使学生模型能够学习到教师模型的深层次知识,从而在资源受限的环境中获得更高的性能和更好的泛化能力。而直接训练学生模型则更依赖于数据和计算资源,适合在资源充足的情况下使用。


七、结语:蒸馏不是终点,而是AI民主化的加速器

大模型蒸馏,表面看是一场"减法"------削减参数、降低算力、压缩体积。但其内核,却是一场深刻的"加法":为AI能力添加可部署性、可访问性、可扩展性与可持续性。

它让GPT-4o的智慧,能流淌进一部智能手机;让DeepSeek-R1的数学推理,能在医院CT机旁实时辅助诊断;让FinBERT的金融洞察,能嵌入银行柜台的微型终端。这不是对大模型的替代,而是对其价值的最大化释放

当然,挑战犹存:知识损失的理论边界、跨领域泛化的普适方法、零样本蒸馏的可靠性......这些正是Hinton 2015年提出KD以来,学术界与工业界持续攻坚的前沿。

🌈 最终,蒸馏的意义早已超越技术本身。正如:"它推动AI技术普及化、可访问性、实时化、可扩展性。" 当每一个开发者、每一家中小企业、每一台边缘设备,都能以低成本获得接近SOTA大模型的能力时,我们才真正站在了AI民主化的黎明。


参考资料与致谢

  • 掘金技术博客《大模型蒸馏技术详解:从原理到应用》(码上芯动,2025-02-19)
  • CSDN博文《模型蒸馏:大模型压缩与高效部署》(CC 4.0 BY-SA,2025-02-25;修改于2025-05-25)
  • DeepSeek蒸馏系列技术文档(2025年公开评测数据)

本文严格遵循各来源事实,所有模型名、参数量、分数、公式、流程均直接引自原文摘录,未作任何虚构或推测。

相关推荐
AI医影跨模态组学2 小时前
Cancer Lett(IF=10.1)北京大学第一医院杨尹默等团队:基于深度学习的病理组学特征可独立于CA19-9预测胰腺导管腺癌的生存与复发
人工智能·深度学习
码农三叔2 小时前
(2-1)常用传感器与基础原理:视觉传感器
人工智能·机器人·大模型·人形机器人
心勤则明2 小时前
Spring AI Alibaba MCP Gateway:将存量服务转换成 MCP Server
人工智能·spring·gateway
Fairy要carry2 小时前
面试-Skill粒度粗细的影响
人工智能
古希腊掌管代码的神THU2 小时前
【清华代码熊】RL后训练解析|Cursor Composer 2 技术报告
人工智能·深度学习·自然语言处理·composer
lpfasd1232 小时前
以Trae为例,拆解AI编程工具沙箱
人工智能·ai编程
猿类崛起@2 小时前
CherryStudio配置本地MCP服务器实现FileSystem本地文件系统读写操作
人工智能·学习·程序员·大模型·agent·ai大模型·mcp
AI医影跨模态组学3 小时前
Cell Rep Med(IF=10.6)北京清华长庚医院李国新&云南省肿瘤医院放射科李振辉等团队:基于TME的深度学习模型预测胃癌治疗反应
人工智能·深度学习·医学·医学影像·医学科研
宇擎智脑科技3 小时前
Claude Code 源码分析(二):Shell 命令安全体系 —— AI Agent 执行终端命令的纵深防御设计
人工智能·安全·claude code