RAG - 高阶检索范式 - 迭代式检索

传统单轮召回,根据查询检索候选文档,然后直接进入排序,难以满足用户意图

迭代式检索,采用多轮查询-响应的交互形式,反复"检索-回答-判断"循环,逐步改进召回结果的质量

(1)根据当前查询进行召回,获得候选文档集合

(2)对候选文档排序,生成top-k个最相关文档

(3)利用阅读理解、信息抽取等技术,从上述文档总结一个粗略答案

(4)评估答案的质量,根据评估结果生成一个改进后的新查询

(5)将新查询代入步骤(1),开启新一轮迭代

py 复制代码
import torch
import numpy as np
from typing import List, Tuple, Dict
from rank_bm25 import BM25Okapi
import jieba
from transformers import AutoTokenizer, AutoModelForCausalLM

# ====================== 1. 配置项 ======================
# 模型配置(建议替换为中文优化模型,如ChatGLM3-6B、Llama2-Chinese)
MODEL_NAME = "THUDM/chatglm3-6b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 检索配置
TOP_K_RETRIEVE = 5  # 每轮召回Top5文档
TOP_K_RANK = 3      # 排序后取Top3文档生成答案
MAX_ITER = 3        # 最大迭代轮数(避免无限循环)
# 质量评估阈值(0-1,高于该值则停止迭代)
QUALITY_THRESHOLD = 0.8

