RAG - 高阶检索范式 - 迭代式检索

传统单轮召回,根据查询检索候选文档,然后直接进入排序,难以满足用户意图

迭代式检索,采用多轮查询-响应的交互形式,反复"检索-回答-判断"循环,逐步改进召回结果的质量

(1)根据当前查询进行召回,获得候选文档集合

(2)对候选文档排序,生成top-k个最相关文档

(3)利用阅读理解、信息抽取等技术,从上述文档总结一个粗略答案

(4)评估答案的质量,根据评估结果生成一个改进后的新查询

(5)将新查询代入步骤(1),开启新一轮迭代

py 复制代码
import torch
import numpy as np
from typing import List, Tuple, Dict
from rank_bm25 import BM25Okapi
import jieba
from transformers import AutoTokenizer, AutoModelForCausalLM

# ====================== 1. 配置项 ======================
# 模型配置(建议替换为中文优化模型,如ChatGLM3-6B、Llama2-Chinese)
MODEL_NAME = "THUDM/chatglm3-6b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 检索配置
TOP_K_RETRIEVE = 5  # 每轮召回Top5文档
TOP_K_RANK = 3      # 排序后取Top3文档生成答案
MAX_ITER = 3        # 最大迭代轮数(避免无限循环)
# 质量评估阈值(0-1,高于该值则停止迭代)
QUALITY_THRESHOLD = 0.8

