milvus+langchain实现RAG应用

一、手动实现数据处理、流程编排

python 复制代码
from glob import glob
import os
from openai import OpenAI
from pymilvus import MilvusClient
from tqdm import tqdm
import json
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage

# 使用硅基流动的免费embedding模型
openai_client = OpenAI(
    api_key="***",
    base_url="https://api.siliconflow.cn/v1",
)
milvus_client = MilvusClient(uri="./milvus_demo.db")
collection_name = "my_rag_collection"

# 使用智谱的免费文本生成模型
llm = ChatOpenAI(
    temperature=0.6,
    model="glm-4.5",
    openai_api_key="***",
    openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
)


def emb_long_text(text):
    chunk_size = 512
    if len(text) <= chunk_size:
        return emb_text(text)
    embeddings = []
    for i in range(0, len(text), chunk_size):
        chunk = text[i : i + chunk_size]
        embedding = emb_text(chunk)
        embeddings.append(embedding)
    # Average the embeddings of all chunks
    avg_embedding = [sum(x) / len(embeddings) for x in zip(*embeddings)]
    return avg_embedding


def emb_text(text):
    return (
        openai_client.embeddings.create(input=text, model="BAAI/bge-m3")
        .data[0]
        .embedding
    )


def create_data_if_need():
    if milvus_client.has_collection(collection_name):
        print(milvus_client.describe_collection(collection_name))
        return
    text_lines = []
    for file_path in glob(
        os.path.expanduser("~/Desktop/milvus_docs/**/*.md"), recursive=True
    ):
        with open(file_path, "r", encoding="utf-8") as file:
            file_text = file.read()
        text_lines += file_text.split("# ")
    embedding_dim = 1024

    milvus_client.create_collection(
        collection_name=collection_name,
        dimension=embedding_dim,
        metric_type="IP",  # Inner product distance
        consistency_level="Bounded",
    )
    data = []
    for i, line in enumerate(tqdm(text_lines, desc="Creating embeddings")):
        if not line.strip():
            continue
        vector = emb_long_text(line)
        if not vector:
            print(f"Failed to embed line: {line}")
            continue
        data.append({"id": i, "vector": vector, "text": line, "text_len": len(line)})
    milvus_client.insert(collection_name=collection_name, data=data)


def do_chat(context, question):
    SYSTEM_PROMPT = """
    你是一名AI助手,你将根据提供的上下文信息回答用户的问题。如果上下文中没有相关信息,请诚实地告诉用户你不知道答案,而不是编造答案。
    你必须严格根据上下文信息作答,不能凭空添加任何信息。
    """
    USER_PROMPT = f"""
    使用下面的context标签中的信息用中文回答用户question标签中的问题。
    <context>
    {context}
    </context>
    <question>
    {question}
    </question>
    """

    # 创建消息
    messages = [
        SystemMessage(content=SYSTEM_PROMPT),
        HumanMessage(content=USER_PROMPT),
    ]

    # 调用模型
    response = llm.invoke(messages)
    print(response.content)


if __name__ == "__main__":
    create_data_if_need()
    while True:
        question = input("请输入你的问题: ")
        search_res = milvus_client.search(
            collection_name=collection_name,
            data=[emb_text(question)],
            limit=35,
            filter="text_len > 500",
            search_params={"metric_type": "IP", "params": {}},
            output_fields=["text"],
        )
        retrieved_lines_with_distances = [
            (res["entity"]["text"], res["distance"]) for res in search_res[0]
        ]
        print(json.dumps(retrieved_lines_with_distances, indent=4))
        context = "\n".join(
            [
                line_with_distance[0]
                for line_with_distance in retrieved_lines_with_distances
            ]
        )
        do_chat(context, question)

二、通过langchain自动进行数据处理、流程编排

python 复制代码
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_milvus import Milvus
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pymilvus import MilvusClient
from glob import glob
from tqdm import tqdm
import os, json

# 创建嵌入模型
embeddings = OpenAIEmbeddings(
    model="BAAI/bge-m3",
    openai_api_key="***",
    openai_api_base="https://api.siliconflow.cn/v1",
)
# 创建语言模型
llm = ChatOpenAI(
    temperature=0.6,
    model="glm-4.5-flash",
    openai_api_key="***",
    openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
)
URI = "./milvus_demo_v2.db"
collection_name = "my_rag_collection_v2"

