用KL散度将Qwen3-8B向量模型知识蒸馏给小模型BGE-m3

BGE-m3 是开源社区下载量最高的向量模型之一,在RAG(检索增强生成)中用于检索出文档。通常微调的一个难点是需要构造质量比较高的数据集,如果合成数据或者人工标注不够准确,过于绝对的监督信号会放大这种不准确,造成过拟合、灾难性遗忘。而用一个更强大的教师模型生成软标签进行知识蒸馏,可以提供更加精细的监督信号来缓解这种问题,值得当你的数据质量有限时可以尝试一下。

近期,阿里云发布了Qwen3-Embedding系列SOTA向量模型。本文利用系列中最强的 Qwen3-Embedding-8B (教师模型),将其知识蒸馏到 0.6B 的 BGE-m3 (学生模型)上。实验结果表明,通过该方法,学生模型在 scidocs-reranking 数据集上的 MAP@10 指标提升幅度达到 10.20%,领域外下降幅度低于2.5%。

项目代码已开源在 GitHub: github.com/kanhaoning/...

一、核心工具与方法

1.1 核心工具

  • 教师模型 : Qwen/Qwen3-Embedding-8B (8B参数)
  • 学生模型 : BAAI/bge-m3 (0.6B参数)
  • 训练/评测框架 : sentence-transformers
  • 数据集 : MTEB/scidocs-rerankingMTEB/Stackoverflowdupquestions-reranking
  • 推理加速框架 : vLLM

1.2 训练方法:基于KL散度的知识蒸馏

在本次实践采用 sentence-transformers 库中更适合排序蒸馏任务的 DistillKLDivLoss

这种方法的核心思想是:学生模型学习通过教师模型所计算出的查询与文档相似度的完整概率分布。这能更精细地传递教师模型的"排序偏好"。

其损失函数是教师模型概率分布 P_t 与学生模型概率分布 P_s 之间的 KL散度 (Kullback-Leibler Divergence)

具体步骤如下:

  1. 对于一个查询 Q 和一组文档 {P_1, P_2, ..., P_n}(包含一个正样本和多个负样本),教师模型和学生模型分别计算它们的相似度分数。
  2. 使用带有温度系数 τSoftmax 函数,将这些分数转换成概率分布。温度 τ 可以平滑概率分布,让教师模型不那么"绝对",从而为学生模型提供更丰富的学习信号。
  3. 计算两个概率分布之间的KL散度作为损失。这里乘一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ 2 \tau^2 </math>τ2是为了当Softmax于平滑时补偿梯度,使得梯度保持稳定。

公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( Q , { P i } ) = τ 2 ⋅ D K L ( P t ∣ ∣ P s ) = τ 2 ⋅ ∑ i P t ( Q , P i ) ⋅ log ⁡ ( P t ( Q , P i ) P s ( Q , P i ) ) L(Q, \{P_i\}) = \tau^2 \cdot D_{KL}(P_t || P_s) = \tau^2 \cdot \sum_{i} P_t(Q, P_i) \cdot \log\left(\frac{P_t(Q, P_i)}{P_s(Q, P_i)}\right) </math>L(Q,{Pi})=τ2⋅DKL(Pt∣∣Ps)=τ2⋅∑iPt(Q,Pi)⋅log(Ps(Q,Pi)Pt(Q,Pi))

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> P t ( Q , P i ) = softmax ( M t ( Q , P i ) τ ) = exp ⁡ ( M t ( Q , P i ) / τ ) ∑ j exp ⁡ ( M t ( Q , P j ) / τ ) P_t(Q, P_i) = \text{softmax}\left(\frac{M_t(Q, P_i)}{\tau}\right) = \frac{\exp(M_t(Q, P_i) / \tau)}{\sum_{j} \exp(M_t(Q, P_j) / \tau)} </math>Pt(Q,Pi)=softmax(τMt(Q,Pi))=∑jexp(Mt(Q,Pj)/τ)exp(Mt(Q,Pi)/τ)
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> P s ( Q , P i ) = softmax ( M s ( Q , P i ) τ ) = exp ⁡ ( M s ( Q , P i ) / τ ) ∑ j exp ⁡ ( M s ( Q , P j ) / τ ) P_s(Q, P_i) = \text{softmax}\left(\frac{M_s(Q, P_i)}{\tau}\right) = \frac{\exp(M_s(Q, P_i) / \tau)}{\sum_{j} \exp(M_s(Q, P_j) / \tau)} </math>Ps(Q,Pi)=softmax(τMs(Q,Pi))=∑jexp(Ms(Q,Pj)/τ)exp(Ms(Q,Pi)/τ)

