文章目录
- 前言
- [一、 安装依赖包](#一、 安装依赖包)
- [二、 设置数据库连接](#二、 设置数据库连接)
- [三、 扫描数据库结构](#三、 扫描数据库结构)
- [四、 生成 SQL 查询](#四、 生成 SQL 查询)
- [五、 执行 SQL 查询](#五、 执行 SQL 查询)
- [六、 运行示例](#六、 运行示例)
- [七、 封装成类](#七、 封装成类)
- 总结
前言
前一篇文章中,我们一起写了一个agent,为了简化代码是直接传递sql的,这一篇文章我们将通过大模型根据我们的自然语言生成sql,然后再通过agent查询数据并交给大模型思考得出结果。
一、 安装依赖包
首先,我们需要安装必要的 Python 包。我们将使用 langchain
和 SQLAlchemy
进行数据库连接和查询生成。
bash
pip install langchain sqlalchemy
二、 设置数据库连接
我们将以 SQLite 数据库为例,展示如何设置数据库连接并创建一个示例表。
python
from sqlalchemy import create_engine
# 设置数据库连接
engine = create_engine('sqlite:///example.db')
# 创建一个示例表并插入一些数据
with engine.connect() as connection:
connection.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT,
age INTEGER
)
""")
connection.execute("""
INSERT INTO users (name, age) VALUES
('Alice', 30),
('Bob', 25),
('Charlie', 35)
""")
三、 扫描数据库结构
为了让语言模型生成正确的 SQL 查询,我们需要提供数据库的结构信息(表名和列名)。我们将使用 SQLAlchemy 的 inspect
模块来扫描数据库结构。
python
from sqlalchemy import inspect
def inspect_db_structure(engine):
inspector = inspect(engine)
structure = {}
for table_name in inspector.get_table_names():
columns = inspector.get_columns(table_name)
structure[table_name] = [column['name'] for column in columns]
return structure
db_structure = inspect_db_structure(engine)
四、 生成 SQL 查询
我们将使用 LangChain 的大语言模型(LLM)来生成 SQL 查询。为此,我们需要定义一个提示模板,并将用户的自然语言请求和数据库结构信息传递给模型。
python
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# 初始化LangChain的LLM
llm = OpenAI(api_key="YOUR_OPENAI_API_KEY")
# 定义生成SQL查询的提示模板
sql_generation_prompt = PromptTemplate(
template="You are an AI assistant. Given the following user request and the database structure, generate a SQL query.\n"
"User request: {request}\n"
"Database structure: {db_structure}\n"
"SQL query:",
input_variables=["request", "db_structure"]
)
def generate_sql_query(request, db_structure):
query = llm({
"prompt": sql_generation_prompt.format(
request=request,
db_structure=db_structure
)
})
return query["choices"][0]["text"].strip()
五、 执行 SQL 查询
定义执行 SQL 查询的工具函数,并使用 LangChain 初始化 Agent 来执行查询。
python
from langchain.agents import initialize_agent, Tool
# 创建数据库会话
Session = sessionmaker(bind=engine)
session = Session()
# 定义执行SQL查询的工具函数
def execute_sql_query(query):
try:
result = session.execute(query)
return result.fetchall()
except Exception as e:
return str(e)
# 定义LangChain的工具
sql_tool = Tool(
name="SQL Executor",
func=execute_sql_query,
description="Executes SQL queries and returns the result"
)
# 创建自定义Agent
agent = initialize_agent(
tools=[sql_tool],
llm=llm,
agent_type="zero_shot",
prompt_template=PromptTemplate(
template="You are an SQL agent. Execute the following SQL query: {query}",
input_variables=["query"]
)
)
六、 运行示例
结合以上所有步骤,使用自定义 Agent 自动生成 SQL 查询并执行:
python
def main():
user_request = "Find all users older than 30"
sql_query = generate_sql_query(user_request, db_structure)
print(f"Generated SQL Query: {sql_query}")
result = agent({"query": sql_query})
print(result)
if __name__ == "__main__":
main()
七、 封装成类
我们将上述功能封装到一个类中。
封装到一个类中有许多好处,包括模块化、可重用性、扩展性、简化复杂性和增强可维护性。通过封装,我们可以将复杂的逻辑抽象出来,使得代码更容易理解和维护,并且可以在不同的项目或不同的部分中重复使用。以下是封装后的完整代码:
python
class SQLAgent:
def __init__(self, database_url, api_key):
# 设置数据库连接
self.engine = create_engine(database_url)
self.Session = sessionmaker(bind=self.engine)
self.session = self.Session()
# 扫描数据库结构
self.db_structure = self.inspect_db_structure()
# 初始化LangChain的LLM
self.llm = OpenAI(api_key=api_key)
# 定义执行SQL查询的工具函数
def execute_sql_query(query):
try:
result = self.session.execute(query)
return result.fetchall()
except Exception as e:
return str(e)
# 定义LangChain的工具
sql_tool = Tool(
name="SQL Executor",
func=execute_sql_query,
description="Executes SQL queries and returns the result"
)
# 定义生成SQL查询的提示模板
self.sql_generation_prompt = PromptTemplate(
template="You are an AI assistant. Given the following user request and the database structure, generate a SQL query.\n"
"User request: {request}\n"
"Database structure: {db_structure}\n"
"SQL query:",
input_variables=["request", "db_structure"]
)
# 创建自定义Agent
self.agent = initialize_agent(
tools=[sql_tool],
llm=self.llm,
agent_type="zero_shot",
prompt_template=PromptTemplate(
template="You are an SQL agent. Execute the following SQL query: {query}",
input_variables=["query"]
)
)
def inspect_db_structure(self):
inspector = inspect(self.engine)
structure = {}
for table_name in inspector.get_table_names():
columns = inspector.get_columns(table_name)
structure[table_name] = [column['name'] for column in columns]
return structure
def generate_sql_query(self, request):
query = self.llm({
"prompt": self.sql_generation_prompt.format(
request=request,
db_structure=self.db_structure
)
})
return query["choices"][0]["text"].strip()
def execute(self, user_request):
sql_query = self.generate_sql_query(user_request)
print(f"Generated SQL Query: {sql_query}")
result = self.agent({"query": sql_query})
return result
# 使用示例
if __name__ == "__main__":
database_url = 'sqlite:///example.db'
api_key = 'YOUR_OPENAI_API_KEY'
sql_agent = SQLAgent(database_url, api_key)
user_request = "Find all users older than 30"
result = sql_agent.execute(user_request)
print(result)
总结
通过本文的示例,我们展示了如何使用 LangChain 和 SQLAlchemy 创建一个自定义的 SQL 查询 Agent。生产的步骤主要就是:扫描数据库结构 -> 生成 SQL 查询
-> 执行 SQL 查询。
大家可以使用自己喜欢的大模型来测试和练习