传统单轮召回,根据查询检索候选文档,然后直接进入排序,难以满足用户意图
迭代式检索,采用多轮查询-响应的交互形式,反复"检索-回答-判断"循环,逐步改进召回结果的质量
(1)根据当前查询进行召回,获得候选文档集合
(2)对候选文档排序,生成top-k个最相关文档
(3)利用阅读理解、信息抽取等技术,从上述文档总结一个粗略答案
(4)评估答案的质量,根据评估结果生成一个改进后的新查询
(5)将新查询代入步骤(1),开启新一轮迭代
py
import torch
import numpy as np
from typing import List, Tuple, Dict
from rank_bm25 import BM25Okapi
import jieba
from transformers import AutoTokenizer, AutoModelForCausalLM
# ====================== 1. 配置项 ======================
# 模型配置(建议替换为中文优化模型,如ChatGLM3-6B、Llama2-Chinese)
MODEL_NAME = "THUDM/chatglm3-6b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 检索配置
TOP_K_RETRIEVE = 5 # 每轮召回Top5文档
TOP_K_RANK = 3 # 排序后取Top3文档生成答案
MAX_ITER = 3 # 最大迭代轮数(避免无限循环)
# 质量评估阈值(0-1,高于该值则停止迭代)
QUALITY_THRESHOLD = 0.8
# ====================== 2. 初始化工具:检索器/模型/分词器 ======================
class IterativeRetriever:
def __init__(self, corpus: List[str]):
"""
初始化迭代式检索器
:param corpus: 文档库(原始文本列表)
"""
# 初始化BM25检索器(中文分词)
self.tokenized_corpus = [list(jieba.cut(doc.lower())) for doc in corpus]
self.bm25 = BM25Okapi(self.tokenized_corpus)
self.corpus = corpus
# 初始化LLM(用于生成答案、评估质量、优化查询)
self.tokenizer, self.model = self._init_llm()
def _init_llm(self) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
"""初始化大模型"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
model.eval()
return tokenizer, model
# ====================== 步骤1:检索候选文档 ======================
def retrieve(self, query: str) -> List[Tuple[str, float]]:
"""
步骤1:根据当前查询召回候选文档
:param query: 当前查询
:return: [(文档文本, BM25得分), ...](Top_K_RETRIEVE)
"""
tokenized_query = list(jieba.cut(query.lower()))
bm25_scores = self.bm25.get_scores(tokenized_query)
# 按得分降序排序,取Top_K_RETRIEVE
top_indices = np.argsort(bm25_scores)[::-1][:TOP_K_RETRIEVE]
retrieve_results = [(self.corpus[idx], bm25_scores[idx]) for idx in top_indices]
return retrieve_results
# ====================== 步骤2:文档排序(BM25得分直接排序) ======================
def rank_docs(self, retrieve_results: List[Tuple[str, float]]) -> List[str]:
"""
步骤2:对候选文档排序,返回Top_K_RANK文档文本
:param retrieve_results: 检索结果(含得分)
:return: 排序后的Top_K_RANK文档列表
"""
# 按BM25得分降序排序
sorted_docs = sorted(retrieve_results, key=lambda x: x[1], reverse=True)
# 提取Top_K_RANK文档文本
top_k_docs = [doc for doc, _ in sorted_docs[:TOP_K_RANK]]
return top_k_docs
# ====================== 步骤3:生成粗略答案 ======================
def generate_rough_answer(self, query: str, top_docs: List[str]) -> str:
"""
步骤3:从Top-K文档中总结粗略答案
:param query: 当前查询
:param top_docs: 排序后的Top-K文档
:return: 粗略答案
"""
prompt = f"""
请基于以下文档,简要回答查询「{query}」:
文档1:{top_docs[0]}
{"文档2:" + top_docs[1] if len(top_docs)>=2 else ""}
{"文档3:" + top_docs[2] if len(top_docs)>=3 else ""}
要求:
1. 仅基于文档内容回答,不编造信息;
2. 语言简洁,字数控制在100字以内;
3. 直接给出答案,无需额外解释。
"""
# LLM生成答案
answer = self._llm_generate(prompt, max_new_tokens=150)
return answer.strip()
# ====================== 步骤4:评估答案质量 + 生成新查询 ======================
def evaluate_and_rewrite_query(self, original_query: str, current_query: str, rough_answer: str, top_docs: List[str]) -> Tuple[float, str]:
"""
步骤4:评估答案质量,并生成改进后的新查询
:param original_query: 初始查询
:param current_query: 当前查询
:param rough_answer: 粗略答案
:param top_docs: 排序后的Top-K文档
:return: (答案质量评分[0-1], 新查询)
"""
# 1. 评估答案质量
eval_prompt = f"""
请评估以下答案是否满足初始查询的需求,仅返回0-1之间的评分(保留2位小数):
初始查询:{original_query}
当前查询:{current_query}
答案:{rough_answer}
评分规则:
- 0.0-0.5:答案未覆盖核心需求,信息缺失严重;
- 0.5-0.8:答案覆盖部分核心需求,仍有信息缺失;
- 0.8-1.0:答案完整覆盖核心需求,信息准确。
"""
quality_score = self._llm_generate(eval_prompt, max_new_tokens=10)
# 解析评分(容错处理)
try:
quality_score = float(quality_score.strip())
quality_score = max(0.0, min(1.0, quality_score)) # 限制在0-1
except:
quality_score = 0.6 # 解析失败默认评分
# 2. 生成改进后的新查询(若评分不足则优化)
if quality_score >= QUALITY_THRESHOLD:
new_query = current_query # 质量达标,无需优化
else:
rewrite_prompt = f"""
初始查询:{original_query}
当前查询:{current_query}
现有答案:{rough_answer}
相关文档片段:{chr(10).join(top_docs)}
该答案未完全满足需求,请生成一个更精准的新查询,要求:
1. 补充缺失的核心维度(如预防/治疗/危害);
2. 更具体、有针对性;
3. 仅返回新查询,无需额外解释。
"""
new_query = self._llm_generate(rewrite_prompt, max_new_tokens=100)
new_query = new_query.strip() or current_query # 容错:为空则沿用当前查询
return quality_score, new_query
# ====================== LLM生成辅助函数 ======================
def _llm_generate(self, prompt: str, max_new_tokens: int = 200) -> str:
"""LLM生成通用函数(禁用梯度)"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024
).to(DEVICE)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=3,
temperature=0.7,
repetition_penalty=1.2,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# 过滤prompt本身(避免模型重复输入)
response = response.replace(prompt, "").strip()
return response
# ====================== 主迭代流程 ======================
def iterative_retrieval(self, original_query: str) -> Dict:
"""
迭代式检索主流程
:param original_query: 初始查询
:return: 最终结果(含各轮迭代信息)
"""
# 初始化迭代状态
current_query = original_query
iteration_history = []
final_answer = ""
final_quality_score = 0.0
# 迭代循环
for iter_num in range(1, MAX_ITER + 1):
print(f"\n===== 迭代轮数 {iter_num}/{MAX_ITER} =====")
print(f"当前查询:{current_query}")
# 步骤1:检索候选文档
retrieve_results = self.retrieve(current_query)
print(f"召回文档数:{len(retrieve_results)}")
# 步骤2:文档排序
top_docs = self.rank_docs(retrieve_results)
print(f"排序后Top-{TOP_K_RANK}文档:{[doc[:50]+'...' for doc in top_docs]}")
# 步骤3:生成粗略答案
rough_answer = self.generate_rough_answer(current_query, top_docs)
print(f"粗略答案:{rough_answer}")
# 步骤4:评估质量 + 生成新查询
quality_score, new_query = self.evaluate_and_rewrite_query(
original_query, current_query, rough_answer, top_docs
)
print(f"答案质量评分:{quality_score:.2f}")
print(f"优化后新查询:{new_query}")
# 记录迭代历史
iteration_history.append({
"iter_num": iter_num,
"current_query": current_query,
"rough_answer": rough_answer,
"quality_score": quality_score,
"new_query": new_query,
"top_docs": top_docs
})
# 更新状态
current_query = new_query
final_answer = rough_answer
final_quality_score = quality_score
# 终止条件:质量达标或达到最大迭代数
if quality_score >= QUALITY_THRESHOLD:
print(f"\n✅ 答案质量达标(评分≥{QUALITY_THRESHOLD}),提前终止迭代")
break
# 返回最终结果
return {
"original_query": original_query,
"final_query": current_query,
"final_answer": final_answer,
"final_quality_score": final_quality_score,
"iteration_history": iteration_history,
"max_iter_reached": final_quality_score < QUALITY_THRESHOLD
}
# ====================== 3. 测试与示例 ======================
if __name__ == "__main__":
# 模拟文档库(实际场景替换为真实文档)
test_corpus = [
"儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上,避免含糖饮料",
"治疗儿童肥胖症不能节食,需均衡饮食+适度运动,极端病例可就医",
"儿童肥胖的危害包括高血压、糖尿病、心理自卑等,需早干预",
"家庭习惯对儿童肥胖影响大,家长应以身作则,减少视屏时间",
"儿童肥胖症的药物治疗仅适用于重度肥胖,需医生全程监督",
"预防儿童肥胖的核心是规律进餐,多吃蔬菜和全谷物,少吃油炸食品",
"运动干预是儿童肥胖治疗的核心,推荐游泳、球类等中等强度运动",
"儿童肥胖症的长期影响包括成年后心脑血管疾病风险升高"
]
# 初始化迭代式检索器
retriever = IterativeRetriever(corpus=test_corpus)
# 初始查询
original_query = "如何预防和治疗儿童肥胖症?"
# 执行迭代式检索
final_result = retriever.iterative_retrieval(original_query)
# 输出最终结果
print("\n" + "="*80)
print("【迭代式检索最终结果】")
print(f"初始查询:{final_result['original_query']}")
print(f"最终优化查询:{final_result['final_query']}")
print(f"最终答案:{final_result['final_answer']}")
print(f"最终质量评分:{final_result['final_quality_score']:.2f}")
print(f"是否达到最大迭代数:{final_result['max_iter_reached']}")
结果
text
===== 迭代轮数 1/3 =====
当前查询:如何预防和治疗儿童肥胖症?
召回文档数:5
排序后Top-3文档:['儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上...', '治疗儿童肥胖症不能节食,需均衡饮食+适度运动...', '儿童肥胖的危害包括高血压、糖尿病、心理自卑等...']
粗略答案:预防儿童肥胖需控制高糖饮食、每天运动60分钟;治疗需均衡饮食+适度运动,不节食。
答案质量评分:0.70
优化后新查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
===== 迭代轮数 2/3 =====
当前查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
召回文档数:5
排序后Top-3文档:['预防儿童肥胖的核心是规律进餐,多吃蔬菜和全谷物...', '运动干预是儿童肥胖治疗的核心,推荐游泳、球类等...', '儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上...']
粗略答案:预防:规律进餐、多吃蔬菜全谷物、少吃油炸食品,每天运动60分钟;治疗:游泳/球类等中等强度运动,均衡饮食不节食。
答案质量评分:0.85
优化后新查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
✅ 答案质量达标(评分≥0.8),提前终止迭代
================================================================================
【迭代式检索最终结果】
初始查询:如何预防和治疗儿童肥胖症?
最终优化查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
最终答案:预防:规律进餐、多吃蔬菜全谷物、少吃油炸食品,每天运动60分钟;治疗:游泳/球类等中等强度运动,均衡饮食不节食。
最终质量评分:0.85
是否达到最大迭代数:False
top_indices = np.argsort(bm25_scores)[::-1][:TOP_K_RETRIEVE] 含义
对 BM25 检索得分排序,提取 Top-K 最高分文档的索引,最终得到 "得分最高的前 N 个文档在原始文档库中的位置"。
解读
假设:
bm25_scores 是一维数组,存储每个文档的 BM25 得分(长度 = 文档库总数):
py
#(对应 5 个文档的得分)
bm25_scores = np.array([1.2, 3.5, 0.8, 2.9, 4.1])
TOP_K_RETRIEVE = 3(要取得分最高的前 3 个文档)
np.argsort(bm25_scores) ------ 对得分排序,返回 "原索引"
np.argsort() 是 NumPy 的核心函数,对数组升序排序,但返回的不是排序后的值,而是原始数组中对应元素的索引
目的:知道 "哪些位置的文档得分高 / 低",而非仅知道得分本身
python
np.argsort(bm25_scores) # 输入[1.2, 3.5, 0.8, 2.9, 4.1]
# 输出:array([2, 0, 3, 1, 4])
解释:
升序排序后的得分是 0.8 → 1.2 → 2.9 → 3.5 → 4.1
这些得分对应的原始索引是:2(0.8)、0(1.2)、3(2.9)、1(3.5)、4(4.1)
[::-1] ------ 反转数组(升序→降序)
Python 的切片语法,作用是反转数组 / 列表,把升序的索引转为降序(对应得分从高到低)
python
np.argsort(bm25_scores)[::-1] # 先排序再反转
# 输出:array([4, 1, 3, 0, 2])
反转后,索引顺序对应得分降序:4(4.1)、1(3.5)、3(2.9)、0(1.2)、2(0.8)
此时索引顺序就是 "得分从高到低的文档位置"
[:TOP_K_RETRIEVE] ------ 取前 N 个索引(Top-K)
[:N] 是切片语法,取数组前 N 个元素,对应 "得分最高的前 TOP_K_RETRIEVE 个文档索引"
python
np.argsort(bm25_scores)[::-1][:3] # 取前3个索引
# 输出:array([4, 1, 3])
解释:
最终得到的 top_indices 是 [4,1,3],即:
文档库中索引 4的文档得分最高(4.1),索引 1 次之(3.5),索引 3 第三(2.9)
为什么要这么写?
在迭代式检索的场景中:
bm25_scores 是 "每个文档的得分",但需要的是 "哪些文档得分高"(索引),而非得分本身
链式操作简洁高效(NumPy 数组操作比纯 Python 列表快得多)
最终通过 top_indices 能精准定位到文档库中 "最相关的 Top-K 文档",为后续排序、生成答案提供基础
等价写法
python
# 步骤1:升序排序,返回索引
sorted_indices_asc = np.argsort(bm25_scores)
# 步骤2:转为降序
sorted_indices_desc = sorted_indices_asc[::-1]
# 步骤3:取Top-K
top_indices = sorted_indices_desc[:TOP_K_RETRIEVE]
混淆 np.sort() 和 np.argsort()
np.sort(bm25_scores):返回排序后的得分值(如 [0.8,1.2,2.9,3.5,4.1]),丢失索引
np.argsort(bm25_scores):返回排序后的索引,保留文档位置,这是检索场景需要的
气态
忘记反转([::-1]):会导致取 "得分最低的 Top-K",完全偏离需求
切片顺序错误:[TOP_K_RETRIEVE:] 会取 "后 N 个",而非前 N 个