RAG实战:用Java+向量数据库打造智能问答系统

RAG实战:用Java+向量数据库打造智能问答系统

最近RAG(检索增强生成)这个概念特别火,基本上做AI应用的都在聊这个。我也花了点时间研究了一下,用Java+向量数据库搞了个智能问答系统。今天就把整个过程,包括踩的那些坑,都分享出来。

什么是RAG?为什么需要它?

简单说,RAG就是先把文档向量化存起来,用户提问的时候,先从向量数据库里找到相关的文档片段,然后把这些片段和问题一起扔给大模型,让模型基于这些内容回答。

为什么要这么做?因为大模型虽然有知识,但:

  1. 训练数据有截止时间,不知道最新的信息
  2. 知识可能不够精准,比如你公司的内部文档
  3. 直接问可能胡说八道,但基于文档回答就靠谱多了

技术选型:选哪个向量数据库?

Java生态下,常用的向量数据库有这几个:

  • Milvus:功能强大,性能好,但部署稍微复杂
  • Chroma:简单易用,Python生态很成熟,Java支持一般
  • Qdrant:Rust写的,性能不错,API也挺友好
  • PostgreSQL + pgvector:如果项目本来就用PostgreSQL,这个最省事

我选了Milvus,主要是因为:

  1. 社区活跃,文档比较全
  2. Java客户端支持不错
  3. 性能确实好,支持大规模数据

第一步:搭建Milvus

Milvus可以用Docker快速启动:

bash 复制代码
# 下载docker-compose.yml
wget https://github.com/milvus-io/milvus/releases/download/v2.3.0/milvus-standalone-docker-compose.yml -O docker-compose.yml

# 启动
docker-compose up -d

启动后,Milvus会在19530端口提供gRPC服务,9200端口是REST API。

第二步:项目依赖

xml 复制代码
<dependencies>
    <dependency>
        <groupId>io.milvus</groupId>
        <artifactId>milvus-sdk-java</artifactId>
        <version>2.3.0</version>
    </dependency>
    
    <dependency>
        <groupId>org.springframework.ai</groupId>
        <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
        <version>1.0.0</version>
    </dependency>
    
    <!-- 文档解析 -->
    <dependency>
        <groupId>org.apache.tika</groupId>
        <artifactId>tika-core</artifactId>
        <version>2.9.0</version>
    </dependency>
</dependencies>

第三步:文档处理和向量化

这是RAG的核心部分。先把文档加载进来,然后分块,最后向量化存到Milvus。

文档加载和分块

java 复制代码
@Service
@Slf4j
public class DocumentService {
    
    @Autowired
    private EmbeddingClient embeddingClient;
    
    // 加载文档(支持PDF、Word、TXT等)
    public List<String> loadDocuments(String filePath) {
        List<String> chunks = new ArrayList<>();
        
        try {
            // 用Tika解析文档
            Tika tika = new Tika();
            String content = tika.parseToString(new File(filePath));
            
            // 按段落分块,每块500字左右
            chunks = splitIntoChunks(content, 500);
            
        } catch (Exception e) {
            log.error("Failed to load document", e);
        }
        
        return chunks;
    }
    
    // 简单的分块策略:按句号和换行分割
    private List<String> splitIntoChunks(String text, int chunkSize) {
        List<String> chunks = new ArrayList<>();
        String[] sentences = text.split("[。\n]");
        
        StringBuilder currentChunk = new StringBuilder();
        for (String sentence : sentences) {
            if (currentChunk.length() + sentence.length() > chunkSize) {
                if (currentChunk.length() > 0) {
                    chunks.add(currentChunk.toString());
                    currentChunk = new StringBuilder();
                }
            }
            currentChunk.append(sentence).append("。");
        }
        
        if (currentChunk.length() > 0) {
            chunks.add(currentChunk.toString());
        }
        
        return chunks;
    }
    
