创建一个使用Spring AI框架构建RAG(Retrieval-Augmented Generation)系统的案例

创建一个使用Spring AI框架构建RAG(Retrieval-Augmented Generation)系统的案例。这个案例将演示如何构建一个智能文档问答系统。

项目概述

我们将构建一个基于Spring AI的RAG系统,它可以:

  • 摄取和处理文档(PDF、文本等)
  • 将文档分割成chunks并向量化存储
  • 根据用户问题检索相关文档片段
  • 结合检索到的上下文生成准确回答

1. 项目依赖配置

首先创建Maven项目,添加以下依赖:

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
         http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.0</version>
        <relativePath/>
    </parent>
    
    <groupId>com.example</groupId>
    <artifactId>spring-ai-rag-demo</artifactId>
    <version>1.0.0</version>
    
    <properties>
        <maven.compiler.source>21</maven.compiler.source>
        <maven.compiler.target>21</maven.compiler.target>
        <spring-ai.version>0.8.0</spring-ai.version>
    </properties>
    
    <dependencies>
        <!-- Spring Boot Starters -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        
        <!-- Spring AI Core -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- Vector Store -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-chroma-store</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- Document Readers -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pdf-document-reader</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-tika-document-reader</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
        
        <!-- Text Splitters -->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-transformers</artifactId>
            <version>${spring-ai.version}</version>
        </dependency>
    </dependencies>
    
    <repositories>
        <repository>
            <id>spring-milestones</id>
            <name>Spring Milestones</name>
            <url>https://repo.spring.io/milestone</url>
            <snapshots>
                <enabled>false</enabled>
            </snapshots>
        </repository>
    </repositories>
</project>

2. 配置文件

创建 application.yml

yaml 复制代码
spring:
  ai:
    openai:
      api-key: ${OPENAI_API_KEY}
      chat:
        options:
          model: gpt-3.5-turbo
          temperature: 0.7
      embedding:
        options:
          model: text-embedding-ada-002
    
    vectorstore:
      chroma:
        url: http://localhost:8000
        collection-name: documents

server:
  port: 8080

logging:
  level:
    org.springframework.ai: DEBUG

3. 核心配置类

java 复制代码
package com.example.rag.config;

import org.springframework.ai.document.DocumentReader;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.reader.pdf.PdfDocumentReader;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.chroma.ChromaVectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;

@Configuration
public class RagConfiguration {

    @Bean
    public RestTemplate restTemplate() {
        return new RestTemplate();
    }

    @Bean
    public TokenTextSplitter textSplitter() {
        return new TokenTextSplitter(500, 100, 5, 10000, true);
    }

    @Bean
    public VectorStore vectorStore(EmbeddingClient embeddingClient) {
        return new ChromaVectorStore(embeddingClient, "http://localhost:8000", "documents", true);
    }
}

4. 文档处理服务

java 复制代码
package com.example.rag.service;

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentReader;
import org.springframework.ai.reader.pdf.PdfDocumentReader;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.List;
import java.util.Map;

@Service
public class DocumentService {

    @Autowired
    private VectorStore vectorStore;

    @Autowired
    private TokenTextSplitter textSplitter;

    /**
     * 处理并存储文档
     */
    public String processDocument(MultipartFile file) throws IOException {
        // 保存上传的文件到临时位置
        Path tempFile = Files.createTempFile("upload-", file.getOriginalFilename());
        Files.copy(file.getInputStream(), tempFile, StandardCopyOption.REPLACE_EXISTING);

        try {
            // 选择合适的文档读取器
            DocumentReader reader = createDocumentReader(tempFile, file.getOriginalFilename());
            
            // 读取文档
            List<Document> documents = reader.get();
            
            // 为每个文档添加元数据
            documents.forEach(doc -> {
                Map<String, Object> metadata = doc.getMetadata();
                metadata.put("source", file.getOriginalFilename());
                metadata.put("upload_time", System.currentTimeMillis());
            });

            // 分割文档
            List<Document> splitDocuments = textSplitter.apply(documents);

            // 存储到向量数据库
            vectorStore.add(splitDocuments);

            return String.format("成功处理文档: %s, 生成了 %d 个文档片段", 
                               file.getOriginalFilename(), splitDocuments.size());

        } finally {
            // 清理临时文件
            Files.deleteIfExists(tempFile);
        }
    }

    /**
     * 根据文件类型创建相应的文档读取器
     */
    private DocumentReader createDocumentReader(Path filePath, String filename) {
        Resource resource = new org.springframework.core.io.FileSystemResource(filePath.toFile());
        
        if (filename.toLowerCase().endsWith(".pdf")) {
            return new PdfDocumentReader(resource);
        } else {
            // 对于其他文件类型,使用Tika读取器
            return new TikaDocumentReader(resource);
        }
    }

