不借助框架实现Text2SQL

相关阅读:

Vanna实现Text2SQL:https://core815.blog.csdn.net/article/details/155195612

LangChain实现Text2SQL:https://core815.blog.csdn.net/article/details/155105220

1.代码实现(Text2SQL)

python 复制代码
## 使用Qwen-coder,对保险客户数据表进行SQL查询
import json
import os
import dashscope
from dashscope.api_entities.dashscope_response import Role
import time
import pandas as pd
import re

# 从环境变量获取 dashscope 的 API Key
api_key = os.environ.get('DASHSCOPE_API_KEY')
dashscope.api_key = api_key


# 封装模型响应函数
def get_response(messages):
    response = dashscope.Generation.call(
        model='qwen-coder-plus',
        messages=messages,
        result_format='message'  # 将输出设置为message形式
    )
    return response


# 从模型响应中提取SQL代码
def get_sql_code(response):
    # 查找```sql和```之间的内容
    pattern = r'```sql(.*?)```'
    match = re.search(pattern, response.output.choices[0].message.content, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        # 如果没有找到```sql标记,尝试查找任何```之间的内容
        pattern = r'```(.*?)```'
        match = re.search(pattern, response.output.choices[0].message.content, re.DOTALL)
        if match:
            return match.group(1).strip()
        else:
            # 如果没有找到任何代码块,返回整个响应
            return response.output.choices[0].message.content


# 得到sql
def get_sql(query):
    start_time = time.time()
    sys_prompt = """我正在编写SQL,以下是数据库中的数据表和字段,请思考:哪些数据表和字段是该SQL需要的,然后编写对应的SQL,如果有多个查询语句,请尝试合并为一个。编写SQL请采用```sql
    """
    user_prompt = f"""-- language: SQL
### Question: {query}
### Input: {create_sql}
### Response:
Here is the SQL query I have generated to answer the question `{query}`:
```sql
"""

    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_prompt}
    ]

    response = get_response(messages)
    return response


######## 需要人工设置 ########
save_file = f'sql_result_qwen_coder.xlsx'
qa_file = 'qa_list.txt'  # QA测试题
sql_file = './create_sql.txt'  # SQL数据表

# 读取 SQL数据表
with open(sql_file, 'r', encoding='utf-8') as file:
    create_sql = file.read()
# 读取 SQL问题列表
with open(qa_file, 'r', encoding='utf-8') as file:
    qa_list = file.read()
qa_list = qa_list.split('=====')

# 保存SQL结果
sql_list = []
markdown_list = []
time_list = []
for qa in qa_list:
    query = qa
    query = query.replace('\n', '')
    print("\n")
    print("分隔符-----------------------------------------------\n")
    print(query)
    start_time = time.time()
    # 请求生成sql
    #content, prompt_len = get_sql(query)
    response = get_sql(query)
    use_time = round(time.time()-start_time, 2)
    time_list.append(use_time)
    print('SQL生成时间:', use_time)
    print('response=', response.output.choices[0].message.content)
    # 提取生成的SQL
    sql = get_sql_code(response)
    print('SQL: {}'.format(sql))
    sql_list.append(sql)

result = pd.DataFrame(columns=['QA', 'SQL', 'time'])
result['QA'] = qa_list
result['SQL'] = sql_list
result['time'] = time_list
result.to_excel(save_file, index=False)
print(f'SQL结果已保存到 {save_file}')

效果:

2.评价生成的SQL

python 复制代码
# 对大模型SQL结果进行评测
import pandas as pd
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
import traceback


def get_db_engine(database, host="localhost", user="root",
                  password="a123456", port=3306):
    """创建数据库连接引擎"""
    connection_str = f'mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}'
    engine = create_engine(connection_str)
    return engine


def get_db_session(engine):
    """创建数据库会话"""
    Session = sessionmaker(bind=engine)
    session = Session()
    return session


def get_markdown_result(session, sql):
    """
    执行SQL并返回markdown格式的结果
    返回值:
    - isok: 'Yes'表示执行成功,'No'表示执行失败
    - result: 成功时返回markdown表格,失败时返回错误信息
    """
    try:
        # 执行SQL查询
        result = session.execute(text(sql))

        # 获取列名
        columns = result.keys()

        # 获取所有数据
        rows = result.fetchall()

        if not rows:
            return 'Yes', '查询结果为空'

        # 构建markdown表格
        markdown = '| ' + ' | '.join(columns) + ' |\n'
        markdown += '| ' + ' | '.join(['---' for _ in columns]) + ' |\n'

        # 添加数据行
        for row in rows:
            markdown += '| ' + ' | '.join(str(cell) for cell in row) + ' |\n'

        return 'Yes', markdown

    except Exception as e:
        error_msg = str(e)
        traceback_msg = traceback.format_exc()
        return 'No', f'SQL执行错误: {error_msg}'
    finally:
        session.close()


filename = './sql_result_qwen_coder.xlsx'

# 获取数据库句柄
engine = get_db_engine(database="life_insurance")

# 读取待评测文件
df = pd.read_excel(filename)

df['能否运行'] = 0
markdown_list = []
for index, row in df.iterrows():
    sql = row['SQL']
    print(sql)
    session = get_db_session(engine)
    if str(sql) == 'nan':
        df.loc[index, '能否运行'] = 'No 没有找到SQL'
        continue

    # 如果有多个sql,只执行第一个
    sqls = sql.split(';')
    sql = sqls[0]
    isok, markdown_table = get_markdown_result(session, sql)
    markdown_list.append(markdown_table)
    if isok == 'Yes':
        df.loc[index, '能否运行'] = 'Yes'
    else:
        df.loc[index, '能否运行'] = 'No ' + markdown_table

df.to_excel(filename, index=False)
df

效果:

3.相关资源

百度网盘:https://pan.baidu.com/s/1GpIkMpVZ9XuxKuH4ctQ_yA?pwd=yb8d

相关推荐
老友@2 小时前
RAG 的诞生:为了让 AI 不再“乱编”
人工智能·搜索引擎·ai·语言模型·自然语言处理·rag
三条猫2 小时前
将3D CAD 模型结构树转换为图结构,用于训练CAD AI的思路
人工智能·3d·ai·cad·模型训练·图结构·结构树
2501_941803622 小时前
互联网技术的未来趋势与发展方向:AI、5G与云计算的深度融合
mysql
熊猫钓鱼>_>4 小时前
从零开始构建RPG游戏战斗系统:实战心得与技术要点
开发语言·人工智能·经验分享·python·游戏·ai·qoder
SamDeepThinking4 小时前
在 MySQL 里,不建议使用长事务的根因
后端·mysql
q***3754 小时前
MySQL输入密码后闪退?
数据库·mysql·adb
张人玉4 小时前
SQLite 快速入门 Cheat Sheet
数据库·sql·sqlite
杨DaB4 小时前
【MySQL】03 数据库的CRUD
数据库·mysql·adb