文章目录
0.前置准备
1.安装ollama
然后安装如下三个模型
shell
C:\Users\GT-JYW-3>ollama list
NAME ID SIZE MODIFIED
qwen3-embedding:0.6b ac6da0dfba84 639 MB 8 hours ago
qwen3.5:2b 324d162be6ca 2.7 GB 25 hours ago
deepseek-r1:1.5b e0979632db5a 1.1 GB 25 hours ag
第一个是检索用的,后面两个对话用,根据自己需要安装别的也行
对话需要带有(tool),检索需要带有(embedding)
2.下载Qdrant
Qdrant是一个本地轻型的向量数据库
下载地址: https://github.com/qdrant/qdrant/releases
windows直接双击启动就行
3.环境准备
jdk17,然后下面就是完整的代码,直接复制粘贴就可以用
对话直接访问 ip:port/chat?msg=你的问题。 流式返回就是chatS?msg=
再加一个&sessionId可以基于redis进行会话存储(使用的是DB:6)
文件检索这个我写死了个目录,可以自行修改和扩展在KnowledgeBaseConfig.java>defaultPath
然后访问ip:port/file/scan 先进行索引创建 然后再/file/chat?msg=你的问题即可
代码99%都是AI写的,本人亲测是可以读取文件夹里面PPT,MD,DOCX内容的
只是模型太笨,可以换个更大的模型,这里用的几个都是很小的来测试的
1.整体结构