    /**
     * 批量处理文档
     */
    public String processBatchDocuments(List<MultipartFile> files) {
        StringBuilder result = new StringBuilder();
        int successCount = 0;
        int totalChunks = 0;

        for (MultipartFile file : files) {
            try {
                String processResult = processDocument(file);
                result.append(processResult).append("\n");
                successCount++;
                
                // 提取chunk数量
                String[] parts = processResult.split("生成了 ");
                if (parts.length > 1) {
                    String chunkInfo = parts[1].split(" ")[0];
                    totalChunks += Integer.parseInt(chunkInfo);
                }
            } catch (Exception e) {
                result.append(String.format("处理文档 %s 失败: %s\n", 
                                          file.getOriginalFilename(), e.getMessage()));
            }
        }

        result.append(String.format("\n批处理完成: 成功处理 %d/%d 个文件,总共生成 %d 个文档片段", 
                                  successCount, files.size(), totalChunks));
        return result.toString();
    }
}

5. RAG查询服务

java 复制代码
package com.example.rag.service;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Service
public class RagService {

    @Autowired
    private ChatClient chatClient;

    @Autowired
    private VectorStore vectorStore;

    private static final String SYSTEM_PROMPT_TEMPLATE = """
            你是一个智能助手,专门根据提供的上下文信息回答用户问题。
            
            请遵循以下规则:
            1. 仅基于提供的上下文信息回答问题
            2. 如果上下文中没有相关信息,请明确说明
            3. 保持回答准确、简洁且有帮助
            4. 如果可能,引用具体的信息来源
            
            上下文信息:
            {context}
            
            请基于以上上下文回答用户的问题。
            """;

    /**
     * 执行RAG查询
     */
    public RagResponse query(String question, int maxResults) {
        // 1. 检索相关文档
        List<Document> relevantDocuments = retrieveRelevantDocuments(question, maxResults);

        if (relevantDocuments.isEmpty()) {
            return new RagResponse(
                "抱歉,我在知识库中没有找到与您的问题相关的信息。请尝试重新表述问题或上传相关文档。",
                List.of(),
                question
            );
        }

        // 2. 构建上下文
        String context = buildContext(relevantDocuments);

        // 3. 生成回答
        String answer = generateAnswer(question, context);

        // 4. 准备源信息
        List<DocumentInfo> sources = relevantDocuments.stream()
                .map(this::extractDocumentInfo)
                .collect(Collectors.toList());

        return new RagResponse(answer, sources, question);
    }

    /**
     * 检索相关文档
     */
    private List<Document> retrieveRelevantDocuments(String question, int maxResults) {
        SearchRequest searchRequest = SearchRequest.query(question)
                .withTopK(maxResults)
                .withSimilarityThreshold(0.7);

        return vectorStore.similaritySearch(searchRequest);
    }

    /**
     * 构建上下文字符串
     */
    private String buildContext(List<Document> documents) {
        return documents.stream()
                .map(doc -> {
                    String source = doc.getMetadata().get("source") != null ? 
                        doc.getMetadata().get("source").toString() : "未知来源";
                    return String.format("来源:%s\n内容:%s\n", source, doc.getContent());
                })
                .collect(Collectors.joining("\n---\n"));
    }

    /**
     * 生成回答
     */
    private String generateAnswer(String question, String context) {
        Map<String, Object> promptParams = new HashMap<>();
        promptParams.put("context", context);

        SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(SYSTEM_PROMPT_TEMPLATE);
        Message systemMessage = systemPromptTemplate.createMessage(promptParams);
        UserMessage userMessage = new UserMessage(question);

        Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
        ChatResponse response = chatClient.call(prompt);

        return response.getResult().getOutput().getContent();
    }

    /**
     * 提取文档信息
     */
    private DocumentInfo extractDocumentInfo(Document document) {
        Map<String, Object> metadata = document.getMetadata();
        String source = metadata.get("source") != null ? 
            metadata.get("source").toString() : "未知来源";
        String preview = document.getContent().length() > 200 ? 
            document.getContent().substring(0, 200) + "..." : 
            document.getContent();

        return new DocumentInfo(source, preview, metadata);
    }

    /**
     * 智能搜索 - 返回相关文档片段
     */
    public List<DocumentInfo> searchDocuments(String query, int maxResults) {
        List<Document> documents = retrieveRelevantDocuments(query, maxResults);
        return documents.stream()
                .map(this::extractDocumentInfo)
                .collect(Collectors.toList());
    }

    // 内部类定义
    public static class RagResponse {
        private final String answer;
        private final List<DocumentInfo> sources;
        private final String originalQuestion;

        public RagResponse(String answer, List<DocumentInfo> sources, String originalQuestion) {
            this.answer = answer;
            this.sources = sources;
            this.originalQuestion = originalQuestion;
        }

