RAG知识库从零到一:简单搭建教程(java版)

前言

虽然Python在AI领域占据主导地位,但在企业级应用中,Java依然是大数据和技术栈的主力。本文将用纯Java实现一个完整的RAG知识库系统,不依赖Python,适合Java技术栈的开发者。


一、技术选型

组件 选型 说明
向量数据库 Qdrant 性能优秀,Rust编写,提供Java SDK
Embedding模型 Ollama + bge-m3 本地部署,可通过HTTP调用
LLM DeepSeek / Ollama 推荐DeepSeek API,国内访问方便
文档解析 Apache PDFBox / Tika PDF、Word等文档解析
HTTP客户端 OkHttp / RestClient 调用Embedding和LLM接口
JSON处理 Jackson 序列化/反序列化
日志 SLF4J + Logback 统一日志

环境前提

  • JDK 17+

  • Maven / Gradle

  • Docker(运行Qdrant)

  • Ollama(本地运行Embedding模型)


二、项目搭建

2.1 Maven依赖

xml

复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
         http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.example</groupId>
    <artifactId>rag-knowledge-base</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>17</maven.compiler.source>
        <maven.compiler.target>17</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>

    <dependencies>
        <!-- Qdrant Java 客户端 -->
        <dependency>
            <groupId>io.qdrant</groupId>
            <artifactId>client</artifactId>
            <version>1.8.0</version>
        </dependency>

        <!-- HTTP 客户端 -->
        <dependency>
            <groupId>com.squareup.okhttp3</groupId>
            <artifactId>okhttp</artifactId>
            <version>4.11.0</version>
        </dependency>

        <!-- JSON 处理 -->
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>2.15.2</version>
        </dependency>

        <!-- PDF 解析 -->
        <dependency>
            <groupId>org.apache.pdfbox</groupId>
            <artifactId>pdfbox</artifactId>
            <version>3.0.0</version>
        </dependency>

        <!-- Word 解析 -->
        <dependency>
            <groupId>org.apache.poi</groupId>
            <artifactId>poi-ooxml</artifactId>
            <version>5.2.3</version>
        </dependency>

        <!-- 工具类 -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.28</version>
            <scope>provided</scope>
        </dependency>

        <!-- 日志 -->
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.4.11</version>
        </dependency>
    </dependencies>

</project>

2.2 启动Qdrant(Docker)

bash

复制代码
# 拉取并启动 Qdrant
docker run -d \
  --name qdrant \
  -p 6333:6333 \
  -p 6334:6334 \
  -v ./qdrant_storage:/qdrant/storage \
  qdrant/qdrant:latest

2.3 启动Ollama并下载模型

bash

复制代码
# 安装 Ollama(官网下载)
# 下载 Chinese Embedding 模型
ollama pull bge-m3:latest

# 可选:下载本地 LLM
ollama pull qwen2:7b

三、核心代码实现

3.1 配置类

java

复制代码
package com.rag.config;

import lombok.Data;

@Data
public class RagConfig {
    // Qdrant 配置
    private String qdrantHost = "localhost";
    private int qdrantPort = 6333;
    private String collectionName = "knowledge_base";
    
    // Embedding 配置
    private String embeddingUrl = "http://localhost:11434/api/embeddings";
    private String embeddingModel = "bge-m3:latest";
    
    // LLM 配置
    private String llmUrl = "https://api.deepseek.com/v1/chat/completions";
    private String llmApiKey = "your-api-key";
    private String llmModel = "deepseek-chat";
    
    // 分块配置
    private int chunkSize = 500;
    private int chunkOverlap = 50;
    
    // 检索配置
    private int topK = 3;
}

3.2 文档模型

java

复制代码
package com.rag.model;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;

@Data
@NoArgsConstructor
@AllArgsConstructor
public class DocumentChunk {
    private String id;
    private String content;
    private Map<String, Object> metadata;
    private List<Float> embedding;
}

@Data
@NoArgsConstructor
@AllArgsConstructor
public class SearchResult {
    private String content;
    private Map<String, Object> metadata;
    private float score;
}

3.3 文档加载与解析

java

复制代码
package com.rag.loader;

import lombok.extern.slf4j.Slf4j;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.text.PDFTextStripper;
import org.apache.poi.xwpf.usermodel.XWPFDocument;
import org.apache.poi.xwpf.usermodel.XWPFParagraph;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;

