本文章介绍并演示了用 LangChain 开发 RAG Agent 时常用的 Retriver,并进行了简单对比。
全部代码附在最后,其中 Embedding 模型使用 ollama 运行。
检索器类型
1. BM25 稀疏检索
原理 :基于 TF-IDF 改进的经典关键词匹配算法,通过计算文档与查询的词频相似度排序
实现:
python
from langchain_community.retrievers import BM25Retriever
retriever = BM25Retriever.from_documents(docs)
优势:
- 对关键词匹配敏感,适合事实性问答
- 计算速度快,资源消耗低
- 无需嵌入模型,避免 embedding 计算成本
劣势:
- 无法理解语义相似性(如"汽车"和"车辆"视为不同)
- 对长文档和复杂查询效果有限
2. 向量稠密检索
原理 :使用 embedding 模型将文本转换为向量,通过余弦相似度匹配语义相似的文档
实现:
python
from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaEmbeddings
embeddings = OllamaEmbeddings(model="qwen3-embedding:0.6b")
vectorstore = FAISS.from_documents(docs, embeddings)
retriever = vectorstore.as_retriever()
优势:
- 理解语义相似性,支持同义词和上下文匹配
- 适合处理模糊查询和概念性问题
- 可捕捉文档深层语义关系
劣势:
- 需要 embedding 模型,计算成本高
- 对硬件资源有一定要求
- 可能受限于 embedding 模型的质量
3. 混合检索(EnsembleRetriever)
原理 :结合 BM25 稀疏检索和向量稠密检索的优势,通过权重融合结果
实现:
python
from langchain_classic.retrievers import EnsembleRetriever
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_retriever],
weights=[0.5, 0.5] # 权重可根据场景调整
)
注意,新版 LangChain 中,EnsembleRetriever 要去langchain_classic 包中找。
优势:
- 兼顾关键词匹配和语义理解
- 鲁棒性强,适应不同类型查询
- 可通过调整权重优化特定场景
劣势:
- 计算成本高于单一检索器
- 实现复杂度增加
- 需要调优权重参数
检索策略对比
| 特性 | BM25 稀疏检索 | 向量稠密检索 | 混合检索 |
|---|---|---|---|
| 核心原理 | 关键词频率统计 | 语义向量相似度 | 加权融合两种策略 |
| 语义理解能力 | 低(字面匹配) | 高(上下文理解) | 高 |
| 计算效率 | 高 | 中(需 embedding) | 低(双重计算) |
| 资源需求 | 低 | 中(需模型) | 高 |
| 适合场景 | 事实性查询 | 概念性/模糊查询 | 复杂场景/混合需求 |
| 错误恢复能力 | 无 | 有(降级机制) | 有 |
降级与重试机制
向量检索实现了自动降级和重试机制,增强系统稳定性:
python
dense_with_fallback = dense_retriever.with_retry(stop_after_attempt=2).with_fallbacks([sparse_retriever])
- 当向量检索失败时(如模型不可用),自动切换到 BM25 检索
- 支持最多 2 次重试,提高成功概率
使用示例
python
# 并发测试三种检索器
parallel_retrievers = RunnableParallel(
bm25=sparse | format_results("BM25 稀疏检索"),
dense=dense_with_fallback | format_results("向量稠密检索"),
ensemble=ensemble | format_results("混合检索 (Ensemble)"),
)
完整示例代码
python
from langchain_classic.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel
from langchain_ollama import OllamaEmbeddings
DOCS = [
"LangChain 是一个用于构建大语言模型应用的框架。",
"BM25 是一种经典的稀疏检索算法,适合关键词匹配。",
"Python 是一门流行的编程语言,常用于 AI 开发。",
"LangGraph 是 LangChain 推出的用于构建 Agent 的库。",
"向量数据库用于存储文本的嵌入表示,支持语义搜索。",
"RAG (检索增强生成) 结合检索和生成,提升回答质量。",
"FAISS 是 Meta 开发的高效向量相似度搜索库。",
"Embedding 模型将文本转换为向量表示,捕捉语义信息。"
]
class RetrievalCallbackHandler(BaseCallbackHandler):
def __init__(self):
self.current_query = None
self.results = {}
def on_retriever_start(self, serialized, query, **kwargs):
self.current_query = query
print(f"\n{'='*50}")
print(f"[检索开始] 问题: {query}")
print(f"{'='*50}")
def on_retriever_end(self, documents, **kwargs):
if self.current_query:
print(f"[检索完成] 找到 {len(documents)} 个结果:")
for i, doc in enumerate(documents):
print(f" [{i+1}] {doc.page_content}")
print(f"{'='*50}")
def on_chain_start(self, serialized, inputs, **kwargs):
if isinstance(inputs, str):
self.current_query = inputs
print(f"\n{'='*60}")
print(f"[链式调用开始] 问题: {inputs}")
print(f"{'='*60}")
elif isinstance(inputs, dict) and "query" in inputs:
self.current_query = inputs["query"]
print(f"\n{'='*60}")
print(f"[链式调用开始] 问题: {inputs['query']}")
print(f"{'='*60}")
def on_chain_end(self, outputs, **kwargs):
if isinstance(outputs, str):
print("\n[链式调用完成] 聚合结果:")
print(outputs)
print(f"\n{'='*60}")
def on_chain_error(self, error, **kwargs):
print(f"\n[错误] 链式调用失败: {str(error)}")
print(f"{'='*60}")
def on_retriever_error(self, error, **kwargs):
print(f"\n[错误] 检索失败: {str(error)}")
print(f"{'='*50}")
def format_docs(docs: list[Document]) -> str:
return "\n".join(f"[{i+1}] {d.page_content}" for i, d in enumerate(docs))
def format_results(name: str):
def _format(docs: list[Document]) -> str:
lines = [f"【{name}】"]
for i, doc in enumerate(docs):
lines.append(f" [{i+1}] {doc.page_content}")
return "\n".join(lines)
return _format
def aggregate_results(results: dict) -> str:
lines = []
for name, output in results.items():
lines.append(output)
lines.append("")
return "\n".join(lines)
def create_sparse_retriever(docs: list[Document], k: int = 3):
retriever = BM25Retriever.from_documents(docs)
retriever.k = k
return retriever
def create_dense_retriever(docs: list[Document], embeddings, k: int = 3):
vectorstore = FAISS.from_documents(docs, embeddings)
return vectorstore.as_retriever(search_kwargs={"k": k})
def create_retriever_with_fallback(dense_retriever, sparse_retriever):
return (
dense_retriever
.with_retry(stop_after_attempt=2)
.with_fallbacks([sparse_retriever], exceptions_to_handle=(Exception,))
)
def main():
docs = [Document(page_content=text) for text in DOCS]
embeddings = OllamaEmbeddings(model="qwen3-embedding:0.6b")
callback = RetrievalCallbackHandler()
sparse = create_sparse_retriever(docs)
dense = create_dense_retriever(docs, embeddings)
dense_with_fallback = create_retriever_with_fallback(dense, sparse)
ensemble = EnsembleRetriever(retrievers=[sparse, dense], weights=[0.5, 0.5])
parallel_retrievers = RunnableParallel(
bm25=sparse | RunnableLambda(format_results("BM25 稀疏检索")),
dense=dense_with_fallback | RunnableLambda(format_results("向量稠密检索")),
ensemble=ensemble | RunnableLambda(format_results("混合检索 (Ensemble)")),
)
chain = (
{"results": parallel_retrievers, "query": RunnablePassthrough()}
| RunnableLambda(lambda x: f"查询: {x['query']}\n\n{aggregate_results(x['results'])}")
)
print("=" * 60)
print("RAG 检索器并发测试 (LCEL + Callback)")
print("=" * 60)
queries = [
"LangChain 是用来做什么的?",
"有什么 AI 相关的工具?",
"Python 和 LangChain 的关系",
]
for query in queries:
result = chain.invoke(query, config={"callbacks": [callback]})
print(result)
print("\n" + "=" * 60)
print("演示完成!")
print("=" * 60)
if __name__ == "__main__":
main()