14. LangChain项目实战1——基于公司制度RAG回答机器人

教学视频:

12. 基于Gradio搭建基于公司制度RAG_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV11VXRYTErZ/

环境配置:

python版本:3.10.8

服务器:Ubuntu

依赖包requirements.txt文件内容:

python 复制代码
aiofiles==23.2.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.9
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.0
asgiref==3.8.1
asttokens==2.4.1
async-timeout==4.0.3
attrs==24.2.0
backoff==2.2.1
bcrypt==4.2.0
beautifulsoup4==4.12.3
build==1.2.2
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.3.2
chroma-hnswlib==0.7.3
chromadb==0.5.0
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
coloredlogs==15.0.1
comm==0.2.2
contourpy==1.3.0
cryptography==43.0.1
cycler==0.12.1
dataclasses-json==0.6.7
debugpy==1.8.6
decorator==5.1.1
deepdiff==8.0.1
Deprecated==1.2.14
distro==1.9.0
durationpy==0.9
emoji==2.14.0
exceptiongroup==1.2.2
executing==2.1.0
fastapi==0.115.0
ffmpy==0.4.0
filelock==3.16.1
filetype==1.2.0
flatbuffers==24.3.25
fonttools==4.54.1
frozenlist==1.4.1
fsspec==2024.9.0
google-auth==2.35.0
googleapis-common-protos==1.65.0
gradio==4.44.1
gradio_client==1.3.0
greenlet==3.1.1
grpcio==1.66.2
h11==0.14.0
httpcore==1.0.6
httptools==0.6.1
httpx==0.27.2
huggingface-hub==0.25.1
humanfriendly==10.0
idna==3.10
importlib_metadata==8.4.0
importlib_resources==6.4.5
ipykernel==6.29.5
ipython==8.28.0
jedi==0.19.1
Jinja2==3.1.4
jiter==0.5.0
joblib==1.4.2
jsonpatch==1.33
jsonpath-python==1.0.6
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter_client==8.6.3
jupyter_core==5.7.2
kiwisolver==1.4.7
kubernetes==31.0.0
langchain==0.2.16
langchain-chroma==0.2.2
langchain-community==0.2.13
langchain-core==0.2.41
langchain-openai==0.1.8
langchain-text-splitters==0.2.4
langdetect==1.0.9
langsmith==0.1.131
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.22.0
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mmh3==5.0.1
monotonic==1.6
mpmath==1.3.0
multidict==6.1.0
mypy-extensions==1.0.0
narwhals==1.9.1
nest-asyncio==1.6.0
networkx==3.3
nltk==3.9.1
numpy==1.26.4
oauthlib==3.2.2
onnxruntime==1.19.2
openai==1.51.0
opentelemetry-api==1.27.0
opentelemetry-exporter-otlp-proto-common==1.27.0
opentelemetry-exporter-otlp-proto-grpc==1.27.0
opentelemetry-instrumentation==0.48b0
opentelemetry-instrumentation-asgi==0.48b0
opentelemetry-instrumentation-fastapi==0.48b0
opentelemetry-proto==1.27.0
opentelemetry-sdk==1.27.0
opentelemetry-semantic-conventions==0.48b0
opentelemetry-util-http==0.48b0
orderly-set==5.2.2
orjson==3.10.7
overrides==7.7.0
packaging==24.1
pandas==2.2.3
parso==0.8.4
pillow==10.4.0
pip==24.2
platformdirs==4.3.6
posthog==3.7.0
prompt_toolkit==3.0.48
protobuf==4.25.5
psutil==6.0.0
pure_eval==0.2.3
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.4
pypdf==5.0.1
PyPika==0.48.9
pyproject_hooks==1.2.0
pyreadline3==3.5.4
python-dateutil==2.9.0.post0
python-docx==1.1.2
python-dotenv==1.0.1
python-iso639==2024.4.27
python-magic==0.4.27
python-multipart==0.0.12
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
rank-bm25==0.2.2
RapidFuzz==3.10.0
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
rich==13.9.2
rpds-py==0.20.0
rsa==4.9
ruff==0.6.9
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
semantic-version==2.10.0
sentence-transformers==3.1.0
setuptools==75.1.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
soupsieve==2.6
spyder-kernels==2.5.2
SQLAlchemy==2.0.35
stack-data==0.6.3
starlette==0.38.6
sympy==1.13.3
tabulate==0.9.0
tenacity==8.5.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.0
tomli==2.0.2
tomlkit==0.12.0
torch==2.4.1
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
transformers==4.45.1
typer==0.12.5
typing_extensions==4.12.2
typing-inspect==0.9.0
tzdata==2024.2
unstructured==0.14.6
unstructured-client==0.25.9
urllib3==2.2.3
uvicorn==0.31.0
watchfiles==0.24.0
wcwidth==0.2.13
websocket-client==1.8.0
websockets==11.0.3
wheel==0.44.0
wrapt==1.16.0
yarl==1.13.1
zipp==3.20.2

.env文件

