SpringAI之RAG

文章目录

由于训练大模型非常耗时,再加上训练语料本身比较滞后,所以大模型存在 知识限制 问题:

  • 知识数据比较落后,往往是几个月之前的
  • 不包含太过专业领域 或者企业私有的数据

为了解决这些问题,就需要用到 RAG 知识库。

1 RAG介绍

要解决大模型的知识限制问题,其实并不复杂。

解决的思路就是给大模型外挂一个知识库,可以是专业领域知识,也可以是企业私有的数据。

不过,知识库不能简单的直接拼接在提示词中。

因为通常知识库数据量都是非常大的,而大模型的上下文是有大小限制的。

那怎么办?

思路很简单,庞大的知识库中与用户问题相关的其实并不多。

所以,我们需要想办法从庞大的知识库中找到与用户问题相关的一小部分,组装成提示词,发送给大模型就可以了。

那么问题来了,该如何从知识库中找到与用户问题相关的内容呢?

可能有人会相想到全文检索,但是在这里是不合适的,因为全文检索是文字匹配,这里要求的是内容上的相似度。

而要从内容相似度来判断,这就不得不提到向量模型了。

1_向量模型

先说说向量,向量是空间中有方向和长度的量,空间可以是二维,也可以是多维。

向量既然是在空间中,两个向量之间就一定能计算距离。

我们以二维向量为例,向量之间的距离有两种计算方法:

通常,两个向量之间欧式距离越近 ,便可以认为两个向量的相似度越高。(余弦距离相反,越大相似度越高)

所以,如果我们能把文本转为向量 ,就可以通过向量距离来判断文本的相似度了。

现在,有不少的专门的向量模型,就可以实现将文本向量化。

一个好的向量模型,就是要尽可能让文本含义相似的向量,在空间中距离更近

接下来,准备一个向量模型,用于将文本向量化。

可以使用阿里云百炼平台提供的向量模型:

修改application.yaml,添加向量模型配置:

yaml 复制代码
spring:
  ai:
    openai:
      base-url: https://dashscope.aliyuncs.com/compatible-mode
      api-key: ${API-KEY}
      embedding:
        options:
          model: text-embedding-v4 # 向量模型名称
          dimensions: 1024 # 向量维度

2_测试向量模型

文本向量化以后,可以通过向量之间的距离来判断文本相似度。

在项目中写一个工具类,用以计算向量之间的欧氏距离余弦距离

java 复制代码
public class VectorDistanceUtils {
    
    // 防止实例化
    private VectorDistanceUtils() {}

    // 浮点数计算精度阈值
    private static final double EPSILON = 1e-12;

    /**
     * 计算欧氏距离
     * @param vectorA 向量A(非空且与B等长)
     * @param vectorB 向量B(非空且与A等长)
     * @return 欧氏距离
     * @throws IllegalArgumentException 参数不合法时抛出
     */
    public static double euclideanDistance(float[] vectorA, float[] vectorB) {
        validateVectors(vectorA, vectorB);
        double sum = 0.0;
        for (int i = 0; i < vectorA.length; i++) {
            double diff = vectorA[i] - vectorB[i];
            sum += diff * diff;
        }
        return Math.sqrt(sum);
    }

    /**
     * 计算余弦距离
     * @param vectorA 向量A(非空且与B等长)
     * @param vectorB 向量B(非空且与A等长)
     * @return 余弦距离,范围[0, 2]
     * @throws IllegalArgumentException 参数不合法或零向量时抛出
     */
    public static double cosineDistance(float[] vectorA, float[] vectorB) {
        validateVectors(vectorA, vectorB);
        
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        
        for (int i = 0; i < vectorA.length; i++) {
            dotProduct += vectorA[i] * vectorB[i];
            normA += vectorA[i] * vectorA[i];
            normB += vectorB[i] * vectorB[i];
        }
        
        normA = Math.sqrt(normA);
        normB = Math.sqrt(normB);
        
        // 处理零向量情况
        if (normA < EPSILON || normB < EPSILON) {
            throw new IllegalArgumentException("Vectors cannot be zero vectors");
        }
        
        // 处理浮点误差,确保结果在[-1,1]范围内
        double similarity =  dotProduct / (normA * normB);
        similarity = Math.max(Math.min(similarity, 1.0), -1.0);
        
        return similarity;
    }

