检索增强的 NLP2SQL 生成

重磅推荐专栏: 《大模型AIGC》 《课程大纲》 《知识星球》
本专栏致力于探索和讨论当今最前沿的技术趋势和应用领域,包括但不限于ChatGPT和Stable Diffusion等。我们将深入研究大型模型的开发和应用,以及与之相关的人工智能生成内容(AIGC)技术。通过深入的技术解析和实践经验分享,旨在帮助读者更好地理解和应用这些领域的最新进展

如何让大模型进行数据分析呢?

我们先来看下这样的应用:

用户可以上传表格数据,进行数据分析类的提问。 大模型回复会包括:提问的回答、图表展示、数据解释说明、以及最终给到完整数据查询结果的Excel文件下载链接。

可以指定要求画柱状图、折线图、饼图 以及 柱状折线结合图 ,用图表来更加形象的回复: 不指定用哪个图表时候,大模型也可以智能的结合合适的图表 来回复展示数据: 以下是该数据分析应用的大致交互流程图: 可以看到,其实这个流程最具技术挑战的便是:NLP2SQL。今天我们就来讲讲如果用RAG(检索增强生成)的方案来提升大模型的NLP2SQL能力。

引言

自然语言到 SQL(NLP2SQL)技术正迅速改变数据分析领域,使非技术用户能够直接与数据库进行交互。然而,传统 NLP2SQL 方法面临诸多挑战,包括处理复杂数据库模式、理解业务术语以及生成准确 SQL 查询等。本文深入探讨如何通过检索增强生成(Retrieval-Augmented Generation, RAG) 技术解决这些问题,并以开源框架 Vanna 为例详细解析其实现原理。

1. RAG 在 NLP2SQL 中的应用原理

1.1 RAG 架构概述

检索增强生成将信息检索与生成模型相结合,解决传统生成模型的局限性:

组件 功能 在 NLP2SQL 中的应用
检索器 根据输入查找相关上下文 从知识库中查找相关DDL、文档和SQL示例
生成器 基于上下文生成响应 根据检索结果生成准确SQL语句
知识库 存储结构化信息 包含数据库模式、业务术语和查询示例

1.2 Vanna 的 RAG 实现

Vanna 采用分层架构实现 RAG:

训练过程:构建知识库

train() 方法是构建 RAG 知识库的核心,支持三种训练数据类型:

2.1 DDL 训练

python 复制代码
def train(self, ddl: str = None, kwargs) -> str:
    """
    训练 DDL 语句:存储表结构定义
    Args:
        ddl (str): 数据定义语言语句
    Returns:
        str: 训练数据的唯一ID
    """
    if ddl:
        # 生成 DDL 文本的嵌入向量
        embedding = self.generate_embedding(ddl)
        # 存储到向量数据库
        doc_id = self.vector_store.add_document(
            content=ddl,
            embedding=embedding,
            doc_type="ddl"
        )
        return doc_id

DDL 训练示例:

python 复制代码
vn.train(ddl="""
CREATE TABLE customers (
    customer_id INT PRIMARY KEY,
    name VARCHAR(100),
    email VARCHAR(255) UNIQUE,
    signup_date DATE
);
""")

2.2 文档训练

python 复制代码
def train(self, documentation: str = None, kwargs) -> str:
    """
    训练业务文档:存储业务术语和解释
    Args:
        documentation (str): 业务文档文本
    Returns:
        str: 训练数据的唯一ID
    """
    if documentation:
        # 生成文档的嵌入向量
        embedding = self.generate_embedding(documentation)
        # 存储到向量数据库
        doc_id = self.vector_store.add_document(
            content=documentation,
            embedding=embedding,
            doc_type="documentation"
        )
        return doc_id

业务文档训练示例:

python 复制代码
vn.train(documentation="活跃用户定义为过去30天内有登录行为的用户")

2.3 SQL 示例训练

python 复制代码
def add_question_sql(self, question: str, sql: str, kwargs) -> str:
    """
    训练 SQL 示例:存储问题-SQL 对
    Args:
        question (str): 自然语言问题
        sql (str): 对应的 SQL 查询
    Returns:
        str: 训练数据的唯一ID
    """
    # 组合问题和SQL
    data = f"Question: {question}\nSQL: {sql}"
    # 生成嵌入向量
    embedding = self.generate_embedding(data)
    # 存储到向量数据库
    doc_id = self.vector_store.add_document(
        content=data,
        embedding=embedding,
        doc_type="sql_example"
    )
    return doc_id