# ====================== 2. 初始化工具:检索器/模型/分词器 ======================
class IterativeRetriever:
    def __init__(self, corpus: List[str]):
        """
        初始化迭代式检索器
        :param corpus: 文档库(原始文本列表)
        """
        # 初始化BM25检索器(中文分词)
        self.tokenized_corpus = [list(jieba.cut(doc.lower())) for doc in corpus]
        self.bm25 = BM25Okapi(self.tokenized_corpus)
        self.corpus = corpus
        
        # 初始化LLM(用于生成答案、评估质量、优化查询)
        self.tokenizer, self.model = self._init_llm()

    def _init_llm(self) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
        """初始化大模型"""
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
            torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
        ).to(DEVICE)
        model.eval()
        return tokenizer, model

    # ====================== 步骤1:检索候选文档 ======================
    def retrieve(self, query: str) -> List[Tuple[str, float]]:
        """
        步骤1:根据当前查询召回候选文档
        :param query: 当前查询
        :return: [(文档文本, BM25得分), ...](Top_K_RETRIEVE)
        """
        tokenized_query = list(jieba.cut(query.lower()))
        bm25_scores = self.bm25.get_scores(tokenized_query)
        # 按得分降序排序,取Top_K_RETRIEVE
        top_indices = np.argsort(bm25_scores)[::-1][:TOP_K_RETRIEVE]
        retrieve_results = [(self.corpus[idx], bm25_scores[idx]) for idx in top_indices]
        return retrieve_results

    # ====================== 步骤2:文档排序(BM25得分直接排序) ======================
    def rank_docs(self, retrieve_results: List[Tuple[str, float]]) -> List[str]:
        """
        步骤2:对候选文档排序,返回Top_K_RANK文档文本
        :param retrieve_results: 检索结果(含得分)
        :return: 排序后的Top_K_RANK文档列表
        """
        # 按BM25得分降序排序
        sorted_docs = sorted(retrieve_results, key=lambda x: x[1], reverse=True)
        # 提取Top_K_RANK文档文本
        top_k_docs = [doc for doc, _ in sorted_docs[:TOP_K_RANK]]
        return top_k_docs

    # ====================== 步骤3:生成粗略答案 ======================
    def generate_rough_answer(self, query: str, top_docs: List[str]) -> str:
        """
        步骤3:从Top-K文档中总结粗略答案
        :param query: 当前查询
        :param top_docs: 排序后的Top-K文档
        :return: 粗略答案
        """
        prompt = f"""
        请基于以下文档,简要回答查询「{query}」:
        文档1:{top_docs[0]}
        {"文档2:" + top_docs[1] if len(top_docs)>=2 else ""}
        {"文档3:" + top_docs[2] if len(top_docs)>=3 else ""}
        
        要求:
        1. 仅基于文档内容回答,不编造信息;
        2. 语言简洁,字数控制在100字以内;
        3. 直接给出答案,无需额外解释。
        """
        # LLM生成答案
        answer = self._llm_generate(prompt, max_new_tokens=150)
        return answer.strip()

    # ====================== 步骤4:评估答案质量 + 生成新查询 ======================
    def evaluate_and_rewrite_query(self, original_query: str, current_query: str, rough_answer: str, top_docs: List[str]) -> Tuple[float, str]:
        """
        步骤4:评估答案质量,并生成改进后的新查询
        :param original_query: 初始查询
        :param current_query: 当前查询
        :param rough_answer: 粗略答案
        :param top_docs: 排序后的Top-K文档
        :return: (答案质量评分[0-1], 新查询)
        """
        # 1. 评估答案质量
        eval_prompt = f"""
        请评估以下答案是否满足初始查询的需求,仅返回0-1之间的评分(保留2位小数):
        初始查询:{original_query}
        当前查询:{current_query}
        答案:{rough_answer}
        
        评分规则:
        - 0.0-0.5:答案未覆盖核心需求,信息缺失严重;
        - 0.5-0.8:答案覆盖部分核心需求,仍有信息缺失;
        - 0.8-1.0:答案完整覆盖核心需求,信息准确。
        """
        
        quality_score = self._llm_generate(eval_prompt, max_new_tokens=10)
        # 解析评分(容错处理)
        try:
            quality_score = float(quality_score.strip())
            quality_score = max(0.0, min(1.0, quality_score))  # 限制在0-1
        except:
            quality_score = 0.6  # 解析失败默认评分
        
        # 2. 生成改进后的新查询(若评分不足则优化)
        if quality_score >= QUALITY_THRESHOLD:
            new_query = current_query  # 质量达标,无需优化
        else:
            rewrite_prompt = f"""
            初始查询:{original_query}
            当前查询:{current_query}
            现有答案:{rough_answer}
            相关文档片段:{chr(10).join(top_docs)}
            
            该答案未完全满足需求,请生成一个更精准的新查询,要求:
            1. 补充缺失的核心维度(如预防/治疗/危害);
            2. 更具体、有针对性;
            3. 仅返回新查询,无需额外解释。
            """
            new_query = self._llm_generate(rewrite_prompt, max_new_tokens=100)
            new_query = new_query.strip() or current_query  # 容错:为空则沿用当前查询
        
        return quality_score, new_query

    # ====================== LLM生成辅助函数 ======================
    def _llm_generate(self, prompt: str, max_new_tokens: int = 200) -> str:
        """LLM生成通用函数(禁用梯度)"""
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=1024
        ).to(DEVICE)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                num_beams=3,
                temperature=0.7,
                repetition_penalty=1.2,
                pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0
            )
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        # 过滤prompt本身(避免模型重复输入)
        response = response.replace(prompt, "").strip()
        return response

    # ====================== 主迭代流程 ======================
    def iterative_retrieval(self, original_query: str) -> Dict:
        """
        迭代式检索主流程
        :param original_query: 初始查询
        :return: 最终结果(含各轮迭代信息)
        """
        # 初始化迭代状态
        current_query = original_query
        iteration_history = []
        final_answer = ""
        final_quality_score = 0.0

        # 迭代循环
        for iter_num in range(1, MAX_ITER + 1):
            print(f"\n===== 迭代轮数 {iter_num}/{MAX_ITER} =====")
            print(f"当前查询:{current_query}")

            # 步骤1:检索候选文档
            retrieve_results = self.retrieve(current_query)
            print(f"召回文档数:{len(retrieve_results)}")

            # 步骤2:文档排序
            top_docs = self.rank_docs(retrieve_results)
            print(f"排序后Top-{TOP_K_RANK}文档:{[doc[:50]+'...' for doc in top_docs]}")

            # 步骤3:生成粗略答案
            rough_answer = self.generate_rough_answer(current_query, top_docs)
            print(f"粗略答案:{rough_answer}")

            # 步骤4:评估质量 + 生成新查询
            quality_score, new_query = self.evaluate_and_rewrite_query(
                original_query, current_query, rough_answer, top_docs
            )
            print(f"答案质量评分:{quality_score:.2f}")
            print(f"优化后新查询:{new_query}")

            # 记录迭代历史
            iteration_history.append({
                "iter_num": iter_num,
                "current_query": current_query,
                "rough_answer": rough_answer,
                "quality_score": quality_score,
                "new_query": new_query,
                "top_docs": top_docs
            })

            # 更新状态
            current_query = new_query
            final_answer = rough_answer
            final_quality_score = quality_score

            # 终止条件:质量达标或达到最大迭代数
            if quality_score >= QUALITY_THRESHOLD:
                print(f"\n✅ 答案质量达标(评分≥{QUALITY_THRESHOLD}),提前终止迭代")
                break

        # 返回最终结果
        return {
            "original_query": original_query,
            "final_query": current_query,
            "final_answer": final_answer,
            "final_quality_score": final_quality_score,
            "iteration_history": iteration_history,
            "max_iter_reached": final_quality_score < QUALITY_THRESHOLD
        }

