10-让模型更小更聪明,学而不忘:知识蒸馏与持续学习

开篇:两个核心问题

在大模型的实际应用中,我们常常面临两个看似矛盾的需求:

问题1:模型太大怎么办?

yaml 复制代码
GPT-4: 1.76万亿参数 → 推理成本高昂
Llama 3 70B: 70B参数 → 需要多张A100
  ↓
能否压缩到更小的模型,但保持性能?

示例

场景 问题
边缘设备部署 手机/IoT设备无法运行70B模型
实时响应 大模型推理延迟高(秒级)
成本控制 API调用费用过高

解决方案知识蒸馏(Knowledge Distillation) - 用小模型学习大模型的知识


问题2:模型会"遗忘"怎么办?

css 复制代码
场景:持续学习新任务
  任务A(医疗问答) → 训练后效果好
  任务B(法律咨询) → 训练后效果好,但任务A性能下降!
  任务C(代码生成) → 训练后效果好,但任务A、B都变差!

示例

训练阶段 医疗问答准确率 法律咨询准确率 代码生成准确率
只训练医疗 90% - -
加入法律训练 65% ⚠️ 88% -
加入代码训练 45% ⚠️⚠️ 62% ⚠️ 85%

这就是灾难性遗忘(Catastrophic Forgetting)

解决方案回放机制(Experience Replay) - 让模型记住旧知识


第一部分:知识蒸馏(Knowledge Distillation)

什么是知识蒸馏?

核心思想:让小模型(学生)学习大模型(教师)的"思考过程",而不仅仅是最终答案。

类比:学习数学

erlang 复制代码
传统训练(学标准答案):
  题目:"2 + 2 = ?"
  标签:4
  学生学到:2 + 2 = 4

知识蒸馏(学老师的思路):
  题目:"2 + 2 = ?"
  老师的想法:
    - 4的概率:95%(最可能)
    - 3的概率:3%(也有点像)
    - 5的概率:2%(数字接近)
    - 其他:0.1%
  学生学到:不仅是答案,还有"为什么其他答案不对"的信息

关键洞察:大模型的**软标签(Soft Label)**包含了比硬标签(Hard Label)更丰富的信息。

蒸馏的基本原理

硬标签 vs 软标签

硬标签(One-hot)

python 复制代码
# 问题:"天空是什么颜色?"
hard_label = {
    "蓝色": 1.0,
    "红色": 0.0,
    "绿色": 0.0,
    "黑色": 0.0
}

软标签(概率分布)

python 复制代码
# 大模型的输出
soft_label = {
    "蓝色": 0.85,  # 大多数情况
    "黑色": 0.10,  # 夜晚的天空
    "红色": 0.03,  # 日出/日落
    "灰色": 0.02   # 阴天
}

优势 :软标签包含了上下文信息不确定性

深入理解:LLM中的硬标签vs软标签

上面的例子是分类任务,但LLM是生成任务,硬标签和软标签具体是什么样子呢?

关键理解:LLM的每一步生成本质上是分类任务

虽然LLM看起来是"生成"任务,但每预测一个词,本质上是在整个词表(通常50,000个词)上做分类:

python 复制代码
# LLM的生成过程
输入:"今天天气"
↓
需要预测:下一个词是什么?
↓
这是一个 50,000 分类问题(从词表中选一个词)

具体例子:预测下一个词

场景:输入 "今天天气",预测下一个词

硬标签(传统训练)

python 复制代码
# 训练数据
输入:  "今天天气"
标签:  "很好"  # Token ID: 1234

# One-hot 硬标签(词表大小=50000)
hard_label = [0, 0, 0, ..., 1, ..., 0, 0]
                           ↑
                      位置1234为1
                      其他49999个位置都是0

# 损失函数
loss = CrossEntropy(model_output, hard_label)
# 只关心位置1234的预测概率

问题

  • ❌ 只告诉模型"正确答案是'很好'"
  • ❌ 没有告诉模型为什么"不错"也可以(同义)
  • ❌ 没有告诉模型为什么"真好"也合理(相似表达)
  • ❌ 没有告诉模型为什么"糟糕"不对(反义)

软标签(教师模型的输出)

python 复制代码
# 大模型(教师)的实际输出
输入: "今天天气"

# 教师模型的 logits(未归一化的分数)
teacher_logits = {
    "很好":  8.5,
    "不错":  8.2,
    "真好":  7.8,
    "挺好":  7.5,
    "还行":  6.2,
    "一般":  4.5,
    "不好":  2.1,
    "糟糕":  1.3,
    ... # 其他49992个词
}

# 经过 softmax(T=1) 得到概率(软标签)
soft_label = {
    "很好":  0.52,  # 最可能
    "不错":  0.31,  # 也很可能
    "真好":  0.10,  # 还算可能
    "挺好":  0.04,  # 有点可能
    "还行":  0.02,  # 稍微可能
    "一般":  0.008, # 不太可能
    "不好":  0.001, # 很不可能
    "糟糕":  0.0005,# 几乎不可能
    其他词:   0.0015 # 极不可能
}

丰富的信息

  1. ✅ "很好"是最佳答案(52%)
  2. ✅ "不错"也很合理(31%)--- 同义词信息
  3. ✅ "真好"、"挺好"可以接受 --- 相似表达
  4. ✅ "一般"不太好但不是完全错误 --- 中性词
  5. ✅ "糟糕"几乎不可能 --- 反义词

完整的生成过程对比

让我们看一个完整句子的生成:

输入:"写一首关于春天的诗"

传统训练(硬标签)

python 复制代码
步骤1:
输入: "写一首关于春天的诗\n"
硬标签:"春"
模型学习:第一个字必须是"春" ✗

步骤2:
输入: "写一首关于春天的诗\n春"
硬标签:"风"
模型学习:第二个字必须是"风" ✗

步骤3:
输入: "写一首关于春天的诗\n春风"
硬标签:"拂"
模型学习:第三个字必须是"拂" ✗

结果:每一步只知道"正确答案",不知道其他选项为什么不对
→ 缺乏灵活性,容易过拟合

知识蒸馏(软标签)

python 复制代码
步骤1:
输入: "写一首关于春天的诗\n"

教师模型的软标签(概率分布):
{
    "春": 0.45,  # 最常见的开头
    "暖": 0.15,  # 也可以
    "阳": 0.10,  # 比较常见
    "万": 0.08,  # "万物复苏"
    "东": 0.05,  # "东风"
    "柳": 0.03,  # 也是春天的意象
    其他: 0.14
}

学生模型学习到:
✓ "春"字最好(直接点题)
✓ "暖"、"阳"也不错(温暖的意象)
✓ "万"可以(万物复苏)
✓ 其他字可能性很低

步骤2:
输入: "写一首关于春天的诗\n春"

教师模型的软标签:
{
    "风": 0.35,  # "春风"搭配
    "光": 0.20,  # "春光"
    "雨": 0.15,  # "春雨"
    "日": 0.10,  # "春日"
    "意": 0.08,  # "春意"
    "花": 0.05,  # "春花"
    其他: 0.07
}

学生模型学习到:
✓ "风"最佳(春风是经典搭配)
✓ "光"、"雨"、"日"都是好的选择
✓ 这些词都和"春"搭配良好
✓ 词语搭配的概率分布

结果:学生模型理解了多种可能性,生成更灵活、自然

数学形式对比

硬标签 Loss
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L hard = − log ⁡ P student ( y true ∣ x ) L_{\text{hard}} = -\log P_{\text{student}}(y_{\text{true}} | x) </math>Lhard=−logPstudent(ytrue∣x)

只优化"正确答案"的概率。

软标签 Loss(蒸馏)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L soft = − ∑ i = 1 V P teacher ( y i ∣ x ) log ⁡ P student ( y i ∣ x ) L_{\text{soft}} = -\sum_{i=1}^{V} P_{\text{teacher}}(y_i | x) \log P_{\text{student}}(y_i | x) </math>Lsoft=−i=1∑VPteacher(yi∣x)logPstudent(yi∣x)

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> V V </math>V:词表大小(如50,000)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> P teacher ( y i ∣ x ) P_{\text{teacher}}(y_i | x) </math>Pteacher(yi∣x):教师模型对第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个词的预测概率
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> P student ( y i ∣ x ) P_{\text{student}}(y_i | x) </math>Pstudent(yi∣x):学生模型对第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个词的预测概率

