如何使用LangChain自定义agent的制作(2) - 让大模型帮我们生成sql

文章目录

  • 前言
  • [一、 安装依赖包](#一、 安装依赖包)
  • [二、 设置数据库连接](#二、 设置数据库连接)
  • [三、 扫描数据库结构](#三、 扫描数据库结构)
  • [四、 生成 SQL 查询](#四、 生成 SQL 查询)
  • [五、 执行 SQL 查询](#五、 执行 SQL 查询)
  • [六、 运行示例](#六、 运行示例)
  • [七、 封装成类](#七、 封装成类)
  • 总结

前言

前一篇文章中,我们一起写了一个agent,为了简化代码是直接传递sql的,这一篇文章我们将通过大模型根据我们的自然语言生成sql,然后再通过agent查询数据并交给大模型思考得出结果。


一、 安装依赖包

首先,我们需要安装必要的 Python 包。我们将使用 langchainSQLAlchemy 进行数据库连接和查询生成。

bash 复制代码
pip install langchain sqlalchemy

二、 设置数据库连接

我们将以 SQLite 数据库为例,展示如何设置数据库连接并创建一个示例表。

python 复制代码
from sqlalchemy import create_engine

# 设置数据库连接
engine = create_engine('sqlite:///example.db')

# 创建一个示例表并插入一些数据
with engine.connect() as connection:
    connection.execute("""
    CREATE TABLE IF NOT EXISTS users (
        id INTEGER PRIMARY KEY,
        name TEXT,
        age INTEGER
    )
    """)
    connection.execute("""
    INSERT INTO users (name, age) VALUES
    ('Alice', 30),
    ('Bob', 25),
    ('Charlie', 35)
    """)

三、 扫描数据库结构

为了让语言模型生成正确的 SQL 查询,我们需要提供数据库的结构信息(表名和列名)。我们将使用 SQLAlchemy 的 inspect 模块来扫描数据库结构。

python 复制代码
from sqlalchemy import inspect

def inspect_db_structure(engine):
    inspector = inspect(engine)
    structure = {}
    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        structure[table_name] = [column['name'] for column in columns]
    return structure

db_structure = inspect_db_structure(engine)

四、 生成 SQL 查询

我们将使用 LangChain 的大语言模型(LLM)来生成 SQL 查询。为此,我们需要定义一个提示模板,并将用户的自然语言请求和数据库结构信息传递给模型。

python 复制代码
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI

# 初始化LangChain的LLM
llm = OpenAI(api_key="YOUR_OPENAI_API_KEY")

# 定义生成SQL查询的提示模板
sql_generation_prompt = PromptTemplate(
    template="You are an AI assistant. Given the following user request and the database structure, generate a SQL query.\n"
             "User request: {request}\n"
             "Database structure: {db_structure}\n"
             "SQL query:",
    input_variables=["request", "db_structure"]
)

def generate_sql_query(request, db_structure):
    query = llm({
        "prompt": sql_generation_prompt.format(
            request=request,
            db_structure=db_structure
        )
    })
    return query["choices"][0]["text"].strip()

五、 执行 SQL 查询

定义执行 SQL 查询的工具函数,并使用 LangChain 初始化 Agent 来执行查询。

python 复制代码
from langchain.agents import initialize_agent, Tool

# 创建数据库会话
Session = sessionmaker(bind=engine)
session = Session()

# 定义执行SQL查询的工具函数
def execute_sql_query(query):
    try:
        result = session.execute(query)
        return result.fetchall()
    except Exception as e:
        return str(e)

# 定义LangChain的工具
sql_tool = Tool(
    name="SQL Executor",
    func=execute_sql_query,
    description="Executes SQL queries and returns the result"
)

# 创建自定义Agent
agent = initialize_agent(
    tools=[sql_tool],
    llm=llm,
    agent_type="zero_shot",
    prompt_template=PromptTemplate(
        template="You are an SQL agent. Execute the following SQL query: {query}",
        input_variables=["query"]
    )
)

六、 运行示例

结合以上所有步骤,使用自定义 Agent 自动生成 SQL 查询并执行:

python 复制代码
def main():
    user_request = "Find all users older than 30"
    sql_query = generate_sql_query(user_request, db_structure)
    print(f"Generated SQL Query: {sql_query}")
    
    result = agent({"query": sql_query})
    print(result)

if __name__ == "__main__":
    main()

七、 封装成类

我们将上述功能封装到一个类中。

封装到一个类中有许多好处,包括模块化、可重用性、扩展性、简化复杂性和增强可维护性。通过封装,我们可以将复杂的逻辑抽象出来,使得代码更容易理解和维护,并且可以在不同的项目或不同的部分中重复使用。以下是封装后的完整代码:

python 复制代码
class SQLAgent:
    def __init__(self, database_url, api_key):
        # 设置数据库连接
        self.engine = create_engine(database_url)
        self.Session = sessionmaker(bind=self.engine)
        self.session = self.Session()

        # 扫描数据库结构
        self.db_structure = self.inspect_db_structure()

        # 初始化LangChain的LLM
        self.llm = OpenAI(api_key=api_key)

        # 定义执行SQL查询的工具函数
        def execute_sql_query(query):
            try:
                result = self.session.execute(query)
                return result.fetchall()
            except Exception as e:
                return str(e)

        # 定义LangChain的工具
        sql_tool = Tool(
            name="SQL Executor",
            func=execute_sql_query,
            description="Executes SQL queries and returns the result"
        )

        # 定义生成SQL查询的提示模板
        self.sql_generation_prompt = PromptTemplate(
            template="You are an AI assistant. Given the following user request and the database structure, generate a SQL query.\n"
                     "User request: {request}\n"
                     "Database structure: {db_structure}\n"
                     "SQL query:",
            input_variables=["request", "db_structure"]
        )

        # 创建自定义Agent
        self.agent = initialize_agent(
            tools=[sql_tool],
            llm=self.llm,
            agent_type="zero_shot",
            prompt_template=PromptTemplate(
                template="You are an SQL agent. Execute the following SQL query: {query}",
                input_variables=["query"]
            )
        )

    def inspect_db_structure(self):
        inspector = inspect(self.engine)
        structure = {}
        for table_name in inspector.get_table_names():
            columns = inspector.get_columns(table_name)
            structure[table_name] = [column['name'] for column in columns]
        return structure

    def generate_sql_query(self, request):
        query = self.llm({
            "prompt": self.sql_generation_prompt.format(
                request=request,
                db_structure=self.db_structure
            )
        })
        return query["choices"][0]["text"].strip()

    def execute(self, user_request):
        sql_query = self.generate_sql_query(user_request)
        print(f"Generated SQL Query: {sql_query}")
        
        result = self.agent({"query": sql_query})
        return result

# 使用示例
if __name__ == "__main__":
    database_url = 'sqlite:///example.db'
    api_key = 'YOUR_OPENAI_API_KEY'
    sql_agent = SQLAgent(database_url, api_key)
    
    user_request = "Find all users older than 30"
    result = sql_agent.execute(user_request)
    print(result)

总结

通过本文的示例,我们展示了如何使用 LangChain 和 SQLAlchemy 创建一个自定义的 SQL 查询 Agent。生产的步骤主要就是:扫描数据库结构 -> 生成 SQL 查询

-> 执行 SQL 查询。

大家可以使用自己喜欢的大模型来测试和练习

相关推荐
古希腊掌管学习的神28 分钟前
[搜广推]王树森推荐系统——矩阵补充&最近邻查找
python·算法·机器学习·矩阵
LucianaiB1 小时前
探索CSDN博客数据:使用Python爬虫技术
开发语言·爬虫·python
PieroPc4 小时前
Python 写的 智慧记 进销存 辅助 程序 导入导出 excel 可打印
开发语言·python·excel
梧桐树04297 小时前
python常用内建模块:collections
python
Dream_Snowar8 小时前
速通Python 第三节
开发语言·python
蓝天星空9 小时前
Python调用open ai接口
人工智能·python
jasmine s9 小时前
Pandas
开发语言·python
郭wes代码9 小时前
Cmd命令大全(万字详细版)
python·算法·小程序
leaf_leaves_leaf9 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零19 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志