Rag优化 - 如何提升首字响应速度

Rag性能公式

  • Tembed(query):查询向量

  • Tretrieval:向量检索

  • Tprompt-build:上下文构建

  • Tinfra-overhead:网络/队列/IO

Embedding阶段

|----------------|-----------------|
| 问题 | 加速手段 |
| 网络往返多、一次只算一条 | 批量请求 (batch) |
| I/O 等待白白浪费 CPU | 异步并发 |
| 重复文本反复算 | Redis/KV 缓存 |
| 模型本身推理慢 | 小模型 / 量化 / 本地部署 |

  • 批量请求Embedding: 利用OpenAI嵌入接口支持批处理的特性,将多个文本一次性发送以减少总请求量。 例如可以把待嵌入的多个查询或文本段组合成数组传入单次API调用,避免逐条请求所带来的网络开销。需要注意每次请求的最大token限制并在超出时拆分批次,避免报错。

  • 并发调用与 异步处理 合理利用并发提高吞吐量。通过Go的WaitGroup等异步框架,可以在等待一个Embedding结果时并行触发其他请求。实践中OpenAI对并发请求有一定限制,过高并发可能引起排队延迟,经验上将并发控制在单-digit数量比较稳妥(例如同时5~10个请求)并监控接口返回的速率限制信息。当高并发场景下,可以采用指数退避重试策略来应对429/Rate Limit错误。异步非阻塞调用能让 CPU 空闲时间用于处理其它任务, 从而提升整体吞吐。(资源多、直接k8s搞无数个节点的可以忽略

  • 缓存Embedding结果 针对重复出现的文本,缓存其Embedding以避免重复计算。例如,对于常见问题或频繁查询,可以在首次获取Embedding后将query->embedding键值对存入内存或Redis缓存。下次遇到相同查询时直接复用缓存向量,跳过API调用,从而显著降低延迟。需要设计缓存键(可用查询字符串或其哈希)并考虑到语义相近但不完全相同的查询不会命中缓存的情况。对于文档语料,尽量预先计算并存储Embedding,避免在查询时现算。

Go 复制代码
package main

import (
        "context"
        "crypto/sha1"
        "encoding/hex"
        "encoding/json"
        "fmt"
        "os"
        "strings"
        "sync"

        "github.com/openai/openai-go"
        "github.com/openai/openai-go/option"
        "github.com/redis/go-redis/v9"
        "github.com/pkoukk/tiktoken-go"
        "golang.org/x/sync/semaphore"
)

const (
        model      = "text-embedding-3-small"
        dim        = 1536
        tokenCap   = 8191
        cacheExpiry = 86400 // 缓存过期时间(秒)
)

var (
        openaiClient *openai.Client
        redisClient  *redis.Client
        encoder      *tiktoken.Encoding
)

// 初始化客户端
func init() {
        // 初始化OpenAI客户端
        apiKey := os.Getenv("OPENAI_API_KEY")
        if apiKey == "" {
                panic("OPENAI_API_KEY环境变量未设置")
        }
        openaiClient = openai.NewClient(option.WithAPIKey(apiKey))

        // 初始化Redis客户端
        redisClient = redis.NewClient(&redis.Options{
                Addr:     "localhost:6379",
                Password: "", // 无密码
                DB:       0,  // 默认DB
        })

        // 初始化tiktoken编码器
        var err error
        encoder, err = tiktoken.EncodingForModel(model)
        if err != nil {
                panic(fmt.Sprintf("初始化编码器失败: %v", err))
        }
}

// 生成缓存键
func key(text string) string {
        // 处理文本:去重空格、转为小写
        processed := strings.ToLower(strings.Join(strings.Fields(text), " "))
        // 计算SHA1哈希
        h := sha1.New()
        h.Write([]byte(processed))
        return "emb:" + hex.EncodeToString(h.Sum(nil))
}

// 批量生成嵌入向量
func embedBatch(ctx context.Context, batch []string) ([][]float32, error) {
        resp, err := openaiClient.Embeddings.Create(ctx, openai.EmbeddingCreateParams{
                Model: openai.F(model),
                Input: batch,
        })
        if err != nil {
                return nil, fmt.Errorf("嵌入API调用失败: %v", err)
        }

        embeddings := make([][]float32, len(batch))
        for i, data := range resp.Data {
                embeddings[i] = data.Embedding
        }
        return embeddings, nil
}

// 生成文本嵌入(带缓存和并发控制)
func Embed(ctx context.Context, texts []string, concurrency int, tokenCap int) ([][]float32, error) {
        out := make([][]float32, len(texts)) // 结果集(按输入顺序)
        unprocessed := make([]string, 0)     // 未命中缓存的文本
        unprocessedIndices := make([]int, 0) // 未命中缓存的文本在原数组中的索引

        // 先检查缓存
        for i, txt := range texts {
                key := key(txt)
                val, err := redisClient.Get(ctx, key).Result()
                if err == nil {
                        // 缓存命中,解析结果
                        var vec []float32
                        if err := json.Unmarshal([]byte(val), &vec); err != nil {
                                return nil, fmt.Errorf("解析缓存数据失败: %v", err)
                        }
                        out[i] = vec
                        continue
                }
                if err != redis.Nil {
                        // 非空错误(Redis异常)
                        return nil, fmt.Errorf("Redis查询失败: %v", err)
                }
                // 缓存未命中,加入待处理列表
                unprocessed = append(unprocessed, txt)
                unprocessedIndices = append(unprocessedIndices, i)
        }

        if len(unprocessed) == 0 {
                // 全部命中缓存,直接返回
                return out, nil
        }

        // 批量分组(按token限制)
        var batches [][]string
        var batchIndices [][]int // 每个批次对应的原索引
        currentBatch := make([]string, 0)
        currentIndices := make([]int, 0)
        currentTokens := 0

        for i, txt := range unprocessed {
                tokens := len(encoder.Encode(txt, nil))
                // 如果加上当前文本超过token限制,且当前批次非空,则新建批次
                if currentTokens+tokens > tokenCap && len(currentBatch) > 0 {
                        batches = append(batches, currentBatch)
                        batchIndices = append(batchIndices, currentIndices)
                        currentBatch = []string{txt}
                        currentIndices = []int{unprocessedIndices[i]}
                        currentTokens = tokens
                } else {
                        currentBatch = append(currentBatch, txt)
                        currentIndices = append(currentIndices, unprocessedIndices[i])
                        currentTokens += tokens
                }
        }
        // 加入最后一个批次
        if len(currentBatch) > 0 {
                batches = append(batches, currentBatch)
                batchIndices = append(batchIndices, currentIndices)
        }

        // 并发处理批次(使用信号量控制并发数)
        sem := semaphore.NewWeighted(int64(concurrency))
        var wg sync.WaitGroup
        var errMu sync.Mutex
        var globalErr error

        for i, batch := range batches {
                indices := batchIndices[i]
                wg.Add(1)
                // 申请信号量
                if err := sem.Acquire(ctx, 1); err != nil {
                        errMu.Lock()
                        globalErr = fmt.Errorf("信号量获取失败: %v", err)
                        errMu.Unlock()
                        wg.Done()
                        break
                }

                go func(batch []string, indices []int) {
                        defer wg.Done()
                        defer sem.Release(1)

                        if globalErr != nil {
                                return
                        }

                        // 生成嵌入向量
                        embeddings, err := embedBatch(ctx, batch)
                        if err != nil {
                                errMu.Lock()
                                globalErr = err
                                errMu.Unlock()
                                return
                        }

                        // 写入结果集并缓存
                        for j, idx := range indices {
                                vec := embeddings[j]
                                out[idx] = vec

                                // 写入Redis缓存
                                key := key(batch[j])
                                data, err := json.Marshal(vec)
                                if err != nil {
                                        errMu.Lock()
                                        globalErr = fmt.Errorf("向量序列化失败: %v", err)
                                        errMu.Unlock()
                                        return
                                }
                                if err := redisClient.Set(ctx, key, data, cacheExpiry).Err(); err != nil {
                                        errMu.Lock()
                                        globalErr = fmt.Errorf("Redis写入失败: %v", err)
                                        errMu.Unlock()
                                        return
                                }
                        }
                }(batch, indices)
        }

        wg.Wait()
        if globalErr != nil {
                return nil, globalErr
        }

        return out, nil
}

func main() {
        // 示例用法
        ctx := context.Background()
        texts := []string{"hello world", "golang async example"}
        embeddings, err := Embed(ctx, texts, 5, 8191)
        if err != nil {
                fmt.Printf("错误: %v\n", err)
                return
        }
        fmt.Printf("生成的嵌入向量数量: %d\n", len(embeddings))
}

向量检索阶段(Milvus)

|------------------|--------------------|
| 痛点 | 加速手段 |
| 全库暴力扫 | ANN 索引(HNSW / IVF) |
| 海量数据串行查 | 批量 search + 多副本加载 |
| query 多但每次只看少量数据 | 分区 / 过滤 |
| CPU 饱和 | GPU or 水平扩容 |

  • 使用近似邻居索引 ( ANN ): 避免对大型语料库进行逐条精确暴力搜索,可改用近似最近邻算法构建索引,例如IVF、HNSW等,以大幅提升检索速度。实践表明,对于百万级向量数据,HNSW索引在保持较高召回的同时能将查询延迟降低到毫秒级。Milvus官方也推荐在需要高性能检索时选用HNSW索引。如果使用IVF索引,可调节细分簇数量(nlist)和查询探测范围(nprobe):增大 nlist 提高召回率,减少 nprobe 缩短查询时间,从而在速度与准确率间取得平衡。索引构建时的参数(如HNSW的efConstruction或IVF的分桶参

  • 优化数据分片与过滤:利用 Milvus 的分区和过滤功能缩小检索范围,从而减少每次查询需要遍历的向量数量。如果先验知道查询只涉及某部分语料(例如按来源、时间分区的数据),可将向量集合按属性切分成分区,查询时指定相应分区检索,避免全库扫描。对于规模超大的向量集合,合理分片(sharding)有助于降低单机内检索延迟。同时剔除过期或低相关的向量(例如对知识库定期清理无用数据)可减小索引规模,使查询更高效。

  • 批量查询与并发连接:Milvus 支持在一次请求中执行批量搜索(即传入多个查询向量一起检索),这相比逐一查询能减少网络开销和调度开销,适用于需要同时回答多子问题或多用户批量请求的场景。对于并发请求量高的系统,可在客户端维护连接池或使用多线程 / 协程并发查询 Milvus。Milvus 2.x 的无锁架构对并发查询有良好支持,但仍需确保后端资源充足(CPU / 内存不成为瓶颈)。如果 QPS 需求特别高,增加检索副本:Milvus 允许在内存中加载数据的多个副本来提高并行查询能力。通过在Collection.load()时设置replica_number>1,可以启用多副本使查询负载分摊到不同 Query Node,从而提升整体吞吐。例如,将副本数设为 4 可显著提高 QPS 上限。同样,需要搭配增加 Milvus 后端的 QueryNode 实例数和计算资源,以充分利用副本带来的并行度。

  • 系统配置与硬件加速:调整 Milvus 的配置以匹配性能需求。例如,在保证召回的前提下将搜索参数efSearch(对 HNSW)或nprobe(对 IVF)设为较小值以加快查询。确保在查询前调用collection.Load()将数据加载至内存,并设置合适的cache_config(Milvus 会将常用数据页缓存在内存)。如果数据规模巨大或需要亚毫秒级查询延迟,可考虑 GPU 加速:使用 Milvus 的 GPU 版本或将向量数据托管到支持 GPU 的向量引擎上,以利用 GPU 的并行计算能力执行向量点积运算。不过 GPU 方案需要权衡部署成本,通常在超大规模或低延迟(如实时推荐)场景才需要。总体而言,充分利用 Milvus 的并行和内存特性。

Go 复制代码
package main

import (
        "context"
        "fmt"

        "github.com/milvus-io/milvus-sdk-go/v2/client"
        "github.com/milvus-io/milvus-sdk-go/v2/entity"
        "github.com/milvus-io/milvus-sdk-go/v2/schema"
)

const (
        // 与Python代码中的DIM保持一致(原代码中为1536,需根据实际场景确认)
        dim        = 1536
        collection = "rag_docs" // 集合名
)

func main() {
        // 初始化上下文(可用于超时控制)
        ctx := context.Background()

        // 1. 连接Milvus服务(对应Python的connections.connect)
        c, err := client.NewClient(
                context.Background(),
                client.Config{
                        Address: "127.0.0.1:19530", // Milvus服务地址
                },
        )
        if err != nil {
                panic(fmt.Sprintf("连接Milvus失败: %v", err))
        }
        defer c.Close() // 程序退出时关闭连接

        // 2. 定义集合结构(对应Python的FieldSchema和CollectionSchema)
        // 字段定义:id(主键,自增)、vec(向量)、txt(文本)
        fields := []schema.Field{
                // id字段:INT64类型,主键,自动生成
                schema.NewField().
                        WithName("id").
                        WithDataType(entity.FieldTypeInt64).
                        WithIsPrimaryKey(true).
                        WithAutoID(true),
                // vec字段:FLOAT_VECTOR类型,维度为dim
                schema.NewField().
                        WithName("vec").
                        WithDataType(entity.FieldTypeFloatVector).
                        WithTypeParams(map[string]string{
                                "dim": fmt.Sprintf("%d", dim), // 向量维度
                        }),
                // txt字段:VARCHAR类型,最大长度1024
                schema.NewField().
                        WithName("txt").
                        WithDataType(entity.FieldTypeVarChar).
                        WithTypeParams(map[string]string{
                                "max_length": "1024", // 最大长度
                        }),
        }

        // 创建集合(如果不存在)
        if err := c.CreateCollection(
                ctx,
                schema.NewSchema().WithName(collection).WithFields(fields...),
                1, // 分片数量(与Python默认一致)
        ); err != nil {
                // 忽略"集合已存在"的错误,其他错误需处理
                if !client.IsCollectionExistError(err) {
                        panic(fmt.Sprintf("创建集合失败: %v", err))
                }
        }

        // 3. 创建HNSW索引(仅创建一次,对应Python的create_index)
        // 先检查是否已存在索引
        hasIndex, err := hasIndexOnField(ctx, c, collection, "vec")
        if err != nil {
                panic(fmt.Sprintf("检查索引失败: %v", err))
        }
        if !hasIndex {
                // 定义HNSW索引参数:index_type=HNSW,metric_type=IP,M=16,efConstruction=128
                indexParams := map[string]string{
                        "index_type":      string(entity.IndexTypeHNSW),
                        "metric_type":     string(entity.MetricTypeIP), // 内积(IP)
                        "M":               "16",                        // HNSW的M参数
                        "efConstruction":  "128",                       // 构建时的ef参数
                }
                if err := c.CreateIndex(
                        ctx,
                        collection,   // 集合名
                        "vec",        // 向量字段名
                        indexParams,  // 索引参数
                        client.WithIndexName("vec_idx"), // 索引名(可选)
                ); err != nil {
                        panic(fmt.Sprintf("创建HNSW索引失败: %v", err))
                }
                fmt.Println("HNSW索引创建成功")
        } else {
                fmt.Println("HNSW索引已存在,跳过创建")
        }

        // 4. 加载集合到内存,并设置4个副本(对应Python的col.load(replica_number=4))
        if err := c.LoadCollection(
                ctx,
                collection,
                client.WithReplicaNumber(4), // 副本数量
        ); err != nil {
                panic(fmt.Sprintf("加载集合到内存失败: %v", err))
        }
        fmt.Println("集合加载成功,副本数: 4")

        // 5. 实现搜索功能(对应Python的search函数)
        // 示例:搜索向量
        testVecs := [][]float32{
                make([]float32, dim), // 示例向量(实际使用时替换为真实向量)
        }
        result, err := Search(ctx, c, testVecs, 5, 64)
        if err != nil {
                panic(fmt.Sprintf("搜索失败: %v", err))
        }
        fmt.Printf("搜索结果: %+v\n", result)
}

// 检查指定字段是否已创建索引
func hasIndexOnField(ctx context.Context, c client.Client, collName, fieldName string) (bool, error) {
        indexes, err := c.DescribeIndex(ctx, collName, fieldName)
        if err != nil {
                // 如果索引不存在,返回特定错误,这里视为"未创建索引"
                if client.IsIndexNotExistError(err) {
                        return false, nil
                }
                return false, err
        }
        return len(indexes) > 0, nil
}

// Search 搜索向量,返回匹配的文本列表(对应Python的search函数)
// vecs: 待搜索的向量列表
// k: 返回TopK结果
// ef: HNSW搜索时的ef参数
func Search(ctx context.Context, c client.Client, vecs [][]float32, k, ef int) ([][]string, error) {
        // 构建搜索参数:metric_type=IP,ef=ef
        searchParams := map[string]string{
                "metric_type": string(entity.MetricTypeIP),
                "ef":          fmt.Sprintf("%d", ef),
        }

        // 执行搜索
        req := client.SearchReq{
                CollectionName: collection,
                PartitionNames: []string{}, // 搜索所有分区
                Expr:           "",         // 无过滤条件
                OutputFields:   []string{"txt"}, // 需要返回的字段(文本内容)
                 vectors:        entity.FloatVector(vecs), // 待搜索的向量
                VecFieldName:   "vec",       // 向量字段名
                Params:         searchParams,
                Limit:          k,           // 返回TopK
        }

        resp, err := c.Search(ctx, req)
        if err != nil {
                return nil, fmt.Errorf("搜索请求失败: %v", err)
        }

        // 解析结果:提取每个向量的匹配文本
        result := make([][]string, 0, len(resp))
        for _, hits := range resp { // hits是单个查询向量的匹配结果
                texts := make([]string, 0, len(hits))
                for _, hit := range hits { // hit是单个匹配项
                        // 从实体中获取"txt"字段的值
                        txt, ok := hit.Entity.GetField("txt").(string)
                        if !ok {
                                return nil, fmt.Errorf("获取字段txt失败,类型不匹配")
                        }
                        texts = append(texts, txt)
                }
                result = append(result, texts)
        }

        return result, nil
}

缓存优化

除了上文说的Embedding缓存优化,还有很多地方可以需要:

引入缓存层(Redis等): 在系统中增加缓存机制,用空间换时间,避免重复计算开销。缓存可存在多个层次:

  1. Embedding缓存 **:**缓存常见查询文本的向量表示,下次出现直接复用;缓存文档向量同样重要,静态语料库可以离线算好全部向量并存入Milvus或KV存储。
  2. 检索结果缓存:对于经常被查询的问题,其检索到的文档列表往往相同,可缓存这些文档ID列表,下次查询时直接使用缓存结果而无需访问向量库。
  3. 答案缓存 :对于高度重复且答案固定的提问(如FAQ),可以直接缓存上一次的完整回答文本。下次相同提问立即返回缓存答案,实现近乎零延迟响应。需要注意对于有时效性的数据(如新闻、股价),缓存过久可能失准,需设置适当 TTL **或在数据更新时主动清除相关缓存。**使用Redis这类内存KV存储可以提供毫秒级的读取性能,适合做共享缓存蹭。同时通过哈希Key(例如将query字符串规范化后哈希)索引缓存内容,并采用LRU等策略淘汰冷门条目。总之,缓存系统能大幅度减少重复调用OpenAI API和向量库的次数,从架构上加快响应。
python 复制代码
# Embedding缓存
EMB_TTL = timedelta(days=30)       # 静态文档可更长

async def get_embed_cached(text: str):
    key = f"emb:{_hash(text)}"
    if (vec := _get(key)):
        return vec                 # 命中缓存
    vec = (await embed([text]))[0]
    _set(key, vec, EMB_TTL)
    return vec

# 检索结果缓存
SEARCH_TTL = timedelta(days=1)     # 语料相对稳定,可按需调整

def search_cached(question: str, q_vec, k=3):
    key = f"srch:{_hash(question)}:{k}"
    if (hits := _get(key)):
        return hits
    hits = search([q_vec], k=k)[0]         # 调 Milvus
    _set(key, hits, SEARCH_TTL)
    return hits

# 答案缓存
ANS_TTL = timedelta(days=7)        # FAQ 可更长;时效数据可减小

async def answer_cached(question: str):
    key = f"ans:{_hash(question)}"
    if (ans := _get(key)):
        return ans                 # 秒级返回

    # ------ 缓存未命中:正常 RAG 流程 ------
    q_vec  = await get_embed_cached(question)
    docs   = search_cached(question, q_vec, k=3)
    prompt = build_prompt(question, docs)

    # 不需要流式时可直接用 openai.ChatCompletion
    chunks = []
    async for tok in stream_chat(prompt):   # 自行实现 yield token
        chunks.append(tok)
    answer = "".join(chunks)

    _set(key, answer, ANS_TTL)
    return answer
相关推荐
紫荆鱼7 小时前
设计模式-命令模式(Command)
c++·后端·设计模式·命令模式
编码追梦人7 小时前
深耕 Rust:核心技术解析、生态实践与高性能开发指南
开发语言·后端·rust
朝新_8 小时前
【SpringBoot】详解Maven的操作与配置
java·spring boot·笔记·后端·spring·maven·javaee
绝无仅有8 小时前
某教育大厂面试题解析:MySQL索引、Redis缓存、Dubbo负载均衡等
vue.js·后端·面试
sean8 小时前
开发一个自己的 claude code
前端·后端·ai编程
追逐时光者9 小时前
C#/.NET/.NET Core技术前沿周刊 | 第 59 期(2025年10.20-10.26)
后端·.net
盖世英雄酱5813610 小时前
java深度调试【第三章内存分析和堆内存设置】
java·后端
007php00710 小时前
京东面试题解析:同步方法、线程池、Spring、Dubbo、消息队列、Redis等
开发语言·后端·百度·面试·职场和发展·架构·1024程序员节
程序定小飞10 小时前
基于springboot的电影评论网站系统设计与实现
java·spring boot·后端