含义 :学生模型要学习教师模型的整个概率分布(50,000个词的概率),而不仅仅是最高概率的那个词。

实际案例:翻译任务

任务:翻译 "The weather is nice today"

硬标签训练的模型

python 复制代码
训练数据只有一个答案:
"The weather is nice today" → "今天天气很好"

模型学到的:
步骤1: "The" → "今" (概率1.0)
步骤2: "weather" → "天" (概率1.0)
...

结果:模型只会生成这一种翻译 ✗
缺乏灵活性,不能处理同义表达

软标签训练的模型

python 复制代码
教师模型(GPT-4)的输出分布:

步骤1: "The" →
{
    "今": 0.60,    # 最常见
    "今日": 0.25,  # 也可以
    "今儿": 0.08,  # 口语化
    其他: 0.07
}

步骤4: "nice" →
{
    "很": 0.35,  # "很好"
    "不": 0.30,  # "不错"
    "真": 0.15,  # "真好"
    "挺": 0.10,  # "挺好"
    其他: 0.10
}

学生模型学到:
✓ 多种表达方式的概率
✓ "今天天气很好" (最常见)
✓ "今天天气不错" (也很好)
✓ "今日天气真好" (稍正式)
✓ 不同表达的合理性 ✓

代码实现对比

python 复制代码
import torch
import torch.nn.functional as F

vocab_size = 50000  # 词表大小

# 输入:"今天天气",预测下一个词
input_text = "今天天气"

# ===== 硬标签训练 =====
# 真实标签:"很好" (token_id = 1234)
hard_label = torch.zeros(vocab_size)
hard_label[1234] = 1.0  # one-hot

student_logits = student_model(input_text)
loss_hard = F.cross_entropy(
    student_logits.unsqueeze(0),
    torch.tensor([1234])
)
# 只关心位置1234的概率 ✗

# ===== 软标签训练(蒸馏)=====
with torch.no_grad():
    teacher_logits = teacher_model(input_text)

# 使用温度软化
temperature = 2.0
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
# teacher_probs 形状: [50000]
# 例如:[0.0001, 0.0002, ..., 0.52(1234), 0.31(5678), ...]
#  ↑ 所有50000个位置都有概率值!✓

student_probs_soft = F.softmax(
    student_logits / temperature,
    dim=-1
)

# 软标签损失(KL散度)
loss_soft = F.kl_div(
    student_probs_soft.log(),
    teacher_probs,
    reduction='batchmean'
) * (temperature ** 2)

# 最终损失:软标签 + 硬标签
loss = 0.7 * loss_soft + 0.3 * loss_hard

软标签的实际价值总结

在LLM中,软标签传递了:

  1. 同义词信息

    • "很好"、"不错"、"真好"都是合理答案
    • 概率反映了它们的相似度
  2. 词语搭配知识

    • "春"后面接"风"、"光"、"雨"都合理
    • 概率反映了搭配的常见程度
  3. 上下文理解

    • 不同上下文下,同一个位置的词分布不同
    • 教师模型的理解被传递给学生
  4. 生成多样性

    • 学生知道多个选择都可行
    • 不会过拟合到单一答案

类比总结

scss 复制代码
硬标签:
告诉你 "2+2=4"
→ 只知道答案是4

软标签:
告诉你 "2+2=4(95%), 也可能是3.9(3%)或4.1(2%),
        但绝不是10(0.001%)"
→ 理解了整个数值空间的合理性

这就是为什么知识蒸馏在LLM中如此有效------软标签包含了教师模型对整个词表空间的深刻理解!

温度(Temperature)

为了让软标签更"软",使用温度参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> softmax ( z i , T ) = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) \text{softmax}(z_i, T) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} </math>softmax(zi,T)=∑jexp(zj/T)exp(zi/T)

效果

温度 <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T 效果 分布特征
<math xmlns="http://www.w3.org/1998/Math/MathML"> T = 1 T = 1 </math>T=1 标准softmax 正常分布
<math xmlns="http://www.w3.org/1998/Math/MathML"> T > 1 T > 1 </math>T>1 软化分布 更平滑,信息更丰富
<math xmlns="http://www.w3.org/1998/Math/MathML"> T → ∞ T \to \infty </math>T→∞ 均匀分布 所有类别概率接近
<math xmlns="http://www.w3.org/1998/Math/MathML"> T < 1 T < 1 </math>T<1 锐化分布 接近one-hot

示例

python 复制代码
import torch
import torch.nn.functional as F

logits = torch.tensor([4.0, 2.0, 1.0, 0.1])

# T=1: 标准softmax
p_t1 = F.softmax(logits, dim=0)
print(f"T=1:   {p_t1}")
# 输出: [0.8360, 0.1131, 0.0416, 0.0092]

# T=3: 软化
p_t3 = F.softmax(logits / 3, dim=0)
print(f"T=3:   {p_t3}")
# 输出: [0.5021, 0.2447, 0.1512, 0.1020]

# T=10: 更软
p_t10 = F.softmax(logits / 10, dim=0)
print(f"T=10:  {p_t10}")
# 输出: [0.3174, 0.2631, 0.2370, 0.1825]

观察:温度越高,概率分布越平滑,包含的"暗知识(Dark Knowledge)"越多。


方法1:Logits蒸馏(经典方法)

原理

学生模型同时学习两个目标

  1. 与真实标签的匹配(硬标签损失)
  2. 与教师模型的匹配(软标签损失)

数据流可视化

ini 复制代码
输入:"这部电影还不错"
    |
    |  (同时输入两个模型)
    |
    ├──────────────────────┬──────────────────────┐
    ↓                      ↓                      ↓
教师模型                学生模型              真实标签
[4.5, 0.3, 1.2]       [3.8, 0.8, 1.0]        "正面"
    |                      |                      |
    | T=3                  | T=3                  |
    ↓                      ↓                      |
[0.65, 0.15, 0.20]    [0.55, 0.22, 0.23]        |
  (软标签)               (学生软预测)            |
    |                      |                      |
    └──────KL散度─────────┘                      |
          ↓                                       |
      软标签损失                                  |
       = 0.042                                    |
                                                  |
                            学生logits            |
                          [3.8, 0.8, 1.0]         |
                                |   T=1           |
                                ↓                 |
                          [0.85, 0.04, 0.11]      |
                                |                 |
                                └──交叉熵─────────┘
                                      ↓
                                  硬标签损失
                                   = 0.163
                                      |
        ┌─────────────────────────────┴─────────────────────────────┐
        ↓                                                             ↓
    α * T² * 软标签损失                              (1-α) * 硬标签损失
  = 0.7 * 9 * 0.042                                = 0.3 * 0.163
  = 0.265                                          = 0.049
        |                                                             |
        └─────────────────────────┬───────────────────────────────┘
                                  ↓
                            总损失 = 0.314
                                  ↓
                            反向传播,更新学生模型

关键理解

  • 学生模型只前向传播一次,得到一个logits输出
  • 这个输出参与两个损失计算
    • 软化后(T=3)与教师的软化输出比较
    • 标准化(T=1)与真实标签比较
  • 两个损失加权求和,共同指导学生学习

损失函数

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L KD = α ⋅ L soft + ( 1 − α ) ⋅ L hard \mathcal{L}{\text{KD}} = \alpha \cdot \mathcal{L}{\text{soft}} + (1-\alpha) \cdot \mathcal{L}_{\text{hard}} </math>LKD=α⋅Lsoft+(1−α)⋅Lhard

其中:

软标签损失
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L soft = T 2 ⋅ KL ( softmax ( z S T )   ∥   softmax ( z T T ) ) \mathcal{L}_{\text{soft}} = T^2 \cdot \text{KL}\left( \text{softmax}\left(\frac{z^S}{T}\right) \, \Big\| \, \text{softmax}\left(\frac{z^T}{T}\right) \right) </math>Lsoft=T2⋅KL(softmax(TzS) softmax(TzT))

