将重排序大模型Qwen3-Reranker-8B的知识蒸馏到小模型BGE-reranker-v2-m3上

BGE-reranker-v2-m3 是一个很好用的重排序模型,在RAG(检索增强生成)中用于进一步优化检索出的文档。但是也存在一个痛点:用大模型合成、甚至人工标注 (query, positive, negative) 三元组数据用于训练微调,过程麻烦且成本较高。

最近,阿里云发布了Qwen3-reranker系列SOTA重排序模型。本文将分享一个低成本的优化方案:利用系列中最强的 Qwen3-Reranker-8B (教师模型),将其知识蒸馏到 0.6B 的 BGE-reranker-v2-m3 (学生模型)上。实验结果表明,通过该方法,学生模型在 stackoverflowdupquestions-reranking 数据集上的 MRR@10 指标提升幅度达到 19.96%。

项目代码已开源在 GitHub: github.com/kanhaoning/...

一、核心工具与方法

1.1 核心工具

  • 教师模型 : Qwen/Qwen3-Reranker-8B (8B参数)
  • 学生模型 : BAAI/bge-reranker-v2-m3 (0.6B参数)
  • 训练/评测框架 : sentence-transformers
  • 数据集 : MTEB/stackoverflowdupquestions-reranking
  • 推理加速框架 : vLLM

1.2 训练方法:MarginMSE知识蒸馏

训练的目标是让学生模型学会模仿教师模型对不同样本打分的差异,而不是直接学习教师模型打出来的分数。具体来说,就是让学生模型对于(查询,更相关文档)和(查询,不相关文档)这两个组合的相关性分数之差,尽可能地接近教师模型给出的分数差。

使用的损失函数是 MarginMSE,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L ( Q , P + , P − ) = MSE ( ( M s ( Q , P + ) − M s ( Q , P − ) ) , ( M t ( Q , P + ) − M t ( Q , P − ) ) ) L(Q, P_+, P_-) = \text{MSE}( (M_s(Q, P_+) - M_s(Q, P_-)), (M_t(Q, P_+) - M_t(Q, P_-)) ) </math>L(Q,P+,P−)=MSE((Ms(Q,P+)−Ms(Q,P−)),(Mt(Q,P+)−Mt(Q,P−)))

其中:

  • Q 是查询(Query)。
  • P+ 是相关性更高的文档(Positive Passage)。
  • P- 是相关性更低的文档(Negative Passage)。
  • Mt 和 Ms 分别是教师(teacher)模型和学生(student)模型计算的logit分数。
  • MSE 是均方误差(Mean Squared Error)。

这种方法不要求学生模型完全复现教师模型的分数,只要求它学会区分"好"与"更好"的差异,这使学生模型可以学习与自身打分差异较大的教师模型的同时而在保留自身打分的特性。相比于SFT,知识蒸馏不需要人工标注,尤其适合在有大量垂直领域的原始数据,但是没有高质量的标注的场景。

二、环境准备

首先,确保你已安装所有必要的库。

我在实验中使用的主要库版本如下:

diff 复制代码
Package                 Version
----------------------- ------------------------
torch                   2.6.0
sentence-transformers   5.0.0
transformers            4.53.1
vllm                    0.8.4

如果尚未安装,可以使用 pip 命令安装。

bash 复制代码
pip install torch sentence-transformers==5.0.0 transformers==4.53.1 vllm==0.8.4
# modelscope 用于方便地下载国内模型
pip install modelscope

三、复现步骤

整个流程分为四步:生成教师分数 -> 构建训练数据 -> 训练学生模型 -> 性能评测。我已经将每一步都封装成了脚本,你只需要按顺序执行即可。

步骤 1:下载模型与数据集

首先,我们需要把教师模型、学生模型和数据集准备好。

1. 下载 BGE-reranker-v2-m3 (学生模型)

python 复制代码
from modelscope import snapshot_download
model_dir = snapshot_download('BAAI/bge-reranker-v2-m3', cache_dir='/path/to/your/models')

