实现知识问题系统

前言

以下是在 Java 中基于已部署的大模型(IP: 192.193.1.133)和 Elasticsearch 集群(IP: 192.193.1.132)实现 ** 知识问答系统** 的完整源码。包含:文档向量化存储、混合检索、大模型调用、多轮对话管理。

一、技术栈与依赖(Maven pom.xml 核心)

xml 复制代码
<dependencies>
    <!-- Spring Boot Web -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
        <version>2.7.18</version>
    </dependency>
    <!-- Elasticsearch Rest Client -->
    <dependency>
        <groupId>org.elasticsearch.client</groupId>
        <artifactId>elasticsearch-rest-high-level-client</artifactId>
        <version>7.17.22</version>
    </dependency>
    <!-- OkHttp (用于调用大模型 API) -->
    <dependency>
        <groupId>com.squareup.okhttp3</groupId>
        <artifactId>okhttp</artifactId>
        <version>4.12.0</version>
    </dependency>
    <!-- Gson -->
    <dependency>
        <groupId>com.google.code.gson</groupId>
        <artifactId>gson</artifactId>
        <version>2.10.1</version>
    </dependency>
    <!-- Lombok -->
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <version>1.18.30</version>
        <scope>provided</scope>
    </dependency>
</dependencies>

二、配置类(连接 ES 和 DeepSeek)

java 复制代码
package com.rag.config;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Configuration;

@Data
@Configuration
@ConfigurationProperties(prefix = "rag")
public class RagProperties {
    private ElasticsearchConfig elasticsearch = new ElasticsearchConfig();
    private DeepSeekConfig deepseek = new DeepSeekConfig();

    @Data
    public static class ElasticsearchConfig {
        private String hosts = "192.193.1.132:9200";
        private String indexName = "knowledge_base";
        private int vectorDimension = 768;
    }

    @Data
    public static class DeepSeekConfig {
        private String url = "http://192.193.1.133:8000/v1/chat/completions";
        private String embeddingUrl = "http://192.193.1.133:8000/v1/embeddings"; // 假设模型支持 embedding
        private String apiKey = "your-api-key"; // 若需要
        private int maxTokens = 1024;
        private double temperature = 0.1;
    }
}
yaml 复制代码
# application.yml
rag:
  elasticsearch:
    hosts: 192.193.1.132:9200
    index-name: knowledge_base
    vector-dimension: 768
  deepseek:
    url: http://192.193.1.133:8000/v1/chat/completions
    embedding-url: http://192.193.1.133:8000/v1/embeddings
    api-key: sk-xxx
    max-tokens: 1024
    temperature: 0.1
java 复制代码
package com.rag.config;

import org.apache.http.HttpHost;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class ElasticsearchConfig {

    @Bean
    public RestHighLevelClient restHighLevelClient(RagProperties props) {
        String host = props.getElasticsearch().getHosts().split(":")[0];
        int port = Integer.parseInt(props.getElasticsearch().getHosts().split(":")[1]);
        return new RestHighLevelClient(
                RestClient.builder(new HttpHost(host, port, "http"))
        );
    }
}

三、向量化 Embedding 服务(调用 DeepSeek 提供的 embedding 接口)

若你的 DeepSeek 部署不提供 /v1/embeddings,可以使用本地 sentence-transformers 服务(另行部署)。这里假设已提供。

java 复制代码
package com.rag.service;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.rag.config.RagProperties;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

@Slf4j
@Service
@RequiredArgsConstructor
public class EmbeddingService {

    private final RagProperties props;
    private final OkHttpClient httpClient = new OkHttpClient.Builder()
            .connectTimeout(30, java.util.concurrent.TimeUnit.SECONDS)
            .readTimeout(60, java.util.concurrent.TimeUnit.SECONDS)
            .build();
    private final Gson gson = new Gson();