硬标签损失
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L hard = CrossEntropy ( z S , y ) \mathcal{L}_{\text{hard}} = \text{CrossEntropy}(z^S, y) </math>Lhard=CrossEntropy(zS,y)

参数:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> z T z^T </math>zT:教师模型的logits
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> z S z^S </math>zS:学生模型的logits
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y:真实标签
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T T </math>T:温度(通常2-10)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> α \alpha </math>α:平衡系数(通常0.7-0.9)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> T 2 T^2 </math>T2:补偿温度导致的梯度缩放

完整示例:理解损失计算

关键点:学生模型只有一个输出,但用两种方式评估

假设一个情感分类任务(正面/负面/中性),输入:"这部电影还不错"

python 复制代码
# ===== 第1步:教师模型推理 =====
teacher_logits = teacher_model("这部电影还不错")
# 输出 logits(未归一化的分数):
# [正面: 4.5, 负面: 0.3, 中性: 1.2]

# 教师的概率分布(T=1):
teacher_probs_T1 = softmax(teacher_logits / 1)
# [正面: 0.91, 负面: 0.01, 中性: 0.08]
# → 教师非常确信是"正面"

# 教师的软化概率(T=3,用于蒸馏):
teacher_probs_T3 = softmax(teacher_logits / 3)
# [正面: 0.65, 负面: 0.15, 中性: 0.20]
# → 软化后,"中性"和"负面"的概率提高了,包含更多信息


# ===== 第2步:学生模型推理(同一个输入) =====
student_logits = student_model("这部电影还不错")
# 输出 logits:
# [正面: 3.8, 负面: 0.8, 中性: 1.0]

# 学生的概率分布(T=1):
student_probs_T1 = softmax(student_logits / 1)
# [正面: 0.85, 负面: 0.04, 中性: 0.11]

# 学生的软化概率(T=3):
student_probs_T3 = softmax(student_logits / 3)
# [正面: 0.55, 负面: 0.22, 中性: 0.23]


# ===== 第3步:计算两个损失 =====

# 损失1:硬标签损失(与真实标签比较)
true_label = "正面"  # 人类标注的真实答案
hard_loss = CrossEntropy(student_probs_T1, true_label)
# = -log(0.85) = 0.163
# 评估:学生对正确答案(正面)的预测准确性

# 损失2:软标签损失(与教师的思考过程比较)
soft_loss = KL_Divergence(student_probs_T3, teacher_probs_T3)
# = 0.65*log(0.65/0.55) + 0.15*log(0.15/0.22) + 0.20*log(0.20/0.23)
# ≈ 0.042
# 评估:学生的思考过程与教师的相似度

# 损失3:总损失(加权组合)
alpha = 0.7  # 软标签权重
total_loss = alpha * (3^2) * soft_loss + (1 - alpha) * hard_loss
#          = 0.7 * 9 * 0.042 + 0.3 * 0.163
#          = 0.265 + 0.049
#          = 0.314

重点理解

  1. 学生模型只输出一次student_logits = [3.8, 0.8, 1.0]

  2. 这一个输出被两种方式使用

    • T=1 的概率与真实标签比较 → 硬标签损失
    • T=3 的概率与教师比较 → 软标签损失
  3. 为什么两个都需要?

    arduino 复制代码
    只用硬标签损失:
      学生只学会"这是正面评价"(结果)
      ✗ 不知道为什么不是中性或负面
    
    只用软标签损失:
      学生学会了教师的思考模式(过程)
      ✗ 但可能偏离真实标签(如果教师也会犯错)
    
    两者结合:
      ✓ 既学会了正确答案(硬标签)
      ✓ 又学会了思考过程(软标签)
  4. 软标签包含的"额外信息"

    csharp 复制代码
    硬标签告诉学生:
      "答案是正面" [1, 0, 0]
    
    软标签告诉学生:
      "正面可能性65%,但也有20%可能是中性(因为'还不错'
       比较温和),15%可能有些负面情绪"
      [0.65, 0.15, 0.20]
    
    学生学到:
      - 主要特征:积极词汇 → 正面
      - 细微差异:"还不错"比"很棒"更温和 → 也有中性成分
      - 边界情况:如何区分"温和正面"和"中性"

常见疑问解答

Q1: 为什么不只用软标签损失?教师已经学到了正确答案。

A: 因为教师也可能犯错,或者在某些样本上不够自信。硬标签提供了"ground truth",确保学生不会被教师的错误误导。

python 复制代码
# 例子:教师在困难样本上可能不确定
教师预测: [正面: 0.48, 负面: 0.52]  # 教师认为略偏负面
真实标签: "正面"                    # 实际是正面

只用软标签 → 学生学到"这是负面的"(错误)
加上硬标签 → 学生知道真实答案是正面,但也学到这是个"接近边界"的案例

Q2: 为什么不只用硬标签损失?传统训练不就是这样吗?

A: 硬标签只提供了"对/错"的信息,丢失了很多细节:

python 复制代码
硬标签: [1, 0, 0]  # 只知道第一个类是对的
软标签: [0.85, 0.10, 0.05]  # 知道第一个类最可能,第二个类也有点像,第三个类完全不像

学生从软标签学到:
- 类别之间的相似性(哪些类容易混淆)
- 决策边界的位置(多接近才算"像")
- 不确定性的估计(模型有多自信)

Q3: 学生模型到底输出什么?

A: 学生模型只输出一个东西:logits向量

python 复制代码
# 完整流程
input = "这部电影还不错"

# 学生只做一次前向传播
student_logits = student_model(input)
# → [3.8, 0.8, 1.0]  (就这一个输出!)

# 然后这个输出被用于两个损失计算:
# 用法1:软化后与教师比较
student_soft = softmax(student_logits / 3)  # T=3
loss_soft = KL(student_soft, teacher_soft)

# 用法2:标准化后与标签比较
student_normal = softmax(student_logits / 1)  # T=1
loss_hard = CE(student_normal, true_label)

# 两个损失加权求和
total_loss = 0.7 * loss_soft + 0.3 * loss_hard

# 反向传播,更新 student_model 的参数
total_loss.backward()

Q4: 为什么软标签损失要乘以 T²?

A: 因为温度会缩放梯度,T² 是为了补偿:

python 复制代码
# 当T增大时,softmax的输出变化变小
# 导致梯度也变小
# 乘以T²可以保持梯度的尺度与硬标签损失相当

# 数学上:
∂L_soft/∂z ∝ 1/T²  (温度缩放导致梯度变小)
L_soft × T² → 梯度恢复正常尺度

代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7):
        """
        Args:
            temperature: 软化温度
            alpha: 软标签损失的权重(硬标签权重为 1-alpha)
        """
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        """
        Args:
            student_logits: 学生模型输出 [batch, num_classes]
            teacher_logits: 教师模型输出 [batch, num_classes]
            labels: 真实标签 [batch]

        关键:student_logits是同一个输出,被两种方式使用:
            - 软化后与教师比较(学习思考过程)
            - 直接与真实标签比较(学习正确答案)
        """
        # 1. 软标签损失(KL散度)
        # 用高温softmax软化,让概率分布更平滑
        student_soft = F.log_softmax(student_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)

        soft_loss = self.kl_div(student_soft, teacher_soft) * (self.temperature ** 2)
        # 目标:让学生的概率分布接近教师的概率分布

        # 2. 硬标签损失(交叉熵)
        # 用标准softmax(T=1),与真实标签比较
        hard_loss = self.ce(student_logits, labels)
        # 目标:让学生预测正确的类别

        # 3. 总损失(加权组合)
        loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
        # alpha=0.7 表示:70%关注教师的思考过程,30%关注正确答案

        return loss, soft_loss, hard_loss

# 使用示例
teacher_model = load_large_model()  # 70B模型
student_model = create_small_model()  # 7B模型

kd_loss = KnowledgeDistillationLoss(temperature=3.0, alpha=0.8)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-5)

for batch in dataloader:
    inputs, labels = batch

    # 教师模型推理(不计算梯度)
    with torch.no_grad():
        teacher_logits = teacher_model(inputs)

    # 学生模型推理
    student_logits = student_model(inputs)

    # 计算蒸馏损失
    loss, soft_loss, hard_loss = kd_loss(
        student_logits, teacher_logits, labels
    )

    # 反向传播
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"Total: {loss:.4f}, Soft: {soft_loss:.4f}, Hard: {hard_loss:.4f}")

