springai+pgvector+ollama实现rag

首先在ollama中安装mofanke/dmeta-embedding-zh:latest。执行ollama run mofanke/dmeta-embedding-zh 。实现将文本转化为向量数据

接着安装pgvector(建议使用pgadmin4作为可视化工具,用navicate会出现表不显示的问题)

安装好需要的软件后我们开始编码操作。

1:在pom文件中加入:

java 复制代码
        <!--用于连接pgsql-->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-jdbc</artifactId>
        </dependency>
        <!--用于使用pgvector来操作向量数据库-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pgvector-store-spring-boot-starter</artifactId>
        </dependency>
        <!--pdf解析-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-pdf-document-reader</artifactId>
        </dependency>
        <!--文档解析l-->
        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-tika-document-reader</artifactId>
        </dependency>

2:在yml中配置:

java 复制代码
spring:
  datasource:
    url: jdbc:postgresql://127.0.0.1:5432/postgres
    username: postgres
    password: password
  ai:
    vectorstore:
      pgvector:
        dimensions: 768   #不同的embeddingmodel对应的值
    ollama:
      base-url: http://127.0.0.1:11434
      chat:
        enabled: true
        options:
          model: qwen2:7b
      embedding:
        model: mofanke/dmeta-embedding-zh

3:在controller中加入:

java 复制代码
   /**
     * 嵌入文件
     *
     * @param file 待嵌入的文件
     * @return 是否成功
     */
    @SneakyThrows
    @PostMapping("embedding")
    public List<Document> embedding(@RequestParam MultipartFile file) {

        // 从IO流中读取文件
        TikaDocumentReader tikaDocumentReader = new TikaDocumentReader(new InputStreamResource(file.getInputStream()));
        // 将文本内容划分成更小的块
        List<Document> splitDocuments = new TokenTextSplitter()
                .apply(tikaDocumentReader.read());
        // 存入向量数据库,这个过程会自动调用embeddingModel,将文本变成向量再存入。
        vector.add(splitDocuments);
        return splitDocuments;
    }

调用上方的接口可以将文档转为向量数据存入到pgvector中

4:请求聊天,先根据聊天内容通过pgvector获取对应的数据,并将结果丢到qwen2模型中进行数据分析并返回结果

java 复制代码
/**
     * 获取prompt
     *
     * @param message 提问内容
     * @param context 上下文
     * @return prompt
     */
    private String getChatPrompt2String(String message, String context) {
        String promptText = """
				请用仅用以下内容回答"%s" ,输出结果仅在以下内容中,输出内容仅以下内容,不需要其他描述词:
				%s
				""";
        return String.format(promptText, message, context);
    }

    @GetMapping("chatToPgVector")
    public String chatToPgVector(String message) {

        // 1. 定义提示词模板,question_answer_context会被替换成向量数据库中查询到的文档。
        String promptWithContext = """
                你是一个代码程序,你需要在文本中获取信息并输出成json格式的数据,下面是上下文信息
                ---------------------
                {question_answer_context}
                ---------------------
                给定的上下文和提供的历史信息,而不是事先的知识,回复用户的意见。如果答案不在上下文中,告诉用户你不能回答这个问题。
                """;
        //查询获取文档信息
        List<Document> documents = vector.similaritySearch(message,"test_store");
        //提取文本内容
        String content = documents.stream()
                .map(Document::getContent)
                .collect(Collectors.joining("\n"));
        System.out.println(content);
        //封装prompt并调用大模型
        String chatResponse = ollamaChatModel.call(getChatPrompt2String(message, content));
        return chatResponse;
   /*     return ChatClient.create(ollamaChatModel).prompt()
                .user(message)
                // 2. QuestionAnswerAdvisor会在运行时替换模板中的占位符`question_answer_context`,替换成向量数据库中查询到的文档。此时的query=用户的提问+替换完的提示词模板;
                .advisors(new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults(), promptWithContext))
                .call().content();*/
    }

至此一个简单的rag搜索增强demo就完成了。接下来我们来看看PgVectorStore为我们做了什么

java 复制代码
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package org.springframework.ai.vectorstore;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.IntStream;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

