文章目录
- 前言
- 一、系统架构总览
-
- [1.1 核心流程](#1.1 核心流程)
- [1.2 系统组件](#1.2 系统组件)
- 二、Schema注入与语义理解
-
- [2.1 Schema信息设计](#2.1 Schema信息设计)
- [2.2 少样本学习(Few-shot Examples)](#2.2 少样本学习(Few-shot Examples))
- 三、SQL生成与验证
-
- [3.1 LLM调用实现](#3.1 LLM调用实现)
- [3.2 SQL安全验证器](#3.2 SQL安全验证器)
- [3.3 高级验证:权限控制](#3.3 高级验证:权限控制)
- 四、SQL执行与结果处理
-
- [4.1 安全执行器](#4.1 安全执行器)
- [4.2 结果解释器](#4.2 结果解释器)
- 五、完整流程集成
-
- [5.1 Agent主控制器](#5.1 Agent主控制器)
- 六、性能优化与监控
-
- [6.1 缓存策略](#6.1 缓存策略)
- [6.2 监控指标](#6.2 监控指标)
- 七、总结与进阶思考
-
- [7.1 核心设计要点](#7.1 核心设计要点)
- [7.2 演进路线](#7.2 演进路线)
- [7.3 常见问题与解决方案](#7.3 常见问题与解决方案)
- 写在最后:
前言
"帮我看看上个月北京卖得最好的5款手机是什么?"------当业务人员急需一份数据时,传统的流程是:提需求给数据部门 -> 排期 -> 写SQL -> 导出Excel -> 邮件发送。这个过程少则半天,多则数天。
有没有可能让业务人员用自然语言直接查询数据库?让AI理解业务语义、自动生成SQL、执行查询,最后用自然语言返回结果?这就是Text-to-SQL技术的核心价值。
本文将深入讲解如何设计一个生产级的数据分析Agent,涵盖:
- 核心流程:从自然语言到SQL再到自然语言的全链路
- 关键技术:Schema注入、SQL安全验证、错误处理
- 高级特性:少样本学习、复杂查询、结果可视化、数据权限
一、系统架构总览
1.1 核心流程
否
是
用户自然语言提问
如:上个月北京销售额
语义理解与实体抽取
LLM
Schema信息注入
表结构+字段描述
SQL生成
LLM + Few-shot
SQL安全验证
正则/语法树
验证通过?
错误反馈
重新生成
执行SQL
只读连接
结果处理
聚合/分页
自然语言解释
LLM
返回答案
文本/图表
1.2 系统组件
数据层
核心处理层
用户层
Web界面
API Gateway
IM机器人
意图理解
SQL生成引擎
SQL验证器
查询执行器
结果解释器
OLAP数据库
Schema缓存
样本库
二、Schema注入与语义理解
2.1 Schema信息设计
要让LLM生成准确的SQL,必须提供完整的数据库结构信息。但直接暴露真实表名和字段名可能不够直观,我们可以添加业务语义描述。
python
# 数据库Schema定义
database_schema = {
"tables": [
{
"name": "sales",
"description": "销售记录表,每笔订单一条记录",
"columns": [
{"name": "id", "type": "int", "description": "主键ID"},
{"name": "product_name", "type": "varchar", "description": "商品名称"},
{"name": "category", "type": "varchar", "description": "商品品类,如手机、电脑、配件"},
{"name": "amount", "type": "decimal", "description": "销售额,单位元"},
{"name": "sale_date", "type": "date", "description": "销售日期"},
{"name": "region", "type": "varchar", "description": "销售地区,如北京、上海、广州"},
{"name": "user_id", "type": "int", "description": "用户ID"}
],
"examples": [
{"product_name": "iPhone14", "category": "手机", "amount": 6999.00,
"sale_date": "2024-01-15", "region": "北京"}
]
},
{
"name": "products",
"description": "商品信息表",
"columns": [
{"name": "id", "type": "int", "description": "商品ID"},
{"name": "product_name", "type": "varchar", "description": "商品名称"},
{"name": "category", "type": "varchar", "description": "商品品类"},
{"name": "price", "type": "decimal", "description": "单价"}
]
}
],
"relationships": [
"sales.product_name 关联 products.product_name"
]
}
def format_schema_for_prompt(schema):
"""将Schema格式化为LLM友好的文本"""
prompt = "数据库结构如下:\n\n"
for table in schema["tables"]:
prompt += f"表名:{table['name']}({table['description']})\n"
prompt += "字段:\n"
for col in table["columns"]:
prompt += f" - {col['name']} ({col['type']}):{col['description']}\n"
# 添加示例数据
if "examples" in table:
prompt += "示例数据:\n"
for ex in table["examples"]:
prompt += f" {ex}\n"
prompt += "\n"
if "relationships" in schema:
prompt += "表间关系:\n"
for rel in schema["relationships"]:
prompt += f" - {rel}\n"
return prompt
2.2 少样本学习(Few-shot Examples)
仅靠Schema描述,LLM可能对复杂业务问法理解不足。通过提供常见问题的SQL示例,可以显著提升准确率。
python
few_shot_examples = [
{
"question": "上个月北京的销售额是多少?",
"sql": "SELECT SUM(amount) FROM sales WHERE region = '北京' AND sale_date >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND sale_date < DATE_TRUNC('month', CURRENT_DATE)"
},
{
"question": "销量前5的商品有哪些?",
"sql": "SELECT product_name, COUNT(*) as sales_count FROM sales GROUP BY product_name ORDER BY sales_count DESC LIMIT 5"
},
{
"question": "手机类目的平均售价是多少?",
"sql": "SELECT AVG(amount) FROM sales WHERE category = '手机'"
},
{
"question": "上季度每个地区的销售额排行",
"sql": "SELECT region, SUM(amount) as total FROM sales WHERE sale_date >= DATE_TRUNC('quarter', CURRENT_DATE - INTERVAL '3 month') AND sale_date < DATE_TRUNC('quarter', CURRENT_DATE) GROUP BY region ORDER BY total DESC"
}
]
def build_prompt(user_question, schema, examples):
"""构建完整的提示词"""
prompt = """你是一个数据分析专家,需要根据用户的自然语言问题生成SQL查询语句。
"""
# 添加Schema信息
prompt += format_schema_for_prompt(schema)
prompt += "\n"
# 添加示例
prompt += "以下是一些常见问题及其对应的SQL示例:\n"
for ex in examples:
prompt += f"问题:{ex['question']}\nSQL:{ex['sql']}\n\n"
# 当前问题
prompt += f"现在,请根据以上信息,为用户的问题生成SQL:\n{user_question}\n"
prompt += "注意:只返回SQL语句,不要其他解释。如果无法生成,返回'ERROR:原因'。"
return prompt
三、SQL生成与验证
3.1 LLM调用实现
python
import openai
import json
import re
class SQLGenerator:
def __init__(self, api_key, schema, examples=None):
self.client = openai.OpenAI(api_key=api_key)
self.schema = schema
self.examples = examples or few_shot_examples
def generate(self, user_question, retry_count=2):
"""生成SQL,支持重试"""
for attempt in range(retry_count):
try:
prompt = build_prompt(user_question, self.schema, self.examples)
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个专业的SQL生成助手,只返回SQL语句。"},
{"role": "user", "content": prompt}
],
temperature=0.1, # 低温度保证稳定性
max_tokens=500
)
sql = response.choices[0].message.content.strip()
# 提取SQL(可能包含markdown代码块)
sql = self.extract_sql(sql)
# 验证SQL
if self.validate(sql):
return sql
else:
print(f"SQL验证失败,第{attempt+1}次重试")
except Exception as e:
print(f"生成SQL异常:{e}")
return None
def extract_sql(self, text):
"""从响应中提取SQL(处理markdown代码块)"""
# 匹配 ```sql ... ```或 ```... ```
sql_pattern = r'```(?:sql)?\s*(.*?)\s*```'
matches = re.findall(sql_pattern, text, re.DOTALL)
if matches:
return matches[0].strip()
# 如果没有代码块,直接返回(假设就是SQL)
return text.strip()
3.2 SQL安全验证器
安全性是重中之重!必须确保生成的SQL是只读查询,不会修改数据。
python
import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML
class SQLValidator:
def __init__(self):
# 禁止的命令
self.forbidden_commands = [
'insert', 'update', 'delete', 'drop',
'alter', 'truncate', 'create', 'replace',
'grant', 'revoke'
]
# 只允许的DML命令
self.allowed_commands = ['select', 'with']
def validate(self, sql):
"""SQL安全验证主流程"""
# 1. 基础关键词检查
sql_lower = sql.lower()
for cmd in self.forbidden_commands:
# 确保是独立的单词,而不是字段名的一部分
if re.search(r'\b' + cmd + r'\b', sql_lower):
raise ValueError(f"禁止使用 {cmd} 命令")
# 2. 语法树解析验证
parsed = sqlparse.parse(sql)
if not parsed:
raise ValueError("SQL语法解析失败")
# 3. 验证是否为SELECT语句
statement = parsed[0]
if not self.is_select_only(statement):
raise ValueError("只允许SELECT查询语句")
# 4. 强制添加LIMIT(如果没有)
if 'limit' not in sql_lower:
sql += ' LIMIT 1000'
return sql
def is_select_only(self, statement):
"""验证语句是否为纯SELECT"""
for token in statement.tokens:
if token.ttype in Keyword and token.value.upper() not in ['SELECT', 'WITH', 'LIMIT', 'ORDER BY', 'GROUP BY', 'WHERE', 'HAVING', 'JOIN', 'ON', 'AND', 'OR', 'IN', 'BETWEEN', 'LIKE', 'AS', 'DISTINCT', 'FROM']:
# 如果是其他关键字,检查是否是SELECT相关的
if token.value.upper() not in ['UNION', 'INTERSECT', 'EXCEPT']:
return False
return True
3.3 高级验证:权限控制
在实际生产环境中,不同用户只能访问特定的数据。
python
class PermissionEnforcer:
def __init__(self, user_context):
self.user = user_context['user_id']
self.role = user_context['role']
self.regions = user_context.get('regions', [])
def enforce(self, sql):
"""根据用户权限改写SQL"""
if self.role == 'admin':
return sql # 管理员全权限
# 解析SQL,添加权限限制
parsed = sqlparse.parse(sql)[0]
# 根据用户可访问的地区添加WHERE条件
if self.regions:
# 检查是否已经有WHERE子句
has_where = False
for token in parsed.tokens:
if token.ttype is Keyword and token.value.upper() == 'WHERE':
has_where = True
# 在现有WHERE后面添加 AND region IN (...)
# 实际实现需要更复杂的语法树操作
pass
if not has_where:
# 添加WHERE region IN (...)
region_list = "', '".join(self.regions)
sql += f" WHERE region IN ('{region_list}')"
else:
# 在WHERE后面添加 AND条件
# 简化处理,实际应该用sqlparse操作语法树
sql = sql.replace("WHERE", f"WHERE region IN ('{region_list}') AND")
return sql
四、SQL执行与结果处理
4.1 安全执行器
python
import psycopg2
from contextlib import contextmanager
import time
class SafeExecutor:
def __init__(self, db_config):
self.db_config = db_config
self.validator = SQLValidator()
@contextmanager
def get_connection(self):
"""获取只读数据库连接"""
conn = psycopg2.connect(
**self.db_config,
options='-c default_transaction_read_only=on' # 强制只读
)
# 设置语句超时
with conn.cursor() as cur:
cur.execute("SET statement_timeout = '30s'")
try:
yield conn
finally:
conn.close()
def execute(self, sql, user_context=None):
"""执行SQL并返回结果"""
start_time = time.time()
try:
# 1. 安全验证
validated_sql = self.validator.validate(sql)
# 2. 权限控制
if user_context:
enforcer = PermissionEnforcer(user_context)
validated_sql = enforcer.enforce(validated_sql)
# 3. 执行查询
with self.get_connection() as conn:
with conn.cursor() as cur:
cur.execute(validated_sql)
# 获取列名
col_names = [desc[0] for desc in cur.description] if cur.description else []
# 获取结果(限制行数)
rows = cur.fetchmany(1000) # 最多1000行
# 如果还有更多,提示用户
has_more = len(rows) == 1000 and cur.fetchone() is not None
result = {
"success": True,
"columns": col_names,
"rows": rows,
"row_count": len(rows),
"has_more": has_more,
"execution_time": time.time() - start_time,
"sql": validated_sql
}
return result
except Exception as e:
return {
"success": False,
"error": str(e),
"execution_time": time.time() - start_time,
"sql": sql
}
4.2 结果解释器
将查询结果转换为自然语言,让业务人员更容易理解。
python
class ResultInterpreter:
def __init__(self, api_key):
self.client = openai.OpenAI(api_key=api_key)
def interpret(self, user_question, result):
"""将查询结果解释为自然语言"""
if not result["success"]:
return f"查询失败:{result['error']}"
# 构建上下文
context = f"""
用户问题:{user_question}
查询结果:
- 返回了 {result['row_count']} 行数据
- 字段:{', '.join(result['columns'])}
- 数据样例:
"""
# 添加前5行作为示例
for i, row in enumerate(result['rows'][:5]):
context += f" {i+1}. {dict(zip(result['columns'], row))}\n"
if result['has_more']:
context += "(还有更多数据未显示)\n"
context += f"\n执行时间:{result['execution_time']:.2f}秒"
# 调用LLM生成解释
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "你是一个数据分析助手,需要根据查询结果用自然语言回答用户的问题。回答要简洁、准确,如果有多个数据可以总结趋势。"},
{"role": "user", "content": context}
],
temperature=0.3
)
return response.choices[0].message.content
def generate_visualization(self, result):
"""生成可视化配置(ECharts格式)"""
if not result["success"] or result["row_count"] == 0:
return None
# 智能判断图表类型
columns = result["columns"]
rows = result["rows"]
# 如果有时间列和数值列,推荐折线图
date_cols = [col for col in columns if 'date' in col.lower() or 'time' in col.lower()]
numeric_cols = [col for col in columns if 'amount' in col.lower() or 'price' in col.lower() or 'count' in col.lower()]
if date_cols and numeric_cols:
# 折线图
chart_config = {
"type": "line",
"title": f"{numeric_cols[0]}趋势",
"xAxis": date_cols[0],
"series": [{
"name": numeric_cols[0],
"data": [row[columns.index(numeric_cols[0])] for row in rows]
}]
}
elif len(rows) <= 10 and numeric_cols:
# 条形图
chart_config = {
"type": "bar",
"title": f"{numeric_cols[0]}分布",
"xAxis": columns[0] if columns[0] != numeric_cols[0] else columns[1],
"series": [{
"name": numeric_cols[0],
"data": [row[columns.index(numeric_cols[0])] for row in rows]
}]
}
else:
# 默认表格
chart_config = {
"type": "table",
"columns": columns,
"data": rows[:20] # 最多20行
}
return chart_config
五、完整流程集成
5.1 Agent主控制器
python
class DataAnalysisAgent:
def __init__(self, openai_key, db_config, schema):
self.sql_generator = SQLGenerator(openai_key, schema)
self.executor = SafeExecutor(db_config)
self.interpreter = ResultInterpreter(openai_key)
def process(self, user_question, user_context=None):
"""处理用户请求"""
# 1. 生成SQL
sql = self.sql_generator.generate(user_question)
if not sql:
return {
"success": False,
"error": "无法理解您的问题,请换个说法试试",
"stage": "sql_generation"
}
# 2. 执行SQL
result = self.executor.execute(sql, user_context)
if not result["success"]:
# 如果是SQL错误,尝试重新生成
if "syntax" in result["error"].lower():
# 可以在这里加入错误重试逻辑
pass
return result
# 3. 解释结果
answer = self.interpreter.interpret(user_question, result)
# 4. 生成可视化
chart = self.interpreter.generate_visualization(result)
return {
"success": True,
"answer": answer,
"sql": sql,
"data": {
"columns": result["columns"],
"rows": result["rows"],
"row_count": result["row_count"]
},
"chart": chart,
"execution_time": result["execution_time"]
}
# 使用示例
agent = DataAnalysisAgent(
openai_key="your-key",
db_config={
"host": "localhost",
"database": "analytics",
"user": "readonly_user",
"password": "readonly_pass"
},
schema=database_schema
)
# 业务人员提问
result = agent.process(
"上个月北京地区销售额最高的3个品类是什么?",
user_context={"user_id": "1001", "role": "analyst", "regions": ["北京", "上海"]}
)
if result["success"]:
print(f"答案:{result['answer']}")
print(f"生成的SQL:{result['sql']}")
if result["chart"]:
print(f"图表配置:{result['chart']}")
else:
print(f"处理失败:{result['error']}")
六、性能优化与监控
6.1 缓存策略
python
import hashlib
import redis
class QueryCache:
def __init__(self, redis_client, ttl=3600):
self.redis = redis_client
self.ttl = ttl
def get_cache_key(self, user_question, user_context):
"""生成缓存键"""
content = f"{user_question}_{user_context.get('role')}_{user_context.get('regions')}"
return f"nl2sql:{hashlib.md5(content.encode()).hexdigest()}"
def get(self, user_question, user_context):
"""获取缓存"""
key = self.get_cache_key(user_question, user_context)
cached = self.redis.get(key)
return json.loads(cached) if cached else None
def set(self, user_question, user_context, result):
"""设置缓存"""
key = self.get_cache_key(user_question, user_context)
self.redis.setex(key, self.ttl, json.dumps(result))
6.2 监控指标
python
class MetricsCollector:
def __init__(self):
self.metrics = {
"total_queries": 0,
"success_queries": 0,
"avg_execution_time": 0,
"sql_generation_time": [],
"execution_time": [],
"error_types": {}
}
def record_query(self, start_time, result, stage="complete"):
"""记录查询指标"""
self.metrics["total_queries"] += 1
duration = time.time() - start_time
if result["success"]:
self.metrics["success_queries"] += 1
# 记录各阶段耗时
if stage == "sql_generation":
self.metrics["sql_generation_time"].append(duration)
elif stage == "execution":
self.metrics["execution_time"].append(duration)
# 记录错误类型
if not result["success"] and "error" in result:
error_type = type(result["error"]).__name__
self.metrics["error_types"][error_type] = \
self.metrics["error_types"].get(error_type, 0) + 1
七、总结与进阶思考
7.1 核心设计要点
| 维度 | 关键设计 | 生产建议 |
|---|---|---|
| Schema注入 | 字段业务描述 + 示例数据 | 定期更新示例,覆盖更多查询模式 |
| SQL生成 | Few-shot + 低温度 + 重试 | 建立SQL样本库,持续优化 |
| 安全验证 | 关键词过滤 + 语法树解析 | 强制只读连接,设置超时 |
| 权限控制 | 用户角色 + 数据维度 | 行级权限通过WHERE注入 |
| 结果解释 | 自然语言 + 可视化 | 根据数据特征智能选择图表 |
7.2 演进路线
V1
单表查询
简单聚合
V2
多表JOIN
子查询支持
V3
复杂分析
窗口函数
V4
多轮对话
上下文记忆
V5
主动洞察
异常预警
7.3 常见问题与解决方案
- SQL语法错误
- 问题:生成的SQL有语法错误
- 方案:捕获错误信息,作为上下文重新生成
- 歧义理解
- 问题:"销售额"可能指总金额、平均金额、订单数
- 方案:添加澄清对话,或从上下文中推断
- 性能问题
- 问题:复杂查询执行慢
- 方案:设置statement_timeout,返回部分结果
- 数据权限
- 问题:用户只能看自己部门的数据
- 方案:在SQL生成阶段注入权限条件
写在最后:
Text-to-SQL技术正在重塑数据分析的工作方式。从简单的单表查询到复杂的多维度分析,从被动回答到主动洞察,数据分析Agent的潜力远未被完全挖掘。