python
复制代码
from glob import glob
import os
from openai import OpenAI
from pymilvus import MilvusClient
from tqdm import tqdm
import json
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
# 使用硅基流动的免费embedding模型
openai_client = OpenAI(
api_key="***",
base_url="https://api.siliconflow.cn/v1",
)
milvus_client = MilvusClient(uri="./milvus_demo.db")
collection_name = "my_rag_collection"
# 使用智谱的免费文本生成模型
llm = ChatOpenAI(
temperature=0.6,
model="glm-4.5",
openai_api_key="***",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
)
def emb_long_text(text):
chunk_size = 512
if len(text) <= chunk_size:
return emb_text(text)
embeddings = []
for i in range(0, len(text), chunk_size):
chunk = text[i : i + chunk_size]
embedding = emb_text(chunk)
embeddings.append(embedding)
# Average the embeddings of all chunks
avg_embedding = [sum(x) / len(embeddings) for x in zip(*embeddings)]
return avg_embedding
def emb_text(text):
return (
openai_client.embeddings.create(input=text, model="BAAI/bge-m3")
.data[0]
.embedding
)
def create_data_if_need():
if milvus_client.has_collection(collection_name):
print(milvus_client.describe_collection(collection_name))
return
text_lines = []
for file_path in glob(
os.path.expanduser("~/Desktop/milvus_docs/**/*.md"), recursive=True
):
with open(file_path, "r", encoding="utf-8") as file:
file_text = file.read()
text_lines += file_text.split("# ")
embedding_dim = 1024
milvus_client.create_collection(
collection_name=collection_name,
dimension=embedding_dim,
metric_type="IP", # Inner product distance
consistency_level="Bounded",
)
data = []
for i, line in enumerate(tqdm(text_lines, desc="Creating embeddings")):
if not line.strip():
continue
vector = emb_long_text(line)
if not vector:
print(f"Failed to embed line: {line}")
continue
data.append({"id": i, "vector": vector, "text": line, "text_len": len(line)})
milvus_client.insert(collection_name=collection_name, data=data)
def do_chat(context, question):
SYSTEM_PROMPT = """
你是一名AI助手,你将根据提供的上下文信息回答用户的问题。如果上下文中没有相关信息,请诚实地告诉用户你不知道答案,而不是编造答案。
你必须严格根据上下文信息作答,不能凭空添加任何信息。
"""
USER_PROMPT = f"""
使用下面的context标签中的信息用中文回答用户question标签中的问题。
<context>
{context}
</context>
<question>
{question}
</question>
"""
# 创建消息
messages = [
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=USER_PROMPT),
]
# 调用模型
response = llm.invoke(messages)
print(response.content)
if __name__ == "__main__":
create_data_if_need()
while True:
question = input("请输入你的问题: ")
search_res = milvus_client.search(
collection_name=collection_name,
data=[emb_text(question)],
limit=35,
filter="text_len > 500",
search_params={"metric_type": "IP", "params": {}},
output_fields=["text"],
)
retrieved_lines_with_distances = [
(res["entity"]["text"], res["distance"]) for res in search_res[0]
]
print(json.dumps(retrieved_lines_with_distances, indent=4))
context = "\n".join(
[
line_with_distance[0]
for line_with_distance in retrieved_lines_with_distances
]
)
do_chat(context, question)
python
复制代码
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_milvus import Milvus
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pymilvus import MilvusClient
from glob import glob
from tqdm import tqdm
import os, json
# 创建嵌入模型
embeddings = OpenAIEmbeddings(
model="BAAI/bge-m3",
openai_api_key="***",
openai_api_base="https://api.siliconflow.cn/v1",
)
# 创建语言模型
llm = ChatOpenAI(
temperature=0.6,
model="glm-4.5-flash",
openai_api_key="***",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
)
URI = "./milvus_demo_v2.db"
collection_name = "my_rag_collection_v2"
# 加载或创建向量存储
def load_vectorstore():
vector_store = Milvus(
collection_name=collection_name,
embedding_function=embeddings,
connection_args={"uri": URI},
)
milvus_client = MilvusClient(uri=URI)
if milvus_client.has_collection(collection_name):
print("向量存储已存在,直接加载...")
return vector_store
print("创建新的向量存储...")
# 加载文档
documents = []
for file_path in glob(
# 文档下载地址:https://github.com/milvus-io/milvus-docs/releases/download/v2.4.6-preview/milvus_docs_2.4.x_en.zip
os.path.expanduser("~/Desktop/milvus_docs/**/*.md"),
recursive=True,
):
try:
loader = TextLoader(file_path, encoding="utf-8")
documents.extend(loader.load())
except Exception as e:
print(f"加载文件 {file_path} 时出错: {e}")
if not documents:
raise ValueError("未找到任何文档!")
# milvus-docs已经按文件进行了分割,不再进行文本分割
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=800,
# chunk_overlap=100,
# length_function=len,
# )
# splits = text_splitter.split_documents(documents)
# 补充文本长度元数据,方便后续检索过滤
for doc in documents:
doc.metadata["text_len"] = len(doc.page_content)
print(f"共 {len(documents)} 个document需要插入")
# 步长过长可能会导致达到接口限制
stride = 10
for i in tqdm(range(0, len(documents), stride), desc="添加文档到向量存储ing..."):
sub_splits = documents[i : i + stride]
vector_store.add_documents(sub_splits)
return vector_store
def format_docs(docs):
for doc in docs:
print(doc)
print("===========================================================================")
print("===========================================================================")
print("===================向量库检索完成,等待大模型响应ing....===================")
print("===========================================================================")
print("===========================================================================")
# 合并多个关联文档,以提交给大模型
return "\n\n".join(doc.page_content for doc in docs)
if __name__ == "__main__":
# 创建或加载向量存储
vector_store = load_vectorstore()
while True:
query = input("请输入您的问题:")
# 定义提示模板
prompt_template = """你是一名AI助手,你将根据提供的上下文信息回答用户的问题。
如果上下文中没有相关信息,请诚实地告诉用户你不知道答案,而不是编造答案。
你必须严格根据上下文信息作答,不能凭空添加任何信息。
使用下面的context标签中的信息用中文回答用户question标签中的问题。
<context>
{context}
</context>
<question>
{question}
</question>"""
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
retriever = vector_store.as_retriever(
# 取top10个关联文档,再通过过滤器筛选
search_kwargs=dict(k=10, expr="text_len > 300")
)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
final_res = rag_chain.invoke(query)
print(final_res)