- 添加依赖
xml
<properties>
<java.version>21</java.version>
<spring-ai.version>1.0.1</spring-ai.version>
<spring-boot.version>3.4.3</spring-boot.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-mcp-server</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-deepseek</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
<version>${spring-boot.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>${spring-boot.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
<version>${spring-boot.version}</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.33</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.36</version>
</dependency>
</dependencies>
- 编写tools
java
@Slf4j
@Service
public class Nl2SqlTools {
private static final String url = "jdbc:mysql://localhost:3306/";
private static final String user = "root";
private static final String password = "iwzf.88xa([";
private static <T> T query(String sql, String db, RowMapper<T> mapper) {
try (Connection connection = DriverManager.getConnection(url + db, user, password);
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(sql);) {
log.info("query sql: {}", sql);
return mapper.mapRow(resultSet);
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
@Tool(description = "Get tables in a database")
public List<String> getTables(String db) {
log.info("getTables db: {}", db);
String sql = "show tables";
return query(sql, db, rs -> {
List<String> tables = new ArrayList<>();
while (rs.next()) {
tables.add(rs.getString(1));
}
return tables;
});
}
@Tool(description = "Get columns in a table")
public List<Column> getColumns(String db, String table) {
String sql = "SELECT column_name, data_type, column_comment FROM information_schema.columns WHERE table_name ='%s'";
log.info("getColumns db: {}, table: {}", db, table);
return query(String.format(sql, table), db, rs -> {
List<Column> columns = new ArrayList<>();
while (rs.next()) {
Column column = new Column();
column.setName(rs.getString(1));
column.setType(rs.getString(2));
column.setComment(rs.getString(3));
columns.add(column);
}
return columns;
});
}
@Tool(description = "Get results of a SQL query")
public List<Map<String, Object>> getResults(String db, String sql) {
log.info("getResults db: {}, sql: {}", db, sql);
return query(sql, db, rs -> {
List<Map<String, Object>> results = new ArrayList<>();
int columnCount = rs.getMetaData().getColumnCount();
while (rs.next()) {
Map<String, Object> row = new java.util.HashMap<>();
for (int i = 1; i <= columnCount; i++) {
row.put(rs.getMetaData().getColumnName(i), rs.getObject(i));
}
results.add(row);
}
return results;
});
}
@Getter
@Setter
public static class Column {
private String name;
private String type;
private String comment;
}
public static interface RowMapper<T> {
T mapRow(ResultSet rs) throws SQLException;
}
}
- 编写用于测试的controller
java
@RestController
public class Nl2SqlController {
private final ChatClient client;
public Nl2SqlController(ChatClient.Builder builder) {
this.client = builder
.build();
}
@GetMapping("/nl2sql")
public String nl2sql(String nl) {
return client.prompt("用户的自然语言问题: " + nl + ",请将用户的自然语言问题转换为SQL查询语句,SQL查询语句必须符合Doris语法规范,SQL查询语句只能返回SQL语句本身,不能包含任何多余的文字说明。" +
"1. SQL查询语句必须符合Doris语法规范。" +
"2. 从用户的提问中,提取到数据库名。" +
"3. 根据数据库名,先获取到所有表,用户选择表之后,再根据表名获取所有的字段信息。" +
"4. 然后根据字段信息,生成SQL查询语句。" +
"5. 生成的Sql, 在调用执行sql 的方法,获取数据。")
.user(nl)
.tools(new Nl2SqlTools())
.call()
.content();
}
}
- 然后大家就可以启动整个spring boot项目,调用
http://localhost:8080/nl2sql
接口进行测试了