    /**
     * 将文本转换为向量(768维)
     */
    public float[] embed(String text) {
        JsonObject body = new JsonObject();
        body.addProperty("model", "text-embedding-ada-002"); // 根据实际模型名
        body.add("input", gson.toJsonTree(List.of(text)));

        Request request = new Request.Builder()
                .url(props.getDeepseek().getEmbeddingUrl())
                .addHeader("Content-Type", "application/json")
                .addHeader("Authorization", "Bearer " + props.getDeepseek().getApiKey())
                .post(RequestBody.create(gson.toJson(body), MediaType.parse("application/json")))
                .build();

        try (Response response = httpClient.newCall(request).execute()) {
            if (!response.isSuccessful()) {
                throw new RuntimeException("Embedding failed: " + response);
            }
            JsonObject resp = gson.fromJson(response.body().string(), JsonObject.class);
            JsonArray data = resp.getAsJsonArray("data");
            JsonArray embeddingArr = data.get(0).getAsJsonObject().getAsJsonArray("embedding");
            float[] vector = new float[embeddingArr.size()];
            for (int i = 0; i < embeddingArr.size(); i++) {
                vector[i] = embeddingArr.get(i).getAsFloat();
            }
            return vector;
        } catch (IOException e) {
            log.error("Embedding error", e);
            throw new RuntimeException(e);
        }
    }

    /**
     * 批量向量化(可选)
     */
    public List<float[]> embedBatch(List<String> texts) {
        List<float[]> vectors = new ArrayList<>();
        for (String text : texts) {
            vectors.add(embed(text));
        }
        return vectors;
    }
}

四、ES 向量存储操作(创建索引、存文档、向量检索)

4.1 定义文档实体

java 复制代码
package com.rag.entity;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class KnowledgeDocument {
    private String chunkId;       // 分块唯一ID
    private String docId;         // 原始文件ID
    private String content;       // 文本内容
    private List<String> tags;    // 标签
    private Map<String, Object> metadata;
    private float[] vector;       // 向量字段(存储时需特殊处理)
}

4.2 ES 索引 Mapping 创建

java 复制代码
package com.rag.repository;

import com.rag.config.RagProperties;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.action.admin.indices.create.CreateIndexRequest;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.common.xcontent.XContentType;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;

@Slf4j
@Component
@RequiredArgsConstructor
public class EsIndexInitializer {

    private final RestHighLevelClient esClient;
    private final RagProperties props;

    @PostConstruct
    public void initIndex() throws Exception {
        String indexName = props.getElasticsearch().getIndexName();
        String mapping = """
                {
                  "mappings": {
                    "properties": {
                      "chunkId": { "type": "keyword" },
                      "docId": { "type": "keyword" },
                      "content": { "type": "text", "analyzer": "ik_max_word" },
                      "tags": { "type": "keyword" },
                      "metadata": { "type": "object", "enabled": true },
                      "vector": {
                        "type": "dense_vector",
                        "dims": %d,
                        "index": true,
                        "similarity": "cosine"
                      }
                    }
                  }
                }
                """.formatted(props.getElasticsearch().getVectorDimension());

        CreateIndexRequest request = new CreateIndexRequest(indexName);
        request.source(mapping, XContentType.JSON);
        try {
            esClient.indices().create(request, RequestOptions.DEFAULT);
            log.info("Index {} created", indexName);
        } catch (Exception e) {
            log.warn("Index may already exist: {}", e.getMessage());
        }
    }
}

4.3 向量存储服务

java 复制代码
package com.rag.repository;

import com.rag.entity.KnowledgeDocument;
import com.rag.service.EmbeddingService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.common.xcontent.XContentType;
import org.springframework.stereotype.Repository;

import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

@Slf4j
@Repository
@RequiredArgsConstructor
public class VectorStoreRepository {

    private final RestHighLevelClient esClient;
    private final EmbeddingService embeddingService;
    private final RagProperties props;

    /**
     * 将解析后的文本分块向量化并存储到 ES
     * @param docId 原始文档ID
     * @param chunkText 分块文本
     * @param tags 标签列表
     * @param metadata 元数据
     */
    public String storeChunk(String docId, String chunkText, List<String> tags, Map<String, Object> metadata) {
        float[] vector = embeddingService.embed(chunkText);
        String chunkId = UUID.randomUUID().toString();

        Map<String, Object> source = new HashMap<>();
        source.put("chunkId", chunkId);
        source.put("docId", docId);
        source.put("content", chunkText);
        source.put("tags", tags);
        source.put("metadata", metadata);
        source.put("vector", vector);

        IndexRequest request = new IndexRequest(props.getElasticsearch().getIndexName())
                .id(chunkId)
                .source(source, XContentType.JSON);

        try {
            IndexResponse response = esClient.index(request, RequestOptions.DEFAULT);
            log.info("Stored chunk {} in ES", response.getId());
            return chunkId;
        } catch (Exception e) {
            log.error("Failed to store chunk", e);
            throw new RuntimeException(e);
        }
    }
}

