这一篇文章我们来介绍一下如何构建可以回答企业数据库SQL问题的agent。我们的sql agent 大致流程如下:
- 从数据库中获取可用表
- 确定哪些表与问题有关
- 提取相关表的结构
- 根据问题和数据库表结构(schemas)生成查询语句
- 用LLM双重校验查询语句的常见错误
- 执行查询并返回结果
- 根据数据库引擎反馈的错误修正问题,直至查询成功执行。
- 根据查询结果生成回复
一. 安装设置
让我们首先安装一些依赖项。我们使用 langchain-Community 的SQL数据库和工具摘要。我们还将需要Langchain 聊天模型。
python
pip install -U langgraph langchain_community "langchain[openai]"
然后初始化LLM,使用openai作为我们的大模型驱动。
python
from langchain.chat_models import init_chat_model
llm = init_chat_model(model="openai:gpt-4.1-mini")
我们为这个Agent创建一个SQLite数据库。 SQLite是一个轻巧的数据库,易于设置和使用。我们将加载 chinook 数据库,该数据库是代表数字媒体存储的示例数据库。 这里我们从公共GCS存储桶上下载sqlite数据库,为了方便起见,我们在公共GCS存储桶上托管了数据库(Chinook.db)。
python
import requests
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)
if response.status_code == 200:
# Open a local file in binary write mode
with open("Chinook.db", "wb") as file:
# Write the content of the response (the file) to the local file
file.write(response.content)
print("File downloaded and saved as Chinook.db")
else:
print(f"Failed to download the file. Status code: {response.status_code}")
执行之后可以在本地生成一个db: 然后我们使用 langchain_community 软件包中可用的方便的SQL数据库包装器与数据库进行交互。包装器提供了一个简单的接口来执行SQL查询并提取结果:
python
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(f"Dialect: {db.dialect}")
print(f"Available tables: {db.get_usable_table_names()}")
print(f'Sample output: {db.run("SELECT * FROM Artist LIMIT 5;")}')
执行可以得到下面结果,也是我们这个数据库里面预先存储的数据: 下面我们来看一下 langchain-community 中实现了一些用于与我们的 SQLDatabase 交互的内置工具,包括用于列出表、读取表架构以及检查和运行查询的工具:
python
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
for tool in tools:
print(f"{tool.name}: {tool.description}\n")
得到下面结果:
二. 使用预构建的 Agent
有了这些工具,我们只需一行代码即可初始化一个预构建的智能体。要定制智能体的行为,我们只需编写一个描述性的系统提示。如下面所示
python
from langgraph.prebuilt import create_react_agent
system_prompt = """
You are an agent designed to interact with a SQL database. Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database. To start you should ALWAYS look at the tables in the database to see what you can query. Do NOT skip this step. Then you should query the schema of the most relevant tables.
""".format(
dialect=db.dialect,
top_k=5,
)
agent = create_react_agent(
llm,
tools,
prompt=system_prompt,
)
该系统提示包含一系列指令,例如要求在执行某些操作前后必须运行特定工具。在接下来的部分中,我们将通过图结构(graph's structure)强制实施这些行为,从而提供更高程度的控制,并允许我们简化提示内容。 我们运行这个代理:
python
question = "Which sales agent made the most in sales in 2009?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
得到下面结果: 从上面可以看到我们的效果达到了,agent 正确地列出了sqlite里面的表,获取了架构,编写了查询,检查了查询,并运行了查询以通知其最终响应。
三.自定义Agent
上面这个智能体让我们能够快速开始工作,但它在每一步都能访问全部工具集。在上文中,我们依赖系统提示来约束其行为------例如,我们通过指令要求智能体必须从"列出数据表"工具开始执行,且在每次运行查询前必须经过查询检查工具的校验。我们可以通过自定义 Agent 在 LangGraph 中更高程度的控制。下面,我们实现了一个简单的 ReAct-agent 模式,其中包含用于特定工具调用的专用节点。我们将使用与上面这个的 Agent 相同的状态。 大致的步骤就是列出 DB 表,调用 "get schema" 工具,生成查询,检查查询等。我们将这些步骤放入专用节点中,可以在需要时第一个是强制工具呼叫,第二个是自定义与每个步骤相关的提示。下面开始我们工作流的构建流程: 先初始化工具节点,获取数据库模式工具和查询执行工具
python
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
创建对应的工具节点
python
get_schema_node = ToolNode([get_schema_tool], name="get_schema")
run_query_node = ToolNode([run_query_tool], name="run_query")
再预定义工具调用
python
def list_tables(state: MessagesState):
# 硬编码工具调用结构
tool_call = {
"name": "sql_db_list_tables",
"args": {},
"id": "abc123",
"type": "tool_call",
}
# 执行工具并构造响应
tool_message = list_tables_tool.invoke(tool_call)
return {"messages": [..., response]}
强制工具调用模式
python
def call_get_schema(state: MessagesState):
# 强制绑定特定工具调用
llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
response = llm_with_tools.invoke(state["messages"])
sql语句查询生成与校验系统
python
generate_query_system_prompt = """
# 强调安全查询规范:
- 自动限制结果数量(top_k=5)
- 禁止DML操作
- 只查询必要列
"""
check_query_system_prompt = """
# SQL专家级校验规则:
- NULL值处理
- 联合查询类型
- 范围查询边界
- 数据类型匹配
- 标识符引号规范
"""
工作流函数
python
def generate_query(state):
# 绑定工具但不强制调用
llm_with_tools = llm.bind_tools([run_query_tool])
def check_query(state):
# 精确校验工具调用参数
tool_call = state["messages"][-1].tool_calls[0]
# 保持调用ID一致性
response.id = state["messages"][-1].id
整个流程就是:强制列出数据表 → 2. 获取表结构 → 3. 生成安全查询 → 4. 专家级语法校验 → 5. 执行查询 我贴上这一段完整代码
python
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
get_schema_node = ToolNode([get_schema_tool], name="get_schema")
run_query_tool = next(tool for tool in tools if tool.name == "sql_db_query")
run_query_node = ToolNode([run_query_tool], name="run_query")
# Example: create a predetermined tool call
def list_tables(state: MessagesState):
tool_call = {
"name": "sql_db_list_tables",
"args": {},
"id": "abc123",
"type": "tool_call",
}
tool_call_message = AIMessage(content="", tool_calls=[tool_call])
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
tool_message = list_tables_tool.invoke(tool_call)
response = AIMessage(f"Available tables: {tool_message.content}")
return {"messages": [tool_call_message, tool_message, response]}
# Example: force a model to create a tool call
def call_get_schema(state: MessagesState):
# Note that LangChain enforces that all models accept `tool_choice="any"`
# as well as `tool_choice=<string name of tool>`.
llm_with_tools = llm.bind_tools([get_schema_tool], tool_choice="any")
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
generate_query_system_prompt = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
""".format(
dialect=db.dialect,
top_k=5,
)
def generate_query(state: MessagesState):
system_message = {
"role": "system",
"content": generate_query_system_prompt,
}
# We do not force a tool call here, to allow the model to
# respond naturally when it obtains the solution.
llm_with_tools = llm.bind_tools([run_query_tool])
response = llm_with_tools.invoke([system_message] + state["messages"])
return {"messages": [response]}
check_query_system_prompt = """
You are a SQL expert with a strong attention to detail.
Double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes,
just reproduce the original query.
You will call the appropriate tool to execute the query after running this check.
""".format(dialect=db.dialect)
def check_query(state: MessagesState):
system_message = {
"role": "system",
"content": check_query_system_prompt,
}
# Generate an artificial user message to check
tool_call = state["messages"][-1].tool_calls[0]
user_message = {"role": "user", "content": tool_call["args"]["query"]}
llm_with_tools = llm.bind_tools([run_query_tool], tool_choice="any")
response = llm_with_tools.invoke([system_message, user_message])
response.id = state["messages"][-1].id
最终,我们使用Graph API将这些步骤组装成完整的工作流。在查询生成阶段,我们定义了条件路由逻辑:当生成SQL查询时会自动跳转到查询校验节点;若未触发任何工具调用(如LLM已直接给出最终答案),则直接终止流程。
python
def should_continue(state: MessagesState) -> Literal[END, "check_query"]:
messages = state["messages"]
last_message = messages[-1]
if not last_message.tool_calls:
return END
else:
return "check_query"
builder = StateGraph(MessagesState)
builder.add_node(list_tables)
builder.add_node(call_get_schema)
builder.add_node(get_schema_node, "get_schema")
builder.add_node(generate_query)
builder.add_node(check_query)
builder.add_node(run_query_node, "run_query")
builder.add_edge(START, "list_tables")
builder.add_edge("list_tables", "call_get_schema")
builder.add_edge("call_get_schema", "get_schema")
builder.add_edge("get_schema", "generate_query")
builder.add_conditional_edges(
"generate_query",
should_continue,
)
builder.add_edge("check_query", "run_query")
builder.add_edge("run_query", "generate_query")
agent = builder.compile()
我们可视化这个应用程序
python
from IPython.display import Image, display
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles
display(Image(agent.get_graph().draw_mermaid_png()))
得到下面结果 现在,到了我们的熟悉环节了,我们可以完全像以前一样调用图表
python
question = "Which sales agent made the most in sales in 2009?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
这样子我们自定义的工作流Agent就完成了,比预构建的Agent更加清晰明白, 这种架构既保证了SQL操作的安全性,又通过LangGraph的状态管理实现了复杂的多步骤交互。