        // Getters
        public String getAnswer() { return answer; }
        public List<DocumentInfo> getSources() { return sources; }
        public String getOriginalQuestion() { return originalQuestion; }
    }

    public static class DocumentInfo {
        private final String source;
        private final String preview;
        private final Map<String, Object> metadata;

        public DocumentInfo(String source, String preview, Map<String, Object> metadata) {
            this.source = source;
            this.preview = preview;
            this.metadata = metadata;
        }

        // Getters
        public String getSource() { return source; }
        public String getPreview() { return preview; }
        public Map<String, Object> getMetadata() { return metadata; }
    }
}

6. REST API控制器

java 复制代码
package com.example.rag.controller;

import com.example.rag.service.DocumentService;
import com.example.rag.service.RagService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.util.List;

@RestController
@RequestMapping("/api/rag")
@CrossOrigin(origins = "*")
public class RagController {

    @Autowired
    private DocumentService documentService;

    @Autowired
    private RagService ragService;

    /**
     * 上传单个文档
     */
    @PostMapping("/upload")
    public ResponseEntity<String> uploadDocument(@RequestParam("file") MultipartFile file) {
        try {
            if (file.isEmpty()) {
                return ResponseEntity.badRequest().body("文件不能为空");
            }

            String result = documentService.processDocument(file);
            return ResponseEntity.ok(result);
        } catch (Exception e) {
            return ResponseEntity.internalServerError()
                    .body("文档处理失败: " + e.getMessage());
        }
    }

    /**
     * 批量上传文档
     */
    @PostMapping("/upload/batch")
    public ResponseEntity<String> uploadBatchDocuments(
            @RequestParam("files") List<MultipartFile> files) {
        try {
            if (files.isEmpty()) {
                return ResponseEntity.badRequest().body("文件列表不能为空");
            }

            String result = documentService.processBatchDocuments(files);
            return ResponseEntity.ok(result);
        } catch (Exception e) {
            return ResponseEntity.internalServerError()
                    .body("批量文档处理失败: " + e.getMessage());
        }
    }

    /**
     * RAG问答
     */
    @PostMapping("/query")
    public ResponseEntity<RagService.RagResponse> query(
            @RequestBody QueryRequest request) {
        try {
            if (request.getQuestion() == null || request.getQuestion().trim().isEmpty()) {
                return ResponseEntity.badRequest().build();
            }

            int maxResults = request.getMaxResults() > 0 ? request.getMaxResults() : 5;
            RagService.RagResponse response = ragService.query(request.getQuestion(), maxResults);
            
            return ResponseEntity.ok(response);
        } catch (Exception e) {
            return ResponseEntity.internalServerError().build();
        }
    }

    /**
     * 文档搜索
     */
    @GetMapping("/search")
    public ResponseEntity<List<RagService.DocumentInfo>> searchDocuments(
            @RequestParam String query,
            @RequestParam(defaultValue = "10") int maxResults) {
        try {
            List<RagService.DocumentInfo> results = ragService.searchDocuments(query, maxResults);
            return ResponseEntity.ok(results);
        } catch (Exception e) {
            return ResponseEntity.internalServerError().build();
        }
    }

    // 请求对象类
    public static class QueryRequest {
        private String question;
        private int maxResults = 5;

        // Constructors, getters, setters
        public QueryRequest() {}

        public String getQuestion() { return question; }
        public void setQuestion(String question) { this.question = question; }
        
        public int getMaxResults() { return maxResults; }
        public void setMaxResults(int maxResults) { this.maxResults = maxResults; }
    }
}

7. 主应用程序类

java 复制代码
package com.example.rag;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class SpringAiRagApplication {
    public static void main(String[] args) {
        SpringApplication.run(SpringAiRagApplication.class, args);
    }
}

8. 使用示例

启动Chroma向量数据库

bash 复制代码
# 使用Docker启动Chroma
docker run -p 8000:8000 chromadb/chroma

环境变量设置

bash 复制代码
export OPENAI_API_KEY=your_openai_api_key_here

API使用示例

  1. 上传文档
bash 复制代码
curl -X POST -F "file=@document.pdf" http://localhost:8080/api/rag/upload
  1. 问答查询
bash 复制代码
curl -X POST http://localhost:8080/api/rag/query \
  -H "Content-Type: application/json" \
  -d '{"question": "这个文档主要讲了什么内容?", "maxResults": 5}'
  1. 文档搜索
bash 复制代码
curl "http://localhost:8080/api/rag/search?query=人工智能&maxResults=10"

9. 前端集成示例(HTML + JavaScript)