五、混合检索服务(向量 + BM25)

java 复制代码
package com.rag.service;

import com.rag.config.RagProperties;
import com.rag.entity.KnowledgeDocument;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
import org.springframework.stereotype.Service;

import java.util.*;

@Slf4j
@Service
@RequiredArgsConstructor
public class HybridSearchService {

    private final RestHighLevelClient esClient;
    private final EmbeddingService embeddingService;
    private final RagProperties props;

    private static final int RECALL_K = 50;    // 召回数量
    private static final int FINAL_K = 5;      // 最终返回数量

    /**
     * 混合检索:向量相似度 + BM25 关键词,使用 RRF 融合排序
     * @param query 用户问题
     * @return 相关文档片段列表
     */
    public List<KnowledgeDocument> search(String query) throws Exception {
        float[] queryVector = embeddingService.embed(query);

        // 构建脚本评分查询(向量相似度)
        String vectorScript = "cosineSimilarity(params.query_vector, 'vector') + 1.0";
        Script script = new Script(ScriptType.INLINE, "painless", vectorScript,
                Map.of("query_vector", queryVector));

        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
                .size(RECALL_K)
                .query(QueryBuilders.scriptQuery(script))
                .fetchSource(new String[]{"chunkId", "docId", "content", "tags", "metadata"}, null)
                .highlighter(new HighlightBuilder().field("content").fragmentSize(150).numOfFragments(1));

        // 同时添加关键词查询(BM25)作为 should 提升召回(实际需要多查询合并,简化:先单独查 BM25 再融合)
        // 此处为了完整演示,分开查询再融合
        SearchRequest vectorRequest = new SearchRequest(props.getElasticsearch().getIndexName());
        vectorRequest.source(sourceBuilder);

        SearchResponse vectorResponse = esClient.search(vectorRequest, RequestOptions.DEFAULT);
        List<KnowledgeDocument> vectorHits = parseResponse(vectorResponse);

        // BM25 关键词查询
        SearchSourceBuilder keywordSource = new SearchSourceBuilder()
                .size(RECALL_K)
                .query(QueryBuilders.matchQuery("content", query))
                .fetchSource(new String[]{"chunkId", "docId", "content", "tags", "metadata"}, null);
        SearchRequest keywordRequest = new SearchRequest(props.getElasticsearch().getIndexName());
        keywordRequest.source(keywordSource);
        SearchResponse keywordResponse = esClient.search(keywordRequest, RequestOptions.DEFAULT);
        List<KnowledgeDocument> keywordHits = parseResponse(keywordResponse);

        // RRF 融合排序 (Reciprocal Rank Fusion)
        return rrfMerge(vectorHits, keywordHits, FINAL_K);
    }

    private List<KnowledgeDocument> parseResponse(SearchResponse response) {
        List<KnowledgeDocument> docs = new ArrayList<>();
        Arrays.stream(response.getHits().getHits()).forEach(hit -> {
            Map<String, Object> src = hit.getSourceAsMap();
            KnowledgeDocument doc = KnowledgeDocument.builder()
                    .chunkId((String) src.get("chunkId"))
                    .docId((String) src.get("docId"))
                    .content((String) src.get("content"))
                    .tags((List<String>) src.get("tags"))
                    .metadata((Map<String, Object>) src.get("metadata"))
                    .build();
            docs.add(doc);
        });
        return docs;
    }

    private List<KnowledgeDocument> rrfMerge(List<KnowledgeDocument> listA, List<KnowledgeDocument> listB, int k) {
        Map<String, Double> scoreMap = new HashMap<>();
        Map<String, KnowledgeDocument> docMap = new HashMap<>();
        int rank = 1;
        for (KnowledgeDocument doc : listA) {
            String id = doc.getChunkId();
            double score = 1.0 / (60 + rank);
            scoreMap.put(id, scoreMap.getOrDefault(id, 0.0) + score);
            docMap.putIfAbsent(id, doc);
            rank++;
        }
        rank = 1;
        for (KnowledgeDocument doc : listB) {
            String id = doc.getChunkId();
            double score = 1.0 / (60 + rank);
            scoreMap.put(id, scoreMap.getOrDefault(id, 0.0) + score);
            docMap.putIfAbsent(id, doc);
            rank++;
        }

        return scoreMap.entrySet().stream()
                .sorted(Map.Entry.<String, Double>comparingByValue().reversed())
                .limit(k)
                .map(entry -> docMap.get(entry.getKey()))
                .toList();
    }
}

