通过 PromptTemplate 生成干净的 SQL 查询语句并执行SQL查询语句

问题描述

在使用 LangChain 和 Llama 模型生成 SQL 查询时,遇到了 sqlite3.OperationalError 错误。错误信息如下:

OperationalError: (sqlite3.OperationalError) near "```sql
SELECT Name 
FROM MediaType 
LIMIT 5;
```": syntax error
[SQL: ```sql
SELECT Name 
FROM MediaType 
LIMIT 5;
```]

错误发生的原因是生成的 SQL 查询包含了不必要的 Markdown 代码块标记 ```,也就是在生成SQL语句的过程中,产生了其他的不干净文本,导致 SQL 语法错误。

最终解决方案

通过修改 PromptTemplate 来生成干净的 SQL 查询,确保生成的查询不包含任何 Markdown 代码块标记或附加评论。以下是解决方案的详细步骤和代码实现:

1. 初始化环境

首先,初始化所需的环境变量和模型:

python 复制代码
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")

# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq", temperature=0)

2. 定义自定义提示模板

定义一个自定义的 PromptTemplate,用于生成干净的 SQL 查询:

python 复制代码
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)

3. 创建 SQL 查询链

创建一个 SQL 查询链,并使用自定义提示模板:

python 复制代码
write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)

4. 构造输入数据字典

构造输入数据字典,其中包含方言、表结构、问题和行数限制:

python 复制代码
input_data = {
    "dialect": db.dialect,                    # 数据库方言,如 "sqlite"
    "table_info": db.get_table_info(),        # 表结构信息
    "input": "What name of MediaType is?",    # 问题
    "top_k": 5                                # 行数限制
}

5. 调用链生成并执行 SQL 查询

调用链生成 SQL 查询,确保生成的查询不包含 Markdown 代码块标记,然后执行查询并打印结果:

python 复制代码
response = write_query.invoke(input_data)
query = response["query"]

# 执行 SQL 查询并打印结果
execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke({"query": query})
print(result)

总结

通过修改 PromptTemplate 来生成 SQL 查询时,明确要求返回的 SQL 查询不包含任何附加评论或 Markdown 格式,确保生成的 SQL 查询是干净的、可执行的。这样可以避免由多余的标记导致的 SQL 语法错误。

最后提供完整代码:

python 复制代码
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase

load_dotenv()


# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
    
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"])  # 注意需要传递列表
print(f"\n Original table info: {table_info}")

   
#  初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
    input_variables=["dialect", "input", "table_info", "top_k"],
    template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)


write_query  = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 构造输入数据字典,其中包含方言、表结构、问题和行数限制
input_data = {
    "dialect": db.dialect,                    # 数据库方言,如 "sqlite"
    "table_info": db.get_table_info(),          # 表结构信息
    "question": "What name of MediaType is?",
    "top_k": 5
}

# 调用链生成 SQL 查询,返回结果为一个字典,包含键 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result:',write_query_response)

#执行SQL语句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result:',execute_response)

#两个动作合起来搞成链
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)

输出:

相关推荐
m0_6724496023 分钟前
使用Java操作Excel
java·python·excel
张琪杭33 分钟前
PyTorch大白话解释算子二
人工智能·pytorch·python
weixin_3077791338 分钟前
PySpark实现获取Cloud Storage上Parquet文件的数据结构,并自动在Google BigQuery里建表和生成对应的建表和导入数据的SQL
数据仓库·python·spark·云计算·googlecloud
Watink Cpper39 分钟前
[MySQL初阶]MySQL(1)MySQL的理解、库的操作、表的操作
linux·运维·服务器·数据库·c++·后端·mysql
დ旧言~40 分钟前
【Python】基础知识四
python
尘世壹俗人44 分钟前
spark写数据库用连接池找不到driver类
大数据·数据库·spark
修昔底德1 小时前
费曼学习法12 - 告别 Excel!用 Python Pandas 开启数据分析高效之路 (Pandas 入门篇)
人工智能·python·学习·excel·pandas
网络风云1 小时前
Django 5实用指南(十二)异步处理与Celery集成
后端·python·django
蹦蹦跳跳真可爱5891 小时前
Python----线性代数(线性代数基础:标量,向量,矩阵,张量)
python·线性代数·矩阵
计算机学长大白2 小时前
Redis是什么?如何使用Redis进行缓存操作?
数据库·redis·缓存