# ====================== 2. 初始化工具:检索器/模型/分词器 ======================
class IterativeRetriever:
    def __init__(self, corpus: List[str]):
        """
        初始化迭代式检索器
        :param corpus: 文档库(原始文本列表)
        """
        # 初始化BM25检索器(中文分词)
        self.tokenized_corpus = [list(jieba.cut(doc.lower())) for doc in corpus]
        self.bm25 = BM25Okapi(self.tokenized_corpus)
        self.corpus = corpus
        
        # 初始化LLM(用于生成答案、评估质量、优化查询)
        self.tokenizer, self.model = self._init_llm()

    def _init_llm(self) -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
        """初始化大模型"""
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
            torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
        ).to(DEVICE)
        model.eval()
        return tokenizer, model

    # ====================== 步骤1:检索候选文档 ======================
    def retrieve(self, query: str) -> List[Tuple[str, float]]:
        """
        步骤1:根据当前查询召回候选文档
        :param query: 当前查询
        :return: [(文档文本, BM25得分), ...](Top_K_RETRIEVE)
        """
        tokenized_query = list(jieba.cut(query.lower()))
        bm25_scores = self.bm25.get_scores(tokenized_query)
        # 按得分降序排序,取Top_K_RETRIEVE
        top_indices = np.argsort(bm25_scores)[::-1][:TOP_K_RETRIEVE]
        retrieve_results = [(self.corpus[idx], bm25_scores[idx]) for idx in top_indices]
        return retrieve_results

    # ====================== 步骤2:文档排序(BM25得分直接排序) ======================
    def rank_docs(self, retrieve_results: List[Tuple[str, float]]) -> List[str]:
        """
        步骤2:对候选文档排序,返回Top_K_RANK文档文本
        :param retrieve_results: 检索结果(含得分)
        :return: 排序后的Top_K_RANK文档列表
        """
        # 按BM25得分降序排序
        sorted_docs = sorted(retrieve_results, key=lambda x: x[1], reverse=True)
        # 提取Top_K_RANK文档文本
        top_k_docs = [doc for doc, _ in sorted_docs[:TOP_K_RANK]]
        return top_k_docs

    # ====================== 步骤3:生成粗略答案 ======================
    def generate_rough_answer(self, query: str, top_docs: List[str]) -> str:
        """
        步骤3:从Top-K文档中总结粗略答案
        :param query: 当前查询
        :param top_docs: 排序后的Top-K文档
        :return: 粗略答案
        """
        prompt = f"""
        请基于以下文档,简要回答查询「{query}」:
        文档1:{top_docs[0]}
        {"文档2:" + top_docs[1] if len(top_docs)>=2 else ""}
        {"文档3:" + top_docs[2] if len(top_docs)>=3 else ""}
        
        要求:
        1. 仅基于文档内容回答,不编造信息;
        2. 语言简洁,字数控制在100字以内;
        3. 直接给出答案,无需额外解释。
        """
        # LLM生成答案
        answer = self._llm_generate(prompt, max_new_tokens=150)
        return answer.strip()

    # ====================== 步骤4:评估答案质量 + 生成新查询 ======================
    def evaluate_and_rewrite_query(self, original_query: str, current_query: str, rough_answer: str, top_docs: List[str]) -> Tuple[float, str]:
        """
        步骤4:评估答案质量,并生成改进后的新查询
        :param original_query: 初始查询
        :param current_query: 当前查询
        :param rough_answer: 粗略答案
        :param top_docs: 排序后的Top-K文档
        :return: (答案质量评分[0-1], 新查询)
        """
        # 1. 评估答案质量
        eval_prompt = f"""
        请评估以下答案是否满足初始查询的需求,仅返回0-1之间的评分(保留2位小数):
        初始查询:{original_query}
        当前查询:{current_query}
        答案:{rough_answer}
        
        评分规则:
        - 0.0-0.5:答案未覆盖核心需求,信息缺失严重;
        - 0.5-0.8:答案覆盖部分核心需求,仍有信息缺失;
        - 0.8-1.0:答案完整覆盖核心需求,信息准确。
        """
        
        quality_score = self._llm_generate(eval_prompt, max_new_tokens=10)
        # 解析评分(容错处理)
        try:
            quality_score = float(quality_score.strip())
            quality_score = max(0.0, min(1.0, quality_score))  # 限制在0-1
        except:
            quality_score = 0.6  # 解析失败默认评分
        
        # 2. 生成改进后的新查询(若评分不足则优化)
        if quality_score >= QUALITY_THRESHOLD:
            new_query = current_query  # 质量达标,无需优化
        else:
            rewrite_prompt = f"""
            初始查询:{original_query}
            当前查询:{current_query}
            现有答案:{rough_answer}
            相关文档片段:{chr(10).join(top_docs)}
            
            该答案未完全满足需求,请生成一个更精准的新查询,要求:
            1. 补充缺失的核心维度(如预防/治疗/危害);
            2. 更具体、有针对性;
            3. 仅返回新查询,无需额外解释。
            """
            new_query = self._llm_generate(rewrite_prompt, max_new_tokens=100)
            new_query = new_query.strip() or current_query  # 容错:为空则沿用当前查询
        
        return quality_score, new_query

    # ====================== LLM生成辅助函数 ======================
    def _llm_generate(self, prompt: str, max_new_tokens: int = 200) -> str:
        """LLM生成通用函数(禁用梯度)"""
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=1024
        ).to(DEVICE)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                num_beams=3,
                temperature=0.7,
                repetition_penalty=1.2,
                pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else 0
            )
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        # 过滤prompt本身(避免模型重复输入)
        response = response.replace(prompt, "").strip()
        return response

    # ====================== 主迭代流程 ======================
    def iterative_retrieval(self, original_query: str) -> Dict:
        """
        迭代式检索主流程
        :param original_query: 初始查询
        :return: 最终结果(含各轮迭代信息)
        """
        # 初始化迭代状态
        current_query = original_query
        iteration_history = []
        final_answer = ""
        final_quality_score = 0.0

        # 迭代循环
        for iter_num in range(1, MAX_ITER + 1):
            print(f"\n===== 迭代轮数 {iter_num}/{MAX_ITER} =====")
            print(f"当前查询:{current_query}")

            # 步骤1:检索候选文档
            retrieve_results = self.retrieve(current_query)
            print(f"召回文档数:{len(retrieve_results)}")

            # 步骤2:文档排序
            top_docs = self.rank_docs(retrieve_results)
            print(f"排序后Top-{TOP_K_RANK}文档:{[doc[:50]+'...' for doc in top_docs]}")

            # 步骤3:生成粗略答案
            rough_answer = self.generate_rough_answer(current_query, top_docs)
            print(f"粗略答案:{rough_answer}")

            # 步骤4:评估质量 + 生成新查询
            quality_score, new_query = self.evaluate_and_rewrite_query(
                original_query, current_query, rough_answer, top_docs
            )
            print(f"答案质量评分:{quality_score:.2f}")
            print(f"优化后新查询:{new_query}")

            # 记录迭代历史
            iteration_history.append({
                "iter_num": iter_num,
                "current_query": current_query,
                "rough_answer": rough_answer,
                "quality_score": quality_score,
                "new_query": new_query,
                "top_docs": top_docs
            })

            # 更新状态
            current_query = new_query
            final_answer = rough_answer
            final_quality_score = quality_score

            # 终止条件:质量达标或达到最大迭代数
            if quality_score >= QUALITY_THRESHOLD:
                print(f"\n✅ 答案质量达标(评分≥{QUALITY_THRESHOLD}),提前终止迭代")
                break

        # 返回最终结果
        return {
            "original_query": original_query,
            "final_query": current_query,
            "final_answer": final_answer,
            "final_quality_score": final_quality_score,
            "iteration_history": iteration_history,
            "max_iter_reached": final_quality_score < QUALITY_THRESHOLD
        }