六、大模型调用服务(DeepSeek)

java 复制代码
package com.rag.service;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.rag.config.RagProperties;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;

@Slf4j
@Service
@RequiredArgsConstructor
public class DeepSeekService {

    private final RagProperties props;
    private final OkHttpClient httpClient = new OkHttpClient.Builder()
            .connectTimeout(60, TimeUnit.SECONDS)
            .readTimeout(120, TimeUnit.SECONDS)
            .build();
    private final Gson gson = new Gson();

    /**
     * 调用 DeepSeek 生成答案,支持流式输出(这里简化返回完整文本)
     * @param messages 对话消息列表
     * @return 模型回复
     */
    public String chat(List<Message> messages) {
        JsonObject requestBody = new JsonObject();
        requestBody.addProperty("model", "deepseek-chat"); // 根据实际模型名
        requestBody.add("messages", gson.toJsonTree(messages));
        requestBody.addProperty("max_tokens", props.getDeepseek().getMaxTokens());
        requestBody.addProperty("temperature", props.getDeepseek().getTemperature());
        requestBody.addProperty("stream", false);

        Request request = new Request.Builder()
                .url(props.getDeepseek().getUrl())
                .addHeader("Content-Type", "application/json")
                .addHeader("Authorization", "Bearer " + props.getDeepseek().getApiKey())
                .post(RequestBody.create(gson.toJson(requestBody), MediaType.parse("application/json")))
                .build();

        try (Response response = httpClient.newCall(request).execute()) {
            if (!response.isSuccessful()) {
                throw new RuntimeException("DeepSeek API error: " + response);
            }
            JsonObject resp = gson.fromJson(response.body().string(), JsonObject.class);
            return resp.getAsJsonArray("choices")
                    .get(0).getAsJsonObject()
                    .getAsJsonObject("message")
                    .get("content").getAsString();
        } catch (IOException e) {
            log.error("Call DeepSeek failed", e);
            throw new RuntimeException(e);
        }
    }

    public static class Message {
        private String role; // "system", "user", "assistant"
        private String content;

        public Message(String role, String content) {
            this.role = role;
            this.content = content;
        }
    }
}

七、RAG 问答服务(包含多轮对话管理)

java 复制代码
package com.rag.service;

import com.rag.entity.KnowledgeDocument;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@Service
@RequiredArgsConstructor
public class RagQaService {

    private final HybridSearchService searchService;
    private final DeepSeekService deepSeekService;

    // 会话历史存储 (sessionId -> 消息列表)
    private final Map<String, List<DeepSeekService.Message>> sessionStore = new ConcurrentHashMap<>();

    // 系统提示词
    private static final String SYSTEM_PROMPT = """
            你是一个专业的知识问答助手,基于提供的【参考资料】回答问题。
            要求:
            1. 只根据参考资料回答,不要编造知识。
            2. 如果参考资料中没有相关信息,请回复"根据已有资料无法回答该问题"。
            3. 在回答末尾标注引用来源(如文档名称)。
            4. 保持回答简洁、准确。
            """;

