原理:在构建大语言模型智能体时,有时需要给大模型提供外部文本资料。需要将待查询文本,切片后存储进向量数据库。通过余弦相似度匹配,找到与用户问题语义最接近的文本片段。
步骤:切片、索引,召回、重排
切片
将文本切片后保存到列表中即可,切片方式根据实际情况确定。
python
with open(file_dir,"r") as file
content = file.read()
contents = content.split("\n\n")
索引,通过算法生成文本的向量表示,一般来说,一个任意长度的文本片段都能生成一个固定长度的向量,不同文本片段语义越相近,它们的向量距离也越相近。在Python中,可以使用SentenceTransformer库来处理。
SentenceTransformer库算法原理:
1. SentenceTransformer 概述
SentenceTransformer 不是一种全新的架构,而是一个封装好的框架 。它把一个预训练的 Transformer 编码器(比如 BERT、RoBERTa)和一个池化策略 组合起来,让模型能把任意长度的句子映射成固定长度的稠密向量,并且这向量能保留语义------意思相近的句子,向量在空间中距离近。
核心公式可以理解为:
text
sentence_vector = pooling( transformer(sentence) ) + optional_normalization
它基于 暹罗网络(Siamese Network) 的思路,训练时用成对或三元组的句子,让模型学会拉近相似句、推远不相似句。
2. 管线分解:从文本到向量
2.1 Tokenization(分词)
输入一个句子,例如 "你好世界",首先会经过一个与底层 Transformer 配套的 Tokenizer。
-
WordPiece (BERT 类)或 SentencePiece 分词。
-
加上特殊 token:
[CLS](分类标记),[SEP](分隔符)。 -
映射成
input_ids(每个 token 在词表中的整数 ID)。 -
生成
attention_mask(真实 token 为 1,填充为 0),用于后续忽略 padding。 -
对于中文模型
shibing624/text2vec-base-chinese,它的底层一般是 BERT 架构 或 MiniLM 结构,词汇表包含简繁中文、英文等。
代码层面,库内部调用 transformers.AutoTokenizer.from_pretrained() 加载。
2.2 Transformer 编码
input_ids 和 attention_mask 送入 Transformer 的 Encoder 堆栈。以 BERT-base 为例:
-
词嵌入 + 位置嵌入 + 段落嵌入 → 输入向量 (batch_size, seq_len, 768)
-
经过 12 层 Transformer Block(Self-Attention + FFN + 残差 + LayerNorm)
-
输出 hidden states ,形状仍是
(batch_size, seq_len, 768)
每个 token 位置对应一个 768 维的上下文表示。但是句子级别的表示不能直接用这整个矩阵,需要池化成一个向量。
2.3 池化(Pooling)------关键细节
这是 SentenceTransformer 与直接用分类模型最大的区别。常见池化方式:
-
Mean Pooling(默认)
对所有 token 的 hidden states 求平均,但 只考虑 attention_mask=1 的位置 ,忽略
[PAD]。伪代码:
python
# token_embeddings: (batch_size, seq_len, 768) # attention_mask: (batch_size, seq_len) , 扩展维度后相乘 input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1) sum_mask = input_mask_expanded.sum(dim=1).clamp(min=1e-9) sentence_embeddings = sum_embeddings / sum_mask这样得到的向量表达了整个句子的"平均含义"。目前大多数 SOTA 模型用 Mean Pooling。
-
CLS Pooling
只取
[CLS]位置的输出向量。BERT 预训练时[CLS]用于 NSP,但对于语义相似度并不总是最优。 -
Max Pooling
取每个维度在所有 token 中的最大值,较早期使用,现在较少。
你加载的 shibing624/text2vec-base-chinese 使用的就是 Mean Pooling 。它的模型配置里有一个 pooling_mode_mean_tokens: true。
2.4 向量后处理
-
归一化 :很多模型(包括这个)会对池化后的向量做 L2 归一化 ,使向量模长为 1。
这样两个向量的点积 就等于余弦相似度,计算更快。
-
可选的 Dense 层 :有些模型在池化后会加一个线性层将维度从 768 压缩到比如 512 或 256(用于减小存储),但
text2vec-base-chinese没有,直接输出 768 维。
3. 训练时如何让向量具有语义?
模型不是天生就输出好的句向量。预训练的 BERT 只做 token 级别的 MLM 和 NSP,句向量通常很糟糕(各向异性,分布狭窄)。SentenceTransformers 的价值在于提供了一套微调优化目标:
-
Contrastive Loss(对比损失)
输入一对句子
(A, B),还有标签0/1(不相似/相似)。计算它们的向量距离(如欧氏距离或余弦距离),让相似对距离小,不相似对距离 > margin。 -
Triplet Loss(三元组损失)
输入
(anchor, positive, negative),让dist(anchor, pos)<dist(anchor, neg) - margin。 -
CoSENT Loss 或其他
专为中文优化的损失函数,直接用余弦相似度排序。
-
训练数据 :通常来自 NLI(自然语言推理)、QA pair、同义句数据集等。
例如
shibing624/text2vec-base-chinese使用了大量中文 STS(Semantic Textual Similarity)、NLI、以及 paraphrase 数据,以 CoSENT 目标训练。
于是微调后,句子向量在空间中呈球形均匀分布,且一个方向对应一种语义。
4. 关键功能与对应的技术实现
4.1 编码(encode)
python
embedding_model.encode(["句子1", "句子2"],
batch_size=32,
show_progress_bar=True,
convert_to_numpy=True)
内部流程:
-
对每个列表元素 tokenize,动态 padding 到 batch 内最长。
-
送入模型(Transformer + Pooling)。
-
对输出向量做 L2 归一化(如果模型设置了
normalize_embeddings: true)。 -
返回
numpy.ndarray或 list。
你代码里的 embed_chunk 就是简单封装,把一大段文本传入,返回一个 List[float],这是一条 embedding。
4.2 语义相似度计算
得到向量 u 和 v 后(均已归一化),余弦相似度就是 dot(u, v)。库提供了 util.cos_sim,但本质上就是矩阵乘法。
4.3 语义搜索
给定一个查询向量 q 和大量的文档向量矩阵 M(shape: (num_docs, dim)),计算 q @ M.T,然后 torch.topk 取出最大得分的几个索引,对应最相关文档。
SentenceTransformers 提供的 util.semantic_search 一次可以处理多个查询,支持按组划分(比如每个文档有多个 chunk,可以做 max pooling 取最佳 chunk 得分)。
4.4 聚类 / 分类
向量可以直接喂给 KMeans 或 ANN 索引(如 Faiss、HNSWlib)做聚类,或者训练一个简单的逻辑回归做文本分类。
5. 深入 text2vec-base-chinese 的技术细节
模型卡信息:
-
基座:
nreimers/MiniLM-L6-H384-uncased的蒸馏版?其实shibing624/text2vec-base-chinese是基于bert-base-chinese架构,12 层,768 维度 ,用 Mean Pooling。 -
训练数据:包含中文 STS-B,LCQMC,QQP 等,还有自建的口语相似数据。
-
损失:CoSENT Loss + 对比学习。
-
配置:
sentence_bert_config.json中指定了pooling_mode_mean_tokens: true和pooling_mode_cls_token: false。 -
输出:768 维,已 L2 归一化。
这意味着 :你调用 embed_chunk(chunk) 后得到的每个向量可以直接用来比较余弦相似度,值域 [-1,1]。在 RAG 场景中,你可以构建一个 numpy 矩阵或 Faiss 索引,查询时计算相似度并取 Top-K。
6. 工程上的一些细节
-
批处理 :
encode支持多句子同时处理,可充分利用 GPU 并行。如果句子太长,会截断到max_seq_length(默认 256 或 512,可通过model.get_max_seq_length()查看,一般在配置文件定义)。 -
长文本处理 :你之前的
split_into_chunks按空行切块,但每个块仍需保证不超过最大长度,否则会自动截断,丢失尾部信息。更好的做法是使用重叠滑动窗口切分。 -
归一化 :如果模型本身做了归一化,存储向量时无需再归一化;如果没有,可以在检索前
sklearn.preprocessing.normalize。 -
微调自己的模型 :可以使用 SentenceTransformers 的
InputExample和SentenceTransformerTrainer继续训练,以适应特定领域。
如何使用SentenceTransformers
1. 它能做什么
把任意文本(单词、句子、段落)转成一个固定长度的浮点数数组(向量)。语义越相近的文本,它们的向量点积/余弦相似度越高。
2. 快速上手的最小工程单元
安装
bash
pip install sentence-transformers
第一次加载模型时会自动下载,后续使用缓存。
加载模型(只需一次)
python
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("shibing624/text2vec-base-chinese")
# 可加参数 device='cuda' 指定 GPU,不加会自动选。
模型选型建议:
-
中文通用:
shibing624/text2vec-base-chinese(平衡质量与速度) -
轻量快速:
all-MiniLM-L6-v2(英文为主,中文一般) -
多语言:
paraphrase-multilingual-MiniLM-L12-v2(中英混合场景) -
你本地也可提前把模型下载到
./local_model,加载时传路径即可。
3. 编码(encode)------ 你最常用的方法
python
embeddings = model.encode(
["文本1", "文本2", ...],
batch_size=32, # 一次处理多少条,按显存/内存调
show_progress_bar=True, # 数据量大时看进度
convert_to_numpy=True, # 返回 numpy 数组,方便存盘和计算
normalize_embeddings=True # 强制输出 L2 归一化向量(推荐)
)
参数要点:
-
batch_size:越大越快(GPU),但别爆显存;纯 CPU 推理设 8~16 即可。 -
normalize_embeddings:最好设为True(模型配置里可能已默认归一化,但显式开启更稳妥)。归一化后,直接用点积算余弦相似度。 -
convert_to_numpy:如果你后面接 Faiss/NumPy 存索引,返回 numpy 更方便。
python
chunks = split_into_chunks(doc_file) # 得到 List[str]
千万别一条一条调 embed_chunk ,那样效率极低。正确做法是批量编码:
python
def embed_chunks(chunks: List[str], model) -> np.ndarray:
return model.encode(chunks, batch_size=64, show_progress_bar=True,
normalize_embeddings=True, convert_to_numpy=True)
一次把整个 chunks 列表丢进去,内部自动分批、动态 padding,比循环快几十倍。
5. 向量存哪里?怎么检索?
存储
最简单的做法,生成立即存成 .npy:
python
import numpy as np
np.save("chunk_embeddings.npy", embeddings)
同时把 chunks 原文用 JSON lines 存好,保证索引对齐。
检索(单机小规模,<10万条)
直接用库自带的语义搜索:
python
from sentence_transformers import util
query_emb = model.encode(["你的问题"], convert_to_numpy=True, normalize_embeddings=True)
# corpus_embeddings 是 (num_chunks, dim) 的 numpy 数组
hits = util.semantic_search(query_emb, corpus_embeddings, top_k=5)
# hits[0] 是列表,每个元素 {'corpus_id': idx, 'score': 余弦相似度}
内部就是矩阵乘法,速度很快。
大规模检索(百万级+)
把 embeddings 喂给 Faiss:
python
import faiss
index = faiss.IndexFlatIP(dim) # IP = 内积,因为归一化了等价余弦
index.add(embeddings)
D, I = index.search(query_emb, 5)
faiss 可建 GPU 索引、IVF 倒排索引压缩,是生产环境标配。
6. 长文本截断问题(工程老坑)
模型有最大输入长度限制,text2vec-base-chinese 通常为 512 个 token(一个汉字大致 1~2 个 token)。
-
若你的
chunk过长(按空行分块后某个段落特别大),超长部分会被直接截断丢弃。 -
解决方案:
-
分块时增加长度检查,用
len(model.tokenizer.encode(chunk))判断 token 数。 -
对超长块做滑动窗口切分(如重叠 50 个 token)。
-
或使用支持长文本的模型(如
text2vec-large-chinese支持到 1024),但资源消耗更大。
-
7. 性能优化清单
-
固定模型对象:加载一次,全局复用,不要反复加载。
-
批处理 :永远攒一批再
encode,不要单条处理。 -
GPU 加速 :
model = SentenceTransformer('...', device='cuda'),比 CPU 快 10~50 倍。 -
数据预处理:对原文做清洗(去多余空白、特殊符号)能提高向量质量。
-
向量维度选择:768 维通常够用,存储敏感可换小模型(如 384 维),牺牲一点精度。
8. 工程化使用模板(完整流程)
python
from sentence_transformers import SentenceTransformer, util
import numpy as np
import json
# 1. 加载模型(全局单例)
model = SentenceTransformer("shibing624/text2vec-base-chinese", device="cuda")
# 2. 处理文档
with open("doc.txt", "r") as f:
text = f.read()
chunks = [c.strip() for c in text.split("\n\n") if c.strip()]
# 可选:过滤过短或过长块
# 3. 批量向量化
embeddings = model.encode(chunks, batch_size=64, show_progress_bar=True,
normalize_embeddings=True, convert_to_numpy=True)
# 4. 本地持久化
np.save("embeddings.npy", embeddings)
with open("chunks.json", "w", encoding="utf-8") as f:
json.dump(chunks, f, ensure_ascii=False)
# 5. 查询
query = "你的问题"
q_emb = model.encode([query], normalize_embeddings=True, convert_to_numpy=True)
hits = util.semantic_search(q_emb, embeddings, top_k=3)
for hit in hits[0]:
idx = hit['corpus_id']
print(f"相似度:{hit['score']:.4f}")
print(chunks[idx][:200])
这个模板已经可以直接用在你的项目中,把你的 embed_chunk 替换成批量逻辑即可。
前面的工作是把文档切成了块,并用 SentenceTransformer 把这些块变成了向量(一长串浮点数)。接下来你需要存起来,方便以后用问题去搜出最相关的块。
chromadb 就是专门干这件事的:向量数据库。它把你的文本和对应的向量一起存好,并内置了快速的相似度搜索。
我帮你把整个工程思路和代码补全,解释都写在注释里,看完就能用。
1. 为什么需要 ChromaDB,而不是直接用 .npy 存?
少量数据时用 numpy 存向量再自己写个循环算相似度是可以的。但一旦数据量变大,或者你希望有更工程化的功能,ChromaDB 的价值就出来了:
-
自带的快速检索:不用自己每次从头扫描全部向量,内部有索引优化(比如 HNSW)。
-
元数据管理:除了向量,还能顺便存原文、标题、来源等过滤信息,以后可以按条件筛。
-
持久化:默认用内存模式(重启就没了),但换个 Client 就能存到磁盘,服务重启后数据还在。
-
简单 API:增、查、删就几个方法,比手写一堆矩阵运算代码方便很多。
你的代码里用了 EphemeralClient,这是临时模式,数据只存在内存中,进程结束就消失------适合学习和原型验证。如果要落地到可保存的项目,后面会讲怎么改成持久化。
2. 代码里这两行做了什么?
python
chromadb_client = chromadb.EphemeralClient()
chromadb_collection = chromadb_client.get_or_create_collection(name="default")
-
EphemeralClient():创建一个"一次性"客户端,数据存在内存里,程序一停就没了。 -
get_or_create_collection("default"):拿到一个叫"default"的集合。可以理解成数据库里的一张表,专门存一类文档的向量。如果这个集合不存在,它会自动创建;如果已经存在,就直接用。
一个集合就像一个 Excel 表格,每一行是一条数据,包含:ID、文本(document)、向量(embedding)、可选的元数据(metadata)。
3. save_embeddings 函数怎么写?
这个函数的作用是把刚刚切好的文本块(chunks)和它们对应的向量(embeddings)存到 ChromaDB 里。
python
from typing import List
def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
"""将文本块和对应的向量存入 ChromaDB 集合"""
# 1. 为每个块生成唯一的 ID
ids = [f"chunk_{i}" for i in range(len(chunks))]
# 2. 批量添加
chromadb_collection.add(
documents=chunks, # 原始文本
embeddings=embeddings, # 对应的向量,形状: (数量, 维度)
ids=ids # 每个块的唯一标识符
)
注意:
-
embeddings参数是你用model.encode()得到的List[List[float]]或numpy数组(ChromaDB 两种都接受)。 -
ids必须每个都不同,否则会覆盖旧数据。常用uuid或chunk_序号。 -
你也可以顺手加上
metadatas参数,放一些比如"来源文件"、"页码"等辅助信息,以后可以按条件过滤。
4. 如何用存储好的数据做查询?
等你存好了,就可以用自然语言问题去搜索最相关的文本块了。
python
def query_chunks(question: str, model, n_results: int = 3):
"""输入问题,返回最相关的 n 个文本块"""
# 1. 先把问题转成向量
query_emb = model.encode([question], normalize_embeddings=True).tolist()
# 2. 在集合里搜索
results = chromadb_collection.query(
query_embeddings=query_emb, # 注意这里是列表的列表
n_results=n_results, # 要返回几条结果
include=["documents", "distances"] # 要带回哪些字段
)
# 3. 解析结果
# results 是一个字典,结构如下:
# {
# 'ids': [['chunk_0', 'chunk_5', ...]], # 最外层的列表对应一次查询的 batch
# 'documents': [['文本1', '文本2', ...]], # 因为我们只查了一个问题,所以取 [0]
# 'distances': [[0.2, 0.5, ...]]
# }
docs = results['documents'][0] # 返回的文本列表
distances = results['distances'][0] # 距离(默认是余弦距离,越小越相关)
return docs, distances
返回值解释:
-
distances:ChromaDB 默认用余弦距离(= 1 - 余弦相似度),所以值越小表示越相似(0 表示完全一样)。 -
documents:对应的原始文本块,直接可以喂给大模型当 prompt 上下文。
5. 从临时模式升级到持久化存储
如果你希望重启程序后数据还在,只需要把客户端换成 PersistentClient,并给它一个本地文件夹路径。
python
import chromadb
# 持久化到 ./my_vector_db 文件夹
chromadb_client = chromadb.PersistentClient(path="./my_vector_db")
chromadb_collection = chromadb_client.get_or_create_collection(name="default")
以后每次启动程序,都能拿到以前存过的集合和数据。
6. 把你的整个流程串起来
结合你之前所有代码,一个完整的"文档入库→查询"例子如下:
python
from typing import List
from sentence_transformers import SentenceTransformer
import chromadb
import os
# ---- 初始化 ----
# 1. 模型(加载一次,全局复用)
model = SentenceTransformer("shibing624/text2vec-base-chinese")
# 2. 向量数据库(改一行就能换成持久化)
# chromadb_client = chromadb.PersistentClient(path="./db") # 永久保存
chromadb_client = chromadb.EphemeralClient() # 临时测试
collection = chromadb_client.get_or_create_collection("my_docs")
# ---- 你之前的切块函数 ----
def split_into_chunks(doc_file: str) -> List[str]:
with open(doc_file, 'r', encoding='utf-8') as f:
content = f.read()
return [chunk.strip() for chunk in content.split("\n\n") if chunk.strip()]
# ---- 批量向量化并存储 ----
def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
ids = [f"chunk_{i}" for i in range(len(chunks))]
collection.add(documents=chunks, embeddings=embeddings, ids=ids)
# ---- 查询 ----
def search(question: str, top_k: int = 3):
query_emb = model.encode([question], normalize_embeddings=True).tolist()
results = collection.query(
query_embeddings=query_emb,
n_results=top_k,
include=["documents", "distances"]
)
for doc, dist in zip(results['documents'][0], results['distances'][0]):
print(f"[相似度: {1-dist:.3f}] {doc[:100]}...") # 把距离转回相似度显示
return results
# ---- 演示流程 ----
if __name__ == "__main__":
# 假设你有一个 test.txt,里面是几个用空行隔开的段落
chunks = split_into_chunks("test.txt")
print(f"切出了 {len(chunks)} 个块")
# 批量得到向量并直接存库
embeddings = model.encode(chunks, normalize_embeddings=True).tolist()
save_embeddings(chunks, embeddings)
# 模拟搜索
search("你的问题")
7. 两个你需要小心的地方
-
embeddings的格式
model.encode()返回numpy.ndarray,但 ChromaDB 的add方法也能接收list。为了通用,我上面用.tolist()转成 Python 列表。如果你用PersistentClient,它会直接存为高效的内部格式,没问题。 -
集合重复添加
如果你多次运行脚本,每次都往同一个 collection 添加同样的数据,ID 会重复(可能报错或覆盖)。通常的做法是:
-
检查集合里是否已有该来源的数据,有则先删再加,或跳过。
-
或者每次创建一个带时间戳的新集合。
python
# 删除整个集合重新来(谨慎,会删光) chromadb_client.delete_collection("my_docs") collection = chromadb_client.get_or_create_collection("my_docs") -
8. 下一步你可以尝试什么
-
加上元数据过滤 :比如
metadatas=[{"source": "文档A.pdf", "page": 3}],查询时可以where={"source": "文档A.pdf"}只搜某个文档。 -
使用 Embedding 模型内置 (ChromaDB 0.4+ 自带 embedding function 集成,省去手动调
model.encode),但你现在自己处理向量,更灵活,也更能理解流程,先保持这种方式就好。 -
接入大语言模型:把
search返回的文本块拼成 prompt 上下文,让 ChatGPT 或本地模型基于原文回答。
用 CrossEncoder(交叉编码器) 对向量数据库返回的结果做重排序(Rerank)。我把它的作用、原理、以及怎么跟你前面的 ChromaDB 配合使用,用工程化的方式讲清楚。
1. 为什么需要重排序?它不是已经有相似度了吗?
你用 ChromaDB 搜索后,返回的 top_k 已经按向量相似度排好了。但向量相似度(余弦距离)有时会"看走眼",比如:
-
向量模型更关注宏观主题,不太擅长抓细节语义、否定词、数字等。
-
返回的 top 20 里可能有噪音,靠前的并不一定是最契合问题的。
CrossEncoder 就是用来"二次筛选"的,它会把问题和每个候选句子拼在一起,直接预测一个精确的相关性分数。虽然速度比向量检索慢很多,但准确率高得多。
所以一个完整的 RAG 检索链路通常是:
text
问题 → 向量库快速初筛(如 top 20) → CrossEncoder 精排 → 取 top 3 送给大模型
2. 代码拆解 ------ 每一行在做什么
python
from sentence_transformers import CrossEncoder
def rerank(query: str, retrieved_chunks: List[str], top_k: int) -> List[str]:
# 1. 加载一个专门用于排序的 CrossEncoder 模型
cross_encoder = CrossEncoder('cross-encoder/mmarc0-mmMiniLMv2-L12-H384-v1')
# 2. 构造一对一的输入:[ (问题, 候选1), (问题, 候选2), ... ]
pairs = [(query, chunk) for chunk in retrieved_chunks]
# 3. 一次性预测所有 pair 的相似度分数(越高越相关)
scores = cross_encoder.predict(pairs)
# 4. 把 chunk 和分数打包,然后按分数降序排序
chunk_with_score_list = [(chunk, score) for chunk, score in zip(retrieved_chunks, scores)]
chunk_with_score_list.sort(key=lambda pair: pair[1], reverse=True)
# 5. 只返回重排后前面的 top_k 个文本
return [chunk for chunk, _ in chunk_with_score_list][:top_k]
# 使用示例
reranked_chunks = rerank(query, retrieved_chunks, 3)
关键细节:
-
CrossEncoder不是把句子变成向量,而是把一对句子拼起来做二分类(相似/不相似)。它的输入是一对,输出是单个浮点分数。 -
这里的模型
mmarc0-mmMiniLMv2-L12-H384-v1是一个轻量级的英文排序模型,如果你是中文场景,换成中文的交叉编码器(如BAAI/bge-reranker-v2-m3或cross-encoder/stsb-roberta-base-chinese)会准得多。 -
predict一次可以接收很多 pairs,内部会批处理,比循环调快。
3. 工程上怎么把它接到 ChromaDB 后面?
假设你之前的查询函数已经返回了 20 个候选块(为了给精排留余地,初筛要多取一些):
python
def search_and_rerank(question: str, model, collection, top_k=3, initial_top=20):
# 第一步:用向量库快速初筛出 initial_top 个候选
query_emb = model.encode([question], normalize_embeddings=True).tolist()
results = collection.query(
query_embeddings=query_emb,
n_results=initial_top, # 多取一些
include=["documents"]
)
candidates = results['documents'][0] # 列表,长度 20
# 第二步:用 CrossEncoder 精选出最相关的 top_k
reranked = rerank(question, candidates, top_k)
return reranked
注意 :CrossEncoder 的加载成本比较高,一定不能写在函数内部每次都加载。要把它提到全局,只加载一次:
python
cross_encoder = CrossEncoder('BAAI/bge-reranker-v2-m3') # 全局单例
def rerank(query, candidates, top_k):
pairs = [(query, c) for c in candidates]
scores = cross_encoder.predict(pairs)
pairs_with_scores = sorted(zip(candidates, scores), key=lambda x: x[1], reverse=True)
return [c for c, _ in pairs_with_scores[:top_k]]
4. 模型选型建议(中文场景)
你目前代码里用的 mmarc0-mmMiniLMv2-L12-H384-v1 是英文为主的模型。如果处理中文文档,推荐换成:
-
BAAI/bge-reranker-v2-m3(通用多语言,效果很好,BGE 系列) -
cross-encoder/stsb-roberta-base-chinese(中文语义相似度基座) -
maidalun1020/bce-reranker-base_v1(也为中文设计)
这些模型在 HuggingFace 上都能直接用 CrossEncoder(模型名) 加载。
5. 性能优化与工程坑
-
不要精排太多候选 :初筛的
initial_top设 20~30 比较合理,太多了精排会变慢,而且边际收益低。 -
批量预测 :代码里的
cross_encoder.predict(pairs)就是批处理,不要自己写循环。 -
GPU 加速 :CrossEncoder 也支持
device='cuda',能显著加快重排序速度。加载时指定:
cross_encoder = CrossEncoder('模型名', device='cuda') -
重排序的耗时:通常在几十到几百毫秒(取决于候选数量和模型大小),相比于几十毫秒的向量检索要慢不少,但为提升最终答案质量是值得的。
6. 你的完整链路现在长这样
text
原始文档 → 按空行切块 → 句向量模型编码 → 存入 ChromaDB
↓
用户问题 → 句向量模型编码 → ChromaDB 快速检索(取 20 个)
↓
CrossEncoder 精排(取 3 个)
↓
把 3 个块拼成上下文 + 问题 → LLM 生成答案