为什么有效?

信息量对比

arduino 复制代码
硬标签:
  "北京是中国的首都" → 标签: 正确 (1.0)
  信息量:1 bit

软标签:
  "北京是中国的首都" → 教师模型输出:
    - 正确: 0.95
    - 错误但相关: 0.03("北京在中国")
    - 错误且无关: 0.02("北京不存在")
  信息量:~5 bits(更丰富)

学生学到

  • ✅ 北京确实是首都(主要信息)
  • ✅ 一些接近正确的表述也有道理(细微差异)
  • ✅ 某些表述明显错误(边界信息)

方法2:特征蒸馏(白盒方法)

原理

不仅学习最终输出,还学习中间层的特征表示

架构对比

less 复制代码
Logits蒸馏:
  Teacher: Input → Transformer → Logits
                                    ↓
  Student: Input → Transformer → Logits
           (只匹配最后输出)

特征蒸馏:
  Teacher: Input → Layer1 → Layer2 → ... → Logits
                     ↓        ↓              ↓
  Student: Input → Layer1 → Layer2 → ... → Logits
           (匹配多个中间层)

损失函数

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L Feature = ∑ l ∈ L λ l ⋅ ∥ H l S − Proj ( H l T ) ∥ 2 \mathcal{L}{\text{Feature}} = \sum{l \in \mathcal{L}} \lambda_l \cdot \| H_l^S - \text{Proj}(H_l^T) \|^2 </math>LFeature=l∈L∑λl⋅∥HlS−Proj(HlT)∥2

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> H l T H_l^T </math>HlT:教师模型第 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l 层的隐藏状态
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> H l S H_l^S </math>HlS:学生模型第 <math xmlns="http://www.w3.org/1998/Math/MathML"> l l </math>l 层的隐藏状态
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> Proj \text{Proj} </math>Proj:投影层(因为教师和学生的维度可能不同)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> L \mathcal{L} </math>L:选择的层(通常选几个关键层)

代码实现

python 复制代码
class FeatureDistillationLoss(nn.Module):
    def __init__(self, teacher_dim, student_dim, num_layers=4):
        super().__init__()
        # 投影层:将教师特征投影到学生维度
        self.projections = nn.ModuleList([
            nn.Linear(teacher_dim, student_dim)
            for _ in range(num_layers)
        ])
        self.mse = nn.MSELoss()

    def forward(self, student_features, teacher_features):
        """
        Args:
            student_features: List of [batch, seq_len, student_dim]
            teacher_features: List of [batch, seq_len, teacher_dim]
        """
        total_loss = 0

        for i, (s_feat, t_feat) in enumerate(
            zip(student_features, teacher_features)
        ):
            # 投影教师特征
            t_feat_proj = self.projections[i](t_feat)

            # MSE损失
            loss = self.mse(s_feat, t_feat_proj)
            total_loss += loss

        return total_loss / len(student_features)

# 使用示例
class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([...])
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, return_features=False):
        features = []
        hidden = x

        for layer in self.layers:
            hidden = layer(hidden)
            if return_features:
                features.append(hidden)

        logits = self.lm_head(hidden)

        if return_features:
            return logits, features
        return logits

# 训练
feature_loss_fn = FeatureDistillationLoss(
    teacher_dim=4096,  # 教师隐藏维度
    student_dim=2048   # 学生隐藏维度
)

for batch in dataloader:
    inputs, labels = batch

    # 教师前向(获取中间特征)
    with torch.no_grad():
        teacher_logits, teacher_features = teacher_model(
            inputs, return_features=True
        )

    # 学生前向
    student_logits, student_features = student_model(
        inputs, return_features=True
    )

    # 特征蒸馏损失
    feature_loss = feature_loss_fn(student_features, teacher_features)

    # Logits蒸馏损失
    logit_loss = kd_loss(student_logits, teacher_logits, labels)

    # 总损失
    loss = logit_loss + 0.5 * feature_loss

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

优势

  • ✅ 学习更深层的知识表示
  • ✅ 更好的泛化能力
  • ✅ 训练更稳定

劣势

  • ⚠️ 需要访问教师模型内部(白盒)
  • ⚠️ 计算开销更大
  • ⚠️ 架构设计更复杂

方法3:响应蒸馏(黑盒方法)

原理

完全不需要教师模型的内部结构,只使用教师模型的文本输出。

适用场景

  • 教师模型是API(如GPT-4、Claude)
  • 教师模型不开源
  • 无法获取教师模型的logits

流程

arduino 复制代码
步骤1:用教师模型生成高质量数据
  输入: "解释量子计算"
  教师输出: "量子计算是利用量子力学原理进行计算的技术..."

步骤2:学生模型学习模仿教师的输出
  训练数据: (输入, 教师输出)
  目标: 最小化学生输出与教师输出的差异

实现方法

方法A:直接监督学习

python 复制代码
# 1. 收集教师响应
teacher_responses = []

for prompt in prompts:
    # 调用API
    response = teacher_model.generate(prompt)
    teacher_responses.append({
        "prompt": prompt,
        "response": response
    })

# 2. 用教师响应训练学生
for batch in create_dataloader(teacher_responses):
    prompts, responses = batch

    # 学生模型前向
    student_outputs = student_model(prompts)

    # 标准语言模型损失
    loss = cross_entropy_loss(student_outputs, responses)

    loss.backward()
    optimizer.step()

方法B:排序蒸馏(Ranking Distillation)

让学生学习教师对多个回答的偏好排序

python 复制代码
class RankingDistillation:
    def __init__(self, teacher_model, student_model):
        self.teacher = teacher_model
        self.student = student_model

    def collect_ranked_data(self, prompts):
        dataset = []

        for prompt in prompts:
            # 生成多个候选回答
            candidates = []
            for _ in range(4):
                response = self.student.generate(prompt, temperature=0.8)
                candidates.append(response)

            # 教师评分(使用reward model或直接打分)
            scores = []
            for candidate in candidates:
                score = self.teacher.score(prompt, candidate)
                scores.append(score)

            # 排序
            ranked = sorted(
                zip(candidates, scores),
                key=lambda x: x[1],
                reverse=True
            )

            dataset.append({
                "prompt": prompt,
                "best": ranked[0][0],
                "worst": ranked[-1][0]
            })

        return dataset

    def train(self, dataset):
        for batch in dataset:
            prompt = batch["prompt"]
            best = batch["best"]
            worst = batch["worst"]

            # 学生模型的log概率
            logp_best = self.student.log_prob(prompt, best)
            logp_worst = self.student.log_prob(prompt, worst)

            # 排序损失(类似DPO)
            loss = -F.logsigmoid(logp_best - logp_worst).mean()

            loss.backward()
            optimizer.step()

优缺点

优点

  • ✅ 完全黑盒,无需访问教师内部
  • ✅ 可以利用API服务
  • ✅ 简单易实现

缺点

  • ❌ 信息损失大(只有文本,没有概率分布)
  • ❌ 需要大量生成数据
  • ❌ 效果通常不如白盒方法

蒸馏效果对比

实验设置

  • 教师:Llama 3 70B
  • 学生:Llama 3 8B
  • 任务:MMLU(通用知识问答)
方法 准确率 相对教师 训练时间 需要访问
学生原始 62.3% -12.7% - -
Logits蒸馏 68.5% -6.5% 2天 Logits
特征蒸馏 69.8% -5.2% 3天 内部特征
响应蒸馏 65.7% -9.3% 1.5天 仅文本
教师模型 75.0% - - -

观察

  • 特征蒸馏效果最好,但需要白盒访问
  • Logits蒸馏是性价比最高的方法
  • 响应蒸馏适合API场景,效果略差

第二部分:灾难性遗忘(Catastrophic Forgetting)

什么是灾难性遗忘?

定义 :神经网络在学习新任务时,会急剧遗忘之前学习的任务知识。

类比:学生学习

复制代码
传统学生:
  学数学 → 数学能力 ↑
  学物理 → 数学能力保持,物理能力 ↑
  学化学 → 数学、物理保持,化学能力 ↑