    // 向量化并存入Milvus
    public void storeDocuments(String collectionName, List<String> chunks) {
        try {
            MilvusServiceClient client = getMilvusClient();
            
            // 创建集合(如果不存在)
            createCollectionIfNotExists(client, collectionName);
            
            // 批量向量化
            List<List<Float>> embeddings = new ArrayList<>();
            for (String chunk : chunks) {
                List<Float> embedding = embeddingClient.embed(chunk);
                embeddings.add(embedding);
            }
            
            // 构建数据
            List<Long> ids = LongStream.range(0, chunks.size())
                .boxed()
                .collect(Collectors.toList());
            
            List<InsertParam.Field> fields = new ArrayList<>();
            fields.add(new InsertParam.Field("id", ids));
            fields.add(new InsertParam.Field("text", chunks));
            fields.add(new InsertParam.Field("vector", embeddings));
            
            // 插入数据
            InsertParam insertParam = InsertParam.newBuilder()
                .withCollectionName(collectionName)
                .withFields(fields)
                .build();
            
            client.insert(insertParam);
            
            // 刷新一下,确保数据可搜索
            client.flush(collectionName);
            
            log.info("Stored {} chunks to Milvus", chunks.size());
            
        } catch (Exception e) {
            log.error("Failed to store documents", e);
            throw new RuntimeException(e);
        }
    }
}

这里踩了个坑:刚开始我直接用字符串长度分块,结果有时候会把一个完整的句子切开,导致语义不完整。后来改成按句子分块,效果好多了。

更智能的分块策略

后来我改进了分块策略,考虑了语义边界:

java 复制代码
private List<String> splitIntoSemanticChunks(String text, int chunkSize) {
    List<String> chunks = new ArrayList<>();
    
    // 先用正则分割成段落
    String[] paragraphs = text.split("\\n\\s*\\n");
    
    for (String para : paragraphs) {
        if (para.length() <= chunkSize) {
            chunks.add(para);
        } else {
            // 段落太长,再按句子分割
            String[] sentences = para.split("[。!?]");
            StringBuilder currentChunk = new StringBuilder();
            
            for (String sentence : sentences) {
                if (currentChunk.length() + sentence.length() > chunkSize) {
                    chunks.add(currentChunk.toString());
                    currentChunk = new StringBuilder(sentence);
                } else {
                    currentChunk.append(sentence).append("。");
                }
            }
            
            if (currentChunk.length() > 0) {
                chunks.add(currentChunk.toString());
            }
        }
    }
    
    return chunks;
}

第四步:Milvus集成

java 复制代码
@Component
@Slf4j
public class MilvusService {
    
    private MilvusServiceClient client;
    
    @PostConstruct
    public void init() {
        ConnectParam connectParam = ConnectParam.newBuilder()
            .withHost("localhost")
            .withPort(19530)
            .build();
        
        this.client = new MilvusServiceClient(connectParam);
        log.info("Connected to Milvus");
    }
    
    // 创建集合
    public void createCollectionIfNotExists(String collectionName) {
        if (hasCollection(collectionName)) {
            log.info("Collection {} already exists", collectionName);
            return;
        }
        
        // 定义字段
        FieldType idField = FieldType.newBuilder()
            .withName("id")
            .withDataType(DataType.Int64)
            .withPrimaryKey(true)
            .withAutoID(false)
            .build();
        
        FieldType textField = FieldType.newBuilder()
            .withName("text")
            .withDataType(DataType.VarChar)
            .withMaxLength(10000)
            .build();
        
        FieldType vectorField = FieldType.newBuilder()
            .withName("vector")
            .withDataType(DataType.FloatVector)
            .withDimension(1536) // OpenAI embedding维度
            .build();
        
        CollectionSchema schema = CollectionSchema.newBuilder()
            .withField(idField)
            .withField(textField)
            .withField(vectorField)
            .build();
        
        CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
            .withCollectionName(collectionName)
            .withSchema(schema)
            .build();
        
        client.createCollection(createParam);
        
        // 创建索引
        IndexType indexType = IndexType.HNSW;
        Map<String, Object> indexParams = new HashMap<>();
        indexParams.put("M", 16);
        indexParams.put("efConstruction", 200);
        
        CreateIndexParam indexParam = CreateIndexParam.newBuilder()
            .withCollectionName(collectionName)
            .withFieldName("vector")
            .withIndexType(indexType)
            .withExtraParam(indexParams)
            .build();
        
        client.createIndex(indexParam);
        
        log.info("Created collection {}", collectionName);
    }
    