2. 下载 Qwen3-Reranker-8B (教师模型)

python 复制代码
from modelscope import snapshot_download
model_dir = snapshot_download('Qwen/Qwen3-Reranker-8B', cache_dir='/path/to/your/models')

3. 下载 Stack Overflow 数据集 访问 MTEB/stackoverflowdupquestions-reranking,手动下载 train.jsonl.gztest.jsonl.gz,然后解压。

bash 复制代码
gunzip train.jsonl.gz
gunzip test.jsonl.gz

将解压后的 train.jsonltest.jsonl 文件放到你的项目目录下。

步骤 2:生成教师模型 Logits 分数

这一步,我们的目标是用更强大的教师模型 Qwen3-Reranker-8B 为数据集中的每一个 (query, passage) 对计算一个logit分数。由于数据集很大(训练集有约60万个query-passage pair),直接用 transformers 跑会非常慢。为了大幅提高效率,我们采用 vLLM 框架进行推理加速。

执行脚本:

bash 复制代码
bash generate_logits.sh

这个脚本会调用 generate_logits.py。在运行前,请修改脚本内的 --model_path ,使其指向你下载好的 Qwen3-Reranker-8B 模型路径。

实现原理

Qwen3-Reranker 将"重排序"任务转化为了一个"生成"任务。输入判断查询(Query)和文档(Document)是否相关的提示词后,预测下一个词是 "yes" 还是 "no",以此来判断文档与查询的相关性。

官方给出的相关性分数计算公式是基于Softmax的概率:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> score ( q , d ) = e l o g i t ( yes ∣ I , q , d ) e l o g i t ( yes ∣ I , q , d ) + e l o g i t ( no ∣ I , q , d ) \text{score}(q, d) = \frac{e^{logit(\text{yes}|I,q,d)}} {e^{logit(\text{yes}|I,q,d)} + e^{logit(\text{no}|I,q,d)}} </math>score(q,d)=elogit(yes∣I,q,d)+elogit(no∣I,q,d)elogit(yes∣I,q,d)

这个公式将模型的输出转换为一个 01 之间的概率值。

在知识蒸馏(knowledge distillation)的场景中,我们需要 Qwen3-Reranker 这种 decoder-only 架构模型提供类似于 cross-encoder 架构中的等效 logit 值。为此,需要对原始的概率得分进行反 sigmoid 变换,即:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logit = log ⁡ ( score 1 − score ) \text{logit} = \log\left(\frac{\text{score}}{1 - \text{score}}\right) </math>logit=log(1−scorescore)

最终,用于知识蒸馏的 logit 值可以通过以下公式获得:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> logit = log ⁡ P ( yes ∣ I , q , d ) − log ⁡ P ( no ∣ I , q , d ) \text{logit} = \log P(\text{yes} \mid I, q, d) - \log P(\text{no} \mid I, q, d) </math>logit=logP(yes∣I,q,d)−logP(no∣I,q,d)

其中,P(yes) 和 P(no) 分别表示模型生成"yes"和"no"的条件概率。得益于 vLLM 提供的接口,我们可以高效地获取每个 token 的对数概率值。接下来我们将基于 vLLM 框架实现上述核心逻辑代码。

代码实现

1. 构造Qwen3-Reranker模型输入

我们首先需要按照 Qwen3-Reranker 指定的模板,将 (query, passage) 对格式化为一段提示词。

python 复制代码
# 文件: generate_logits.py

def format_and_tokenize_inputs(
    tokenizer: AutoTokenizer,
    queries: List[str],
    docs: List[str],
    instruction: str,
    max_length: int
) -> List[TokensPrompt]:
    """使用 apply_chat_template 格式化并 tokenize 输入"""
    messages = []
    for query, doc in zip(queries, docs):
        # 这是模型要求的标准对话格式
        message = [
            {"role": "system", "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\"."},
            {"role": "user", "content": f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"}
        ]
        messages.append(message)
    
    # 使用 tokenizer 的模板功能,高效地将文本转换为 token IDs
    templated_messages = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True, # add_generation_prompt=True 会自动添加 assistant 角色的起始符
        enable_thinking=False
    )
    
    # 截断超长序列并转换为 vLLM 接受的 TokensPrompt 格式
    processed_messages = [ele[:max_length] for ele in templated_messages]
    final_messages = [TokensPrompt(prompt_token_ids=ele) for ele in processed_messages]
    return final_messages

