一、场景背景
后端开发中,经常遇到业务方/用户想用自然语言(如"查最近30天热销产品")查询数据库的需求。传统方案需硬编码SQL分支,维护成本高,而通过大模型将自然语言转SQL(NL2SQL),可实现动态、智能的数据查询,尤其适合电商、报表、运营分析等场景。
二、核心架构:4层闭环设计
用户输入自然语言
后端接口层:接收&预处理
大模型层:自然语言转SQL
SQL校验层:安全&语法检查
数据库执行SQL
结果转换层:SQL结果转自然语言
返回给用户/前端
三、分步实现细节
1. 第一步:后端接口层(以Go为例)
1.1 接收用户输入
用Gin框架写一个HTTP接口,接收用户的自然语言查询(需限制输入长度,避免过载):
go
package main
import (
"github.com/gin-gonic/gin"
"net/http"
)
func main() {
r := gin.Default()
// 自然语言查询接口
r.POST("/api/nl2sql/query", func(c *gin.Context) {
var req struct {
UserQuery string `json:"user_query" binding:"required,max=500"` // 限制最长500字符
UserId string `json:"user_id" binding:"required"` // 用于权限校验
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 后续步骤:调用大模型+校验SQL+查数据库...
})
r.Run(":8080")
}
1.2 预处理:补充上下文
用户输入可能模糊(如"热销产品"未说明时间范围),需补充默认上下文,提升大模型生成SQL的准确性:
go
// 补充上下文逻辑
func addContext(userQuery string) string {
// 1. 时间上下文:默认补充"最近30天"(可根据业务调整)
if contains(userQuery, "热销") && !contains(userQuery, "天") && !contains(userQuery, "月") {
userQuery += "(时间范围:最近30天)"
}
// 2. 业务上下文:补充数据库所属业务(如电商产品库)
userQuery += "【注:查询的是电商产品数据库,包含产品表(products)、订单表(orders)】"
return userQuery
}
// 辅助函数:判断字符串是否包含子串
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}
2. 第二步:大模型层(自然语言转SQL)
2.1 选模型&调用方式
- 轻量级场景:用DeepSeek 7B/1.5B(本地部署,需GPU:7B约需16G显存,1.5B约需4G显存)
- 便捷场景:调用DeepSeek API(无需本地部署,按token计费)
以调用DeepSeek API为例,需传入3个核心参数(参考CSDN NL2SQL最佳实践):
go
package main
import (
"bytes"
"encoding/json"
"net/http"
)
// 大模型请求体
type LLMRequest struct {
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"` // 0.1-0.3,越低越精准
}
// 大模型响应体
type LLMResponse struct {
Choices []struct {
Text string `json:"text"`
} `json:"choices"`
}
// 调用DeepSeek生成SQL
func generateSQL(userQueryWithCtx string) (string, error) {
// 1. 构造Prompt(关键:清晰表结构+查询需求+约束)
prompt := `请根据以下数据库信息生成MySQL兼容的SQL语句:
--------------------------
【表结构】
1. 表名:products(产品表)
- product_id (int, 主键):产品ID
- product_name (varchar):产品名称
- category (varchar):产品分类(如"家电""数码")
- price (decimal):单价(元)
2. 表名:orders(订单表)
- order_id (int, 主键):订单ID
- product_id (int):关联products.product_id
- order_time (datetime):下单时间(格式:YYYY-MM-DD HH:MM:SS)
- sales_num (int):销售数量(单订单)
- status (varchar):订单状态("已支付""已取消",仅统计"已支付")
--------------------------
【查询需求】` + userQueryWithCtx + `
--------------------------
【约束】
1. 仅返回可直接执行的MySQL SQL语句,无需解释;
2. 涉及时间范围需用BETWEEN或>=/<+具体日期(如最近30天:order_time >= DATE_SUB(CURDATE(), INTERVAL 30 DAY));
3. 销售数量需汇总(如SUM(sales_num)),热销定义为"销售数量总和前10";
4. 多表关联用显式JOIN,避免隐式连接。`
// 2. 调用DeepSeek API(需替换为实际API地址和密钥)
apiKey := "YOUR_DEEPSEEK_API_KEY"
apiURL := "https://api.deepseek.com/v1/chat/completions"
reqBody, _ := json.Marshal(LLMRequest{
Prompt: prompt,
MaxTokens: 500,
Temperature: 0.2,
})
client := &http.Client{}
req, _ := http.NewRequest("POST", apiURL, bytes.NewBuffer(reqBody))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
// 3. 解析响应,提取SQL
var llmResp LLMResponse
json.NewDecoder(resp.Body).Decode(&llmResp)
if len(llmResp.Choices) == 0 {
return "", fmt.Errorf("大模型未返回SQL")
}
return llmResp.Choices[0].Text, nil
}
2.2 示例:用户输入→生成SQL
用户输入:"查最近热销产品"
补充上下文后:"查最近热销产品(时间范围:最近30天)【注:查询的是电商产品数据库...】"
大模型生成的SQL:
sql
SELECT
p.product_id,
p.product_name,
p.category,
SUM(o.sales_num) AS total_sales
FROM products p
INNER JOIN orders o ON p.product_id = o.product_id
WHERE o.status = '已支付'
AND o.order_time >= DATE_SUB(CURDATE(), INTERVAL 30 DAY)
GROUP BY p.product_id, p.product_name, p.category
ORDER BY total_sales DESC
LIMIT 10;
3. 第三步:SQL校验层(安全+准确性双检)
必须加! 避免大模型生成危险SQL(如DROP、UPDATE)或语法错误SQL,分2层校验:
3.1 安全校验(防注入&敏感操作)
- 用正则匹配危险关键字(DROP、DELETE、UPDATE等);
- 基于RBAC权限校验(如普通用户禁止查敏感字段):
go
import (
"regexp"
"github.com/go-sql-driver/mysql"
)
// 安全校验SQL
func checkSQLSecurity(sql string, userId string) error {
// 1. 禁止危险操作
dangerPattern := regexp.MustCompile(`(?i)DROP|DELETE|UPDATE|TRUNCATE|ALTER`)
if dangerPattern.MatchString(sql) {
return fmt.Errorf("禁止执行危险SQL:%s", dangerPattern.FindString(sql))
}
// 2. 权限校验(示例:普通用户禁止查单价字段)
isAdmin := checkUserIsAdmin(userId) // 从权限系统获取用户角色
if !isAdmin && strings.Contains(sql, "price") {
return fmt.Errorf("无权限查询产品单价")
}
// 3. 防止全表扫描(如无WHERE条件的SELECT *)
if strings.Contains(sql, "SELECT *") && !strings.Contains(sql, "WHERE") {
return fmt.Errorf("禁止无条件全表查询,请补充筛选条件")
}
return nil
}
3.2 语法&逻辑校验
- 用MySQL解析库(如
go-sqlparser)检查SQL语法; - 验证表名、字段名是否存在于数据库:
go
import (
"github.com/xwb1989/sqlparser"
)
// 语法&逻辑校验
func checkSQLValidity(sql string) error {
// 1. 解析SQL,检查语法错误
stmt, err := sqlparser.Parse(sql)
if err != nil {
return fmt.Errorf("SQL语法错误:%s", err.Error())
}
// 2. 提取涉及的表名,验证是否存在(需从数据库元数据获取)
tables := getTablesFromStmt(stmt) // 自定义函数:从AST中提取表名
existingTables := []string{"products", "orders"} // 实际从INFORMATION_SCHEMA.TABLES查询
for _, tbl := range tables {
if !contains(existingTables, tbl) {
return fmt.Errorf("表不存在:%s", tbl)
}
}
// 3. 验证聚合函数是否正确(如热销需SUM(sales_num))
if strings.Contains(sql, "热销") && !strings.Contains(sql, "SUM(sales_num)") {
return fmt.Errorf("热销查询需汇总销售数量(SUM(sales_num))")
}
return nil
}
// 辅助函数:从SQL语句中提取表名
func getTablesFromStmt(stmt sqlparser.Statement) []string {
var tables []string
sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.AliasedTableExpr:
if tn, ok := n.Expr.(sqlparser.TableName); ok {
tables = append(tables, tn.Name.String())
}
}
return true, nil
}, stmt)
return tables
}
4. 第四步:执行SQL&转换结果
4.1 执行校验后的SQL
用Go的database/sql库连接MySQL执行SQL(注意用参数化查询,避免注入):
go
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
// 执行SQL并获取结果
func executeSQL(sql string) ([]map[string]interface{}, error) {
// 连接数据库(实际需用配置中心管理DSN)
dsn := "user:password@tcp(127.0.0.1:3306)/ecommerce?parseTime=true"
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
defer db.Close()
// 执行SQL
rows, err := db.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
// 获取列名
cols, _ := rows.Columns()
// 遍历结果,转为map(方便后续转自然语言)
var results []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(cols))
valuePtrs := make([]interface{}, len(cols))
for i := range values {
valuePtrs[i] = &values[i]
}
rows.Scan(valuePtrs...)
row := make(map[string]interface{})
for i, col := range cols {
row[col] = values[i]
}
results = append(results, row)
}
return results, nil
}
4.2 结果转自然语言(再调用大模型)
将SQL返回的结构化数据(如[{"product_name":"iPhone 15","total_sales":1200},...])转为用户易懂的自然语言:
go
// 结果转自然语言
func resultToNaturalLang(results []map[string]interface{}, userQuery string) (string, error) {
// 构造Prompt:告诉大模型需整理结果
prompt := fmt.Sprintf(`请将以下SQL查询结果整理成自然语言,回答用户问题:
--------------------------
【用户问题】%s
【查询结果】%v
--------------------------
【要求】
1. 语言简洁,分点列出(如"1. XX产品:销量XX件");
2. 保留关键数据(产品名、销量),无需提及SQL细节;
3. 若结果为空,提示"暂无符合条件的热销产品"。`, userQuery, results)
// 调用DeepSeek API(同2.1的调用逻辑)
// ...(省略重复的API调用代码)
llmResp.Choices[0].Text // 示例返回:
/*
最近30天热销产品如下:
1. iPhone 15:销量1200件
2. 小米14:销量980件
3. 华为Mate 60 Pro:销量850件
...(共10款)
*/
return llmResp.Choices[0].Text, nil
}
四、关键注意事项
1. 性能优化
- 缓存高频SQL:对重复查询(如"今日热销")缓存生成的SQL,避免重复调用大模型;
- 异步执行:复杂SQL(如跨月统计)用消息队列异步处理,返回任务ID给用户,查询完成后通知。
2. 数据安全
- 脱敏处理:若结果含敏感数据(如单价),需按用户权限脱敏(如普通用户显示"≥1000元");
- 审计日志:记录所有用户查询(用户ID、自然语言、生成的SQL、结果),用于追溯。
3. 准确性提升
- 微调大模型 :若通用大模型生成SQL不准确,可用企业内部SQL样本(如历史报表SQL)微调DeepSeek 7B,步骤参考:
- 准备样本:
[{"自然语言":"查家电类热销产品"},{"SQL":"SELECT...WHERE category='家电'..."}] - 用PEFT库进行LoRA微调(需GPU);
- 部署微调后的模型替换通用API。
- 准备样本:
五、完整流程串联代码
go
// 完整接口逻辑
r.POST("/api/nl2sql/query", func(c *gin.Context) {
// 1. 接收输入
var req struct {
UserQuery string `json:"user_query" binding:"required,max=500"`
UserId string `json:"user_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 2. 补充上下文
userQueryWithCtx := addContext(req.UserQuery)
// 3. 生成SQL
sql, err := generateSQL(userQueryWithCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成SQL失败:" + err.Error()})
return
}
// 4. 校验SQL
if err := checkSQLSecurity(sql, req.UserId); err != nil {
c.JSON(http.StatusForbidden, gin.H{"error": "SQL安全校验失败:" + err.Error()})
return
}
if err := checkSQLValidity(sql); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "SQL逻辑错误:" + err.Error()})
return
}
// 5. 执行SQL
results, err := executeSQL(sql)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "执行SQL失败:" + err.Error()})
return
}
// 6. 结果转自然语言
nlResult, err := resultToNaturalLang(results, req.UserQuery)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "结果转换失败:" + err.Error()})
return
}
// 7. 返回响应
c.JSON(http.StatusOK, gin.H{
"result": nlResult,
"sql": sql, // 可选:给技术用户展示SQL
})
})
六、总结
通过"自然语言→大模型转SQL→校验执行→结果转自然语言"的闭环,后端可快速实现智能数据查询功能,核心优势:
- 无需硬编码SQL分支,适配动态动态