    // 参数校验统一方法
    private static void validateVectors(float[] a, float[] b) {
        if (a == null || b == null) {
            throw new IllegalArgumentException("Vectors cannot be null");
        }
        if (a.length != b.length) {
            throw new IllegalArgumentException("Vectors must have same dimension");
        }
        if (a.length == 0) {
            throw new IllegalArgumentException("Vectors cannot be empty");
        }
    }
}

编写测试类,计算两个文本的距离

java 复制代码
@Resource
private OpenAiEmbeddingModel embeddingModel;
@Test
public void testEmbedding() {
    // 1.测试数据
    // 1.1.用来查询的文本,国际冲突
    String query = "global conflicts";
    // 1.2.用来做比较的文本
    String[] texts = new String[]{
            "哈马斯称加沙下阶段停火谈判仍在进行 以方尚未做出承诺",
            "土耳其、芬兰、瑞典与北约代表将继续就瑞典"入约"问题进行谈判",
            "日本航空基地水井中检测出有机氟化物超标",
            "国家游泳中心(水立方):恢复游泳、嬉水乐园等水上项目运营",
            "我国首次在空间站开展舱外辐射生物学暴露实验",
    };
    // 2.向量化
    // 2.1.先将查询文本向量化
    float[] queryVector = embeddingModel.embed(query);
    // 2.2.再将比较文本向量化,放到一个数组
    List<float[]> textVectors = embeddingModel.embed(Arrays.asList(texts));
    // 3.比较欧氏距离
    // 3.1.把查询文本自己与自己比较,肯定是相似度最高的
    System.out.println("与自己的距离: " + VectorDistanceUtils.euclideanDistance(queryVector, queryVector));
    // 3.2.把查询文本与其它文本比较
    for (float[] textVector : textVectors) {
        System.out.println("与其他文本的距离: " + VectorDistanceUtils.euclideanDistance(queryVector, textVector));
    }
    System.out.println("------------------");
    // 4.比较余弦距离
    // 4.1.把查询文本自己与自己比较,肯定是相似度最高的
    System.out.println("与自己的距离: " + VectorDistanceUtils.cosineDistance(queryVector, queryVector));
    // 4.2.把查询文本与其它文本比较
    for (float[] textVector : textVectors) {
        System.out.println("与其他文本的距离: " + VectorDistanceUtils.cosineDistance(queryVector, textVector));
    }
}

运行结果:

plain 复制代码
与自己的距离: 0.0
与其他文本的距离: 1.277985806334919
与其他文本的距离: 1.217696088331691
与其他文本的距离: 1.3344384543780141
与其他文本的距离: 1.3342534594876638
与其他文本的距离: 1.3400395683070097
------------------
与自己的距离: 1.0
与其他文本的距离: 0.18337628987446866
与其他文本的距离: 0.25860824145232714
与其他文本的距离: 0.1096371227131696
与其他文本的距离: 0.10988406960580344
与其他文本的距离: 0.10214705075234658

可以看到,向量相似度确实符合我们的预期。

有了比较文本相似度的办法,知识库的问题就可以解决了。

只需利用向量数据库从庞大的知识库中比较和检索数据找到与用户问题相关的一小部分**,组装成提示词**,发送给大模型就可以了。

3_向量数据库

向量数据库的主要作用有两个:

  • 存储向量数据
  • 基于相似度检索数据SpringAI支持很多向量数据库,并且都进行了封装,可以用统一的API去访问:

实现:

这些库都实现了统一的接口:VectorStore,因此操作方式一模一样。

java 复制代码
public interface VectorStore extends DocumentWriter {

    default String getName() {
		return this.getClass().getSimpleName();
	}
    // 保存文档到向量库
    void add(List<Document> documents);
    // 根据文档id删除文档
    void delete(List<String> idList);

    void delete(Filter.Expression filterExpression);

    default void delete(String filterExpression) { ... };
    // 根据文档id删除文档
    List<Document> similaritySearch(String query);
    // 根据条件检索文档
    List<Document> similaritySearch(SearchRequest request);

    default <T> Optional<T> getNativeClient() {
		return Optional.empty();
	}
}

VectorStore 操作向量化的基本单位是 Document,在使用时需要将知识库分割转换为一个个的 Document ,然后再写入 VectorStore。

检索数据时使用构建器构建 SearchRequest 查询条件对象。

我们以 Redis 为例实现向量数据库。

xml 复制代码
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-vector-store-redis</artifactId>
</dependency>

需要部署 Redis 向量版本(docker):

bash 复制代码
docker run -d --name redis-stack \
  -p 6379:6379 -p 8001:8001 \
  -e REDIS_ARGS="--requirepass 123456" \
  redis/redis-stack:latest