复制代码
HF_ENDPOINT=https://hf-mirror.com
HF_TOKEN=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
HF_CACHE_HOME=/root/autodl-fs/hugging_face
OPENAI_API_KEY=sk-a612336087ac408588d3e26cb47a5d9d
OPENAI_BASE_URL=https://api.deepseek.com/v1
OPENAI_MODEL=deepseek-chat
ZHIPU_BASE_URL=https://open.bigmodel.cn/api/paas/v4/
ZHIPU_API_KEY=4923c4daef034dc79086b53d60a87b2b.T5iESEb4WKlTGEx7
EMBEDDING_MODEL=embedding-3

下载 嵌入(embed)和 重排序(rerank)大模型

嵌入:BAAI/bge-large-zh-v1.5

python 复制代码
import os

from dotenv import load_dotenv, find_dotenv
from huggingface_hub import snapshot_download
from langchain_community.embeddings import HuggingFaceBgeEmbeddings

load_dotenv(find_dotenv())
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
HUGGING_FACE_ACCESS_TOKEN = os.getenv('HUGGING_FACE_ACCESS_TOKEN')


def download_model(model_name="BAAI/bge-large-zh-v1.5"):
    local_dir = f"../{model_name}"

    snapshot_download(repo_id=model_name, local_dir=local_dir, token=HUGGING_FACE_ACCESS_TOKEN)


def test_model(model_path="../BAAI/bge-large-zh-v1.5"):
    model_kwargs = {'device': 'cuda'}

    embeddings = HuggingFaceBgeEmbeddings(model_name=model_path,
                                          model_kwargs=model_kwargs)

    r = embeddings.embed_query("测试结核杆菌该回家回家过节国家机关和监管机构好几个好几个机会跟好几个机会和加官晋爵回家好不好举火炬计划v环境和v就不会进步v家")
    print(r)


# download_model()
test_model()

下载后测试效果如下:

重排序:BAAI/bge-reranker-large

python 复制代码
import os

from dotenv import load_dotenv, find_dotenv
from huggingface_hub import snapshot_download
from sentence_transformers import CrossEncoder


load_dotenv(find_dotenv())
#huggingface token
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
HUGGING_FACE_ACCESS_TOKEN = os.getenv('HUGGING_FACE_ACCESS_TOKEN')

# https://huggingface.co/BAAI/bge-reranker-large
def download_model(model_name="BAAI/bge-reranker-large"):
    local_dir = f"../{model_name}"

    snapshot_download(repo_id=model_name, local_dir=local_dir, token=HUGGING_FACE_ACCESS_TOKEN)


def test_model(model_patch="../BAAI/bge-reranker-large"):
    model = CrossEncoder(model_patch,device="cuda")

    pairs = [["孙悟空是谁", "孙悟空是你大爷"], ["孙悟空是谁", "孙悟空是唐僧徒弟"]]
    r = model.predict(pairs)
    print(r)


# download_model()
test_model()

下载后测试效果如下:

llm.py

初始化

python 复制代码
import hashlib
import os.path
from typing import Optional, Iterable
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.retrieval import create_retrieval_chain
from langchain.indexes import SQLRecordManager
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever, RePhraseQueryRetriever
from langchain.retrievers.document_compressors import LLMChainFilter, CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain_core.document_loaders import BaseLoader
from langchain_core.embeddings import Embeddings
from langchain_core.indexing import index
from langchain_core.messages import AIMessageChunk
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import AddableDict
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import  ChatOpenAI, OpenAI
from unstructured.file_utils.filetype import FileType, detect_filetype
from langchain_community.document_loaders import PyPDFLoader, CSVLoader, TextLoader, UnstructuredWordDocumentLoader,UnstructuredMarkdownLoader
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from dotenv import load_dotenv
import logging
import shutil

# 这行代码配置了日志记录的基本设置。它调用 logging.basicConfig(),这会对日志记录进行基本配置,例如设置日志记录格式、日志文件等。
# 这里没有提供具体参数,所以使用默认配置,这通常包括在控制台输出日志消息。
logging.basicConfig()

# 这行代码获取名为 "langchain.retrievers.multi_query" 的日志记录器,并将其日志级别设置为 INFO。
# 这样,任何由这个记录器产生的 INFO 级别及以上的日志消息(INFO、WARNING、ERROR、CRITICAL)都会被输出。
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

# 这行代码获取名为 "langchain.retrievers.re_phraser" 的日志记录器,并将其日志级别设置为 INFO。
# 同样,任何由这个记录器产生的 INFO 级别及以上的日志消息都会被输出。
logging.getLogger("langchain.retrievers.re_phraser").setLevel(logging.INFO)


# 加载.env文件中的环境变量
load_dotenv()
ai_model = os.getenv("OPENAI_MODEL")

# 设置知识库 向量模型 重排序模型的路径
KNOWLEDGE_DIR = './chroma/knowledge/'
embedding_model = './BAAI/bge-large-zh-v1.5'
rerank_model = './BAAI/bge-reranker-large'
model_kwargs = {'device': 'cuda'}


# 知识库问答指令
qa_system_prompt = (
    "你叫瓜皮,一个帮助人们解答各种问题的助手。 "
    "使用检索到的上下文来回答问题。如果你不知道答案,就说你不知道。 "
    "\n\n"
    "{context}"
)

qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        ("placeholder", "{chat_history}"),
        ("human", "{input}"),
    ]
)

# 正常聊天指令
normal_system_prompt = (
    "你叫瓜皮,一个帮助人们解答各种问题的助手。"
)

normal_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", normal_system_prompt),
        ("placeholder", "{chat_history}"),
        ("human", "{input}"),
    ]
)

create_indexes 函数

python 复制代码
def create_indexes(collection_name: str, loader: BaseLoader, embedding_function: Optional[Embeddings] = None):
    db = Chroma(collection_name=collection_name,
                embedding_function=embedding_function,
                persist_directory=os.path.join('./chroma', collection_name))

    # https://python.langchain.com/v0.2/docs/how_to/indexing/
    record_manager = SQLRecordManager(
        f"chromadb/{collection_name}", db_url="sqlite:///record_manager_cache.sql"
    )
    print('record_manager: ',record_manager)
    record_manager.create_schema()
    print('record_manager: ',record_manager)
    print('record_manager.create_schema: ',record_manager.create_schema())
    documents = loader.load()
    print('documents: ',documents)

    r = index(documents, record_manager, db, cleanup="full", source_id_key="source")
    print('r: ',r)
    '''混合检索,将稀疏检索器(如BM25)与密集检索器(如嵌入相似性)相结合。
    稀疏检索器擅长根据关键字查找相关文档,而密集检索器擅长根据语义相似性查找相关文档。'''
    ensemble_retriever = EnsembleRetriever(
        retrievers=[db.as_retriever(search_kwargs={"k": 3}), BM25Retriever.from_documents(documents)]
    )
    print('ensemble_retriever: ',ensemble_retriever)

    return ensemble_retriever

create_indexes 函数讲解

函数目的

create_indexes 函数旨在为一个指定的文档集合创建索引,以便后续能够快速检索相关文档。它结合了稀疏检索(如BM25)和密集检索(如基于嵌入的相似性)的优势,通过混合检索技术提高了检索的准确性和效率。

参数说明
  • collection_name: 字符串类型,表示文档集合的名称。
  • loader: BaseLoader 类型,负责加载文档集合。
  • embedding_function: 可选参数,Embeddings 类型,用于将文档转换为嵌入向量。如果未提供,则可能无法执行基于嵌入的检索。
函数流程
  1. 初始化 Chroma 实例‌:

    • 使用 collection_nameembedding_function 初始化 Chroma 实例,该实例负责处理文档的嵌入和索引。
    • persist_directory 设置为 './chroma/' 加上 collection_name,用于持久化存储索引数据。
  2. 创建 SQLRecordManager 实例‌:

    • 实例化 SQLRecordManager,用于管理记录,确保索引中的文档是最新的且没有重复。
    • 命名空间设置为 "chromadb/" 加上 collection_name,数据库URL指向 SQLite 数据库文件 "sqlite:///record_manager_cache.sql"
    • 调用 create_schema 方法创建数据库架构。
  3. 加载文档‌:

    • 使用 loader 加载文档集合。
  4. 索引文档‌:

    • 调用 index 函数,将文档集合、record_managerChroma 实例以及清理模式("full")和源ID键("source")作为参数,对文档进行索引。
  5. 创建混合检索器‌:

    • 实例化 EnsembleRetriever,将 Chroma 实例作为密集检索器(配置为返回前3个结果)和基于文档的BM25检索器作为稀疏检索器组合在一起。
    • 这样,混合检索器可以结合两种检索技术的优势,提供更准确的检索结果。
返回值
  • 返回 ensemble_retriever 实例,即混合检索器,可用于后续的文档检索任务。

解析AIMessage

python 复制代码
def get_md5(input_string):
    # 创建一个 md5 哈希对象
    hash_md5 = hashlib.md5()

    # 需要确保输入字符串是字节串,因此如果它是字符串,则需要编码为字节串
    hash_md5.update(input_string.encode('utf-8'))

    # 获取十六进制的哈希值
    return hash_md5.hexdigest()


def streaming_parse(chunks: Iterable[AIMessageChunk]):
    for chunk in chunks:
        yield AddableDict({'answer': chunk.content})

AIMessageChunk 的作用

AIMessageChunk 是 LangChain 库中的一个类,它主要用于处理和表示消息或文档中的分块数据。在 LangChain 的上下文中,AIMessageChunk 可能有以下几个关键作用:

  1. 数据分块‌:

    • AIMessageChunk 可以将较大的消息或文档分割成较小的、更易于管理的块或片段。这种分块处理在处理长文本或大型文档时特别有用,可以提高处理效率和准确性。
  2. 处理文本数据‌:

    • 类可能提供了一系列方法来处理这些分块数据,例如提取关键信息、进行文本分析或转换等。这对于实现复杂的文本处理任务非常有帮助。
  3. 支持索引和检索‌:

    • 在 LangChain 的索引和检索系统中,AIMessageChunk 可能作为索引和检索的基本单位。通过将文档分块,可以更精细地控制索引的粒度和检索的准确性。
  4. 数据整合‌:

    • 在某些情况下,AIMessageChunk 还可以用于整合来自不同来源或格式的数据。通过将数据转换为统一的分块格式,可以更容易地进行比较、分析和处理。
  5. 支持异步处理‌:

    • 考虑到 LangChain 的设计哲学,AIMessageChunk 可能也支持异步处理方法,以充分利用现代硬件和并发处理技术的优势。

