可以自我反思的检索增强生成

Self-RAG 系统实现详解

一、Self-RAG 简介

Self-RAG(Self-Reflective Retrieval-Augmented Generation)赋予大模型 "自我反思与修正" 能力,可有效改善检索不准确、生成结果不可靠等问题。

二、Self-RAG 工作流程

  1. 检索文档:携带初始或修改后的问题进行文档检索,获取文档后开展上下文相关性评估。

  2. 上下文评估

    • 若评估为不相关,且此前未执行过 Query2Doc 转换,则进行 Query2Doc 转换,随后返回步骤 1 重新检索。
    • 若评估为相关,或虽不相关但已尝试过 Query2Doc 转换,则筛选出相关文档,进入生成答案环节 。
  3. 生成答案:依据筛选出的相关文档生成答案。

  4. 评估答案

    • 检查答案是否基于上下文(Supported)

      • 若 "否"(存在幻觉情况),返回步骤 3 重新生成答案。
      • 若 "是",继续检查答案是否有用。
    • 检查答案是否有用(Useful)

      • 若 "是",流程结束(END)。

      • 若 "否"(答案跑题或无效),检查查询重写次数:

        • 未达重写上限,执行查询重写(Rewrite Query),返回步骤 1 重新检索。
        • 已达重写上限,直接结束(END),不再进行修正尝试。

三、实现所需节点

  1. 检索节点:基于问题检索信息。
  2. 上下文评估节点:评估检索结果与问题的相关性,决定后续操作。
  3. 生成答案节点:在相关性评估通过后生成答案。
  4. 评估答案是否有用节点:判断答案有效性,决定是否进入查询重写流程。
  5. 转换查询节点(query2doc) :在上下文评估不通过或答案无用时,转换查询内容 。
  6. 重写查询节点:对无效答案对应的查询进行重写。
  7. 结束节点:当答案有用或与问题完全不相关时,终止流程。

四、节点实现详解

1. 检索节点

思考过程

  • 输入:用户问题

  • 操作:在向量数据库检索相似文档

  • 函数接收参数:用户问题、最大查询数

  • 检索方法:通过向量数据库的 similarity_search 获取目标列表

  • 上下文构建:遍历 related_doc,字符串拼接文档内容形成 context

  • 输出:初步检索的文档列表

实现代码

python

ini 复制代码
def rag_retrieve(question, k=3): 
    related_docs = zhidu_db.similarity_search(question, k=3) 
    context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(related_docs)]) 
    return related_docs, context

2. 转换查询节点

思考过程

  • 输入:LangGraph 节点函数的 State 对象(含流程所有状态信息)

  • 操作:调用 Query2doc 传入用户问题,通过提示词模板解答问题

  • 输出:Query2doc 函数返回的答案

实现代码

ini 复制代码
def transform_query2doc(state):
    print("---transform_query2doc---")

    # node input
    state_dict = state["keys"]
    question = state_dict["question"] # 获取原始问题
    documents = state_dict["documents"] # 获取当前文档(后续使用)
    context = state_dict["context"]   # 获取当前上下文(后续使用)
    query2doc_count = state_dict.get("query2doc_count", 0) # 获取转换计数
    rewrite_count = state_get("rewrite_count", 0) # 获取重写计数

    # task - 核心操作!
    context_query = query2doc(question)

    # node output
    return {"keys": {"context": context,
                    "documents": documents,
                    "question": question, # 保留原始问题
                    "context_query": context_query, # 添加转换后的查询
                    "query2doc_count": query2doc_count + 1, # 增加转换计数
                    "rewrite_count": rewrite_count}} # 保持重写计数不变

3. 重写查询节点

思考过程

  • 输入:LangGraph 节点函数的 State 对象(含流程所有状态信息)

  • 过程:调用 question_rewrite 传入用户问题,借助提示词模板解答

  • 输出:question_rewrite 函数返回的答案,统计并限制重写次数避免死循环

实现代码

ini 复制代码
def transform_query_rewrite(state): 
    print("---transform_query---") 
    state_dict = state["keys"] 
    question = state_dict["question"] 
    documents = state_dict["documents"] 
    context = state_dict["context"] 
    query2doc_count = state_dict.get("query2doc_count", 0) 
    rewrite_count = state_dict.get("rewrite_count", 0)

    context_query = question_rewrite(question)
    return {
        "keys": {
            "context": context,
            "documents": documents,
            "question": question,
            "context_query": context_query,
            "query2doc_count": query2doc_count,
            "rewrite_count": rewrite_count + 1
        }
    }

4. 上下文评估节点