public class PgVectorStore implements VectorStore, InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(PgVectorStore.class);
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String VECTOR_TABLE_NAME = "vector_store";
    public static final String VECTOR_INDEX_NAME = "spring_ai_vector_index";
    public final FilterExpressionConverter filterExpressionConverter;
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingModel embeddingModel;
    private int dimensions;
    private PgDistanceType distanceType;
    private ObjectMapper objectMapper;
    private boolean removeExistingVectorStoreTable;
    private PgIndexType createIndexMethod;
    private final boolean initializeSchema;

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
        this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema) {
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        this.objectMapper = new ObjectMapper();
        this.jdbcTemplate = jdbcTemplate;
        this.embeddingModel = embeddingModel;
        this.dimensions = dimensions;
        this.distanceType = distanceType;
        this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
        this.createIndexMethod = createIndexMethod;
        this.initializeSchema = initializeSchema;
    }

    public PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void add(final List<Document> documents) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException var3) {
            throw new RuntimeException(var3);
        }
    }

    private float[] toFloatArray(List<Double> embeddingDouble) {
        float[] embeddingFloat = new float[embeddingDouble.size()];
        int i = 0;

        Double d;
        for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
            d = (Double)var4.next();
        }

        return embeddingFloat;
    }

    public Optional<Boolean> delete(List<String> idList) {
        int updateCount = 0;

        int count;
        for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
            String id = (String)var3.next();
            count = this.jdbcTemplate.update("DELETE FROM vector_store WHERE id = ?", new Object[]{UUID.fromString(id)});
        }

        return Optional.of(updateCount == idList.size());
    }

    public List<Document> similaritySearch(SearchRequest request) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        String jsonPathFilter = "";
        if (StringUtils.hasText(nativeFilterExpression)) {
            jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
        }

        double distance = 1.0 - request.getSimilarityThreshold();
        PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
        return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, "vector_store", jsonPathFilter), new DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
    }

    public List<Double> embeddingDistance(String query) {
        return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
            @Nullable
            public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
                return rs.getDouble("distance");
            }
        }, new Object[]{this.getQueryEmbedding(query)});
    }

    private PGvector getQueryEmbedding(String query) {
        List<Double> embedding = this.embeddingModel.embed(query);
        return new PGvector(this.toFloatArray(embedding));
    }

    private String comparisonOperator() {
        return this.getDistanceType().operator;
    }

    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        } else {
            try {
                int embeddingDimensions = this.embeddingModel.dimensions();
                if (embeddingDimensions > 0) {
                    return embeddingDimensions;
                }
            } catch (Exception var2) {
                logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
            }

            return 1536;
        }
    }

    public static enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        private PgDistanceType(String operator, String index, String sqlTemplate) {
            this.operator = operator;
            this.index = index;
            this.similaritySearchSqlTemplate = sqlTemplate;
        }
    }

    public static enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW;

        private PgIndexType() {
        }
    }

    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private ObjectMapper objectMapper;

        public DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString("id");
            String content = rs.getString("content");
            PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
            PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
            Float distance = rs.getFloat("distance");
            Map<String, Object> metadata = this.toMap(pgMetadata);
            metadata.put("distance", distance);
            Document document = new Document(id, content, metadata);
            document.setEmbedding(this.toDoubleList(embedding));
            return document;
        }

        private List<Double> toDoubleList(PGobject embedding) throws SQLException {
            float[] floatArray = (new PGvector(embedding.getValue())).toArray();
            return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
                return (double)floatArray[i];
            }).boxed().toList();
        }

        private Map<String, Object> toMap(PGobject pgObject) {
            String source = pgObject.getValue();

            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            } catch (JsonProcessingException var4) {
                throw new RuntimeException(var4);
            }
        }
    }
}

我们可以看到PgVectorStore实现了InitializingBean并实现了afterPropertiesSet方法。它会在属性设置完成后执行。

java 复制代码
 public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }

这里它会根据initializeSchema(在PgVectorStoreProperties中,默认为true,我们可以yml中配置spring:ai:vectorstore:pgvector:initialize-schema:false来禁用)来判断是否帮我们建表。这里他会帮我们建一个叫vector_store的表,其中包含id(uuid),metadate(json),content(text),embedding(vector(1536))。这里1536指的就是dimensions的值。当我们用默认建的表去做pgvector的诗句存储时会出现 ERROR: expected 1536 dimensions, not 768这样的报错,就是表示我们ollama中的embedding模型输出的dimensions是768,而pgvector中的embedding是1536,他们不匹配所以无法存储。这时我们需要去pgvector中修改embedding字段的token数为768即可(这里不同模型返回的dimension值不一样,可以根据报错信息自行调整)