配置向量数据库

yaml 复制代码
spring:
  data:
    redis:
      host: 192.168.200.129
      port: 6379
      password: 123456
  ai:
    vectorstore:
      redis:
        initialize-schema: true # 是否初始化向量索引结构
        index-name: ai-index # 向量索引库名称 custom-index
        prefix: "doc:" # 向量 key 前缀 custom-prefix

注意:引入多个模型可能会导致 RedisVectorStore 自动装配失败

4_文件读取和转换

前面说过,知识库太大,是需要拆分成文档片段,然后再做向量化的。

而 Spring AI 中向量库接收的是 Document 类型的文档,也就是说,处理后的文档还要转成 Document 格式。

在 Spring AI 中提供了各种文档读取的工具,可以进行文档读取、拆分、转换的动作(也可以自己尝试实现):

https://docs.spring.io/spring-ai/reference/api/etl-pipeline.html#pdfparagraph

比如 PDF 文档读取和拆分,Spring AI 提供了两种默认的拆分原则:

  • PagePdfDocumentReader:按页拆分,推荐使用(示例)。
  • ParagraphPdfDocumentReader:按 pdf 的目录拆分,不推荐,因为很多PDF不规范,没有章节标签。

引入依赖:

xml 复制代码
<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-pdf-document-reader</artifactId>
</dependency>

测试

java 复制代码
@Test
public void testVectorStore(){
    Resource resource = new ClassPathResource("context.pdf");
    // 1.创建PDF的读取器
    PagePdfDocumentReader reader = new PagePdfDocumentReader(resource, PdfDocumentReaderConfig.builder()
            .withPageExtractedTextFormatter(ExtractedTextFormatter.defaults())
            .withPagesPerDocument(1)// 每1页PDF作为一个Document
            .build());
    // 2.读取PDF文档,拆分为 Document
    List<Document> documents = reader.read();
    // 3.写入向量库
    vectorStore.add(documents);
    // 4.搜索
    SearchRequest request = SearchRequest.builder()
            .query("论语中教育的目的是什么")
            .topK(1) //返回一个相似结果
            .similarityThreshold(0.6)//筛选相似性阈值
            .filterExpression("file_name == 'context.pdf'")//元数据筛选过滤条件,指定pdf文档
            .build();
    List<Document> docs = vectorStore.similaritySearch(request);
    if (docs == null) {
        System.out.println("没有搜索到任何内容");
        return;
    }
    for (Document doc : docs) {
        System.out.println(doc.getId());
        System.out.println(doc.getScore());
        System.out.println(doc.getText());
    }
}

元数据过滤说明:
https://docs.spring.io/spring-ai/reference/api/vectordbs.html#metadata-filters

5_总结

工具:

  • PDFReader:读取文档并拆分为片段
  • 向量大模型:将文本片段向量化
  • 向量数据库:存储向量,检索向量

解决的问题和解决思路:

  • 要解决大模型的知识限制问题,需要外挂知识库
  • 受到大模型上下文限制,知识库不能简单的直接拼接在提示词中
  • 需要从庞大的知识库中找到与用户问题相关的一小部分,再组装成提示词
  • 这些可以利用文档读取器向量大模型向量数据库来解决。

所以 RAG 要做的事情就是将知识库分割,然后利用向量模型做向量化,存入向量数据库,然后查询的时候去检索:

第一阶段(存储知识库):

  • 将知识库内容切片,分为一个个片段
  • 将每个片段利用向量模型向量化
  • 将所有向量化后的片段写入向量数据库

第二阶段(检索知识库):

  • 每当用户询问AI时,将用户问题向量化
  • 拿着问题向量去向量数据库检索最相关的片段

第三阶段(对话大模型):

  • 将检索到的片段、用户的问题一起拼接为提示词
  • 发送提示词给大模型,得到响应

2 实现 ChatPDF

目标:实现一个非常火爆的个人知识库 AI 应用,ChatPDF,原网站如下

这个网站其实就是把个人的 PDF 文件作为知识库,让 AI 基于 PDF 内容来回答问题,对于大学生、研究人员、专业人士来说,非常方便。

既然是 ChatPDF,也就是说所有知识库都是 PDF 形式的,由用户提交给应用。所以,需要先实现一个上传 PDF 的接口,在接口中实现下列功能:

  • 校验文件格式是否为 PDF

  • 保存文件信息

    • 保存文件(可以是 oss 或本地保存)
    • 保存会话ID和文件路径的映射关系(方便查询会话历史的时候再次读取文件)
  • 文档拆分和向量化(文档太大,需要拆分为一个个片段,分别向量化)