思考过程

  • 输入:LangGraph 节点函数的标准输入 state 对象,提取用户问题和检索文档

  • 操作:遍历文档列表,评估每个文档与问题的相关性,筛选相关文档

  • 输出:更新后的 state 字典,包含仅含相关文档的 documents 列表、基于相关文档构建的 context 字符串、可能的标识位及其他状态信息

实现代码

python 复制代码
def grade_documents(state): 
    print("---Determines whether the retrieved documents are relevant to the question---") 
    state_dict = state["keys"] 
    question, documents = state_dict["question"], state_dict["documents"] 
    query2doc_count = state_dict.get("query2doc_count", 0) 
    rewrite_count = state_dict.get("rewrite_count", 0)

    filtered_docs, retrieve_enhance = [], "No"

    for d in documents:
        grade = context_grade_chain.invoke({"question": question, "context": d.page_content})
        print(f"Document (first 50): {d.page_content[:50]}... Grade: {grade}")
        if "yes" in grade.lower():
            filtered_docs.append(d)
        else:
            retrieve_enhance = "Yes"

    if query2doc_count > 0:
        retrieve_enhance = "No"

    context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(filtered_docs)]) if filtered_docs else ""

    return {
        "keys": {
            "context": context,
            "documents": filtered_docs,
            "question": question,
            "run_retrieve_enhance": retrieve_enhance,
            "query2doc_count": query2doc_count,
            "rewrite_count": rewrite_count
        }
    }

5. 生成答案节点

思考过程

  • 输入:State 对象,其中包含用户问题、筛选后的文档或构建的上下文

  • 操作:构建模板,填充问题和上下文,调用 LLM 生成答案

  • 输出:更新后的 state 字典,新增 generation 键存储 LLM 生成的答案,其余状态不变

实现代码

css 复制代码
from langchain.prompts import PromptTemplate 
from langchain_core.output_parsers import StrOutputParser

def generate(state): 
    print("---Generate answer---") 
    state_dict = state["keys"] 
    question = state_dict["question"] 
    documents = state_dict["documents"] 
    query2doc_count = state_dict.get("query2doc_count", 0) 
    rewrite_count = state_dict.get("rewrite_count", 0) 
    context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(documents)]) 
    prompt = PromptTemplate( input_variables=["question", "context"], template=prompt_template ) 
    rag_chain = prompt | llm | StrOutputParser() 
    generation = rag_chain.invoke({"context": context, "question": question}) 
    return { 
        "keys": { 
            "context": context, 
            "question": question, 
            "documents": documents, 
            "generation": generation, 
            "query2doc_count": query2doc_count, 
            "rewrite_count": rewrite_count 
        } 
    }

6. 评估答案是否有用节点

思考过程

  • 输入:state 对象、当前问题、生成答案、相关上下文与文档、查询重写计数器参数

  • 操作:评估答案有效性、是否基于上下文,依评估结果决定流程走向

  • 输出:表示下一跳转节点名称的字符串,用于工作流判断

实现代码

perl 复制代码
# --- 4. 答案评估函数部分 ---
def grade_generation_v_documents_and_question(state):

    print("---Determines whether the answer is relevant to the question---") # 日志:开始评估答案

    # a. 获取当前状态信息
    state_dict = state["keys"]
    question = state_dict["question"]
    context = state_dict["context"] # 注意:这里用的是合并后的 context 字符串
    generation = state_dict["generation"] # 获取上一步生成的答案
    rewrite_count = state_dict.get("rewrite_count", 0) # 获取查询重写次数

    # b. 第一层检查:答案是否基于上下文?
    print("---GRADE GENERATION vs CONTEXT (Supported?)---") # 日志:检查答案是否基于上下文
    grade = answer_supported_chain.invoke({"generation": generation, "context": context})

    # c. 判断第一层检查结果
    if "yes" in grade.lower(): # 如果答案是基于上下文的 ("yes")
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") # 日志:判定答案有依据

        # d. 第二层检查:答案是否有用 (针对问题)?
        print("---GRADE GENERATION vs QUESTION (Useful?)---") # 日志:检查答案是否有用
        score = answer_useful_chain.invoke({"question": question, "generation": generation})

        # e. 判断第二层检查结果
        if "yes" in score.lower(): # 如果答案有用 ("yes")
            print("---DECISION: GENERATION ADDRESSES QUESTION---") # 日志:判定答案有用
            return "useful" # 返回 "useful",流程将走向 END
        else: # 如果答案无用 ("no")
            # f. 检查是否已重写过查询
            if rewrite_count < 1: # 如果还没重写过 (次数小于1)
                print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION, REWRITE---") # 日志:判定答案无用,准备重写
                return "not useful" # 返回 "not useful",流程将走向 transform_query_rewrite
            else: # 如果已经重写过一次或更多次
                print("---DECISION: GENERATION USELESS AFTER REWRITE, END---") # 日志:重写后答案仍无用,结束
                return "end" # 返回 "end",流程将走向 END (放弃治疗)
    else: # 如果答案不基于上下文 ("no")
        print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY GENERATE---") # 日志:判定答案无依据,重试生成
        return "not supported" # 返回 "not supported",流程将走向 generate (重新生成)

