springai+pgvector+ollama实现rag

首先在ollama中安装mofanke/dmeta-embedding-zh:latest。执行ollama run mofanke/dmeta-embedding-zh 。实现将文本转化为向量数据

接着安装pgvector(建议使用pgadmin4作为可视化工具,用navicate会出现表不显示的问题)

安装好需要的软件后我们开始编码操作。

1:在pom文件中加入:

java 复制代码
        <!--用于连接pgsql-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>
        <!--用于使用pgvector来操作向量数据库-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
        </dependency>
        <!--pdf解析-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pdf-document-reader</artifactId>
        </dependency>
        <!--文档解析l-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-tika-document-reader</artifactId>
        </dependency>

2:在yml中配置:

java 复制代码
spring:
  datasource:
    url: jdbc:postgresql://127.0.0.1:5432/postgres
    username: postgres
    password: password
  ai:
    vectorstore:
      pgvector:
        dimensions: 768   #不同的embeddingmodel对应的值
    ollama:
      base-url: http://127.0.0.1:11434
      chat:
        enabled: true
        options:
          model: qwen2:7b
      embedding:
        model: mofanke/dmeta-embedding-zh

3:在controller中加入:

java 复制代码
   /**
     * 嵌入文件
     *
     * @param file 待嵌入的文件
     * @return 是否成功
     */
    @SneakyThrows
    @PostMapping("embedding")
    public List<Document> embedding(@RequestParam MultipartFile file) {

        // 从IO流中读取文件
        TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
        // 将文本内容划分成更小的块
        List<Document> splitDocuments = new TokenTextSplitter()
                .apply(tikaDocumentReader.read());
        // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。
        vector.add(splitDocuments);
        return splitDocuments;
    }

调用上方的接口可以将文档转为向量数据存入到pgvector中

4:请求聊天,先根据聊天内容通过pgvector获取对应的数据,并将结果丢到qwen2模型中进行数据分析并返回结果

java 复制代码
/**
     * 获取prompt
     *
     * @param message 提问内容
     * @param context 上下文
     * @return prompt
     */
    private String getChatPrompt2String(String message, String context) {
        String promptText = """
				请用仅用以下内容回答"%s" ,输出结果仅在以下内容中,输出内容仅以下内容,不需要其他描述词:
				%s
				""";
        return String.format(promptText, message, context);
    }

    @GetMapping("chatToPgVector")
    public String chatToPgVector(String message) {

        // 1. 定义提示词模板,question_answer_context会被替换成向量数据库中查询到的文档。
        String promptWithContext = """
                你是一个代码程序,你需要在文本中获取信息并输出成json格式的数据,下面是上下文信息
                ---------------------
                {question_answer_context}
                ---------------------
                给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。
                """;
        //查询获取文档信息
        List<Document> documents = vector.similaritySearch(message,"test_store");
        //提取文本内容
        String content = documents.stream()
                .map(Document::getContent)
                .collect(Collectors.joining("\n"));
        System.out.println(content);
        //封装prompt并调用大模型
        String chatResponse = ollamaChatModel.call(getChatPrompt2String(message, content));
        return chatResponse;
   /*     return ChatClient.create(ollamaChatModel).prompt()
                .user(message)
                // 2. QuestionAnswerAdvisor会在运行时替换模板中的占位符`question_answer_context`,替换成向量数据库中查询到的文档。此时的query=用户的提问+替换完的提示词模板;
                .advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults(), promptWithContext))
                .call().content();*/
    }

至此一个简单的rag搜索增强demo就完成了。接下来我们来看看PgVectorStore为我们做了什么

java 复制代码
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package org.springframework.ai.vectorstore;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.IntStream;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