另外,将来用户查询会话历史,我们还需要返回 pdf 文件给前端用于预览,所以需要实现一个下载 PDF 接口,包含下面功能:

  • 读取文件
  • 返回文件给前端

1_PDF文件管理

由于将来要实现 PDF 下载功能,所以需要记住每一个 chatId 对应的 PDF 文件名称。

所以,我们定义一个类,记录 chatId 与 pdf 文件的映射关系,同时实现基本的文件保存功能。

java 复制代码
import org.springframework.core.io.Resource;

public interface FileRepository {
    /**
     * 保存文件,还要记录chatId与文件的映射关系
     *
     * @param chatId   会话id
     * @param resource 文件
     * @return 上传成功,返回true; 否则返回false
     */
    boolean save(String chatId, Resource resource);

    /**
     * 根据chatId获取文件
     *
     * @param chatId 会话id
     * @return 找到的文件
     */
    Resource getFile(String chatId);
}

实现类:

java 复制代码
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.time.LocalDateTime;
import java.util.Objects;
import java.util.Properties;

@Slf4j
@Component
@RequiredArgsConstructor
public class LocalPdfFileRepository implements FileRepository {

    private final VectorStore vectorStore;

    // 会话id 与 文件名的对应关系,方便查询会话历史时重新加载文件(Properties自带持久化的能力)
    private final Properties chatFiles = new Properties();

    @Override
    public boolean save(String chatId, Resource resource) {
        // 2.保存到本地磁盘
        String filename = resource.getFilename();
        File target = new File(Objects.requireNonNull(filename));
        if (!target.exists()) {
            try {
                Files.copy(resource.getInputStream(), target.toPath());
            } catch (IOException e) {
                log.error("Failed to save PDF resource.", e);
                return false;
            }
        }
        // 3.保存映射关系
        chatFiles.put(chatId, filename);
        return true;
    }

    @Override
    public Resource getFile(String chatId) {
        return new FileSystemResource(chatFiles.getProperty(chatId));
    }