# ====================== 3. 测试与示例 ======================
if __name__ == "__main__":
    # 模拟文档库(实际场景替换为真实文档)
    test_corpus = [
        "儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上,避免含糖饮料",
        "治疗儿童肥胖症不能节食,需均衡饮食+适度运动,极端病例可就医",
        "儿童肥胖的危害包括高血压、糖尿病、心理自卑等,需早干预",
        "家庭习惯对儿童肥胖影响大,家长应以身作则,减少视屏时间",
        "儿童肥胖症的药物治疗仅适用于重度肥胖,需医生全程监督",
        "预防儿童肥胖的核心是规律进餐,多吃蔬菜和全谷物,少吃油炸食品",
        "运动干预是儿童肥胖治疗的核心,推荐游泳、球类等中等强度运动",
        "儿童肥胖症的长期影响包括成年后心脑血管疾病风险升高"
    ]

    # 初始化迭代式检索器
    retriever = IterativeRetriever(corpus=test_corpus)

    # 初始查询
    original_query = "如何预防和治疗儿童肥胖症?"

    # 执行迭代式检索
    final_result = retriever.iterative_retrieval(original_query)

    # 输出最终结果
    print("\n" + "="*80)
    print("【迭代式检索最终结果】")
    print(f"初始查询:{final_result['original_query']}")
    print(f"最终优化查询:{final_result['final_query']}")
    print(f"最终答案:{final_result['final_answer']}")
    print(f"最终质量评分:{final_result['final_quality_score']:.2f}")
    print(f"是否达到最大迭代数:{final_result['max_iter_reached']}")

结果

text 复制代码
===== 迭代轮数 1/3 =====
当前查询:如何预防和治疗儿童肥胖症?
召回文档数:5
排序后Top-3文档:['儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上...', '治疗儿童肥胖症不能节食,需均衡饮食+适度运动...', '儿童肥胖的危害包括高血压、糖尿病、心理自卑等...']
粗略答案:预防儿童肥胖需控制高糖饮食、每天运动60分钟;治疗需均衡饮食+适度运动,不节食。
答案质量评分:0.70
优化后新查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?

===== 迭代轮数 2/3 =====
当前查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
召回文档数:5
排序后Top-3文档:['预防儿童肥胖的核心是规律进餐,多吃蔬菜和全谷物...', '运动干预是儿童肥胖治疗的核心,推荐游泳、球类等...', '儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上...']
粗略答案:预防:规律进餐、多吃蔬菜全谷物、少吃油炸食品,每天运动60分钟;治疗:游泳/球类等中等强度运动,均衡饮食不节食。
答案质量评分:0.85
优化后新查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?

✅ 答案质量达标(评分≥0.8),提前终止迭代

================================================================================
【迭代式检索最终结果】
初始查询:如何预防和治疗儿童肥胖症?
最终优化查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
最终答案:预防:规律进餐、多吃蔬菜全谷物、少吃油炸食品,每天运动60分钟;治疗:游泳/球类等中等强度运动,均衡饮食不节食。
最终质量评分:0.85
是否达到最大迭代数:False

top_indices = np.argsort(bm25_scores)[::-1][:TOP_K_RETRIEVE] 含义

对 BM25 检索得分排序,提取 Top-K 最高分文档的索引,最终得到 "得分最高的前 N 个文档在原始文档库中的位置"。

解读

假设:

bm25_scores 是一维数组,存储每个文档的 BM25 得分(长度 = 文档库总数):

py 复制代码
#(对应 5 个文档的得分)
bm25_scores = np.array([1.2, 3.5, 0.8, 2.9, 4.1])

TOP_K_RETRIEVE = 3(要取得分最高的前 3 个文档)

np.argsort(bm25_scores) ------ 对得分排序,返回 "原索引"
np.argsort() 是 NumPy 的核心函数,对数组升序排序,但返回的不是排序后的值,而是原始数组中对应元素的索引