# ====================== 3. 测试与示例 ======================
if __name__ == "__main__":
    # 模拟文档库(实际场景替换为真实文档)
    test_corpus = [
        "儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上,避免含糖饮料",
        "治疗儿童肥胖症不能节食,需均衡饮食+适度运动,极端病例可就医",
        "儿童肥胖的危害包括高血压、糖尿病、心理自卑等,需早干预",
        "家庭习惯对儿童肥胖影响大,家长应以身作则,减少视屏时间",
        "儿童肥胖症的药物治疗仅适用于重度肥胖,需医生全程监督",
        "预防儿童肥胖的核心是规律进餐,多吃蔬菜和全谷物,少吃油炸食品",
        "运动干预是儿童肥胖治疗的核心,推荐游泳、球类等中等强度运动",
        "儿童肥胖症的长期影响包括成年后心脑血管疾病风险升高"
    ]

    # 初始化迭代式检索器
    retriever = IterativeRetriever(corpus=test_corpus)

    # 初始查询
    original_query = "如何预防和治疗儿童肥胖症?"

    # 执行迭代式检索
    final_result = retriever.iterative_retrieval(original_query)

    # 输出最终结果
    print("\n" + "="*80)
    print("【迭代式检索最终结果】")
    print(f"初始查询:{final_result['original_query']}")
    print(f"最终优化查询:{final_result['final_query']}")
    print(f"最终答案:{final_result['final_answer']}")
    print(f"最终质量评分:{final_result['final_quality_score']:.2f}")
    print(f"是否达到最大迭代数:{final_result['max_iter_reached']}")

结果

text 复制代码
===== 迭代轮数 1/3 =====
当前查询:如何预防和治疗儿童肥胖症?
召回文档数:5
排序后Top-3文档:['儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上...', '治疗儿童肥胖症不能节食,需均衡饮食+适度运动...', '儿童肥胖的危害包括高血压、糖尿病、心理自卑等...']
粗略答案:预防儿童肥胖需控制高糖饮食、每天运动60分钟;治疗需均衡饮食+适度运动,不节食。
答案质量评分:0.70
优化后新查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?

===== 迭代轮数 2/3 =====
当前查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
召回文档数:5
排序后Top-3文档:['预防儿童肥胖的核心是规律进餐,多吃蔬菜和全谷物...', '运动干预是儿童肥胖治疗的核心,推荐游泳、球类等...', '儿童肥胖症预防需控制高糖饮食,每天运动60分钟以上...']
粗略答案:预防:规律进餐、多吃蔬菜全谷物、少吃油炸食品,每天运动60分钟;治疗:游泳/球类等中等强度运动,均衡饮食不节食。
答案质量评分:0.85
优化后新查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?

✅ 答案质量达标(评分≥0.8),提前终止迭代

================================================================================
【迭代式检索最终结果】
初始查询:如何预防和治疗儿童肥胖症?
最终优化查询:如何预防和治疗儿童肥胖症?具体饮食和运动方案有哪些?
最终答案:预防:规律进餐、多吃蔬菜全谷物、少吃油炸食品,每天运动60分钟;治疗:游泳/球类等中等强度运动,均衡饮食不节食。
最终质量评分:0.85
是否达到最大迭代数:False

top_indices = np.argsort(bm25_scores)[::-1][:TOP_K_RETRIEVE] 含义

对 BM25 检索得分排序,提取 Top-K 最高分文档的索引,最终得到 "得分最高的前 N 个文档在原始文档库中的位置"。

解读

假设:

bm25_scores 是一维数组,存储每个文档的 BM25 得分(长度 = 文档库总数):