    @PostConstruct
    private void init() {
        // 属性文件
        FileSystemResource pdfResource = new FileSystemResource("chat-pdf.properties");
        if (pdfResource.exists()) {
            try {
                chatFiles.load(new BufferedReader(new InputStreamReader(pdfResource.getInputStream(), StandardCharsets.UTF_8)));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        // 内存向量数据库保存位置
        FileSystemResource vectorResource = new FileSystemResource("chat-pdf.json");
        // 只有基于内存的 SimpleVectorStore 实现才需要重新加载
        if (vectorResource.exists() && (vectorStore instanceof SimpleVectorStore simpleVectorStore)) {
            simpleVectorStore.load(vectorResource);
        }
    }

    @PreDestroy
    private void persistent() {
        try {
            // 属性文件
            chatFiles.store(new FileWriter("chat-pdf.properties"), LocalDateTime.now().toString());
            //内存实现则需要保存
            if (vectorStore instanceof SimpleVectorStore simpleVectorStore) {
                simpleVectorStore.save(new File("chat-pdf.json"));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

注意

这里将 pdf 文件与 chatId 的对应关系持久化到了磁盘,至于 VectorStore 则进行了一个判断,如果选择了基于内存的 SimpleVectorStore,则将向量数据保存到一个 JSON 文件中。

实际开发中,如果你选择了 RedisVectorStore,或者 CassandraVectorStore,则无序自己持久化。

但是 chatId 和 PDF 文件之间的对应关系,还是需要自己维护的。

2_封装响应

返回给前端文件上传需要的响应结果,定义 Result 类:

java 复制代码
@Data
@NoArgsConstructor
public class Result {
    private Integer ok;
    private String msg;

    private Result(Integer ok, String msg) {
        this.ok = ok;
        this.msg = msg;
    }

    public static Result ok() {
        return new Result(1, "ok");
    }

    public static Result fail(String msg) {
        return new Result(0, msg);
    }
}

3_文件上传、下载

接口实现:

java 复制代码
@Slf4j
@RequiredArgsConstructor
@RestController
@RequestMapping("/ai/pdf")
public class PdfController {

    private final FileRepository fileRepository;

    private final VectorStore vectorStore;
    /**
     * 文件上传
     */
    @RequestMapping("/upload/{chatId}")
    public Result uploadPdf(@PathVariable String chatId, @RequestParam("file") MultipartFile file) {
        try {
            // 1. 校验文件是否为PDF格式
            if (!Objects.equals(file.getContentType(), "application/pdf")) {
                return Result.fail("只能上传PDF文件!");
            }
            // 2.保存文件
            boolean success = fileRepository.save(chatId, file.getResource());
            if(! success) {
                return Result.fail("保存文件失败!");
            }
            // 3.写入向量库
            this.writeToVectorStore(file.getResource());
            return Result.ok();
        } catch (Exception e) {
            log.error("Failed to upload PDF.", e);
            return Result.fail("上传文件失败!");
        }
    }

    /**
     * 文件下载
     */
    @GetMapping("/file/{chatId}")
    public ResponseEntity<Resource> download(@PathVariable("chatId") String chatId) throws IOException {
        // 1.读取文件
        Resource resource = fileRepository.getFile(chatId);
        if (!resource.exists()) {
            return ResponseEntity.notFound().build();
        }
        // 2.文件名编码,写入响应头
        String filename = URLEncoder.encode(Objects.requireNonNull(resource.getFilename()), StandardCharsets.UTF_8);
        // 3.返回文件
        return ResponseEntity.ok()
                .contentType(MediaType.APPLICATION_OCTET_STREAM)
                .header("Content-Disposition", "attachment; filename=\"" + filename + "\"")
                .body(resource);
    }

    private void writeToVectorStore(Resource resource) {
        // 1.创建PDF的读取器
        PagePdfDocumentReader reader = new PagePdfDocumentReader(
                resource, // 文件源
                PdfDocumentReaderConfig.builder()
                        .withPageExtractedTextFormatter(ExtractedTextFormatter.defaults())
                        .withPagesPerDocument(1) // 每1页PDF作为一个Document
                        .build()
        );
        // 2.读取PDF文档,拆分为Document
        List<Document> documents = reader.read();
        // 3.写入向量库
        vectorStore.add(documents);
    }
}

4_上传大小限制

SpringMVC 有默认的文件大小限制,只有10M,很多知识库文件都会超过这个值,因此需要修改配置,增加文件上传允许的上限。

yaml 复制代码
spring:
  servlet:
    multipart:
      max-file-size: 104857600
      max-request-size: 104857600

5_配置ChatClient

理论上来说,我们每次与AI对话的完整流程是这样的:

  • 将用户的问题利用向量大模型做向量化 OpenAiEmbeddingModel;
  • 去向量数据库检索相关的文档 VectorStore;
  • 拼接提示词,发送给大模型;
  • 解析响应结果;

不过,SpringAI 同样基于 AOP 技术帮我们完成了全部流程,用到的是一个名 QuestionAnswerAdvisor 的 Advisor,需要先引入依赖。

xml 复制代码
<dependency>
   <groupId>org.springframework.ai</groupId>
   <artifactId>spring-ai-advisors-vector-store</artifactId>
</dependency>

之后只需要把 VectorStore 配置到 Advisor 即可。

java 复制代码
@Bean
public ChatClient pdfChatClient(OpenAiChatModel model, ChatMemory chatMemory, VectorStore vectorStore) {
    return ChatClient.builder(model)
            .defaultSystem("请根据提供的上下文回答问题,不要自己猜测。")
            .defaultAdvisors(
                    MessageChatMemoryAdvisor.builder(chatMemory).build(), // CHAT MEMORY
                    new SimpleLoggerAdvisor(),
                    QuestionAnswerAdvisor.builder(vectorStore)
                            .searchRequest(
                                    SearchRequest.builder()
                                            .similarityThreshold(0.5d)
                                            .topK(1)
                                            .build()
                            ).build())
            .build();
}

也可以自己自定义 RAG 查询的流程,不使用 Advisor,具体可参考官网:https://docs.spring.io/spring-ai/reference/api

6_对话接口

对接前端,与大模型对话,在 PdfController 中添加如下内容:

java 复制代码
private final ChatHistoryRepository chatRepository;
private final ChatClient pdfChatClient;
@RequestMapping(value = "/chat", produces = "text/html;charset=UTF-8")
public Flux<String> chat(String prompt, String chatId) {
    chatRepository.save("pdf", chatId);
    Resource file = fileRepository.getFile(chatId);
    return pdfChatClient
            .prompt(prompt)
            .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, chatId))
            .advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "file_name == '" + file.getFilename() + "'"))
            .stream()
            .content();
}

7_测试

上传一个 PDF 文件之后,就可以对 PDF 提问了,AI 也会根据文档来回答问题:

成功实现文档对话功能。