架构
从高层次来看,这些系统的步骤如下:
-
将问题转换为SQL查询:模型将用户输入转换为SQL查询。
-
执行SQL查询:执行查询。
-
回答问题:模型根据查询结果响应用户输入。

样本数据
下载样本数据:
curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db
现在,Chinook.db 位于我们的目录中,我们可以使用 SQLAlchemy 驱动的 SQLDatabase 类与它进行交互:
python
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:Documents/learn-langchain/example-data/Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
链
链条是可预测步骤的组合。在 LangGraph 中,我们可以通过简单的节点序列来表示链条。让我们创建一个步骤序列,给定一个问题,执行以下操作:
-
将问题转换为 SQL 查询;
-
执行查询;
-
使用结果回答原始问题。
这个安排并不支持所有场景。例如,系统会对任何用户输入执行 SQL 查询------即使是"你好"。值得注意的是,正如我们下面将看到的,有些问题需要多次查询才能回答。我们将在"代理"部分解决这些场景。
应用状态
我们应用的 LangGraph 状态控制着输入到应用程序的数据、在步骤之间传递的数据以及应用程序输出的数据。它通常是一个 TypedDict,也可以是一个 Pydantic BaseModel。
对于这个应用,我们可以只跟踪输入的问题、生成的查询、查询结果和生成的答案:
python
from typing_extensions import TypedDict
class State(TypedDict):
question: str
query: str
result: str
answer: str
现在我们只需要一些函数来操作这个状态并填充其内容。
将问题转换为 SQL 查询
第一步是将用户输入转换为 SQL 查询。为了可靠地获取 SQL 查询(不包括 Markdown 格式的说明或解释),我们将利用 LangChain 的结构化输出抽象。
python
from config import *
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model='deepseek-v3')
我们将从 Prompt Hub 中获取一个提示,来指导模型。
python
from langchain import hub
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()
===============================[1m System Message [0m================================
Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Only use the following tables:
[33;1m[1;3m{table_info}[0m
Question: [33;1m[1;3m{input}[0m
这个提示包含了我们需要填充的几个参数,例如 SQL 方言和表模式。LangChain 的 SQLDatabase 对象包含了一些方法来帮助我们处理这些。我们的 write_query 步骤将只填充这些参数并提示模型生成 SQL 查询:
python
from pydantic import BaseModel
from typing_extensions import Annotated
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import PydanticOutputParser
class QueryOutput(BaseModel):
"""Generated SQL query."""
query: Annotated[str, ..., "Syntactically valid SQL query."]
parser = PydanticOutputParser(pydantic_object=QueryOutput)
query_prompt = ChatPromptTemplate.from_messages(
[
('system', '{format_instructions}'),
query_prompt_template
]
).partial(format_instructions=parser.get_format_instructions())
def write_query(state: State):
"""Generate SQL query to fetch information."""
prompt = query_prompt.invoke(
{
"dialect": db.dialect,
"top_k": 10,
"table_info": db.get_table_info(),
"input": state["question"],
}
)
chain = llm | parser
result = chain.invoke(prompt)
return {"query": result.query}
看一下query_prompt的内容:
python
for message in query_prompt.messages:
message.pretty_print()
================================[1m System Message [0m================================
[33;1m[1;3m{format_instructions}[0m
================================[1m System Message [0m================================
Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Only use the following tables:
[33;1m[1;3m{table_info}[0m
Question: [33;1m[1;3m{input}[0m
让我们测试一下:
python
write_query({"question": "How many Employees are there?"})
{'query': 'SELECT COUNT(*) AS EmployeeCount FROM Employee;'}
执行查询
这是创建 SQL 链条中最危险的部分。在自动执行查询之前,请仔细考虑是否可以对数据运行自动化查询。尽可能减少数据库连接权限。考虑在查询执行之前在链条中添加人工批准步骤(见下文)。
为了执行查询,我们将从 langchain-community 加载一个工具。我们的 execute_query 节点只是封装这个工具:
python
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
def execute_query(state: State):
"""Execute SQL query."""
execute_query_tool = QuerySQLDatabaseTool(db=db)
return {"result": execute_query_tool.invoke(state["query"])}
测试一下:
python
execute_query({'query': 'SELECT COUNT(*) AS EmployeeCount FROM Employee;'})
{'result': '[(8,)]'}
生成答案
最后,我们的最后一步是根据从数据库中提取的信息生成问题的答案:
python
def generate_answer(state: State):
"""Answer question using retrieved information as context."""
prompt = (
"Given the following user question, corresponding SQL query, "
"and SQL result, answer the user question.\n\n"
f'Question: {state["question"]}\n'
f'SQL Query: {state["query"]}\n'
f'SQL Result: {state["result"]}'
)
response = llm.invoke(prompt)
return {"answer": response.content}
使用 LangGraph 进行协调
最后,我们将应用程序编译成一个单一的图形对象。在这种情况下,我们只是将这三步连接成一个单一的序列。
python
from langgraph.graph import START, StateGraph
graph_builder = StateGraph(State).add_sequence(
[write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()
python
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))

测试一下应用!
python
for step in graph.stream(
{"question": "How many employees are there?"}, stream_mode="updates"
):
print(step)
{'write_query': {'query': 'SELECT COUNT(*) AS NumberOfEmployees FROM Employee;'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are **8 employees** in total.'}}
人工参与
LangGraph 支持许多对这个工作流有用的功能,其中之一就是人工参与:我们可以在敏感步骤(如执行 SQL 查询)之前中断应用程序,以便进行人工审核。这是通过 LangGraph 的持久化层实现的,该层将运行进度保存到您选择的存储中。下面,我们指定了内存存储:
python
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])
# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}
display(Image(graph.get_graph().draw_mermaid_png()))

让我们重复相同的运行,并添加一个简单的 yes/no 审批步骤:
python
for step in graph.stream(
{"question": "How many employees are there?"},
config,
stream_mode="updates",
):
print(step)
try:
user_approval = input("Do you want to go to execute query? (yes/no): ")
except Exception:
user_approval = "no"
if user_approval.lower() == "yes":
# If approved, continue the graph execution
for step in graph.stream(None, config, stream_mode="updates"):
print(step)
else:
print("Operation cancelled by user.")
{'write_query': {'query': 'SELECT COUNT(*) AS EmployeeCount FROM Employee;'}}
{'__interrupt__': ()}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': 'There are **8 employees** in total.'}}
代理
代理利用大型语言模型(LLM)的推理能力在执行过程中做出决策。使用代理可以将更多的判断权转移到查询生成和执行过程中。尽管它们的行为比上述"链条"更不可预测,但它们也有一些优势:
• 它们可以根据需要多次查询数据库以回答用户问题。
• 它们可以通过运行生成的查询,捕获回溯并正确地重新生成查询,从而从错误中恢复。
• 它们不仅可以根据数据库的内容回答问题,还可以基于数据库的模式回答问题(比如描述特定的表)。
下面我们组装一个最小的 SQL 代理。
python
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
tools
[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11919bee0>),
InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11919bee0>),
ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11919bee0>),
QuerySQLCheckerTool(description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11919bee0>, llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x11d442b80>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x11d4841f0>, root_client=<openai.OpenAI object at 0x11d3e05b0>, root_async_client=<openai.AsyncOpenAI object at 0x11d442bb0>, model_name='deepseek-v3', model_kwargs={}, openai_api_key=SecretStr('**********'), openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1'), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x11d442b80>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x11d4841f0>, root_client=<openai.OpenAI object at 0x11d3e05b0>, root_async_client=<openai.AsyncOpenAI object at 0x11d442bb0>, model_name='deepseek-v3', model_kwargs={}, openai_api_key=SecretStr('**********'), openai_api_base='https://dashscope.aliyuncs.com/compatible-mode/v1'), output_parser=StrOutputParser(), llm_kwargs={}))]
系统提示
我们还需要为我们的代理加载一个系统提示。这将包括行为指令。
python
from langchain import hub
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
assert len(prompt_template.messages) == 1
让我们填充提示中的参数:
python
system_message = prompt_template.format(dialect="SQLite", top_k=5)
print(system_message)
System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables.
初始化代理
我们将使用一个预构建的 LangGraph 代理来构建我们的代理。
python
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="qwen-max")
agent_executor = create_react_agent(llm, tools, prompt=system_message)
display(Image(agent_executor.get_graph().draw_mermaid_png()))

python
question = "Which country's customers spent the most?"
result = agent_executor.invoke({"messages": [{"role": "user", "content": question}]})
python
for m in result['messages']:
m.pretty_print()
================================[1m Human Message [0m=================================
Which country's customers spent the most?
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_list_tables (call_11959469bb4c42ab8faaee)
Call ID: call_11959469bb4c42ab8faaee
Args:
tool_input:
=================================[1m Tool Message [0m=================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_schema (call_6f549cbdefa94a5e80152e)
Call ID: call_6f549cbdefa94a5e80152e
Args:
table_names: Customer,Invoice
=================================[1m Tool Message [0m=================================
Name: sql_db_schema
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_query_checker (call_a859a5ed305e42d2b9048f)
Call ID: call_a859a5ed305e42d2b9048f
Args:
query: SELECT Customer.Country, SUM(Invoice.Total) AS TotalSpent FROM Invoice JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Customer.Country ORDER BY TotalSpent DESC LIMIT 5;
=================================[1m Tool Message [0m=================================
Name: sql_db_query_checker
```sql
SELECT Customer.Country, SUM(Invoice.Total) AS TotalSpent
FROM Invoice
JOIN Customer ON Invoice.CustomerId = Customer.CustomerId
GROUP BY Customer.Country
ORDER BY TotalSpent DESC
LIMIT 5;
```
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_query (call_5b569709098b49dd9336e5)
Call ID: call_5b569709098b49dd9336e5
Args:
query: SELECT Customer.Country, SUM(Invoice.Total) AS TotalSpent FROM Invoice JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Customer.Country ORDER BY TotalSpent DESC LIMIT 5;
=================================[1m Tool Message [0m=================================
Name: sql_db_query
[('USA', 523.06), ('Canada', 303.96), ('France', 195.1), ('Brazil', 190.1), ('Germany', 156.48)]
==================================[1m Ai Message [0m==================================
The countries whose customers spent the most, in descending order, are as follows:
1. USA - $523.06
2. Canada - $303.96
3. France - $195.10
4. Brazil - $190.10
5. Germany - $156.48
These are the top 5 countries with the highest spending customers.
代理会执行多个查询,直到获取所需的信息:
-
列出可用的表;
-
获取三个表的模式;
-
通过联接操作查询多个表。
然后,代理能够使用最终查询的结果来生成原始问题的答案。
代理同样可以处理定性问题:
python
question = "Describe the playlisttrack table"
result = agent_executor.invoke({"messages": [{"role": "user", "content": question}]})
python
for m in result['messages']:
m.pretty_print()
================================[1m Human Message [0m=================================
Describe the playlisttrack table
==================================[1m Ai Message [0m==================================
Tool Calls:
sql_db_schema (call_b9ead21107664c6a851c85)
Call ID: call_b9ead21107664c6a851c85
Args:
table_names: playlisttrack
=================================[1m Tool Message [0m=================================
Name: sql_db_schema
Error: table_names {'playlisttrack'} not found in database
==================================[1m Ai Message [0m==================================
It seems that there was an error because the table 'playlisttrack' could not be found in the database. Let me first retrieve a list of the available tables to check if the correct table name is being used.
Tool Calls:
sql_db_list_tables (call_fc6f21ca229c4a84adb266)
Call ID: call_fc6f21ca229c4a84adb266
Args:
tool_input:
=================================[1m Tool Message [0m=================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
==================================[1m Ai Message [0m==================================
The correct table name is indeed 'PlaylistTrack', and it is available in the database. Let me now retrieve the schema for the 'PlaylistTrack' table.
Tool Calls:
sql_db_schema (call_ee9bcfe4dfae45edb01325)
Call ID: call_ee9bcfe4dfae45edb01325
Args:
table_names: PlaylistTrack
=================================[1m Tool Message [0m=================================
Name: sql_db_schema
CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)
/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/
==================================[1m Ai Message [0m==================================
The `PlaylistTrack` table has the following schema:
- `PlaylistId` (INTEGER, NOT NULL): This is a foreign key that references the `Playlist` table and is part of the composite primary key.
- `TrackId` (INTEGER, NOT NULL): This is a foreign key that references the `Track` table and is also part of the composite primary key.
The primary key for this table is a combination of `PlaylistId` and `TrackId`, which means that each track can only appear once in a given playlist. Here are a few sample rows from the `PlaylistTrack` table to illustrate:
| PlaylistId | TrackId |
|------------|---------|
| 1 | 3402 |
| 1 | 3389 |
| 1 | 3390 |
If you need more specific information or a query based on this table, please let me know!