源码中为了避免数值下溢出, <math xmlns="http://www.w3.org/1998/Math/MathML"> P s ( Q , P i ) P_s(Q, P_i) </math>Ps(Q,Pi)是使用更稳定的log_softmax实现的。这里各个符号的含义如下:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 是查询(Query)。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> { P i } \{P_i\} </math>{Pi} 是一组候选文档,包含一个正样本(Positive)和多个负样本(Negative)。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> M t M_t </math>Mt 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> M s M_s </math>Ms 分别是教师(teacher)模型和学生(student)模型计算的相似度分数。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> τ \tau </math>τ 是温度系数(temperature),用于软化概率分布。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L D_{KL} </math>DKL 代表 KL 散度。

二、环境准备

首先,确保你已安装所有必要的库。

我在实验中使用的主要库版本如下:

diff 复制代码
Package                 Version
----------------------- ------------------------
torch                   2.6.0
sentence-transformers   5.0.0
transformers            4.53.1
vllm                    0.8.4

如果尚未安装,可以使用 pip 命令安装。

bash 复制代码
pip install torch sentence-transformers==5.0.0 transformers==4.53.1 vllm==0.8.4
# modelscope 用于方便地下载国内模型
pip install modelscope

三、复现步骤

整个流程分为四步:生成教师分数 -> 构建训练数据 -> 训练学生模型 -> 性能评测。我已经将每一步都封装成了脚本,你只需要按顺序执行即可。

步骤 1:下载模型与数据集

首先,我们需要把教师模型、学生模型和数据集准备好。

1. 下载 BGE-m3 (学生模型)

python 复制代码
from modelscope import snapshot_download
model_dir = snapshot_download('BAAI/bge-m3', cache_dir='/path/to/your/models')

2. 下载 Qwen3-Embedding-8B (教师模型)

python 复制代码
from modelscope import snapshot_download
model_dir = snapshot_download('Qwen/Qwen3-Embedding-8B', cache_dir='/path/to/your/models')

3. 下载 Scidocs 数据集

注意这里直接拉取数据集可能会报错,建议手动下载:

访问 MTEB/scidocs-reranking,手动下载validation.jsonl.gztest.jsonl.gz(均为3.5MB)到路径Embedding-Distillation/dataset_scidocs,然后执行以下代码解压:

bash 复制代码
gunzip validation.jsonl.gz
gunzip test.jsonl.gz

解压后Embedding-Distillation/dataset_scidocs目录下应该有 validation.jsonltest.jsonl 两个文件。 为了降低复现成本,将使用这里的validation.jsonl为训练集,test.json为领域内测试集。

3. 下载 Stackoverflowdupquestions 数据集

访问 MTEB/stackoverflowdupquestions-reranking,手动下载test.jsonl.gz(1.35MB)到路径Embedding-Distillation/dataset_stackoverflowdupquestions,然后执行以下代码解压:

bash 复制代码
gunzip test.jsonl.gz

解压后Embedding-Distillation/dataset_stackoverflowdupquestions目录下应该有test.jsonl 文件。 为了测试灾难性遗忘,将使用这里的test.jsonl为领域外测试集。

步骤 2:生成蒸馏数据集

