检索模型cross-encoder笔记

文章目录

cross-encoder一种检索模型,和双路召回机制不一样,各有优缺点。
cross-encoder最大的特点就是会将query(问题)和document(候选文本)一起分析。
一般的流程是,双路召回先粗排,cross-encoder再精排。

计算句子对相似度

代码:

python 复制代码
from sentence_transformers import CrossEncoder

# 1. 加载预训练模型
# 这里使用微软开源的 MiniLM 模型,速度快且效果好
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 2. 准备句子对
# 假设我们要判断"查询"和"文档"的相关性
pairs = [
    ['如何重置路由器密码?', '忘记路由器管理密码了怎么办?'],  # 语义高度相关
    ['如何重置路由器密码?', '路由器指示灯闪烁的含义'],      # 语义相关性低
    ['如何重置路由器密码?', '最新的iPhone价格是多少?']     # 完全无关
]

# 3. 预测分数
scores = model.predict(pairs)

# 4. 输出结果
for pair, score in zip(pairs, scores):
    print(f"查询: {pair[0]}")
    print(f"文档: {pair[1]}")
    print(f"相关性分数: {score:.4f}")
    print("-" * 30)

输出结果:

bash 复制代码
查询: 如何重置路由器密码?
文档: 忘记路由器管理密码了怎么办?
相关性分数: 7.0266
------------------------------
查询: 如何重置路由器密码?
文档: 路由器指示灯闪烁的含义
相关性分数: 3.6690
------------------------------
查询: 如何重置路由器密码?
文档: 最新的iPhone价格是多少?
相关性分数: 3.6907

解读:

第一组的文档很明显更贴近查询,所以给了高分,和我们预期的效果一致。

搜索结果的"重排序"
python 复制代码
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

# --- 第一步:双编码器召回 (Bi-Encoder Retrieval) ---
# 模拟从海量数据库中快速检索
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
query = "如何治疗感冒?"

# 假设这是数据库中的文档片段
corpus = [
    "感冒通常由病毒引起,建议多休息多喝水。",
    "Python 是一种高级编程语言。",
    "治疗流感的特效药需要医生处方。",
    "感冒和流感的区别在于症状的严重程度。",
    "今天天气真不错,适合出去野餐。"
]

# 将文档转换为向量(实际项目中通常预先计算好)
corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True)
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)

# 计算余弦相似度并获取 Top 3 结果
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=3)

print(f"--- 双编码器召回结果 (Top 3) ---")
for score, idx in zip(top_results.values, top_results.indices):
    print(f"[{score:.4f}] {corpus[idx]}")

# --- 第二步:交叉编码器重排序 (Cross-Encoder Reranking) ---
# 将召回的 3 个结果交给 Cross-Encoder 进行精准判断
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 构建待重排序的句子对
rerank_pairs = [[query, corpus[idx]] for idx in top_results.indices]

# 获取精准分数
rerank_scores = reranker.predict(rerank_pairs)

# 将分数和原文重新组合并排序
final_results = sorted(zip(rerank_scores, [corpus[idx] for idx in top_results.indices]), reverse=True)

print(f"\n--- Cross-Encoder 重排序结果 ---")
for score, text in final_results:
    print(f"[{score:.4f}] {text}")

输出结果:

bash 复制代码
--- 双编码器召回结果 (Top 3) ---
[0.8146] 治疗流感的特效药需要医生处方。
[0.3674] 感冒和流感的区别在于症状的严重程度。
[0.3576] 感冒通常由病毒引起,建议多休息多喝水。

--- Cross-Encoder 重排序结果 ---
[7.2241] 治疗流感的特效药需要医生处方。
[4.5378] 感冒通常由病毒引起,建议多休息多喝水。
[4.3828] 感冒和流感的区别在于症状的严重程度。

解读:

双编码器可能会因为关键词匹配(如"感冒"和"流感")给出较高的初步分数,但 Cross-Encoder 能更精准地判断"治疗感冒"和"治疗流感"的区别,从而给出更符合用户意图的排序。

相关推荐
safedebug8 小时前
此服务器的证书无效 您可能正在连接到一个伪装
笔记·学习
做cv的小昊8 小时前
【TJU】应用统计学——第六周作业(3.3 两个正态总体参数的假设检验、3.4 非正态总体参数的假设检验、4.1 一元线性回归分析)
笔记·算法·数学建模·矩阵·回归·线性回归·学习方法
晨晖29 小时前
linux笔记6
linux·运维·笔记
古方路杰出青年9 小时前
学习笔记1:Python FastAPI极简后端API示例解析
笔记·后端·python·学习·fastapi
羊群智妍9 小时前
2026年GEO监测工具大全|免费AI搜索优化直接用
笔记
江湖人称小鱼哥19 小时前
Obsidian-Graphify-让你的笔记库自己长出知识图谱
笔记·知识图谱·obsidian·claude code·graphify·卡帕西
苦 涩21 小时前
考研408笔记之计算机网络(三)——数据链路层
笔记·计算机网络·考研408
三品吉他手会点灯21 小时前
STM32F103 学习笔记-21-串口通信(第4节)—串口发送和接收代码讲解(中)
笔记·stm32·单片机·嵌入式硬件·学习
雾岛听蓝1 天前
Qt操作指南:窗口组成与菜单栏
开发语言·经验分享·笔记·qt
北山有鸟1 天前
【学习笔记】MIPI CSI-2 协议全解析:从底层封包到像素解析
linux·驱动开发·笔记·学习·相机