人工智能基础知识笔记二十三:构建一个可以查询数据库的Agent

1、介绍

本片文章主要是构建一个可以查询数据库的Agent,只需要输入问题,就可以自动查询数据库,并且获得执行SQL的结果。主要实现的功能:

  • 生成SQL: 根据用户输入的问题,生成查询SQL的query语句
  • 执行SQL : 负责执行SQL获取结果
  • 验证SQL: 以防有操作数据的SQL,保证只能查询
  • 修正SQL: 如果SQL有错误,尝试修正SQL

2、准备

2.1 环境

请安装一下依赖的第三方库:

python 复制代码
pip install -U langchain langchain-ollama, langchain-community langchain-core

如果遇到错误,可以升级一下pydantic:

python 复制代码
pip install --upgrade pydantic

2.2 数据

https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip 下载一个SQLLite的数据库,数据库的表和表之间的关系,可以参考:https://www.sqlitetutorial.net/wp-content/uploads/2018/03/sqlite-sample-database-diagram.pdf。

3、实现

3.1 导入依赖

python 复制代码
import re
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain.agents import create_agent
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage

3.2 构建数据库,并且测试是否可以连接成功

python 复制代码
print("Setting up database...")

db = SQLDatabase.from_uri("sqlite:///db/Chinook.db")

# check connetion and get basic info

try:
    # testing connection by getting table name
    tables = db.get_table_names()
    print(f"Connected to db successfully!")
    print(f"found {len(tables)} Tables: [{','.join(tables)}]")
except Exception as e:
    print(f"Failed to connect to db")

# get schema of tables
SCHEMA = db.get_table_info()
print("\nTable Schema:", SCHEMA)

Jupyter Notebook/VS Code 中**,** 执行结果如下,则说明构建成功**。**

可以通过以下SQL,查看所有的数据库里的表:

python 复制代码
db.run("select name from sqlite_master where type='table'")

3.3 构建LLM

python 复制代码
######## Setting Up Ollama
base_url = "http://localhost:11434"
llm = ChatOllama(model="qwen2.5", base_url=base_url) # , temperature=0

response = llm.invoke(input="Hello, how are you?")
response.pretty_print()
print("Initialized Ollama")

3.4 实现操作数据库的tools

3.4.1 获取数据库的Schema

python 复制代码
##### Step 3: SQL Tools 

@tool
def get_database_schema(table_name: str = None) -> str:
    """Get the schema of a database for SQL query generation. 
    Use this first to understand table structure before creating SQL queries."""

    print(f"Getting schema for {table_name if table_name else 'all tables'}")

    if table_name: 
        try:
            #get specific table 
            tables = db.get_table_names()
            if table_name.lower() in [t.lower() for t in tables]: 
                result = db.get_table_info([table_name])
                print(f"Retrieeved schema for {table_name}")
                return result
            else: 
                return f"Error 404: Table {table_name} not found. Available tables: {','.join(tables)}"
        except Exception as e: 
            return f"Error retrieving schema for {table_name}: {e}"
    else:
        print("Retrieving schema for all tables")
        return SCHEMA
    

Jupyter Notebook/VS Code中,测试输入如下,说明工具可以正常工作:

3.4.2 根据query生成SQL

python 复制代码
@tool
def generate_sql_query(question:str, schema_info:str = None) -> str:
    """Generate a SOL SELECT query from a natural language question using database schema.
        Always use this after getting schema information."""
    print(f"Generating SQL for: {question[:100]}...")

    # Use provided schema or get full schema
    schema_to_use = schema_info if schema_info else SCHEMA
    prompt = f"""
        Based on this database schema:
            {schema_to_use}
        Generate a SQL query to answer this question: {question}
        Rules :
         - Use only SELECT statements
         - Include only existing columns and tables
         - Add appropriate WHERE, GROUP BY, ORDER BY clauses as needed
         - Limit results to 10 rows unless specified otherwise
         - Use proper SQL syntax for SQLite
        
            Return only the soL query, nothing else
        """
    
    try:
        response = llm.invoke(prompt)
        query = response.content.strip()
        return query
    except Exception as e:
        print(f"Error: {e}")
        return None

generate_sql_query的测试如下:

3.4.3 验证生成SQL

python 复制代码
@tool
def validate_sql_query(query: str)-> str:
    """Validate SOL query for safety and syntax before execution.
    Returns 'Valid:<query>'if safe or 'Error: <message>' if unsafe."""
    print(f"® Validating SQL: {query[:100]}...")
    # Clean up the query
    clean_query = query.strip()
    # Remove SOL code block markers if present
    clean_query = re.sub(r'```sql\s*','',clean_query,flags=re.IGNORECASE)
    clean_query = re.sub(r'```\s*','', clean_query)
    clean_query= clean_query.strip().rstrip(";")
    
    # Check 1: Must be a SELECT statement
    if not clean_query.lower().startswith("select"):
        return "Error: Only SELECT statements are allowed"
    # Check 2: Block dangerous SQL keywords
    dangerous_keywords =['UPDATE', 'DELETE', 'ALTER', 'DROP','CREATE', 'TRUNCATE']
    query_upper = clean_query.upper()
    for keyword in dangerous_keywords :
        if keyword in query_upper:
            return f"Error:{keyword} operations are not allowed"
    print("√ Query validation passed")
    return f"Valid:{clean_query}"

validate_sql_query的测试如下:

3.4.4 执行生成SQL