AddableDict 的作用

AddableDict 是 LangChain 库中 langchain_core.runnables 模块下的一个类,它主要用于提供一个具有额外加法操作的字典类型。以下是 AddableDict 的几个关键作用:

  1. 字典功能‌:

    • AddableDict 首先是一个字典,因此它继承了标准字典的所有功能,如存储键值对、访问和修改值等。
  2. 加法操作‌:

    • 与普通字典不同的是,AddableDict 支持加法操作。这意味着你可以将两个 AddableDict 实例相加,得到一个新的 AddableDict 实例,其中包含两个原始字典中所有键的值之和(假设值是数值类型)。如果某个键只在一个字典中存在,则其结果将包含该键及其对应的值。
  3. 方便数据聚合‌:

    • 在处理需要聚合多个字典数据的场景时,AddableDict 提供了极大的便利。通过简单的加法操作,可以快速合并多个字典,并对相应的值进行求和,而无需手动遍历和累加。
  4. 灵活性和可扩展性‌:

    • AddableDict 的设计使其具有灵活性和可扩展性。虽然它主要支持数值类型的加法操作,但在实际应用中,可以通过继承和重写相关方法来扩展其功能,以适应不同的数据类型和操作需求。

MyCustomLoader

python 复制代码
class MyCustomLoader(BaseLoader):
    # 支持加载的文件类型
    file_type = {
        FileType.CSV: (CSVLoader, {'autodetect_encoding': True}),
        FileType.TXT: (TextLoader, {'autodetect_encoding': True}),
        FileType.DOC: (UnstructuredWordDocumentLoader, {}),
        FileType.DOCX: (UnstructuredWordDocumentLoader, {}),
        FileType.PDF: (PyPDFLoader, {}),
        FileType.MD: (UnstructuredMarkdownLoader, {})
    }
    # 初始化方法  将加载的文件进行切分
    def __init__(self, file_path: str):
        loader_class, params = self.file_type[detect_filetype(file_path)]
        print('loader_class:',loader_class)
        print('params:',params)
        self.loader: BaseLoader = loader_class(file_path, **params)
        print('self.loader:',self.loader)
        self.text_splitter = RecursiveCharacterTextSplitter(
            separators=["\n\n", "\n", " ", ""],
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
        )

    def lazy_load(self):
        # 懒惰切分加载
        return self.loader.load_and_split(self.text_splitter)

    def load(self):
        # 加载
        return self.lazy_load()

MyCustomLoader 类是一个自定义的文档加载器,它继承自一个假设存在的 BaseLoader 类。该类的主要目的是根据文件的类型选择合适的加载器,并对加载的内容进行切分处理。它支持多种文件类型,包括 CSV、TXT、DOC、DOCX、PDF 和 MD。

属性与方法
  1. 属性

    • file_type: 一个字典,定义了支持的文件类型及其对应的加载器和参数。键是 FileType 枚举值,值是一个元组,包含加载器类和参数字典。
    • loader: 一个 BaseLoader 类型的实例,根据文件类型动态创建,用于加载和初步处理文件内容。
    • text_splitter: 一个 RecursiveCharacterTextSplitter 实例,用于对加载的内容进行切分。
  2. 方法

    • __init__(self, file_path: str): 初始化方法。接受一个文件路径作为参数,根据文件类型选择合适的加载器和参数,创建 loader 实例。同时,初始化 text_splitter 实例。
    • lazy_load(self): 懒惰加载方法。调用 loaderload_and_split 方法,传入 text_splitter,对加载的内容进行切分处理。返回切分后的结果。
    • load(self): 加载方法。直接调用 lazy_load 方法,返回其结果。这个方法的存在主要是为了提供一个简单的加载接口。
工作流程
  1. 创建一个 MyCustomLoader 实例,传入文件路径。
  2. 在初始化过程中,根据文件路径检测文件类型,并从 file_type 字典中选择合适的加载器和参数。
  3. 创建 loader 实例,用于加载文件内容。
  4. 初始化 text_splitter 实例,用于对加载的内容进行切分。
  5. 调用 loadlazy_load 方法,加载并切分文件内容,返回切分后的结果。

MyKnowledge