# 加载或创建向量存储
def load_vectorstore():
    vector_store = Milvus(
        collection_name=collection_name,
        embedding_function=embeddings,
        connection_args={"uri": URI},
    )
    milvus_client = MilvusClient(uri=URI)
    if milvus_client.has_collection(collection_name):
        print("向量存储已存在,直接加载...")
        return vector_store
    print("创建新的向量存储...")
    # 加载文档
    documents = []
    for file_path in glob(
        # 文档下载地址:https://github.com/milvus-io/milvus-docs/releases/download/v2.4.6-preview/milvus_docs_2.4.x_en.zip
        os.path.expanduser("~/Desktop/milvus_docs/**/*.md"),
        recursive=True,
    ):
        try:
            loader = TextLoader(file_path, encoding="utf-8")
            documents.extend(loader.load())
        except Exception as e:
            print(f"加载文件 {file_path} 时出错: {e}")

    if not documents:
        raise ValueError("未找到任何文档!")

    # milvus-docs已经按文件进行了分割,不再进行文本分割
    # text_splitter = RecursiveCharacterTextSplitter(
    #     chunk_size=800,
    #     chunk_overlap=100,
    #     length_function=len,
    # )
    # splits = text_splitter.split_documents(documents)

    # 补充文本长度元数据,方便后续检索过滤
    for doc in documents:
        doc.metadata["text_len"] = len(doc.page_content)
    print(f"共 {len(documents)} 个document需要插入")
    # 步长过长可能会导致达到接口限制
    stride = 10
    for i in tqdm(range(0, len(documents), stride), desc="添加文档到向量存储ing..."):
        sub_splits = documents[i : i + stride]
        vector_store.add_documents(sub_splits)
    return vector_store


def format_docs(docs):
    for doc in docs:
        print(doc)
    print("===========================================================================")
    print("===========================================================================")
    print("===================向量库检索完成,等待大模型响应ing....===================")
    print("===========================================================================")
    print("===========================================================================")
    # 合并多个关联文档,以提交给大模型
    return "\n\n".join(doc.page_content for doc in docs)


if __name__ == "__main__":
    # 创建或加载向量存储
    vector_store = load_vectorstore()
    while True:
        query = input("请输入您的问题:")
        # 定义提示模板
        prompt_template = """你是一名AI助手,你将根据提供的上下文信息回答用户的问题。
如果上下文中没有相关信息,请诚实地告诉用户你不知道答案,而不是编造答案。
你必须严格根据上下文信息作答,不能凭空添加任何信息。

使用下面的context标签中的信息用中文回答用户question标签中的问题。
<context>
{context}
</context>
<question>
{question}
</question>"""
        prompt = PromptTemplate(
            template=prompt_template, input_variables=["context", "question"]
        )
        retriever = vector_store.as_retriever(
            # 取top10个关联文档,再通过过滤器筛选
            search_kwargs=dict(k=10, expr="text_len > 300")
        )
        rag_chain = (
            {"context": retriever | format_docs, "question": RunnablePassthrough()}
            | prompt
            | llm
            | StrOutputParser()
        )
        final_res = rag_chain.invoke(query)
        print(final_res)
相关推荐
王国强20097 小时前
Workflows vs Agents:如何选择你的 LLM 应用架构?
langchain
~kiss~7 小时前
Milvus-云原生和分布式的开源向量数据库-介绍
分布式·云原生·milvus
玲小珑9 小时前
LangChain.js 完全开发手册(十九)前端 AI 开发进阶技巧
前端·langchain·ai编程
大模型真好玩9 小时前
LangChain1.0实战之多模态RAG系统(一)——多模态RAG系统核心架构及智能问答功能开发
人工智能·langchain·agent
zhangbaolin19 小时前
深度智能体-长短期记忆
langchain·大模型·长期记忆·深度智能体·短期记忆
Geo_V1 天前
LangChain Memory 使用示例
人工智能·python·chatgpt·langchain·openai·大模型应用·llm 开发
小程故事多_801 天前
LangChain1.0系列:中间件深度解析,让 AI智能体上下文控制不失控
人工智能·中间件·langchain
汗流浃背了吧,老弟!1 天前
采用Langchain调用LLM完成简单翻译任务
langchain
Miku161 天前
LangGraph+BrightData+PaperSearch的研究助理
爬虫·langchain·mcp