python 复制代码
@tool
def execute_sql_query(query: str) -> str:
    """Execute an SQL query and return the result results.
        Only use this after verifying the query for safety.
    """
    print(f"® Executing SQL: {query[:100]}...")
    try:
        clean_query = query.strip()
        if clean_query.startswith("Valid:"):
           clean_query = clean_query[6:]    # remove "Valid:"
        
        clean_query = re.sub(r'```sql\s*','',clean_query,flags=re.IGNORECASE)
        clean_query = re.sub(r'```\s*','', clean_query)
        clean_query= clean_query.strip().rstrip(";")

        #execute query
        result = db.run(clean_query)
        print(f"® Query executed successfully.")
        if result:
            return f"Query result:\n{result}"
        else:
            return "query successfully but No results found."
    except Exception as e:
        result = f"Error: {e}"
        return result

execute_sql_query的测试如下:

3.4.5 修正错误的SQL

python 复制代码
@tool
def fix_sql_error(original_query: str, error_message: str, question: str)->str:
    """Fix a failed SOL query by analyzing the error and generating a corrected version.
     Use this when validation or execution fails."""

    print(f"' Fixing SOL error: {error_message[:100]}...")
    fix_prompt = f"""
            The following soL queryfailed:
            Query: {original_query}
            Error: {error_message}Original Question: {question}
            DatabaseSchema:
            {SCHEMA}
            Analyze the error and provide a corrected SOL query that:
            1. Fixes the specific error mentioned
            2. Still answers the original question
            3. Uses only valid table and column names from the schema
            4. Follows SLite syntax rules
            Return only the corrected soL query, nothing else.
            """
    try:
        response = llm.invoke(fix_prompt)
        fixed_query = response.content.strip()
        print(f"' Fixed query: {fixed_query[:100]}...")
        return fixed_query
    except Exception as e:
        print(f"' Failed to fix query: {e}")

fix_sql_error的测试如下:

3.5 定义一个统一的Prompt模板

python 复制代码
###### System Prompt
SOL_SYSTEM_PROMPT = f"""You are an expert sOl analyst working with an employees database.
        DatabaseSchema :
        {SCHEMA}

        Your workflow for answering questions:
        1. Use `get_database_schema` first to understand available tables and columns (if needed)
        2. Use `generate_sql_query` to create SOL based on the question
        3. Use`validate_sql_query`to check the query for safety and syntax
        4. Use `execute_sql_query`to run the validated query
        5. If there's an error, use fix_sql_error to correct it and try again (up to 3 times)
        6. Provide a clear answer based on the query results
        Rules:
         - Always follow the workflow step by step
         - If a query fails, use the fix ool and try again
         - Provide clear,informative answers
         - Be precise with table and column names
         - Handle errors gracefully and try to fix them
         - If you fail after 3 attempts, explain what went wrong

        Available tools for each step:
        get_database_schema: Get table structure info
        generate_sql_query:Create SOL from question
        validate_sql_query: check query safety/syntax
        execute_sql_query: Run the queify
       fix_sql_error: Fix failed queries
Remember: Always validate queries before executing them for safety.
"""

3.6 创建一个可以执行SQL的Agent

python 复制代码
### Create a Agent

tools = [
    get_database_schema,
    generate_sql_query,
    validate_sql_query,
    execute_sql_query,
    fix_sql_error
]

sql_agent = create_agent(
    llm,
    tools,
    system_prompt=SOL_SYSTEM_PROMPT
)

Jupyter Notebook/VS Code中,查看sql_agent,可以看到Agent的结构:

3.7 封装一个统一的接口

python 复制代码
###### Query Functinos ###############

def ask_question(question):
    """Ask the SQL agent a question using the full workflow."""
    print(f"\n{'#'*60}")
    print(f"SQL AGENT - Question: {question}")
    print('#'*60)
    for event in sql_agent.stream({"messages": question}, stream_mode="values" ):
        msg= event["messages"][-1]
    
    # Show tool usage
    if hasattr(msg,'tool calls') and msg.tool_calls:
        for tc in msg.tool_calls:
            print(f"\n' using Tool: {tc['name']}")
            print(f"Args:{str(tc['args'])[:200]}")
    # Show final answer
    elif hasattr(msg,'content') and msg.content:
        print(f"\nAnswer is:\n{msg.content}")
        

Jupyter Notebook/VS Code中,尝试测试一个问题如下:

python 复制代码
ask_question("What is the average UnitPrice in invoice_items")

执行结果:

可以看到SQL被执行成功,而且返回了最终发票的UnitPrice的平均值。

执行"SELECT AVG(UnitPrice) FROM invoice_items",验证结果是否正确:

可以看到结果和实际的值是一致的。

相关推荐
星浩AI18 小时前
Skill 的核心要素与渐进式加载架构——如何设计一个生产可用的 Skill?
人工智能·agent
树獭非懒18 小时前
告别繁琐多端开发:DivKit 带你玩转 Server-Driven UI!
android·前端·人工智能
阿尔的代码屋18 小时前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python
小小小怪兽18 小时前
🔨聊一聊Skills
人工智能·agent
穿过生命散发芬芳18 小时前
OpenClaw:开启OpenCloudOS 操作系统智能运维初体验
人工智能·aigc
老金带你玩AI19 小时前
Claude Code自动记忆来了!配合老金三层记忆系统全开源!加强Plus!
人工智能
Halo咯咯19 小时前
无限免费 OpenClaw:接入本地模型后,你的 AI Agent 就可以 24 小时自动干活(Mac Mini 可用)
人工智能
NAGNIP1 天前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab1 天前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab1 天前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读