public class PgVectorStore implements VectorStore, InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class);
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String VECTOR_TABLE_NAME = "vector_store";
    public static final String VECTOR_INDEX_NAME = "spring_ai_vector_index";
    public final FilterExpressionConverter filterExpressionConverter;
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingModel embeddingModel;
    private int dimensions;
    private PgDistanceType distanceType;
    private ObjectMapper objectMapper;
    private boolean removeExistingVectorStoreTable;
    private PgIndexType createIndexMethod;
    private final boolean initializeSchema;

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
        this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema) {
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        this.objectMapper = new ObjectMapper();
        this.jdbcTemplate = jdbcTemplate;
        this.embeddingModel = embeddingModel;
        this.dimensions = dimensions;
        this.distanceType = distanceType;
        this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
        this.createIndexMethod = createIndexMethod;
        this.initializeSchema = initializeSchema;
    }

    public PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void add(final List<Document> documents) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException var3) {
            throw new RuntimeException(var3);
        }
    }

    private float[] toFloatArray(List<Double> embeddingDouble) {
        float[] embeddingFloat = new float[embeddingDouble.size()];
        int i = 0;

        Double d;
        for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
            d = (Double)var4.next();
        }

        return embeddingFloat;
    }

    public Optional<Boolean> delete(List<String> idList) {
        int updateCount = 0;

        int count;
        for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
            String id = (String)var3.next();
            count = this.jdbcTemplate.update("DELETE FROM vector_store WHERE id = ?", new Object[]{UUID.fromString(id)});
        }

        return Optional.of(updateCount == idList.size());
    }

    public List<Document> similaritySearch(SearchRequest request) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        String jsonPathFilter = "";
        if (StringUtils.hasText(nativeFilterExpression)) {
            jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
        }

        double distance = 1.0 - request.getSimilarityThreshold();
        PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
        return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, "vector_store", jsonPathFilter), new DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
    }

    public List<Double> embeddingDistance(String query) {
        return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
            @Nullable
            public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
                return rs.getDouble("distance");
            }
        }, new Object[]{this.getQueryEmbedding(query)});
    }

    private PGvector getQueryEmbedding(String query) {
        List<Double> embedding = this.embeddingModel.embed(query);
        return new PGvector(this.toFloatArray(embedding));
    }

    private String comparisonOperator() {
        return this.getDistanceType().operator;
    }

    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        } else {
            try {
                int embeddingDimensions = this.embeddingModel.dimensions();
                if (embeddingDimensions > 0) {
                    return embeddingDimensions;
                }
            } catch (Exception var2) {
                logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
            }

            return 1536;
        }
    }

    public static enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        private PgDistanceType(String operator, String index, String sqlTemplate) {
            this.operator = operator;
            this.index = index;
            this.similaritySearchSqlTemplate = sqlTemplate;
        }
    }

    public static enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW;

        private PgIndexType() {
        }
    }

    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private ObjectMapper objectMapper;

        public DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString("id");
            String content = rs.getString("content");
            PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
            PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
            Float distance = rs.getFloat("distance");
            Map<String, Object> metadata = this.toMap(pgMetadata);
            metadata.put("distance", distance);
            Document document = new Document(id, content, metadata);
            document.setEmbedding(this.toDoubleList(embedding));
            return document;
        }

        private List<Double> toDoubleList(PGobject embedding) throws SQLException {
            float[] floatArray = (new PGvector(embedding.getValue())).toArray();
            return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
                return (double)floatArray[i];
            }).boxed().toList();
        }

        private Map<String, Object> toMap(PGobject pgObject) {
            String source = pgObject.getValue();

            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            } catch (JsonProcessingException var4) {
                throw new RuntimeException(var4);
            }
        }
    }
}

我们可以看到PgVectorStore实现了InitializingBean并实现了afterPropertiesSet方法。它会在属性设置完成后执行。

java 复制代码
 public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }

这里它会根据initializeSchema(在PgVectorStoreProperties中,默认为true,我们可以yml中配置spring:ai:vectorstore:pgvector:initialize-schema:false来禁用)来判断是否帮我们建表。这里他会帮我们建一个叫vector_store的表,其中包含id(uuid),metadate(json),content(text),embedding(vector(1536))。这里1536指的就是dimensions的值。当我们用默认建的表去做pgvector的诗句存储时会出现 ERROR: expected 1536 dimensions, not 768这样的报错,就是表示我们ollama中的embedding模型输出的dimensions是768,而pgvector中的embedding是1536,他们不匹配所以无法存储。这时我们需要去pgvector中修改embedding字段的token数为768即可(这里不同模型返回的dimension值不一样,可以根据报错信息自行调整)