接下来我们看一下核心的操作方法-向数据库中插入数据

java 复制代码
 public void add(final List<Document> documents) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO vector_store (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = PgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(PgVectorStore.this.toFloatArray(PgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

这里因为Springai刚出,也不是稳定版的,它在代码中直接写死了操作表。我们使用pgvectorStore时只能对vector_store进行操作,这在实际应用场景中可能会造成一定的局限性。所以我们可以自己写一个扩展操作类来替换它。如下:

java 复制代码
package com.lccloud.tenderdocument.vector;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import org.postgresql.util.PGobject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.converter.PgVectorFilterExpressionConverter;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.StatementCreatorUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.IntStream;

public class ExtendPgVectorStore  {
    private static final Logger logger = LoggerFactory.getLogger(ExtendPgVectorStore.class);
    public final FilterExpressionConverter filterExpressionConverter;
    private final JdbcTemplate jdbcTemplate;
    private final EmbeddingModel embeddingModel;
    private int dimensions;
    private PgVectorStore.PgDistanceType distanceType;
    private ObjectMapper objectMapper;
    private boolean removeExistingVectorStoreTable;
    private PgVectorStore.PgIndexType createIndexMethod;
    private final boolean initializeSchema;

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
        this(jdbcTemplate, embeddingModel, -1, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions) {
        this(jdbcTemplate, embeddingModel, dimensions, PgVectorStore.PgDistanceType.COSINE_DISTANCE, false, PgVectorStore.PgIndexType.NONE, false);
    }

    public ExtendPgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgVectorStore.PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgVectorStore.PgIndexType createIndexMethod, boolean initializeSchema) {
        this.filterExpressionConverter = new PgVectorFilterExpressionConverter();
        this.objectMapper = new ObjectMapper();
        this.jdbcTemplate = jdbcTemplate;
        this.embeddingModel = embeddingModel;
        this.dimensions = dimensions;
        this.distanceType = distanceType;
        this.removeExistingVectorStoreTable = removeExistingVectorStoreTable;
        this.createIndexMethod = createIndexMethod;
        this.initializeSchema = initializeSchema;
    }

    public PgVectorStore.PgDistanceType getDistanceType() {
        return this.distanceType;
    }

    public void add(final List<Document> documents,String tableName) {
        final int size = documents.size();
        this.jdbcTemplate.batchUpdate("INSERT INTO "+ tableName+" (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) ON CONFLICT (id) DO UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", new BatchPreparedStatementSetter() {
            public void setValues(PreparedStatement ps, int i) throws SQLException {
                Document document = (Document)documents.get(i);
                String content = document.getContent();
                String json = ExtendPgVectorStore.this.toJson(document.getMetadata());
                PGvector pGvector = new PGvector(ExtendPgVectorStore.this.toFloatArray(ExtendPgVectorStore.this.embeddingModel.embed(document)));
                StatementCreatorUtils.setParameterValue(ps, 1, Integer.MIN_VALUE, UUID.fromString(document.getId()));
                StatementCreatorUtils.setParameterValue(ps, 2, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 3, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 4, Integer.MIN_VALUE, pGvector);
                StatementCreatorUtils.setParameterValue(ps, 5, Integer.MIN_VALUE, content);
                StatementCreatorUtils.setParameterValue(ps, 6, Integer.MIN_VALUE, json);
                StatementCreatorUtils.setParameterValue(ps, 7, Integer.MIN_VALUE, pGvector);
            }

            public int getBatchSize() {
                return size;
            }
        });
    }

    private String toJson(Map<String, Object> map) {
        try {
            return this.objectMapper.writeValueAsString(map);
        } catch (JsonProcessingException var3) {
            throw new RuntimeException(var3);
        }
    }

    private float[] toFloatArray(List<Double> embeddingDouble) {
        float[] embeddingFloat = new float[embeddingDouble.size()];
        int i = 0;

        Double d;
        for(Iterator var4 = embeddingDouble.iterator(); var4.hasNext(); embeddingFloat[i++] = d.floatValue()) {
            d = (Double)var4.next();
        }

        return embeddingFloat;
    }

    public Optional<Boolean> delete(List<String> idList,String tableName) {
        int updateCount = 0;

        int count;
        for(Iterator var3 = idList.iterator(); var3.hasNext(); updateCount += count) {
            String id = (String)var3.next();
            count = this.jdbcTemplate.update("DELETE FROM "+tableName+" WHERE id = ?", new Object[]{UUID.fromString(id)});
        }

        return Optional.of(updateCount == idList.size());
    }

    public List<Document> similaritySearch(String query,String tableName) {
        return this.similaritySearch(SearchRequest.query(query),tableName);
    }

    public List<Document> similaritySearch(SearchRequest request,String tableName) {
        String nativeFilterExpression = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        String jsonPathFilter = "";
        if (StringUtils.hasText(nativeFilterExpression)) {
            jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
        }

        double distance = 1.0 - request.getSimilarityThreshold();
        PGvector queryEmbedding = this.getQueryEmbedding(request.getQuery());
        return this.jdbcTemplate.query(String.format(this.getDistanceType().similaritySearchSqlTemplate, tableName, jsonPathFilter), new ExtendPgVectorStore.DocumentRowMapper(this.objectMapper), new Object[]{queryEmbedding, queryEmbedding, distance, request.getTopK()});
    }

    public List<Double> embeddingDistance(String query,String tableName) {
        return this.jdbcTemplate.query("SELECT embedding " + this.comparisonOperator() + " ? AS distance FROM vector_store", new RowMapper<Double>() {
            @Nullable
            public Double mapRow(ResultSet rs, int rowNum) throws SQLException {
                return rs.getDouble("distance");
            }
        }, new Object[]{this.getQueryEmbedding(query)});
    }

    private PGvector getQueryEmbedding(String query) {
        List<Double> embedding = this.embeddingModel.embed(query);
        return new PGvector(this.toFloatArray(embedding));
    }

    private String comparisonOperator() {
        return this.getDistanceType().operator;
    }

/*    public void afterPropertiesSet() throws Exception {
        if (this.initializeSchema) {
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
            this.jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
            if (this.removeExistingVectorStoreTable) {
                this.jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
            }

            this.jdbcTemplate.execute(String.format("CREATE TABLE IF NOT EXISTS %s (\n\tid uuid DEFAULT uuid_generate_v4() PRIMARY KEY,\n\tcontent text,\n\tmetadata json,\n\tembedding vector(%d)\n)\n", "vector_store", this.embeddingDimensions()));
            if (this.createIndexMethod != PgVectorStore.PgIndexType.NONE) {
                this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (embedding %s)\n", "spring_ai_vector_index", "vector_store", this.createIndexMethod, this.getDistanceType().index));
            }

        }
    }*/

    int embeddingDimensions() {
        if (this.dimensions > 0) {
            return this.dimensions;
        } else {
            try {
                int embeddingDimensions = this.embeddingModel.dimensions();
                if (embeddingDimensions > 0) {
                    return embeddingDimensions;
                }
            } catch (Exception var2) {
                logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:1536", var2);
            }

            return 1536;
        }
    }

    public static enum PgDistanceType {
        EUCLIDEAN_DISTANCE("<->", "vector_l2_ops", "SELECT *, embedding <-> ? AS distance FROM %s WHERE embedding <-> ? < ? %s ORDER BY distance LIMIT ? "),
        NEGATIVE_INNER_PRODUCT("<#>", "vector_ip_ops", "SELECT *, (1 + (embedding <#> ?)) AS distance FROM %s WHERE (1 + (embedding <#> ?)) < ? %s ORDER BY distance LIMIT ? "),
        COSINE_DISTANCE("<=>", "vector_cosine_ops", "SELECT *, embedding <=> ? AS distance FROM %s WHERE embedding <=> ? < ? %s ORDER BY distance LIMIT ? ");

        public final String operator;
        public final String index;
        public final String similaritySearchSqlTemplate;

        private PgDistanceType(String operator, String index, String sqlTemplate) {
            this.operator = operator;
            this.index = index;
            this.similaritySearchSqlTemplate = sqlTemplate;
        }
    }

    public static enum PgIndexType {
        NONE,
        IVFFLAT,
        HNSW;

        private PgIndexType() {
        }
    }

    private static class DocumentRowMapper implements RowMapper<Document> {
        private static final String COLUMN_EMBEDDING = "embedding";
        private static final String COLUMN_METADATA = "metadata";
        private static final String COLUMN_ID = "id";
        private static final String COLUMN_CONTENT = "content";
        private static final String COLUMN_DISTANCE = "distance";
        private ObjectMapper objectMapper;

        public DocumentRowMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
        }

        public Document mapRow(ResultSet rs, int rowNum) throws SQLException {
            String id = rs.getString("id");
            String content = rs.getString("content");
            PGobject pgMetadata = (PGobject)rs.getObject("metadata", PGobject.class);
            PGobject embedding = (PGobject)rs.getObject("embedding", PGobject.class);
            Float distance = rs.getFloat("distance");
            Map<String, Object> metadata = this.toMap(pgMetadata);
            metadata.put("distance", distance);
            Document document = new Document(id, content, metadata);
            document.setEmbedding(this.toDoubleList(embedding));
            return document;
        }

        private List<Double> toDoubleList(PGobject embedding) throws SQLException {
            float[] floatArray = (new PGvector(embedding.getValue())).toArray();
            return IntStream.range(0, floatArray.length).mapToDouble((i) -> {
                return (double)floatArray[i];
            }).boxed().toList();
        }

        private Map<String, Object> toMap(PGobject pgObject) {
            String source = pgObject.getValue();

            try {
                return (Map)this.objectMapper.readValue(source, Map.class);
            } catch (JsonProcessingException var4) {
                throw new RuntimeException(var4);
            }
        }
    }
}

当我们要使用上面这个ExtendPgVectorStore进行操作时首先我们要排除掉原PgVectorStore的注入。

接着我们需要注入自己的ExtendPgVectorStore类

java 复制代码
import com.lccloud.tenderdocument.vector.ExtendPgVectorStore;
import org.springframework.ai.autoconfigure.vectorstore.pgvector.PgVectorStoreProperties;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;

@Configuration
@AutoConfigureAfter(JdbcTemplateAutoConfiguration.class)
@EnableConfigurationProperties({PgVectorStoreProperties.class})
public class PgVectorConfig {

    public PgVectorConfig() {
    }
    /**
     * 向量数据库进行检索操作
     * @param jdbcTemplate
     * @return
     */
    @Bean
    public ExtendPgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, PgVectorStoreProperties properties) {
        boolean initializeSchema = properties.isInitializeSchema();
        return new ExtendPgVectorStore(jdbcTemplate, embeddingModel, properties.getDimensions(), properties.getDistanceType(), properties.isRemoveExistingVectorStoreTable(), properties.getIndexType(), initializeSchema);
    }
    /**
     * 文本分割器
     * @return
     */
    @Bean
    public TokenTextSplitter tokenTextSplitter() {
        return new TokenTextSplitter();
    }
}