神经网络:
  学数学 → 数学能力 ↑
  学物理 → 数学能力 ↓↓,物理能力 ↑
  学化学 → 数学、物理能力 ↓↓↓,化学能力 ↑

实验:观察灾难性遗忘

实验设计

python 复制代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# 任务A:情感分类(电影评论)
# 任务B:主题分类(新闻文章)

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(10000, 128)
        self.lstm = nn.LSTM(128, 256, batch_first=True)
        self.fc = nn.Linear(256, 2)

    def forward(self, x):
        emb = self.embedding(x)
        _, (hidden, _) = self.lstm(emb)
        return self.fc(hidden[-1])

# 评估函数
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)

    return correct / total

# 实验流程
model = SimpleModel()

# 阶段1:训练任务A
print("=== 训练任务A(情感分类)===")
train_on_task_a(model, task_a_data, epochs=5)
acc_a_after_a = evaluate(model, task_a_test)
print(f"任务A准确率: {acc_a_after_a:.2%}")

# 阶段2:训练任务B
print("\n=== 训练任务B(主题分类)===")
train_on_task_b(model, task_b_data, epochs=5)
acc_a_after_b = evaluate(model, task_a_test)  # 再次评估任务A
acc_b_after_b = evaluate(model, task_b_test)
print(f"任务A准确率: {acc_a_after_b:.2%} (下降 {acc_a_after_a - acc_a_after_b:.2%})")
print(f"任务B准确率: {acc_b_after_b:.2%}")

# 阶段3:训练任务C
print("\n=== 训练任务C(实体识别)===")
train_on_task_c(model, task_c_data, epochs=5)
acc_a_after_c = evaluate(model, task_a_test)
acc_b_after_c = evaluate(model, task_b_test)
acc_c_after_c = evaluate(model, task_c_test)
print(f"任务A准确率: {acc_a_after_c:.2%} (下降 {acc_a_after_a - acc_a_after_c:.2%})")
print(f"任务B准确率: {acc_b_after_c:.2%} (下降 {acc_b_after_b - acc_b_after_c:.2%})")
print(f"任务C准确率: {acc_c_after_c:.2%}")

典型结果

less 复制代码
=== 训练任务A(情感分类)===
任务A准确率: 89.5%

=== 训练任务B(主题分类)===
任务A准确率: 67.2% (下降 22.3%) ⚠️
任务B准确率: 86.3%

=== 训练任务C(实体识别)===
任务A准确率: 52.1% (下降 37.4%) ⚠️⚠️
任务B准确率: 63.8% (下降 22.5%) ⚠️
任务C准确率: 84.7%

可视化

css 复制代码
准确率
  ^
  |  任务A
90|  ●━━━━╮
  |        ╰━━━━╮
  |              ╰━━━━━━━━●  任务C
  |                   任务B
60|                   ●━━━━╮
  |                        ╰━━━━●
  |
30|
  |
  +─────────────────────────────────> 训练阶段
     训练A    训练B         训练C

为什么会发生灾难性遗忘?

原因1:参数重写

神经网络的参数是共享的

css 复制代码
任务A学习:
  参数θ初始 → 参数θ_A(优化到任务A最优)

任务B学习:
  参数θ_A → 参数θ_B(优化到任务B最优)
  但θ_B可能对任务A很差!

数学解释

假设任务A和B的损失函数分别为 <math xmlns="http://www.w3.org/1998/Math/MathML"> L A ( θ ) \mathcal{L}_A(\theta) </math>LA(θ) 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> L B ( θ ) \mathcal{L}_B(\theta) </math>LB(θ)

  • 训练任务A后: <math xmlns="http://www.w3.org/1998/Math/MathML"> θ A = arg ⁡ min ⁡ L A ( θ ) \theta_A = \arg\min \mathcal{L}_A(\theta) </math>θA=argminLA(θ)
  • 训练任务B后: <math xmlns="http://www.w3.org/1998/Math/MathML"> θ B = arg ⁡ min ⁡ L B ( θ ) \theta_B = \arg\min \mathcal{L}_B(\theta) </math>θB=argminLB(θ)

问题: <math xmlns="http://www.w3.org/1998/Math/MathML"> θ B \theta_B </math>θB 不保证 <math xmlns="http://www.w3.org/1998/Math/MathML"> L A ( θ B ) \mathcal{L}_A(\theta_B) </math>LA(θB) 小!

原因2:梯度冲突

任务A和B的梯度可能相反

ini 复制代码
参数w的梯度:
  任务A: ∂L_A/∂w = +2.5  (希望增大w)
  任务B: ∂L_B/∂w = -3.0  (希望减小w)

结果:训练任务B时,w减小,损害任务A性能

原因3:分布偏移

新任务的数据分布不同

arduino 复制代码
任务A(医疗):
  - 词汇:"症状"、"诊断"、"治疗"
  - 句式:正式、专业

任务B(社交媒体):
  - 词汇:"点赞"、"转发"、"吐槽"
  - 句式:口语、非正式

模型适应B后,A的特征激活模式被改变

遗忘的数学分析

Fisher信息矩阵

核心思想:某些参数对旧任务"更重要",不应该被大幅修改。

Fisher信息矩阵定义:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> F i = E x ∼ D A [ ( ∂ log ⁡ p ( y ∣ x ; θ A ) ∂ θ i ) 2 ] F_i = \mathbb{E}_{x \sim \mathcal{D}_A} \left[ \left( \frac{\partial \log p(y|x; \theta_A)}{\partial \theta_i} \right)^2 \right] </math>Fi=Ex∼DA[(∂θi∂logp(y∣x;θA))2]

含义

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> F i F_i </math>Fi 大:参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ i \theta_i </math>θi 对任务A很重要(梯度大且稳定)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> F i F_i </math>Fi 小:参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> θ i \theta_i </math>θi 对任务A不太重要(可以安全修改)

可视化

makefile 复制代码
参数重要性:

θ₁: ████████████ F₁=12.5  (非常重要,不能改)
θ₂: ████████     F₂=8.0   (重要)
θ₃: ███          F₃=3.0   (一般)
θ₄: █            F₄=1.0   (不重要,可以改)

启发 :训练新任务时,重要参数应该少改或不改


第三部分:回放机制(Experience Replay)

什么是回放?

核心思想 :在学习新任务时,同时回放旧任务的数据,让模型"不忘初心"。

类比:学生复习

erlang 复制代码
不使用回放:
  第1天:学数学(专注数学,100%时间)
  第2天:学物理(专注物理,100%时间)← 忘记数学
  第3天:学化学(专注化学,100%时间)← 忘记数学和物理

使用回放:
  第1天:学数学(100%时间)
  第2天:学物理(70%时间) + 复习数学(30%时间)
  第3天:学化学(60%时间) + 复习数学、物理(40%时间)

回放方法1:原始数据回放(Naive Replay)

原理

保存旧任务的训练数据,与新任务数据混合训练。

实现

python 复制代码
class ExperienceReplayBuffer:
    def __init__(self, max_size=10000):
        self.buffer = []
        self.max_size = max_size

    def add_task_data(self, task_data):
        """
        添加新任务的数据到缓冲区

        Args:
            task_data: List of (input, label) tuples
        """
        # 如果缓冲区满了,随机替换旧数据
        for sample in task_data:
            if len(self.buffer) < self.max_size:
                self.buffer.append(sample)
            else:
                # 随机替换
                idx = random.randint(0, self.max_size - 1)
                self.buffer[idx] = sample

    def sample(self, batch_size):
        """随机采样"""
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))

# 使用示例
replay_buffer = ExperienceReplayBuffer(max_size=5000)

# 训练任务A
for epoch in range(5):
    for batch in task_a_dataloader:
        inputs, labels = batch

        # 训练
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 保存数据到回放缓冲区
        replay_buffer.add_task_data(list(zip(inputs, labels)))

print(f"缓冲区大小: {len(replay_buffer.buffer)}")

