RAG详解

原理:在构建大语言模型智能体时,有时需要给大模型提供外部文本资料。需要将待查询文本,切片后存储进向量数据库。通过余弦相似度匹配,找到与用户问题语义最接近的文本片段。

步骤:切片、索引,召回、重排

切片

将文本切片后保存到列表中即可,切片方式根据实际情况确定。

python 复制代码
with open(file_dir,"r") as file
    content = file.read()
    contents = content.split("\n\n")

索引,通过算法生成文本的向量表示,一般来说,一个任意长度的文本片段都能生成一个固定长度的向量,不同文本片段语义越相近,它们的向量距离也越相近。在Python中,可以使用SentenceTransformer库来处理。

SentenceTransformer库算法原理:

1. SentenceTransformer 概述

SentenceTransformer 不是一种全新的架构,而是一个封装好的框架 。它把一个预训练的 Transformer 编码器(比如 BERT、RoBERTa)和一个池化策略 组合起来,让模型能把任意长度的句子映射成固定长度的稠密向量,并且这向量能保留语义------意思相近的句子,向量在空间中距离近。

核心公式可以理解为:

text

复制代码
sentence_vector = pooling( transformer(sentence) )  + optional_normalization

它基于 暹罗网络(Siamese Network) 的思路,训练时用成对或三元组的句子,让模型学会拉近相似句、推远不相似句。


2. 管线分解:从文本到向量

2.1 Tokenization(分词)

输入一个句子,例如 "你好世界",首先会经过一个与底层 Transformer 配套的 Tokenizer

  • WordPiece (BERT 类)或 SentencePiece 分词。

  • 加上特殊 token:[CLS](分类标记),[SEP](分隔符)。

  • 映射成 input_ids(每个 token 在词表中的整数 ID)。

  • 生成 attention_mask(真实 token 为 1,填充为 0),用于后续忽略 padding。

  • 对于中文模型 shibing624/text2vec-base-chinese,它的底层一般是 BERT 架构MiniLM 结构,词汇表包含简繁中文、英文等。

代码层面,库内部调用 transformers.AutoTokenizer.from_pretrained() 加载。

2.2 Transformer 编码

input_idsattention_mask 送入 Transformer 的 Encoder 堆栈。以 BERT-base 为例:

  • 词嵌入 + 位置嵌入 + 段落嵌入 → 输入向量 (batch_size, seq_len, 768)

  • 经过 12 层 Transformer Block(Self-Attention + FFN + 残差 + LayerNorm)

  • 输出 hidden states ,形状仍是 (batch_size, seq_len, 768)

每个 token 位置对应一个 768 维的上下文表示。但是句子级别的表示不能直接用这整个矩阵,需要池化成一个向量。

2.3 池化(Pooling)------关键细节

这是 SentenceTransformer 与直接用分类模型最大的区别。常见池化方式:

  • Mean Pooling(默认)

    对所有 token 的 hidden states 求平均,但 只考虑 attention_mask=1 的位置 ,忽略 [PAD]

    伪代码:

    python

    复制代码
    # token_embeddings: (batch_size, seq_len, 768)
    # attention_mask: (batch_size, seq_len)  , 扩展维度后相乘
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
    sum_mask = input_mask_expanded.sum(dim=1).clamp(min=1e-9)
    sentence_embeddings = sum_embeddings / sum_mask

    这样得到的向量表达了整个句子的"平均含义"。目前大多数 SOTA 模型用 Mean Pooling。

  • CLS Pooling

    只取 [CLS] 位置的输出向量。BERT 预训练时 [CLS] 用于 NSP,但对于语义相似度并不总是最优。

  • Max Pooling

    取每个维度在所有 token 中的最大值,较早期使用,现在较少。

你加载的 shibing624/text2vec-base-chinese 使用的就是 Mean Pooling 。它的模型配置里有一个 pooling_mode_mean_tokens: true

2.4 向量后处理

  • 归一化 :很多模型(包括这个)会对池化后的向量做 L2 归一化 ,使向量模长为 1。

    这样两个向量的点积 就等于余弦相似度,计算更快。

  • 可选的 Dense 层 :有些模型在池化后会加一个线性层将维度从 768 压缩到比如 512 或 256(用于减小存储),但 text2vec-base-chinese 没有,直接输出 768 维。