这一步,我们的目标是用更强大的教师模型 Qwen3-Embedding-8B 为数据集中的每一个 (query, passage) 对计算一个相似度分数。由于数据集很大(训练集有约8.8万个passage),直接用 transformers 跑会非常慢。为了大幅提高速度,我们采用 vLLM 框架进行推理加速。

执行脚本:

bash 复制代码
bash generate_distillation_data.sh

这个脚本会调用 generate_distillation_data.py。在运行前,请修改脚本内的 --teacher_model_path ,使其指向你下载好的 Qwen3-Embedding-8B 模型路径。

实现思路

generate_distillation_data.py 的逻辑清晰高效,我们可以将其拆解为以下四步:

1. 读取与展开数据 脚本首先读取原始的 validation.jsonl 文件。文件中的每一行都包含一个 querypositive 列表和 negative 列表。为了给 DistillKLDivLoss 准备输入,脚本会使用 itertools.product (笛卡尔积)将positive和negative文档两两组合获取大量的 (query, positive_doc, negative_doc) 三元组。

python 复制代码
# 关键代码片段 1: 使用笛卡尔积生成三元组
from itertools import product

for pos_item, neg_item in product(positives, negatives):
    unique_texts.add(pos_item)
    unique_texts.add(neg_item)
    triplets.append({'query': query, 'positive': pos_item, 'negative': neg_item})

2. 批量向量化 为了最大化效率,脚本会收集所有不重复的文本(包括所有 query、positive 和 negative),然后用 vLLMmodel.embed() 方法一次性将它们全部向量化。另外需要注意的是,Qwen3-Embedding官方代码会给Query加上一句任务的Instruct,比如Given a web search query, retrieve relevant passages that answer the query,但是实测会降低蒸馏效果,所以没有使用。

python 复制代码
# 关键代码片段 2: 使用 vLLM 进行高效批量编码
input_texts = list(unique_texts)
outputs = model.embed(input_texts)
all_embeddings = [torch.tensor(o.outputs.embedding) for o in outputs]

3. 计算教师分数 获得所有文本的向量后,脚本会为每一个三元组计算教师模型给出的余弦相似度分数。通常对于标准化的向量,余弦相似度可以使用更简单的内积等效实现。

python 复制代码
# 关键代码片段 3: 计算(query, pos)和(query, neg)的相似度
emb_q = text_to_embedding.get(q_text)
emb_p = text_to_embedding.get(p_text)
emb_n = text_to_embedding.get(n_text)

sim_pos = similarity(emb_q, emb_p)
sim_neg = similarity(emb_q, emb_n)

4. 生成蒸馏文件 最后,脚本将每个三元组和它对应的两个相似度分数打包成一个新的 JSON 对象,写入到输出文件中。这个格式正是 DistillKLDivLoss 所需要的。

python 复制代码
# 关键代码片段 4: 构建最终的输出记录
record = {
    "query": q_text,
    "positive": p_text,
    "negative": n_text,
    "label": [sim_pos, sim_neg]  # "软标签"
}
f_out.write(json.dumps(record, ensure_ascii=False) + '\n')

注意这里一个样本只有(query, positive, negative)三元组,将公式中的候选文档集合 {P_i} 简化为了 {positive, negative} 两个文档。教师模型生成的 label[M_t(Q, P_positive), M_t(Q, P_negative)]。实践时不用拘泥于只用单个negative的这种格式,sentence-transformers实现的DistillKLDivLoss也是支持一个样本有多个negative字段的,比如{"query": "q_text", "positive": "p_text", "negative1": "n_text1", "negative2": "n_text2", "label": [0.6, 0.3, 0.1]},但是这样会增大调参的难度。

最终会在Embedding-Distillation/dataset_scidocs路径下生成输入模型训练的jsonl文件validation_kldiv_distill.jsonl,这是一个样本例子:

JSON 复制代码
{"query": "Beauty eMakeup: A Deep Makeup Transfer System", "positive": "Learning Hierarchical Features for Scene Labeling", "negative": "Registration with the Point Cloud Library: A Modular Framework for Aligning in 3-D", "label": [0.6058026552200317, 0.5828931331634521]}

步骤 3:训练学生模型

我们现在开始训练学生模型 BAAI/bge-m3

执行脚本:

bash 复制代码
bash train.sh

这个命令会执行 train.py 脚本。在运行前,请务必检查并修改 train.sh 中的几个关键路径和参数:

  • --student_model_name_or_path: 将 your_path_to/bge-m3 修改为你下载好的 bge-m3 模型路径。
  • --train_dataset_path: 确认它指向步骤2生成的 validation_kldiv_distill.jsonl 文件。
  • --output_dir: 指定你想要保存训练模型检查点的目录。

该脚本使用 sentence-transformers 库,其核心逻辑主要有以下几步:

实现思路

1. 加载学生模型和数据集 脚本首先加载预训练的学生模型 bge-m3。这是一个 SentenceTransformer (双编码器)模型,它能将文本高效地映射到向量空间。接着,它加载我们在上一步生成的 .jsonl 蒸馏数据集。

python 复制代码
# 文件: train.py

# 1. 加载学生模型
student_model = SentenceTransformer(model_args.student_model_name_or_path)

# 2. 加载蒸馏数据集
# 每条样本包含 query, positive, negative 和 label 字段
train_dataset = load_dataset("json", data_files=model_args.train_dataset_path)["train"]

2. 定义 DistillKLDivLoss 损失函数 该损失函数执行的一些细节如下:

  • 它接收学生模型model和温度temperature作为参数。
  • 在训练时,它会使用学生模型计算 (query, positive)(query, negative) 的相似度分数。
  • 然后,它将这些学生分数和数据集中由教师模型提供的 label(即 [teacher_pos_score, teacher_neg_score])都通过 Softmax 函数转换为概率分布。
  • 最后,它计算这两个概率分布之间的 KL散度,并将其作为损失进行反向传播,从而指导学生模型去模仿教师模型的打分分布。

使用代码如下:

python 复制代码
# 文件: train.py

from sentence_transformers import losses

# 定义KL散度蒸馏损失
train_loss = losses.DistillKLDivLoss(
    model=student_model,
    temperature=model_args.temperature # 温度参数,默认为2.0
)

3. 初始化并运行 SentenceTransformerTrainer 最后,我们将所有组件,包括学生模型、训练参数、蒸馏数据集和损失函数都传递给 SentenceTransformerTrainer。这个trainer都实现了各种训练细节,如批处理、梯度累积、学习率调度、日志记录和模型保存。

python 复制代码
# 文件: train.py

from sentence_transformers.trainer import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    loss=train_loss,
)

# 启动训练
trainer.train()

关键参数说明:

  • --per_device_train_batch_size--gradient_accumulation_steps:这两个参数共同决定了你的有效批处理大小 (比如16 * 32 = 512)。这是一个重要的超参数,需要根据显存和实验效果进行调整。
  • --learning_rate: 学习率,决定了模型参数更新的步长。
  • --bf16: 启用 bfloat16 混合精度训练,相比fp16可以提高训练的数值稳定性,但是可能在一些Compute Capability低于8的旧GPU上不支持(比如V100\2080Ti等)。

训练开始后,你可以通过日志实时监控训练进度。训练完成后,最终模型和检查点将保存在 --output_dir 指定的目录中,用于下一步性能评测。

步骤 4:性能评测

训练完成后,让我们来验证以下两个问题:

  1. 学生模型在领域内(scidocs)上的性能提升了多少?
  2. 学生模型在领域领域外,比如(stackoverflowdupquestions)有没有大幅下降?

现在使用蒸馏前后的 bge-m3 模型,分别对领域内和领域外两个测试集用sentence-transformers实现的RerankingEvaluator进快速进行评测。

执行脚本:

bash 复制代码
bash evaluation.sh

这个命令会调用 evaluation.py 脚本来执行完整的评测流程。在运行前,请务必修改 evaluation.sh 中的模型路径:

  • --model_before: 保持其为原始的 BAAI/bge-m3 模型路径。
  • --model_after: 将其修改为你上一步训练产出的 checkpoint 路径 ,例如 output/checkpoint-1000
  • --in_domain_dataset: 确认指向 dataset_scidocs/test.jsonl
  • --out_domain_dataset: 确认指向 dataset_stackoverflowdupquestions/test.jsonl

结果

等待 evaluation.py 脚本运行完毕后,生成如下表格:

指标 蒸馏前 蒸馏后 绝对变化 相对变化(%)
领域内 (scidocs)
map 0.7744 0.8534 +0.0790 +10.20
mrr@10 0.9321 0.9554 +0.0233 +2.50
ndcg@10 0.8296 0.8973 +0.0676 +8.15
领域外 (stackoverflow)
map 0.5168 0.5040 -0.0129 -2.49
mrr@10 0.5240 0.5116 -0.0124 -2.37
ndcg@10 0.5904 0.5774 -0.0129 -2.19

(注:你的具体数值可能会因训练的随机性、超参数设置等有微小差异,但总体趋势应保持一致。)

从结果中我们可以得出两个结论:

  1. 领域内性能显著提升 :在目标任务 scidocs 数据集上,3个关键指标都获得了明显增长。其中 MAP 指标提升了 10.20% ,这证明了知识蒸馏的有效性。学生模型成功地从教师模型 Qwen3-Embedding-8B 那里学到了更精细的排序知识,使其在专业领域的文档排序能力变得更强。

  2. 通用能力基本保持 :在领域外的 stackoverflowdupquestions 数据集上,模型的性能只有一定的下降(MAP指标下降 2.49%),但是不算灾难性遗忘。学生模型在提升特定领域能力的同时,基本保留了其原有的通用能力。

这个方法整体实测下来,用软标签的损失函数DistillKLDivLoss整体上比硬标签的MultipleNegativesRankingLoss、TripletLoss在保留通用能力和提升泛化能力上都有一定优势,但是也有翻车的时候,比如教师模型表现不佳的数据集。具体实践得结合数据数量、质量、教师模型的性能考虑方法选用。

五、参考文献

Qwen3 Embedding: Advancing Text Embedding and Reranking Through Foundation Models

Distilling Dense Representations for Ranking using Tightly-Coupled Teachers


如果这篇文章对你有帮助,请给我的Github项目点个star吧:github.com/kanhaoning/...


本文首发于知乎平台,原文链接:zhuanlan.zhihu.com/p/193369385...

相关推荐
爱代码的小黄人1 小时前
利用劳斯判据分析右半平面极点数量的方法研究
算法·机器学习·平面
今天也好累5 小时前
C 语言基础第16天:指针补充
java·c语言·数据结构·笔记·学习·算法
大千AI助手5 小时前
直接偏好优化(DPO):原理、演进与大模型对齐新范式
人工智能·神经网络·算法·机器学习·dpo·大模型对齐·直接偏好优化
徐小夕7 小时前
再也不怕看不懂 GitHub 代码!这款AI开源项目,一键生成交互架构图
前端·算法·github
SirLancelot18 小时前
数据结构-Set集合(一)Set集合介绍、优缺点
java·开发语言·数据结构·后端·算法·哈希算法·set
YouQian7728 小时前
label 拓扑排序
数据结构·算法
YouQian7728 小时前
(补题)小塔的饭
算法
歌者長門8 小时前
做题笔记:某大讯飞真题28道
java·数据结构·算法
是店小二呀8 小时前
【动态规划 | 多状态问题】动态规划求解多状态问题
算法·动态规划
竹子_239 小时前
《零基础入门AI:传统机器学习核心算法解析(KNN、模型调优与朴素贝叶斯)》
人工智能·算法·机器学习