LangChain开发【NL2SQL】应用

前言

关于LangGraph的简单介绍,请参考这篇博客:

LangGraph开发Agent智能体应用【基础聊天机器人】-CSDN博客

对比LangChain实现NL2SQL

关于用LangChain开发NL2SQL的Agent应用,在这篇博客提供了完整的代码实现:

LangChain开发LLM应用【入门指南】_langchain 开发社区-CSDN博客

我在这里赘述一下:

复制代码
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase
 
# 数据库连接信息
username = 'root'
password = 'MyNewPass1!'
host = 'desk04v.mlprod.bjpdc.qihoo.net'
port = '3306'
database = 'test'
 
engine = create_engine(f'mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}')
db = SQLDatabase(engine)
result = db.run("select * FROM courses LIMIT 5;")
print(result)
 
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
 
llm = ChatOpenAI(model="gpt-4o", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
 
agent_executor.invoke(
    "找到学分最高的课程"
)

效果如下:

我并没有在代码中定义执行链,只是给LLM提供了一个工具集,让Agent自行决定如何使用。

可以看到,Agent,先查了一下数据库中有哪些表,找到看上去有用的表后,再查了一下表结构以及预览数据,再生成sql执行(先校验一遍再执行),最后整合结果告诉我结论。

可以说已经是非常智能了。

LangGraph实现实现NL2SQL

LangGraph的方式,就和LangChain不一样了,它的开发方式就是不断给图添加"节点"和"线",组成一个工作流。

注意:这里的工作流,并不是简单理解的操作流,LangGraph的工作流和LangChain的工作流不是是一个层面的东西,相信你看完这个例子就能感受到了。

还是先上代码!!!

第一步:定义工具集合

LangChain 和 LangGraph是打通的(准确的说,LangGraph是LangChain生态的高级框架)

所以我们可以直接使用LangChain的工具集 SQLDatabaseToolkit

如果你愿意深入看看源码,就知道这个工具集里有四个工具:

执行sql:QuerySQLDataBaseTool

查看表详情:InfoSQLDatabaseTool

sql语法检查:QuerySQLCheckerTool

查看所有表:ListSQLDatabaseTool

python 复制代码
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

# 数据库连接信息
username = 'root'
password = 'MyNewPass1!'
host = 'desk04v.mlprod.bjpdc.qihoo.net'
port = '3306'
database = 'test'

engine = create_engine(f'mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}')
db = SQLDatabase(engine)
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(temperature=0))
context = toolkit.get_context()
tools = toolkit.get_tools()

第二步:定义LLM节点,并加入到图中

让LLM绑定工具,一定要绑定,就像你需要告诉LLM,可以使用哪些工具,LLM才会生成调用计划

python 复制代码
from typing import Annotated

from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict

from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages


class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)

# expt_llm = "gpt-4-1106-preview"
expt_llm = "gpt-4o"
llm = ChatOpenAI(temperature=0, model=expt_llm)
# Modification: tell the LLM which tools it can call
llm_with_tools = llm.bind_tools(tools)


def chatbot(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}


graph_builder.add_node("chatbot", chatbot)

第三步:定义工具节点,并加入到图中

python 复制代码
import json

from langchain_core.messages import ToolMessage


class BasicToolNode:
    """运行最后一个AIMessage中请求的工具"""

    def __init__(self, tools: list) -> None:
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, inputs: dict):
        if messages := inputs.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        outputs = []
        for tool_call in message.tool_calls:
            print(tool_call["name"])
            print(self.tools_by_name[tool_call["name"]])
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": outputs}


tool_node = BasicToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)

第四步:定义"边"

add_edge方法是直接定义"边",在例子中表示tools -> chatbot

add_conditional_edges方法是增加条件路由"边",在例子中表示chatbot根据情况 -> tools 或者 -> end

python 复制代码
from typing import Literal