上面这里的PgVectorStoreProperties也可以换成我们自己的类方法(这里我懒得换就用pgVectorStore自带的了)。然后我们在使用的时候就可以注入ExtendPgVectorStore进行操作了。

相关推荐
AI大模型2 天前
基于 Ollama 本地 LLM 大语言模型实现 ChatGPT AI 聊天系统
程序员·llm·ollama
未来之窗软件服务3 天前
自建知识库,向量数据库 体系建设(二)之BERT 与.NET 8
人工智能·深度学习·bert·知识库·向量数据库·仙盟创梦ide·东方仙盟
李大腾腾4 天前
3、JSON处理( n8n 节点用于交换信息的语言)
openai·workflow·ollama
陈佬昔没带相机4 天前
ollama 终于有UI了,但我却想弃坑了
人工智能·llm·ollama
两棵雪松4 天前
为什么RAG技术可以缓解大模型知识固话和幻觉问题
人工智能·rag
占星安啦4 天前
【SpringAI】9.创建本地mcp服务(演示通过mcp实现联网搜索)
springai·mcp·联网搜索·searchapi
李大腾腾5 天前
2、n8n 构建你的第一个 AI Agent
openai·agent·ollama
真就死难5 天前
适用于个人开发、中小型项目的Embedding方案(配合ChromaDB)
python·embedding·rag
一包烟电脑面前做一天6 天前
RAG实现:.Net + Ollama + Qdrant 实现文本向量化,实现简单RAG
.net·向量数据库·ai大模型·rag·ollama·qdrant·文本分块
一包烟电脑面前做一天6 天前
MCP实现:.Net实现MCP服务端 + Ollama ,MCP服务端工具调用
.net·ai大模型·ollama·mcp·mcp服务端