# 训练任务B(同时回放任务A)
for epoch in range(5):
    for batch in task_b_dataloader:
        new_inputs, new_labels = batch

        # 1. 训练新任务数据
        outputs = model(new_inputs)
        loss_new = criterion(outputs, new_labels)

        # 2. 回放旧任务数据
        if len(replay_buffer.buffer) > 0:
            replay_samples = replay_buffer.sample(batch_size=32)
            replay_inputs, replay_labels = zip(*replay_samples)
            replay_inputs = torch.stack(replay_inputs)
            replay_labels = torch.tensor(replay_labels)

            replay_outputs = model(replay_inputs)
            loss_replay = criterion(replay_outputs, replay_labels)
        else:
            loss_replay = 0

        # 3. 总损失
        loss = loss_new + 0.5 * loss_replay  # 可调整权重

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

效果对比

方法 任务A (训练B后) 任务B 总平均
无回放 67.2% (-22.3%) 86.3% 76.8%
回放 (50%) 82.5% (-7.0%) 84.1% 83.3% ✅
回放 (100%) 87.2% (-2.3%) 82.8% 85.0% ✅✅

观察

  • 回放显著减少遗忘
  • 回放比例越高,旧任务保留越好(但新任务可能略有下降)

优缺点

优点

  • ✅ 简单有效
  • ✅ 理论上最优(如果数据充足)

缺点

  • 存储开销大(需要保存原始数据)
  • 隐私问题(无法删除用户数据)
  • 数据不平衡(多任务时缓冲区有限)

回放方法2:生成式回放(Generative Replay)

原理

不保存原始数据,而是用生成模型"回忆"旧数据

流程

css 复制代码
步骤1:训练任务A
  → 训练主模型(分类器)
  → 同时训练生成器G_A(学习生成任务A的数据)

步骤2:训练任务B
  → 用G_A生成"伪"任务A数据
  → 与真实任务B数据混合训练
  → 训练生成器G_B

实现

python 复制代码
class GenerativeReplay:
    def __init__(self, main_model, generator):
        """
        Args:
            main_model: 主任务模型(如分类器)
            generator: 生成模型(如VAE或diffusion)
        """
        self.main_model = main_model
        self.generator = generator

    def train_task(self, new_data, prev_generators=None):
        """
        训练新任务

        Args:
            new_data: 新任务的真实数据
            prev_generators: 之前任务的生成器列表
        """
        optimizer_main = torch.optim.Adam(self.main_model.parameters())
        optimizer_gen = torch.optim.Adam(self.generator.parameters())

        for epoch in range(num_epochs):
            for batch in new_data:
                real_inputs, real_labels = batch

                # ===== 训练主模型 =====
                # 1. 新任务数据
                outputs = self.main_model(real_inputs)
                loss_new = criterion(outputs, real_labels)

                # 2. 生成旧任务数据(回放)
                if prev_generators:
                    loss_replay = 0
                    for old_gen in prev_generators:
                        # 生成伪数据
                        with torch.no_grad():
                            fake_inputs = old_gen.generate(batch_size=32)

                        # 用当前模型预测(目标是保持旧模型的预测)
                        outputs = self.main_model(fake_inputs)

                        # 用旧模型的预测作为伪标签
                        with torch.no_grad():
                            pseudo_labels = old_main_model(fake_inputs).argmax(dim=1)

                        loss_replay += criterion(outputs, pseudo_labels)

                    loss_main = loss_new + 0.5 * loss_replay
                else:
                    loss_main = loss_new

                # 更新主模型
                optimizer_main.zero_grad()
                loss_main.backward()
                optimizer_main.step()

                # ===== 训练生成器 =====
                # 让生成器学习生成当前任务的数据
                fake_data = self.generator.generate(batch_size=32)
                gen_loss = generator_loss(fake_data, real_inputs)

                optimizer_gen.zero_grad()
                gen_loss.backward()
                optimizer_gen.step()

        return self.generator  # 返回当前生成器,供后续任务使用

# 使用
replay = GenerativeReplay(
    main_model=classifier,
    generator=VAE()
)

# 训练任务A
gen_a = replay.train_task(task_a_data, prev_generators=None)

# 训练任务B(使用gen_a回放)
gen_b = replay.train_task(task_b_data, prev_generators=[gen_a])

# 训练任务C
gen_c = replay.train_task(task_c_data, prev_generators=[gen_a, gen_b])

优缺点

优点

  • 不需要存储原始数据(解决隐私问题)
  • 内存需求小(只存生成器参数)
  • 可扩展(生成器可压缩)

缺点

  • 生成质量影响效果(生成器不好则回放失效)
  • 训练成本高(需要额外训练生成器)
  • 不适合复杂数据(如高分辨率图像、长文本)

回放方法3:知识蒸馏回放(Distillation Replay)

原理

结合知识蒸馏和回放:用旧模型的软标签作为回放目标。

关键洞察

  • 不需要保存原始数据
  • 不需要训练生成器
  • 当前模型在旧任务上的预测作为回放信号

实现

python 复制代码
class DistillationReplay:
    def __init__(self, model, temperature=2.0):
        self.model = model
        self.temperature = temperature
        self.old_model = None  # 保存旧模型

    def train_new_task(self, new_data, replay_data=None):
        """
        训练新任务,同时通过蒸馏回放旧知识

        Args:
            new_data: 新任务数据 (inputs, labels)
            replay_data: 回放数据(只需输入,不需要标签)
        """
        # 复制当前模型作为"旧模型"(冻结)
        if self.old_model is None:
            self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()

        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for new_batch in new_data:
                new_inputs, new_labels = new_batch

                # ===== 新任务损失 =====
                new_outputs = self.model(new_inputs)
                loss_new = F.cross_entropy(new_outputs, new_labels)

                # ===== 回放损失(蒸馏) =====
                if replay_data and len(replay_data) > 0:
                    # 采样回放数据
                    replay_inputs = random.sample(replay_data, k=32)
                    replay_inputs = torch.stack(replay_inputs)

                    # 当前模型的输出
                    curr_outputs = self.model(replay_inputs)

                    # 旧模型的输出(软标签)
                    with torch.no_grad():
                        old_outputs = self.old_model(replay_inputs)

                    # 蒸馏损失
                    curr_soft = F.log_softmax(
                        curr_outputs / self.temperature, dim=1
                    )
                    old_soft = F.softmax(
                        old_outputs / self.temperature, dim=1
                    )

                    loss_distill = F.kl_div(
                        curr_soft, old_soft, reduction='batchmean'
                    ) * (self.temperature ** 2)
                else:
                    loss_distill = 0

                # ===== 总损失 =====
                loss = loss_new + 0.5 * loss_distill

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # 更新旧模型
        self.old_model = copy.deepcopy(self.model)

# 使用
replay = DistillationReplay(model, temperature=3.0)

# 训练任务A
replay.train_new_task(
    new_data=task_a_data,
    replay_data=None  # 第一个任务无需回放
)

# 训练任务B(蒸馏回放任务A)
# 只需要任务A的输入(不需要标签)
task_a_inputs = [x for x, y in task_a_data]

replay.train_new_task(
    new_data=task_b_data,
    replay_data=task_a_inputs  # 回放任务A的输入
)

# 训练任务C
task_ab_inputs = task_a_inputs + [x for x, y in task_b_data]

replay.train_new_task(
    new_data=task_c_data,
    replay_data=task_ab_inputs  # 回放A和B
)

为什么有效?

旧模型的软标签包含了"如何解决旧任务"的知识

less 复制代码
任务A:情感分类
  输入: "这部电影很棒"
  旧模型输出: [正面: 0.92, 负面: 0.08]

学习任务B时:
  在相同输入上,新模型应该输出相似的概率分布
  → 保持了任务A的知识

优缺点

优点

  • 不需要存储数据(只需少量输入样本)
  • 不需要训练生成器
  • 简单高效
  • 内存友好

缺点

  • ⚠️ 需要保存旧模型副本(但可以定期合并)
  • ⚠️ 回放数据需要覆盖旧任务的分布

回放方法对比

方法 存储需求 隐私友好 效果 复杂度 适用场景
原始数据回放 高(原始数据) ❌ 低 ⭐⭐⭐⭐⭐ 数据可存储
生成式回放 中(生成器) ✅ 高 ⭐⭐⭐ 隐私敏感
蒸馏回放 低(模型副本) ✅ 高 ⭐⭐⭐⭐ 推荐

第四部分:蒸馏与回放的结合

为什么结合?

蒸馏和回放解决不同问题

技术 解决的问题 典型场景
知识蒸馏 模型太大,需要压缩 部署到边缘设备
回放机制 持续学习时遗忘 多任务增量学习

结合场景

markdown 复制代码
场景:在线学习系统
  - 大模型(教师):在云端,持续学习新任务
  - 小模型(学生):在设备端,需要同步云端知识

挑战:
  1. 学生模型太小,无法直接学习所有任务
  2. 教师模型学习新任务时会遗忘

解决方案:蒸馏 + 回放

方法1:逐步蒸馏(Progressive Distillation)

原理

每学习一个新任务,就蒸馏一次

css 复制代码
任务A → 训练Teacher_A → 蒸馏到Student_A
任务B → 训练Teacher_B(回放A)→ 蒸馏到Student_B(回放A)
任务C → 训练Teacher_C(回放A+B)→ 蒸馏到Student_C(回放A+B)

实现

python 复制代码
class ProgressiveDistillation:
    def __init__(self, teacher_model, student_model, temperature=3.0):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        self.replay_buffer = []

    def learn_task(self, task_data, task_id):
        """
        学习新任务,同时回放旧任务

        Args:
            task_data: 新任务数据
            task_id: 任务ID
        """
        print(f"\n=== 学习任务 {task_id} ===")

        # ===== 步骤1:训练教师模型(带回放) =====
        print("步骤1: 训练教师模型")
        self._train_teacher(task_data)

        # ===== 步骤2:蒸馏到学生模型(带回放) =====
        print("步骤2: 蒸馏到学生模型")
        self._distill_to_student(task_data)

        # ===== 步骤3:保存回放数据 =====
        self.replay_buffer.extend(
            random.sample(task_data, k=min(1000, len(task_data)))
        )
        print(f"回放缓冲区大小: {len(self.replay_buffer)}")

    def _train_teacher(self, new_data):
        """训练教师模型(带回放)"""
        optimizer = torch.optim.Adam(self.teacher.parameters(), lr=1e-5)

        for epoch in range(5):
            # 混合新数据和回放数据
            if self.replay_buffer:
                replay_sample = random.sample(
                    self.replay_buffer,
                    k=min(len(new_data), len(self.replay_buffer))
                )
                mixed_data = list(new_data) + replay_sample
            else:
                mixed_data = new_data

            random.shuffle(mixed_data)

            for inputs, labels in DataLoader(mixed_data, batch_size=32):
                outputs = self.teacher(inputs)
                loss = F.cross_entropy(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def _distill_to_student(self, new_data):
        """蒸馏到学生模型(带回放)"""
        optimizer = torch.optim.Adam(self.student.parameters(), lr=1e-5)

        # 混合数据
        if self.replay_buffer:
            all_data = list(new_data) + self.replay_buffer
        else:
            all_data = new_data

        for epoch in range(3):
            for inputs, labels in DataLoader(all_data, batch_size=32):
                # 教师预测
                with torch.no_grad():
                    teacher_logits = self.teacher(inputs)

                # 学生预测
                student_logits = self.student(inputs)

                # 蒸馏损失
                loss = self._kd_loss(
                    student_logits, teacher_logits, labels
                )

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def _kd_loss(self, student_logits, teacher_logits, labels):
        """知识蒸馏损失"""
        # 软标签损失
        student_soft = F.log_softmax(
            student_logits / self.temperature, dim=1
        )
        teacher_soft = F.softmax(
            teacher_logits / self.temperature, dim=1
        )
        soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')
        soft_loss *= (self.temperature ** 2)

        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, labels)

        # 组合
        return 0.7 * soft_loss + 0.3 * hard_loss

# 使用
teacher = LargeModel()
student = SmallModel()

progressive = ProgressiveDistillation(teacher, student, temperature=3.0)

# 逐步学习多个任务
for task_id, task_data in enumerate(all_tasks):
    progressive.learn_task(task_data, task_id)

    # 评估
    print(f"任务 {task_id} 完成后的性能:")
    for eval_id in range(task_id + 1):
        acc = evaluate(student, eval_tasks[eval_id])
        print(f"  任务 {eval_id}: {acc:.2%}")

效果

任务 标准微调 蒸馏(无回放) 蒸馏+回放
任务A 89% 81% 87% ✅
任务B 62% (-27%) 78% (-3%) 84% ✅
任务C 48% (-41%) 76% (-5%) 82% ✅
平均 66.3% 78.3% 84.3% ✅✅

方法2:自蒸馏回放(Self-Distillation Replay)

原理

模型作为自己的教师

css 复制代码
不需要单独的大模型教师
  ↓
模型A(任务A训练后)→ 作为教师
  ↓
模型B(学习任务B)← 从模型A蒸馏 + 学习新任务
  ↓
模型B → 作为教师
  ↓
模型C(学习任务C)← 从模型B蒸馏 + 学习新任务

实现

python 复制代码
class SelfDistillationReplay:
    def __init__(self, model, temperature=2.0, alpha=0.5):
        self.model = model
        self.temperature = temperature
        self.alpha = alpha  # 蒸馏权重
        self.old_model = None
        self.replay_inputs = []

    def learn_task(self, new_data):
        """学习新任务,通过自蒸馏保持旧知识"""

        # 保存旧模型
        if self.old_model is not None:
            # 更新回放输入
            new_inputs = [x for x, _ in random.sample(new_data, k=100)]
            self.replay_inputs.extend(new_inputs)

        self.old_model = copy.deepcopy(self.model)
        self.old_model.eval()

        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for inputs, labels in DataLoader(new_data, batch_size=32):
                # ===== 新任务损失 =====
                outputs = self.model(inputs)
                loss_new = F.cross_entropy(outputs, labels)

                # ===== 自蒸馏损失 =====
                if self.old_model and self.replay_inputs:
                    # 采样回放输入
                    replay_batch = random.sample(
                        self.replay_inputs,
                        k=min(32, len(self.replay_inputs))
                    )
                    replay_batch = torch.stack(replay_batch)

                    # 当前模型输出
                    curr_logits = self.model(replay_batch)

                    # 旧模型输出
                    with torch.no_grad():
                        old_logits = self.old_model(replay_batch)

                    # 蒸馏损失
                    loss_distill = self._compute_kd_loss(
                        curr_logits, old_logits
                    )
                else:
                    loss_distill = 0

                # ===== 总损失 =====
                loss = (1 - self.alpha) * loss_new + self.alpha * loss_distill

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    def _compute_kd_loss(self, student_logits, teacher_logits):
        """计算KD损失"""
        student_soft = F.log_softmax(
            student_logits / self.temperature, dim=1
        )
        teacher_soft = F.softmax(
            teacher_logits / self.temperature, dim=1
        )
        return F.kl_div(
            student_soft, teacher_soft, reduction='batchmean'
        ) * (self.temperature ** 2)

# 使用
model = MyModel()
self_distill = SelfDistillationReplay(model, temperature=2.0, alpha=0.5)

# 持续学习
for task_data in all_tasks:
    self_distill.learn_task(task_data)

优势

  • 不需要单独的教师模型(节省内存)
  • 简单易实现
  • 适合资源受限场景

第五部分:实践建议与总结

选择指南

场景1:模型压缩

目标:将70B模型压缩到7B

推荐方案

markdown 复制代码
1. Logits蒸馏(首选)
   - 温度: T=3-5
   - alpha: 0.7-0.9
   - 训练: 2-3 epochs

2. 如果可以访问内部:特征蒸馏
   - 选择关键层(每隔4层选一个)
   - 特征损失权重: 0.5

3. 如果只有API:响应蒸馏
   - 收集大量高质量生成数据
   - 使用排序蒸馏提高效果

场景2:持续学习

目标:模型不断学习新任务,不遗忘旧任务

推荐方案

makefile 复制代码
任务数量少(2-5个):
  → 原始数据回放(效果最好)
  → 每个任务保存1000-5000样本

任务数量中等(5-20个):
  → 蒸馏回放(推荐)
  → 保存旧模型副本 + 少量输入样本

任务数量多(20+个):
  → 生成式回放 或 压缩回放
  → 定期合并模型