2.pom文件
xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.3.0</version>
<relativePath/>
</parent>
<groupId>com.cxl</groupId>
<artifactId>springai-demo</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>springai-demo</name>
<description>springai-demo</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerConnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>17</java.version>
<spring-ai.version>1.0.0-M6</spring-ai.version>
</properties>
<repositories>
<repository>
<id>spring-milestones</id>
<name>Spring Milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-qdrant-store-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version> <!-- 替换为实际版本 -->
</dependency>
<!-- Spring Boot Web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- Spring AI Core -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
</dependency>
<!-- Spring AI Ollama Starter -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<!-- Spring Boot DevTools -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<scope>runtime</scope>
<optional>true</optional>
</dependency>
<!-- Lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!-- Apache POI for PPTX extraction -->
<dependency>
<groupId>org.apache.poi</groupId>
<artifactId>poi-ooxml</artifactId>
<version>5.2.5</version>
</dependency>
<!-- Spring Data Redis -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<!-- Spring AOP -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- Test -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
</path>
<path>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
3.application.yaml文件
yaml
# ===================================================================
# Spring AI Demo 配置文件
# ===================================================================
# 服务器配置
server:
port: 12000 # 应用端口
# Spring 配置
spring:
application:
name: springai-demo # 应用名称
# Redis 配置 - 用于存储对话历史
data:
redis:
host: localhost # Redis 服务器地址
port: 6379 # Redis 端口
database: 6 # Redis 数据库编号 (0-15)
# AI 大模型配置
ai:
# Ollama 本地大模型服务配置
ollama:
base-url: http://localhost:11434 # Ollama API 地址
embedding:
model: qwen3-embedding:0.6b # 嵌入模型 (用于 RAG 向量化)
enabled: true
# 嵌入向量配置
embedding:
options:
model: qwen3-embedding:0.6b # 嵌入模型名称
# 向量数据库配置 (Qdrant)
vectorstore:
qdrant:
host: localhost # Qdrant 服务地址
port: 6334 # Qdrant gRPC 端口 (REST API 端口为 6333)
collection-name: my_docs # 向量集合名称
# ===================================================================
# 知识库配置 (Knowledge Base)
# 用途: FileSearchService 用于 RAG 知识检索
# ===================================================================
knowledge-base:
default-path: "C:\\Users\\GT-JYW-3\\Documents\\doct" # 知识库文件默认目录
collection-name: my_docs # Qdrant 向量集合名称
host: localhost # Qdrant 服务地址
port: 6334 # Qdrant 端口
top-k: 5 # 知识检索返回的相关文档数量
# ===================================================================
# 对话历史配置 (Chat History)
# 用途: ChatHistoryService + ChatHistoryAspect 用于存储会话上下文
# ===================================================================
chat-history:
enabled: true # 是否启用对话历史功能
key-prefix: "chat:history:" # Redis Key 前缀
max-size: 20 # 每个会话最大保存的消息数量
expire-days: 7 # 历史记录过期天数
# ===================================================================
# System 预制词配置 (AI 角色设定)
# 用途: ChatService 调用不同模型时使用的系统提示词
# ===================================================================
system:
settings:
# 默认提示词 (当模型没有特定提示词时使用)
default-prompt: "你是一个有帮助的AI助手,请用简洁专业的语言回答用户问题。"
# 针对不同模型的特定提示词
prompts:
# Qwen 模型提示词
qwen3.5:2b: "你是Qwen3.5模型,一个由阿里云开发的大语言模型。你擅长中文理解和生成,请用专业、准确的语言回答用户问题。"
# DeepSeek 模型提示词
deepseek-r1:1.5b: "你是DeepSeek-R1模型,一个专注于推理和思考的AI助手。请在回答问题时展示清晰的思考过程,提供深入的分析。"
# ===================================================================
# 应用全局配置
# ===================================================================
app:
context-aware-enabled: true # 是否启用上下文感知功能
default-top-k: 3 # 默认知识检索数量
session-timeout-minutes: 60 # 会话超时时间(分钟)
4.启用对话记忆注解
java
package com.cxl.annotation;
import java.lang.annotation.*;
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface EnableChatHistory {
boolean enabled() default true;
}
5.对话记忆注解切面
java
package com.cxl.aspect;
import com.cxl.annotation.EnableChatHistory;
import com.cxl.service.ChatHistoryService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
import java.lang.reflect.Method;
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor
public class ChatHistoryAspect {
private final ChatHistoryService chatHistoryService;
@Around("@annotation(com.cxl.annotation.EnableChatHistory)")
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
EnableChatHistory annotation = method.getAnnotation(EnableChatHistory.class);
if (!annotation.enabled()) {
return joinPoint.proceed();
}
Object[] args = joinPoint.getArgs();
String[] paramNames = signature.getParameterNames();
String sessionId = null;
String userMessage = null;
for (int i = 0; i < args.length; i++) {
String paramName = paramNames[i];
if ("sessionId".equals(paramName) && args[i] != null) {
sessionId = (String) args[i];
}
if ("msg".equals(paramName) && args[i] != null) {
userMessage = (String) args[i];
}
}
if (sessionId == null || sessionId.isEmpty()) {
return joinPoint.proceed();
}
final String finalSessionId = sessionId;
final String finalUserMessage = userMessage;
log.info("========== ChatHistoryAspect ==========");
log.info("SessionId: {}", finalSessionId);
log.info("UserMessage: {}", finalUserMessage);
Object result = joinPoint.proceed();
if (result instanceof Flux) {
Flux<String> fluxResult = (Flux<String>) result;
return fluxResult
.collectList()
.doOnNext(chunks -> {
String fullResponse = String.join("", chunks);
chatHistoryService.addMessage(finalSessionId, "user", finalUserMessage);
chatHistoryService.addMessage(finalSessionId, "assistant", fullResponse);
log.info("Saved stream chat history to Redis");
})
.flatMapMany(chunks -> Flux.fromIterable(chunks));
} else if (result instanceof String) {
chatHistoryService.addMessage(finalSessionId, "user", finalUserMessage);
chatHistoryService.addMessage(finalSessionId, "assistant", (String) result);
log.info("Saved chat history to Redis");
}
return result;
}
}
6.配置类
1.AIConfig
java
package com.cxl.config;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import java.util.List;
@Configuration
public class AIConfig {
@Bean
public OllamaApi ollamaApi() {
return new OllamaApi("http://localhost:11434");
}
@Bean
public ObservationRegistry observationRegistry() {
return ObservationRegistry.NOOP;
}
@Bean
@Primary
public ChatModel qwenChatModel(OllamaApi ollamaApi, ObservationRegistry observationRegistry) {
return new OllamaChatModel(
ollamaApi,
OllamaOptions.builder().model("qwen3.5:2b").build(),
null,
List.of(),
observationRegistry,
ModelManagementOptions.builder().build()
);
}
@Bean("deepseekChatModel")
public ChatModel deepseekChatModel(OllamaApi ollamaApi, ObservationRegistry observationRegistry) {
return new OllamaChatModel(
ollamaApi,
OllamaOptions.builder().model("deepseek-r1:1.5b").build(),
null,
List.of(),
observationRegistry,
ModelManagementOptions.builder().build()
);
}
@Bean
public EmbeddingModel embeddingModel(OllamaApi ollamaApi, ObservationRegistry observationRegistry) {
return new OllamaEmbeddingModel(
ollamaApi,
OllamaOptions.builder().model("qwen3-embedding:0.6b").build(),
observationRegistry,
ModelManagementOptions.builder().build()
);
}
}
2.AppConfig
java
package com.cxl.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
/**
* 全局配置类
* 管理应用的全局开关和配置
*/
@Data
@Component
@ConfigurationProperties(prefix = "app")
public class AppConfig {
/**
* 上下文感知全局开关
* true: 启用上下文感知
* false: 禁用上下文感知
*/
private boolean contextAwareEnabled = true;
/**
* 文件检索默认返回数量
*/
private int defaultTopK = 3;
/**
* 会话超时时间(分钟)
*/
private int sessionTimeoutMinutes = 60;
/**
* 文件扫描的最大文件大小(字节)
*/
private long maxFileSize = 10 * 1024 * 1024; // 10MB
/**
* 支持的文件类型
*/
private String[] supportedFileTypes = {
"txt", "md", "json", "xml", "html", "css", "js", "java",
"pdf", "pptx", "ppt",
"jpg", "jpeg", "png", "gif", "webp",
"mp4", "avi", "mov", "wmv"
};
}
3.对话客户端配置
java
package com.cxl.config;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
@Configuration
public class ChatClientConfig {
@Bean
@Primary
public ChatClient qwenChatClient(ChatModel chatModel) {
return ChatClient.builder(chatModel).build();
}
@Bean("deepseekChatClient")
public ChatClient deepseekChatClient(@Qualifier("deepseekChatModel") ChatModel chatModel) {
return ChatClient.builder(chatModel).build();
}
}
4.知识库配置(向量数据库)
java
package com.cxl.config;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;
@Configuration
@ConfigurationProperties(prefix = "knowledge-base")
public class KnowledgeBaseConfig {
private String defaultPath = "C:\\Users\\GT-JYW-3\\Documents\\doct";
private String collectionName = "my_docs";
private String host = "localhost";
private int port = 6334;
private int topK = 5;
public String getDefaultPath() {
return defaultPath;
}
public void setDefaultPath(String defaultPath) {
this.defaultPath = defaultPath;
}
public String getCollectionName() {
return collectionName;
}
public void setCollectionName(String collectionName) {
this.collectionName = collectionName;
}
public String getHost() {
return host;
}
public void setHost(String host) {
this.host = host;
}
public int getPort() {
return port;
}
public void setPort(int port) {
this.port = port;
}
public int getTopK() {
return topK;
}
public void setTopK(int topK) {
this.topK = topK;
}
}
5.系统配置
java
package com.cxl.config;
import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.Map;
/**
* System预制词配置类
* 用于管理不同模型的System预制词
*/
@Data
@Component
@ConfigurationProperties(prefix = "system.settings")
public class SystemSettings {
/**
* 不同模型的System预制词配置
* key: 模型名称
* value: System预制词内容
*/
private Map<String, String> prompts = new HashMap<>();
/**
* 默认System预制词
*/
private String defaultPrompt = "你是一个有帮助的AI助手。";
/**
* 根据模型名称获取对应的System预制词
* @param modelName 模型名称
* @return System预制词
*/
public String getSystemPrompt(String modelName) {
return prompts.getOrDefault(modelName, defaultPrompt);
}
}
7.接口
1.基本对话&流式返回
java
package com.cxl.controller;
import com.cxl.annotation.EnableChatHistory;
import com.cxl.service.ChatService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
@RestController
@RequestMapping
@RequiredArgsConstructor
public class ChatController {
private final ChatService chatService;
@EnableChatHistory
@GetMapping("/chat")
public String chat(
@RequestParam String msg,
@RequestParam(required = false) String sessionId,
@RequestParam(defaultValue = "qwen") String model) {
return chatService.chat(msg, model, sessionId);
}
@EnableChatHistory
@GetMapping(value = "/chatS", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> chatStream(
@RequestParam String msg,
@RequestParam(required = false) String sessionId,
@RequestParam(defaultValue = "qwen") String model) {
return chatService.chatStreamRaw(msg, model, sessionId);
}
}
2.基于文件目录对话
java
package com.cxl.controller;
import com.cxl.annotation.EnableChatHistory;
import com.cxl.service.ChatHistoryService;
import com.cxl.service.ChatService;
import com.cxl.service.FileSearchService;
import com.cxl.service.FileSearchService.DocumentInfo;
import com.cxl.service.SessionManagerService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/file")
@RequiredArgsConstructor
public class FileSearchController {
private final FileSearchService fileSearchService;
private final ChatService chatService;
private final SessionManagerService sessionManagerService;
private final ChatHistoryService chatHistoryService;
@GetMapping("/scan")
public Map<String, Object> scanDirectory(
@RequestParam(required = false) String directory,
@RequestParam(required = false) String sessionId) {
if (sessionId != null) {
SessionManagerService.SessionInfo session = sessionManagerService.getOrCreateSession(sessionId);
session.updateLastActivity();
}
int fileCount = fileSearchService.scanDirectory(directory);
Map<String, Object> result = new HashMap<>();
result.put("status", "success");
result.put("message", "文件夹扫描完成");
result.put("fileCount", fileCount);
return result;
}
@EnableChatHistory
@GetMapping("/chat")
public String fileChat(
@RequestParam String msg,
@RequestParam String sessionId,
@RequestParam(defaultValue = "qwen") String model,
@RequestParam(required = false) Boolean contextAware,
@RequestParam(defaultValue = "3") int topK,
@RequestParam(required = false) String fileType) {
SessionManagerService.SessionInfo session = sessionManagerService.getOrCreateSession(sessionId);
session.updateLastActivity();
boolean useContextAware = session.isContextAware();
if (contextAware != null) {
useContextAware = contextAware;
session.setContextAware(useContextAware);
}
Map<String, Object> filters = new HashMap<>();
if (fileType != null && !fileType.isEmpty()) {
filters.put("filetype", fileType);
}
List<DocumentInfo> documents = fileSearchService.search(msg, topK, filters);
StringBuilder context = new StringBuilder();
if (useContextAware && !documents.isEmpty()) {
context.append("根据以下文件内容回答问题:\n\n");
for (DocumentInfo doc : documents) {
context.append("文件: ").append(doc.filename)
.append("\n路径: ").append(doc.path)
.append("\n内容: ").append(doc.content.substring(0, Math.min(500, doc.content.length())))
.append("...\n\n");
}
}
String prompt = context.toString() + "用户问题: " + msg;
return chatService.chat(prompt, model, sessionId);
}
@EnableChatHistory
@GetMapping(value = "/chatS", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> fileChatStream(
@RequestParam String msg,
@RequestParam String sessionId,
@RequestParam(defaultValue = "qwen") String model,
@RequestParam(required = false) Boolean contextAware,
@RequestParam(defaultValue = "3") int topK,
@RequestParam(required = false) String fileType) {
SessionManagerService.SessionInfo session = sessionManagerService.getOrCreateSession(sessionId);
session.updateLastActivity();
boolean useContextAware = session.isContextAware();
if (contextAware != null) {
useContextAware = contextAware;
session.setContextAware(useContextAware);
}
Map<String, Object> filters = new HashMap<>();
if (fileType != null && !fileType.isEmpty()) {
filters.put("filetype", fileType);
}
List<DocumentInfo> documents = fileSearchService.search(msg, topK, filters);
StringBuilder context = new StringBuilder();
if (useContextAware && !documents.isEmpty()) {
context.append("根据以下文件内容回答问题:\n\n");
for (DocumentInfo doc : documents) {
context.append("文件: ").append(doc.filename)
.append("\n路径: ").append(doc.path)
.append("\n内容: ").append(doc.content.substring(0, Math.min(500, doc.content.length())))
.append("...\n\n");
}
}
String prompt = context.toString() + "用户问题: " + msg;
return chatService.chatStreamRaw(prompt, model, sessionId);
}
@GetMapping("/clear")
public Map<String, Object> clearAll(@RequestParam(required = false) String sessionId) {
fileSearchService.clearAll();
if (sessionId != null) {
chatHistoryService.clearHistory(sessionId);
}
Map<String, Object> result = new HashMap<>();
result.put("status", "success");
result.put("message", sessionId != null
? "已清除向量数据库和会话历史记录"
: "已清除向量数据库中的所有文档");
return result;
}
@GetMapping("/context")
public Map<String, Object> setContextAware(
@RequestParam String sessionId,
@RequestParam boolean contextAware) {
SessionManagerService.SessionInfo session = sessionManagerService.getOrCreateSession(sessionId);
session.setContextAware(contextAware);
Map<String, Object> result = new HashMap<>();
result.put("status", "success");
result.put("contextAware", contextAware);
result.put("sessionId", sessionId);
return result;
}
}
8.Service
1.ChatHistoryService
java
package com.cxl.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
@Slf4j
@Service
@ConfigurationProperties(prefix = "chat-history")
public class ChatHistoryService {
private String keyPrefix = "chat:history:";
private int maxSize = 20;
private int expireDays = 7;
private boolean enabled = true;
private final StringRedisTemplate redisTemplate;
public ChatHistoryService(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}
public void addMessage(String sessionId, String role, String content) {
if (!enabled) {
log.debug("Chat history is disabled, skipping save");
return;
}
if (sessionId == null || sessionId.isEmpty()) {
log.warn("sessionId is null or empty, skipping save");
return;
}
String key = keyPrefix + sessionId;
String message = String.format("%s|%s|%s",
LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME),
role,
content
);
try {
redisTemplate.opsForList().rightPush(key, message);
redisTemplate.opsForList().trim(key, -maxSize, -1);
redisTemplate.expire(key, expireDays, TimeUnit.DAYS);
log.info("Saved message to Redis - sessionId: {}, role: {}, key: {}", sessionId, role, key);
} catch (Exception e) {
log.error("Failed to save message to Redis: {}", e.getMessage(), e);
}
}
public List<ChatMessage> getHistory(String sessionId) {
if (!enabled) {
log.debug("Chat history is disabled, returning empty history");
return new ArrayList<>();
}
if (sessionId == null || sessionId.isEmpty()) {
log.warn("sessionId is null or empty, returning empty history");
return new ArrayList<>();
}
String key = keyPrefix + sessionId;
log.info("Getting history from Redis - sessionId: {}, key: {}", sessionId, key);
try {
List<String> rawMessages = redisTemplate.opsForList().range(key, 0, -1);
if (rawMessages == null || rawMessages.isEmpty()) {
log.info("No history found in Redis for sessionId: {}", sessionId);
return new ArrayList<>();
}
log.info("Found {} messages in Redis for sessionId: {}", rawMessages.size(), sessionId);
List<ChatMessage> messages = new ArrayList<>();
for (String raw : rawMessages) {
String[] parts = raw.split("\\|", 3);
if (parts.length >= 3) {
ChatMessage msg = new ChatMessage();
msg.setTimestamp(parts[0]);
msg.setRole(parts[1]);
msg.setContent(parts[2]);
messages.add(msg);
}
}
return messages;
} catch (Exception e) {
log.error("Failed to get history from Redis: {}", e.getMessage(), e);
return new ArrayList<>();
}
}
public String getLastUserMessage(String sessionId) {
List<ChatMessage> history = getHistory(sessionId);
for (int i = history.size() - 1; i >= 0; i--) {
if ("user".equals(history.get(i).getRole())) {
return history.get(i).getContent();
}
}
return null;
}
public void clearHistory(String sessionId) {
String key = keyPrefix + sessionId;
redisTemplate.delete(key);
}
public String getKeyPrefix() { return keyPrefix; }
public void setKeyPrefix(String keyPrefix) { this.keyPrefix = keyPrefix; }
public int getMaxSize() { return maxSize; }
public void setMaxSize(int maxSize) { this.maxSize = maxSize; }
public int getExpireDays() { return expireDays; }
public void setExpireDays(int expireDays) { this.expireDays = expireDays; }
public boolean isEnabled() { return enabled; }
public void setEnabled(boolean enabled) { this.enabled = enabled; }
public static class ChatMessage {
private String timestamp;
private String role;
private String content;
public String getTimestamp() { return timestamp; }
public void setTimestamp(String timestamp) { this.timestamp = timestamp; }
public String getRole() { return role; }
public void setRole(String role) { this.role = role; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}
}
2.ChatService
java
package com.cxl.service;
import com.cxl.config.SystemSettings;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@Service
public class ChatService {
private final ChatClient qwenChatClient;
private final ChatClient deepseekChatClient;
private final SystemSettings systemSettings;
private final ChatHistoryService chatHistoryService;
private final DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
public ChatService(
ChatClient qwenChatClient,
@Qualifier("deepseekChatClient") ChatClient deepseekChatClient,
SystemSettings systemSettings,
ChatHistoryService chatHistoryService) {
this.qwenChatClient = qwenChatClient;
this.deepseekChatClient = deepseekChatClient;
this.systemSettings = systemSettings;
this.chatHistoryService = chatHistoryService;
}
public String chat(String msg, String model) {
return chatWithHistory(msg, model, null);
}
public String chat(String msg, String model, String sessionId) {
return chatWithHistory(msg, model, sessionId);
}
private String chatWithHistory(String msg, String model, String sessionId) {
String startTime = LocalDateTime.now().format(formatter);
log.info("========== 对话开始 ==========");
log.info("开始时间: {}", startTime);
log.info("使用模型: {}", model);
log.info("用户消息: {}", msg);
if (sessionId != null) {
log.info("会话ID: {}", sessionId);
}
ChatClient chatClient = getChatClient(model);
String modelName = getModelName(model);
String systemPrompt = systemSettings.getSystemPrompt(modelName);
log.info("System预制词: {}", systemPrompt);
long startMillis = System.currentTimeMillis();
ChatClient.ChatClientRequestSpec promptSpec = chatClient.prompt()
.system(systemPrompt);
if (sessionId != null) {
List<ChatHistoryService.ChatMessage> history = chatHistoryService.getHistory(sessionId);
if (!history.isEmpty()) {
StringBuilder historyContext = new StringBuilder();
historyContext.append("\n以下是之前的对话历史:\n");
for (ChatHistoryService.ChatMessage chatMsg : history) {
historyContext.append(chatMsg.getRole()).append(": ").append(chatMsg.getContent()).append("\n");
}
promptSpec = promptSpec.system(systemPrompt + historyContext.toString());
}
}
String content = promptSpec.user(msg).call().content();
long endMillis = System.currentTimeMillis();
log.info("响应耗时: {} ms", endMillis - startMillis);
log.info("AI回复: {}", content);
log.info("========== 对话结束 ==========\n");
return content;
}
public Flux<String> chatStreamRaw(String msg, String model) {
return chatStreamWithHistory(msg, model, null);
}
public Flux<String> chatStreamRaw(String msg, String model, String sessionId) {
return chatStreamWithHistory(msg, model, sessionId);
}
private Flux<String> chatStreamWithHistory(String msg, String model, String sessionId) {
String startTime = LocalDateTime.now().format(formatter);
log.info("========== 流式对话开始 ==========");
log.info("开始时间: {}", startTime);
log.info("使用模型: {}", model);
log.info("用户消息: {}", msg);
if (sessionId != null) {
log.info("会话ID: {}", sessionId);
}
ChatClient chatClient = getChatClient(model);
String modelName = getModelName(model);
String systemPrompt = systemSettings.getSystemPrompt(modelName);
log.info("System预制词: {}", systemPrompt);
long startMillis = System.currentTimeMillis();
ChatClient.ChatClientRequestSpec promptSpec = chatClient.prompt()
.system(systemPrompt);
if (sessionId != null) {
List<ChatHistoryService.ChatMessage> history = chatHistoryService.getHistory(sessionId);
if (!history.isEmpty()) {
StringBuilder historyContext = new StringBuilder();
historyContext.append("\n以下是之前的对话历史:\n");
for (ChatHistoryService.ChatMessage chatMsg : history) {
historyContext.append(chatMsg.getRole()).append(": ").append(chatMsg.getContent()).append("\n");
}
promptSpec = promptSpec.system(systemPrompt + historyContext.toString());
}
}
return promptSpec.user(msg)
.stream()
.content()
.doOnComplete(() -> {
long endMillis = System.currentTimeMillis();
log.info("响应耗时: {} ms", endMillis - startMillis);
log.info("========== 流式对话结束 ==========\n");
})
.doOnError(error -> {
log.error("流式对话发生错误: {}", error.getMessage(), error);
});
}
private ChatClient getChatClient(String model) {
if ("deepseek".equalsIgnoreCase(model)) {
return deepseekChatClient;
}
return qwenChatClient;
}
private String getModelName(String model) {
if ("deepseek".equalsIgnoreCase(model)) {
return "deepseek-r1:1.5b";
}
return "qwen3.5:2b";
}
}
3.FileSearchService
java
package com.cxl.service;
import com.cxl.config.KnowledgeBaseConfig;
import lombok.extern.slf4j.Slf4j;
import org.apache.poi.xslf.usermodel.XMLSlideShow;
import org.apache.poi.xslf.usermodel.XSLFSlide;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.stereotype.Service;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Service
public class FileSearchService {
private final VectorStore vectorStore;
private final KnowledgeBaseConfig knowledgeBaseConfig;
public FileSearchService(VectorStore vectorStore, KnowledgeBaseConfig knowledgeBaseConfig) {
this.vectorStore = vectorStore;
this.knowledgeBaseConfig = knowledgeBaseConfig;
}
public int scanDirectory(String directoryPath) {
String actualPath = (directoryPath == null || directoryPath.isEmpty())
? knowledgeBaseConfig.getDefaultPath()
: directoryPath;
log.info("开始扫描文件夹: {}", actualPath);
List<Document> documents = new ArrayList<>();
File directory = new File(actualPath);
if (!directory.exists() || !directory.isDirectory()) {
log.error("文件夹不存在或不是目录: {}", actualPath);
return 0;
}
try {
scanFile(directory, documents);
if (!documents.isEmpty()) {
vectorStore.add(documents);
log.info("成功扫描 {} 个文件并添加到向量数据库", documents.size());
}
} catch (Exception e) {
log.error("扫描文件夹失败: {}", e.getMessage(), e);
return 0;
}
return documents.size();
}
private void scanFile(File file, List<Document> documents) throws IOException {
if (file.isDirectory()) {
File[] files = file.listFiles();
if (files != null) {
for (File f : files) {
scanFile(f, documents);
}
}
} else {
String content = extractContent(file);
if (content != null && !content.trim().isEmpty()) {
Map<String, Object> metadata = new HashMap<>();
metadata.put("filename", file.getName());
metadata.put("path", file.getAbsolutePath());
metadata.put("filetype", getFileType(file.getName()));
metadata.put("extension", getFileExtension(file.getName()));
metadata.put("size", (int) file.length());
Document doc = Document.builder()
.id(UUID.randomUUID().toString())
.text(content)
.metadata(metadata)
.build();
documents.add(doc);
}
}
}
private String extractContent(File file) {
try {
String extension = getFileExtension(file.getName()).toLowerCase();
Path path = Paths.get(file.getAbsolutePath());
if (extension.equals("txt") || extension.equals("md") || extension.equals("json") ||
extension.equals("xml") || extension.equals("html") || extension.equals("css") ||
extension.equals("js") || extension.equals("java") || extension.equals("py") ||
extension.equals("go") || extension.equals("rs") || extension.equals("c") ||
extension.equals("cpp") || extension.equals("h") || extension.equals("sql")) {
return Files.readString(path);
} else if (extension.equals("pptx") || extension.equals("ppt")) {
return extractPptContent(file);
} else {
return String.format("文件名: %s\n路径: %s\n文件大小: %d bytes\n文件类型: %s",
file.getName(), file.getAbsolutePath(), file.length(), getFileType(file.getName()));
}
} catch (Exception e) {
log.warn("提取文件内容失败: {} - {}", file.getAbsolutePath(), e.getMessage());
return null;
}
}
private String extractPptContent(File file) {
String extension = getFileExtension(file.getName()).toLowerCase();
StringBuilder content = new StringBuilder();
content.append("文件名: ").append(file.getName()).append("\n");
content.append("路径: ").append(file.getAbsolutePath()).append("\n\n");
if (!extension.equals("pptx")) {
return String.format("文件名: %s\n路径: %s\n文件大小: %d bytes\n文件类型: PPT(仅支持.pptx格式)",
file.getName(), file.getAbsolutePath(), file.length());
}
try (FileInputStream fis = new FileInputStream(file);
XMLSlideShow ppt = new XMLSlideShow(fis)) {
List<XSLFSlide> slides = ppt.getSlides();
for (int i = 0; i < slides.size(); i++) {
XSLFSlide slide = slides.get(i);
content.append("=== 第 ").append(i + 1).append(" 页 ===\n");
for (Object shape : slide.getShapes()) {
if (shape instanceof org.apache.poi.xslf.usermodel.XSLFTextShape) {
org.apache.poi.xslf.usermodel.XSLFTextShape textShape =
(org.apache.poi.xslf.usermodel.XSLFTextShape) shape;
String text = textShape.getText();
if (text != null && !text.trim().isEmpty()) {
content.append(text).append("\n");
}
}
}
content.append("\n");
}
} catch (Exception e) {
log.warn("提取PPT内容失败: {} - {}", file.getAbsolutePath(), e.getMessage());
return String.format("文件名: %s\n路径: %s\n文件大小: %d bytes\n文件类型: PPT",
file.getName(), file.getAbsolutePath(), file.length());
}
return content.toString();
}
private String getFileType(String filename) {
String extension = getFileExtension(filename).toLowerCase();
switch (extension) {
case "txt": case "md": case "json": case "xml": case "html":
case "css": case "js": case "java": case "py": case "go": case "rs":
case "c": case "cpp": case "h": case "sql":
return "text";
case "pdf": case "doc": case "docx": return "document";
case "pptx": case "ppt": return "presentation";
case "jpg": case "jpeg": case "png": case "gif": case "webp": return "image";
case "mp4": case "avi": case "mov": case "wmv": return "video";
default: return "other";
}
}
private String getFileExtension(String filename) {
int lastDotIndex = filename.lastIndexOf('.');
return lastDotIndex > 0 ? filename.substring(lastDotIndex + 1) : "";
}
public List<DocumentInfo> search(String query, int topK, Map<String, Object> filters) {
log.info("开始检索: {}", query);
try {
int actualTopK = (topK > 0) ? topK : knowledgeBaseConfig.getTopK();
SearchRequest.Builder builder = SearchRequest.builder()
.query(query)
.topK(actualTopK);
if (filters != null && filters.containsKey("filetype")) {
String filterType = (String) filters.get("filetype");
Expression filter = new FilterExpressionBuilder()
.eq("filetype", filterType)
.build();
builder.filterExpression(filter);
}
List<org.springframework.ai.document.Document> results = vectorStore.similaritySearch(builder.build());
return results.stream()
.map(this::toDocumentInfo)
.collect(Collectors.toList());
} catch (Exception e) {
log.error("检索失败: {}", e.getMessage(), e);
return new ArrayList<>();
}
}
private DocumentInfo toDocumentInfo(org.springframework.ai.document.Document doc) {
DocumentInfo info = new DocumentInfo();
info.id = doc.getId();
info.content = doc.getText();
info.filename = (String) doc.getMetadata().getOrDefault("filename", "");
info.path = (String) doc.getMetadata().getOrDefault("path", "");
info.filetype = (String) doc.getMetadata().getOrDefault("filetype", "");
info.extension = (String) doc.getMetadata().getOrDefault("extension", "");
Object sizeObj = doc.getMetadata().get("size");
info.size = (sizeObj != null) ? ((Number) sizeObj).longValue() : 0L;
return info;
}
public void clearAll() {
try {
SearchRequest searchRequest = SearchRequest.builder()
.query("*")
.topK(1000)
.build();
List<org.springframework.ai.document.Document> allDocs = vectorStore.similaritySearch(searchRequest);
List<String> ids = allDocs.stream()
.map(org.springframework.ai.document.Document::getId)
.collect(Collectors.toList());
if (!ids.isEmpty()) {
vectorStore.delete(ids);
}
log.info("已清除向量数据库中的所有文档,共 {} 条", ids.size());
} catch (Exception e) {
log.error("清除向量数据库失败: {}", e.getMessage(), e);
}
}
public static class DocumentInfo {
public String id;
public String filename;
public String path;
public String content;
public String filetype;
public String extension;
public long size;
}
}
4.SessionManagerService
java
package com.cxl.service;
import lombok.Data;
import org.springframework.stereotype.Service;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
@Service
public class SessionManagerService {
private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>();
private final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1);
private int sessionTimeoutMinutes = 60;
public SessionManagerService() {
executorService.scheduleAtFixedRate(this::cleanupExpiredSessions,
sessionTimeoutMinutes,
sessionTimeoutMinutes,
TimeUnit.MINUTES);
}
public SessionInfo getOrCreateSession(String sessionId) {
return sessions.computeIfAbsent(sessionId, id -> {
SessionInfo session = new SessionInfo(id);
return session;
});
}
public void removeSession(String sessionId) {
sessions.remove(sessionId);
}
private void cleanupExpiredSessions() {
long now = System.currentTimeMillis();
long timeoutMs = sessionTimeoutMinutes * 60 * 1000;
sessions.entrySet().removeIf(entry -> {
SessionInfo session = entry.getValue();
boolean expired = now - session.getLastActivityTime() > timeoutMs;
return expired;
});
}
@Data
public static class SessionInfo {
private final String sessionId;
private long lastActivityTime;
private boolean contextAware = true;
public SessionInfo(String sessionId) {
this.sessionId = sessionId;
this.lastActivityTime = System.currentTimeMillis();
}
public void updateLastActivity() {
this.lastActivityTime = System.currentTimeMillis();
}
}
}
9.向量数据库初始化
java
package com.cxl;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
public class QdrantCollectionCreator {
private static final String COLLECTION_NAME = "my_docs";
private static final int VECTOR_DIMENSION = 1024;
private static final String HOST = "localhost";
private static final int PORT = 6333;
public static void main(String[] args) {
System.out.println("=== Qdrant Collection Creator ===");
HttpClient client = HttpClient.newHttpClient();
System.out.println("\n1. Checking Qdrant service availability...");
String healthUrl = "http://" + HOST + ":" + PORT + "/";
try {
HttpRequest healthRequest = HttpRequest.newBuilder()
.uri(URI.create(healthUrl))
.GET()
.build();
HttpResponse<String> healthResponse = client.send(healthRequest, HttpResponse.BodyHandlers.ofString());
System.out.println("Qdrant service status: " + healthResponse.statusCode());
System.out.println("Response: " + healthResponse.body());
} catch (Exception e) {
System.err.println("ERROR: Cannot connect to Qdrant at http://" + HOST + ":" + PORT);
return;
}
System.out.println("\n2. Listing available collections...");
String listUrl = "http://" + HOST + ":" + PORT + "/collections";
try {
HttpRequest listRequest = HttpRequest.newBuilder()
.uri(URI.create(listUrl))
.GET()
.build();
HttpResponse<String> listResponse = client.send(listRequest, HttpResponse.BodyHandlers.ofString());
System.out.println("Status: " + listResponse.statusCode());
System.out.println("Response: " + listResponse.body());
} catch (Exception e) {
System.out.println("Error: " + e.getMessage());
}
System.out.println("\n3. Checking if collection exists...");
String checkUrl = "http://" + HOST + ":" + PORT + "/collections/" + COLLECTION_NAME;
try {
HttpRequest checkRequest = HttpRequest.newBuilder()
.uri(URI.create(checkUrl))
.GET()
.build();
HttpResponse<String> checkResponse = client.send(checkRequest, HttpResponse.BodyHandlers.ofString());
System.out.println("Status: " + checkResponse.statusCode());
if (checkResponse.statusCode() == 200) {
System.out.println("Collection '" + COLLECTION_NAME + "' already exists!");
return;
}
} catch (Exception e) {
System.out.println("Collection does not exist. Will create it...");
}
System.out.println("\n4. Creating collection with PUT...");
String createUrl = "http://" + HOST + ":" + PORT + "/collections/" + COLLECTION_NAME;
String requestBody = String.format("""
{
"vectors": {
"size": %d,
"distance": "Cosine"
}
}
""", VECTOR_DIMENSION);
try {
HttpRequest createRequest = HttpRequest.newBuilder()
.uri(URI.create(createUrl))
.header("Content-Type", "application/json")
.PUT(HttpRequest.BodyPublishers.ofString(requestBody))
.build();
HttpResponse<String> createResponse = client.send(createRequest, HttpResponse.BodyHandlers.ofString());
System.out.println("Status: " + createResponse.statusCode());
System.out.println("Response: " + createResponse.body());
if (createResponse.statusCode() == 200 || createResponse.statusCode() == 201) {
System.out.println("\n✓ Collection '" + COLLECTION_NAME + "' created successfully!");
return;
}
} catch (Exception e) {
System.err.println("Error with PUT: " + e.getMessage());
}
System.out.println("\n5. Trying POST to /collections endpoint...");
String postUrl = "http://" + HOST + ":" + PORT + "/collections";
try {
HttpRequest postRequest = HttpRequest.newBuilder()
.uri(URI.create(postUrl))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build();
HttpResponse<String> postResponse = client.send(postRequest, HttpResponse.BodyHandlers.ofString());
System.out.println("Status: " + postResponse.statusCode());
System.out.println("Response: " + postResponse.body());
if (postResponse.statusCode() == 200 || postResponse.statusCode() == 201) {
System.out.println("\n✓ Collection '" + COLLECTION_NAME + "' created successfully!");
} else {
System.err.println("\n✗ Failed to create collection");
}
} catch (Exception e) {
System.err.println("Error with POST: " + e.getMessage());
}
}
}
10.启动类
java
package com.cxl;
import com.cxl.config.AppConfig;
import com.cxl.config.SystemSettings;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@SpringBootApplication
@EnableConfigurationProperties({SystemSettings.class, AppConfig.class})
public class SpringaiDemoApplication {
public static void main(String[] args) {
SpringApplication.run(SpringaiDemoApplication.class, args);
}
}