    /**
     * 执行 RAG 问答
     * @param sessionId 会话ID(用于多轮对话)
     * @param userQuestion 用户问题
     * @return 模型答案
     */
    public String ask(String sessionId, String userQuestion) {
        try {
            // 1. 检索相关文档片段
            List<KnowledgeDocument> retrieved = searchService.search(userQuestion);
            log.info("Retrieved {} documents", retrieved.size());

            // 2. 构造上下文
            String context = buildContext(retrieved);

            // 3. 构建 Prompt(包含系统提示、历史对话、当前问题和上下文)
            String userPrompt = String.format("""
                    【参考资料】
                    %s
                    
                    【用户问题】
                    %s
                    
                    请根据参考资料回答。
                    """, context, userQuestion);

            // 4. 获取对话历史
            List<DeepSeekService.Message> history = sessionStore.getOrDefault(sessionId, new ArrayList<>());
            List<DeepSeekService.Message> messages = new ArrayList<>();
            messages.add(new DeepSeekService.Message("system", SYSTEM_PROMPT));
            messages.addAll(history);
            messages.add(new DeepSeekService.Message("user", userPrompt));

            // 5. 调用大模型
            String answer = deepSeekService.chat(messages);

            // 6. 保存历史(保留最近10轮)
            history.add(new DeepSeekService.Message("user", userQuestion));
            history.add(new DeepSeekService.Message("assistant", answer));
            if (history.size() > 20) { // 10轮对话 = 20条消息
                history = history.subList(history.size() - 20, history.size());
            }
            sessionStore.put(sessionId, history);

            return answer;
        } catch (Exception e) {
            log.error("RAG问答失败", e);
            return "系统繁忙,请稍后再试。";
        }
    }

    private String buildContext(List<KnowledgeDocument> docs) {
        if (docs.isEmpty()) {
            return "(无相关参考资料)";
        }
        StringBuilder sb = new StringBuilder();
        int idx = 1;
        for (KnowledgeDocument doc : docs) {
            String source = doc.getMetadata() != null ? doc.getMetadata().getOrDefault("fileName", "未知来源").toString() : "未知";
            sb.append(String.format("%d. 来源:%s\n   %s\n\n", idx++, source, doc.getContent()));
        }
        return sb.toString();
    }

    /**
     * 清除会话历史
     */
    public void clearSession(String sessionId) {
        sessionStore.remove(sessionId);
    }
}

八、Controller 对外接口

java 复制代码
package com.rag.controller;

import com.rag.service.RagQaService;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;

import java.util.Map;

@RestController
@RequestMapping("/api/rag")
@RequiredArgsConstructor
public class RagController {

    private final RagQaService ragQaService;

    @PostMapping("/ask")
    public Map<String, String> ask(@RequestBody Map<String, String> request) {
        String sessionId = request.getOrDefault("sessionId", "default");
        String question = request.get("question");
        String answer = ragQaService.ask(sessionId, question);
        return Map.of("answer", answer);
    }

    @PostMapping("/clear")
    public Map<String, String> clear(@RequestBody Map<String, String> request) {
        String sessionId = request.getOrDefault("sessionId", "default");
        ragQaService.clearSession(sessionId);
        return Map.of("status", "cleared");
    }
}

九、启动类

java 复制代码
package com.rag;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import com.rag.config.RagProperties;

@SpringBootApplication
@EnableConfigurationProperties(RagProperties.class)
public class RagApplication {
    public static void main(String[] args) {
        SpringApplication.run(RagApplication.class, args);
    }
}

十、使用示例

10.1 存储文档向量(调用 VectorStoreRepository

java 复制代码
// 假设已经从文件中解析出文本块
String docId = "file_001";
String chunk = "本节课主要讲解Java并发编程中的锁机制...";
List<String> tags = List.of("Java", "并发");
Map<String, Object> metadata = new HashMap<>();
metadata.put("fileName", "java_concurrent.pdf");
metadata.put("page", 10);

vectorStoreRepository.storeChunk(docId, chunk, tags, metadata);

10.2 问答请求

bash 复制代码
POST /api/rag/ask
{
    "sessionId": "user_123",
    "question": "Java中的ReentrantLock和synchronized有什么区别?"
}

响应:
{
    "answer": "根据参考资料... ReentrantLock 提供可中断锁、超时锁等高级特性..."
}

注意事项

  1. Embedding 接口 :上述代码假设 DeepSeek 部署同时提供了 /v1/embeddings 端点,若没有,请单独部署一个 Embedding 服务(如 BGE、text2vec)并修改 EmbeddingService 的实现。
  2. Elasticsearch 版本 :使用 7.x 或 8.x 均可,需确保 dense_vector 类型支持。
  3. 中文分词 :建议安装 ik 分词器,否则使用 standard 亦可。
  4. 性能调优:实际生产环境建议增加重排序(Cross-Encoder)和向量缓存。
  5. 安全:API Key 等敏感信息应使用环境变量或配置中心。