Self-RAG 系统实现详解
一、Self-RAG 简介
Self-RAG(Self-Reflective Retrieval-Augmented Generation)赋予大模型 "自我反思与修正" 能力,可有效改善检索不准确、生成结果不可靠等问题。
二、Self-RAG 工作流程
-
检索文档:携带初始或修改后的问题进行文档检索,获取文档后开展上下文相关性评估。
-
上下文评估:
- 若评估为不相关,且此前未执行过 Query2Doc 转换,则进行 Query2Doc 转换,随后返回步骤 1 重新检索。
- 若评估为相关,或虽不相关但已尝试过 Query2Doc 转换,则筛选出相关文档,进入生成答案环节 。
-
生成答案:依据筛选出的相关文档生成答案。
-
评估答案:
-
检查答案是否基于上下文(Supported) :
- 若 "否"(存在幻觉情况),返回步骤 3 重新生成答案。
- 若 "是",继续检查答案是否有用。
-
检查答案是否有用(Useful) :
-
若 "是",流程结束(END)。
-
若 "否"(答案跑题或无效),检查查询重写次数:
- 未达重写上限,执行查询重写(Rewrite Query),返回步骤 1 重新检索。
- 已达重写上限,直接结束(END),不再进行修正尝试。
-
-
三、实现所需节点
- 检索节点:基于问题检索信息。
- 上下文评估节点:评估检索结果与问题的相关性,决定后续操作。
- 生成答案节点:在相关性评估通过后生成答案。
- 评估答案是否有用节点:判断答案有效性,决定是否进入查询重写流程。
- 转换查询节点(query2doc) :在上下文评估不通过或答案无用时,转换查询内容 。
- 重写查询节点:对无效答案对应的查询进行重写。
- 结束节点:当答案有用或与问题完全不相关时,终止流程。
四、节点实现详解
1. 检索节点
思考过程:
-
输入:用户问题
-
操作:在向量数据库检索相似文档
-
函数接收参数:用户问题、最大查询数
-
检索方法:通过向量数据库的
similarity_search
获取目标列表 -
上下文构建:遍历
related_doc
,字符串拼接文档内容形成context
-
输出:初步检索的文档列表
实现代码:
python
ini
def rag_retrieve(question, k=3):
related_docs = zhidu_db.similarity_search(question, k=3)
context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(related_docs)])
return related_docs, context
2. 转换查询节点
思考过程:
-
输入:LangGraph 节点函数的 State 对象(含流程所有状态信息)
-
操作:调用
Query2doc
传入用户问题,通过提示词模板解答问题 -
输出:
Query2doc
函数返回的答案
实现代码:
ini
def transform_query2doc(state):
print("---transform_query2doc---")
# node input
state_dict = state["keys"]
question = state_dict["question"] # 获取原始问题
documents = state_dict["documents"] # 获取当前文档(后续使用)
context = state_dict["context"] # 获取当前上下文(后续使用)
query2doc_count = state_dict.get("query2doc_count", 0) # 获取转换计数
rewrite_count = state_get("rewrite_count", 0) # 获取重写计数
# task - 核心操作!
context_query = query2doc(question)
# node output
return {"keys": {"context": context,
"documents": documents,
"question": question, # 保留原始问题
"context_query": context_query, # 添加转换后的查询
"query2doc_count": query2doc_count + 1, # 增加转换计数
"rewrite_count": rewrite_count}} # 保持重写计数不变
3. 重写查询节点
思考过程:
-
输入:LangGraph 节点函数的 State 对象(含流程所有状态信息)
-
过程:调用
question_rewrite
传入用户问题,借助提示词模板解答 -
输出:
question_rewrite
函数返回的答案,统计并限制重写次数避免死循环
实现代码:
ini
def transform_query_rewrite(state):
print("---transform_query---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
context = state_dict["context"]
query2doc_count = state_dict.get("query2doc_count", 0)
rewrite_count = state_dict.get("rewrite_count", 0)
context_query = question_rewrite(question)
return {
"keys": {
"context": context,
"documents": documents,
"question": question,
"context_query": context_query,
"query2doc_count": query2doc_count,
"rewrite_count": rewrite_count + 1
}
}
4. 上下文评估节点
思考过程:
-
输入:LangGraph 节点函数的标准输入 state 对象,提取用户问题和检索文档
-
操作:遍历文档列表,评估每个文档与问题的相关性,筛选相关文档
-
输出:更新后的 state 字典,包含仅含相关文档的
documents
列表、基于相关文档构建的context
字符串、可能的标识位及其他状态信息
实现代码:
python
def grade_documents(state):
print("---Determines whether the retrieved documents are relevant to the question---")
state_dict = state["keys"]
question, documents = state_dict["question"], state_dict["documents"]
query2doc_count = state_dict.get("query2doc_count", 0)
rewrite_count = state_dict.get("rewrite_count", 0)
filtered_docs, retrieve_enhance = [], "No"
for d in documents:
grade = context_grade_chain.invoke({"question": question, "context": d.page_content})
print(f"Document (first 50): {d.page_content[:50]}... Grade: {grade}")
if "yes" in grade.lower():
filtered_docs.append(d)
else:
retrieve_enhance = "Yes"
if query2doc_count > 0:
retrieve_enhance = "No"
context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(filtered_docs)]) if filtered_docs else ""
return {
"keys": {
"context": context,
"documents": filtered_docs,
"question": question,
"run_retrieve_enhance": retrieve_enhance,
"query2doc_count": query2doc_count,
"rewrite_count": rewrite_count
}
}
5. 生成答案节点
思考过程:
-
输入:State 对象,其中包含用户问题、筛选后的文档或构建的上下文
-
操作:构建模板,填充问题和上下文,调用 LLM 生成答案
-
输出:更新后的 state 字典,新增
generation
键存储 LLM 生成的答案,其余状态不变
实现代码:
css
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
def generate(state):
print("---Generate answer---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
query2doc_count = state_dict.get("query2doc_count", 0)
rewrite_count = state_dict.get("rewrite_count", 0)
context = "\n".join([f"上下文{i+1}: {doc.page_content} \n" for i, doc in enumerate(documents)])
prompt = PromptTemplate( input_variables=["question", "context"], template=prompt_template )
rag_chain = prompt | llm | StrOutputParser()
generation = rag_chain.invoke({"context": context, "question": question})
return {
"keys": {
"context": context,
"question": question,
"documents": documents,
"generation": generation,
"query2doc_count": query2doc_count,
"rewrite_count": rewrite_count
}
}
6. 评估答案是否有用节点
思考过程:
-
输入:state 对象、当前问题、生成答案、相关上下文与文档、查询重写计数器参数
-
操作:评估答案有效性、是否基于上下文,依评估结果决定流程走向
-
输出:表示下一跳转节点名称的字符串,用于工作流判断
实现代码:
perl
# --- 4. 答案评估函数部分 ---
def grade_generation_v_documents_and_question(state):
print("---Determines whether the answer is relevant to the question---") # 日志:开始评估答案
# a. 获取当前状态信息
state_dict = state["keys"]
question = state_dict["question"]
context = state_dict["context"] # 注意:这里用的是合并后的 context 字符串
generation = state_dict["generation"] # 获取上一步生成的答案
rewrite_count = state_dict.get("rewrite_count", 0) # 获取查询重写次数
# b. 第一层检查:答案是否基于上下文?
print("---GRADE GENERATION vs CONTEXT (Supported?)---") # 日志:检查答案是否基于上下文
grade = answer_supported_chain.invoke({"generation": generation, "context": context})
# c. 判断第一层检查结果
if "yes" in grade.lower(): # 如果答案是基于上下文的 ("yes")
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") # 日志:判定答案有依据
# d. 第二层检查:答案是否有用 (针对问题)?
print("---GRADE GENERATION vs QUESTION (Useful?)---") # 日志:检查答案是否有用
score = answer_useful_chain.invoke({"question": question, "generation": generation})
# e. 判断第二层检查结果
if "yes" in score.lower(): # 如果答案有用 ("yes")
print("---DECISION: GENERATION ADDRESSES QUESTION---") # 日志:判定答案有用
return "useful" # 返回 "useful",流程将走向 END
else: # 如果答案无用 ("no")
# f. 检查是否已重写过查询
if rewrite_count < 1: # 如果还没重写过 (次数小于1)
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION, REWRITE---") # 日志:判定答案无用,准备重写
return "not useful" # 返回 "not useful",流程将走向 transform_query_rewrite
else: # 如果已经重写过一次或更多次
print("---DECISION: GENERATION USELESS AFTER REWRITE, END---") # 日志:重写后答案仍无用,结束
return "end" # 返回 "end",流程将走向 END (放弃治疗)
else: # 如果答案不基于上下文 ("no")
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY GENERATE---") # 日志:判定答案无依据,重试生成
return "not supported" # 返回 "not supported",流程将走向 generate (重新生成)
7. 结束节点
思考过程 :在 workflow.add_conditional_edges
或 workflow.add_edge
中设置流程指向 END,当流程执行到 END 时,app.stream()
或 app.invoke()
停止并返回最终状态或结果。
实现代码:
bash
workflow.add_conditional_edges(
"generate", # 源节点
grade_generation_v_documents_and_question, # 条件判断函数
{ # 结果到目标节点的映射
"not supported": "generate", # 如果答案不被支持,回到 generate
"useful": END, # 如果答案有用,流向 END
"end": END, # 如果达到重写次数上限,也流向 END
"not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写
},
)
五、将节点串联成工作流
思考过程:
-
回顾自适应 RAG 工作流程,明确各环节先后顺序与逻辑关系 。
-
采用
TypedDict
定义统一结构存储问题、文档、上下文、答案、计数器等中间状态。 -
确定节点函数(
retrieve
、grade_documents
、generate
、transform_query2doc
、transform_query_rewrite
)与条件判断函数(decide_to_generate
、grade_generation_v_documents_and_question
)。 -
实例化
StateGraph
对象,注册节点、指定起点、链接节点(固定顺序用add_edge
,条件判断用add_conditional_edges
) 。
实现代码:
python
class GraphState(TypedDict):
keys: Dict[str, any]
workflow = StateGraph(GraphState)
# 添加节点
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query2doc", transform_query2doc)
workflow.add_node("transform_query_rewrite", transform_query_rewrite)
# 设置入口点
workflow.set_entry_point("retrieve")
# 添加固定边
workflow.add_edge("retrieve", "grade_documents")
workflow.add_edge("transform_query2doc", "retrieve") # Query2Doc 后重新检索
workflow.add_edge("transform_query_rewrite", "retrieve") # Rewrite Query 后重新检索
# 添加条件边 - 根据上下文评估结果决定走向
workflow.add_conditional_edges(
"grade_documents", # 源节点
decide_to_generate, # 条件判断函数
{ # 结果到目标节点的映射
"transform_query2doc": "transform_query2doc", # 如果需要转换查询
"generate": "generate", # 如果可以直接生成
},
)
# 添加条件边 - 根据答案评估结果决定走向
workflow.add_conditional_edges(
"generate", # 源节点
grade_generation_v_documents_and_question, # 条件判断函数
{ # 结果到目标节点的映射
"not supported": "generate", # 如果答案不被支持,回到 generate
"useful": END, # 如果答案有用,流向 END
"end": END, # 如果达到重写次数上限,也流向 END
"not useful": "transform_query_rewrite", # 如果答案无用且未达上限,去重写
},
)