文章目录
cross-encoder一种检索模型,和双路召回机制不一样,各有优缺点。
cross-encoder最大的特点就是会将query(问题)和document(候选文本)一起分析。
一般的流程是,双路召回先粗排,cross-encoder再精排。
计算句子对相似度
代码:
python
from sentence_transformers import CrossEncoder
# 1. 加载预训练模型
# 这里使用微软开源的 MiniLM 模型,速度快且效果好
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 2. 准备句子对
# 假设我们要判断"查询"和"文档"的相关性
pairs = [
['如何重置路由器密码?', '忘记路由器管理密码了怎么办?'], # 语义高度相关
['如何重置路由器密码?', '路由器指示灯闪烁的含义'], # 语义相关性低
['如何重置路由器密码?', '最新的iPhone价格是多少?'] # 完全无关
]
# 3. 预测分数
scores = model.predict(pairs)
# 4. 输出结果
for pair, score in zip(pairs, scores):
print(f"查询: {pair[0]}")
print(f"文档: {pair[1]}")
print(f"相关性分数: {score:.4f}")
print("-" * 30)
输出结果:
bash
查询: 如何重置路由器密码?
文档: 忘记路由器管理密码了怎么办?
相关性分数: 7.0266
------------------------------
查询: 如何重置路由器密码?
文档: 路由器指示灯闪烁的含义
相关性分数: 3.6690
------------------------------
查询: 如何重置路由器密码?
文档: 最新的iPhone价格是多少?
相关性分数: 3.6907
解读:
第一组的文档很明显更贴近查询,所以给了高分,和我们预期的效果一致。
搜索结果的"重排序"
python
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
# --- 第一步:双编码器召回 (Bi-Encoder Retrieval) ---
# 模拟从海量数据库中快速检索
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
query = "如何治疗感冒?"
# 假设这是数据库中的文档片段
corpus = [
"感冒通常由病毒引起,建议多休息多喝水。",
"Python 是一种高级编程语言。",
"治疗流感的特效药需要医生处方。",
"感冒和流感的区别在于症状的严重程度。",
"今天天气真不错,适合出去野餐。"
]
# 将文档转换为向量(实际项目中通常预先计算好)
corpus_embeddings = bi_encoder.encode(corpus, convert_to_tensor=True)
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
# 计算余弦相似度并获取 Top 3 结果
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=3)
print(f"--- 双编码器召回结果 (Top 3) ---")
for score, idx in zip(top_results.values, top_results.indices):
print(f"[{score:.4f}] {corpus[idx]}")
# --- 第二步:交叉编码器重排序 (Cross-Encoder Reranking) ---
# 将召回的 3 个结果交给 Cross-Encoder 进行精准判断
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 构建待重排序的句子对
rerank_pairs = [[query, corpus[idx]] for idx in top_results.indices]
# 获取精准分数
rerank_scores = reranker.predict(rerank_pairs)
# 将分数和原文重新组合并排序
final_results = sorted(zip(rerank_scores, [corpus[idx] for idx in top_results.indices]), reverse=True)
print(f"\n--- Cross-Encoder 重排序结果 ---")
for score, text in final_results:
print(f"[{score:.4f}] {text}")
输出结果:
bash
--- 双编码器召回结果 (Top 3) ---
[0.8146] 治疗流感的特效药需要医生处方。
[0.3674] 感冒和流感的区别在于症状的严重程度。
[0.3576] 感冒通常由病毒引起,建议多休息多喝水。
--- Cross-Encoder 重排序结果 ---
[7.2241] 治疗流感的特效药需要医生处方。
[4.5378] 感冒通常由病毒引起,建议多休息多喝水。
[4.3828] 感冒和流感的区别在于症状的严重程度。
解读:
双编码器可能会因为关键词匹配(如"感冒"和"流感")给出较高的初步分数,但 Cross-Encoder 能更精准地判断"治疗感冒"和"治疗流感"的区别,从而给出更符合用户意图的排序。