def route_tools(
    state: State,
) -> Literal["tools", "__end__"]:
    """如果最后一条消息,在conditional_edge中使用路由到ToolNode,就调用工具。否则,路线到终点。"""
    if isinstance(state, list):
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        raise ValueError(f"在tool_edge的输入状态中没有找到消息: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return "__end__"


# ' tools_condition '函数返回"tools",表示LLM要求使用工具,返回"__end__"直接结束。
graph_builder.add_conditional_edges(
    "chatbot",
    route_tools,
    # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node
    # It defaults to the identity function, but if you
    # want to use a node named something else apart from "tools",
    # You can update the value of the dictionary to something else
    # e.g., "tools": "my_tools"
    {"tools": "tools", "__end__": "__end__"},
)
# 任何时候调用一个工具,我们都会流转到聊天机器人
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()

第五步:把图画出来(非必需)

python 复制代码
from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except:
    # This requires some extra dependencies and is optional
    pass

效果如下:

整个流程很简单,用大白话讲,就是:

把提问信息传给LLM,LLM决定用什么工具,然后graph就调用工具返回结果传给LLM,LLM拿到结果后有可能继续调用工具,也有可能直接输出答案,如此循环或者终止。

第六步:执行

通过流式调用,传入用户的提问

python 复制代码
from langchain_core.messages import BaseMessage

while True:
    user_input = input("User: ")
    if user_input.lower() in ["quit", "exit", "q"]:
        print("Goodbye!")
        break
    for event in graph.stream({"messages": [("user", user_input)]}):
        for value in event.values():
            if isinstance(value["messages"][-1], BaseMessage):
                print("Assistant:", value["messages"][-1].content)

效果如下:

python 复制代码
User:  找到5月各地区的运费最贵的用户
Assistant: 
sql_db_list_tables
db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "arbitraments, courses, orders, scores, sink_chunjun_1, source_chunjun_1, students, test_binlog_1"
Assistant: 
sql_db_schema
description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "\nCREATE TABLE orders (\n\torder_id INTEGER NOT NULL COMMENT '\u8ba2\u5355ID', \n\tcustomer_id VARCHAR(255) COMMENT '\u5ba2\u6237ID', \n\temployee_id INTEGER COMMENT '\u5458\u5de5ID', \n\torder_date DATE COMMENT '\u8ba2\u5355\u65e5\u671f', \n\trequired_date DATE COMMENT '\u8981\u6c42\u4ea4\u8d27\u65e5\u671f', \n\tshipped_date DATE COMMENT '\u53d1\u8d27\u65e5\u671f', \n\tshipper_id INTEGER COMMENT '\u53d1\u8d27\u65b9\u5f0f', \n\tfreight DECIMAL(10, 2) COMMENT '\u8fd0\u8d39', \n\tship_name VARCHAR(255) COMMENT '\u6536\u8d27\u4eba\u540d\u79f0', \n\tship_address VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u5740', \n\tship_city VARCHAR(255) COMMENT '\u6536\u8d27\u57ce\u5e02', \n\tship_region VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u533a', \n\tship_postal_code VARCHAR(255) COMMENT '\u6536\u8d27\u90ae\u7f16', \n\torder_status VARCHAR(50) COMMENT '\u8ba2\u5355\u72b6\u6001', \n\tsnapshot_timestamp TIMESTAMP NULL COMMENT '\u5feb\u7167\u65f6\u95f4\u6233' DEFAULT CURRENT_TIMESTAMP, \n\tPRIMARY KEY (order_id)\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from orders table:\norder_id\tcustomer_id\temployee_id\torder_date\trequired_date\tshipped_date\tshipper_id\tfreight\tship_name\tship_address\tship_city\tship_region\tship_postal_code\torder_status\tsnapshot_timestamp\n1\tCUST001\t1\t2024-05-01\t2024-05-05\t2024-05-03\t1\t100.50\t\u5f20\u4e09\t\u5317\u4eac\u5e02\u671d\u9633\u533a\u5efa\u56fd\u8def100\u53f7\t\u5317\u4eac\t\u534e\u5317\t100022\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n2\tCUST002\t2\t2024-05-02\t2024-05-06\t2024-05-04\t2\t200.75\t\u674e\u56db\t\u4e0a\u6d77\u5e02\u6d66\u4e1c\u65b0\u533a\u4e16\u7eaa\u5927\u9053200\u53f7\t\u4e0a\u6d77\t\u534e\u4e1c\t200120\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n3\tCUST003\t3\t2024-05-03\t2024-05-07\t2024-05-05\t3\t150.00\t\u738b\u4e94\t\u5e7f\u5dde\u5e02\u5929\u6cb3\u533a\u4f53\u80b2\u897f\u8def300\u53f7\t\u5e7f\u5dde\t\u534e\u5357\t510620\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n*/"
Assistant: 
sql_db_query_checker
description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90> llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy='') llm_chain=LLMChain(prompt=PromptTemplate(input_variables=['dialect', 'query'], template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy=''))
Assistant: "SELECT ship_region, customer_id, freight FROM orders \nWHERE MONTH(order_date) = 5 \nORDER BY freight DESC \nLIMIT 1;"
Assistant: 
sql_db_query
description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields." db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "[('\u534e\u4e1c', 'CUST008', Decimal('300.80'))]"
Assistant: 在5月份各地区中,运费最贵的用户信息如下:

- 地区:华东
- 用户ID:CUST008
- 运费:300.80元
User:  统计5月的每个地区的运费最贵的用户
Assistant: 
sql_db_list_tables
db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "arbitraments, courses, orders, scores, sink_chunjun_1, source_chunjun_1, students, test_binlog_1"
Assistant: 
sql_db_schema
description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "\nCREATE TABLE orders (\n\torder_id INTEGER NOT NULL COMMENT '\u8ba2\u5355ID', \n\tcustomer_id VARCHAR(255) COMMENT '\u5ba2\u6237ID', \n\temployee_id INTEGER COMMENT '\u5458\u5de5ID', \n\torder_date DATE COMMENT '\u8ba2\u5355\u65e5\u671f', \n\trequired_date DATE COMMENT '\u8981\u6c42\u4ea4\u8d27\u65e5\u671f', \n\tshipped_date DATE COMMENT '\u53d1\u8d27\u65e5\u671f', \n\tshipper_id INTEGER COMMENT '\u53d1\u8d27\u65b9\u5f0f', \n\tfreight DECIMAL(10, 2) COMMENT '\u8fd0\u8d39', \n\tship_name VARCHAR(255) COMMENT '\u6536\u8d27\u4eba\u540d\u79f0', \n\tship_address VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u5740', \n\tship_city VARCHAR(255) COMMENT '\u6536\u8d27\u57ce\u5e02', \n\tship_region VARCHAR(255) COMMENT '\u6536\u8d27\u5730\u533a', \n\tship_postal_code VARCHAR(255) COMMENT '\u6536\u8d27\u90ae\u7f16', \n\torder_status VARCHAR(50) COMMENT '\u8ba2\u5355\u72b6\u6001', \n\tsnapshot_timestamp TIMESTAMP NULL COMMENT '\u5feb\u7167\u65f6\u95f4\u6233' DEFAULT CURRENT_TIMESTAMP, \n\tPRIMARY KEY (order_id)\n)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB\n\n/*\n3 rows from orders table:\norder_id\tcustomer_id\temployee_id\torder_date\trequired_date\tshipped_date\tshipper_id\tfreight\tship_name\tship_address\tship_city\tship_region\tship_postal_code\torder_status\tsnapshot_timestamp\n1\tCUST001\t1\t2024-05-01\t2024-05-05\t2024-05-03\t1\t100.50\t\u5f20\u4e09\t\u5317\u4eac\u5e02\u671d\u9633\u533a\u5efa\u56fd\u8def100\u53f7\t\u5317\u4eac\t\u534e\u5317\t100022\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n2\tCUST002\t2\t2024-05-02\t2024-05-06\t2024-05-04\t2\t200.75\t\u674e\u56db\t\u4e0a\u6d77\u5e02\u6d66\u4e1c\u65b0\u533a\u4e16\u7eaa\u5927\u9053200\u53f7\t\u4e0a\u6d77\t\u534e\u4e1c\t200120\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n3\tCUST003\t3\t2024-05-03\t2024-05-07\t2024-05-05\t3\t150.00\t\u738b\u4e94\t\u5e7f\u5dde\u5e02\u5929\u6cb3\u533a\u4f53\u80b2\u897f\u8def300\u53f7\t\u5e7f\u5dde\t\u534e\u5357\t510620\t\u5df2\u53d1\u8d27\t2024-06-04 17:05:11\n*/"
Assistant: 
sql_db_query_checker
description='Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query!' db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90> llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy='') llm_chain=LLMChain(prompt=PromptTemplate(input_variables=['dialect', 'query'], template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x7fd64e081310>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x7fd64e099350>, temperature=0.0, openai_api_key=SecretStr('**********'), openai_api_base='https://api.360.cn/v1', openai_proxy=''))
Assistant: "SELECT ship_region, customer_id, MAX(freight) as max_freight\nFROM orders\nWHERE order_date BETWEEN '2024-05-01' AND '2024-05-31'\nGROUP BY ship_region, customer_id"
Assistant: 
sql_db_query
description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields." db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x7fd64efb1f90>
Assistant: "[('\u534e\u5317', 'CUST001', Decimal('100.50')), ('\u534e\u4e1c', 'CUST002', Decimal('200.75')), ('\u534e\u5357', 'CUST003', Decimal('150.00')), ('\u534e\u5357', 'CUST004', Decimal('120.25')), ('\u897f\u5357', 'CUST005', Decimal('180.90')), ('\u897f\u5357', 'CUST006', Decimal('250.60')), ('\u534e\u4e1c', 'CUST007', Decimal('90.45')), ('\u534e\u4e1c', 'CUST008', Decimal('300.80')), ('\u534e\u4e2d', 'CUST009', Decimal('220.30')), ('\u897f\u5317', 'CUST010', Decimal('170.95'))]"
Assistant: 以下是5月每个地区运费最贵的用户:

| 地区   | 客户ID  | 最大运费 (¥) |
|-------|--------|--------------|
| 华北   | CUST001 | 100.50       |
| 华东   | CUST008 | 300.80       |
| 华南   | CUST003 | 150.00       |
| 西南   | CUST006 | 250.60       |
| 华中   | CUST009 | 220.30       |
| 西北   | CUST010 | 170.95       |

可以看到,每个地区的运费最高的用户及其对应的运费如上所示。
User:  q
Goodbye!

总结

不知道你有没有发现一个神奇的现象:

对于问题:找到5月各地区的运费最贵的用户

assistant生成的sql其实是错的

复制代码
SELECT ship_region, customer_id, MAX(freight) as max_freight
FROM orders
WHERE order_date BETWEEN '2024-05-01' AND '2024-05-31'
GROUP BY ship_region, customer_id

这是我在Leetcode上找的一个中等难度的sql题

assistant生成的这个sql,执行的结果,应该会出现同一个地区有多个用户的情况,但是最后给我的输出答案确实对的。。。

原因是:chatgpt4o 耍诈,它没有能力生成这么复杂sql,但是可以写一个中间结果的sql,然后自己把结果处理一下再返回给我

这只是测试数据集,数据量比较小,如果在生产环境上,那就有问题了。

如何解决?

可以使用few-shot的方式优化

关于few-shot的调优,我单独写了一个博客,请移步:LangGraph开发Agent智能体应用【NL2SQL】(few-shot优化)-CSDN博客

参考

🦜🕸️LangGraph - LangGraph

LangChain开发LLM应用【入门指南】_langchain 开发社区-CSDN博客

LangGraph开发Agent智能体应用【基础聊天机器人】-CSDN博客

Introduction | 🦜️🔗 LangChain

代码已上传,按需下载,谢谢大家

相关推荐
聚客AI19 分钟前
PyTorch玩转CNN:卷积操作可视化+五大经典网络复现+分类项目
人工智能·pytorch·神经网络
程序员岳焱22 分钟前
深度剖析:Spring AI 与 LangChain4j,谁才是 Java 程序员的 AI 开发利器?
java·人工智能·后端
有风南来22 分钟前
算术图片验证码(四则运算)+selenium
自动化测试·python·selenium·算术图片验证码·四则运算验证码·加减乘除图片验证码
wangjinjin18023 分钟前
Python Excel 文件处理:openpyxl 与 pandas 库完全指南
开发语言·python
柠檬味拥抱24 分钟前
AI智能体在金融决策系统中的自主学习与行为建模方法探讨
人工智能
智驱力人工智能34 分钟前
智慧零售管理中的客流统计与属性分析
人工智能·算法·边缘计算·零售·智慧零售·聚众识别·人员计数
workflower1 小时前
以光量子为例,详解量子获取方式
数据仓库·人工智能·软件工程·需求分析·量子计算·软件需求
壹氿1 小时前
Supersonic 新一代AI数据分析平台
人工智能·数据挖掘·数据分析
张较瘦_1 小时前
[论文阅读] 人工智能 | 搜索增强LLMs的用户偏好与性能分析
论文阅读·人工智能
我不是小upper1 小时前
SVM超详细原理总结
人工智能·机器学习·支持向量机