LangChain Retriver 例子

本文章介绍并演示了用 LangChain 开发 RAG Agent 时常用的 Retriver,并进行了简单对比。

全部代码附在最后,其中 Embedding 模型使用 ollama 运行。

检索器类型

1. BM25 稀疏检索

原理 :基于 TF-IDF 改进的经典关键词匹配算法,通过计算文档与查询的词频相似度排序
实现

python 复制代码
from langchain_community.retrievers import BM25Retriever
retriever = BM25Retriever.from_documents(docs)

优势

  • 对关键词匹配敏感,适合事实性问答
  • 计算速度快,资源消耗低
  • 无需嵌入模型,避免 embedding 计算成本

劣势

  • 无法理解语义相似性(如"汽车"和"车辆"视为不同)
  • 对长文档和复杂查询效果有限

2. 向量稠密检索

原理 :使用 embedding 模型将文本转换为向量,通过余弦相似度匹配语义相似的文档
实现

python 复制代码
from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(model="qwen3-embedding:0.6b")
vectorstore = FAISS.from_documents(docs, embeddings)
retriever = vectorstore.as_retriever()

优势

  • 理解语义相似性,支持同义词和上下文匹配
  • 适合处理模糊查询和概念性问题
  • 可捕捉文档深层语义关系

劣势

  • 需要 embedding 模型,计算成本高
  • 对硬件资源有一定要求
  • 可能受限于 embedding 模型的质量

3. 混合检索(EnsembleRetriever)

原理 :结合 BM25 稀疏检索和向量稠密检索的优势,通过权重融合结果
实现

python 复制代码
from langchain_classic.retrievers import EnsembleRetriever

ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, vector_retriever],
    weights=[0.5, 0.5]  # 权重可根据场景调整
)

注意,新版 LangChain 中,EnsembleRetriever 要去langchain_classic 包中找。

优势

  • 兼顾关键词匹配和语义理解
  • 鲁棒性强,适应不同类型查询
  • 可通过调整权重优化特定场景

劣势

  • 计算成本高于单一检索器
  • 实现复杂度增加
  • 需要调优权重参数

检索策略对比

特性 BM25 稀疏检索 向量稠密检索 混合检索
核心原理 关键词频率统计 语义向量相似度 加权融合两种策略
语义理解能力 低(字面匹配) 高(上下文理解)
计算效率 中(需 embedding) 低(双重计算)
资源需求 中(需模型)
适合场景 事实性查询 概念性/模糊查询 复杂场景/混合需求
错误恢复能力 有(降级机制)

降级与重试机制

向量检索实现了自动降级和重试机制,增强系统稳定性:

python 复制代码
dense_with_fallback = dense_retriever.with_retry(stop_after_attempt=2).with_fallbacks([sparse_retriever])
  • 当向量检索失败时(如模型不可用),自动切换到 BM25 检索
  • 支持最多 2 次重试,提高成功概率

使用示例

python 复制代码
# 并发测试三种检索器
parallel_retrievers = RunnableParallel(
    bm25=sparse | format_results("BM25 稀疏检索"),
    dense=dense_with_fallback | format_results("向量稠密检索"),
    ensemble=ensemble | format_results("混合检索 (Ensemble)"),
)

完整示例代码

python 复制代码
from langchain_classic.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel
from langchain_ollama import OllamaEmbeddings

DOCS = [
    "LangChain 是一个用于构建大语言模型应用的框架。",
    "BM25 是一种经典的稀疏检索算法,适合关键词匹配。",
    "Python 是一门流行的编程语言,常用于 AI 开发。",
    "LangGraph 是 LangChain 推出的用于构建 Agent 的库。",
    "向量数据库用于存储文本的嵌入表示,支持语义搜索。",
    "RAG (检索增强生成) 结合检索和生成,提升回答质量。",
    "FAISS 是 Meta 开发的高效向量相似度搜索库。",
    "Embedding 模型将文本转换为向量表示,捕捉语义信息。"
]


class RetrievalCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        self.current_query = None
        self.results = {}
    
    def on_retriever_start(self, serialized, query, **kwargs):
        self.current_query = query
        print(f"\n{'='*50}")
        print(f"[检索开始] 问题: {query}")
        print(f"{'='*50}")
    
    def on_retriever_end(self, documents, **kwargs):
        if self.current_query:
            print(f"[检索完成] 找到 {len(documents)} 个结果:")
            for i, doc in enumerate(documents):
                print(f"  [{i+1}] {doc.page_content}")
            print(f"{'='*50}")
    
    def on_chain_start(self, serialized, inputs, **kwargs):
        if isinstance(inputs, str):
            self.current_query = inputs
            print(f"\n{'='*60}")
            print(f"[链式调用开始] 问题: {inputs}")
            print(f"{'='*60}")
        elif isinstance(inputs, dict) and "query" in inputs:
            self.current_query = inputs["query"]
            print(f"\n{'='*60}")
            print(f"[链式调用开始] 问题: {inputs['query']}")
            print(f"{'='*60}")
    
    def on_chain_end(self, outputs, **kwargs):
        if isinstance(outputs, str):
            print("\n[链式调用完成] 聚合结果:")
            print(outputs)
            print(f"\n{'='*60}")
    
    def on_chain_error(self, error, **kwargs):
        print(f"\n[错误] 链式调用失败: {str(error)}")
        print(f"{'='*60}")
    
    def on_retriever_error(self, error, **kwargs):
        print(f"\n[错误] 检索失败: {str(error)}")
        print(f"{'='*50}")


def format_docs(docs: list[Document]) -> str:
    return "\n".join(f"[{i+1}] {d.page_content}" for i, d in enumerate(docs))


def format_results(name: str):
    def _format(docs: list[Document]) -> str:
        lines = [f"【{name}】"]
        for i, doc in enumerate(docs):
            lines.append(f"  [{i+1}] {doc.page_content}")
        return "\n".join(lines)
    return _format


def aggregate_results(results: dict) -> str:
    lines = []
    for name, output in results.items():
        lines.append(output)
        lines.append("")
    return "\n".join(lines)


def create_sparse_retriever(docs: list[Document], k: int = 3):
    retriever = BM25Retriever.from_documents(docs)
    retriever.k = k
    return retriever


def create_dense_retriever(docs: list[Document], embeddings, k: int = 3):
    vectorstore = FAISS.from_documents(docs, embeddings)
    return vectorstore.as_retriever(search_kwargs={"k": k})


def create_retriever_with_fallback(dense_retriever, sparse_retriever):
    return (
        dense_retriever
        .with_retry(stop_after_attempt=2)
        .with_fallbacks([sparse_retriever], exceptions_to_handle=(Exception,))
    )


def main():
    docs = [Document(page_content=text) for text in DOCS]
    embeddings = OllamaEmbeddings(model="qwen3-embedding:0.6b")
    
    callback = RetrievalCallbackHandler()
    
    sparse = create_sparse_retriever(docs)
    dense = create_dense_retriever(docs, embeddings)
    dense_with_fallback = create_retriever_with_fallback(dense, sparse)
    ensemble = EnsembleRetriever(retrievers=[sparse, dense], weights=[0.5, 0.5])
    
    parallel_retrievers = RunnableParallel(
        bm25=sparse | RunnableLambda(format_results("BM25 稀疏检索")),
        dense=dense_with_fallback | RunnableLambda(format_results("向量稠密检索")),
        ensemble=ensemble | RunnableLambda(format_results("混合检索 (Ensemble)")),
    )
    
    chain = (
        {"results": parallel_retrievers, "query": RunnablePassthrough()}
        | RunnableLambda(lambda x: f"查询: {x['query']}\n\n{aggregate_results(x['results'])}")
    )
    
    print("=" * 60)
    print("RAG 检索器并发测试 (LCEL + Callback)")
    print("=" * 60)
    
    queries = [
        "LangChain 是用来做什么的?",
        "有什么 AI 相关的工具?",
        "Python 和 LangChain 的关系",
    ]
    
    for query in queries:
        result = chain.invoke(query, config={"callbacks": [callback]})
        print(result)
    
    print("\n" + "=" * 60)
    print("演示完成!")
    print("=" * 60)


if __name__ == "__main__":
    main()
相关推荐
FserSuN2 小时前
LangChain V1 create_agent 与 DeepAgents create_deep_agent 对比学习
langchain
心在飞扬3 小时前
我把本地文档 RAG 做成了可用系统:Flask + Vue3 + LangChain + FAISS(多知识库 + 流式输出)
langchain·openai·ai编程
知秋丶5 小时前
LangGraph 实战:如何用“双图编排”将多模态 OCR-RAG 做到生产级落地
人工智能·langchain·ocr
Java咩7 小时前
LangChain 之 LCEL表达式语法
python·langchain·lcel
来一斤小鲜肉7 小时前
Spring AI核心:高阶API之ChatMemory
langchain·aigc
王解8 小时前
开源与第三方视角:Thoughtworks、LangChain等如何看待Harness Engineering?
langchain·开源·ai agent
武汉知识图谱科技8 小时前
超越预测性维护:基于知识超图与根因推理的能源电力“免疫系统”构建
人工智能·物联网·langchain·能源·知识图谱·embedding
西西弗Sisyphus10 小时前
使用 langchain 的 PromptTemplate 处理多变量提示词
langchain·agent