场景3:在线学习系统

目标:云端大模型 + 边缘小模型,持续更新

推荐方案

markdown 复制代码
逐步蒸馏 + 回放:
  1. 云端:大模型学习新任务(带回放)
  2. 蒸馏:定期蒸馏到小模型
  3. 部署:推送小模型到边缘设备

周期:每周或每月一次

超参数建议

知识蒸馏

参数 推荐值 说明
温度 T 2-5 越大越软,信息越丰富
alpha 0.7-0.9 软标签权重
学习率 1e-5 ~ 5e-5 比正常训练小
Epochs 2-5 不要过拟合

回放机制

参数 推荐值 说明
缓冲区大小 1000-10000/任务 取决于内存
回放比例 0.3-0.5 旧数据占比
回放权重 0.5-1.0 回放损失的权重

监控指标

蒸馏训练

python 复制代码
def monitor_distillation(teacher, student, val_data):
    metrics = {}

    # 1. 性能差距
    teacher_acc = evaluate(teacher, val_data)
    student_acc = evaluate(student, val_data)
    metrics['performance_gap'] = teacher_acc - student_acc
    print(f"性能差距: {metrics['performance_gap']:.2%}")
    # 期望: <10%

    # 2. 输出分布相似度
    kl_divs = []
    for inputs, _ in val_data:
        with torch.no_grad():
            t_logits = teacher(inputs)
            s_logits = student(inputs)
            kl = F.kl_div(
                F.log_softmax(s_logits, dim=1),
                F.softmax(t_logits, dim=1),
                reduction='batchmean'
            )
            kl_divs.append(kl.item())

    metrics['avg_kl'] = sum(kl_divs) / len(kl_divs)
    print(f"平均KL散度: {metrics['avg_kl']:.4f}")
    # 期望: <0.5

    return metrics

持续学习

python 复制代码
def monitor_continual_learning(model, all_task_data, current_task):
    """监控灾难性遗忘"""
    print(f"\n=== 完成任务 {current_task} 后的评估 ===")

    accuracies = []
    for task_id in range(current_task + 1):
        acc = evaluate(model, all_task_data[task_id])
        accuracies.append(acc)
        print(f"任务 {task_id}: {acc:.2%}")

    # 平均准确率
    avg_acc = sum(accuracies) / len(accuracies)
    print(f"平均准确率: {avg_acc:.2%}")

    # 遗忘程度(旧任务性能下降)
    if current_task > 0:
        old_tasks_acc = sum(accuracies[:-1]) / (len(accuracies) - 1)
        print(f"旧任务平均: {old_tasks_acc:.2%}")

        # 与初始训练时对比
        forgetting = initial_accs[current_task-1] - old_tasks_acc
        print(f"遗忘程度: {forgetting:.2%}")
        # 期望: <10%

常见问题与解决

问题1:蒸馏效果差

markdown 复制代码
症状:学生模型准确率比教师低很多(>15%)
原因:
  1. 温度太低 → 软标签不够软
  2. alpha太小 → 软标签权重不够
  3. 学生太小 → 容量不足

解决:
  1. 增大温度(T: 3→5)
  2. 增大alpha(0.7→0.9)
  3. 增大学生模型(如果可能)
  4. 使用特征蒸馏(如果是白盒)

问题2:持续学习仍然遗忘

markdown 复制代码
症状:即使用了回放,旧任务性能仍下降>20%
原因:
  1. 回放数据太少
  2. 回放权重太小
  3. 任务差异太大(分布偏移严重)

解决:
  1. 增大回放缓冲区(1000→5000)
  2. 增大回放权重(0.3→0.7)
  3. 使用更多回放策略(混合多种方法)
  4. 降低学习率(减缓参数变化)

问题3:训练时间过长

markdown 复制代码
症状:蒸馏或回放训练时间是正常的2倍以上
原因:
  1. 回放数据太多
  2. 教师模型推理慢
  3. 没有并行优化

解决:
  1. 减小回放比例(50%→30%)
  2. 缓存教师模型的输出(预计算)
  3. 使用混合精度训练(FP16)
  4. 批量蒸馏(一次蒸馏多个样本)

小结

核心要点

知识蒸馏:让小模型学习大模型的"思考过程"

  • Logits蒸馏:学习输出概率分布(最常用)
  • 特征蒸馏:学习中间层表示(效果更好)
  • 响应蒸馏:学习文本输出(黑盒场景)

灾难性遗忘:持续学习新任务时,模型会快速遗忘旧知识

  • 原因:参数共享、梯度冲突、分布偏移
  • 表现:旧任务性能急剧下降(20-40%)

回放机制:通过"复习"旧任务防止遗忘

  • 原始回放:保存真实数据(效果最好,但有隐私问题)
  • 生成式回放:用生成器重建数据(隐私友好)
  • 蒸馏回放:用旧模型的软标签(推荐,简单高效)

结合使用:蒸馏 + 回放 = 小而不忘的模型

  • 逐步蒸馏:每学习新任务就蒸馏一次
  • 自蒸馏回放:模型作为自己的教师

核心公式

知识蒸馏损失
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L KD = α ⋅ T 2 ⋅ KL ( softmax ( z S / T ) ∥ softmax ( z T / T ) ) + ( 1 − α ) ⋅ CE ( z S , y ) \mathcal{L}_{\text{KD}} = \alpha \cdot T^2 \cdot \text{KL}\left(\text{softmax}(z^S/T) \| \text{softmax}(z^T/T)\right) + (1-\alpha) \cdot \text{CE}(z^S, y) </math>LKD=α⋅T2⋅KL(softmax(zS/T)∥softmax(zT/T))+(1−α)⋅CE(zS,y)

蒸馏回放损失
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L total = L new + λ ⋅ KL ( π curr ( x old ) ∥ π old ( x old ) ) \mathcal{L}{\text{total}} = \mathcal{L}{\text{new}} + \lambda \cdot \text{KL}(\pi_{\text{curr}}(x_{\text{old}}) \| \pi_{\text{old}}(x_{\text{old}})) </math>Ltotal=Lnew+λ⋅KL(πcurr(xold)∥πold(xold))

实践检查清单

蒸馏前

  • 确定教师和学生模型架构
  • 选择蒸馏类型(白盒/黑盒)
  • 准备蒸馏数据(教师的输出)
  • 设置超参数(T, alpha)

持续学习前

  • 评估基线性能(每个任务单独训练)
  • 选择回放策略(数据/生成/蒸馏)
  • 分配回放缓冲区大小
  • 设计评估指标(监控遗忘)

训练中

  • 监控性能差距(学生 vs 教师)
  • 监控旧任务性能(防止遗忘)
  • 调整回放比例(平衡新旧任务)

训练后

  • 全面评估所有任务
  • 计算平均遗忘率
  • 对比压缩率(模型大小 vs 性能)

知识蒸馏让模型更小,回放机制让模型不忘 ------ 两者结合,打造高效、持续学习的AI系统!

相关推荐
Cosolar30 分钟前
AI Agent 的记忆战争:OpenClaw vs Hermes vs QwenPaw vs HiClaw,谁真正"记得住"?
人工智能·后端·面试
M ? A1 小时前
VuReact:Vue转React的增量编译利器
前端·vue.js·后端·react.js·面试·开源·vureact
kuntli1 小时前
思维树:让AI像人一样多路思考
aigc
aircrushin1 小时前
给宝宝办了个宴,朋友用trae做的工具帮了大忙
前端·后端
码上小翔哥1 小时前
Jackson 配置深度解析
java·后端
程序员Sunday1 小时前
爆肝万字!这应该是全网最全的 Codex 实战教程了
前端·后端·ai编程
aircrushin1 小时前
朋友用trae搭建的工具,解决了旅行拍照共享的大事儿
前端·后端
星栈1 小时前
把业务逻辑写成纯函数之后,我再也不想写 Service 层了
后端·开源
未秃头的程序猿1 小时前
如何用 AI 写出符合规范的 Java 代码?我总结了 7 条有效建议
java·后端·ai编程
阿聪谈架构1 小时前
第10章:Agent 记忆系统 —— 让 AI 真正"记住"你
人工智能·后端