接下来我们看一下核心的操作方法-向数据库中插入数据

java 复制代码
 public void add(final List<Document> documents) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

这里因为Springai刚出,也不是稳定版的,它在代码中直接写死了操作表。我们使用pgvectorStore时只能对vector_store进行操作,这在实际应用场景中可能会造成一定的局限性。所以我们可以自己写一个扩展操作类来替换它。如下:

java 复制代码
package com.lccloud.tenderdocument.vector;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.IntStream;

public class ExtendPgVectorStore  {
    private static final Logger logger = LoggerFactory.getLogger(ExtendPgVectorStore.class);
    public final FilterExpressionConverter filterExpressionConverter;
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingModel embeddingModel;
    private int dimensions;
    private PgVectorStore.PgDistanceType distanceType;
    private ObjectMapper objectMapper;
    private boolean removeExistingVectorStoreTable;
    private PgVectorStore.PgIndexType createIndexMethod;
    private final boolean initializeSchema;

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
        this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgVectorStore.PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgVectorStore.PgIndexType createIndexMethod, boolean initializeSchema) {
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        this.objectMapper = new ObjectMapper();
        this.jdbcTemplate = jdbcTemplate;
        this.embeddingModel = embeddingModel;
        this.dimensions = dimensions;
        this.distanceType = distanceType;
        this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
        this.createIndexMethod = createIndexMethod;
        this.initializeSchema = initializeSchema;
    }

    public PgVectorStore.PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void add(final List<Document> documents,String tableName) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO "+ tableName+" (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = ExtendPgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(ExtendPgVectorStore.this.toFloatArray(ExtendPgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException var3) {
            throw new RuntimeException(var3);
        }
    }

    private float[] toFloatArray(List<Double> embeddingDouble) {
        float[] embeddingFloat = new float[embeddingDouble.size()];
        int i = 0;

        Double d;
        for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
            d = (Double)var4.next();
        }

        return embeddingFloat;
    }

    public Optional<Boolean> delete(List<String> idList,String tableName) {
        int updateCount = 0;

        int count;
        for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
            String id = (String)var3.next();
            count = this.jdbcTemplate.update("DELETE FROM "+tableName+" WHERE id = ?", new Object[]{UUID.fromString(id)});
        }

        return Optional.of(updateCount == idList.size());
    }

    public List<Document> similaritySearch(String query,String tableName) {
        return this.similaritySearch(SearchRequest.query(query),tableName);
    }

    public List<Document> similaritySearch(SearchRequest request,String tableName) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        String jsonPathFilter = "";
        if (StringUtils.hasText(nativeFilterExpression)) {
            jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
        }

        double distance = 1.0 - request.getSimilarityThreshold();
        PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
        return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, tableName, jsonPathFilter), new ExtendPgVectorStore.DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
    }

    public List<Double> embeddingDistance(String query,String tableName) {
        return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
            @Nullable
            public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
                return rs.getDouble("distance");
            }
        }, new Object[]{this.getQueryEmbedding(query)});
    }

    private PGvector getQueryEmbedding(String query) {
        List<Double> embedding = this.embeddingModel.embed(query);
        return new PGvector(this.toFloatArray(embedding));
    }

    private String comparisonOperator() {
        return this.getDistanceType().operator;
    }

