检索模型cross-encoder笔记

文章目录

cross-encoder一种检索模型,和双路召回机制不一样,各有优缺点。
cross-encoder最大的特点就是会将query(问题)和document(候选文本)一起分析。
一般的流程是,双路召回先粗排,cross-encoder再精排。

计算句子对相似度

代码:

python 复制代码
from sentence_transformers import CrossEncoder

# 1. 加载预训练模型
# 这里使用微软开源的 MiniLM 模型,速度快且效果好
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 2. 准备句子对
# 假设我们要判断"查询"和"文档"的相关性
pairs = [
    ['如何重置路由器密码?', '忘记路由器管理密码了怎么办?'],  # 语义高度相关
    ['如何重置路由器密码?', '路由器指示灯闪烁的含义'],      # 语义相关性低
    ['如何重置路由器密码?', '最新的iPhone价格是多少?']     # 完全无关
]

# 3. 预测分数
scores = model.predict(pairs)

# 4. 输出结果
for pair, score in zip(pairs, scores):
    print(f"查询: {pair[0]}")
    print(f"文档: {pair[1]}")
    print(f"相关性分数: {score:.4f}")
    print("-" * 30)

输出结果:

bash 复制代码
查询: 如何重置路由器密码?
文档: 忘记路由器管理密码了怎么办?
相关性分数: 7.0266
------------------------------
查询: 如何重置路由器密码?
文档: 路由器指示灯闪烁的含义
相关性分数: 3.6690
------------------------------
查询: 如何重置路由器密码?
文档: 最新的iPhone价格是多少?
相关性分数: 3.6907

解读:

第一组的文档很明显更贴近查询,所以给了高分,和我们预期的效果一致。

搜索结果的"重排序"
python 复制代码
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch

# --- 第一步:双编码器召回 (Bi-Encoder Retrieval) ---
# 模拟从海量数据库中快速检索
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
query = "如何治疗感冒?"

# 假设这是数据库中的文档片段
corpus = [
    "感冒通常由病毒引起,建议多休息多喝水。",
    "Python 是一种高级编程语言。",
    "治疗流感的特效药需要医生处方。",
    "感冒和流感的区别在于症状的严重程度。",
    "今天天气真不错,适合出去野餐。"
]

# 将文档转换为向量(实际项目中通常预先计算好)
corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True)
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)

# 计算余弦相似度并获取 Top 3 结果
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=3)

print(f"--- 双编码器召回结果 (Top 3) ---")
for score, idx in zip(top_results.values, top_results.indices):
    print(f"[{score:.4f}] {corpus[idx]}")

# --- 第二步:交叉编码器重排序 (Cross-Encoder Reranking) ---
# 将召回的 3 个结果交给 Cross-Encoder 进行精准判断
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 构建待重排序的句子对
rerank_pairs = [[query, corpus[idx]] for idx in top_results.indices]

# 获取精准分数
rerank_scores = reranker.predict(rerank_pairs)

# 将分数和原文重新组合并排序
final_results = sorted(zip(rerank_scores, [corpus[idx] for idx in top_results.indices]), reverse=True)

print(f"\n--- Cross-Encoder 重排序结果 ---")
for score, text in final_results:
    print(f"[{score:.4f}] {text}")

输出结果:

bash 复制代码
--- 双编码器召回结果 (Top 3) ---
[0.8146] 治疗流感的特效药需要医生处方。
[0.3674] 感冒和流感的区别在于症状的严重程度。
[0.3576] 感冒通常由病毒引起,建议多休息多喝水。

--- Cross-Encoder 重排序结果 ---
[7.2241] 治疗流感的特效药需要医生处方。
[4.5378] 感冒通常由病毒引起,建议多休息多喝水。
[4.3828] 感冒和流感的区别在于症状的严重程度。

解读:

双编码器可能会因为关键词匹配(如"感冒"和"流感")给出较高的初步分数,但 Cross-Encoder 能更精准地判断"治疗感冒"和"治疗流感"的区别,从而给出更符合用户意图的排序。

相关推荐
许长安11 小时前
gRPC 数据包传输格式解析:从 Protobuf 到 HTTP/2
c++·经验分享·笔记·http·rpc
问心无愧051311 小时前
ctf show web入门47
前端·笔记
网络工程小王11 小时前
【LangGraph 状态持久化(Checkpoint)详解】学习笔记
jvm·人工智能·笔记·langchain
问心无愧051311 小时前
ctf show web入门81
前端·笔记
sheeta199811 小时前
TypeScript 学习笔记
笔记·学习·typescript
sheeta199812 小时前
Pinia核心笔记
前端·vue.js·笔记
Honker_yhw12 小时前
大数据管理与应用系列丛书《数据挖掘》(吕欣等著)读书笔记-数据预处理
笔记·学习
sakiko_12 小时前
Swift学习笔记26-使用第三方库
笔记·学习·swift
じ☆冷颜〃21 小时前
实分析与测度论、复分析、傅里叶分析、泛函分析、凸分析概述.
笔记·学习·数学建模·拓扑学·傅立叶分析
kobesdu1 天前
【ROS2实战笔记-19】ROS2 生命周期节点的启动顺序、状态转换陷阱与热备方案
java·前端·笔记·机器人·ros·ros2