基于ReAction范式的问答系统实现demo

基于ReAction范式的问答系统实现demo

参考文档

ReAct论文解读:LLM ReAct范式,在大语言模型中结合推理和动作

说明

由于我最近在做一个基于图数据库的问答系统,所以样例就以查询图数据背景,实现过程仅供参考,希望能够大家带来帮助。

源码

python 复制代码
import os
import json
from typing import Generator, Optional, Dict, Any
from neo4j import GraphDatabase
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()


# ----------------------------
# Neo4j 工具类
# ----------------------------
class Neo4jSearchTool:
    def __init__(self):
        self.driver = GraphDatabase.driver(
            os.getenv("NEO4J_URI"),
            auth=(
                os.getenv("NEO4J_USER"),
                os.getenv("NEO4J_PASSWORD")
            )
        )

    def run(self, query: str) -> str:
        try:
            with self.driver.session() as session:
                result = session.run(query)
                data = [dict(record) for record in result]
                return json.dumps(data, ensure_ascii=False) if data else "[]"
        except Exception as e:
            return f"ERROR: {str(e)}"


class Neo4jSchemaTool:
    def __init__(self, driver):
        self.driver = driver

    def get_node_schema(self, session):
        q = """
        CALL db.schema.nodeTypeProperties()
        YIELD nodeType, propertyName, propertyTypes
        RETURN nodeType, propertyName, propertyTypes
        """
        schema = {}
        for rec in session.run(q):
            label = rec["nodeType"].strip(":`")
            prop = rec["propertyName"]
            types = ", ".join(rec["propertyTypes"]) or "Unknown"
            schema.setdefault(label, {})[prop] = types
        return schema

    # ----------------------------------------------------------------------
    def get_relationship_schema(self, session):
        """
        For each relType: collect property definitions + a sampled (srcLabel, tgtLabel).
        """
        # 1) property map
        q_props = """
        CALL db.schema.relTypeProperties()
        YIELD relType, propertyName, propertyTypes
        RETURN relType, propertyName, propertyTypes
        """
        rel_schema = {}
        for rec in session.run(q_props):
            rtype = rec["relType"].strip(":`")
            prop = rec["propertyName"]
            if prop:
                types = ", ".join(rec["propertyTypes"]) or "Unknown"
                rel_schema.setdefault(rtype, {})[prop] = types

        # 2) sample endpoints for each relationship type
        for rtype in rel_schema:
            q_sample = f"""
            MATCH (s)-[r:`{rtype}`]->(t)
            WITH labels(s)[0] AS src, labels(t)[0] AS tgt
            RETURN src, tgt LIMIT 1
            """
            rec = session.run(q_sample).single()
            if rec:
                rel_schema[rtype]["_endpoints"] = [rec["src"], rec["tgt"]]
            else:  # no relationship instance found
                rel_schema[rtype]["_endpoints"] = ["Unknown", "Unknown"]

        return rel_schema

    def get_schema(self) -> dict:
        """提取数据库中的所有标签、关系和属性"""
        with self.driver.session() as session:
            # 获取所有节点标签

            labels = self.get_node_schema(session)
            rel_types = self.get_relationship_schema(session)

            return {
                "NodeTypes": labels,
                "RelationshipTypes": rel_types
            }
    def format_schema_prompt(self) -> str:
        """将schema转换为自然语言描述"""
        schema = self.get_schema()
        prompt = "数据库包含以下结构:\n"

        # 标签和属性
        prompt += "## 节点类型\n"
        prompt += json.dumps(schema["NodeTypes"],ensure_ascii=False)

        # 关系
        prompt += "\n## 关系类型\n"
        prompt += json.dumps(schema["RelationshipTypes"],ensure_ascii=False)

        return prompt


class AnswerValidator:
    @staticmethod
    def is_valid_answer(observation: str) -> bool:
        """检查工具返回是否包含有效答案"""
        if observation.startswith("ERROR") or observation == "[]":
            return False

        try:
            data = json.loads(observation)
            if isinstance(data, list) and len(data) > 0:
                first_item = data[0]
                # 检查是否有非空值
                return any(v for v in first_item.values() if v not in [None, ""])
            return False
        except:
            return False

    @staticmethod
    def should_terminate(llm_response: str) -> bool:
        """通过LLM判断是否应该终止"""
        prompt = f"""判断以下模型响应是否包含最终答案:
响应内容:{llm_response}

只需返回true或false:"""
        response = OpenAI().chat.completions.create(
            model="deepseek-chat",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )
        return "true" in response.choices[0].message.content.lower()