目的:知道 "哪些位置的文档得分高 / 低",而非仅知道得分本身

python 复制代码
np.argsort(bm25_scores)  # 输入[1.2, 3.5, 0.8, 2.9, 4.1]
# 输出:array([2, 0, 3, 1, 4])

解释:

升序排序后的得分是 0.8 → 1.2 → 2.9 → 3.5 → 4.1

这些得分对应的原始索引是:2(0.8)、0(1.2)、3(2.9)、1(3.5)、4(4.1)

[::-1] ------ 反转数组(升序→降序)

Python 的切片语法,作用是反转数组 / 列表,把升序的索引转为降序(对应得分从高到低)

python 复制代码
np.argsort(bm25_scores)[::-1]  # 先排序再反转
# 输出:array([4, 1, 3, 0, 2])

反转后,索引顺序对应得分降序:4(4.1)、1(3.5)、3(2.9)、0(1.2)、2(0.8)

此时索引顺序就是 "得分从高到低的文档位置"

[:TOP_K_RETRIEVE] ------ 取前 N 个索引(Top-K)
[:N] 是切片语法,取数组前 N 个元素,对应 "得分最高的前 TOP_K_RETRIEVE 个文档索引"

python 复制代码
np.argsort(bm25_scores)[::-1][:3]  # 取前3个索引
# 输出:array([4, 1, 3])

解释:

最终得到的 top_indices 是 [4,1,3],即:

文档库中索引 4的文档得分最高(4.1),索引 1 次之(3.5),索引 3 第三(2.9)

为什么要这么写?

在迭代式检索的场景中:

bm25_scores 是 "每个文档的得分",但需要的是 "哪些文档得分高"(索引),而非得分本身

链式操作简洁高效(NumPy 数组操作比纯 Python 列表快得多)

最终通过 top_indices 能精准定位到文档库中 "最相关的 Top-K 文档",为后续排序、生成答案提供基础

等价写法

python 复制代码
# 步骤1:升序排序,返回索引
sorted_indices_asc = np.argsort(bm25_scores)
# 步骤2:转为降序
sorted_indices_desc = sorted_indices_asc[::-1]
# 步骤3:取Top-K
top_indices = sorted_indices_desc[:TOP_K_RETRIEVE]

混淆 np.sort() 和 np.argsort()

np.sort(bm25_scores):返回排序后的得分值(如 [0.8,1.2,2.9,3.5,4.1]),丢失索引

np.argsort(bm25_scores):返回排序后的索引,保留文档位置,这是检索场景需要的

气态

忘记反转([::-1]):会导致取 "得分最低的 Top-K",完全偏离需求

切片顺序错误:[TOP_K_RETRIEVE:] 会取 "后 N 个",而非前 N 个

相关推荐
xcLeigh1 小时前
AI的提示词专栏:写作助手 Prompt,从提纲到完整文章
人工智能·ai·prompt·提示词
lbb 小魔仙2 小时前
AI Agent 开发终极手册:Manus、MetaGPT 与 CrewAI 深度对比
人工智能·ai
humors2213 小时前
四步生成喜欢的图片
人工智能·ai·图片·背景·祝福·头像
技术小甜甜3 小时前
[AI 工程实践] 远程调用 Ollama 报错解析:如何解决“本地文件找不到”的误区
ai·自动化·llm·agent·ollama·’人工智能·aider
Damon小智4 小时前
【TextIn大模型加速器 + 火山引擎】跨国药企多语言手册智能翻译系统设计与实现
人工智能·ai·ocr·agent·火山引擎
雪碧聊技术9 小时前
《2025全栈成长实录:Vue3→Spring Boot→云部署→AI探索,一个初级工程师的技术演进》
ai·年终总结·全栈·csdn博客之星
Android系统攻城狮12 小时前
XUbuntu22.04之视频编辑利器:kdenlive剪切视频片段+自动转码输出(二百八十七)
ai·音视频·视频转码·视频编辑·xubuntu22.04
乾元12 小时前
Network-as-Code:把 HCIE / CCIE 实验脚本转为企业级 CI 工程化流程
运维·网络·人工智能·安全·web安全·ai·架构
CoderJia程序员甲13 小时前
GitHub 热榜项目 - 日榜(2026-1-1)
ai·开源·大模型·github·ai教程
村口曹大爷13 小时前
[特殊字符] 2026年AI最新趋势深度解读:智能体崛起、多模态融合、全球竞速加剧
人工智能·ai