Q:我是一个初级程序员,请你帮助我学习这个项目的架构设计
好的,我来帮你分析这个项目的架构设计。这是一个基于 RAG(检索增强生成)的知识库问答系统。让我们从整体架构开始分析:
1. 核心架构组件
存储层抽象
定义了三个基础存储接口:
python
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
embedding_func: EmbeddingFunc
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
raise NotImplementedError
- KV存储(BaseKVStorage):用于存储文档和文本块
- 向量存储(BaseVectorStorage):用于存储和检索向量化的内容
- 图存储(BaseGraphStorage):用于存储实体和关系
存储实现
支持多种存储后端:
- Oracle
- TiDB
- MongoDB
- Milvus
- 本地JSON文件存储等
LLM 集成
支持多个大语言模型:
python
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
**kwargs,
) -> str:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
openai_async_client = (
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 添加日志输出
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
logger.debug("Full context:")
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
OpenAI
Zhipu AI
Ollama 等
2. 核心功能流程
文档处理流程:
文档分块
向量化
实体抽取
构建知识图谱
查询模式:
python
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
naive: 简单文本匹配
local: 基于向量相似度的局部检索
global: 基于知识图谱的全局检索
hybrid: 混合检索模式
3. 可扩展性设计
模块化设计:
存储层可以轻松扩展新的存储后端
LLM层支持集成不同的模型
查询模式可以自定义扩展
异步处理:
`
- 异步处理:
python
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
"""Add restriction of maximum async calling times for a async func"""
def final_decro(func):
"""Not using async.Semaphore to aovid use nest-asyncio"""
__current_size = 0
@wraps(func)
async def wait_func(*args, **kwargs):
nonlocal __current_size
while __current_size >= max_size:
await asyncio.sleep(waitting_time)
__current_size += 1
result = await func(*args, **kwargs)
__current_size -= 1
return result
return wait_func
return final_decro
使用 asyncio 实现异步操作,提高性能
4. API 设计
提供了多种接口形式:
python
class QueryRequest(BaseModel):
query: str
mode: str = "hybrid"
only_need_context: bool = False
only_need_prompt: bool = False
REST API
Python SDK
命令行接口
5. 配置管理
采用分层配置:
全局配置
存储配置
模型配置
运行时配置
实用建议
- 如果你想学习这个项目:
先从基础的 JSON 存储实现开始
理解核心的文档处理流程
然后再研究更复杂的存储后端实现 - 如果你想使用这个项目:
从简单的 demo 开始
根据需求选择合适的存储后端
选择合适的 LLM 模型 - 如果你想贡献代码:
了解项目的接口设计
遵循现有的代码风格
添加完整的测试用例
这个项目的架构设计非常清晰,采用了很多优秀的设计模式,值得学习。你有任何具体的问题都可以继续问我。
Q: 我想学习这个项目,请你带领我开始学习
好的,让我带你从基础开始学习这个项目。我们先从一个简单的示例开始:
1. 基本概念
LightRAG 是一个 RAG(检索增强生成)系统,主要用于:
文档的存储和检索
知识图谱构建
智能问答
2. 快速开始
让我们看一个最简单的示例,使用智谱AI的模型:
python
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=zhipu_complete,
llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
llm_model_max_async=4,
llm_model_max_token_size=32768,
embedding_func=EmbeddingFunc(
embedding_dim=2048, # Zhipu embedding-3 dimension
max_token_size=8192,
func=lambda texts: zhipu_embedding(texts),
),
)
这个示例展示了 LightRAG 的基本用法:
初始化配置:
working_dir: 工作目录
llm_model_func: 使用的大语言模型
embedding_func: 文本向量化函数
基本使用流程:
python
# 1. 创建 RAG 实例
rag = LightRAG(...)
# 2. 插入文档
rag.insert("你的文档内容")
# 3. 查询
rag.query("你的问题", param=QueryParam(mode="hybrid"))
3. 核心组件
LightRAG 有三个核心存储组件:
python
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
embedding_func: EmbeddingFunc
async def all_keys(self) -> list[str]:
raise NotImplementedError
async def get_by_id(self, id: str) -> Union[T, None]:
raise NotImplementedError
async def get_by_ids(
self, ids: list[str], fields: Union[set[str], None] = None
) -> list[Union[T, None]]:
raise NotImplementedError
async def filter_keys(self, data: list[str]) -> set[str]:
"""return un-exist keys"""
raise NotImplementedError
async def upsert(self, data: dict[str, T]):
raise NotImplementedError
async def drop(self):
raise NotImplementedError
BaseVectorStorage: 向量存储
用于存储文档的向量表示
支持相似度检索
BaseKVStorage: KV存储
用于存储原始文档和分块
支持按ID检索
BaseGraphStorage: 图存储
用于存储实体和关系
支持知识图谱查询
4. 查询模式
支持多种查询模式:
python
class QueryParam:
mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
only_need_context: bool = False
only_need_prompt: bool = False
response_type: str = "Multiple Paragraphs"
stream: bool = False
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
top_k: int = 60
# Number of document chunks to retrieve.
# top_n: int = 10
# Number of tokens for the original chunks.
max_token_for_text_unit: int = 4000
# Number of tokens for the relationship descriptions
max_token_for_global_context: int = 4000
# Number of tokens for the entity descriptions
max_token_for_local_context: int = 4000
naive: 简单文本匹配
local: 基于向量相似度的局部检索
global: 基于知识图谱的全局检索
hybrid: 混合检索模式
5. 实践练习
让我们从一个简单的例子开始:
首先安装依赖:
创建测试文件 test.py:
python
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
# 初始化 RAG
rag = LightRAG(
working_dir="./test_rag",
llm_model_func=openai_complete_if_cache,
embedding_func=EmbeddingFunc(
embedding_dim=1536,
max_token_size=8192,
func=openai_embedding
)
)
# 插入文档
with open("your_doc.txt", "r") as f:
rag.insert(f.read())
# 查询
result = rag.query(
"你的问题?",
param=QueryParam(mode="hybrid")
)
print(result)
下一步学习建议
- 先熟悉基本用法:
尝试不同的查询模式
理解文档处理流程 - 深入存储实现:
从简单的 JSON 存储开始
逐步尝试其他存储后端 - 探索高级特性:
自定义 embedding 模型
优化检索参数
集成不同的 LLM
你想先深入了解哪个部分?我可以为你详细讲解。
Q:我希望理解文档处理流程
让我帮你理解 LightRAG 的文档处理流程。这个过程主要包含以下几个步骤:
1. 文档插入流程
当你调用 rag.insert() 或 rag.ainsert() 时,会经过以下处理步骤:
- 文档分块 (Chunking)
python
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:")
文档会被切分成较小的块,使用 chunking_by_token_size 函数处理:
控制每块大小 (chunk_token_size)
设置重叠部分 (chunk_overlap_token_size)
生成唯一ID
保存原文引用关系
向量化处理
每个文本块会被转换为向量:
python
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results
主要步骤:
批量处理文本
调用 embedding 函数生成向量
存储向量和元数据
实体抽取和关系构建
这个过程包括:
python
async def ainsert_custom_kg(self, custom_kg: dict):
update_storage = False
try:
# Insert chunks into vector storage
all_chunks_data = {}
chunk_to_source_map = {}
for chunk_data in custom_kg.get("chunks", []):
chunk_content = chunk_data["content"]
source_id = chunk_data["source_id"]
chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
all_chunks_data[chunk_id] = chunk_entry
chunk_to_source_map[source_id] = chunk_id
update_storage = True
if self.chunks_vdb is not None and all_chunks_data:
await self.chunks_vdb.upsert(all_chunks_data)
if self.text_chunks is not None and all_chunks_data:
await self.text_chunks.upsert(all_chunks_data)
# Insert entities into knowledge graph
all_entities_data = []
...
edge_data = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
"keywords": keywords,
}
all_relationships_data.append(edge_data)
update_storage = True
# Insert entities into vector storage if needed
if self.entities_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
"content": dp["entity_name"] + dp["description"],
"entity_name": dp["entity_name"],
}
for dp in all_entities_data
}
await self.entities_vdb.upsert(data_for_vdb)
# Insert relationships into vector storage if needed
if self.relationships_vdb is not None:
data_for_vdb = {
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
"src_id": dp["src_id"],
"tgt_id": dp["tgt_id"],
"content": dp["keywords"]
+ dp["src_id"]
+ dp["tgt_id"]
+ dp["description"],
}
for dp in all_relationships_data
}
await self.relationships_vdb.upsert(data_for_vdb)
抽取实体和关系
构建知识图谱
存储实体和关系的向量表示
2. 存储架构
系统使用三种存储:
- KV存储 - 存储原始文本:
python
@dataclass
class JsonKVStorage(BaseKVStorage):
def __post_init__(self):
working_dir = self.global_config["working_dir"]
self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
self._data = load_json(self._file_name) or {}
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
async def all_keys(self) -> list[str]:
return list(self._data.keys())
async def index_done_callback(self):
write_json(self._data, self._file_name)
async def get_by_id(self, id):
return self._data.get(id, None)
async def get_by_ids(self, ids, fields=None):
if fields is None:
return [self._data.get(id, None) for id in ids]
return [
(
{k: v for k, v in self._data[id].items() if k in fields}
if self._data.get(id, None)
else None
)
for id in ids
]
async def filter_keys(self, data: list[str]) -> set[str]:
return set([s for s in data if s not in self._data])
async def upsert(self, data: dict[str, dict]):
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
return left_data
async def drop(self):
self._data = {}
- 向量存储 - 存储向量化后的内容:
python
@dataclass
class BaseVectorStorage(StorageNameSpace):
embedding_func: EmbeddingFunc
meta_fields: set = field(default_factory=set)
async def query(self, query: str, top_k: int) -> list[dict]:
raise NotImplementedError
async def upsert(self, data: dict[str, dict]):
"""Use 'content' field from value for embedding, use key as id.
If embedding_func is None, use 'embedding' field from value
"""
raise NotImplementedError
- 图存储 - 存储知识图谱:
python
class OracleVectorDBStorage(BaseVectorStorage):
cosine_better_than_threshold: float = 0.2
def __post_init__(self):
pass
async def upsert(self, data: dict[str, dict]):
"""向向量数据库中插入数据"""
pass
async def index_done_callback(self):
pass
#################### query method ###############
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
"""从向量数据库中查询数据"""
embeddings = await self.embedding_func([query])
embedding = embeddings[0]
# 转换精度
dtype = str(embedding.dtype).upper()
dimension = embedding.shape[0]
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
params = {
"embedding_string": embedding_string,
"workspace": self.db.workspace,
"top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold,
}
# print(SQL)
results = await self.db.query(SQL, params=params, multirows=True)
# print("vector search result:",results)
return results
3. 查询处理流程
根据不同的查询模式:
naive 模式:
python
async def naive_query(
query,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
):
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
results = await chunks_vdb.query(query, top_k=query_param.top_k)
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
# Filter out invalid chunks
valid_chunks = [
chunk for chunk in chunks if chunk is not None and "content" in chunk
]
if not valid_chunks:
logger.warning("No valid chunks found after filtering")
return PROMPTS["fail_response"]
maybe_trun_chunks = truncate_list_by_token_size(
valid_chunks,
key=lambda x: x["content"],
max_token_size=query_param.max_token_for_text_unit,
)
if not maybe_trun_chunks:
logger.warning("No chunks left after truncation")
return PROMPTS["fail_response"]
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
if query_param.only_need_context:
return section
sys_prompt_temp = PROMPTS["naive_rag_response"]
sys_prompt = sys_prompt_temp.format(
content_data=section, response_type=query_param.response_type
)
if query_param.only_need_prompt:
return sys_prompt
response = await use_model_func(
query,
system_prompt=sys_prompt,
)
直接通过向量相似度检索
返回相关文本块
kg_query 模式:
python
async def kg_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
# Handle cache
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash(query_param.mode, query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, query_param.mode
)
if cached_response is not None:
return cached_response
example_number = global_config["addon_params"].get("example_number", None)
if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
examples = "\n".join(
PROMPTS["keywords_extraction_examples"][: int(example_number)]
)
else:
examples = "\n".join(PROMPTS["keywords_extraction_examples"])
language = global_config["addon_params"].get(
"language", PROMPTS["DEFAULT_LANGUAGE"]
)
# Set mode
if query_param.mode not in ["local", "global", "hybrid"]:
logger.error(f"Unknown mode {query_param.mode} in kg_query")
return PROMPTS["fail_response"]
# LLM generate keywords
kw_prompt_temp = PROMPTS["keywords_extraction"]
kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
result = await use_model_func(kw_prompt, keyword_extraction=True)
logger.info("kw_prompt result:")
print(result)
提取查询关键词
检索相关实体和关系
组合上下文信息
mix 模式:
python
async def mix_kg_vector_query(
query,
knowledge_graph_inst: BaseGraphStorage,
entities_vdb: BaseVectorStorage,
relationships_vdb: BaseVectorStorage,
chunks_vdb: BaseVectorStorage,
text_chunks_db: BaseKVStorage[TextChunkSchema],
query_param: QueryParam,
global_config: dict,
hashing_kv: BaseKVStorage = None,
) -> str:
"""
Hybrid retrieval implementation combining knowledge graph and vector search.
This function performs a hybrid search by:
1. Extracting semantic information from knowledge graph
2. Retrieving relevant text chunks through vector similarity
3. Combining both results for comprehensive answer generation
"""
# 1. Cache handling
use_model_func = global_config["llm_model_func"]
args_hash = compute_args_hash("mix", query)
cached_response, quantized, min_val, max_val = await handle_cache(
hashing_kv, args_hash, query, "mix"
)
if cached_response is not None:
return cached_response
结合向量检索和图检索
融合多种相关信息
实践建议
开始使用时,可以先尝试最简单的配置:
python
1. 开始使用时,可以先尝试最简单的配置:
```python
from lightrag import LightRAG, QueryParam
# 初始化
rag = LightRAG(
working_dir="./test_dir",
chunk_token_size=512, # 较小的块大小,便于测试
embedding_func=your_embedding_func,
llm_model_func=your_llm_func
)
# 插入文档
await rag.ainsert("你的文档内容")
# 测试不同查询模式
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:
result = await rag.aquery(
"你的问题?",
param=QueryParam(mode=mode)
)
print(f"Mode {mode}:", result)
注意事项:
- 合理设置分块大小
- 选择适合的 embedding 模型
- 根据数据量选择合适的存储后端
- 监控向量化和实体抽取的质量
你想深入了解哪个具体环节?我可以为你详细解释。
Q:详细讲解文本分块
好的,让我详细讲解 LightRAG 中的文本分块(Chunking)机制。
- 分块的目的
文本分块是 RAG 系统的重要预处理步骤,主要有以下目的:
将长文档切分成适合 LLM 处理的小块
保持语义完整性
便于向量化和检索 - 分块实现
LightRAG 使用 chunking_by_token_size 函数来实现分块:
python
def chunking_by_token_size(
content: str,
overlap_token_size=128, # 重叠token数
max_token_size=1024, # 最大token数
tiktoken_model="gpt-4o" # 使用的分词模型
):
# 1. 将文本转换为 tokens
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
# 2. 按照滑动窗口进行分块
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
):
# 3. 将 tokens 转回文本
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size],
model_name=tiktoken_model
)
# 4. 构建分块信息
results.append({
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
})
return results
- 关键参数说明
overlap_token_size (重叠大小)
默认值: 128 tokens
作用: 确保分块之间有重叠,避免切断语义
建议: 根据文档特点调整,通常设置为 max_token_size 的 10%-20%
max_token_size (最大分块大小)
默认值: 1024 tokens
作用: 控制单个分块的最大长度
建议: 根据 LLM 的上下文窗口大小设置
tiktoken_model (分词模型)
默认值: "gpt-4o"
作用: 指定使用哪个模型的分词器
建议: 与最终使用的 LLM 模型保持一致 - 分块处理流程
在 LightRAG 中,分块是插入文档时的第一步处理:
python
async def ainsert(self, string_or_strings):
"""Insert documents with checkpoint support
...
for i in range(0, len(new_docs), batch_size):
batch_docs = dict(list(new_docs.items())[i : i + batch_size])
for doc_id, doc in tqdm_async(
batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
):
try:
# Update status to processing
doc_status = {
"content_summary": doc["content_summary"],
"content_length": doc["content_length"],
"status": DocStatus.PROCESSING,
"created_at": doc["created_at"],
"updated_at": datetime.now().isoformat(),
}
await self.doc_status.upsert({doc_id: doc_status})
# Generate chunks from document
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_id,
}
for dp in chunking_by_token_size(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
# Update status with chunks information
doc_status.update(
{
"chunks_count": len(chunks),
"updated_at": datetime.now().isoformat(),
}
)
try:
# Store chunks in vector database
await self.chunks_vdb.upsert(chunks)
# Extract and store entities and relationships
maybe_new_kg = await extract_entities(
chunks,
knowledge_graph_inst=self.chunk_entity_relation_graph,
entity_vdb=self.entities_vdb,
relationships_vdb=self.relationships_vdb,
global_config=asdict(self),
)
if maybe_new_kg is None:
raise Exception(
"Failed to extract entities and relationships"
)
self.chunk_entity_relation_graph = maybe_new_kg
# Store original document and chunks
await self.full_docs.upsert(
{doc_id: {"content": doc["content"]}}
)
await self.text_chunks.upsert(chunks)
主要步骤:
文档分块
向量化存储
实体抽取
关系构建
存储原始文档和分块
- 存储设计
分块后的内容会被存储到多个位置:
向量数据库 (chunks_vdb):
python
1. **向量数据库** (chunks_vdb):
```42:75:lightrag/kg/milvus_impl.py
async def upsert(self, data: dict[str, dict]):
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
if not len(data):
logger.warning("You insert an empty data to vector DB")
return []
list_data = [
{
"id": k,
**{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
async def wrapped_task(batch):
result = await self.embedding_func(batch)
pbar.update(1)
return result
embedding_tasks = [wrapped_task(batch) for batch in batches]
pbar = tqdm_async(
total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
)
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["vector"] = embeddings[i]
results = self._client.upsert(collection_name=self.namespace, data=list_data)
return results
KV存储 (text_chunks):
存储原始文本内容
保存分块元信息
图存储 (chunk_entity_relation_graph):
存储分块之间的关系
构建知识图谱
6. 使用建议
分块大小选择:
```javascript
rag = LightRAG(
chunk_token_size=512, # 较小文档用小块
chunk_overlap_size=64, # 设置适当重叠
embedding_func=your_func,
llm_model_func=your_llm
)
自定义分块策略:
对于特殊格式文档,可以先进行预处理
可以按段落或章节等自然边界分块
性能优化:
使用批量处理减少 API 调用
启用异步处理提高效率
合理设置缓存策略
监控建议:
跟踪分块质量
监控向量化效果
评估检索准确性
- 实践示例
一个完整的文档处理示例:
javascript
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
# 1. 初始化 RAG
rag = LightRAG(
working_dir="./test",
chunk_token_size=512,
chunk_overlap_size=64,
embedding_func=EmbeddingFunc(
embedding_dim=1536,
max_token_size=512,
func=your_embedding_func
)
)
# 2. 插入文档
async def process_doc():
with open("your_doc.txt", "r") as f:
content = f.read()
await rag.ainsert(content)
# 3. 测试检索效果
result = await rag.aquery(
"你的问题?",
param=QueryParam(mode="hybrid")
)
print(result)
通过合理的分块策略,可以显著提升 RAG 系统的检索效果。