RAG实战:用Java+向量数据库打造智能问答系统
最近RAG(检索增强生成)这个概念特别火,基本上做AI应用的都在聊这个。我也花了点时间研究了一下,用Java+向量数据库搞了个智能问答系统。今天就把整个过程,包括踩的那些坑,都分享出来。
什么是RAG?为什么需要它?
简单说,RAG就是先把文档向量化存起来,用户提问的时候,先从向量数据库里找到相关的文档片段,然后把这些片段和问题一起扔给大模型,让模型基于这些内容回答。
为什么要这么做?因为大模型虽然有知识,但:
- 训练数据有截止时间,不知道最新的信息
- 知识可能不够精准,比如你公司的内部文档
- 直接问可能胡说八道,但基于文档回答就靠谱多了
技术选型:选哪个向量数据库?
Java生态下,常用的向量数据库有这几个:
- Milvus:功能强大,性能好,但部署稍微复杂
- Chroma:简单易用,Python生态很成熟,Java支持一般
- Qdrant:Rust写的,性能不错,API也挺友好
- PostgreSQL + pgvector:如果项目本来就用PostgreSQL,这个最省事
我选了Milvus,主要是因为:
- 社区活跃,文档比较全
- Java客户端支持不错
- 性能确实好,支持大规模数据
第一步:搭建Milvus
Milvus可以用Docker快速启动:
bash
# 下载docker-compose.yml
wget https://github.com/milvus-io/milvus/releases/download/v2.3.0/milvus-standalone-docker-compose.yml -O docker-compose.yml
# 启动
docker-compose up -d
启动后,Milvus会在19530端口提供gRPC服务,9200端口是REST API。
第二步:项目依赖
xml
<dependencies>
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>1.0.0</version>
</dependency>
<!-- 文档解析 -->
<dependency>
<groupId>org.apache.tika</groupId>
<artifactId>tika-core</artifactId>
<version>2.9.0</version>
</dependency>
</dependencies>
第三步:文档处理和向量化
这是RAG的核心部分。先把文档加载进来,然后分块,最后向量化存到Milvus。
文档加载和分块
java
@Service
@Slf4j
public class DocumentService {
@Autowired
private EmbeddingClient embeddingClient;
// 加载文档(支持PDF、Word、TXT等)
public List<String> loadDocuments(String filePath) {
List<String> chunks = new ArrayList<>();
try {
// 用Tika解析文档
Tika tika = new Tika();
String content = tika.parseToString(new File(filePath));
// 按段落分块,每块500字左右
chunks = splitIntoChunks(content, 500);
} catch (Exception e) {
log.error("Failed to load document", e);
}
return chunks;
}
// 简单的分块策略:按句号和换行分割
private List<String> splitIntoChunks(String text, int chunkSize) {
List<String> chunks = new ArrayList<>();
String[] sentences = text.split("[。\n]");
StringBuilder currentChunk = new StringBuilder();
for (String sentence : sentences) {
if (currentChunk.length() + sentence.length() > chunkSize) {
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString());
currentChunk = new StringBuilder();
}
}
currentChunk.append(sentence).append("。");
}
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString());
}
return chunks;
}
// 向量化并存入Milvus
public void storeDocuments(String collectionName, List<String> chunks) {
try {
MilvusServiceClient client = getMilvusClient();
// 创建集合(如果不存在)
createCollectionIfNotExists(client, collectionName);
// 批量向量化
List<List<Float>> embeddings = new ArrayList<>();
for (String chunk : chunks) {
List<Float> embedding = embeddingClient.embed(chunk);
embeddings.add(embedding);
}
// 构建数据
List<Long> ids = LongStream.range(0, chunks.size())
.boxed()
.collect(Collectors.toList());
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field("id", ids));
fields.add(new InsertParam.Field("text", chunks));
fields.add(new InsertParam.Field("vector", embeddings));
// 插入数据
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(collectionName)
.withFields(fields)
.build();
client.insert(insertParam);
// 刷新一下,确保数据可搜索
client.flush(collectionName);
log.info("Stored {} chunks to Milvus", chunks.size());
} catch (Exception e) {
log.error("Failed to store documents", e);
throw new RuntimeException(e);
}
}
}
这里踩了个坑:刚开始我直接用字符串长度分块,结果有时候会把一个完整的句子切开,导致语义不完整。后来改成按句子分块,效果好多了。
更智能的分块策略
后来我改进了分块策略,考虑了语义边界:
java
private List<String> splitIntoSemanticChunks(String text, int chunkSize) {
List<String> chunks = new ArrayList<>();
// 先用正则分割成段落
String[] paragraphs = text.split("\\n\\s*\\n");
for (String para : paragraphs) {
if (para.length() <= chunkSize) {
chunks.add(para);
} else {
// 段落太长,再按句子分割
String[] sentences = para.split("[。!?]");
StringBuilder currentChunk = new StringBuilder();
for (String sentence : sentences) {
if (currentChunk.length() + sentence.length() > chunkSize) {
chunks.add(currentChunk.toString());
currentChunk = new StringBuilder(sentence);
} else {
currentChunk.append(sentence).append("。");
}
}
if (currentChunk.length() > 0) {
chunks.add(currentChunk.toString());
}
}
}
return chunks;
}
第四步:Milvus集成
java
@Component
@Slf4j
public class MilvusService {
private MilvusServiceClient client;
@PostConstruct
public void init() {
ConnectParam connectParam = ConnectParam.newBuilder()
.withHost("localhost")
.withPort(19530)
.build();
this.client = new MilvusServiceClient(connectParam);
log.info("Connected to Milvus");
}
// 创建集合
public void createCollectionIfNotExists(String collectionName) {
if (hasCollection(collectionName)) {
log.info("Collection {} already exists", collectionName);
return;
}
// 定义字段
FieldType idField = FieldType.newBuilder()
.withName("id")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(false)
.build();
FieldType textField = FieldType.newBuilder()
.withName("text")
.withDataType(DataType.VarChar)
.withMaxLength(10000)
.build();
FieldType vectorField = FieldType.newBuilder()
.withName("vector")
.withDataType(DataType.FloatVector)
.withDimension(1536) // OpenAI embedding维度
.build();
CollectionSchema schema = CollectionSchema.newBuilder()
.withField(idField)
.withField(textField)
.withField(vectorField)
.build();
CreateCollectionParam createParam = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withSchema(schema)
.build();
client.createCollection(createParam);
// 创建索引
IndexType indexType = IndexType.HNSW;
Map<String, Object> indexParams = new HashMap<>();
indexParams.put("M", 16);
indexParams.put("efConstruction", 200);
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName("vector")
.withIndexType(indexType)
.withExtraParam(indexParams)
.build();
client.createIndex(indexParam);
log.info("Created collection {}", collectionName);
}
// 搜索相似文档
public List<String> search(String collectionName, List<Float> queryVector, int topK) {
List<List<Float>> searchVectors = Collections.singletonList(queryVector);
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(collectionName)
.withVectorFieldName("vector")
.withVectors(searchVectors)
.withTopK(topK)
.withMetricType(MetricType.L2)
.withParams("{\"ef\": 64}")
.withOutFields(Collections.singletonList("text"))
.build();
R<SearchResults> response = client.search(searchParam);
if (response.getStatus() != R.Status.Success.getCode()) {
throw new RuntimeException("Search failed: " + response.getMessage());
}
SearchResults results = response.getData();
List<String> texts = new ArrayList<>();
for (QueryResultsWrapper.RowRecord record : results.getRowRecords()) {
Object text = record.get("text");
if (text != null) {
texts.add(text.toString());
}
}
return texts;
}
private boolean hasCollection(String collectionName) {
R<Boolean> response = client.hasCollection(
HasCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build()
);
return response.getData();
}
}
第五步:RAG查询服务
核心逻辑:先检索相关文档,再组装prompt,最后调用大模型。
java
@Service
@Slf4j
public class RAGService {
@Autowired
private MilvusService milvusService;
@Autowired
private EmbeddingClient embeddingClient;
@Autowired
private ChatClient chatClient;
public String answer(String collectionName, String question) {
// 1. 问题向量化
List<Float> questionVector = embeddingClient.embed(question);
// 2. 从Milvus检索相关文档
List<String> relevantDocs = milvusService.search(
collectionName,
questionVector,
5 // 返回top 5
);
// 3. 组装prompt
String context = String.join("\n\n", relevantDocs);
String prompt = buildPrompt(question, context);
// 4. 调用大模型
String answer = chatClient.call(prompt);
return answer;
}
private String buildPrompt(String question, String context) {
return String.format(
"基于以下文档内容回答问题。如果文档中没有相关信息,请回答" +
"你不知道,不要编造答案。\n\n" +
"文档内容:\n%s\n\n" +
"问题:%s\n\n" +
"回答:",
context,
question
);
}
}
实际使用中的优化
1. 混合检索
只用向量检索有时候不够准,我加了关键词检索作为补充:
java
public List<String> hybridSearch(
String collectionName,
String question,
int topK) {
// 向量检索
List<Float> questionVector = embeddingClient.embed(question);
List<String> vectorResults = milvusService.search(
collectionName,
questionVector,
topK * 2
);
// 关键词检索(简单的实现,实际可以用Elasticsearch)
List<String> keywordResults = keywordSearch(collectionName, question, topK);
// 合并结果,去重
Set<String> combined = new LinkedHashSet<>();
combined.addAll(vectorResults);
combined.addAll(keywordResults);
return new ArrayList<>(combined).subList(0, Math.min(topK, combined.size()));
}
2. 重排序(Rerank)
检索到的文档可能相关性不够好,加个重排序模型:
java
public List<String> rerank(String question, List<String> candidates, int topK) {
// 这里可以接入重排序模型,比如BGE Reranker
// 简单实现:计算每个候选文档和问题的相似度
List<Pair<String, Double>> scored = new ArrayList<>();
for (String candidate : candidates) {
double score = calculateSimilarity(question, candidate);
scored.add(Pair.of(candidate, score));
}
// 按分数排序
scored.sort((a, b) -> Double.compare(b.getRight(), a.getRight()));
return scored.stream()
.limit(topK)
.map(Pair::getLeft)
.collect(Collectors.toList());
}
3. 缓存
相同的问题没必要每次都检索,加个缓存:
java
@Cacheable(value = "ragCache", key = "#question")
public String cachedAnswer(String collectionName, String question) {
return answer(collectionName, question);
}
踩坑总结
-
向量维度不匹配:不同模型的embedding维度不一样,OpenAI是1536,有的模型是768。刚开始没注意,存的数据和查询的维度不一样,直接报错。
-
分块大小:块太小丢失上下文,块太大检索不准。试了几次,500-1000字比较合适。
-
检索数量:返回太多文档,Token消耗大,而且可能引入噪音。一般3-5个就够了。
-
Milvus索引参数:HNSW的M和ef参数需要调优,默认值不一定适合你的数据量。
完整示例
最后贴个完整的Controller:
java
@RestController
@RequestMapping("/api/rag")
@Slf4j
public class RAGController {
@Autowired
private RAGService ragService;
@Autowired
private DocumentService documentService;
// 上传文档
@PostMapping("/upload")
public ResponseEntity<String> uploadDocument(
@RequestParam("file") MultipartFile file,
@RequestParam String collectionName) {
try {
// 保存文件
String filePath = saveFile(file);
// 加载和存储
List<String> chunks = documentService.loadDocuments(filePath);
documentService.storeDocuments(collectionName, chunks);
return ResponseEntity.ok("文档上传成功,共处理 " + chunks.size() + " 个文档块");
} catch (Exception e) {
log.error("Upload failed", e);
return ResponseEntity.status(500).body("上传失败:" + e.getMessage());
}
}
// 问答
@PostMapping("/ask")
public ResponseEntity<String> ask(
@RequestParam String collectionName,
@RequestParam String question) {
try {
String answer = ragService.answer(collectionName, question);
return ResponseEntity.ok(answer);
} catch (Exception e) {
log.error("Ask failed", e);
return ResponseEntity.status(500).body("回答失败:" + e.getMessage());
}
}
}
下一步优化
现在的版本基本能用,但还有很多优化空间:
- 支持更多文档格式:Excel、PPT、图片OCR等
- 多模态RAG:支持图片、表格等
- 增量更新:文档更新后只更新变化的部分
- 多租户支持:不同用户使用不同的知识库
RAG虽然概念简单,但要做好还是有很多细节要处理。如果你也在做RAG相关的项目,欢迎交流经验。
完整代码我放在GitHub上了,需要的同学可以看看。记得给个Star哈哈。