py 复制代码
#(对应 5 个文档的得分)
bm25_scores = np.array([1.2, 3.5, 0.8, 2.9, 4.1])

TOP_K_RETRIEVE = 3(要取得分最高的前 3 个文档)

np.argsort(bm25_scores) ------ 对得分排序,返回 "原索引"
np.argsort() 是 NumPy 的核心函数,对数组升序排序,但返回的不是排序后的值,而是原始数组中对应元素的索引

目的:知道 "哪些位置的文档得分高 / 低",而非仅知道得分本身

python 复制代码
np.argsort(bm25_scores)  # 输入[1.2, 3.5, 0.8, 2.9, 4.1]
# 输出:array([2, 0, 3, 1, 4])

解释:

升序排序后的得分是 0.8 → 1.2 → 2.9 → 3.5 → 4.1

这些得分对应的原始索引是:2(0.8)、0(1.2)、3(2.9)、1(3.5)、4(4.1)

[::-1] ------ 反转数组(升序→降序)

Python 的切片语法,作用是反转数组 / 列表,把升序的索引转为降序(对应得分从高到低)

python 复制代码
np.argsort(bm25_scores)[::-1]  # 先排序再反转
# 输出:array([4, 1, 3, 0, 2])

反转后,索引顺序对应得分降序:4(4.1)、1(3.5)、3(2.9)、0(1.2)、2(0.8)

此时索引顺序就是 "得分从高到低的文档位置"

[:TOP_K_RETRIEVE] ------ 取前 N 个索引(Top-K)
[:N] 是切片语法,取数组前 N 个元素,对应 "得分最高的前 TOP_K_RETRIEVE 个文档索引"

python 复制代码
np.argsort(bm25_scores)[::-1][:3]  # 取前3个索引
# 输出:array([4, 1, 3])

解释:

最终得到的 top_indices 是 [4,1,3],即:

文档库中索引 4的文档得分最高(4.1),索引 1 次之(3.5),索引 3 第三(2.9)

为什么要这么写?

在迭代式检索的场景中:

bm25_scores 是 "每个文档的得分",但需要的是 "哪些文档得分高"(索引),而非得分本身

链式操作简洁高效(NumPy 数组操作比纯 Python 列表快得多)

最终通过 top_indices 能精准定位到文档库中 "最相关的 Top-K 文档",为后续排序、生成答案提供基础

等价写法

python 复制代码
# 步骤1:升序排序,返回索引
sorted_indices_asc = np.argsort(bm25_scores)
# 步骤2:转为降序
sorted_indices_desc = sorted_indices_asc[::-1]
# 步骤3:取Top-K
top_indices = sorted_indices_desc[:TOP_K_RETRIEVE]

混淆 np.sort() 和 np.argsort()

np.sort(bm25_scores):返回排序后的得分值(如 [0.8,1.2,2.9,3.5,4.1]),丢失索引

np.argsort(bm25_scores):返回排序后的索引,保留文档位置,这是检索场景需要的

气态

忘记反转([::-1]):会导致取 "得分最低的 Top-K",完全偏离需求

切片顺序错误:[TOP_K_RETRIEVE:] 会取 "后 N 个",而非前 N 个

相关推荐
特立独行的猫a1 小时前
SSE技术详解及在MCP协议中的应用和优势
ai·sse·mcp
万俟淋曦2 小时前
【论文速递】2025年第33周(Aug-10-16)(Robotics/Embodied AI/LLM)
人工智能·深度学习·ai·机器人·论文·robotics·具身智能
我不是QI3 小时前
周志华《机器学习---西瓜书》三
人工智能·机器学习·ai
Elastic 中国社区官方博客4 小时前
Elasticsearch:数据脱节如何破坏现代调查
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
张彦峰ZYF5 小时前
AI赋能原则7解读思考:AI时代构建可组合的能力比单点专业更重要
人工智能·ai·ai赋能与落地
badfl5 小时前
OpenAI文本嵌入模型text-embedding-3是什么?
人工智能·机器学习·ai
~kiss~6 小时前
RAG - 高阶检索范式-step-by-step prompting -分步提示
ai
杨晓风-linda6 小时前
工作流基础知识
人工智能·ai·工作流·n8n
阿杰学AI6 小时前
AI核心知识40——大语言模型之Token(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·token