python 复制代码
class MyKnowledge:
    # 向量化模型
    __embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model, model_kwargs=model_kwargs)
    print('__embeddings:',__embeddings)

    __retrievers = {}
    __llm = ChatOpenAI(
        model=ai_model,
        temperature=0
    )

    def upload_knowledge(self, temp_file):
        # 获取上传文件名的名称,不包括路径
        file_name = os.path.basename(temp_file)
        # 生成存储知识库的完整路径
        file_path = os.path.join(KNOWLEDGE_DIR, file_name)
        # 如果文件不存在就copy
        if not os.path.exists(file_path):
            # 如果文件不存在,那么就创建目录
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            # 将temp_file复制到知识库目录下file_path
            shutil.copy(temp_file, file_path)

        import gradio as gr
        # 返回None 和 更新gradio组件的选项
        return None, gr.update(choices=self.load_knowledge())

    def load_knowledge(self):
        # exist_ok=True目标目录已存在的情况下不会抛出异常。
        # 这意味着如果目录已经存在,os.makedirs不会做任何事情,也不会报错
        os.makedirs(os.path.dirname(KNOWLEDGE_DIR), exist_ok=True)

        # 知识库默认为空
        collections = [None]
        print('os.listdir(KNOWLEDGE_DIR):',os.listdir(KNOWLEDGE_DIR))

        for file in os.listdir(KNOWLEDGE_DIR):
            # 将知识库进行添加
            collections.append(file)

            # 得到知识库的路径
            file_path = os.path.join(KNOWLEDGE_DIR, file)
            print('file_path:', file_path)

            # 知识库文件名进行md5编码,对某一个知识库进行唯一标识
            # collection_name1
            # collection_name2
            collection_name = get_md5(file)
            print('collection_name:',collection_name)

            print('self.__retrievers:',self.__retrievers)
            if collection_name in self.__retrievers:
                continue
            # 创建对应加载器
            loader = MyCustomLoader(file_path)
            print('loader:',loader)
            self.__retrievers[collection_name] = create_indexes(collection_name, loader, self.__embeddings)
            print('collections:',collections)
        return collections

    def get_retrievers(self, collection):
        collection_name = get_md5(collection)
        print('知识库名字md5:',collection_name)
        if collection_name not in self.__retrievers:
            print('self.__retrievers:',self.__retrievers)
            print('True')
            return None

        retriever = self.__retrievers[collection_name]
        print('get_retrievers中:',retriever)
        ''' LLMChainFilter:过滤,对寻回的文本进行过滤。它的主要目的是根据一定的条件或规则筛选和过滤文本内容。'''
        compression_retriever = ContextualCompressionRetriever(
            base_compressor=LLMChainFilter.from_llm(self.__llm),
            # https://python.langchain.com/v0.2/docs/integrations/retrievers/re_phrase/#setting-up
            # 提取问题关键元素
            base_retriever= RePhraseQueryRetriever.from_llm(retriever, self.__llm)
        )

        '''rerank https://python.langchain.com/v0.2/docs/integrations/document_transformers/cross_encoder_reranker/'''
        model = HuggingFaceCrossEncoder(model_name=rerank_model,model_kwargs=model_kwargs)
        compressor = CrossEncoderReranker(model=model, top_n=3)

        compression_retriever = ContextualCompressionRetriever(
            base_compressor=compressor, base_retriever=compression_retriever
        )

        print('compression_retriever:',compression_retriever)

        return compression_retriever

MyKnowledge 类介绍

MyKnowledge 类是一个设计用于管理和检索知识库的系统。它结合了文件上传、知识加载、以及基于特定集合的检索器获取等功能。以下是该类的主要组成部分及其功能详解:

属性
  • __embeddings: 利用Hugging Face的BGE(Big Giant Embeddings)技术,为知识库中的文档生成向量表示。这有助于后续的相似度计算和文档检索。
  • __retrievers: 一个字典,用于存储不同知识集合对应的检索器。这些检索器能够基于用户的查询返回相关的文档。
  • __llm: 一个基于OpenAI的聊天模型实例,用于执行语言理解和生成任务,如查询重述和结果过滤。
方法
  1. ‌**upload_knowledge(self, temp_file)**‌

    • 功能:将临时文件上传到知识库目录。
    • 实现:提取文件名,构建存储路径,复制文件到目标位置,并更新Gradio组件的选项以反映新的知识集合。
  2. ‌**load_knowledge(self)**‌

    • 功能:加载知识库中的所有知识集合,并为每个集合创建检索器。
    • 实现:遍历知识库目录,为每个文件创建一个唯一的MD5标识作为集合名。若集合名已存在于__retrievers中,则跳过。否则,使用MyCustomLoader加载文件,并通过create_indexes函数(未在代码中定义,可能是一个自定义函数)结合向量表示创建检索器,存入__retrievers
  3. ‌**get_retrievers(self, collection)**‌

    • 功能:根据集合名获取对应的检索器,并对其进行增强处理。
    • 实现:首先,将集合名转换为MD5标识。若该标识不存在于__retrievers中,则返回None。否则,从__retrievers中获取基础检索器,并使用ContextualCompressionRetriever进行封装。封装过程中,先通过LLMChainFilter(基于__llm)进行结果过滤,再通过RePhraseQueryRetriever(同样基于__llm)进行查询重述。最后,还可以选择性地添加一个基于HuggingFaceCrossEncoderCrossEncoderReranker进行结果重排序。

引入的Retriever介绍