SQL 示例训练示例:

python 复制代码
vn.train(
    question="最近30天的活跃用户数",
    sql="SELECT COUNT(*) FROM users WHERE last_login >= NOW() - INTERVAL '30 days'"
)

2.4 自动训练计划

对于大型数据库,Vanna 提供自动生成训练计划的功能:

python 复制代码
def get_training_plan_generic(self, df) -> TrainingPlan:
    """
    自动生成训练计划
    Args:
        df (pd.DataFrame): INFORMATION_SCHEMA.COLUMNS 查询结果
    Returns:
        TrainingPlan: 训练计划对象
    """
    plan = TrainingPlan([])
    # 识别数据列
    db_col = next((col for col in df.columns if 'database' in col.lower()), None)
    schema_col = next((col for col in df.columns if 'schema' in col.lower()), None)
    table_col = next((col for col in df.columns if 'table_name' in col.lower()), None)
    
    for database in df[db_col].unique():
        for schema in df[df[db_col] == database][schema_col].unique():
            for table in df[(df[db_col]==database) & 
                           (df[schema_col]==schema)][table_col].unique():
                # 获取表的所有列
                table_cols = df[(df[db_col]==database) & 
                               (df[schema_col]==schema) & 
                               (df[table_col]==table)]
                # 构建表描述
                doc = f"表 {table} 位于 {database}.{schema},包含列:\n\n"
                doc += table_cols[['column_name', 'data_type', 'comment']].to_markdown()
                # 添加到训练计划
                plan.add_item(
                    item_type=TrainingPlanItem.ITEM_TYPE_IS,
                    item_group=f"{database}.{schema}",
                    item_name=table,
                    item_value=doc
                )
    return plan

3. SQL 生成过程

generate_sql() 是 Vanna 的核心方法,实现检索增强的 SQL 生成:

3.1 整体流程

3.2 上下文检索

python 复制代码
def generate_sql(self, question: str, kwargs) -> str:
    # 1. 检索相关问题-SQL对
    question_sql_list = self.get_similar_question_sql(question, kwargs)
    
    # 2. 检索相关DDL
    ddl_list = self.get_related_ddl(question, kwargs)
    
    # 3. 检索相关文档
    doc_list = self.get_related_documentation(question, kwargs)
    
    # ...后续处理...

检索方法实现:

python 复制代码
def get_similar_question_sql(self, question: str, top_k: int = 5, kwargs) -> list:
    """
    检索相关问题-SQL对
    Args:
        question (str): 用户问题
        top_k (int): 返回结果数量
    Returns:
        list: 相关问题-SQL对列表
    """
    # 生成问题嵌入向量
    embedding = self.generate_embedding(question)
    # 向量数据库查询
    results = self.vector_store.query(
        embedding=embedding,
        top_k=top_k,
        doc_type="sql_example"
    )
    # 解析结果
    return [{
        'question': res['content'].split('\n')[0].replace('Question: ', ''),
        'sql': res['content'].split('\n')[1].replace('SQL: ', '')
for res in results]

3.3 提示工程

python 复制代码
def get_sql_prompt(self, question, question_sql_list, ddl_list, doc_list, kwargs):
    """
    构建LLM提示
    Args:
        question: 用户问题
        question_sql_list: 相关问题-SQL
        ddl_list: 相关DDL
        doc_list: 相关文档
    Returns:
        list: 消息列表
    """
    # 基础系统提示
    prompt = [self.system_message(
        f"你是 {self.dialect} 专家,根据以下上下文生成SQL:")]
    
    # 添加DDL上下文
    if ddl_list:
        prompt.append(self.user_message("= 表结构 ="))
        for ddl in ddl_list:
            prompt.append(self.user_message(ddl))
    
    # 添加文档上下文
    if doc_list:
        prompt.append(self.user_message("= 业务定义 ="))
        for doc in doc_list:
            prompt.append(self.user_message(doc))
    
    # 添加SQL示例上下文
    if question_sql_list:
        prompt.append(self.user_message("= 类似SQL示例 ="))
        for item in question_sql_list:
            prompt.append(self.user_message(f"问题:{item['question']}"))
            prompt.append(self.assistant_message(f"SQL:{item['sql']}"))
    
    # 添加当前问题
    prompt.append(self.user_message(f"问题:{question}"))
    
    # 添加响应指南
    prompt.append(self.system_message(
        "= 响应指南 =\n"
        "1. 仅返回SQL语句\n"
        "2. 使用标准SQL语法\n"
        "3. 包含必要的JOIN\n"
        f"4. 符合{self.dialect}规范"
    ))
    
    return prompt

3.4 SQL 生成与优化

python 复制代码
def generate_sql(self, question: str, kwargs) -> str:
    # ...检索上下文...
    
    # 构建提示
    prompt = self.get_sql_prompt(question, question_sql_list, ddl_list, doc_list, kwargs)
    
    # 提交到LLM
    llm_response = self.submit_prompt(prompt, kwargs)
    
    # 提取SQL
    sql = self.extract_sql(llm_response)
    
    # 处理中间SQL
    if 'intermediate_sql' in llm_response and kwargs.get('allow_llm_to_see_data', False):
        # 执行中间SQL获取数据
        try:
            df = self.run_sql(extract_sql(llm_response))
            # 添加数据到文档
            doc_list.append(f"查询结果:\n{df.head().to_markdown()}")
            # 重新生成最终SQL
            prompt = self.get_sql_prompt(question, question_sql_list, ddl_list, doc_list, kwargs)
            llm_response = self.submit_prompt(prompt, kwargs)
            sql = self.extract_sql(llm_response)
        except Exception as e:
            return f"中间SQL执行错误:{str(e)}"
    
    return sql

SQL 提取方法:

python 复制代码
def extract_sql(self, llm_response: str) -> str:
    """
    从LLM响应中提取SQL语句
    Args:
        llm_response (str): LLM完整响应
    Returns:
        str: 纯SQL语句
    """
    # 尝试匹配SQL代码块
    matches = re.findall(r"sql\n(.*?)", llm_response, re.DOTALL)
    if matches:
        return matches[0].strip()
    
    # 尝试匹配SELECT语句
    matches = re.findall(r"(SELECT\s+.*?;)", llm_response, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches[0]
    
    # 尝试匹配其他SQL语句
    matches = re.findall(r"((?:INSERTUPDATE DELETE CREATE ALTER
DROP).*?;)", 
                         llm_response, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches[-1]
    
    # 默认返回整个响应
    return llm_response

性能优化策略

4.1 上下文选择策略

Vanna 采用多级上下文选择策略:

4.2 动态上下文优化

根据问题复杂度动态调整上下文:

问题复杂度 DDL数量 文档数量 SQL示例数量 Token限制
简单查询 1-2 0-1 1-2 2000
中等查询 2-3 1-2 2-3 4000
复杂查询 3-5 2-3 3-5 8000

4.3 多步生成机制

对于复杂查询,采用多步生成策略:

python 复制代码
def generate_complex_sql(self, question: str, max_steps: int = 3, kwargs) -> str:
    """
    多步复杂SQL生成
    Args:
        question: 用户问题
        max_steps: 最大生成步数
    Returns:
        str: 最终SQL
    """
    intermediate_sql = ""
    
    for step in range(max_steps):
        # 生成中间SQL
        prompt = self.build_intermediate_prompt(question, intermediate_sql, step)
        response = self.submit_prompt(prompt, kwargs)
        
        # 检查是否完成
        if "FINAL_SQL" in response:
            return self.extract_final_sql(response)
        
        # 执行中间SQL
        current_sql = self.extract_intermediate_sql(response)
        try:
            df = self.run_sql(current_sql)
            intermediate_sql = f"{intermediate_sql}\n-- Step {step+1} result:\n{df.head().to_markdown()}"
        except Exception as e:
            return f"步骤{step+1}错误: {str(e)}"
    
    return "超过最大生成步数"

5. 实验结果分析

使用 SEC 数据集测试不同策略的准确率:

策略 GPT-3.5 准确率 GPT-4 准确率 Bison 准确率 平均响应时间(秒)
无上下文 12% 18% 10% 2.1
静态示例 48% 62% 52% 3.5
上下文相关示例 76% 83% 79% 4.2
多步生成 82% 88% 85% 6.8

关键发现:

  • 上下文检索带来 60-70% 的准确率提升
  • GPT-4 在复杂查询中表现最佳
  • 多步生成显著提高复杂查询准确率

6. 生产环境最佳实践

6.1 训练数据管理

6.2 部署架构

6.3 监控指标

指标类别 具体指标 目标值 监控频率
准确性 SQL生成准确率 >85% 实时
性能 平均响应时间 <3s 每分钟
资源 Token使用量 <80%配额 每小时
业务 用户查询量 递增趋势 每天

7. 扩展与自定义

7.1 自定义 LLM 集成

python 复制代码
class CustomLLM(VannaBase):
    def __init__(self, config):
        self.api_url = config['api_url']
        self.api_key = config['api_key']
    
    def submit_prompt(self, messages, kwargs):
        # 转换消息格式
        prompt = self._format_messages(messages)
        
        # 调用自定义API
        response = requests.post(
            self.api_url,
            json={'prompt': prompt},
            headers={'Authorization': f'Bearer {self.api_key}'}
        )
        
        return response.json()['text']
    
    def _format_messages(self, messages):
        # 将消息列表转换为自定义格式
        return "\n".join(
            f"{msg['role'].upper()}: {msg['content']}" 
            for msg in messages
        )

7.2 自定义向量存储

python 复制代码
class CustomVectorStore(VannaBase):
    def __init__(self, config):
        self.db_host = config['host']
        self.db_port = config['port']
        self._connect()
    
    def _connect(self):
        # 连接向量数据库
        self.conn = psycopg2.connect(
            host=self.db_host,
            port=self.db_port,
            user='user',
            password='pass'
        )
    
    def add_document(self, content, embedding, doc_type):
        # 插入文档到向量数据库
        cursor = self.conn.cursor()
        cursor.execute(
            "INSERT INTO documents (content, embedding, doc_type) VALUES (%s, %s, %s) RETURNING id",
            (content, embedding, doc_type)
        )
        doc_id = cursor.fetchone()[0]
        self.conn.commit()
        return doc_id
    
    def query(self, embedding, top_k=5, doc_type=None):
        # 查询相似文档
        query = """
        SELECT id, content, doc_type 
        FROM documents 
        WHERE doc_type = COALESCE(%s, doc_type)
        ORDER BY embedding <-> %s
        LIMIT %s
        """
        cursor = self.conn.cursor()
        cursor.execute(query, (doc_type, embedding, top_k))
        return cursor.fetchall()

结论

通过 Vanna 框架的深度解析,我们展示了检索增强生成在 NLP2SQL 领域的强大应用。关键要点包括:

  • 上下文检索是提高 SQL 生成准确性的核心
  • 多类型训练数据(DDL、文档、SQL 示例)协同作用
  • 动态提示工程优化上下文选择
  • 多步生成机制解决复杂查询问题

实验表明,合理应用 RAG 技术可使 NLP2SQL 准确率从不足 20% 提升至 85% 以上,为数据分析领域带来革命性变革。

Vanna 框架的模块化设计使其易于扩展和定制,开发者可轻松集成不同的 LLM 和向量数据库,构建适应特定业务需求的 NLP2SQL 系统。随着大语言模型和向量检索技术的持续发展,检索增强的 SQL 生成将更加智能和高效,最终实现"用自然语言探索所有数据"的愿景。

参考文献

  • Vanna 官方文档. vanna.ai/docs/
  • Lewis P, et al. Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks. arXiv:2005.11401
  • Scholak T, et al. PICARD: Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models. EMNLP 2021
  • Rajkumar N, et al. Evaluating the Text-to-SQL Capabilities of Large Language Models. arXiv:2204.00498
  • Chase H. LangChain: Building applications with LLMs through composability. langchain.com
相关推荐
红衣小蛇妖10 分钟前
神经网络-Day44
人工智能·深度学习·神经网络
忠于明白10 分钟前
Spring AI 核心工作流
人工智能·spring·大模型应用开发·spring ai·ai 应用商业化
大写-凌祁1 小时前
论文阅读:HySCDG生成式数据处理流程
论文阅读·人工智能·笔记·python·机器学习
柯南二号1 小时前
深入理解 Agent 与 LLM 的区别:从智能体到语言模型
人工智能·机器学习·llm·agent
珂朵莉MM1 小时前
2021 RoboCom 世界机器人开发者大赛-高职组(初赛)解题报告 | 珂学家
java·开发语言·人工智能·算法·职场和发展·机器人
IT_陈寒1 小时前
Element Plus 2.10.0 重磅发布!新增Splitter组件
前端·人工智能·后端
jndingxin1 小时前
OpenCV CUDA模块图像处理------创建一个模板匹配(Template Matching)对象函数createTemplateMatching()
图像处理·人工智能·opencv
盛寒2 小时前
N元语言模型 —— 一文讲懂!!!
人工智能·语言模型·自然语言处理
weixin_177297220692 小时前
家政小程序开发——AI+IoT技术融合,打造“智慧家政”新物种
人工智能·物联网
Jay Kay2 小时前
ReLU 新生:从死亡困境到强势回归
人工智能·数据挖掘·回归