2. 计算 Logit 分数

接收 vLLM 的推理结果,并从中提取 "yes" 和 "no" 的对数概率(logprobs),最后计算它们的差值。

python 复制代码
# 文件: generate_logits.py

def compute_scores_vllm(
    model: LLM,
    tokenizer: AutoTokenizer,
    sampling_params: SamplingParams,
    batch_queries: List[str],
    batch_docs: List[str],
    instruction: str,
    max_length: int
) -> List[float]:
    """计算分数的函数"""
    # 获取 'yes' 和 'no' 两个词对应的 token ID
    true_token = tokenizer("yes", add_special_tokens=False).input_ids[0]
    false_token = tokenizer("no", add_special_tokens=False).input_ids[0]

    # 1. 格式化输入
    tokenized_batch = format_and_tokenize_inputs(tokenizer, batch_queries, batch_docs, instruction, max_length)
    # 2. 使用 vLLM 并行推理
    outputs = model.generate(tokenized_batch, sampling_params=sampling_params, use_tqdm=False)

    scores = []
    for output in outputs:
        # 3. 从推理结果中提取 logprobs
        # 我们只关心生成的第一个 token,所以取 logprobs[-1]
        final_logprobs = output.outputs[0].logprobs[-1]

        # 4. 获取 'yes' 和 'no' 的对数概率,如果某个词不存在于 top logprobs 中,给一个很小的默认值
        true_logprob = final_logprobs.get(true_token, -10.0)
        if not isinstance(true_logprob, float): # vLLM 可能返回 Logprob 对象
            true_logprob = true_logprob.logprob

        false_logprob = final_logprobs.get(false_token, -10.0)
        if not isinstance(false_logprob, float):
            false_logprob = false_logprob.logprob

        # 5. 核心:计算yes和no的对数概率差值作为Logit
        logit_diff = true_logprob - false_logprob
        scores.append(logit_diff)
        
    return scores

脚本与产出

generate_logits.sh 脚本负责调用上述 Python 代码,并传入必要的参数,如模型路径、输入文件名和批处理大小。

bash 复制代码
# generate_logits.sh 内容
#!/bin/bash

# 使用 vLLM + Qwen3-Reranker-8B 生成训练/测试数据的 logits 分数
python generate_logits.py \
  --model_path your_path_to/Qwen3-Reranker-8B \
  --input_files train.jsonl test.jsonl \
  --output_suffix _distill_qwen3_8b_vLLMlogit \
  --batch_size 8 \
  --max_model_len 8192 \
  --gpu_memory_utilization 0.9 \
  --task_instruction "Given a web search query, retrieve relevant passages that answer the query"

脚本运行成功后,将生成两个新的jsonl文件,分别对应构建训练集、测试集所需的logit分数:

  • train_distill_qwen3_8b_vLLMlogit.jsonl
  • test_distill_qwen3_8b_vLLMlogit.jsonl

每一行是一个(query, passage, score)pair,以下是一个具体例子:

json 复制代码
{"query": "String isNullOrEmpty in Java?", "passage": "Java equivalent of c# String.IsNullOrEmpty() and String.IsNullOrWhiteSpace()", "score": 0.875}

步骤 3:构建训练样本

接下来,我们需要将上一步生成的Logit分数文件,转换为 MarginMSE 损失函数需要的三元组格式 (query, positive, negative, score_diff)

执行脚本:

bash 复制代码
bash create_triplets.sh

该脚本会调用 create_triplets.py,它会为每个 query 下的高分 passage(正例)匹配若干个低分 passage(负例),并计算它们的分数差。这一步是由Gemini生成的采样方法,不一定是最优解。

