一、前言
在上一节nl2sql技术实现自动sql生成之Spring AI Alibaba Nl2sql已经通过spring ai alibaba的技术栈实现了一个简单的nl2sql的组件,下面我将通过langchain4j ai开发框架实现chatbi的功能(既能生成sql,并且基于sql去查询对应库查出结果,由大模型基于问题和查询结果生成最终回答)。
二、环境要求
jdk:17+
langchain4j:1.10.0
h2:2.3+
三、maven依赖
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>1.10.0</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community-dashscope</artifactId>
<version>1.10.0-beta18</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-experimental-sql</artifactId>
<version>1.10.0-beta18</version>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>2.3.232</version>
</dependency>
说明:
1)langchain4j:langchain核心组件
2)langchain4j-community-dashscope:langchain4j集成dashscope相关模型api
3)langchain4j-experimental-sql:构建表结构和根据用户问题生成sql并执行,得出结果提供给大模型
四、具体实现思路
1、测试数据准备
需要准备2个文件,在resources目录下建立一个sql文件夹,并将文件create_tables.sql和prefill_tables.sql至于该文件中。
1)文件1:create_tables.sql
CREATE TABLE customers ( customer_id INT PRIMARY KEY, first_name VARCHAR(50), last_name VARCHAR(50), email VARCHAR(100) ); CREATE TABLE products ( product_id INT PRIMARY KEY, product_name VARCHAR(100), price DECIMAL(10, 2) ); CREATE TABLE orders ( order_id INT PRIMARY KEY, customer_id INT, product_id INT, quantity INT, order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers (customer_id), FOREIGN KEY (product_id) REFERENCES products (product_id) );
2)文件2:prefill_tables.sql
INSERT INTO customers (customer_id, first_name, last_name, email) VALUES (1, 'John', 'Doe', 'john.doe@example.com'), (2, 'Jane', 'Smith', 'jane.smith@example.com'), (3, 'Alice', 'Johnson', 'alice.johnson@example.com'), (4, 'Bob', 'Williams', 'bob.williams@example.com'), (5, 'Carol', 'Brown', 'carol.brown@example.com'); INSERT INTO products (product_id, product_name, price) VALUES (10, 'Notebook', 12.99), (20, 'Pen', 1.50), (30, 'Desk Lamp', 23.99), (40, 'Backpack', 49.99), (50, 'Stapler', 7.99); INSERT INTO orders (order_id, customer_id, product_id, quantity, order_date) VALUES (100, 1, 10, 2, '2024-04-20'), (200, 2, 20, 5, '2024-04-21'), (300, 3, 10, 1, '2024-04-22'), (400, 4, 30, 1, '2024-04-23'), (500, 5, 40, 1, '2024-04-24'), (600, 1, 50, 3, '2024-04-25'), (700, 2, 10, 2, '2024-04-26'), (800, 3, 40, 1, '2024-04-27'), (900, 4, 20, 10, '2024-04-28'), (10000, 5, 30, 2, '2024-04-29');
2、代码实现
2.1、创建H2数据源,并执行上面那2个文件中的sql,初始化数据
java
private static DataSource createDataSource() {
JdbcDataSource dataSource = new JdbcDataSource();
dataSource.setURL("jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");
dataSource.setUser("sa");
dataSource.setPassword("sa");
String createTablesScript = read("sql/create_tables.sql");
execute(createTablesScript, dataSource);
String prefillTablesScript = read("sql/prefill_tables.sql");
execute(prefillTablesScript, dataSource);
return dataSource;
}
private static String read(String path) {
try {
return new String(Files.readAllBytes(toPath(path)));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static void execute(String sql, DataSource dataSource) {
try (Connection connection = dataSource.getConnection(); Statement statement = connection.createStatement()) {
for (String sqlStatement : sql.split(";")) {
statement.execute(sqlStatement.trim());
}
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
2.2、创建一个nl2sql的ai助手,使用阿里通义千问作为大模型进行内容生成,通过SqlDatabaseContentRetriever连接数据库,根据用户问题生成sql并执行sql,大模型根据返回结果生成内容。
java
private static Assistant createAssistant() {
DataSource dataSource = createDataSource();
ChatModel chatModel = QwenChatModel.builder()
.apiKey("your-api-key")
.modelName("qwen-flash")
.build();
ContentRetriever contentRetriever = SqlDatabaseContentRetriever.builder()
.dataSource(dataSource)
.chatModel(chatModel)
.build();
return AiServices.builder(Assistant.class)
.chatModel(chatModel)
.contentRetriever(contentRetriever)
.chatMemory(MessageWindowChatMemory.withMaxMessages(10))
.build();
}
2.3、写一个main程序测试程序
java
public static void main(String[] args) {
Assistant assistant = createAssistant();
startConversationWith(assistant);
}
public static void startConversationWith(Assistant assistant) {
Logger log = LoggerFactory.getLogger(Assistant.class);
try (Scanner scanner = new Scanner(System.in)) {
while (true) {
log.info("==================================================");
log.info("User: ");
String userQuery = scanner.nextLine();
log.info("==================================================");
if ("exit".equalsIgnoreCase(userQuery)) {
break;
}
String agentAnswer = assistant.answer(userQuery);
log.info("==================================================");
log.info("Assistant: " + agentAnswer);
}
}
}
测试验证结果截图如下:
