LangChain解析器与下游任务集成的源码实现剖析
一、LangChain解析器架构源码解析
1.1 文档加载器核心实现
LangChain文档加载器的核心设计基于抽象基类BaseLoader
,所有具体的加载器都必须继承该类并实现load
方法。以下是文档加载器的关键源码片段:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any
class BaseLoader(ABC):
"""文档加载器的抽象基类"""
@abstractmethod
def load(self) -> List[Dict[str, Any]]:
"""加载文档并返回文档列表"""
pass
class TextLoader(BaseLoader):
"""文本文件加载器"""
def __init__(self, file_path: str, encoding: str = "utf-8"):
self.file_path = file_path
self.encoding = encoding
def load(self) -> List[Dict[str, Any]]:
"""实现抽象方法,加载文本文件"""
try:
with open(self.file_path, "r", encoding=self.encoding) as f:
content = f.read()
# 返回包含文本内容和元数据的文档对象
return [{"content": content, "metadata": {"source": self.file_path}}]
except Exception as e:
# 错误处理
raise ValueError(f"Error loading {self.file_path}: {e}")
class PDFLoader(BaseLoader):
"""PDF文件加载器,依赖PyMuPDF库"""
def __init__(self, file_path: str):
self.file_path = file_path
def load(self) -> List[Dict[str, Any]]:
"""实现抽象方法,加载PDF文件"""
try:
import fitz # PyMuPDF
documents = []
with fitz.open(self.file_path) as doc:
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text()
# 为每个页面创建文档对象,包含页码元数据
documents.append({
"content": text,
"metadata": {"source": self.file_path, "page": page_num + 1}
})
return documents
except ImportError:
raise ValueError("PyMuPDF is required to load PDF files. Install with `pip install pymupdf`")
except Exception as e:
raise ValueError(f"Error loading {self.file_path}: {e}")
文档加载器的设计采用了策略模式,通过抽象基类定义统一接口,具体实现由各个子类完成。这种设计使得新增加载器变得简单,只需继承BaseLoader
并实现load
方法即可。
1.2 文本拆分器源码分析
文本拆分器的核心是TextSplitter
抽象类,它定义了文本拆分的基本接口。以下是关键实现:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any
class TextSplitter(ABC):
"""文本拆分器的抽象基类"""
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200):
"""初始化拆分器参数"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# 验证参数有效性
if chunk_overlap > chunk_size:
raise ValueError(f"Chunk overlap {chunk_overlap} must be less than chunk size {chunk_size}")
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""将文本拆分为多个片段"""
pass
def create_documents(self, texts: List[str], metadata: Dict[str, Any] = None) -> List[Dict[str, Any]]:
"""创建文档对象列表"""
metadata = metadata or {}
documents = []
for i, text in enumerate(texts):
# 为每个文本片段添加索引元数据
doc_metadata = {**metadata, "chunk_index": i}
documents.append({"content": text, "metadata": doc_metadata})
return documents
class CharacterTextSplitter(TextSplitter):
"""基于字符的文本拆分器"""
def __init__(self, separator: str = "\n\n", **kwargs):
"""初始化分隔符"""
super().__init__(**kwargs)
self.separator = separator
def split_text(self, text: str) -> List[str]:
"""实现基于字符的文本拆分"""
if not text:
return []
# 使用分隔符分割文本
splits = text.split(self.separator)
# 合并过小的片段,确保每个片段大小接近chunk_size
merged_splits = []
current_chunk = ""
for split in splits:
if len(current_chunk) + len(split) < self.chunk_size:
if current_chunk:
current_chunk += self.separator + split
else:
current_chunk = split
else:
merged_splits.append(current_chunk)
current_chunk = split
# 添加最后一个片段
if current_chunk:
merged_splits.append(current_chunk)
# 处理片段重叠
if self.chunk_overlap > 0 and len(merged_splits) > 1:
final_splits = []
for i in range(len(merged_splits)):
if i == 0:
# 第一个片段不需要前缀重叠
final_splits.append(merged_splits[i])
else:
# 获取前一个片段的最后部分作为重叠
overlap = merged_splits[i-1][-self.chunk_overlap:]
# 将重叠部分添加到当前片段的开头
final_splits.append(overlap + merged_splits[i])
return final_splits
return merged_splits
文本拆分器的设计考虑了多种因素,如片段大小控制、片段重叠处理等。CharacterTextSplitter
是最基本的实现,还存在基于句子、基于标记等更高级的拆分器实现。
1.3 嵌入生成器源码实现
嵌入生成器负责将文本转换为向量表示,核心是Embeddings
抽象类:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import numpy as np
class Embeddings(ABC):
"""嵌入生成器的抽象基类"""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[np.ndarray]:
"""生成多个文档的嵌入向量"""
pass
def embed_query(self, text: str) -> np.ndarray:
"""生成单个查询的嵌入向量"""
return self.embed_documents([text])[0]
class OpenAIEmbeddings(Embeddings):
"""OpenAI API的嵌入生成器"""
def __init__(self, api_key: str = None, model_name: str = "text-embedding-ada-002"):
"""初始化OpenAI API客户端"""
import openai
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OpenAI API key not provided")
self.model_name = model_name
self.client = openai.Embedding
def embed_documents(self, texts: List[str]) -> List[np.ndarray]:
"""实现OpenAI嵌入生成"""
# 处理文本过长的情况
texts = [text[:8191] for text in texts] # OpenAI API限制
try:
# 批量调用OpenAI API
response = self.client.create(
input=texts,
model=self.model_name
)
# 提取嵌入向量并转换为numpy数组
embeddings = [np.array(data["embedding"]) for data in response["data"]]
return embeddings
except Exception as e:
raise ValueError(f"Error generating embeddings: {e}")
class HuggingFaceEmbeddings(Embeddings):
"""Hugging Face Transformers的嵌入生成器"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
model_kwargs: Dict[str, Any] = None):
"""初始化Hugging Face模型"""
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ValueError("sentence-transformers package is required. Install with `pip install sentence-transformers`")
self.model_name = model_name
self.model_kwargs = model_kwargs or {}
# 加载预训练模型
self.model = SentenceTransformer(model_name, **self.model_kwargs)
def embed_documents(self, texts: List[str]) -> List[np.ndarray]:
"""实现Hugging Face嵌入生成"""
try:
# 使用模型生成嵌入向量
embeddings = self.model.encode(texts, convert_to_numpy=True)
return [embedding for embedding in embeddings]
except Exception as e:
raise ValueError(f"Error generating embeddings: {e}")
嵌入生成器采用了适配器模式,通过统一的接口封装不同的嵌入服务提供商。这样设计使得可以轻松切换不同的嵌入模型,而不需要改变上层代码。
二、解析器与向量数据库集成源码解析
2.1 向量存储接口设计
LangChain通过VectorStore
抽象类定义了与向量数据库交互的统一接口:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import numpy as np
class VectorStore(ABC):
"""向量存储的抽象基类"""
@abstractmethod
def add_texts(self, texts: List[str], embeddings: List[np.ndarray],
metadatas: Optional[List[Dict[str, Any]]] = None) -> List[str]:
"""添加文本及其嵌入向量到向量存储"""
pass
@abstractmethod
def similarity_search(self, query_embedding: np.ndarray, k: int = 4) -> List[Dict[str, Any]]:
"""基于向量相似度检索文档"""
pass
@classmethod
@abstractmethod
def from_texts(cls, texts: List[str], embeddings: Embeddings,
metadatas: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> "VectorStore":
"""从文本集合创建向量存储实例"""
pass
这个接口定义了向量存储的基本操作:添加文本和嵌入向量、基于相似度检索文档,以及从文本集合创建实例的类方法。
2.2 Chroma向量数据库集成实现
Chroma是一个轻量级的向量数据库,LangChain提供了完整的集成支持:
python
import chromadb
from chromadb.config import Settings
from chromadb.api.types import Documents, Embeddings, IDs, Metadata
class Chroma(VectorStore):
"""Chroma向量数据库集成"""
def __init__(self, client: Optional[chromadb.Client] = None,
collection_name: str = "langchain",
embedding_function: Optional[Embeddings] = None):
"""初始化Chroma客户端和集合"""
self.client = client or chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=".chromadb"
))
self.collection_name = collection_name
self.embedding_function = embedding_function
# 创建或获取集合
self.collection = self.client.get_or_create_collection(
name=collection_name,
embedding_function=self._embedding_function if embedding_function else None
)
def _embedding_function(self, texts: Documents) -> Embeddings:
"""封装嵌入函数,用于Chroma内部调用"""
return [self.embedding_function.embed_query(text) for text in texts]
def add_texts(self, texts: List[str], embeddings: List[np.ndarray],
metadatas: Optional[List[Dict[str, Any]]] = None) -> List[str]:
"""实现添加文本和嵌入向量到Chroma"""
# 生成唯一ID
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
# 处理元数据
metadatas = metadatas or [{} for _ in range(len(texts))]
# 添加到Chroma集合
self.collection.add(
ids=ids,
embeddings=[embedding.tolist() for embedding in embeddings],
documents=texts,
metadatas=metadatas
)
return ids
def similarity_search(self, query_embedding: np.ndarray, k: int = 4) -> List[Dict[str, Any]]:
"""实现基于向量相似度的检索"""
# 执行相似度搜索
results = self.collection.query(
query_embeddings=query_embedding.tolist(),
n_results=k
)
# 处理搜索结果
documents = []
if "ids" in results and results["ids"]:
for i in range(len(results["ids"][0])):
documents.append({
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"similarity": results["distances"][0][i]
})
return documents
@classmethod
def from_texts(cls, texts: List[str], embeddings: Embeddings,
metadatas: Optional[List[Dict[str, Any]]] = None,
collection_name: str = "langchain",
client: Optional[chromadb.Client] = None,
**kwargs) -> "Chroma":
"""从文本集合创建Chroma向量存储实例"""
# 生成嵌入向量
embeddings_list = embeddings.embed_documents(texts)
# 创建Chroma实例
chroma = cls(
client=client,
collection_name=collection_name,
embedding_function=embeddings
)
# 添加文本和嵌入向量
chroma.add_texts(texts=texts, embeddings=embeddings_list, metadatas=metadatas)
return chroma
Chroma集成实现了VectorStore
接口的所有方法,通过封装Chroma的Python API,提供了与其他向量存储一致的操作方式。
2.3 Pinecone向量数据库集成实现
Pinecone是一个云原生的向量数据库,LangChain同样提供了完整的集成:
python
import pinecone
from typing import List, Dict, Any, Optional
import numpy as np
class Pinecone(VectorStore):
"""Pinecone向量数据库集成"""
def __init__(self, index: pinecone.Index, embedding_function: Embeddings,
text_key: str = "text", namespace: str = ""):
"""初始化Pinecone索引和配置"""
self.index = index
self.embedding_function = embedding_function
self.text_key = text_key
self.namespace = namespace
def add_texts(self, texts: List[str], embeddings: List[np.ndarray],
metadatas: Optional[List[Dict[str, Any]]] = None) -> List[str]:
"""实现添加文本和嵌入向量到Pinecone"""
# 生成唯一ID
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
# 处理元数据
metadatas = metadatas or [{} for _ in range(len(texts))]
for i, metadata in enumerate(metadatas):
metadata[self.text_key] = texts[i]
# 准备向量数据
vectors = []
for i in range(len(texts)):
vectors.append({
"id": ids[i],
"values": embeddings[i].tolist(),
"metadata": metadatas[i]
})
# 批量添加到Pinecone
self.index.upsert(vectors=vectors, namespace=self.namespace)
return ids
def similarity_search(self, query_embedding: np.ndarray, k: int = 4) -> List[Dict[str, Any]]:
"""实现基于向量相似度的检索"""
# 执行相似度搜索
results = self.index.query(
vector=query_embedding.tolist(),
top_k=k,
include_metadata=True,
namespace=self.namespace
)
# 处理搜索结果
documents = []
for match in results["matches"]:
metadata = match["metadata"]
documents.append({
"content": metadata.get(self.text_key, ""),
"metadata": metadata,
"similarity": match["score"]
})
return documents
@classmethod
def from_texts(cls, texts: List[str], embeddings: Embeddings,
metadatas: Optional[List[Dict[str, Any]]] = None,
index_name: str = None, embedding_dim: int = None,
text_key: str = "text", namespace: str = "",
**kwargs) -> "Pinecone":
"""从文本集合创建Pinecone向量存储实例"""
# 初始化Pinecone
if not pinecone.client:
api_key = kwargs.get("api_key", os.environ.get("PINECONE_API_KEY"))
environment = kwargs.get("environment", os.environ.get("PINECONE_ENVIRONMENT"))
if not api_key or not environment:
raise ValueError("Pinecone API key and environment are required")
pinecone.init(api_key=api_key, environment=environment)
# 获取或创建索引
if index_name not in pinecone.list_indexes():
if not embedding_dim:
# 需要先计算一个嵌入向量来确定维度
sample_embedding = embeddings.embed_query(texts[0])
embedding_dim = len(sample_embedding)
pinecone.create_index(index_name, dimension=embedding_dim, **kwargs)
index = pinecone.Index(index_name)
# 生成嵌入向量
embeddings_list = embeddings.embed_documents(texts)
# 创建Pinecone实例并添加数据
pinecone_store = cls(
index=index,
embedding_function=embeddings,
text_key=text_key,
namespace=namespace
)
pinecone_store.add_texts(texts=texts, embeddings=embeddings_list, metadatas=metadatas)
return pinecone_store
Pinecone集成同样实现了VectorStore
接口,通过封装Pinecone的Python API,提供了与其他向量存储一致的操作方式。这种设计使得开发者可以在不同的向量数据库之间轻松切换。
三、解析器与检索系统集成源码解析
3.1 检索器接口设计
LangChain通过BaseRetriever
抽象类定义了检索器的统一接口:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any
class BaseRetriever(ABC):
"""检索器的抽象基类"""
@abstractmethod
def get_relevant_documents(self, query: str) -> List[Dict[str, Any]]:
"""获取与查询相关的文档"""
pass
async def aget_relevant_documents(self, query: str) -> List[Dict[str, Any]]:
"""异步获取与查询相关的文档,默认同步实现"""
return self.get_relevant_documents(query)
这个接口非常简洁,只定义了一个核心方法get_relevant_documents
,用于根据查询获取相关文档。同时提供了一个异步版本的方法,默认实现为调用同步方法。
3.2 向量检索器实现
最常见的检索器实现是基于向量相似度的检索器:
python
from typing import List, Dict, Any, Optional
import numpy as np
class VectorStoreRetriever(BaseRetriever):
"""基于向量存储的检索器"""
def __init__(self, vector_store: VectorStore, search_kwargs: Dict[str, Any] = None):
"""初始化向量检索器"""
self.vector_store = vector_store
self.search_kwargs = search_kwargs or {"k": 4}
def get_relevant_documents(self, query: str) -> List[Dict[str, Any]]:
"""实现基于向量相似度的文档检索"""
# 生成查询的嵌入向量
query_embedding = self.vector_store.embedding_function.embed_query(query)
# 执行相似度搜索
docs = self.vector_store.similarity_search(
query_embedding=query_embedding,
**self.search_kwargs
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Dict[str, Any]]:
"""异步实现基于向量相似度的文档检索"""
# 注意:大多数向量数据库客户端不支持异步,这里使用线程池执行同步操作
import asyncio
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, lambda: self.get_relevant_documents(query)
)
VectorStoreRetriever
通过封装VectorStore
接口,实现了基于向量相似度的文档检索功能。它将查询文本转换为向量,然后在向量数据库中查找最相似的文档。
3.3 混合检索器实现
LangChain还提供了混合检索器,可以组合多种检索策略:
python
from typing import List, Dict, Any, Optional
class CombinedRetriever(BaseRetriever):
"""组合多个检索器的结果"""
def __init__(self, retrievers: List[BaseRetriever], weights: Optional[List[float]] = None):
"""初始化混合检索器"""
self.retrievers = retrievers
self.weights = weights or [1.0 / len(retrievers)] * len(retrievers)
# 验证权重
if len(self.weights) != len(self.retrievers):
raise ValueError("Number of weights must match number of retrievers")
if sum(self.weights) != 1.0:
self.weights = [w / sum(self.weights) for w in self.weights]
def get_relevant_documents(self, query: str) -> List[Dict[str, Any]]:
"""组合多个检索器的结果"""
# 从每个检索器获取结果
all_results = []
for retriever in self.retrievers:
results = retriever.get_relevant_documents(query)
all_results.append(results)
# 合并结果并根据权重排序
merged_results = self._merge_results(all_results, self.weights)
return merged_results
def _merge_results(self, results_list: List[List[Dict[str, Any]]],
weights: List[float]) -> List[Dict[str, Any]]:
"""合并多个检索结果并根据权重排序"""
# 这里简化实现,实际可能需要更复杂的融合策略
merged = []
for i, results in enumerate(results_list):
weight = weights[i]
for result in results:
# 复制结果并调整相似度分数
new_result = {**result}
if "similarity" in new_result:
new_result["similarity"] *= weight
merged.append(new_result)
# 根据相似度排序
merged.sort(key=lambda x: x.get("similarity", 0), reverse=True)
return merged
CombinedRetriever
可以组合多个不同的检索器,例如同时使用向量检索器和基于关键词的检索器,通过加权的方式合并它们的结果,从而提高检索的准确性。
四、解析器与问答系统集成源码解析
4.1 问答链基础架构
LangChain通过Chain
抽象类定义了处理流程的基本框架,问答链是其中的一种具体实现:
python
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
class Chain(ABC):
"""处理链的抽象基类"""
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""返回链的输入键"""
pass
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""返回链的输出键"""
pass
@abstractmethod
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""执行链的核心逻辑"""
pass
def run(self, **kwargs: Any) -> Any:
"""运行链并返回结果"""
# 验证输入
for key in self.input_keys:
if key not in kwargs:
raise ValueError(f"Missing input key: {key}")
# 执行链
output = self._call(kwargs)
# 返回结果
if len(self.output_keys) == 1:
return output[self.output_keys[0]]
return output
问答链继承自Chain
类,实现了基于检索的问答功能:
python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import BaseLLM
class RetrievalQA(Chain):
"""基于检索的问答链"""
def __init__(self, retriever: BaseRetriever, llm: BaseLLM,
prompt: PromptTemplate = None,
return_source_documents: bool = False):
"""初始化问答链"""
self.retriever = retriever
self.llm = llm
self.return_source_documents = return_source_documents
# 设置默认提示模板
if prompt is None:
prompt_template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
self.prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
else:
self.prompt = prompt
# 创建LLM链
self.llm_chain = LLMChain(llm=llm, prompt=self.prompt)
@property
def input_keys(self) -> List[str]:
"""返回输入键"""
return ["question"]
@property
def output_keys(self) -> List[str]:
"""返回输出键"""
keys = ["answer"]
if self.return_source_documents:
keys.append("source_documents")
return keys
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""执行问答链的核心逻辑"""
question = inputs["question"]
# 检索相关文档
docs = self.retriever.get_relevant_documents(question)
# 提取上下文
context = "\n\n".join([doc["content"] for doc in docs])
# 生成答案
answer = self.llm_chain.run(context=context, question=question)
# 准备输出
output = {"answer": answer}
if self.return_source_documents:
output["source_documents"] = docs
return output
4.2 提示模板设计
提示模板在问答系统中起着关键作用,它决定了如何将检索到的文档和用户问题组合成模型可以理解的输入:
python
from langchain.prompts import PromptTemplate
# 基础问答提示模板
QA_PROMPT = PromptTemplate(
template="""Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:""",
input_variables=["context", "question"]
)
# 更复杂的提示模板,包含指令和示例
ADVANCED_QA_PROMPT = PromptTemplate(
template="""You are an AI assistant specialized in answering technical questions.
Your answers should be accurate, concise, and based on the provided context.
If the context does not provide enough information to answer the question,
indicate that the answer is not available in the provided context.
Examples:
Context: "Python is a high-level programming language. It was created by Guido van Rossum."
Question: "Who created Python?"
Answer: "Python was created by Guido van Rossum."
Context: "LangChain is a framework for developing applications powered by language models."
Question: "What is LangChain used for?"
Answer: "LangChain is used for developing applications powered by language models."
Current context:
{context}
Question: {question}
Helpful Answer:""",
input_variables=["context", "question"]
)
提示模板的设计需要考虑多个因素,如模型的特性、领域知识、用户期望等。好的提示模板可以显著提高问答系统的性能。
4.3 流式问答实现
LangChain支持流式问答,允许在生成答案的过程中逐步返回结果:
python
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
class StreamingHandler(BaseCallbackHandler):
"""流式处理回调"""
def __init__(self, queue):
self.queue = queue
def on_llm_new_token(self, token: str, **kwargs) -> None:
"""处理新生成的token"""
self.queue.put(token)
def on_llm_end(self, *args, **kwargs) -> None:
"""处理生成结束"""
self.queue.put(None)
def get_streaming_qa_chain(retriever, model_name="gpt-3.5-turbo"):
"""创建支持流式输出的问答链"""
# 创建流式LLM
llm = OpenAI(
model_name=model_name,
streaming=True,
temperature=0
)
# 创建问答链
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever
)
return chain
def stream_answer(chain, question):
"""流式生成答案"""
import queue
# 创建队列用于接收流式输出
output_queue = queue.Queue()
handler = StreamingHandler(output_queue)
# 异步运行链
import threading
def task():
chain.run(question=question, callbacks=[handler])
thread = threading.Thread(target=task)
thread.start()
# 生成器模式返回流式输出
while True:
token = output_queue.get()
if token is None:
break
yield token
流式问答对于长答案或需要实时交互的场景非常有用,可以提高用户体验。
五、解析器与文本生成集成源码解析
5.1 文本生成链实现
文本生成链是LangChain中用于文本生成任务的核心组件:
python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import BaseLLM
class TextGenerationChain(LLMChain):
"""文本生成链"""
def __init__(self, llm: BaseLLM, prompt: PromptTemplate = None,
output_key: str = "text", **kwargs):
"""初始化文本生成链"""
# 设置默认提示模板
if prompt is None:
prompt = PromptTemplate(
input_variables=["input"],
template="{input}"
)
super().__init__(llm=llm, prompt=prompt, output_key=output_key, **kwargs)
@classmethod
def from_llm_and_prompt(cls, llm: BaseLLM, prompt: PromptTemplate, **kwargs):
"""从LLM和提示模板创建文本生成链"""
return cls(llm=llm, prompt=prompt, **kwargs)
TextGenerationChain
继承自LLMChain
,提供了更简单的接口用于文本生成任务。它可以根据给定的提示模板和LLM,生成符合要求的文本。
5.2 摘要生成链实现
摘要生成是文本生成的一个重要应用场景,LangChain提供了专门的摘要生成链:
python
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
class SummarizationChain:
"""摘要生成链"""
def __init__(self, llm: BaseLLM, chain_type: str = "stuff",
text_splitter: RecursiveCharacterTextSplitter = None):
"""初始化摘要生成链"""
self.llm = llm
self.chain_type = chain_type
# 设置默认文本拆分器
self.text_splitter = text_splitter or RecursiveCharacterTextSplitter(
chunk_size=4000,
chunk_overlap=200
)
# 加载摘要链
self.chain = load_summarize_chain(
llm=llm,
chain_type=chain_type
)
def summarize(self, text: str) -> str:
"""生成文本摘要"""
# 拆分文本
docs = [Document(page_content=text)]
split_docs = self.text_splitter.split_documents(docs)
# 生成摘要
summary = self.chain.run(split_docs)
return summary
async def asummarize(self, text: str) -> str:
"""异步生成文本摘要"""
import asyncio
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None, lambda: self.summarize(text)
)
摘要生成链支持多种摘要策略,如"stuff"(直接将所有文本放入提示)、"map_reduce"(先对每个文本片段生成摘要,再对摘要进行汇总)等。
5.3 翻译链实现
翻译也是文本生成的一个常见应用,LangChain提供了翻译链的实现:
python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import BaseLLM
class TranslationChain(LLMChain):
"""翻译链"""
def __init__(self, llm: BaseLLM, source_language: str = "English",
target_language: str = "Chinese", **kwargs):
"""初始化翻译链"""
# 创建翻译提示模板
prompt = PromptTemplate(
input_variables=["text"],
template=f"Translate the following text from {source_language} to {target_language}:\n\n{{text}}"
)
super().__init__(llm=llm, prompt=prompt, **kwargs)
@classmethod
def from_llm(cls, llm: BaseLLM, source_language: str = "English",
target_language: str = "Chinese", **kwargs):
"""从LLM创建翻译链"""
return cls(llm=llm, source_language=source_language, target_language=target_language, **kwargs)
def translate(self, text: str) -> str:
"""执行翻译"""
return self.run(text=text)
async def atranslate(self, text: str) -> str:
"""异步执行翻译"""
return await self.arun(text=text)
翻译链通过设计合适的提示模板,利用LLM的能力实现文本翻译功能。可以根据需要指定源语言和目标语言。
六、解析器与知识图谱集成源码解析
6.1 知识提取器实现
知识提取器负责从文本中提取实体、关系和属性,构建知识图谱的基础:
python
from typing import List, Dict, Any, Tuple
import spacy
from spacy.tokens import Doc, Span
class KnowledgeExtractor:
"""知识提取器"""
def __init__(self, model_name: str = "en_core_web_sm"):
"""初始化知识提取器"""
# 加载spaCy模型
self.nlp = spacy.load(model_name)
# 添加实体识别器
if "ner" not in self.nlp.pipe_names:
self.nlp.add_pipe("ner")
def extract_entities(self, text: str) -> List[Dict[str, Any]]:
"""提取文本中的实体"""
doc = self.nlp(text)
entities = []
for ent in doc.ents:
entities.append({
"text": ent.text,
"label": ent.label_,
"start": ent.start_char,
"end": ent.end_char
})
return entities
def extract_relations(self, text: str) -> List[Dict[str, Any]]:
"""提取文本中的关系"""
doc = self.nlp(text)
relations = []
# 简单的关系提取策略:基于依存句法
for token in doc:
if token.dep_ in ("attr", "dobj"):
subject = [w for w in token.head.lefts if w.dep_ == "nsubj"]
if subject:
subject = subject[0]
relations.append({
"subject": subject.text,
"relation": token.head.text,
"object": token.text
})
return relations
def extract_knowledge(self, text: str) -> Dict[str, Any]:
"""提取文本中的知识(实体和关系)"""
entities = self.extract_entities(text)
relations
python
def extract_knowledge(self, text: str) -> Dict[str, Any]:
"""提取文本中的知识(实体和关系)"""
entities = self.extract_entities(text)
relations = self.extract_relations(text)
return {
"entities": entities,
"relations": relations,
"text": text
}
def batch_extract_knowledge(self, texts: List[str]) -> List[Dict[str, Any]]:
"""批量提取文本中的知识"""
results = []
for doc in self.nlp.pipe(texts):
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
relations = []
# 基于依存句法的关系提取
for token in doc:
if token.dep_ == "dobj" or token.dep_ == "pobj":
subject = next((w for w in token.head.lefts if w.dep_ == "nsubj"), None)
if subject:
relations.append({
"subject": subject.text,
"relation": token.head.text,
"object": token.text
})
results.append({
"entities": entities,
"relations": relations,
"text": doc.text
})
return results
知识提取器的实现基于spaCy库,它提供了强大的自然语言处理能力。extract_entities
方法使用命名实体识别(NER)技术从文本中提取实体,extract_relations
方法基于依存句法分析提取实体之间的关系。
6.2 知识图谱构建实现
知识图谱构建器负责将提取的知识整合到图谱结构中:
python
from typing import List, Dict, Any
from rdflib import Graph, URIRef, Literal, Namespace
from rdflib.namespace import RDF, RDFS
class KnowledgeGraph:
"""知识图谱"""
def __init__(self, namespace: str = "http://example.org/"):
"""初始化知识图谱"""
self.graph = Graph()
self.ns = Namespace(namespace)
# 定义常见命名空间
self.graph.bind("rdf", RDF)
self.graph.bind("rdfs", RDFS)
self.graph.bind("ex", self.ns)
def add_entity(self, entity_id: str, entity_type: str, label: str) -> None:
"""添加实体到知识图谱"""
entity_uri = self.ns[entity_id]
self.graph.add((entity_uri, RDF.type, self.ns[entity_type]))
self.graph.add((entity_uri, RDFS.label, Literal(label)))
def add_relation(self, subject_id: str, relation_type: str, object_id: str) -> None:
"""添加关系到知识图谱"""
subject_uri = self.ns[subject_id]
object_uri = self.ns[object_id]
self.graph.add((subject_uri, self.ns[relation_type], object_uri))
def add_text_metadata(self, text_id: str, text: str, source: str = None) -> None:
"""添加文本元数据到知识图谱"""
text_uri = self.ns[text_id]
self.graph.add((text_uri, RDF.type, self.ns["Text"]))
self.graph.add((text_uri, RDFS.label, Literal(text)))
if source:
self.graph.add((text_uri, self.ns["source"], Literal(source)))
def from_knowledge_extraction(self, extractions: List[Dict[str, Any]],
source: str = None) -> None:
"""从知识提取结果构建知识图谱"""
for i, extraction in enumerate(extractions):
text_id = f"text_{i}"
self.add_text_metadata(text_id, extraction["text"], source)
# 添加实体
entity_map = {}
for j, entity in enumerate(extraction["entities"]):
entity_id = f"entity_{i}_{j}"
entity_map[entity["text"]] = entity_id
self.add_entity(entity_id, entity["label"], entity["text"])
# 关联实体和文本
self.add_relation(entity_id, "extractedFrom", text_id)
# 添加关系
for relation in extraction["relations"]:
if relation["subject"] in entity_map and relation["object"] in entity_map:
self.add_relation(
entity_map[relation["subject"]],
relation["relation"],
entity_map[relation["object"]]
)
def serialize(self, format: str = "turtle") -> str:
"""序列化知识图谱"""
return self.graph.serialize(format=format)
def query(self, sparql_query: str) -> List[Dict[str, Any]]:
"""执行SPARQL查询"""
results = []
query_result = self.graph.query(sparql_query)
for row in query_result:
result_row = {}
for var in query_result.vars:
result_row[str(var)] = str(row[var])
results.append(result_row)
return results
知识图谱构建器使用rdflib库实现,它提供了RDF图的操作能力。add_entity
和add_relation
方法用于向图谱中添加实体和关系,from_knowledge_extraction
方法可以从知识提取结果批量构建知识图谱。
6.3 知识图谱与解析器集成实现
下面是知识图谱与解析器集成的实现:
python
from typing import List, Dict, Any
from langchain.document_loaders import BaseLoader
from langchain.text_splitter import TextSplitter
from langchain.embeddings import Embeddings
class KnowledgeGraphIntegrator:
"""知识图谱与解析器集成器"""
def __init__(self, extractor: KnowledgeExtractor,
graph: KnowledgeGraph,
text_splitter: TextSplitter = None,
embeddings: Embeddings = None):
"""初始化集成器"""
self.extractor = extractor
self.graph = graph
self.text_splitter = text_splitter
self.embeddings = embeddings
def process_documents(self, documents: List[Dict[str, Any]],
source: str = None) -> None:
"""处理文档并构建知识图谱"""
texts = [doc["content"] for doc in documents]
if self.text_splitter:
# 拆分文本
split_texts = []
for text in texts:
split_texts.extend(self.text_splitter.split_text(text))
texts = split_texts
# 提取知识
knowledge_extractions = self.extractor.batch_extract_knowledge(texts)
# 构建知识图谱
self.graph.from_knowledge_extraction(knowledge_extractions, source)
def process_loader(self, loader: BaseLoader, source: str = None) -> None:
"""处理文档加载器并构建知识图谱"""
documents = loader.load()
self.process_documents(documents, source)
def get_related_entities(self, entity: str, relation_type: str = None) -> List[Dict[str, Any]]:
"""获取与实体相关的其他实体"""
# 简单的SPARQL查询示例
query = f"""
PREFIX ex: <{self.graph.ns}>
SELECT ?relatedEntity ?relation ?label
WHERE {{
?entity rdfs:label "{entity}" .
?entity ?relation ?relatedEntity .
?relatedEntity rdfs:label ?label .
FILTER(?relation != rdf:type && ?relation != rdfs:label)
{f'FILTER(?relation = ex:{relation_type})' if relation_type else ''}
}}
"""
return self.graph.query(query)
def answer_question_with_kg(self, question: str, llm_chain=None) -> str:
"""利用知识图谱回答问题"""
# 从问题中提取实体
entities = self.extractor.extract_entities(question)
if not entities:
return "No entities found in the question"
# 以第一个实体为例进行查询
main_entity = entities[0]["text"]
related_entities = self.get_related_entities(main_entity)
if not related_entities:
return f"No related entities found for {main_entity}"
# 构建上下文
context = f"Knowledge about {main_entity}:\n"
for entity in related_entities:
context += f"- {entity['label']} is {entity['relation'].split('/')[-1]} of {main_entity}\n"
# 如果提供了LLM链,使用它生成答案
if llm_chain:
return llm_chain.run(context=context, question=question)
# 否则直接返回上下文
return context
知识图谱与解析器集成器提供了处理文档并构建知识图谱的功能。它可以与文档加载器、文本拆分器和嵌入生成器协同工作,形成一个完整的知识处理流程。
七、解析器与推理系统集成源码解析
7.1 推理引擎接口设计
LangChain通过ReasoningEngine
抽象类定义了推理引擎的统一接口:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
class ReasoningEngine(ABC):
"""推理引擎的抽象基类"""
@abstractmethod
def reason(self, premises: List[Dict[str, Any]],
rules: List[Dict[str, Any]],
query: Dict[str, Any]) -> Dict[str, Any]:
"""基于前提和规则进行推理"""
pass
@abstractmethod
def explain(self, result: Dict[str, Any]) -> str:
"""生成推理过程的解释"""
pass
这个接口定义了两个核心方法:reason
用于执行推理,explain
用于生成推理过程的解释。
7.2 基于规则的推理引擎实现
下面是一个基于规则的推理引擎实现:
python
from typing import List, Dict, Any, Optional
import re
class RuleBasedReasoningEngine(ReasoningEngine):
"""基于规则的推理引擎"""
def __init__(self, knowledge_graph: KnowledgeGraph = None):
"""初始化推理引擎"""
self.knowledge_graph = knowledge_graph
self.facts = []
self.rules = []
def add_fact(self, fact: Dict[str, Any]) -> None:
"""添加事实"""
self.facts.append(fact)
def add_rule(self, rule: Dict[str, Any]) -> None:
"""添加规则"""
self.rules.append(rule)
def reason(self, premises: List[Dict[str, Any]],
rules: List[Dict[str, Any]],
query: Dict[str, Any]) -> Dict[str, Any]:
"""基于前提和规则进行推理"""
# 合并前提和已有事实
all_facts = self.facts + premises
# 合并规则
all_rules = self.rules + rules
# 执行前向链接推理
inferred_facts = self._forward_chaining(all_facts, all_rules)
# 检查查询是否被推断出来
result = {"is_true": False, "explanation": "", "facts": inferred_facts}
for fact in inferred_facts:
if self._match_fact(fact, query):
result["is_true"] = True
break
return result
def _forward_chaining(self, facts: List[Dict[str, Any]],
rules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""执行前向链接推理"""
inferred_facts = facts.copy()
new_facts_added = True
while new_facts_added:
new_facts_added = False
for rule in rules:
# 检查规则前提是否满足
premises = rule.get("premises", [])
all_premises_matched = True
bindings = {}
for premise in premises:
matched = False
for fact in inferred_facts:
fact_bindings = self._match_fact_with_bindings(fact, premise)
if fact_bindings is not None:
# 检查绑定是否一致
if self._check_bindings_consistency(bindings, fact_bindings):
bindings.update(fact_bindings)
matched = True
break
if not matched:
all_premises_matched = False
break
# 如果所有前提都匹配,应用规则
if all_premises_matched:
conclusion = rule.get("conclusion", {})
# 应用绑定到结论
grounded_conclusion = self._apply_bindings(conclusion, bindings)
# 检查是否是新事实
if not any(self._match_fact(fact, grounded_conclusion) for fact in inferred_facts):
inferred_facts.append(grounded_conclusion)
new_facts_added = True
return inferred_facts
def _match_fact(self, fact: Dict[str, Any], pattern: Dict[str, Any]) -> bool:
"""检查事实是否匹配模式"""
for key, value in pattern.items():
if key not in fact or fact[key] != value:
return False
return True
def _match_fact_with_bindings(self, fact: Dict[str, Any], pattern: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""检查事实是否匹配模式并返回变量绑定"""
bindings = {}
for key, value in pattern.items():
if key not in fact:
return None
if isinstance(value, str) and value.startswith("?"):
# 变量
var_name = value[1:]
if var_name in bindings:
if bindings[var_name] != fact[key]:
return None
else:
bindings[var_name] = fact[key]
else:
# 常量
if fact[key] != value:
return None
return bindings
def _check_bindings_consistency(self, existing_bindings: Dict[str, Any],
new_bindings: Dict[str, Any]) -> bool:
"""检查新绑定是否与现有绑定一致"""
for var_name, value in new_bindings.items():
if var_name in existing_bindings and existing_bindings[var_name] != value:
return False
return True
def _apply_bindings(self, expression: Dict[str, Any],
bindings: Dict[str, Any]) -> Dict[str, Any]:
"""应用绑定到表达式"""
result = {}
for key, value in expression.items():
if isinstance(value, str) and value.startswith("?"):
var_name = value[1:]
result[key] = bindings.get(var_name, value)
else:
result[key] = value
return result
def explain(self, result: Dict[str, Any]) -> str:
"""生成推理过程的解释"""
if not result["is_true"]:
return "The query could not be inferred from the given premises and rules."
explanation = "The query is true based on the following reasoning:\n"
# 简化的解释生成
for fact in result["facts"]:
explanation += f"- {fact}\n"
return explanation
def reason_with_kg(self, query: str) -> Dict[str, Any]:
"""结合知识图谱进行推理"""
if not self.knowledge_graph:
raise ValueError("Knowledge graph is required for reasoning with KG")
# 从查询中提取实体和关系
entities = self.knowledge_graph.extractor.extract_entities(query)
if not entities:
return {"is_true": False, "explanation": "No entities found in the query"}
# 构建基于知识图谱的前提
premises = []
for entity in entities:
related_entities = self.knowledge_graph.get_related_entities(entity["text"])
for rel_entity in related_entities:
premises.append({
"subject": rel_entity["label"],
"relation": rel_entity["relation"].split("/")[-1],
"object": entity["text"]
})
# 简单的推理规则示例
rules = [
{
"premises": [
{"subject": "?x", "relation": "authorOf", "object": "?book"},
{"subject": "?book", "relation": "topic", "object": "AI"}
],
"conclusion": {"subject": "?x", "relation": "expertIn", "object": "AI"}
}
]
# 构建查询事实
query_fact = {"subject": entities[0]["text"], "relation": "expertIn", "object": "AI"}
# 执行推理
return self.reason(premises, rules, query_fact)
基于规则的推理引擎使用前向链接算法实现推理过程。它可以处理事实和规则,并根据这些进行逻辑推理。reason_with_kg
方法展示了如何结合知识图谱进行推理。
7.3 推理链实现
推理链将解析器、检索器和推理引擎组合在一起,形成一个完整的推理流程:
python
from langchain.chains import Chain
from langchain.retrievers import BaseRetriever
from typing import List, Dict, Any
class ReasoningChain(Chain):
"""推理链"""
def __init__(self, retriever: BaseRetriever,
reasoning_engine: ReasoningEngine,
llm_chain=None):
"""初始化推理链"""
self.retriever = retriever
self.reasoning_engine = reasoning_engine
self.llm_chain = llm_chain
@property
def input_keys(self) -> List[str]:
"""返回输入键"""
return ["query"]
@property
def output_keys(self) -> List[str]:
"""返回输出键"""
return ["answer", "explanation"]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""执行推理链"""
query = inputs["query"]
# 检索相关文档
docs = self.retriever.get_relevant_documents(query)
# 从文档中提取前提
premises = []
for doc in docs:
# 简化处理,实际应用中可能需要更复杂的提取
premises.append({"text": doc["content"], "source": doc["metadata"].get("source", "")})
# 定义一些简单的推理规则
rules = [
{
"premises": [{"text": "?x是?y的作者", "variable": "?x", "relation": "authorOf", "object": "?y"}],
"conclusion": {"subject": "?x", "relation": "expertIn", "object": "?y"}
}
]
# 执行推理
result = self.reasoning_engine.reason(premises, rules, {"query": query})
# 生成答案
if self.llm_chain:
context = "\n".join([premise["text"] for premise in premises])
answer = self.llm_chain.run(context=context, question=query)
else:
answer = "基于推理,查询结果为: " + ("是" if result["is_true"] else "否")
return {
"answer": answer,
"explanation": self.reasoning_engine.explain(result)
}
推理链将检索器和推理引擎结合在一起,首先检索与查询相关的文档,然后从文档中提取前提,最后使用推理引擎进行推理并生成答案。
八、解析器与多模态处理集成源码解析
8.1 多模态解析器设计
多模态解析器可以处理不同类型的媒体内容,包括文本、图像和音频:
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import numpy as np
class MultiModalParser:
"""多模态解析器"""
def __init__(self, text_parser=None, image_parser=None, audio_parser=None):
"""初始化多模态解析器"""
self.text_parser = text_parser
self.image_parser = image_parser
self.audio_parser = audio_parser
def parse(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""解析多模态数据"""
result = {}
# 解析文本
if "text" in data and self.text_parser:
text_result = self.text_parser.parse(data["text"])
result["text"] = text_result
# 解析图像
if "images" in data and self.image_parser:
image_results = []
for image in data["images"]:
image_result = self.image_parser.parse(image)
image_results.append(image_result)
result["images"] = image_results
# 解析音频
if "audio" in data and self.audio_parser:
audio_result = self.audio_parser.parse(data["audio"])
result["audio"] = audio_result
# 整合多模态结果
if len(result) > 1:
result["combined"] = self._combine_results(result)
return result
def _combine_results(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""整合多模态解析结果"""
# 简单实现,实际应用中可能需要更复杂的整合策略
combined = {}
# 提取关键实体和关系
entities = []
relations = []
if "text" in results:
entities.extend(results["text"].get("entities", []))
relations.extend(results["text"].get("relations", []))
if "images" in results:
for image_result in results["images"]:
entities.extend(image_result.get("objects", []))
image_relations = image_result.get("relations", [])
for rel in image_relations:
# 为图像关系添加来源标识
rel["source"] = "image"
relations.extend(image_relations)
combined["entities"] = entities
combined["relations"] = relations
return combined
多模态解析器通过组合不同类型的解析器(文本、图像、音频),可以处理包含多种媒体类型的输入。它将不同模态的解析结果整合在一起,形成统一的表示。
8.2 图像解析器实现
图像解析器负责处理图像内容,提取对象和关系:
python
from typing import List, Dict, Any
import torch
from PIL import Image
import requests
from transformers import DetrFeatureExtractor, DetrForObjectDetection
class ImageParser:
"""图像解析器"""
def __init__(self, model_name: str = "facebook/detr-resnet-50"):
"""初始化图像解析器"""
self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
self.model = DetrForObjectDetection.from_pretrained(model_name)
def parse(self, image: Image.Image) -> Dict[str, Any]:
"""解析图像"""
# 预处理图像
inputs = self.feature_extractor(images=image, return_tensors="pt")
# 运行模型
with torch.no_grad():
outputs = self.model(**inputs)
# 后处理结果
target_sizes = torch.tensor([image.size[::-1]])
results = self.feature_extractor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9
)[0]
# 构建解析结果
objects = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
objects.append({
"label": self.model.config.id2label[label.item()],
"confidence": float(score),
"bounding_box": box
})
# 提取图像关系(简化版)
relations = self._extract_relations(objects)
return {
"objects": objects,
"relations": relations,
"metadata": {"width": image.width, "height": image.height}
}
def _extract_relations(self, objects: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""提取图像中对象之间的关系"""
# 简化实现,实际应用中可能需要更复杂的关系提取算法
relations = []
# 示例:如果两个对象的边界框重叠超过一定阈值,则认为它们有关系
for i in range(len(objects)):
for j in range(i + 1, len(objects)):
box1 = objects[i]["bounding_box"]
box2 = objects[j]["bounding_box"]
# 计算IoU (Intersection over Union)
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = area1 + area2 - intersection
iou = intersection / union if union > 0 else 0
if iou > 0.1: # 阈值可调整
relations.append({
"subject": objects[i]["label"],
"relation": "near",
"object": objects[j]["label"],
"confidence": iou
})
return relations
图像解析器使用DETR模型检测图像中的对象,并提取它们之间的关系。它返回一个包含检测到的对象和关系的字典。
8.3 多模态嵌入生成器实现
多模态嵌入生成器可以为不同类型的内容生成统一的向量表示:
python
from typing import List, Dict, Any, Optional
import numpy as np
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
import torch
from transformers import CLIPProcessor, CLIPModel
class MultiModalEmbeddings:
"""多模态嵌入生成器"""
def __init__(self, text_embeddings=None, image_embeddings=None):
"""初始化多模态嵌入生成器"""
# 文本嵌入模型
self.text_embeddings = text_embeddings or OpenAIEmbeddings()
# 图像嵌入模型
if image_embeddings is None:
self.image_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
self.image_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
else:
self.image_model = image_embeddings
def embed_text(self, text: str) -> np.ndarray:
"""生成文本嵌入"""
return self.text_embeddings.embed_query(text)
def embed_image(self, image: Image.Image) -> np.ndarray:
"""生成图像嵌入"""
inputs = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.image_model(**inputs)
# 获取图像特征向量
image_features = outputs.vision_model_output[1] # [CLS] token
return image_features.numpy().flatten()
def embed_multi_modal(self, data: Dict[str, Any]) -> np.ndarray:
"""生成多模态内容的联合嵌入"""
embeddings = []
# 文本嵌入
if "text" in data:
text_embedding = self.embed_text(data["text"])
embeddings.append(text_embedding)
# 图像嵌入
if "image" in data:
image_embedding = self.embed_image(data["image"])
embeddings.append(image_embedding)
# 如果有多个模态,简单拼接它们的嵌入
if len(embeddings) > 1:
return np.concatenate(embeddings)
# 如果只有一种模态,直接返回其嵌入
if embeddings:
return embeddings[0]
raise ValueError("No data provided for embedding")
多模态嵌入生成器结合了文本和图像嵌入模型,能够为不同类型的内容生成统一的向量表示。这对于多模态检索和问答非常有用。
九、性能优化与调优源码解析
9.1 缓存机制实现
缓存机制可以避免重复计算,提高系统性能:
python
from typing import Dict, Any, Callable, Optional
import hashlib
import json
from functools import wraps
class Cache:
"""缓存机制"""
def __init__(self, cache_dir: str = ".cache", max_size: int = 1000):
"""初始化缓存"""
self.cache_dir = cache_dir
self.max_size = max_size
self.cache = {} # 内存缓存
self.cache_keys = [] # 维护缓存顺序
# 创建缓存目录
import os
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
def get_key(self, *args, **kwargs) -> str:
"""生成缓存键"""
# 将参数转换为字符串
args_str = json.dumps(args, sort_keys=True)
kwargs_str = json.dumps(kwargs, sort_keys=True)
combined = f"{args_str}_{kwargs_str}"
# 计算哈希值作为缓存键
return hashlib.sha256(combined.encode()).hexdigest()
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
# 先检查内存缓存
if key in self.cache:
return self.cache[key]
# 检查磁盘缓存
import os
file_path = os.path.join(self.cache_dir, key)
if os.path.exists(file_path):
try:
with open(file_path, "r") as f:
return json.load(f)
except Exception:
return None
return None
def set(self, key: str, value: Any) -> None:
"""设置缓存值"""
# 更新内存缓存
self.cache[key] = value
self.cache_keys.append(key)
# 如果缓存超过最大大小,移除最旧的项
if len(self.cache) > self.max_size:
oldest_key = self.cache_keys.pop(0)
self.cache.pop(oldest_key, None)
# 从磁盘删除
import os
file_path = os.path.join(self.cache_dir, oldest_key)
if os.path.exists(file_path):
os.remove(file_path)
# 保存到磁盘
import os
file_path = os.path.join(self.cache_dir, key)
try:
with open(file_path, "w") as f:
json.dump(value, f)
except Exception:
# 缓存失败不影响主流程
pass
def cache_decorator(self, func: Callable) -> Callable:
"""缓存装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键
key = self.get_key(*args, **kwargs)
# 检查缓存
cached_result = self.get(key)
if cached_result is not None:
return cached_result
# 执行函数
result = func(*args, **kwargs)
# 缓存结果
self.set(key, result)
return result
return wrapper
缓存机制可以应用于各种耗时的操作,如嵌入生成、文档解析等。通过装饰器的方式,可以很方便地为函数添加缓存功能。
9.2 并行处理实现
并行处理可以充分利用多核CPU的计算能力,提高系统吞吐量:
python
from typing import List, Dict, Any, Callable, Optional
import concurrent.futures
from tqdm import tqdm
class ParallelProcessor:
"""并行处理器"""
def __init__(self, max_workers: int = None):
"""初始化并行处理器"""
self.max_workers = max_workers
def process(self, func: Callable, items: List[Any],
desc: str = "Processing",
progress: bool = True) -> List[Any]:
"""并行处理项目"""
results = []
# 如果只有一个项目,直接处理
if len(items) == 1:
return [func(items[0])]
# 使用线程池并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交任务
futures = [executor.submit(func, item) for item in items]
# 获取结果
if progress:
futures_iter = tqdm(concurrent.futures.as_completed(futures),
total=len(items), desc=desc)
else:
futures_iter = concurrent.futures.as_completed(futures)
# 收集结果
for future in futures_iter:
try:
results.append(future.result())
except Exception as e:
# 记录错误但继续处理其他项目
print(f"Error processing item: {e}")
results.append(None)
# 按原始顺序排序结果
ordered_results = [None] * len(items)
for i, future in enumerate(futures):
if results[i] is not None:
ordered_results[i] = results[i]
return ordered_results
def process_batch(self, func: Callable, items: List[Any],
batch_size: int = 32, desc: str = "Processing") -> List[Any]:
"""分批并行处理项目"""
results = []
# 分批处理
for i in tqdm(range(0, len(items), batch_size), desc=desc):
batch = items[i:i+batch_size]
batch_results = self.process(func, batch, desc=f"Batch {i//batch_size+1}", progress=False)
results.extend(batch_results)
return results
并行处理器可以应用于各种场景,如批量文档解析、并行嵌入生成等。通过合理设置最大工作线程数,可以充分利用系统资源。
9.3 异步处理实现
异步处理适合I/O密集型任务,可以提高系统的并发性能:
python
import asyncio
from typing import List, Dict, Any, Callable, Optional
from langchain.callbacks import AsyncIteratorCallbackHandler
class AsyncProcessor:
"""异步处理器"""
async def process(self, func: Callable, items: List[Any],
concurrency: int = 10) -> List[Any]:
"""异步处理项目"""
semaphore = asyncio.Semaphore(concurrency)
tasks = []
async def process_item(item):
async with semaphore:
return await func(item)
# 创建任务
for item in items:
tasks.append(process_item(item))
# 执行所有任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常
final_results = []
for result in results:
if isinstance(result, Exception):
print(f"Error processing item: {result}")
final_results.append(
python
# 处理异常
final_results = []
for result in results:
if isinstance(result, Exception):
print(f"Error processing item: {result}")
final_results.append(None)
else:
final_results.append(result)
return final_results
async def stream_results(self, func: Callable, items: List[Any],
concurrency: int = 10):
"""异步流式处理结果"""
semaphore = asyncio.Semaphore(concurrency)
queue = asyncio.Queue()
async def process_and_enqueue(item):
async with semaphore:
result = await func(item)
await queue.put(result)
# 创建任务
tasks = [process_and_enqueue(item) for item in items]
# 启动消费者协程
async def consume():
for _ in range(len(items)):
result = await queue.get()
yield result
queue.task_done()
# 启动所有任务
consumers = asyncio.create_task(consume())
producers = asyncio.gather(*tasks)
# 等待所有任务完成
try:
async for result in consumers:
yield result
finally:
await producers
异步处理器提供了两种处理模式:批量处理和流式处理。对于I/O密集型任务,如API调用或文件读取,异步处理可以显著提高性能。stream_results
方法还支持流式返回结果,适合需要实时处理的场景。
9.4 量化与优化实现
模型量化可以减小模型尺寸并提高推理速度:
python
import torch
from transformers import AutoModel, AutoTokenizer
class ModelOptimizer:
"""模型优化器"""
def __init__(self):
"""初始化模型优化器"""
pass
def quantize_model(self, model: torch.nn.Module,
quantization_method: str = "dynamic") -> torch.nn.Module:
"""量化模型以减小尺寸和提高速度"""
if quantization_method == "dynamic":
# 动态量化
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return quantized_model
elif quantization_method == "static":
# 静态量化
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)
# 需要校准数据来确定量化参数
# calibration_data = ...
# model_prepared(calibration_data)
quantized_model = torch.quantization.convert(model_prepared)
return quantized_model
else:
raise ValueError(f"Unsupported quantization method: {quantization_method}")
def optimize_model_for_inference(self, model: torch.nn.Module) -> torch.nn.Module:
"""优化模型以提高推理速度"""
# 使用TorchScript编译模型
try:
scripted_model = torch.jit.script(model)
optimized_model = torch.jit.optimize_for_inference(scripted_model)
return optimized_model
except Exception as e:
print(f"Failed to script model: {e}")
return model
def load_optimized_model(self, model_name: str,
quantization: bool = True,
torchscript: bool = True) -> torch.nn.Module:
"""加载优化后的模型"""
# 加载基础模型
model = AutoModel.from_pretrained(model_name)
# 应用优化
if quantization:
model = self.quantize_model(model)
if torchscript:
model = self.optimize_model_for_inference(model)
return model
模型优化器提供了多种优化技术,包括动态量化、静态量化和TorchScript编译。这些技术可以在不显著损失模型性能的前提下,大幅提高推理速度和降低内存消耗。
9.5 索引优化实现
向量数据库索引优化可以提高检索性能:
python
from typing import List, Dict, Any
import numpy as np
class IndexOptimizer:
"""索引优化器"""
def __init__(self, vector_store):
"""初始化索引优化器"""
self.vector_store = vector_store
def optimize_index(self, nlist: int = 100, nprobes: int = 10,
ef_construction: int = 200, ef_search: int = 50):
"""优化向量存储索引"""
# 根据向量存储类型执行不同的优化策略
if hasattr(self.vector_store, "collection") and hasattr(self.vector_store.collection, "modify"):
# Chroma向量存储优化
try:
self.vector_store.collection.modify(
index_params={
"nlist": nlist,
"nprobes": nprobes
}
)
return True
except Exception as e:
print(f"Failed to optimize Chroma index: {e}")
return False
elif hasattr(self.vector_store, "index") and hasattr(self.vector_store.index, "set_extra_param"):
# FAISS向量存储优化
try:
# 设置索引参数
self.vector_store.index.nprobe = nprobes
return True
except Exception as e:
print(f"Failed to optimize FAISS index: {e}")
return False
elif hasattr(self.vector_store, "index") and hasattr(self.vector_store.index, "config"):
# HNSW类型索引优化
try:
# 更新HNSW参数
self.vector_store.index.config["ef_construction"] = ef_construction
self.vector_store.index.config["ef_search"] = ef_search
return True
except Exception as e:
print(f"Failed to optimize HNSW index: {e}")
return False
return False
def rebuild_index(self, batch_size: int = 1000):
"""重建向量索引以提高性能"""
try:
# 获取所有向量
all_vectors = []
all_metadatas = []
all_texts = []
# 分批获取向量
offset = 0
while True:
results = self.vector_store.collection.get(
limit=batch_size,
offset=offset
)
if not results["ids"]:
break
all_vectors.extend(results["embeddings"])
all_metadatas.extend(results["metadatas"])
all_texts.extend(results["documents"])
offset += batch_size
# 清空当前索引
self.vector_store.collection.delete(ids=self.vector_store.collection.get()["ids"])
# 重新添加所有向量,触发索引重建
self.vector_store.collection.add(
embeddings=all_vectors,
metadatas=all_metadatas,
documents=all_texts,
ids=[str(i) for i in range(len(all_vectors))]
)
return True
except Exception as e:
print(f"Failed to rebuild index: {e}")
return False
def analyze_index_performance(self, queries: List[str], k: int = 4) -> Dict[str, Any]:
"""分析索引性能"""
import time
results = {
"query_times": [],
"avg_query_time": 0,
"total_time": 0,
"recall@k": 0,
"precision@k": 0
}
# 假设我们有一些ground truth结果
# 在实际应用中,这需要根据实际情况设置
ground_truth = [[] for _ in range(len(queries))]
# 执行查询并记录时间
start_time = time.time()
for i, query in enumerate(queries):
query_start = time.time()
query_embedding = self.vector_store.embedding_function.embed_query(query)
docs = self.vector_store.similarity_search(query_embedding, k=k)
query_time = time.time() - query_start
results["query_times"].append(query_time)
# 计算准确率和召回率
retrieved_ids = [doc["metadata"].get("id", "") for doc in docs]
relevant_docs = ground_truth[i]
# 计算交集
intersection = set(retrieved_ids).intersection(set(relevant_docs))
# 计算准确率和召回率
precision = len(intersection) / k if k > 0 else 0
recall = len(intersection) / len(relevant_docs) if len(relevant_docs) > 0 else 0
results["precision@k"] += precision
results["recall@k"] += recall
# 计算平均值
results["total_time"] = time.time() - start_time
results["avg_query_time"] = sum(results["query_times"]) / len(queries) if queries else 0
results["precision@k"] /= len(queries) if queries else 1
results["recall@k"] /= len(queries) if queries else 1
return results
索引优化器提供了多种优化技术,包括调整索引参数、重建索引和性能分析。这些技术可以显著提高向量检索的速度和准确性。
十、安全与隐私保护源码解析
10.1 数据加密实现
数据加密可以保护敏感信息不被未授权访问:
python
from cryptography.fernet import Fernet
import os
class DataEncryptor:
"""数据加密器"""
def __init__(self, encryption_key: str = None):
"""初始化数据加密器"""
if encryption_key:
self.key = encryption_key.encode()
else:
# 生成新密钥
self.key = Fernet.generate_key()
self.cipher_suite = Fernet(self.key)
def get_key(self) -> str:
"""获取加密密钥"""
return self.key.decode()
def encrypt(self, data: str) -> str:
"""加密数据"""
encrypted = self.cipher_suite.encrypt(data.encode())
return encrypted.decode()
def decrypt(self, encrypted_data: str) -> str:
"""解密数据"""
decrypted = self.cipher_suite.decrypt(encrypted_data.encode())
return decrypted.decode()
def encrypt_file(self, file_path: str, output_path: str = None) -> str:
"""加密文件"""
if not output_path:
output_path = file_path + ".encrypted"
with open(file_path, "rb") as f:
data = f.read()
encrypted_data = self.cipher_suite.encrypt(data)
with open(output_path, "wb") as f:
f.write(encrypted_data)
return output_path
def decrypt_file(self, encrypted_file_path: str, output_path: str = None) -> str:
"""解密文件"""
if not output_path:
# 移除.encrypted后缀
output_path = encrypted_file_path
if output_path.endswith(".encrypted"):
output_path = output_path[:-10]
with open(encrypted_file_path, "rb") as f:
encrypted_data = f.read()
decrypted_data = self.cipher_suite.decrypt(encrypted_data)
with open(output_path, "wb") as f:
f.write(decrypted_data)
return output_path
数据加密器使用对称加密算法保护数据安全。它可以加密和解密字符串,也可以处理文件。密钥管理是加密系统的关键,建议安全地存储和传输加密密钥。
10.2 访问控制实现
访问控制可以限制对系统资源的访问权限:
python
from typing import Dict, Any, Callable, List
from enum import Enum
class Role(Enum):
"""用户角色"""
ADMIN = "admin"
USER = "user"
GUEST = "guest"
class AccessControl:
"""访问控制"""
def __init__(self):
"""初始化访问控制"""
self.permissions = {
Role.ADMIN: {
"read": ["*"],
"write": ["*"],
"delete": ["*"],
"manage": ["*"]
},
Role.USER: {
"read": ["documents", "embeddings", "queries"],
"write": ["documents", "queries"],
"delete": ["queries"],
"manage": []
},
Role.GUEST: {
"read": ["documents"],
"write": [],
"delete": [],
"manage": []
}
}
def check_permission(self, role: Role, action: str, resource: str) -> bool:
"""检查用户是否有访问权限"""
# 获取角色的权限
role_permissions = self.permissions.get(role, {})
# 获取该操作的允许资源
allowed_resources = role_permissions.get(action, [])
# 检查是否有通配符权限
if "*" in allowed_resources:
return True
# 检查是否有特定资源权限
return resource in allowed_resources
def require_permission(self, action: str, resource: str):
"""权限检查装饰器"""
def decorator(func: Callable):
def wrapper(self, *args, **kwargs):
# 从上下文中获取用户角色
user_role = self.get_current_user_role() # 假设这是一个实例方法
# 检查权限
if not self.access_control.check_permission(user_role, action, resource):
raise PermissionError(f"User {user_role} does not have permission to {action} {resource}")
# 执行原函数
return func(self, *args, **kwargs)
return wrapper
return decorator
def get_user_permissions(self, role: Role) -> Dict[str, List[str]]:
"""获取用户的所有权限"""
return self.permissions.get(role, {})
def set_permissions(self, role: Role, permissions: Dict[str, List[str]]) -> None:
"""设置角色的权限"""
self.permissions[role] = permissions
访问控制实现了基于角色的权限管理系统。它定义了不同角色(管理员、普通用户、访客)的权限,并提供了权限检查机制和装饰器,方便在代码中应用访问控制。
10.3 隐私保护实现
隐私保护机制可以确保用户数据不被泄露:
python
import re
from typing import List, Dict, Any
class PrivacyProtector:
"""隐私保护器"""
def __init__(self, sensitive_patterns: Dict[str, str] = None):
"""初始化隐私保护器"""
# 预定义的敏感信息模式
self.sensitive_patterns = sensitive_patterns or {
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
"phone": r"\b(?:\+?\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b",
"credit_card": r"\b(?:\d{4}[-.\s]?){3}\d{4}\b",
"ssn": r"\b\d{3}[-.\s]?\d{2}[-.\s]?\d{4}\b",
"ip_address": r"\b(?:\d{1,3}\.){3}\d{1,3}\b"
}
# 替换模式
self.redaction_pattern = "[REDACTED]"
def redact_sensitive_data(self, text: str) -> str:
"""替换文本中的敏感信息"""
for pattern_name, pattern_regex in self.sensitive_patterns.items():
text = re.sub(pattern_regex, self.redaction_pattern, text)
return text
def process_documents(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""处理文档列表,保护敏感信息"""
processed_docs = []
for doc in documents:
processed_doc = {**doc}
# 处理内容
if "content" in processed_doc:
processed_doc["content"] = self.redact_sensitive_data(processed_doc["content"])
# 处理元数据
if "metadata" in processed_doc:
processed_metadata = {}
for key, value in processed_doc["metadata"].items():
if isinstance(value, str):
processed_metadata[key] = self.redact_sensitive_data(value)
else:
processed_metadata[key] = value
processed_doc["metadata"] = processed_metadata
processed_docs.append(processed_doc)
return processed_docs
def anonymize_embeddings(self, embeddings: List[List[float]],
noise_level: float = 0.01) -> List[List[float]]:
"""向嵌入向量添加噪声以保护隐私"""
import numpy as np
# 转换为numpy数组
embeddings_array = np.array(embeddings)
# 生成随机噪声
noise = np.random.normal(0, noise_level, embeddings_array.shape)
# 添加噪声
anonymized_embeddings = embeddings_array + noise
# 转换回列表
return anonymized_embeddings.tolist()
隐私保护器提供了多种隐私保护技术,包括敏感信息识别与替换、数据匿名化等。它可以处理文本内容和元数据,确保敏感信息不被泄露。向嵌入向量添加噪声的技术可以在保持向量语义的同时,降低个人信息被识别的风险。
十一、监控与日志源码解析
11.1 日志记录实现
日志记录可以帮助追踪系统运行状态和问题排查:
python
import logging
from typing import Dict, Any
class Logger:
"""日志记录器"""
def __init__(self, name: str = "langchain",
level: int = logging.INFO,
log_file: str = None):
"""初始化日志记录器"""
# 创建日志记录器
self.logger = logging.getLogger(name)
self.logger.setLevel(level)
# 确保不会重复添加处理器
if not self.logger.handlers:
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
# 创建格式化器并添加到处理器
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
# 将处理器添加到日志记录器
self.logger.addHandler(console_handler)
# 如果指定了日志文件,添加文件处理器
if log_file:
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(level)
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
# 防止日志传播到父记录器
self.logger.propagate = False
def info(self, message: str, extra: Dict[str, Any] = None) -> None:
"""记录信息日志"""
self.logger.info(message, extra=extra)
def warning(self, message: str, extra: Dict[str, Any] = None) -> None:
"""记录警告日志"""
self.logger.warning(message, extra=extra)
def error(self, message: str, extra: Dict[str, Any] = None) -> None:
"""记录错误日志"""
self.logger.error(message, extra=extra)
def exception(self, message: str, extra: Dict[str, Any] = None) -> None:
"""记录异常日志"""
self.logger.exception(message, extra=extra)
def debug(self, message: str, extra: Dict[str, Any] = None) -> None:
"""记录调试日志"""
self.logger.debug(message, extra=extra)
日志记录器封装了Python的logging模块,提供了统一的日志接口。它支持控制台输出和文件输出,可以记录不同级别的日志信息,方便系统监控和问题排查。
11.2 性能监控实现
性能监控可以帮助识别系统瓶颈和优化方向:
python
import time
from typing import Dict, Any, Callable
from functools import wraps
class PerformanceMonitor:
"""性能监控器"""
def __init__(self, logger=None):
"""初始化性能监控器"""
self.metrics = {}
self.logger = logger
def time_function(self, func: Callable) -> Callable:
"""函数执行时间装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
# 执行原函数
result = func(*args, **kwargs)
# 计算执行时间
execution_time = time.time() - start_time
# 记录指标
self._record_metric(func.__name__, "execution_time", execution_time)
# 记录日志
if self.logger:
self.logger.info(f"Function {func.__name__} executed in {execution_time:.4f} seconds")
return result
return wrapper
def _record_metric(self, name: str, metric_type: str, value: float) -> None:
"""记录指标"""
if name not in self.metrics:
self.metrics[name] = {}
if metric_type not in self.metrics[name]:
self.metrics[name][metric_type] = []
self.metrics[name][metric_type].append(value)
def get_metrics(self) -> Dict[str, Any]:
"""获取所有指标"""
return self.metrics
def get_average_metric(self, name: str, metric_type: str) -> float:
"""获取指标的平均值"""
if name in self.metrics and metric_type in self.metrics[name]:
values = self.metrics[name][metric_type]
return sum(values) / len(values) if values else 0
return 0
def reset_metrics(self) -> None:
"""重置所有指标"""
self.metrics = {}
性能监控器提供了函数执行时间监控的装饰器,可以方便地测量各个函数的执行时间。它还提供了指标记录和查询功能,帮助分析系统性能瓶颈。
11.3 错误追踪实现
错误追踪可以帮助快速定位和修复系统问题:
python
import traceback
from typing import Dict, Any, Callable
from functools import wraps
class ErrorTracker:
"""错误追踪器"""
def __init__(self, logger=None, error_handler=None):
"""初始化错误追踪器"""
self.logger = logger
self.error_handler = error_handler
self.errors = []
def track_errors(self, func: Callable) -> Callable:
"""错误追踪装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
try:
# 执行原函数
return func(*args, **kwargs)
except Exception as e:
# 捕获异常
error_info = {
"function": func.__name__,
"error_type": type(e).__name__,
"error_message": str(e),
"traceback": traceback.format_exc(),
"timestamp": time.time()
}
# 记录错误
self._record_error(error_info)
# 记录日志
if self.logger:
self.logger.error(f"Error in {func.__name__}: {str(e)}", extra=error_info)
# 调用错误处理器
if self.error_handler:
self.error_handler(error_info)
# 重新抛出异常
raise
return wrapper
def _record_error(self, error_info: Dict[str, Any]) -> None:
"""记录错误信息"""
self.errors.append(error_info)
def get_errors(self) -> List[Dict[str, Any]]:
"""获取所有错误"""
return self.errors
def get_recent_errors(self, count: int = 10) -> List[Dict[str, Any]]:
"""获取最近的错误"""
return self.errors[-count:]
def clear_errors(self) -> None:
"""清除所有错误记录"""
self.errors = []
错误追踪器提供了错误追踪装饰器,可以捕获函数执行过程中的异常,并记录详细的错误信息。它还提供了错误查询和管理功能,帮助开发者快速定位和解决问题。
十二、集成测试与部署源码解析
12.1 集成测试框架实现
集成测试可以确保系统各组件协同工作正常:
python
import unittest
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
class TestLangChainIntegration(unittest.TestCase):
"""LangChain集成测试"""
def setUp(self):
"""设置测试环境"""
# 创建测试文档
self.test_text = "LangChain is a framework for developing applications powered by language models. " \
"It enables applications that: \n" \
"- Connect language models to sources of data\n" \
"- Connect language models to other computation"
self.test_file = "test_doc.txt"
# 写入测试文档
with open(self.test_file, "w") as f:
f.write(self.test_text)
def tearDown(self):
"""清理测试环境"""
# 删除测试文档
import os
if os.path.exists(self.test_file):
os.remove(self.test_file)
def test_end_to_end_workflow(self):
"""测试端到端工作流程"""
# 1. 加载文档
loader = TextLoader(self.test_file)
documents = loader.load()
# 2. 文本拆分
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
# 3. 嵌入生成
embeddings = OpenAIEmbeddings()
# 4. 创建向量存储
vectorstore = Chroma.from_documents(texts, embeddings)
# 5. 创建检索器
retriever = vectorstore.as_retriever()
# 6. 创建问答链
llm = OpenAI()
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
# 7. 执行查询
query = "What is LangChain?"
result = qa.run(query)
# 8. 验证结果
self.assertIsNotNone(result)
self.assertIn("LangChain", result)
def test_multi_modal_workflow(self):
"""测试多模态工作流程"""
# 注意:这个测试需要图像数据,实际实现中需要适当调整
# 这里仅作为示例展示测试框架的结构
# 1. 初始化多模态解析器
# multi_modal_parser = MultiModalParser(...)
# 2. 准备多模态数据
# data = {
# "text": "This is a test image with a cat",
# "image": Image.open("test_image.jpg")
# }
# 3. 解析多模态数据
# result = multi_modal_parser.parse(data)
# 4. 验证结果
# self.assertIsNotNone(result)
# self.assertIn("cat", result["combined"]["entities"])
# 由于缺少实际图像,这个测试暂时跳过
self.skipTest("Multi-modal test requires image data")
if __name__ == '__main__':
unittest.main()
集成测试框架使用Python的unittest模块实现,测试了从文档加载、文本拆分、嵌入生成到问答链的完整工作流程。这些测试确保了系统各组件能够协同工作,为系统的稳定性提供保障。
12.2 部署配置实现
部署配置可以帮助将系统部署到不同环境:
python
import os
from typing import Dict, Any
class DeploymentConfig:
"""部署配置"""
def __init__(self, env: str = None):
"""初始化部署配置"""
self.env = env or os.environ.get("ENV", "development")
self.config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
"""加载配置"""
# 基础配置
config = {
"common": {
"log_level": "INFO",
"cache_dir": ".cache",
"temp_dir": ".temp"
},
"development": {
"llm_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"vector_db": {
"type": "chroma",
"persist_directory": ".chromadb"
},
"api_keys": {
"openai": os.environ.get("OPENAI_API_KEY", "your_openai_api_key")
}
},
"production": {
"llm_model": "gpt-4",
"embedding_model": "text-embedding-ada-002",
"vector_db": {
"type": "pinecone",
"index_name": "langchain-prod",
"api_key": os.environ.get("PINECONE_API_KEY", "your_pinecone_api_key"),
"environment": os.environ.get("PINECONE_ENVIRONMENT", "your_pinecone_environment")
},
"api_keys": {
"openai": os.environ.get("OPENAI_API_KEY", "your_openai_api_key")
},
"security": {
"encryption_key": os.environ.get("ENCRYPTION_KEY", "your_encryption_key"),
"access_control": True
}
},
"testing": {
"llm_model": "gpt-3.5-turbo",
"embedding_model": "text-embedding-ada-002",
"vector_db": {
"type": "chroma",
"persist_directory": ".chromadb-test"
},
"api_keys": {
"openai": os.environ.get("OPENAI_API_KEY", "your_openai_api_key")
}
}
}
# 合并通用配置和环境特定配置
env_config = {**config["common"], **config[self.env]}
return env_config
def get(self, key: str, default: Any = None) -> Any:
"""获取配置值"""
return self.config.get(key, default)
def get_nested(self, keys: List[str], default: Any = None) -> Any:
"""获取嵌套配置值"""
value = self.config
for key in keys:
if key in value:
value = value[key]
else:
return default
return value
def set(self, key: str, value: Any) -> None:
"""设置配置值"""
self.config[key] = value
def get_env(self) -> str:
"""获取当前环境"""
return self.env
部署配置实现了基于环境的配置管理,支持开发、生产和测试等不同环境的配置。它从环境变量中获取敏感信息,确保配置的安全性,同时提供了统一的配置访问接口。
12.3 容器化部署实现
容器化部署可以简化系统的部署和管理:
dockerfile
# Dockerfile
# 基础镜像
FROM python:3.9-slim
# 设置工作目录
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# 安装Python依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 设置环境变量
ENV ENV=production
ENV OPENAI_API_KEY=""
ENV PINECONE_API_KEY=""
ENV PINECONE_ENVIRONMENT=""
ENV ENCRYPTION_KEY=""
# 暴露端口
EXPOSE 8000
# 启动应用
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
Dockerfile定义了如何构建应用的Docker镜像,包括安装系统依赖、Python依赖,复制应用代码,设置环境变量和启动应用。使用Docker可以确保应用在不同环境中的一致性。
yaml
# docker-compose.yml
version: '3'
services:
langchain-app:
build: .
ports:
- "8000:8000"
environment:
- ENV=production
- OPENAI_API_KEY=${OPENAI_API_KEY}
- PINECONE_API_KEY=${PINECONE_API_KEY}
- PINECONE_ENVIRONMENT=${PINECONE_ENVIRONMENT}
- ENCRYPTION_KEY=${ENCRYPTION_KEY}
volumes:
- ./data:/app/data
depends_on:
- vector-db
restart: always
vector-db:
image: chromadb/chroma
ports:
- "8001:8000"
volumes:
- ./chroma-data:/data
restart: always
docker-compose.yml定义了如何使用Docker Compose部署应用和其依赖的服务。这里部署了两个服务:应用服务和向量数据库服务,它们通过网络连接并可以协同工作。
十三、应用案例与最佳实践源码解析
13.1 文档问答应用实现
文档问答应用是LangChain的典型应用场景:
python
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.callbacks import get_openai_callback
class DocumentQAApp:
"""文档问答应用"""
def __init__(self, documents_dir: str = None,
persist_dir: str = ".chromadb",
embedding_model: str = "text-embedding-ada-002",
llm_model: str = "gpt-3.5-turbo"):
"""初始化文档问答应用"""
self.documents_dir = documents_dir
self.persist_dir = persist_dir
self.embedding_model = embedding_model
self.llm_model = llm_model
# 初始化组件
self.embeddings = OpenAIEmbeddings(model=self.embedding_model)
self.llm = OpenAI(model_name=self.llm_model)
# 如果指定了文档目录,加载文档
if self.documents_dir:
self.vectorstore = self._load_documents()
else:
# 尝试从持久化目录加载向量存储
self.vectorstore = self._load_vectorstore()
# 创建问答链
self.qa_chain = self._create_qa_chain()
def _load_documents(self) -> Chroma:
"""加载并处理文档"""
# 加载文档
loader = DirectoryLoader(self.documents_dir, glob="**/*.txt", loader_cls=TextLoader)
documents = loader.load()
# 文本拆分
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
texts = text_splitter.split_documents(documents)
# 创建向量存储
vectorstore = Chroma.from_documents(
documents=texts,
embeddings=self.embeddings,
persist_directory=self.persist_dir
)
# 持久化向量存储
vectorstore.persist()
return vectorstore
def _load_vectorstore(self) -> Chroma:
"""从持久化目录加载向量存储"""
try:
vectorstore = Chroma(
embedding_function=self.embeddings,
persist_directory=self.persist_dir
)
return vectorstore
except Exception:
raise ValueError("No vector store found. Please provide a documents directory to initialize.")
def _create_qa_chain(self) -> RetrievalQA:
"""创建问答链"""
# 创建提示模板
prompt_template = """Use the following pieces of context to answer the question at the end.