产出: 此步骤会生成最终的训练和评估文件:

  • train_distill_qwen3_8b_vLLMlogit_margin_sampled.jsonl
  • test_distill_qwen3_8b_vLLMlogit_margin_sampled.jsonl

以下是一条数据的具体例子:

json 复制代码
{"query": "String isNullOrEmpty in Java?", "positive": "Java equivalent of c# String.IsNullOrEmpty() and String.IsNullOrWhiteSpace()", "negative": "isLocalHost(String hostNameOrIpAddress) in Java", "score": 6.231092929840088}

步骤 4:训练学生模型

现在我们开始训练(蒸馏)学生模型 bge-reranker-v2-m3

执行脚本:

bash 复制代码
bash train.sh

此脚本会调用 train.py,使用 sentence-transformers 框架提供的 MarginMSELoss 来进行微调。

关键参数说明:

  • --model_name_or_path: 确保指向原始的 bge-reranker-v2-m3 模型。
  • --per_device_train_batch_size: 根据你的 GPU 显存大小调整。
  • --nproc_per_node: 使用的 GPU 数量。

训练过程日志会显示 eval_loss,我们可以依据此指标来保存最佳模型。

步骤 5:性能评测与对比

训练完成后,我们评测一下效果

执行脚本:

bash 复制代码
bash evaluate.sh

该脚本会调用 evaluation.py,分别评估蒸馏前蒸馏后的模型在测试集上的性能,并清晰地展示对比结果。

关键参数说明:

  • --model_before_path: 指向原始 bge-reranker-v2-m3 模型。
  • --model_after_path: 指向 train.sh 训练产出的模型 checkpoint 路径(例如 output/checkpoint-1217)。

四、结果分析

经过蒸馏,bge-reranker-v2-m3 模型在 stackoverflowdupquestions-reranking 测试集上的各项重排指标都获得了明显提升

指标 (Metric) 蒸馏前 (Before) 蒸馏后 (After) 绝对提升 相对提升
MAP 0.472061 0.565317 +0.093256 +19.76% 🚀
MRR@10 0.478234 0.573779 +0.095545 +19.98% 🚀
NDCG@10 0.547284 0.639033 +0.091748 +16.76% 🚀

从上表可以看出所有核心评估指标(MAP, MRR, NDCG)均有16-20%的增长。这表明在这个场景有效将知识从大模型蒸馏到小模型上(但是整体分数还是较低,可能这个数据集比较有难度)

五、参考文献

Qwen3 Embedding: Advancing Text Embedding and Reranking Through Foundation Models

Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation


如果这篇文章对你有帮助,请给我的Github项目点个star吧:github.com/kanhaoning/...


本文首发于知乎平台,原文链接:zhuanlan.zhihu.com/p/192822324...

相关推荐
没学上了11 分钟前
Qt轮廓分析设计+算法+避坑
算法
用户9704438781161 小时前
taobao商品详情数据获取实战方法
算法·html
yu2024111 小时前
【【异世界历险之数据结构世界(二叉树)】】
数据结构·算法
补三补四2 小时前
RNN(循环神经网络)
人工智能·rnn·深度学习·神经网络·算法
逐闲3 小时前
LeetCode热题100【第一天】
算法·leetcode
爱吃涮毛肚的肥肥(暂时吃不了版)4 小时前
剑指offer——模拟:顺时针打印矩阵
算法·leetcode·矩阵
chao_7894 小时前
动态规划题解——乘积最大子数组【LeetCode】
python·算法·leetcode·动态规划
今天背单词了吗9804 小时前
算法学习笔记:16.哈希算法 ——从原理到实战,涵盖 LeetCode 与考研 408 例题
笔记·学习·算法
前端拿破轮5 小时前
字符串消消乐你会吗?😋😋😋
算法·leetcode·面试
EndingCoder5 小时前
图算法在前端的复杂交互
前端·算法·图算法