BGE-m3 是开源社区下载量最高的向量模型之一,在RAG(检索增强生成)中用于检索出文档。通常微调的一个难点是需要构造质量比较高的数据集,如果合成数据或者人工标注不够准确,过于绝对的监督信号会放大这种不准确,造成过拟合、灾难性遗忘。而用一个更强大的教师模型生成软标签进行知识蒸馏,可以提供更加精细的监督信号来缓解这种问题,值得当你的数据质量有限时可以尝试一下。
近期,阿里云发布了Qwen3-Embedding系列SOTA向量模型。本文利用系列中最强的 Qwen3-Embedding-8B (教师模型),将其知识蒸馏到 0.6B 的 BGE-m3 (学生模型)上。实验结果表明,通过该方法,学生模型在 scidocs-reranking
数据集上的 MAP@10 指标提升幅度达到 10.20%,领域外下降幅度低于2.5%。
项目代码已开源在 GitHub: github.com/kanhaoning/...
一、核心工具与方法
1.1 核心工具
- 教师模型 :
Qwen/Qwen3-Embedding-8B
(8B参数) - 学生模型 :
BAAI/bge-m3
(0.6B参数) - 训练/评测框架 :
sentence-transformers
- 数据集 :
MTEB/scidocs-reranking
,MTEB/Stackoverflowdupquestions-reranking
- 推理加速框架 :
vLLM
1.2 训练方法:基于KL散度的知识蒸馏
在本次实践采用 sentence-transformers
库中更适合排序蒸馏任务的 DistillKLDivLoss
。
这种方法的核心思想是:学生模型学习通过教师模型所计算出的查询与文档相似度的完整概率分布。这能更精细地传递教师模型的"排序偏好"。
其损失函数是教师模型概率分布 P_t
与学生模型概率分布 P_s
之间的 KL散度 (Kullback-Leibler Divergence)。
具体步骤如下:
- 对于一个查询
Q
和一组文档{P_1, P_2, ..., P_n}
(包含一个正样本和多个负样本),教师模型和学生模型分别计算它们的相似度分数。 - 使用带有温度系数
τ
的Softmax
函数,将这些分数转换成概率分布。温度τ
可以平滑概率分布,让教师模型不那么"绝对",从而为学生模型提供更丰富的学习信号。 - 计算两个概率分布之间的KL散度作为损失。这里乘一个 <math xmlns="http://www.w3.org/1998/Math/MathML"> τ 2 \tau^2 </math>τ2是为了当Softmax于平滑时补偿梯度,使得梯度保持稳定。
公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( Q , { P i } ) = τ 2 ⋅ D K L ( P t ∣ ∣ P s ) = τ 2 ⋅ ∑ i P t ( Q , P i ) ⋅ log ( P t ( Q , P i ) P s ( Q , P i ) ) L(Q, \{P_i\}) = \tau^2 \cdot D_{KL}(P_t || P_s) = \tau^2 \cdot \sum_{i} P_t(Q, P_i) \cdot \log\left(\frac{P_t(Q, P_i)}{P_s(Q, P_i)}\right) </math>L(Q,{Pi})=τ2⋅DKL(Pt∣∣Ps)=τ2⋅∑iPt(Q,Pi)⋅log(Ps(Q,Pi)Pt(Q,Pi))
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> P t ( Q , P i ) = softmax ( M t ( Q , P i ) τ ) = exp ( M t ( Q , P i ) / τ ) ∑ j exp ( M t ( Q , P j ) / τ ) P_t(Q, P_i) = \text{softmax}\left(\frac{M_t(Q, P_i)}{\tau}\right) = \frac{\exp(M_t(Q, P_i) / \tau)}{\sum_{j} \exp(M_t(Q, P_j) / \tau)} </math>Pt(Q,Pi)=softmax(τMt(Q,Pi))=∑jexp(Mt(Q,Pj)/τ)exp(Mt(Q,Pi)/τ)
- <math xmlns="http://www.w3.org/1998/Math/MathML"> P s ( Q , P i ) = softmax ( M s ( Q , P i ) τ ) = exp ( M s ( Q , P i ) / τ ) ∑ j exp ( M s ( Q , P j ) / τ ) P_s(Q, P_i) = \text{softmax}\left(\frac{M_s(Q, P_i)}{\tau}\right) = \frac{\exp(M_s(Q, P_i) / \tau)}{\sum_{j} \exp(M_s(Q, P_j) / \tau)} </math>Ps(Q,Pi)=softmax(τMs(Q,Pi))=∑jexp(Ms(Q,Pj)/τ)exp(Ms(Q,Pi)/τ)
源码中为了避免数值下溢出, <math xmlns="http://www.w3.org/1998/Math/MathML"> P s ( Q , P i ) P_s(Q, P_i) </math>Ps(Q,Pi)是使用更稳定的log_softmax实现的。这里各个符号的含义如下:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> Q Q </math>Q 是查询(Query)。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> { P i } \{P_i\} </math>{Pi} 是一组候选文档,包含一个正样本(Positive)和多个负样本(Negative)。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> M t M_t </math>Mt 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> M s M_s </math>Ms 分别是教师(teacher)模型和学生(student)模型计算的相似度分数。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> τ \tau </math>τ 是温度系数(temperature),用于软化概率分布。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> D K L D_{KL} </math>DKL 代表 KL 散度。
二、环境准备
首先,确保你已安装所有必要的库。
我在实验中使用的主要库版本如下:
diff
Package Version
----------------------- ------------------------
torch 2.6.0
sentence-transformers 5.0.0
transformers 4.53.1
vllm 0.8.4
如果尚未安装,可以使用 pip
命令安装。
bash
pip install torch sentence-transformers==5.0.0 transformers==4.53.1 vllm==0.8.4
# modelscope 用于方便地下载国内模型
pip install modelscope
三、复现步骤
整个流程分为四步:生成教师分数 -> 构建训练数据 -> 训练学生模型 -> 性能评测。我已经将每一步都封装成了脚本,你只需要按顺序执行即可。
步骤 1:下载模型与数据集
首先,我们需要把教师模型、学生模型和数据集准备好。
1. 下载 BGE-m3 (学生模型)
python
from modelscope import snapshot_download
model_dir = snapshot_download('BAAI/bge-m3', cache_dir='/path/to/your/models')
2. 下载 Qwen3-Embedding-8B (教师模型)
python
from modelscope import snapshot_download
model_dir = snapshot_download('Qwen/Qwen3-Embedding-8B', cache_dir='/path/to/your/models')
3. 下载 Scidocs 数据集
注意这里直接拉取数据集可能会报错,建议手动下载:
访问 MTEB/scidocs-reranking,手动下载validation.jsonl.gz
和test.jsonl.gz
(均为3.5MB)到路径Embedding-Distillation/dataset_scidocs
,然后执行以下代码解压:
bash
gunzip validation.jsonl.gz
gunzip test.jsonl.gz
解压后Embedding-Distillation/dataset_scidocs目录下应该有 validation.jsonl
和 test.jsonl
两个文件。 为了降低复现成本,将使用这里的validation.jsonl为训练集,test.json为领域内测试集。
3. 下载 Stackoverflowdupquestions 数据集
访问 MTEB/stackoverflowdupquestions-reranking,手动下载test.jsonl.gz
(1.35MB)到路径Embedding-Distillation/dataset_stackoverflowdupquestions
,然后执行以下代码解压:
bash
gunzip test.jsonl.gz
解压后Embedding-Distillation/dataset_stackoverflowdupquestions目录下应该有test.jsonl
文件。 为了测试灾难性遗忘,将使用这里的test.jsonl为领域外测试集。
步骤 2:生成蒸馏数据集
这一步,我们的目标是用更强大的教师模型 Qwen3-Embedding-8B
为数据集中的每一个 (query, passage)
对计算一个相似度分数。由于数据集很大(训练集有约8.8万个passage),直接用 transformers
跑会非常慢。为了大幅提高速度,我们采用 vLLM
框架进行推理加速。
执行脚本:
bash
bash generate_distillation_data.sh
这个脚本会调用 generate_distillation_data.py
。在运行前,请修改脚本内的 --teacher_model_path
,使其指向你下载好的 Qwen3-Embedding-8B
模型路径。
实现思路
generate_distillation_data.py
的逻辑清晰高效,我们可以将其拆解为以下四步:
1. 读取与展开数据 脚本首先读取原始的 validation.jsonl
文件。文件中的每一行都包含一个 query
、positive
列表和 negative
列表。为了给 DistillKLDivLoss
准备输入,脚本会使用 itertools.product
(笛卡尔积)将positive和negative文档两两组合获取大量的 (query, positive_doc, negative_doc)
三元组。
python
# 关键代码片段 1: 使用笛卡尔积生成三元组
from itertools import product
for pos_item, neg_item in product(positives, negatives):
unique_texts.add(pos_item)
unique_texts.add(neg_item)
triplets.append({'query': query, 'positive': pos_item, 'negative': neg_item})
2. 批量向量化 为了最大化效率,脚本会收集所有不重复的文本(包括所有 query、positive 和 negative),然后用 vLLM
的 model.embed()
方法一次性将它们全部向量化。另外需要注意的是,Qwen3-Embedding官方代码会给Query加上一句任务的Instruct,比如Given a web search query, retrieve relevant passages that answer the query
,但是实测会降低蒸馏效果,所以没有使用。
python
# 关键代码片段 2: 使用 vLLM 进行高效批量编码
input_texts = list(unique_texts)
outputs = model.embed(input_texts)
all_embeddings = [torch.tensor(o.outputs.embedding) for o in outputs]
3. 计算教师分数 获得所有文本的向量后,脚本会为每一个三元组计算教师模型给出的余弦相似度分数。通常对于标准化的向量,余弦相似度可以使用更简单的内积等效实现。
python
# 关键代码片段 3: 计算(query, pos)和(query, neg)的相似度
emb_q = text_to_embedding.get(q_text)
emb_p = text_to_embedding.get(p_text)
emb_n = text_to_embedding.get(n_text)
sim_pos = similarity(emb_q, emb_p)
sim_neg = similarity(emb_q, emb_n)
4. 生成蒸馏文件 最后,脚本将每个三元组和它对应的两个相似度分数打包成一个新的 JSON 对象,写入到输出文件中。这个格式正是 DistillKLDivLoss
所需要的。
python
# 关键代码片段 4: 构建最终的输出记录
record = {
"query": q_text,
"positive": p_text,
"negative": n_text,
"label": [sim_pos, sim_neg] # "软标签"
}
f_out.write(json.dumps(record, ensure_ascii=False) + '\n')
注意这里一个样本只有(query, positive, negative)
三元组,将公式中的候选文档集合 {P_i}
简化为了 {positive, negative}
两个文档。教师模型生成的 label
即 [M_t(Q, P_positive), M_t(Q, P_negative)]
。实践时不用拘泥于只用单个negative的这种格式,sentence-transformers实现的DistillKLDivLoss也是支持一个样本有多个negative字段的,比如{"query": "q_text", "positive": "p_text", "negative1": "n_text1", "negative2": "n_text2", "label": [0.6, 0.3, 0.1]},但是这样会增大调参的难度。
最终会在Embedding-Distillation/dataset_scidocs
路径下生成输入模型训练的jsonl文件validation_kldiv_distill.jsonl
,这是一个样本例子:
JSON
{"query": "Beauty eMakeup: A Deep Makeup Transfer System", "positive": "Learning Hierarchical Features for Scene Labeling", "negative": "Registration with the Point Cloud Library: A Modular Framework for Aligning in 3-D", "label": [0.6058026552200317, 0.5828931331634521]}
步骤 3:训练学生模型
我们现在开始训练学生模型 BAAI/bge-m3
。
执行脚本:
bash
bash train.sh
这个命令会执行 train.py
脚本。在运行前,请务必检查并修改 train.sh
中的几个关键路径和参数:
--student_model_name_or_path
: 将your_path_to/bge-m3
修改为你下载好的bge-m3
模型路径。--train_dataset_path
: 确认它指向步骤2生成的validation_kldiv_distill.jsonl
文件。--output_dir
: 指定你想要保存训练模型检查点的目录。
该脚本使用 sentence-transformers
库,其核心逻辑主要有以下几步:
实现思路
1. 加载学生模型和数据集 脚本首先加载预训练的学生模型 bge-m3
。这是一个 SentenceTransformer
(双编码器)模型,它能将文本高效地映射到向量空间。接着,它加载我们在上一步生成的 .jsonl
蒸馏数据集。
python
# 文件: train.py
# 1. 加载学生模型
student_model = SentenceTransformer(model_args.student_model_name_or_path)
# 2. 加载蒸馏数据集
# 每条样本包含 query, positive, negative 和 label 字段
train_dataset = load_dataset("json", data_files=model_args.train_dataset_path)["train"]
2. 定义 DistillKLDivLoss
损失函数 该损失函数执行的一些细节如下:
- 它接收学生模型
model
和温度temperature
作为参数。 - 在训练时,它会使用学生模型计算
(query, positive)
和(query, negative)
的相似度分数。 - 然后,它将这些学生分数和数据集中由教师模型提供的
label
(即[teacher_pos_score, teacher_neg_score]
)都通过 Softmax 函数转换为概率分布。 - 最后,它计算这两个概率分布之间的 KL散度,并将其作为损失进行反向传播,从而指导学生模型去模仿教师模型的打分分布。
使用代码如下:
python
# 文件: train.py
from sentence_transformers import losses
# 定义KL散度蒸馏损失
train_loss = losses.DistillKLDivLoss(
model=student_model,
temperature=model_args.temperature # 温度参数,默认为2.0
)
3. 初始化并运行 SentenceTransformerTrainer
最后,我们将所有组件,包括学生模型、训练参数、蒸馏数据集和损失函数都传递给 SentenceTransformerTrainer
。这个trainer都实现了各种训练细节,如批处理、梯度累积、学习率调度、日志记录和模型保存。
python
# 文件: train.py
from sentence_transformers.trainer import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
model=student_model,
args=training_args,
train_dataset=train_dataset,
loss=train_loss,
)
# 启动训练
trainer.train()
关键参数说明:
--per_device_train_batch_size
和--gradient_accumulation_steps
:这两个参数共同决定了你的有效批处理大小 (比如16 * 32 = 512
)。这是一个重要的超参数,需要根据显存和实验效果进行调整。--learning_rate
: 学习率,决定了模型参数更新的步长。--bf16
: 启用 bfloat16 混合精度训练,相比fp16
可以提高训练的数值稳定性,但是可能在一些Compute Capability低于8的旧GPU上不支持(比如V100\2080Ti等)。
训练开始后,你可以通过日志实时监控训练进度。训练完成后,最终模型和检查点将保存在 --output_dir
指定的目录中,用于下一步性能评测。
步骤 4:性能评测
训练完成后,让我们来验证以下两个问题:
- 学生模型在领域内(
scidocs
)上的性能提升了多少? - 学生模型在领域领域外,比如(
stackoverflowdupquestions
)有没有大幅下降?
现在使用蒸馏前后的 bge-m3
模型,分别对领域内和领域外两个测试集用sentence-transformers实现的RerankingEvaluator
进快速进行评测。
执行脚本:
bash
bash evaluation.sh
这个命令会调用 evaluation.py
脚本来执行完整的评测流程。在运行前,请务必修改 evaluation.sh
中的模型路径:
--model_before
: 保持其为原始的BAAI/bge-m3
模型路径。--model_after
: 将其修改为你上一步训练产出的 checkpoint 路径 ,例如output/checkpoint-1000
。--in_domain_dataset
: 确认指向dataset_scidocs/test.jsonl
。--out_domain_dataset
: 确认指向dataset_stackoverflowdupquestions/test.jsonl
。
结果
等待 evaluation.py
脚本运行完毕后,生成如下表格:
指标 | 蒸馏前 | 蒸馏后 | 绝对变化 | 相对变化(%) |
---|---|---|---|---|
领域内 (scidocs) | ||||
map | 0.7744 | 0.8534 | +0.0790 | +10.20 |
mrr@10 | 0.9321 | 0.9554 | +0.0233 | +2.50 |
ndcg@10 | 0.8296 | 0.8973 | +0.0676 | +8.15 |
领域外 (stackoverflow) | ||||
map | 0.5168 | 0.5040 | -0.0129 | -2.49 |
mrr@10 | 0.5240 | 0.5116 | -0.0124 | -2.37 |
ndcg@10 | 0.5904 | 0.5774 | -0.0129 | -2.19 |
(注:你的具体数值可能会因训练的随机性、超参数设置等有微小差异,但总体趋势应保持一致。)
从结果中我们可以得出两个结论:
-
领域内性能显著提升 :在目标任务
scidocs
数据集上,3个关键指标都获得了明显增长。其中 MAP 指标提升了 10.20% ,这证明了知识蒸馏的有效性。学生模型成功地从教师模型Qwen3-Embedding-8B
那里学到了更精细的排序知识,使其在专业领域的文档排序能力变得更强。 -
通用能力基本保持 :在领域外的
stackoverflowdupquestions
数据集上,模型的性能只有一定的下降(MAP指标下降 2.49%),但是不算灾难性遗忘。学生模型在提升特定领域能力的同时,基本保留了其原有的通用能力。
这个方法整体实测下来,用软标签的损失函数DistillKLDivLoss整体上比硬标签的MultipleNegativesRankingLoss、TripletLoss在保留通用能力和提升泛化能力上都有一定优势,但是也有翻车的时候,比如教师模型表现不佳的数据集。具体实践得结合数据数量、质量、教师模型的性能考虑方法选用。
五、参考文献
Qwen3 Embedding: Advancing Text Embedding and Reranking Through Foundation Models
Distilling Dense Representations for Ranking using Tightly-Coupled Teachers
如果这篇文章对你有帮助,请给我的Github项目点个star吧:github.com/kanhaoning/...
本文首发于知乎平台,原文链接:zhuanlan.zhihu.com/p/193369385...