一、代码任务介绍
该代码实现了一套基于大语言模型(GLM-4)的自然语言转SQL查询问答系统,核心目标是让用户通过中文自然语言提问,系统自动完成「生成SQL语句→执行SQL查询→返回自然语言答案」的全流程。具体能力包括:
- 解析中文自然语言问题,生成符合MySQL语法规范的SQL查询语句;
- 自动执行生成的SQL语句,从指定MySQL数据库中获取查询结果;
- 将「用户问题、SQL语句、查询结果」整合,生成通俗易懂的中文回答;
- 适配智谱AI GLM-4模型的API调用规范,保证SQL生成和回答生成的准确性。
二、代码整体流程总结
- 前置配置阶段:配置MySQL数据库连接信息(地址、端口、账号等)和GLM-4模型参数(API Key、模型版本、温度等),初始化数据库连接实例;
- SQL生成阶段:接收用户自然语言问题 → 自动获取数据库表结构信息 → 调用GLM-4生成SQL语句 → 清洗SQL语句(去除markdown格式、多余空格);
- SQL执行阶段:使用清洗后的SQL语句查询数据库 → 获取并返回查询结果;
- 回答生成阶段:将用户问题、生成的SQL语句、数据库查询结果传入GLM-4 → 模型生成自然语言回答 → 输出最终结果。
三、代码逐段解析(含关键作用注释)
python
# 导入核心依赖库
from operator import itemgetter # 用于从字典/对象中提取指定字段(如后续提取生成的SQL语句)
import os # 用于读取环境变量中的智谱API Key,避免硬编码
from langchain_community.utilities import SQLDatabase # LangChain封装的数据库连接工具,简化数据库操作
from langchain_community.tools import QuerySQLDataBaseTool # SQL执行工具,安全执行SQL并返回结果
from langchain_openai import ChatOpenAI # OpenAI兼容的LLM客户端,适配智谱GLM-4 API
from langchain_core.runnables import RunnablePassthrough, RunnableLambda # LangChain核心执行组件:
# RunnablePassthrough用于透传输入数据,RunnableLambda用于封装自定义函数
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate # 提示词模板,定义LLM的输入格式
from langchain_core.output_parsers import StrOutputParser # 输出解析器,将LLM返回的复杂对象转为字符串
# ✅ 清洗 SQL 输出:解决LLM生成SQL时可能附带markdown格式(```sql/```)的问题,保证SQL可执行
def clean_sql_output(ai_msg):
# 兼容LLM不同返回格式:优先读取content属性,无则转为字符串
sql = ai_msg.content if hasattr(ai_msg, 'content') else str(ai_msg)
# 清理SQL语句:去除首尾空格 → 移除开头的```sql → 移除结尾的```→ 再次清理首尾空格
return sql.strip().removeprefix('```sql').removesuffix('```').strip()
# 构建SQL查询生成链:封装「获取表结构→生成SQL→清洗SQL」的完整逻辑
def create_sql_query_chain(llm, database: SQLDatabase):
# 定义获取数据库表结构的函数:接收任意输入,返回数据库所有表的结构信息(字段、类型、约束等)
def get_table_info(_):
return database.get_table_info()
# 定义SQL生成的提示词模板:约束LLM角色为SQL专家,输入表结构和用户问题,输出纯SQL语句
prompt = ChatPromptTemplate.from_messages([
("system", "You are an expert SQL assistant. Given the following database schema:\n{table_info}\n"
"Write a syntactically correct SQL query that answers the user's question. "
"Only return the SQL query, nothing else. Do not wrap it in markdown."),
("human", "{question}")
])
# 组装SQL生成链:
# 1. 组装输入数据(透传用户问题 + 调用函数获取表结构)
# 2. 将输入传入提示词模板,生成完整prompt
# 3. 调用LLM生成SQL语句
# 4. 清洗SQL语句格式,确保可执行
return (
{"question": RunnablePassthrough(), "table_info": RunnableLambda(get_table_info)}
| prompt
| llm
| RunnableLambda(clean_sql_output)
)
# === 数据库配置 ===
HOSTNAME = '127.0.0.1' # MySQL服务器地址
PORT = '3306' # MySQL端口号
DATABASE = 'test' # 目标数据库名
USERNAME = 'test' # 数据库登录用户名
PASSWORD = '123456' # 数据库登录密码
# 构建MySQL连接URI:使用pymysql驱动,指定utf8mb4字符集避免中文乱码
MYSQL_URI = f'mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}?charset=utf8mb4'
# === 模型配置 ===
api_key = os.getenv('Zhipu_API_KEY') # 从环境变量读取智谱API Key(安全最佳实践)
model = ChatOpenAI(
model='glm-4-0520', # 指定使用GLM-4 0520版本模型
temperature=0, # 温度设为0,保证生成结果的确定性(适合SQL生成场景)
api_key=api_key, # 传入智谱API Key
base_url='https://open.bigmodel.cn/api/paas/v4' # 智谱API的base_url(兼容OpenAI格式)
)
# === 初始化数据库和工具 ===
db = SQLDatabase.from_uri(MYSQL_URI) # 初始化数据库连接实例
create_sql = create_sql_query_chain(llm=model, database=db) # 初始化SQL生成链
execute_sql = QuerySQLDataBaseTool(db=db) # 初始化SQL执行工具(绑定数据库实例)
# === 回答模板 ===
# 定义自然语言回答的提示词模板:整合问题、SQL、结果,要求模型生成中文回答
answer_prompt = PromptTemplate.from_template(
"""根据用户问题、对应 SQL 查询和结果,用中文回答问题:
问题:{question}
SQL 查询:{query}
查询结果:{result}
最终答案:"""
)
# 组装回答生成链:提示词模板 → 调用LLM → 解析为字符串输出
answer_chain = answer_prompt | model | StrOutputParser()
# === 最终链 ===
# 组装端到端的问答链:
# 1. 透传用户输入,新增query字段(值为生成的SQL语句),query对应 answer_prompt 模板
# 2. 基于query字段执行SQL,新增result字段(值为SQL查询结果)result对应 answer_prompt 模板
# 3. 将所有字段传入回答生成链,生成最终回答
chain = (
RunnablePassthrough.assign(query=create_sql)
.assign(result=itemgetter('query') | execute_sql)
| answer_chain
)
# RunnablePassthrough.assign 用于 新增 / 覆盖指定字段
#input_data = {"question": "一共有多少员工?"}
# assign 新增字段逻辑
new_data = RunnablePassthrough.assign(
# 新增字段名 = 计算该字段的Runnable
query=create_sql_chain # 用SQL生成链计算query字段值
).invoke(input_data)
# 最终 new_data 结构:
# {
# "question": "一共有多少员工?" # 原始字段(透传)
# "query": "SELECT COUNT(*) FROM employees" # assign新增的字段
# }
# === 测试 ===
if __name__ == "__main__":
# 测试场景1:聚合查询(统计员工总数)
print("=== 测试1:员工总数 ===")
resp1 = chain.invoke({"question": "请问:一共有多少个员工?"})
print(resp1)
# 测试场景2:条件查询(查询权限最高的员工及权限)
print("\n=== 测试2:权限最高员工 ===")
resp2 = chain.invoke({"question": "请问:哪个员工的权限最高?并且返回该员工的权限"})
print(resp2)
四、关键组件核心作用说明
| 组件/函数名称 | 核心作用 |
|---|---|
clean_sql_output |
清洗LLM生成的SQL语句,移除markdown格式标记,保证SQL语句可直接执行 |
create_sql_query_chain |
封装SQL生成的全流程,整合「表结构获取→提示词构建→LLM生成→格式清洗」 |
SQLDatabase |
建立与MySQL数据库的连接,提供表结构查询、SQL执行的基础能力 |
QuerySQLDataBaseTool |
安全执行SQL语句,封装数据库游标操作,避免直接操作底层连接 |
RunnablePassthrough |
透传用户输入数据,同时支持动态新增计算字段(如query、result) |
answer_chain |
整合「问题、SQL、结果」生成自然语言回答,提升用户体验 |
ChatOpenAI |
适配智谱GLM-4 API的LLM客户端,兼容LangChain的链式调用逻辑 |
五、使用说明
- 环境依赖安装:需提前安装
langchain-community langchain-core openai pymysql依赖包; - 环境变量配置:需在系统环境变量中配置
Zhipu_API_KEY(智谱AI的API密钥); - 数据库适配:修改「数据库配置」部分的HOSTNAME、PORT、DATABASE等参数,匹配实际的MySQL环境;
- 模型适配:可根据需求调整model参数(如glm-4-flash)或temperature值(平衡生成多样性与准确性)。