html 复制代码
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Spring AI RAG Demo</title>
    <style>
        body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
        .section { margin: 20px 0; padding: 20px; border: 1px solid #ddd; border-radius: 5px; }
        .response { background-color: #f5f5f5; padding: 15px; margin: 10px 0; border-radius: 5px; }
        .sources { margin-top: 15px; }
        .source-item { background-color: #e9ecef; padding: 10px; margin: 5px 0; border-radius: 3px; }
        input, textarea, button { width: 100%; padding: 10px; margin: 5px 0; }
        button { background-color: #007bff; color: white; border: none; cursor: pointer; }
        button:hover { background-color: #0056b3; }
    </style>
</head>
<body>
    <h1>Spring AI RAG 智能问答系统</h1>

    <!-- 文档上传 -->
    <div class="section">
        <h2>上传文档</h2>
        <input type="file" id="fileInput" multiple accept=".pdf,.txt,.doc,.docx">
        <button onclick="uploadDocuments()">上传文档</button>
        <div id="uploadResult"></div>
    </div>

    <!-- 问答界面 -->
    <div class="section">
        <h2>智能问答</h2>
        <textarea id="questionInput" rows="3" placeholder="请输入您的问题..."></textarea>
        <button onclick="askQuestion()">提问</button>
        <div id="answerResult"></div>
    </div>

    <script>
        const API_BASE = 'http://localhost:8080/api/rag';

        async function uploadDocuments() {
            const fileInput = document.getElementById('fileInput');
            const files = fileInput.files;
            
            if (files.length === 0) {
                alert('请选择要上传的文件');
                return;
            }

            const formData = new FormData();
            for (let i = 0; i < files.length; i++) {
                formData.append('files', files[i]);
            }

            try {
                const response = await fetch(`${API_BASE}/upload/batch`, {
                    method: 'POST',
                    body: formData
                });
                
                const result = await response.text();
                document.getElementById('uploadResult').innerHTML = 
                    `<div class="response">${result}</div>`;
            } catch (error) {
                document.getElementById('uploadResult').innerHTML = 
                    `<div class="response">上传失败: ${error.message}</div>`;
            }
        }

        async function askQuestion() {
            const question = document.getElementById('questionInput').value.trim();
            
            if (!question) {
                alert('请输入问题');
                return;
            }

            try {
                const response = await fetch(`${API_BASE}/query`, {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json'
                    },
                    body: JSON.stringify({
                        question: question,
                        maxResults: 5
                    })
                });
                
                const result = await response.json();
                displayAnswer(result);
            } catch (error) {
                document.getElementById('answerResult').innerHTML = 
                    `<div class="response">查询失败: ${error.message}</div>`;
            }
        }

        function displayAnswer(result) {
            let html = `<div class="response">
                <h3>回答:</h3>
                <p>${result.answer}</p>
            </div>`;

            if (result.sources && result.sources.length > 0) {
                html += '<div class="sources"><h4>参考来源:</h4>';
                result.sources.forEach((source, index) => {
                    html += `<div class="source-item">
                        <strong>来源 ${index + 1}: ${source.source}</strong>
                        <p>${source.preview}</p>
                    </div>`;
                });
                html += '</div>';
            }

            document.getElementById('answerResult').innerHTML = html;
        }
    </script>
</body>
</html>

总结

这个代码的Spring AI RAG案例展示了:

  1. 文档处理:支持PDF、文本等多种格式的文档上传和处理
  2. 向量存储:使用Chroma作为向量数据库存储文档embeddings
  3. 智能检索:基于语义相似性检索相关文档片段
  4. 智能生成:结合检索到的上下文生成准确回答
  5. RESTful API:提供完整的HTTP API接口
  6. 前端集成:包含简单的Web界面进行交互
相关推荐
学习编程的小羊10 分钟前
Spring Boot 全局异常处理与日志监控实战
java·spring boot·后端
Sword9927 分钟前
🎮 AI编程新时代:Trae×Three.js打造沉浸式3D魔方游戏
前端·ai编程·trae
hongweihao1 小时前
Cursor 不讲武德,我反手开通了Claude max试试能不能用 Claude code 代替它
ai编程·claude
LEAFF1 小时前
手把手教你使用Coze开发一个AI翻译应用
agent·ai编程·coze
CoderLiu2 小时前
AI提示词工程优化指南:8个技巧,释放大语言模型的全部潜力
前端·人工智能·ai编程
量子位2 小时前
拿下3D生成行业新标杆!昆仑万维Matrix-3D新模型鲨疯了,一张图建模游戏场景
ai编程
量子位2 小时前
GitHub独立时代落幕!CEO离职创业,微软全面接管
github·ai编程
量子位2 小时前
黄仁勋子女成长路径曝光:一个学烘焙一个开酒吧,从基层做到英伟达高管
ai编程·nvidia
咚咚咚ddd3 小时前
Cursor Figma MCP: 从设计稿到自动代码实现(DTC)
ai编程·cursor·mcp