在LangChain框架中,Retriever是用于根据非结构化查询返回相关文档的重要组件。以下是您所引用的几个Retriever的详细介绍:

  1. ‌**ContextualCompressionRetriever(上下文压缩检索器)**‌

    • 功能‌:上下文压缩检索器旨在解决在数据检索过程中,与查询最相关的信息可能被埋藏在大量无关文本中的问题。它利用给定查询的上下文对检索到的文档进行压缩,以便只返回相关信息。这里的"压缩"既指压缩单个文档的内容,也指整体过滤掉不相关的文档。
    • 工作原理‌:该检索器接收查询,传递给基础检索器,接收初步文档,并通过文档压缩器处理这些文档,以缩短它们并仅返回与查询最相关的信息。
  2. ‌**EnsembleRetriever(集成检索器)**‌

    • 功能‌:集成检索器是LangChain中一种强大的检索策略,它结合了多个检索器的结果,以提高检索的准确性和全面性。通过集成不同的检索器,可以利用它们各自的优点,从而得到更优质的检索结果。
    • 工作原理‌:集成检索器会对每个子检索器执行查询,并合并它们的结果。合并策略可能包括简单的结果汇总、基于排名的加权合并等。具体实现可能因LangChain的版本和配置而异。
  3. ‌**RePhraseQueryRetriever(查询重述检索器)**‌

    • 功能‌:查询重述检索器通过自动重述用户查询来提高检索效果。它尝试以不同的方式表达查询,以便更准确地匹配文档库中的相关内容。这对于处理复杂或模糊的查询特别有用。
    • 工作原理‌:该检索器接收原始查询,然后使用自然语言处理(NLP)技术生成查询的重述版本。这些重述版本随后被用于执行检索,并合并结果以返回给用户。重述策略可能包括同义词替换、句式变换等。

这些Retriever在LangChain框架中发挥着重要作用,它们共同构成了强大的信息检索系统,能够处理各种复杂的查询场景,并返回准确、全面的结果。通过合理利用这些Retriever,可以显著提高信息检索的效率和准确性‌

LLMChainFilter

LLMChainFilter 是 LangChain 中用于优化文档检索流程的组件,属于 document_compressors 模块,与 CrossEncoderReranker 协同工作,主要用于对检索到的文档进行智能过滤和压缩,以减少无关信息对后续处理的影响。其核心功能及特点如下:


1. 功能与作用

  • 动态过滤逻辑‌:基于大型语言模型(LLM)生成的条件或规则,自动筛选出与用户查询最相关的文档片段。例如,通过分析用户提问的语义,提取关键词或结构化条件(如时间范围、实体类型等),再对文档进行匹配‌14。
  • 文档压缩‌:通过删除冗余内容或保留关键段落,降低文档长度,提升后续处理(如重排序、答案生成)的效率‌24。
  • 上下文感知‌:结合 LLM 的语义理解能力,判断文档片段是否与当前查询的上下文相关,避免机械式关键词匹配的局限性‌

MyLLM

python 复制代码
class MyLLM(MyKnowledge):
    # 初始化聊天记录
    __chat_history = ChatMessageHistory()
    print('__chat_history:',__chat_history)
    #创建问答链(QA Chain),检索与过滤,设置聊天模型,返回一个带历史记录的可执行链
    def get_chain(self, collection, max_length, temperature):
        retriever = None
        print('collection:',collection)
        # 判断是否有 collection 知识库
        if collection:
            retriever = self.get_retrievers(collection) # 如果有知识库,调用 get_retrievers 方法检索相应的知识库
            print('retriever:',retriever)

        # 只保留3个记录
        print('len:',self.__chat_history.messages,'####:',len(self.__chat_history.messages))
        if len(self.__chat_history.messages) > 6:
            self.__chat_history.messages = self.__chat_history.messages[-6:]

        chat = ChatOpenAI(model=ai_model, max_tokens=max_length, temperature=temperature)

        if retriever:
            question_answer_chain = create_stuff_documents_chain(chat, qa_prompt) # 创建一个问答链,用于处理问题并生成回答
            print('question_answer_chain:',question_answer_chain)
            rag_chain = create_retrieval_chain(retriever, question_answer_chain) # 创建一个检索链,将检索器和问答链结合起来
            print('rag_chain:',rag_chain)
        else:
            rag_chain = normal_prompt | chat | streaming_parse
            print('rag_chain:',rag_chain)
        ''' 需要注意:output_messages_key,如果是无知识库的情况下是从AIMessageChunk的Content取,
            知识库是返回 AddableDict('answer') '''

        conversational_rag_chain = RunnableWithMessageHistory(
            rag_chain,
            lambda session_id: self.__chat_history,
            input_messages_key="input",
            history_messages_key="chat_history",
            output_messages_key="answer",
        )
        print('conversational_rag_chain:',conversational_rag_chain)
        return conversational_rag_chain

    def invoke(self, question, collection, max_length=256, temperature=1):
        return self.get_chain(collection, max_length, temperature).invoke(
            {"input": question},
            {"configurable": {"session_id": "unused"}},
        )

    def stream(self, question, collection, max_length=256, temperature=1):
        return self.get_chain(collection, max_length, temperature).stream(
            {"input": question},
            {"configurable": {"session_id": "unused"}},
        )
    def clear_history(self) -> None:
        self.__chat_history.clear()
    def get_history_message(self):
        return self.__chat_history.messages

create_retrieval_chain

create_retrieval_chain 是 LangChain 库中的一个函数,用于创建一个信息检索链(Retrieval Chain)。这个链式操作允许你将多个组件组合起来,以实现从文档集合中检索相关信息的自动化流程。下面是对 create_retrieval_chain 的详细介绍:

create_retrieval_chain 函数允许你创建一个检索链,该链可以执行以下操作:

  1. 文档加载‌:从各种数据源加载文档,如本地文件、数据库或在线资源。
  2. 文档预处理‌:对加载的文档进行预处理,如文本清洗、分割或向量化。
  3. 检索‌:使用检索器(Retriever)在预处理后的文档集合中查找与查询相关的信息。
  4. 后处理‌:对检索结果进行进一步处理,如重排序、过滤或摘要生成。

RunnableWithMessageHistory

RunnableWithMessageHistory 是 LangChain 库中的一个类,它提供了一种简便的方法来为智能聊天机器人或其他可运行对象(Runnable)添加消息历史记录管理功能。通过使用 RunnableWithMessageHistory,你可以让聊天机器人记住并理解整个对话的上下文,从而提供更加连贯和个性化的响应。

  1. 消息历史加载与保存‌:

    • 在调用可运行对象之前,RunnableWithMessageHistory 会自动加载与会话相关的消息历史记录。
    • 在调用可运行对象之后,它会自动保存产生的响应到消息历史记录中。
  2. 多会话管理‌:

    • 通过使用会话 ID(session_id),RunnableWithMessageHistory 能够区分并管理多个会话的消息历史记录。
    • 这使得在同一时间内处理多个用户的对话成为可能。
  3. 灵活的消息存储‌:

    • RunnableWithMessageHistory 允许你选择一个消息存储后端,如 SQLite、Redis、Postgres 等,来存储和加载消息历史记录。
    • 你可以根据性能需求和数据持久性要求来选择最适合你的存储后端。

测试:

  1. 上传知识库
python 复制代码
if __name__ == "__main__":
    k = MyKnowledge()
    k.load_knowledge()

由于chroma使用sqlite作为底层数据库,导入后的数据如下

upsertion_record

继续测试

python 复制代码
    retriever = k.get_retrievers("中国人工智能系列白皮书.pdf")
    docs = retriever.base_retriever.invoke("确保大模型高效的技术有哪些?")
    print("rerank前:")
    for doc in docs:
        print(doc)
    docs = retriever.invoke("确保大模型高效的技术有哪些?")
    print("rerank后:")
    for doc in docs:
        print(doc)

主入口界面(main.py)

python 复制代码
# 导入Gradio库,用于创建交互式Web应用程序
import gradio as gr

# 从llm模块中导入MyLLM类,这是自定义的大型语言模型接口
from llm import MyLLM

# 实例化MyLLM类,用于后续的模型调用和处理
llm = MyLLM()

# 定义submit函数,用于处理用户提交的查询
# query => 是用户输入
# chat_history => 聊天的历史记录
def submit(query, chat_history):
    print('query:',query)
    print('chat_history:',chat_history)
    f""" 搜索知识库 """
    # 如果查询为空字符串,返回空字符串和当前的聊天记录
    if query == '':
        return '', chat_history
    # 如果查询不为空,将查询添加到聊天记录中,并返回更新后的聊天记录
    chat_history.append([query, None]) # 还没有使用到模型的回复
    print('chat_history####:',chat_history)
    print('('', chat_history):',('', chat_history))
    return '', chat_history

# 定义 change_collection 函数,用于在更改知识库时清除模型历史记录
def change_collection():
    llm.clear_history()

# 定义load_history函数,用于加载模型的历史消息
def load_history():
    # 将历史消息格式化为对话形式,每两条消息为一对
    history = llm.get_history_message()
    return [
        [history[i].content, history[i + 1].content]
        if i + 1 < len(history) else
        [history[i].content, None]
        for i in range(0, len(history), 2)
        ]

# 定义llm_reply函数,用于生成模型回复
def llm_reply(collection, chat_history, max_length=256, temperature=1):
    question = chat_history[-1][0]
    print('question:',question)
    # 使用流式生成方法从模型中获取回复
    response = llm.stream(question, collection, max_length=max_length, temperature=temperature)
    print('response:',response)
    # print('responselist:',list(response))
    chat_history[-1][1] = ""
    print('chat_history:',chat_history)

    # 逐块处理模型生成的回复
    for chunk in response:
        print('chunk:',chunk)
        if 'context' in chunk:
            # 如果块中包含上下文信息,则打印出来
            for doc in chunk['context']:
                print('doc:',doc)
        if 'answer' in chunk:
            # 如果块中包含答案,则将其追加到聊天记录的最后一个条目中
            chunk_content = chunk['answer']
            print('chunk_content:',chunk_content)
            if chunk_content is not None:
                chat_history[-1][1] += chunk_content
                # 返回更新后的聊天记录
                yield chat_history
    print('chat_history:', chat_history)