    // 搜索相似文档
    public List<String> search(String collectionName, List<Float> queryVector, int topK) {
        List<List<Float>> searchVectors = Collections.singletonList(queryVector);
        
        SearchParam searchParam = SearchParam.newBuilder()
            .withCollectionName(collectionName)
            .withVectorFieldName("vector")
            .withVectors(searchVectors)
            .withTopK(topK)
            .withMetricType(MetricType.L2)
            .withParams("{\"ef\": 64}")
            .withOutFields(Collections.singletonList("text"))
            .build();
        
        R<SearchResults> response = client.search(searchParam);
        
        if (response.getStatus() != R.Status.Success.getCode()) {
            throw new RuntimeException("Search failed: " + response.getMessage());
        }
        
        SearchResults results = response.getData();
        List<String> texts = new ArrayList<>();
        
        for (QueryResultsWrapper.RowRecord record : results.getRowRecords()) {
            Object text = record.get("text");
            if (text != null) {
                texts.add(text.toString());
            }
        }
        
        return texts;
    }
    
    private boolean hasCollection(String collectionName) {
        R<Boolean> response = client.hasCollection(
            HasCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .build()
        );
        return response.getData();
    }
}

第五步:RAG查询服务

核心逻辑:先检索相关文档,再组装prompt,最后调用大模型。

java 复制代码
@Service
@Slf4j
public class RAGService {
    
    @Autowired
    private MilvusService milvusService;
    
    @Autowired
    private EmbeddingClient embeddingClient;
    
    @Autowired
    private ChatClient chatClient;
    
    public String answer(String collectionName, String question) {
        // 1. 问题向量化
        List<Float> questionVector = embeddingClient.embed(question);
        
        // 2. 从Milvus检索相关文档
        List<String> relevantDocs = milvusService.search(
            collectionName, 
            questionVector, 
            5  // 返回top 5
        );
        
        // 3. 组装prompt
        String context = String.join("\n\n", relevantDocs);
        String prompt = buildPrompt(question, context);
        
        // 4. 调用大模型
        String answer = chatClient.call(prompt);
        
        return answer;
    }
    
    private String buildPrompt(String question, String context) {
        return String.format(
            "基于以下文档内容回答问题。如果文档中没有相关信息,请回答" +
            "你不知道,不要编造答案。\n\n" +
            "文档内容:\n%s\n\n" +
            "问题:%s\n\n" +
            "回答:",
            context,
            question
        );
    }
}

实际使用中的优化

1. 混合检索

只用向量检索有时候不够准,我加了关键词检索作为补充:

java 复制代码
public List<String> hybridSearch(
        String collectionName, 
        String question, 
        int topK) {
    
    // 向量检索
    List<Float> questionVector = embeddingClient.embed(question);
    List<String> vectorResults = milvusService.search(
        collectionName, 
        questionVector, 
        topK * 2
    );
    
    // 关键词检索(简单的实现,实际可以用Elasticsearch)
    List<String> keywordResults = keywordSearch(collectionName, question, topK);
    
    // 合并结果,去重
    Set<String> combined = new LinkedHashSet<>();
    combined.addAll(vectorResults);
    combined.addAll(keywordResults);
    
    return new ArrayList<>(combined).subList(0, Math.min(topK, combined.size()));
}

2. 重排序(Rerank)

检索到的文档可能相关性不够好,加个重排序模型:

java 复制代码
public List<String> rerank(String question, List<String> candidates, int topK) {
    // 这里可以接入重排序模型,比如BGE Reranker
    // 简单实现:计算每个候选文档和问题的相似度
    List<Pair<String, Double>> scored = new ArrayList<>();
    
    for (String candidate : candidates) {
        double score = calculateSimilarity(question, candidate);
        scored.add(Pair.of(candidate, score));
    }
    
    // 按分数排序
    scored.sort((a, b) -> Double.compare(b.getRight(), a.getRight()));
    
    return scored.stream()
        .limit(topK)
        .map(Pair::getLeft)
        .collect(Collectors.toList());
}

3. 缓存

相同的问题没必要每次都检索,加个缓存:

java 复制代码
@Cacheable(value = "ragCache", key = "#question")
public String cachedAnswer(String collectionName, String question) {
    return answer(collectionName, question);
}

踩坑总结

  1. 向量维度不匹配:不同模型的embedding维度不一样,OpenAI是1536,有的模型是768。刚开始没注意,存的数据和查询的维度不一样,直接报错。

  2. 分块大小:块太小丢失上下文,块太大检索不准。试了几次,500-1000字比较合适。

  3. 检索数量:返回太多文档,Token消耗大,而且可能引入噪音。一般3-5个就够了。

  4. Milvus索引参数:HNSW的M和ef参数需要调优,默认值不一定适合你的数据量。

完整示例

最后贴个完整的Controller:

java 复制代码
@RestController
@RequestMapping("/api/rag")
@Slf4j
public class RAGController {
    
    @Autowired
    private RAGService ragService;
    
    @Autowired
    private DocumentService documentService;
    
    // 上传文档
    @PostMapping("/upload")
    public ResponseEntity<String> uploadDocument(
            @RequestParam("file") MultipartFile file,
            @RequestParam String collectionName) {
        try {
            // 保存文件
            String filePath = saveFile(file);
            
            // 加载和存储
            List<String> chunks = documentService.loadDocuments(filePath);
            documentService.storeDocuments(collectionName, chunks);
            
            return ResponseEntity.ok("文档上传成功,共处理 " + chunks.size() + " 个文档块");
        } catch (Exception e) {
            log.error("Upload failed", e);
            return ResponseEntity.status(500).body("上传失败:" + e.getMessage());
        }
    }
    
    // 问答
    @PostMapping("/ask")
    public ResponseEntity<String> ask(
            @RequestParam String collectionName,
            @RequestParam String question) {
        try {
            String answer = ragService.answer(collectionName, question);
            return ResponseEntity.ok(answer);
        } catch (Exception e) {
            log.error("Ask failed", e);
            return ResponseEntity.status(500).body("回答失败:" + e.getMessage());
        }
    }
}

下一步优化

现在的版本基本能用,但还有很多优化空间:

  1. 支持更多文档格式:Excel、PPT、图片OCR等
  2. 多模态RAG:支持图片、表格等
  3. 增量更新:文档更新后只更新变化的部分
  4. 多租户支持:不同用户使用不同的知识库

RAG虽然概念简单,但要做好还是有很多细节要处理。如果你也在做RAG相关的项目,欢迎交流经验。

完整代码我放在GitHub上了,需要的同学可以看看。记得给个Star哈哈。

相关推荐
篱笆院的狗2 小时前
Java 中如何创建多线程?
java·开发语言
北极糊的狐2 小时前
若依报错org.springframework.dao.DataIntegrityViolationException
数据库·mysql
晨晖22 小时前
二叉树遍历,先中后序遍历,c++版
开发语言·c++
醒过来摸鱼2 小时前
Java Compiler API使用
java·开发语言·python
wangchen_02 小时前
C/C++时间操作(ctime、chrono)
开发语言·c++
dazhong20122 小时前
Mybatis 敏感数据加解密插件完整实现方案
java·数据库·mybatis
Dev7z2 小时前
基于MATLAB HSI颜色空间的图像美颜系统设计与实现
开发语言·matlab
superman超哥2 小时前
仓颉语言中字符串常用方法的深度剖析与工程实践
开发语言·后端·python·c#·仓颉
薛晓刚2 小时前
2025 年度个人回顾总结
数据库