/*    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }*/

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        } else {
            try {
                int embeddingDimensions = this.embeddingModel.dimensions();
                if (embeddingDimensions > 0) {
                    return embeddingDimensions;
                }
            } catch (Exception var2) {
                logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
            }

            return 1536;
        }
    }

    public static enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        private PgDistanceType(String operator, String index, String sqlTemplate) {
            this.operator = operator;
            this.index = index;
            this.similaritySearchSqlTemplate = sqlTemplate;
        }
    }

    public static enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW;

        private PgIndexType() {
        }
    }

    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private ObjectMapper objectMapper;

        public DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString("id");
            String content = rs.getString("content");
            PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
            PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
            Float distance = rs.getFloat("distance");
            Map<String, Object> metadata = this.toMap(pgMetadata);
            metadata.put("distance", distance);
            Document document = new Document(id, content, metadata);
            document.setEmbedding(this.toDoubleList(embedding));
            return document;
        }

        private List<Double> toDoubleList(PGobject embedding) throws SQLException {
            float[] floatArray = (new PGvector(embedding.getValue())).toArray();
            return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
                return (double)floatArray[i];
            }).boxed().toList();
        }

        private Map<String, Object> toMap(PGobject pgObject) {
            String source = pgObject.getValue();

            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            } catch (JsonProcessingException var4) {
                throw new RuntimeException(var4);
            }
        }
    }
}

当我们要使用上面这个ExtendPgVectorStore进行操作时首先我们要排除掉原PgVectorStore的注入。

接着我们需要注入自己的ExtendPgVectorStore类

java 复制代码
import com.lccloud.tenderdocument.vector.ExtendPgVectorStore;
import org.springframework.ai.autoconfigure.vectorstore.pgvector.PgVectorStoreProperties;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;

@Configuration
@AutoConfigureAfter(JdbcTemplateAutoConfiguration.class)
@EnableConfigurationProperties({PgVectorStoreProperties.class})
public class PgVectorConfig {

    public PgVectorConfig() {
    }
    /**
     * 向量数据库进行检索操作
     * @param jdbcTemplate
     * @return
     */
    @Bean
    public ExtendPgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, PgVectorStoreProperties properties) {
        boolean initializeSchema = properties.isInitializeSchema();
        return new ExtendPgVectorStore(jdbcTemplate, embeddingModel, properties.getDimensions(), properties.getDistanceType(), properties.isRemoveExistingVectorStoreTable(), properties.getIndexType(), initializeSchema);
    }
    /**
     * 文本分割器
     * @return
     */
    @Bean
    public TokenTextSplitter tokenTextSplitter() {
        return new TokenTextSplitter();
    }
}

上面这里的PgVectorStoreProperties也可以换成我们自己的类方法(这里我懒得换就用pgVectorStore自带的了)。然后我们在使用的时候就可以注入ExtendPgVectorStore进行操作了。

相关推荐
Tiandaren8 小时前
自用提示词02 || Prompt Engineering || RAG数据切分 || 作用:通过LLM将文档切分成chunks
数据库·pytorch·深度学习·oracle·prompt·rag
胡桃姓胡,蝴蝶也姓胡1 天前
Rag优化 - 如何提升首字响应速度
后端·大模型·rag
曾经的三心草1 天前
SpringAI5-智能聊天机器⼈
java·springai
熊猫钓鱼>_>2 天前
基于模板提高垂直领域大模型应用场景的文字语言组织准确性
自动化·llm·多模态·模板·rag·垂直领域
深色風信子5 天前
SpringAI Redis RAG 搜索
springai·rag 搜索·spring redis·redis rag·springai rag·redis rag 搜索·springai rag 搜索
深色風信子5 天前
SpringAI 内嵌模型 ONNX
springai·springai onnx·embedding onnx·java springai·java onnx·java embedding
RAG专家5 天前
【Mixture-of-RAG】将文本和表格与大型语言模型相结合
人工智能·语言模型·rag·检索增强生成
Microsoft Word5 天前
向量数据库与RAG
数据库·人工智能·向量数据库·rag
大模型教程6 天前
Windows系统本地知识库构建:Cherry Studio+Ollama
llm·agent·ollama
Qiuner6 天前
快速入门LangChain4j Ollama本地部署与阿里百炼请求大模型
语言模型·langchain·nlp·llama·ollama