# ----------------------------
# ReAct 引擎
# ----------------------------
class ReActQASystem:
    def __init__(self):
        self.llm = OpenAI()
        neo4j_driver = GraphDatabase.driver(
            os.getenv("NEO4J_URI"),
            auth=(os.getenv("NEO4J_USER"), os.getenv("NEO4J_PASSWORD"))
        )
        self.tools = {
            "neo4j_search": Neo4jSearchTool(),
            "get_schema": Neo4jSchemaTool(neo4j_driver)
        }
        self.schema_prompt = self.tools["get_schema"].format_schema_prompt()
        self.max_steps = 5

    def _build_prompt(self, query: str, scratchpad: str = "") -> str:
        base_prompt = f"""你是一个审计专家,需要根据数据库结构编写准确的Cypher查询。

    {self.schema_prompt}

    可用工具:
    - neo4j_search: 执行Cypher查询,输入应为JSON格式的{{"query": "MATCH..."}}

    当前问题:{query}

    历史步骤:
    {scratchpad}

    严格按格式响应:
    Thought: 分析问题并确认需要查询的标签和关系
    Action:
    ```json
    {{"action": "工具名", "action_input": {{...}}}}
    ```"""
        return base_prompt

    def execute(self, query: str) -> Generator[str, None, None]:
        scratchpad = ""
        for step in range(self.max_steps):
            # 调用LLM生成响应
            prompt = self._build_prompt(query, scratchpad)
            print(f"LLM prompt: {prompt}")
            response = self.llm.chat.completions.create(
                model="deepseek-chat",
                messages=[{"role": "user", "content": prompt}],
                temperature=0
            )

            content = response.choices[0].message.content
            print(f"LLM Response: {content}")
            print(f"================================================")
            # 解析响应
            thought, action = self._parse_response(content)
            scratchpad += f"\n{content}\n"

            if not action:
                yield f"Final Answer: {thought}"
                break

            # 执行工具调用
            tool_name = action["action"]
            if AnswerValidator.should_terminate(action["action_input"]):
                yield f"Final Answer: {action['action_input']}"
                break
            elif tool_name in self.tools:
                tool_result = self.tools[tool_name].run(action["action_input"]["query"])
                observation = f"Observation: {tool_result}"
                scratchpad += observation + "\n"
                yield observation
            else:
                yield f"ERROR: 未知工具 {tool_name}"

    def _parse_response(self, text: str) -> tuple[str, Optional[Dict]]:
        thought = ""
        action = None

        # 提取Thought部分
        thought_start = text.find("Thought:") + len("Thought:")
        thought_end = text.find("Action:")
        if thought_start >= 0 and thought_end >= 0:
            thought = text[thought_start:thought_end].strip()

        # 提取Action部分
        action_start = text.find("```json") + len("```json")
        action_end = text.find("```", action_start)
        if action_start >= 0 and action_end >= 0:
            try:
                action = json.loads(text[action_start:action_end].strip())
            except json.JSONDecodeError:
                pass

        return thought, action


# ----------------------------
# 主程序
# ----------------------------
def main():
    qa_system = ReActQASystem()

    print("审计问答系统已启动(输入quit退出)")
    while True:
        query = input("\n用户提问: ")
        if query.lower() == "quit":
            break

        print("\n系统响应:")
        for response in qa_system.execute(query):
            print(response)


if __name__ == "__main__":
    main()

总结

欢迎大家留言,讨论

相关推荐
风逸hhh19 分钟前
python打卡day46@浙大疏锦行
开发语言·python
火兮明兮43 分钟前
Python训练第四十三天
开发语言·python
ascarl20102 小时前
准确--k8s cgroup问题排查
java·开发语言
互联网杂货铺2 小时前
完美搭建appium自动化环境
自动化测试·软件测试·python·测试工具·职场和发展·appium·测试用例
Gyoku Mint2 小时前
机器学习×第二卷:概念下篇——她不再只是模仿,而是开始决定怎么靠近你
人工智能·python·算法·机器学习·pandas·ai编程·matplotlib
fpcc2 小时前
跟我学c++中级篇——理解类型推导和C++不同版本的支持
开发语言·c++
莱茵菜苗2 小时前
Python打卡训练营day46——2025.06.06
开发语言·python
爱学习的小道长2 小时前
Python 构建法律DeepSeek RAG
开发语言·python
luojiaao3 小时前
【Python工具开发】k3q_arxml 简单但是非常好用的arxml编辑器,可以称为arxml杀手包
开发语言·python·编辑器
终焉代码3 小时前
STL解析——list的使用
开发语言·c++