3. 训练时如何让向量具有语义?

模型不是天生就输出好的句向量。预训练的 BERT 只做 token 级别的 MLM 和 NSP,句向量通常很糟糕(各向异性,分布狭窄)。SentenceTransformers 的价值在于提供了一套微调优化目标

  • Contrastive Loss(对比损失)

    输入一对句子 (A, B),还有标签 0/1(不相似/相似)。计算它们的向量距离(如欧氏距离或余弦距离),让相似对距离小,不相似对距离 > margin。

  • Triplet Loss(三元组损失)

    输入 (anchor, positive, negative),让 dist(anchor, pos) < dist(anchor, neg) - margin

  • CoSENT Loss 或其他

    专为中文优化的损失函数,直接用余弦相似度排序。

  • 训练数据 :通常来自 NLI(自然语言推理)、QA pair、同义句数据集等。

    例如 shibing624/text2vec-base-chinese 使用了大量中文 STS(Semantic Textual Similarity)、NLI、以及 paraphrase 数据,以 CoSENT 目标训练。

于是微调后,句子向量在空间中呈球形均匀分布,且一个方向对应一种语义。


4. 关键功能与对应的技术实现

4.1 编码(encode

python

复制代码
embedding_model.encode(["句子1", "句子2"], 
                       batch_size=32, 
                       show_progress_bar=True,
                       convert_to_numpy=True)

内部流程:

  1. 对每个列表元素 tokenize,动态 padding 到 batch 内最长。

  2. 送入模型(Transformer + Pooling)。

  3. 对输出向量做 L2 归一化(如果模型设置了 normalize_embeddings: true)。

  4. 返回 numpy.ndarray 或 list。

你代码里的 embed_chunk 就是简单封装,把一大段文本传入,返回一个 List[float],这是一条 embedding。

4.2 语义相似度计算

得到向量 uv 后(均已归一化),余弦相似度就是 dot(u, v)。库提供了 util.cos_sim,但本质上就是矩阵乘法。

4.3 语义搜索

给定一个查询向量 q 和大量的文档向量矩阵 M(shape: (num_docs, dim)),计算 q @ M.T,然后 torch.topk 取出最大得分的几个索引,对应最相关文档。

SentenceTransformers 提供的 util.semantic_search 一次可以处理多个查询,支持按组划分(比如每个文档有多个 chunk,可以做 max pooling 取最佳 chunk 得分)。

4.4 聚类 / 分类

向量可以直接喂给 KMeans 或 ANN 索引(如 Faiss、HNSWlib)做聚类,或者训练一个简单的逻辑回归做文本分类。


5. 深入 text2vec-base-chinese 的技术细节

模型卡信息

  • 基座:nreimers/MiniLM-L6-H384-uncased 的蒸馏版?其实 shibing624/text2vec-base-chinese 是基于 bert-base-chinese 架构,12 层,768 维度 ,用 Mean Pooling

  • 训练数据:包含中文 STS-B,LCQMC,QQP 等,还有自建的口语相似数据。

  • 损失:CoSENT Loss + 对比学习。

  • 配置:sentence_bert_config.json 中指定了 pooling_mode_mean_tokens: truepooling_mode_cls_token: false

  • 输出:768 维,已 L2 归一化。

这意味着 :你调用 embed_chunk(chunk) 后得到的每个向量可以直接用来比较余弦相似度,值域 [-1,1]。在 RAG 场景中,你可以构建一个 numpy 矩阵或 Faiss 索引,查询时计算相似度并取 Top-K。


6. 工程上的一些细节

  • 批处理encode 支持多句子同时处理,可充分利用 GPU 并行。如果句子太长,会截断到 max_seq_length(默认 256 或 512,可通过 model.get_max_seq_length() 查看,一般在配置文件定义)。

  • 长文本处理 :你之前的 split_into_chunks 按空行切块,但每个块仍需保证不超过最大长度,否则会自动截断,丢失尾部信息。更好的做法是使用重叠滑动窗口切分。

  • 归一化 :如果模型本身做了归一化,存储向量时无需再归一化;如果没有,可以在检索前 sklearn.preprocessing.normalize

  • 微调自己的模型 :可以使用 SentenceTransformers 的 InputExampleSentenceTransformerTrainer 继续训练,以适应特定领域。

如何使用SentenceTransformers

1. 它能做什么

把任意文本(单词、句子、段落)转成一个固定长度的浮点数数组(向量)。语义越相近的文本,它们的向量点积/余弦相似度越高。


2. 快速上手的最小工程单元

安装

bash

复制代码
pip install sentence-transformers

第一次加载模型时会自动下载,后续使用缓存。

加载模型(只需一次)

python

复制代码
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("shibing624/text2vec-base-chinese")
# 可加参数 device='cuda' 指定 GPU,不加会自动选。

模型选型建议

  • 中文通用:shibing624/text2vec-base-chinese(平衡质量与速度)

  • 轻量快速:all-MiniLM-L6-v2(英文为主,中文一般)

  • 多语言:paraphrase-multilingual-MiniLM-L12-v2(中英混合场景)

  • 你本地也可提前把模型下载到 ./local_model,加载时传路径即可。


3. 编码(encode)------ 你最常用的方法

python

复制代码
embeddings = model.encode(
    ["文本1", "文本2", ...],
    batch_size=32,          # 一次处理多少条,按显存/内存调
    show_progress_bar=True, # 数据量大时看进度
    convert_to_numpy=True,  # 返回 numpy 数组,方便存盘和计算
    normalize_embeddings=True  # 强制输出 L2 归一化向量(推荐)
)

参数要点

  • batch_size:越大越快(GPU),但别爆显存;纯 CPU 推理设 8~16 即可。

  • normalize_embeddings:最好设为 True(模型配置里可能已默认归一化,但显式开启更稳妥)。归一化后,直接用点积算余弦相似度。

  • convert_to_numpy:如果你后面接 Faiss/NumPy 存索引,返回 numpy 更方便。


python

复制代码
chunks = split_into_chunks(doc_file)  # 得到 List[str]

千万别一条一条调 embed_chunk ,那样效率极低。正确做法是批量编码

python

复制代码
def embed_chunks(chunks: List[str], model) -> np.ndarray:
    return model.encode(chunks, batch_size=64, show_progress_bar=True,
                        normalize_embeddings=True, convert_to_numpy=True)

一次把整个 chunks 列表丢进去,内部自动分批、动态 padding,比循环快几十倍。


5. 向量存哪里?怎么检索?

存储

最简单的做法,生成立即存成 .npy

python

复制代码
import numpy as np
np.save("chunk_embeddings.npy", embeddings)

同时把 chunks 原文用 JSON lines 存好,保证索引对齐。

检索(单机小规模,<10万条)

直接用库自带的语义搜索:

python

复制代码
from sentence_transformers import util

query_emb = model.encode(["你的问题"], convert_to_numpy=True, normalize_embeddings=True)
# corpus_embeddings 是 (num_chunks, dim) 的 numpy 数组
hits = util.semantic_search(query_emb, corpus_embeddings, top_k=5)
# hits[0] 是列表,每个元素 {'corpus_id': idx, 'score': 余弦相似度}

内部就是矩阵乘法,速度很快。

大规模检索(百万级+)

embeddings 喂给 Faiss:

python

复制代码
import faiss
index = faiss.IndexFlatIP(dim)  # IP = 内积,因为归一化了等价余弦
index.add(embeddings)
D, I = index.search(query_emb, 5)

faiss 可建 GPU 索引、IVF 倒排索引压缩,是生产环境标配。


6. 长文本截断问题(工程老坑)

模型有最大输入长度限制,text2vec-base-chinese 通常为 512 个 token(一个汉字大致 1~2 个 token)。

  • 若你的 chunk 过长(按空行分块后某个段落特别大),超长部分会被直接截断丢弃

  • 解决方案

    • 分块时增加长度检查,用 len(model.tokenizer.encode(chunk)) 判断 token 数。

    • 对超长块做滑动窗口切分(如重叠 50 个 token)。

    • 或使用支持长文本的模型(如 text2vec-large-chinese 支持到 1024),但资源消耗更大。


7. 性能优化清单

  1. 固定模型对象:加载一次,全局复用,不要反复加载。

  2. 批处理 :永远攒一批再 encode,不要单条处理。

  3. GPU 加速model = SentenceTransformer('...', device='cuda'),比 CPU 快 10~50 倍。

  4. 数据预处理:对原文做清洗(去多余空白、特殊符号)能提高向量质量。

  5. 向量维度选择:768 维通常够用,存储敏感可换小模型(如 384 维),牺牲一点精度。


8. 工程化使用模板(完整流程)

python

复制代码
from sentence_transformers import SentenceTransformer, util
import numpy as np
import json

# 1. 加载模型(全局单例)
model = SentenceTransformer("shibing624/text2vec-base-chinese", device="cuda")

# 2. 处理文档
with open("doc.txt", "r") as f:
    text = f.read()
chunks = [c.strip() for c in text.split("\n\n") if c.strip()]
# 可选:过滤过短或过长块

# 3. 批量向量化
embeddings = model.encode(chunks, batch_size=64, show_progress_bar=True,
                          normalize_embeddings=True, convert_to_numpy=True)

# 4. 本地持久化
np.save("embeddings.npy", embeddings)
with open("chunks.json", "w", encoding="utf-8") as f:
    json.dump(chunks, f, ensure_ascii=False)

# 5. 查询
query = "你的问题"
q_emb = model.encode([query], normalize_embeddings=True, convert_to_numpy=True)
hits = util.semantic_search(q_emb, embeddings, top_k=3)

for hit in hits[0]:
    idx = hit['corpus_id']
    print(f"相似度:{hit['score']:.4f}")
    print(chunks[idx][:200])

这个模板已经可以直接用在你的项目中,把你的 embed_chunk 替换成批量逻辑即可。

前面的工作是把文档切成了块,并用 SentenceTransformer 把这些块变成了向量(一长串浮点数)。接下来你需要存起来,方便以后用问题去搜出最相关的块。

chromadb 就是专门干这件事的:向量数据库。它把你的文本和对应的向量一起存好,并内置了快速的相似度搜索。

我帮你把整个工程思路和代码补全,解释都写在注释里,看完就能用。


1. 为什么需要 ChromaDB,而不是直接用 .npy 存?

少量数据时用 numpy 存向量再自己写个循环算相似度是可以的。但一旦数据量变大,或者你希望有更工程化的功能,ChromaDB 的价值就出来了:

  • 自带的快速检索:不用自己每次从头扫描全部向量,内部有索引优化(比如 HNSW)。

  • 元数据管理:除了向量,还能顺便存原文、标题、来源等过滤信息,以后可以按条件筛。

  • 持久化:默认用内存模式(重启就没了),但换个 Client 就能存到磁盘,服务重启后数据还在。

  • 简单 API:增、查、删就几个方法,比手写一堆矩阵运算代码方便很多。

你的代码里用了 EphemeralClient,这是临时模式,数据只存在内存中,进程结束就消失------适合学习和原型验证。如果要落地到可保存的项目,后面会讲怎么改成持久化。


2. 代码里这两行做了什么?

python

复制代码
chromadb_client = chromadb.EphemeralClient()
chromadb_collection = chromadb_client.get_or_create_collection(name="default")
  • EphemeralClient():创建一个"一次性"客户端,数据存在内存里,程序一停就没了。

  • get_or_create_collection("default") :拿到一个叫 "default"集合。可以理解成数据库里的一张表,专门存一类文档的向量。如果这个集合不存在,它会自动创建;如果已经存在,就直接用。

一个集合就像一个 Excel 表格,每一行是一条数据,包含:ID、文本(document)、向量(embedding)、可选的元数据(metadata)。


3. save_embeddings 函数怎么写?

这个函数的作用是把刚刚切好的文本块(chunks)和它们对应的向量(embeddings)存到 ChromaDB 里。

python

复制代码
from typing import List

def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
    """将文本块和对应的向量存入 ChromaDB 集合"""
    # 1. 为每个块生成唯一的 ID
    ids = [f"chunk_{i}" for i in range(len(chunks))]
    
    # 2. 批量添加
    chromadb_collection.add(
        documents=chunks,          # 原始文本
        embeddings=embeddings,     # 对应的向量,形状: (数量, 维度)
        ids=ids                    # 每个块的唯一标识符
    )

注意

  • embeddings 参数是你用 model.encode() 得到的 List[List[float]]numpy 数组(ChromaDB 两种都接受)。

  • ids 必须每个都不同,否则会覆盖旧数据。常用 uuidchunk_序号

  • 你也可以顺手加上 metadatas 参数,放一些比如"来源文件"、"页码"等辅助信息,以后可以按条件过滤。


4. 如何用存储好的数据做查询?

等你存好了,就可以用自然语言问题去搜索最相关的文本块了。

python

复制代码
def query_chunks(question: str, model, n_results: int = 3):
    """输入问题,返回最相关的 n 个文本块"""
    # 1. 先把问题转成向量
    query_emb = model.encode([question], normalize_embeddings=True).tolist()
    
    # 2. 在集合里搜索
    results = chromadb_collection.query(
        query_embeddings=query_emb,  # 注意这里是列表的列表
        n_results=n_results,         # 要返回几条结果
        include=["documents", "distances"]  # 要带回哪些字段
    )
    
    # 3. 解析结果
    # results 是一个字典,结构如下:
    # {
    #     'ids': [['chunk_0', 'chunk_5', ...]],   # 最外层的列表对应一次查询的 batch
    #     'documents': [['文本1', '文本2', ...]],  # 因为我们只查了一个问题,所以取 [0]
    #     'distances': [[0.2, 0.5, ...]]
    # }
    
    docs = results['documents'][0]     # 返回的文本列表
    distances = results['distances'][0] # 距离(默认是余弦距离,越小越相关)
    
    return docs, distances

返回值解释

  • distances:ChromaDB 默认用余弦距离(= 1 - 余弦相似度),所以值越小表示越相似(0 表示完全一样)。

  • documents:对应的原始文本块,直接可以喂给大模型当 prompt 上下文。


5. 从临时模式升级到持久化存储

如果你希望重启程序后数据还在,只需要把客户端换成 PersistentClient,并给它一个本地文件夹路径。

python

复制代码
import chromadb

# 持久化到 ./my_vector_db 文件夹
chromadb_client = chromadb.PersistentClient(path="./my_vector_db")
chromadb_collection = chromadb_client.get_or_create_collection(name="default")

以后每次启动程序,都能拿到以前存过的集合和数据。


6. 把你的整个流程串起来

结合你之前所有代码,一个完整的"文档入库→查询"例子如下:

python

复制代码
from typing import List
from sentence_transformers import SentenceTransformer
import chromadb
import os

# ---- 初始化 ----
# 1. 模型(加载一次,全局复用)
model = SentenceTransformer("shibing624/text2vec-base-chinese")

# 2. 向量数据库(改一行就能换成持久化)
# chromadb_client = chromadb.PersistentClient(path="./db")  # 永久保存
chromadb_client = chromadb.EphemeralClient()                # 临时测试
collection = chromadb_client.get_or_create_collection("my_docs")

# ---- 你之前的切块函数 ----
def split_into_chunks(doc_file: str) -> List[str]:
    with open(doc_file, 'r', encoding='utf-8') as f:
        content = f.read()
    return [chunk.strip() for chunk in content.split("\n\n") if chunk.strip()]

# ---- 批量向量化并存储 ----
def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
    ids = [f"chunk_{i}" for i in range(len(chunks))]
    collection.add(documents=chunks, embeddings=embeddings, ids=ids)

# ---- 查询 ----
def search(question: str, top_k: int = 3):
    query_emb = model.encode([question], normalize_embeddings=True).tolist()
    results = collection.query(
        query_embeddings=query_emb,
        n_results=top_k,
        include=["documents", "distances"]
    )
    for doc, dist in zip(results['documents'][0], results['distances'][0]):
        print(f"[相似度: {1-dist:.3f}] {doc[:100]}...")  # 把距离转回相似度显示
    return results

# ---- 演示流程 ----
if __name__ == "__main__":
    # 假设你有一个 test.txt,里面是几个用空行隔开的段落
    chunks = split_into_chunks("test.txt")
    print(f"切出了 {len(chunks)} 个块")

    # 批量得到向量并直接存库
    embeddings = model.encode(chunks, normalize_embeddings=True).tolist()
    save_embeddings(chunks, embeddings)

    # 模拟搜索
    search("你的问题")

7. 两个你需要小心的地方

  1. embeddings 的格式
    model.encode() 返回 numpy.ndarray,但 ChromaDB 的 add 方法也能接收 list。为了通用,我上面用 .tolist() 转成 Python 列表。如果你用 PersistentClient,它会直接存为高效的内部格式,没问题。

  2. 集合重复添加

    如果你多次运行脚本,每次都往同一个 collection 添加同样的数据,ID 会重复(可能报错或覆盖)。通常的做法是:

    • 检查集合里是否已有该来源的数据,有则先删再加,或跳过。

    • 或者每次创建一个带时间戳的新集合。

    python

    复制代码
    # 删除整个集合重新来(谨慎,会删光)
    chromadb_client.delete_collection("my_docs")
    collection = chromadb_client.get_or_create_collection("my_docs")

8. 下一步你可以尝试什么

  • 加上元数据过滤 :比如 metadatas=[{"source": "文档A.pdf", "page": 3}],查询时可以 where={"source": "文档A.pdf"} 只搜某个文档。

  • 使用 Embedding 模型内置 (ChromaDB 0.4+ 自带 embedding function 集成,省去手动调 model.encode),但你现在自己处理向量,更灵活,也更能理解流程,先保持这种方式就好。

  • 接入大语言模型:把 search 返回的文本块拼成 prompt 上下文,让 ChatGPT 或本地模型基于原文回答。

CrossEncoder(交叉编码器) 对向量数据库返回的结果做重排序(Rerank)。我把它的作用、原理、以及怎么跟你前面的 ChromaDB 配合使用,用工程化的方式讲清楚。


1. 为什么需要重排序?它不是已经有相似度了吗?

你用 ChromaDB 搜索后,返回的 top_k 已经按向量相似度排好了。但向量相似度(余弦距离)有时会"看走眼",比如:

  • 向量模型更关注宏观主题,不太擅长抓细节语义、否定词、数字等。

  • 返回的 top 20 里可能有噪音,靠前的并不一定是最契合问题的。

CrossEncoder 就是用来"二次筛选"的,它会把问题和每个候选句子拼在一起,直接预测一个精确的相关性分数。虽然速度比向量检索慢很多,但准确率高得多。

所以一个完整的 RAG 检索链路通常是:

text

复制代码
问题 → 向量库快速初筛(如 top 20) → CrossEncoder 精排 → 取 top 3 送给大模型

2. 代码拆解 ------ 每一行在做什么

python

复制代码
from sentence_transformers import CrossEncoder

def rerank(query: str, retrieved_chunks: List[str], top_k: int) -> List[str]:
    # 1. 加载一个专门用于排序的 CrossEncoder 模型
    cross_encoder = CrossEncoder('cross-encoder/mmarc0-mmMiniLMv2-L12-H384-v1')
    
    # 2. 构造一对一的输入:[ (问题, 候选1), (问题, 候选2), ... ]
    pairs = [(query, chunk) for chunk in retrieved_chunks]
    
    # 3. 一次性预测所有 pair 的相似度分数(越高越相关)
    scores = cross_encoder.predict(pairs)

    # 4. 把 chunk 和分数打包,然后按分数降序排序
    chunk_with_score_list = [(chunk, score) for chunk, score in zip(retrieved_chunks, scores)]
    chunk_with_score_list.sort(key=lambda pair: pair[1], reverse=True)

    # 5. 只返回重排后前面的 top_k 个文本
    return [chunk for chunk, _ in chunk_with_score_list][:top_k]

# 使用示例
reranked_chunks = rerank(query, retrieved_chunks, 3)

关键细节

  • CrossEncoder 不是把句子变成向量,而是把一对句子拼起来做二分类(相似/不相似)。它的输入是一对,输出是单个浮点分数。

  • 这里的模型 mmarc0-mmMiniLMv2-L12-H384-v1 是一个轻量级的英文排序模型,如果你是中文场景,换成中文的交叉编码器(如 BAAI/bge-reranker-v2-m3cross-encoder/stsb-roberta-base-chinese)会准得多。

  • predict 一次可以接收很多 pairs,内部会批处理,比循环调快。


3. 工程上怎么把它接到 ChromaDB 后面?

假设你之前的查询函数已经返回了 20 个候选块(为了给精排留余地,初筛要多取一些):

python

复制代码
def search_and_rerank(question: str, model, collection, top_k=3, initial_top=20):
    # 第一步:用向量库快速初筛出 initial_top 个候选
    query_emb = model.encode([question], normalize_embeddings=True).tolist()
    results = collection.query(
        query_embeddings=query_emb,
        n_results=initial_top,          # 多取一些
        include=["documents"]
    )
    candidates = results['documents'][0]  # 列表,长度 20

    # 第二步:用 CrossEncoder 精选出最相关的 top_k
    reranked = rerank(question, candidates, top_k)
    return reranked

注意CrossEncoder 的加载成本比较高,一定不能写在函数内部每次都加载。要把它提到全局,只加载一次:

python

复制代码
cross_encoder = CrossEncoder('BAAI/bge-reranker-v2-m3')  # 全局单例

def rerank(query, candidates, top_k):
    pairs = [(query, c) for c in candidates]
    scores = cross_encoder.predict(pairs)
    pairs_with_scores = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
    return [c for c, _ in pairs_with_scores[:top_k]]

4. 模型选型建议(中文场景)

你目前代码里用的 mmarc0-mmMiniLMv2-L12-H384-v1 是英文为主的模型。如果处理中文文档,推荐换成:

  • BAAI/bge-reranker-v2-m3(通用多语言,效果很好,BGE 系列)

  • cross-encoder/stsb-roberta-base-chinese(中文语义相似度基座)

  • maidalun1020/bce-reranker-base_v1(也为中文设计)

这些模型在 HuggingFace 上都能直接用 CrossEncoder(模型名) 加载。


5. 性能优化与工程坑

  • 不要精排太多候选 :初筛的 initial_top 设 20~30 比较合理,太多了精排会变慢,而且边际收益低。

  • 批量预测 :代码里的 cross_encoder.predict(pairs) 就是批处理,不要自己写循环。

  • GPU 加速 :CrossEncoder 也支持 device='cuda',能显著加快重排序速度。加载时指定:
    cross_encoder = CrossEncoder('模型名', device='cuda')

  • 重排序的耗时:通常在几十到几百毫秒(取决于候选数量和模型大小),相比于几十毫秒的向量检索要慢不少,但为提升最终答案质量是值得的。


6. 你的完整链路现在长这样

text

复制代码
原始文档 → 按空行切块 → 句向量模型编码 → 存入 ChromaDB
                                              ↓
用户问题 → 句向量模型编码 → ChromaDB 快速检索(取 20 个)
                                              ↓
                                  CrossEncoder 精排(取 3 个)
                                              ↓
                                    把 3 个块拼成上下文 + 问题 → LLM 生成答案
相关推荐
傲笑风1 小时前
jupyter转PDF教程
python·jupyter
测试员周周1 小时前
【AI测试功能2】AI功能测试的“不可确定性“难题与应对思路:从精确断言到统计判定的完整方案
大数据·人工智能·python·功能测试·测试工具·单元测试·测试用例
薛定谔的猫3691 小时前
深入浅出 Model Context Protocol (MCP):连接 AI 与外部数据的桥梁
ai·llm·agent·mcp·modelcontextprotocol
szial2 小时前
uv 实战指南:用一个工具重塑 Python 开发工作流
开发语言·python·uv
Aision_2 小时前
为什么 CTI 场景需要知识图谱?
人工智能·python·安全·web安全·langchain·prompt·知识图谱
BU摆烂会噶2 小时前
【LangGraph】LangGraph 工具中访问运行时上下文——ToolRuntime
人工智能·python·langchain·人机交互
Karl_wei2 小时前
LangChain Agent 实战接入
aigc·agent·ai编程
MATLAB代码顾问10 小时前
5大智能算法优化标准测试函数对比(Python实现)
开发语言·python
ting945200010 小时前
Tornado 全栈技术深度指南:从原理到实战
人工智能·python·架构·tornado