@Slf4j
public class DocumentLoader {
    
    /**
     * 加载 PDF 文件
     */
    public String loadPdf(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (PDDocument document = PDDocument.load(new File(filePath))) {
            PDFTextStripper stripper = new PDFTextStripper();
            content.append(stripper.getText(document));
        }
        return content.toString();
    }
    
    /**
     * 加载 Word 文件
     */
    public String loadWord(String filePath) throws IOException {
        StringBuilder content = new StringBuilder();
        try (FileInputStream fis = new FileInputStream(filePath);
             XWPFDocument document = new XWPFDocument(fis)) {
            for (XWPFParagraph paragraph : document.getParagraphs()) {
                content.append(paragraph.getText()).append("\n");
            }
        }
        return content.toString();
    }
    
    /**
     * 加载纯文本文件
     */
    public String loadText(String filePath) throws IOException {
        return new String(java.nio.file.Files.readAllBytes(java.nio.file.Paths.get(filePath)));
    }
    
    /**
     * 根据文件扩展名自动加载
     */
    public String loadDocument(String filePath) throws IOException {
        if (filePath.endsWith(".pdf")) {
            return loadPdf(filePath);
        } else if (filePath.endsWith(".docx") || filePath.endsWith(".doc")) {
            return loadWord(filePath);
        } else {
            return loadText(filePath);
        }
    }
}

3.4 文本分块器

java

复制代码
package com.rag.splitter;

import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;

@Slf4j
public class TextSplitter {
    private final int chunkSize;
    private final int chunkOverlap;
    
    public TextSplitter(int chunkSize, int chunkOverlap) {
        this.chunkSize = chunkSize;
        this.chunkOverlap = chunkOverlap;
    }
    
    /**
     * 将文本分割成多个块
     */
    public List<String> splitText(String text) {
        List<String> chunks = new ArrayList<>();
        
        // 先按段落分割
        String[] paragraphs = text.split("\\n\\n+");
        StringBuilder currentChunk = new StringBuilder();
        
        for (String paragraph : paragraphs) {
            // 如果单个段落超过 chunkSize,需要进一步分割
            if (paragraph.length() > chunkSize) {
                if (currentChunk.length() > 0) {
                    chunks.add(currentChunk.toString());
                    currentChunk = new StringBuilder();
                }
                chunks.addAll(splitLongParagraph(paragraph));
                continue;
            }
            
            // 如果加入当前段落后超过 chunkSize,保存当前块并开始新块
            if (currentChunk.length() + paragraph.length() + 2 > chunkSize) {
                chunks.add(currentChunk.toString());
                currentChunk = new StringBuilder();
                // 添加重叠部分(从已保存块中取最后 overlap 个字符)
                if (chunks.size() > 0 && chunkOverlap > 0) {
                    String lastChunk = chunks.get(chunks.size() - 1);
                    int overlapStart = Math.max(0, lastChunk.length() - chunkOverlap);
                    currentChunk.append(lastChunk.substring(overlapStart));
                }
            }
            
            if (currentChunk.length() > 0) {
                currentChunk.append("\n\n");
            }
            currentChunk.append(paragraph);
        }
        
        if (currentChunk.length() > 0) {
            chunks.add(currentChunk.toString());
        }
        
        log.info("文本分割完成,共 {} 个块", chunks.size());
        return chunks;
    }
    
    /**
     * 分割过长的段落
     */
    private List<String> splitLongParagraph(String paragraph) {
        List<String> chunks = new ArrayList<>();
        int start = 0;
        
        while (start < paragraph.length()) {
            int end = Math.min(start + chunkSize, paragraph.length());
            // 尽量在句子边界分割
            int lastPeriod = paragraph.lastIndexOf('。', end);
            int lastNewLine = paragraph.lastIndexOf('\n', end);
            int splitPos = Math.max(lastPeriod, lastNewLine);
            
            if (splitPos > start) {
                end = splitPos + 1;
            }
            
            chunks.add(paragraph.substring(start, Math.min(end, paragraph.length())));
            start = end;
        }
        
        return chunks;
    }
}

3.5 Embedding客户端(调用Ollama)

java

复制代码
package com.rag.embedding;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

@Slf4j
public class EmbeddingClient {
    private final OkHttpClient httpClient;
    private final ObjectMapper objectMapper;
    private final String embeddingUrl;
    private final String model;
    
    public EmbeddingClient(String embeddingUrl, String model) {
        this.httpClient = new OkHttpClient.Builder()
                .connectTimeout(30, TimeUnit.SECONDS)
                .readTimeout(60, TimeUnit.SECONDS)
                .build();
        this.objectMapper = new ObjectMapper();
        this.embeddingUrl = embeddingUrl;
        this.model = model;
    }
    
    /**
     * 获取单个文本的向量
     */
    public List<Float> getEmbedding(String text) throws IOException {
        return getEmbeddings(List.of(text)).get(0);
    }
    
    /**
     * 批量获取文本向量
     */
    public List<List<Float>> getEmbeddings(List<String> texts) throws IOException {
        List<List<Float>> allEmbeddings = new ArrayList<>();
        
        for (String text : texts) {
            // Ollama API 格式
            String json = String.format("""
                {
                    "model": "%s",
                    "prompt": "%s"
                }
                """, model, escapeJson(text));
            
            Request request = new Request.Builder()
                    .url(embeddingUrl)
                    .post(RequestBody.create(json, MediaType.parse("application/json")))
                    .build();
            
            try (Response response = httpClient.newCall(request).execute()) {
                if (!response.isSuccessful()) {
                    throw new IOException("Embedding API 调用失败: " + response.code());
                }
                
                String responseBody = response.body().string();
                JsonNode jsonNode = objectMapper.readTree(responseBody);
                JsonNode embeddingArray = jsonNode.get("embedding");
                
                List<Float> embedding = new ArrayList<>();
                for (JsonNode value : embeddingArray) {
                    embedding.add(value.floatValue());
                }
                
                allEmbeddings.add(embedding);
            }
        }
        
        return allEmbeddings;
    }
    
    private String escapeJson(String text) {
        return text.replace("\\", "\\\\")
                   .replace("\"", "\\\"")
                   .replace("\n", "\\n")
                   .replace("\r", "\\r")
                   .replace("\t", "\\t");
    }
}

3.6 Qdrant向量数据库操作类

java

复制代码
package com.rag.vectorstore;

import com.rag.model.DocumentChunk;
import com.rag.model.SearchResult;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.Points;
import lombok.extern.slf4j.Slf4j;

import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

@Slf4j
public class QdrantVectorStore {
    private final QdrantClient client;
    private final String collectionName;
    private final int vectorSize;
    
    public QdrantVectorStore(String host, int port, String collectionName, int vectorSize) {
        this.client = new QdrantClient(
            QdrantGrpcClient.newBuilder(host, port, false).build()
        );
        this.collectionName = collectionName;
        this.vectorSize = vectorSize;
        initCollection();
    }
    
    /**
     * 初始化集合
     */
    private void initCollection() {
        try {
            // 检查集合是否存在
            var exists = client.collectionExistsAsync(collectionName).get();
            if (!exists) {
                // 创建集合
                var vectorParams = Collections.VectorParams.newBuilder()
                        .setSize(vectorSize)
                        .setDistance(Collections.Distance.Cosine)
                        .build();
                
                client.createCollectionAsync(collectionName,
                        Collections.CreateCollection.newBuilder()
                                .setVectorsConfig(
                                    Collections.VectorsConfig.newBuilder()
                                        .setParams(vectorParams)
                                        .build())
                                .build()).get();
                
                log.info("创建集合: {}", collectionName);
            } else {
                log.info("集合已存在: {}", collectionName);
            }
        } catch (InterruptedException | ExecutionException e) {
            log.error("初始化集合失败", e);
            throw new RuntimeException(e);
        }
    }
    
    /**
     * 插入文档块
     */
    public void upsert(List<DocumentChunk> chunks) {
        try {
            List<Points.PointStruct> points = new ArrayList<>();
            
            for (DocumentChunk chunk : chunks) {
                var point = Points.PointStruct.newBuilder()
                        .setId(Points.PointId.newBuilder()
                                .setUuid(chunk.getId())
                                .build())
                        .putAllPayload(convertMetadataToPayload(chunk.getMetadata()))
                        .putPayload("content", Points.Value.newBuilder()
                                .setStringValue(chunk.getContent())
                                .build())
                        .setVectors(Points.Vectors.newBuilder()
                                .putVector("", Points.Vector.newBuilder()
                                        .addAllData(chunk.getEmbedding())
                                        .build())
                                .build())
                        .build();
                points.add(point);
            }
            
            client.upsertAsync(collectionName, points).get();
            log.info("插入 {} 个文档块", points.size());
            
        } catch (InterruptedException | ExecutionException e) {
            log.error("插入失败", e);
            throw new RuntimeException(e);
        }
    }
    
    /**
     * 搜索相似向量
     */
    public List<SearchResult> search(List<Float> queryVector, int topK) {
        try {
            var searchResult = client.searchAsync(Points.SearchPoints.newBuilder()
                    .setCollectionName(collectionName)
                    .addAllVector(queryVector)
                    .setLimit(topK)
                    .setWithPayload(Points.WithPayloadSelector.newBuilder()
                            .setEnable(true)
                            .build())
                    .build()).get();
            
            return searchResult.stream().map(scoredPoint -> {
                SearchResult result = new SearchResult();
                result.setScore(scoredPoint.getScore());
                result.setContent(scoredPoint.getPayloadMap().get("content").getStringValue());
                
                Map<String, Object> metadata = new HashMap<>();
                for (var entry : scoredPoint.getPayloadMap().entrySet()) {
                    if (!entry.getKey().equals("content")) {
                        metadata.put(entry.getKey(), extractValue(entry.getValue()));
                    }
                }
                result.setMetadata(metadata);
                return result;
            }).collect(Collectors.toList());
            
        } catch (InterruptedException | ExecutionException e) {
            log.error("搜索失败", e);
            throw new RuntimeException(e);
        }
    }
    
    /**
     * 获取集合中的文档数量
     */
    public long getCollectionCount() {
        try {
            var info = client.collectionInfoAsync(collectionName).get();
            return info.getPointsCount();
        } catch (InterruptedException | ExecutionException e) {
            log.error("获取数量失败", e);
            return 0;
        }
    }
    
    private Map<String, Points.Value> convertMetadataToPayload(Map<String, Object> metadata) {
        Map<String, Points.Value> payload = new HashMap<>();
        for (var entry : metadata.entrySet()) {
            if (entry.getValue() instanceof String) {
                payload.put(entry.getKey(), Points.Value.newBuilder()
                        .setStringValue((String) entry.getValue())
                        .build());
            } else if (entry.getValue() instanceof Integer) {
                payload.put(entry.getKey(), Points.Value.newBuilder()
                        .setIntegerValue((Integer) entry.getValue())
                        .build());
            } else if (entry.getValue() instanceof Double) {
                payload.put(entry.getKey(), Points.Value.newBuilder()
                        .setDoubleValue((Double) entry.getValue())
                        .build());
            } else if (entry.getValue() instanceof Boolean) {
                payload.put(entry.getKey(), Points.Value.newBuilder()
                        .setBoolValue((Boolean) entry.getValue())
                        .build());
            }
        }
        return payload;
    }
    
    private Object extractValue(Points.Value value) {
        if (value.hasStringValue()) {
            return value.getStringValue();
        } else if (value.hasIntegerValue()) {
            return value.getIntegerValue();
        } else if (value.hasDoubleValue()) {
            return value.getDoubleValue();
        } else if (value.hasBoolValue()) {
            return value.getBoolValue();
        }
        return null;
    }
}

3.7 LLM调用客户端

java

复制代码
package com.rag.llm;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;

@Slf4j
public class LlmClient {
    private final OkHttpClient httpClient;
    private final ObjectMapper objectMapper;
    private final String llmUrl;
    private final String apiKey;
    private final String model;
    
    public LlmClient(String llmUrl, String apiKey, String model) {
        this.httpClient = new OkHttpClient.Builder()
                .connectTimeout(60, TimeUnit.SECONDS)
                .readTimeout(120, TimeUnit.SECONDS)
                .build();
        this.objectMapper = new ObjectMapper();
        this.llmUrl = llmUrl;
        this.apiKey = apiKey;
        this.model = model;
    }
    
    /**
     * 基于上下文生成回答
     */
    public String generateAnswer(String question, List<String> contexts, List<SearchResult> sources) throws IOException {
        StringBuilder contextBuilder = new StringBuilder();
        for (int i = 0; i < contexts.size(); i++) {
            contextBuilder.append("【参考内容").append(i + 1).append("】\n");
            contextBuilder.append(contexts.get(i)).append("\n\n");
        }
        
        String systemPrompt = "你是一个专业的助手,请基于提供的参考内容回答用户问题。" +
                              "如果参考内容不包含答案,请明确说明'根据现有资料无法回答'。" +
                              "回答时请引用具体的参考来源。";
        
        String userPrompt = String.format("""
            参考内容:
            %s
            
            用户问题:%s
            
            请根据参考内容回答用户问题。
            """, contextBuilder.toString(), question);
        
        // 构造 OpenAI 兼容的请求格式
        ObjectNode requestBody = objectMapper.createObjectNode();
        requestBody.put("model", model);
        requestBody.put("temperature", 0.3);
        requestBody.put("max_tokens", 1024);
        
        ArrayNode messages = objectMapper.createArrayNode();
        
        ObjectNode systemMessage = objectMapper.createObjectNode();
        systemMessage.put("role", "system");
        systemMessage.put("content", systemPrompt);
        messages.add(systemMessage);
        
        ObjectNode userMessage = objectMapper.createObjectNode();
        userMessage.put("role", "user");
        userMessage.put("content", userPrompt);
        messages.add(userMessage);
        
        requestBody.set("messages", messages);
        
        Request request = new Request.Builder()
                .url(llmUrl)
                .addHeader("Authorization", "Bearer " + apiKey)
                .addHeader("Content-Type", "application/json")
                .post(RequestBody.create(requestBody.toString(), MediaType.parse("application/json")))
                .build();
        
        try (Response response = httpClient.newCall(request).execute()) {
            if (!response.isSuccessful()) {
                log.error("LLM API 调用失败: {}", response.code());
                return "AI服务暂时不可用,请稍后重试";
            }
            
            String responseBody = response.body().string();
            JsonNode jsonNode = objectMapper.readTree(responseBody);
            String answer = jsonNode.path("choices").path(0).path("message").path("content").asText();
            
            // 附加引用来源
            StringBuilder finalAnswer = new StringBuilder(answer);
            finalAnswer.append("\n\n---\n**参考来源:**\n");
            for (SearchResult source : sources) {
                String sourceName = source.getMetadata().getOrDefault("source", "未知来源").toString();
                finalAnswer.append("- ").append(sourceName).append("\n");
            }
            
            return finalAnswer.toString();
        }
    }
}

3.8 RAG服务主类

java

复制代码
package com.rag.service;

import com.rag.config.RagConfig;
import com.rag.embedding.EmbeddingClient;
import com.rag.llm.LlmClient;
import com.rag.loader.DocumentLoader;
import com.rag.model.DocumentChunk;
import com.rag.model.SearchResult;
import com.rag.splitter.TextSplitter;
import com.rag.vectorstore.QdrantVectorStore;
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;

@Slf4j
public class RagService {
    private final RagConfig config;
    private final DocumentLoader documentLoader;
    private final TextSplitter textSplitter;
    private final EmbeddingClient embeddingClient;
    private final QdrantVectorStore vectorStore;
    private final LlmClient llmClient;
    
    public RagService(RagConfig config) {
        this.config = config;
        this.documentLoader = new DocumentLoader();
        this.textSplitter = new TextSplitter(config.getChunkSize(), config.getChunkOverlap());
        this.embeddingClient = new EmbeddingClient(config.getEmbeddingUrl(), config.getEmbeddingModel());
        // 注意:向量维度需要根据实际模型确定(bge-m3:latest 是 1024 维)
        this.vectorStore = new QdrantVectorStore(
            config.getQdrantHost(), 
            config.getQdrantPort(), 
            config.getCollectionName(),
            1024
        );
        this.llmClient = new LlmClient(config.getLlmUrl(), config.getLlmApiKey(), config.getLlmModel());
    }
    
    /**
     * 构建知识库:加载文档 → 分块 → 向量化 → 存入向量库
     */
    public void buildKnowledgeBase(String documentPath, String sourceName) throws IOException {
        log.info("开始构建知识库: {}", documentPath);
        
        // 1. 加载文档
        String fullText = documentLoader.loadDocument(documentPath);
        log.info("文档加载完成,长度: {} 字符", fullText.length());
        
        // 2. 文本分块
        List<String> chunks = textSplitter.splitText(fullText);
        log.info("文档分块完成,共 {} 块", chunks.size());
        
        // 3. 批量向量化
        List<List<Float>> embeddings = embeddingClient.getEmbeddings(chunks);
        log.info("向量化完成");
        
        // 4. 构建 DocumentChunk
        List<DocumentChunk> documentChunks = new ArrayList<>();
        for (int i = 0; i < chunks.size(); i++) {
            Map<String, Object> metadata = new HashMap<>();
            metadata.put("source", sourceName);
            metadata.put("chunk_index", i);
            metadata.put("timestamp", System.currentTimeMillis());
            
            DocumentChunk chunk = new DocumentChunk();
            chunk.setId(UUID.randomUUID().toString());
            chunk.setContent(chunks.get(i));
            chunk.setMetadata(metadata);
            chunk.setEmbedding(embeddings.get(i));
            
            documentChunks.add(chunk);
        }
        
        // 5. 存入向量数据库
        vectorStore.upsert(documentChunks);
        log.info("知识库构建完成,共 {} 个文档片段", vectorStore.getCollectionCount());
    }
    
    /**
     * 问答:接收问题 → 向量化 → 检索 → LLM生成回答
     */
    public String ask(String question) throws IOException {
        log.info("用户问题: {}", question);
        
        // 1. 向量化用户问题
        List<Float> queryVector = embeddingClient.getEmbedding(question);
        
        // 2. 向量检索
        List<SearchResult> searchResults = vectorStore.search(queryVector, config.getTopK());
        log.info("检索到 {} 个相关片段", searchResults.size());
        
        if (searchResults.isEmpty()) {
            return "未找到相关内容,请尝试其他问题。";
        }
        
        // 3. 提取上下文
        List<String> contexts = searchResults.stream()
                .map(SearchResult::getContent)
                .collect(Collectors.toList());
        
        // 4. 调用 LLM 生成回答
        String answer = llmClient.generateAnswer(question, contexts, searchResults);
        log.info("回答生成完成,长度: {} 字符", answer.length());
        
        return answer;
    }
    
    /**
     * 批量构建知识库(支持多个文档)
     */
    public void buildKnowledgeBase(Map<String, String> documents) throws IOException {
        for (var entry : documents.entrySet()) {
            buildKnowledgeBase(entry.getValue(), entry.getKey());
        }
    }
}

3.9 启动类和测试

java

复制代码
package com.rag;

import com.rag.config.RagConfig;
import com.rag.service.RagService;
import lombok.extern.slf4j.Slf4j;

import java.util.Scanner;

@Slf4j
public class Application {
    public static void main(String[] args) {
        try {
            // 初始化配置
            RagConfig config = new RagConfig();
            // 如果使用 DeepSeek,填入你的 API Key
            // config.setLlmApiKey("sk-your-deepseek-api-key");
            
            // 如果使用本地 Ollama LLM,修改配置
            // config.setLlmUrl("http://localhost:11434/api/generate");
            // config.setLlmModel("qwen2:7b");
            
            RagService ragService = new RagService(config);
            
            // 检查知识库是否为空
            log.info("当前知识库文档数: {}", ragService.getVectorStore().getCollectionCount());
            
            // 如果需要重建知识库,取消注释
            // log.info("构建知识库...");
            // ragService.buildKnowledgeBase("./docs/product_manual.pdf", "产品手册.pdf");
            
            // 交互式问答
            Scanner scanner = new Scanner(System.in);
            System.out.println("RAG知识库问答系统已启动(输入 exit 退出)");
            System.out.println("=" .repeat(50));
            
            while (true) {
                System.out.print("\n请输入问题: ");
                String question = scanner.nextLine().trim();
                
                if ("exit".equalsIgnoreCase(question)) {
                    System.out.println("再见!");
                    break;
                }
                
                if (question.isEmpty()) {
                    continue;
                }
                
                try {
                    long startTime = System.currentTimeMillis();
                    String answer = ragService.ask(question);
                    long endTime = System.currentTimeMillis();
                    
                    System.out.println("\n回答:");
                    System.out.println(answer);
                    System.out.println("\n[耗时: " + (endTime - startTime) + "ms]");
                } catch (Exception e) {
                    log.error("问答失败", e);
                    System.out.println("问答失败: " + e.getMessage());
                }
            }
            
            scanner.close();
            
        } catch (Exception e) {
            log.error("应用启动失败", e);
        }
    }
}

四、使用指南

4.1 环境准备

bash

复制代码
# 1. 启动 Qdrant
docker run -d --name qdrant -p 6333:6333 -p 6334:6334 qdrant/qdrant:latest

# 2. 启动 Ollama 并下载模型
ollama pull bge-m3:latest   # Embedding 模型
ollama pull qwen2:7b        # 可选:本地 LLM

# 3. 编译 Java 项目
mvn clean compile

# 4. 运行
mvn exec:java -Dexec.mainClass="com.rag.Application"

4.2 API 接口封装(Spring Boot)

如果你希望提供 REST API,可以快速封装:

java

复制代码
@RestController
@RequestMapping("/api/rag")
public class RagController {
    
    @Autowired
    private RagService ragService;
    
    @PostMapping("/ask")
    public ResponseEntity<AskResponse> ask(@RequestBody AskRequest request) {
        try {
            String answer = ragService.ask(request.getQuestion());
            return ResponseEntity.ok(new AskResponse(answer));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(new AskResponse("错误: " + e.getMessage()));
        }
    }
    
    @PostMapping("/build")
    public ResponseEntity<String> build(@RequestBody BuildRequest request) {
        try {
            ragService.buildKnowledgeBase(request.getFilePath(), request.getSourceName());
            return ResponseEntity.ok("知识库构建成功");
        } catch (Exception e) {
            return ResponseEntity.status(500).body("构建失败: " + e.getMessage());
        }
    }
}

@Data
class AskRequest {
    private String question;
}

@Data
class BuildRequest {
    private String filePath;
    private String sourceName;
}

五、常见问题与优化

5.1 向量维度不匹配

java

复制代码
// bge-m3 的输出维度是 1024
// OpenAI text-embedding-3-small 是 1536
// 创建集合时务必确认维度
int vectorSize = 1024;  // bge-m3

5.2 性能优化建议

问题 解决方案
向量化慢 使用批量 embedding(一次处理多个文本)
检索延迟高 增加 Qdrant 内存配置,使用 HNSW 索引
大文档处理 OOM 流式读取,分批处理

5.3 生产环境配置

yaml

复制代码
# application-prod.yml
rag:
  qdrant:
    host: ${QDRANT_HOST:localhost}
    port: ${QDRANT_PORT:6333}
    collection: ${QDRANT_COLLECTION:knowledge_base}
  embedding:
    url: ${EMBEDDING_URL:http://ollama:11434/api/embeddings}
    model: ${EMBEDDING_MODEL:bge-m3:latest}
  llm:
    url: ${LLM_URL:https://api.deepseek.com/v1/chat/completions}
    api-key: ${DEEPSEEK_API_KEY}
    model: deepseek-chat
  retrieval:
    top-k: 5
    score-threshold: 0.7

写在最后

本文提供了一个完整的 Java 版 RAG 知识库实现,涵盖:

  1. 文档加载:支持 PDF、Word、TXT

  2. 文本分块:可配置大小和重叠

  3. 向量化:Ollama + bge-m3

  4. 向量存储:Qdrant(高性能向量数据库)

  5. 检索+生成:完整 RAG 流程

整个方案可以在普通服务器上运行,无需 GPU,适合企业内部知识库、智能客服等场景。

相关推荐
敲代码的瓦龙1 小时前
Android?碎片!!!
java·开发语言·android-studio
月落归舟1 小时前
深入解析Spring依赖注入 DI 的三种方式
java·后端·spring
庞轩px1 小时前
第一篇:Spring IoC容器——控制反转的本质与Bean的生命周期
spring·ioc·di·控制反转·bean生命周期·循环依赖
亚马逊云开发者1 小时前
Lambda 冷启动改善了,你的 Provisioned Concurrency 可能白花钱了
java
C雨后彩虹1 小时前
猴子爬山问题
java·数据结构·算法·华为·面试
冲上云霄的Jayden1 小时前
面向 FAQ、流程文档、规则文档的 RAG 处理方案
metadata·chunk·rag·语义搜索·向量化·faq·langchain4j
天真吴邪xie1 小时前
Claude Code安装
java·git
小新同学^O^2 小时前
简单学习 --> Spring统一处理
java·学习·spring·统一功能处理
程序猿乐锅2 小时前
【Tilas|第七篇】学员管理实现
java·笔记·idea·tlias