7. 结束节点

思考过程 :在 workflow.add_conditional_edgesworkflow.add_edge 中设置流程指向 END,当流程执行到 END 时,app.stream()app.invoke() 停止并返回最终状态或结果。

实现代码

bash 复制代码
workflow.add_conditional_edges(
    "generate", # 源节点
    grade_generation_v_documents_and_question, # 条件判断函数
    { # 结果到目标节点的映射
        "not supported": "generate", # 如果答案不被支持,回到 generate
        "useful": END, # 如果答案有用,流向 END
        "end": END, # 如果达到重写次数上限,也流向 END
        "not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写
    },
)

五、将节点串联成工作流

思考过程

  1. 回顾自适应 RAG 工作流程,明确各环节先后顺序与逻辑关系 。

  2. 采用 TypedDict 定义统一结构存储问题、文档、上下文、答案、计数器等中间状态。

  3. 确定节点函数(retrievegrade_documentsgeneratetransform_query2doctransform_query_rewrite)与条件判断函数(decide_to_generategrade_generation_v_documents_and_question)。

  4. 实例化 StateGraph 对象,注册节点、指定起点、链接节点(固定顺序用 add_edge,条件判断用 add_conditional_edges) 。

实现代码

python 复制代码
class GraphState(TypedDict):
    keys: Dict[str, any]

workflow = StateGraph(GraphState)

# 添加节点
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query2doc", transform_query2doc)
workflow.add_node("transform_query_rewrite", transform_query_rewrite)

# 设置入口点
workflow.set_entry_point("retrieve")

# 添加固定边
workflow.add_edge("retrieve", "grade_documents")
workflow.add_edge("transform_query2doc", "retrieve") # Query2Doc 后重新检索
workflow.add_edge("transform_query_rewrite", "retrieve") # Rewrite Query 后重新检索

# 添加条件边 - 根据上下文评估结果决定走向
workflow.add_conditional_edges(
    "grade_documents", # 源节点
    decide_to_generate, # 条件判断函数
    { # 结果到目标节点的映射
        "transform_query2doc": "transform_query2doc", # 如果需要转换查询
        "generate": "generate", # 如果可以直接生成
    },
)

# 添加条件边 - 根据答案评估结果决定走向
workflow.add_conditional_edges(
    "generate", # 源节点
    grade_generation_v_documents_and_question, # 条件判断函数
    { # 结果到目标节点的映射
        "not supported": "generate", # 如果答案不被支持,回到 generate
        "useful": END, # 如果答案有用,流向 END
        "end": END, # 如果达到重写次数上限,也流向 END
        "not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写
    },
)

相关推荐
小技工丨2 分钟前
详解大语言模型生态系统概念:lama,llama.cpp,HuggingFace 模型 ,GGUF,MLX,lm-studio,ollama这都是什么?
人工智能·语言模型·llama
陈奕昆4 分钟前
大模型微调之LLaMA-Factory 系列教程大纲
人工智能·llama·大模型微调·llama-factory
上海云盾商务经理杨杨27 分钟前
AI如何重塑DDoS防护行业?六大变革与未来展望
人工智能·安全·web安全·ddos
一刀到底21138 分钟前
ai agent(智能体)开发 python3基础8 网页抓取中 selenium 和 Playwright 区别和联系
人工智能·python
每天都要写算法(努力版)43 分钟前
【神经网络与深度学习】改变随机种子可以提升模型性能?
人工智能·深度学习·神经网络
烟锁池塘柳01 小时前
【计算机视觉】三种图像质量评价指标详解:PSNR、SSIM与SAM
人工智能·深度学习·计算机视觉
小森77671 小时前
(六)机器学习---聚类与K-means
人工智能·机器学习·数据挖掘·scikit-learn·kmeans·聚类
RockLiu@8052 小时前
探索PyTorch中的空间与通道双重注意力机制:实现concise的scSE模块
人工智能·pytorch·python
进取星辰2 小时前
PyTorch 深度学习实战(23):多任务强化学习(Multi-Task RL)之扩展
人工智能·pytorch·深度学习
极客智谷2 小时前
Spring AI应用系列——基于ARK实现多模态模型应用
人工智能·后端