# 创建一个Gradio Blocks应用,设置fill_height为True
with gr.Blocks(fill_height=True) as demo:
    # 在应用中添加一个HTML元素,显示标题
    gr.HTML("""<h1 align="center">Chat With AI</h1>""")

    # 创建一个新的行布局
    with gr.Row():
        # 创建一个占比为 4 的列布局
        with gr.Column(scale=4):
            # 创建一个聊天机器人界面
            chatbot = gr.Chatbot(show_label=False, scale=3, show_copy_button=True)

        # 创建一个占比为 1 的列布局,显示进度
        with gr.Column(scale=1, show_progress=True) as column_config:
            # 创建一个滑块,用于设置生成回复的最大长度
            max_length = gr.Slider(1, 4095, value=256, step=1.0, label="Maximum length", interactive=True)
            # 创建一个滑块,用于设置生成回复的温度
            temperature = gr.Slider(0, 2, value=1, step=0.01, label="Temperature", interactive=True)
            # 创建一个按钮,用于清除聊天记录
            clear = gr.Button("清除")
            # 创建一个下拉菜单,用于选择知识库
            collection = gr.Dropdown(choices=llm.load_knowledge(), label="知识库")
            # 创建一个文件上传控件,支持多种文件类型
            file = gr.File(label="上传文件", file_types=['doc', 'docx', 'csv', 'txt', 'pdf', 'md'])

    # 创建一个文本框,用于用户输入
    user_input = gr.Textbox(placeholder="Input...", show_label=False)
    # 创建一个按钮,用于提交用户输入
    user_submit = gr.Button("提交")

    # 绑定 clear 按钮的点击事件,清除模型历史记录,并更新聊天机器人界面
    clear.click(fn=llm.clear_history, inputs=None, outputs=[chatbot])

    # 当我们询问模型的时候,按的回车键,会触发 => submit 事件
    user_input.submit(fn=submit,
                      inputs=[user_input, chatbot], # 一个是用户输入 一个当前的聊天记录
                      outputs=[user_input, chatbot] # 一个是为了清空用户输入的文本框 另一个是更新后的聊天记录,将新的用户查询添加到聊天记录中。
                      ).then( # 执行代码的时候有先后顺序
        fn=llm_reply,
        inputs=[collection, chatbot, max_length, temperature],
                # collection: 用户选择的知识库。
                # chatbot: 当前的聊天记录(已经包含用户的新查询)。
                # max_length: 用户设置的生成回复的最大长度。
                # temperature: 用户设置的生成回复的温度。
        outputs=[chatbot]
        # 更新后的聊天记录,将模型生成的回复添加到聊天记录中。
    )
    # 绑定用户输入文本框的提交事件,
    # 先调用submit函数,
    # 然后调用llm_reply函数,
    # 并更新聊天机器人界面
    user_submit.click(fn=submit,
                      inputs=[user_input, chatbot],
                      outputs=[user_input, chatbot]
                      ).then(
        fn=llm_reply,
        inputs=[collection, chatbot, max_length, temperature],
        outputs=[chatbot]
    )
    # 绑定提交按钮的点击事件,先调用submit函数,
    # 然后调用llm_reply函数,并更新聊天机器人界面

    # 绑定文件上传控件的上传事件,调用upload_knowledge函数,并更新文件控件和知识库下拉菜单
    file.upload(fn=llm.upload_knowledge, inputs=[file], outputs=[file, collection])

    # 绑定知识库下拉菜单的更改事件,调用clear_history函数,并更新聊天机器人界面 也就是换一个知识库就清空当前的页面
    collection.change(fn=llm.clear_history, inputs=None, outputs=[chatbot])

    # 绑定应用加载事件,调用clear_history函数,并更新聊天机器人界面
    demo.load(fn=llm.clear_history, inputs=None, outputs=[chatbot])

# 启动 Gradio 应用
demo.launch()

以上代码构建一个基于Gradio库的交互式Web聊天应用程序,该程序将使用自定义的大型语言模型接口MyLLM来与用户进行交互。

以上也就完成了一个基于RAG的公司制度查询系统。完整代码和教学视频参照以下链接:

https://cloud.189.cn/t/UvMzA377JV3m(访问码:6feb)天翼云盘是中国电信推出的云存储服务,为用户提供跨平台的文件存储、备份、同步及分享服务,是国内领先的免费网盘,安全、可靠、稳定、快速。天翼云盘为用户守护数据资产。https://cloud.189.cn/t/UvMzA377JV3m(访问码:6feb)

相关推荐
ZWZhangYu10 小时前
LangChain 构建向量数据库和检索器
数据库·langchain·easyui
booooooty19 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
伊布拉西莫1 天前
LangChain 全面入门
langchain
AI大模型1 天前
LangGraph官方文档笔记(七)——Agent的输入输出
langchain·llm·agent
knqiufan1 天前
深度解析影响 RAG 召回率的四大支柱——模型、数据、索引与检索
llm·milvus·向量数据库·rag
AI大模型2 天前
LangGraph官方文档笔记(6)——时间旅行
程序员·langchain·llm
是小王同学啊~3 天前
(LangChain)RAG系统链路向量检索器之Retrievers(五)
python·算法·langchain
AIGC包拥它3 天前
提示技术系列——链式提示
人工智能·python·langchain·prompt
在未来等你3 天前
RAG实战指南 Day 4:LlamaIndex框架实战指南
大语言模型·rag·llamaindex·检索增强生成·ai开发
AI大模型3 天前
LangGraph官方文档笔记(4)——提示聊天机器人
程序员·langchain·llm