大模型做导师之开源项目学习(lightRAG)

Q:我是一个初级程序员,请你帮助我学习这个项目的架构设计

好的,我来帮你分析这个项目的架构设计。这是一个基于 RAG(检索增强生成)的知识库问答系统。让我们从整体架构开始分析:

1. 核心架构组件

存储层抽象

定义了三个基础存储接口:

python 复制代码
@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
    embedding_func: EmbeddingFunc

    async def all_keys(self) -> list[str]:
        raise NotImplementedError

    async def get_by_id(self, id: str) -> Union[T, None]:
        raise NotImplementedError

    async def get_by_ids(
        self, ids: list[str], fields: Union[set[str], None] = None
    ) -> list[Union[T, None]]:
        raise NotImplementedError

    async def filter_keys(self, data: list[str]) -> set[str]:
        """return un-exist keys"""
        raise NotImplementedError

    async def upsert(self, data: dict[str, T]):
        raise NotImplementedError
  • KV存储(BaseKVStorage):用于存储文档和文本块
  • 向量存储(BaseVectorStorage):用于存储和检索向量化的内容
  • 图存储(BaseGraphStorage):用于存储实体和关系

存储实现

支持多种存储后端:

  • Oracle
  • TiDB
  • MongoDB
  • Milvus
  • 本地JSON文件存储等

LLM 集成

支持多个大语言模型:

python 复制代码
@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=4, max=10),
    retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def openai_complete_if_cache(
    model,
    prompt,
    system_prompt=None,
    history_messages=[],
    base_url=None,
    api_key=None,
    **kwargs,
) -> str:
    if api_key:
        os.environ["OPENAI_API_KEY"] = api_key

    openai_async_client = (
        AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
    )
    kwargs.pop("hashing_kv", None)
    kwargs.pop("keyword_extraction", None)
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.extend(history_messages)
    messages.append({"role": "user", "content": prompt})

    # 添加日志输出
    logger.debug("===== Query Input to LLM =====")
    logger.debug(f"Query: {prompt}")
    logger.debug(f"System prompt: {system_prompt}")
    logger.debug("Full context:")
    if "response_format" in kwargs:
        response = await openai_async_client.beta.chat.completions.parse(
            model=model, messages=messages, **kwargs
        )
    else:
        response = await openai_async_client.chat.completions.create(
            model=model, messages=messages, **kwargs
        )

    if hasattr(response, "__aiter__"):

        async def inner():
            async for chunk in response:
                content = chunk.choices[0].delta.content
                if content is None:
                    continue
                if r"\u" in content:
                    content = safe_unicode_decode(content.encode("utf-8"))
                yield content

        return inner()
    else:
        content = response.choices[0].message.content
        if r"\u" in content:
            content = safe_unicode_decode(content.encode("utf-8"))
        return content

OpenAI

Zhipu AI

Ollama 等

2. 核心功能流程

文档处理流程:

文档分块

向量化

实体抽取

构建知识图谱

查询模式:

python 复制代码
class QueryParam:
    mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
    only_need_context: bool = False
    only_need_prompt: bool = False
    response_type: str = "Multiple Paragraphs"
    stream: bool = False
    # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
    top_k: int = 60
    # Number of document chunks to retrieve.
    # top_n: int = 10
    # Number of tokens for the original chunks.
    max_token_for_text_unit: int = 4000
    # Number of tokens for the relationship descriptions
    max_token_for_global_context: int = 4000
    # Number of tokens for the entity descriptions

naive: 简单文本匹配

local: 基于向量相似度的局部检索

global: 基于知识图谱的全局检索

hybrid: 混合检索模式

3. 可扩展性设计

模块化设计:

存储层可以轻松扩展新的存储后端

LLM层支持集成不同的模型

查询模式可以自定义扩展

异步处理:

`

  1. 异步处理:
python 复制代码
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
    """Add restriction of maximum async calling times for a async func"""

    def final_decro(func):
        """Not using async.Semaphore to aovid use nest-asyncio"""
        __current_size = 0

        @wraps(func)
        async def wait_func(*args, **kwargs):
            nonlocal __current_size
            while __current_size >= max_size:
                await asyncio.sleep(waitting_time)
            __current_size += 1
            result = await func(*args, **kwargs)
            __current_size -= 1
            return result

        return wait_func

    return final_decro

使用 asyncio 实现异步操作,提高性能

4. API 设计

提供了多种接口形式:

python 复制代码
class QueryRequest(BaseModel):
    query: str
    mode: str = "hybrid"
    only_need_context: bool = False
    only_need_prompt: bool = False

REST API

Python SDK

命令行接口

5. 配置管理

采用分层配置:

全局配置

存储配置

模型配置

运行时配置

实用建议

  • 如果你想学习这个项目:
    先从基础的 JSON 存储实现开始
    理解核心的文档处理流程
    然后再研究更复杂的存储后端实现
  • 如果你想使用这个项目:
    从简单的 demo 开始
    根据需求选择合适的存储后端
    选择合适的 LLM 模型
  • 如果你想贡献代码:
    了解项目的接口设计
    遵循现有的代码风格
    添加完整的测试用例

这个项目的架构设计非常清晰,采用了很多优秀的设计模式,值得学习。你有任何具体的问题都可以继续问我。

Q: 我想学习这个项目,请你带领我开始学习

好的,让我带你从基础开始学习这个项目。我们先从一个简单的示例开始:

1. 基本概念

LightRAG 是一个 RAG(检索增强生成)系统,主要用于:

文档的存储和检索

知识图谱构建

智能问答

2. 快速开始

让我们看一个最简单的示例,使用智谱AI的模型:

python 复制代码
rag = LightRAG(
    working_dir=WORKING_DIR,
    llm_model_func=zhipu_complete,
    llm_model_name="glm-4-flashx",  # Using the most cost/performance balance model, but you can change it here.
    llm_model_max_async=4,
    llm_model_max_token_size=32768,
    embedding_func=EmbeddingFunc(
        embedding_dim=2048,  # Zhipu embedding-3 dimension
        max_token_size=8192,
        func=lambda texts: zhipu_embedding(texts),
    ),
)

这个示例展示了 LightRAG 的基本用法:

初始化配置:

working_dir: 工作目录

llm_model_func: 使用的大语言模型

embedding_func: 文本向量化函数

基本使用流程:

python 复制代码
# 1. 创建 RAG 实例
rag = LightRAG(...)
# 2. 插入文档
rag.insert("你的文档内容")
# 3. 查询
rag.query("你的问题", param=QueryParam(mode="hybrid"))

3. 核心组件

LightRAG 有三个核心存储组件:

python 复制代码
@dataclass
class BaseVectorStorage(StorageNameSpace):
    embedding_func: EmbeddingFunc
    meta_fields: set = field(default_factory=set)

    async def query(self, query: str, top_k: int) -> list[dict]:
        raise NotImplementedError

    async def upsert(self, data: dict[str, dict]):
        """Use 'content' field from value for embedding, use key as id.
        If embedding_func is None, use 'embedding' field from value
        """
        raise NotImplementedError


@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
    embedding_func: EmbeddingFunc

    async def all_keys(self) -> list[str]:
        raise NotImplementedError

    async def get_by_id(self, id: str) -> Union[T, None]:
        raise NotImplementedError

    async def get_by_ids(
        self, ids: list[str], fields: Union[set[str], None] = None
    ) -> list[Union[T, None]]:
        raise NotImplementedError

    async def filter_keys(self, data: list[str]) -> set[str]:
        """return un-exist keys"""
        raise NotImplementedError

    async def upsert(self, data: dict[str, T]):
        raise NotImplementedError

    async def drop(self):
        raise NotImplementedError

BaseVectorStorage: 向量存储

用于存储文档的向量表示

支持相似度检索

BaseKVStorage: KV存储

用于存储原始文档和分块

支持按ID检索

BaseGraphStorage: 图存储

用于存储实体和关系

支持知识图谱查询

4. 查询模式

支持多种查询模式:

python 复制代码
class QueryParam:
    mode: Literal["local", "global", "hybrid", "naive", "mix"] = "global"
    only_need_context: bool = False
    only_need_prompt: bool = False
    response_type: str = "Multiple Paragraphs"
    stream: bool = False
    # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
    top_k: int = 60
    # Number of document chunks to retrieve.
    # top_n: int = 10
    # Number of tokens for the original chunks.
    max_token_for_text_unit: int = 4000
    # Number of tokens for the relationship descriptions
    max_token_for_global_context: int = 4000
    # Number of tokens for the entity descriptions
    max_token_for_local_context: int = 4000

naive: 简单文本匹配

local: 基于向量相似度的局部检索

global: 基于知识图谱的全局检索

hybrid: 混合检索模式

5. 实践练习

让我们从一个简单的例子开始:

首先安装依赖:

创建测试文件 test.py:

python 复制代码
import os
from lightrag import LightRAG, QueryParam
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc

# 初始化 RAG
rag = LightRAG(
    working_dir="./test_rag",
    llm_model_func=openai_complete_if_cache,
    embedding_func=EmbeddingFunc(
        embedding_dim=1536,
        max_token_size=8192,
        func=openai_embedding
    )
)

# 插入文档
with open("your_doc.txt", "r") as f:
    rag.insert(f.read())

# 查询
result = rag.query(
    "你的问题?",
    param=QueryParam(mode="hybrid")
)
print(result)

下一步学习建议

  1. 先熟悉基本用法:
    尝试不同的查询模式
    理解文档处理流程
  2. 深入存储实现:
    从简单的 JSON 存储开始
    逐步尝试其他存储后端
  3. 探索高级特性:
    自定义 embedding 模型
    优化检索参数
    集成不同的 LLM

你想先深入了解哪个部分?我可以为你详细讲解。

Q:我希望理解文档处理流程

让我帮你理解 LightRAG 的文档处理流程。这个过程主要包含以下几个步骤:

1. 文档插入流程

当你调用 rag.insert() 或 rag.ainsert() 时,会经过以下处理步骤:

  1. 文档分块 (Chunking)
python 复制代码
async def kg_query(
    query,
    knowledge_graph_inst: BaseGraphStorage,
    entities_vdb: BaseVectorStorage,
    relationships_vdb: BaseVectorStorage,
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
    hashing_kv: BaseKVStorage = None,
) -> str:
    # Handle cache
    use_model_func = global_config["llm_model_func"]
    args_hash = compute_args_hash(query_param.mode, query)
    cached_response, quantized, min_val, max_val = await handle_cache(
        hashing_kv, args_hash, query, query_param.mode
    )
    if cached_response is not None:
        return cached_response

    example_number = global_config["addon_params"].get("example_number", None)
    if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
        examples = "\n".join(
            PROMPTS["keywords_extraction_examples"][: int(example_number)]
        )
    else:
        examples = "\n".join(PROMPTS["keywords_extraction_examples"])
    language = global_config["addon_params"].get(
        "language", PROMPTS["DEFAULT_LANGUAGE"]
    )

    # Set mode
    if query_param.mode not in ["local", "global", "hybrid"]:
        logger.error(f"Unknown mode {query_param.mode} in kg_query")
        return PROMPTS["fail_response"]

    # LLM generate keywords
    kw_prompt_temp = PROMPTS["keywords_extraction"]
    kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
    result = await use_model_func(kw_prompt, keyword_extraction=True)
    logger.info("kw_prompt result:")

文档会被切分成较小的块,使用 chunking_by_token_size 函数处理:

控制每块大小 (chunk_token_size)

设置重叠部分 (chunk_overlap_token_size)

生成唯一ID

保存原文引用关系

向量化处理

每个文本块会被转换为向量:

python 复制代码
    async def upsert(self, data: dict[str, dict]):
        logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
        if not len(data):
            logger.warning("You insert an empty data to vector DB")
            return []
        list_data = [
            {
                "id": k,
                **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
            }
            for k, v in data.items()
        ]
        contents = [v["content"] for v in data.values()]
        batches = [
            contents[i : i + self._max_batch_size]
            for i in range(0, len(contents), self._max_batch_size)
        ]

        async def wrapped_task(batch):
            result = await self.embedding_func(batch)
            pbar.update(1)
            return result

        embedding_tasks = [wrapped_task(batch) for batch in batches]
        pbar = tqdm_async(
            total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
        )
        embeddings_list = await asyncio.gather(*embedding_tasks)

        embeddings = np.concatenate(embeddings_list)
        for i, d in enumerate(list_data):
            d["vector"] = embeddings[i]
        results = self._client.upsert(collection_name=self.namespace, data=list_data)
        return results

主要步骤:

批量处理文本

调用 embedding 函数生成向量

存储向量和元数据

实体抽取和关系构建

这个过程包括:

python 复制代码
    async def ainsert_custom_kg(self, custom_kg: dict):
        update_storage = False
        try:
            # Insert chunks into vector storage
            all_chunks_data = {}
            chunk_to_source_map = {}
            for chunk_data in custom_kg.get("chunks", []):
                chunk_content = chunk_data["content"]
                source_id = chunk_data["source_id"]
                chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")

                chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
                all_chunks_data[chunk_id] = chunk_entry
                chunk_to_source_map[source_id] = chunk_id
                update_storage = True

            if self.chunks_vdb is not None and all_chunks_data:
                await self.chunks_vdb.upsert(all_chunks_data)
            if self.text_chunks is not None and all_chunks_data:
                await self.text_chunks.upsert(all_chunks_data)

            # Insert entities into knowledge graph
            all_entities_data = []

                ...

                edge_data = {
                    "src_id": src_id,
                    "tgt_id": tgt_id,
                    "description": description,
                    "keywords": keywords,
                }
                all_relationships_data.append(edge_data)
                update_storage = True

            # Insert entities into vector storage if needed
            if self.entities_vdb is not None:
                data_for_vdb = {
                    compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
                        "content": dp["entity_name"] + dp["description"],
                        "entity_name": dp["entity_name"],
                    }
                    for dp in all_entities_data
                }
                await self.entities_vdb.upsert(data_for_vdb)

            # Insert relationships into vector storage if needed
            if self.relationships_vdb is not None:
                data_for_vdb = {
                    compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
                        "src_id": dp["src_id"],
                        "tgt_id": dp["tgt_id"],
                        "content": dp["keywords"]
                        + dp["src_id"]
                        + dp["tgt_id"]
                        + dp["description"],
                    }
                    for dp in all_relationships_data
                }
                await self.relationships_vdb.upsert(data_for_vdb)

抽取实体和关系

构建知识图谱

存储实体和关系的向量表示

2. 存储架构

系统使用三种存储:

  1. KV存储 - 存储原始文本:
python 复制代码
@dataclass
class JsonKVStorage(BaseKVStorage):
    def __post_init__(self):
        working_dir = self.global_config["working_dir"]
        self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
        self._data = load_json(self._file_name) or {}
        logger.info(f"Load KV {self.namespace} with {len(self._data)} data")

    async def all_keys(self) -> list[str]:
        return list(self._data.keys())

    async def index_done_callback(self):
        write_json(self._data, self._file_name)

    async def get_by_id(self, id):
        return self._data.get(id, None)

    async def get_by_ids(self, ids, fields=None):
        if fields is None:
            return [self._data.get(id, None) for id in ids]
        return [
            (
                {k: v for k, v in self._data[id].items() if k in fields}
                if self._data.get(id, None)
                else None
            )
            for id in ids
        ]

    async def filter_keys(self, data: list[str]) -> set[str]:
        return set([s for s in data if s not in self._data])

    async def upsert(self, data: dict[str, dict]):
        left_data = {k: v for k, v in data.items() if k not in self._data}
        self._data.update(left_data)
        return left_data

    async def drop(self):
        self._data = {}
  1. 向量存储 - 存储向量化后的内容:
python 复制代码
@dataclass
class BaseVectorStorage(StorageNameSpace):
    embedding_func: EmbeddingFunc
    meta_fields: set = field(default_factory=set)

    async def query(self, query: str, top_k: int) -> list[dict]:
        raise NotImplementedError

    async def upsert(self, data: dict[str, dict]):
        """Use 'content' field from value for embedding, use key as id.
        If embedding_func is None, use 'embedding' field from value
        """
        raise NotImplementedError
  1. 图存储 - 存储知识图谱:
python 复制代码
class OracleVectorDBStorage(BaseVectorStorage):
    cosine_better_than_threshold: float = 0.2

    def __post_init__(self):
        pass

    async def upsert(self, data: dict[str, dict]):
        """向向量数据库中插入数据"""
        pass

    async def index_done_callback(self):
        pass

    #################### query method ###############
    async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
        """从向量数据库中查询数据"""
        embeddings = await self.embedding_func([query])
        embedding = embeddings[0]
        # 转换精度
        dtype = str(embedding.dtype).upper()
        dimension = embedding.shape[0]
        embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"

        SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
        params = {
            "embedding_string": embedding_string,
            "workspace": self.db.workspace,
            "top_k": top_k,
            "better_than_threshold": self.cosine_better_than_threshold,
        }
        # print(SQL)
        results = await self.db.query(SQL, params=params, multirows=True)
        # print("vector search result:",results)
        return results

3. 查询处理流程

根据不同的查询模式:

naive 模式:

python 复制代码
async def naive_query(
    query,
    chunks_vdb: BaseVectorStorage,
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
    hashing_kv: BaseKVStorage = None,
):
    # Handle cache
    use_model_func = global_config["llm_model_func"]
    args_hash = compute_args_hash(query_param.mode, query)
    cached_response, quantized, min_val, max_val = await handle_cache(
        hashing_kv, args_hash, query, query_param.mode
    )
    if cached_response is not None:
        return cached_response

    results = await chunks_vdb.query(query, top_k=query_param.top_k)
    if not len(results):
        return PROMPTS["fail_response"]

    chunks_ids = [r["id"] for r in results]
    chunks = await text_chunks_db.get_by_ids(chunks_ids)

    # Filter out invalid chunks
    valid_chunks = [
        chunk for chunk in chunks if chunk is not None and "content" in chunk
    ]

    if not valid_chunks:
        logger.warning("No valid chunks found after filtering")
        return PROMPTS["fail_response"]

    maybe_trun_chunks = truncate_list_by_token_size(
        valid_chunks,
        key=lambda x: x["content"],
        max_token_size=query_param.max_token_for_text_unit,
    )

    if not maybe_trun_chunks:
        logger.warning("No chunks left after truncation")
        return PROMPTS["fail_response"]

    logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
    section = "\n--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])

    if query_param.only_need_context:
        return section

    sys_prompt_temp = PROMPTS["naive_rag_response"]
    sys_prompt = sys_prompt_temp.format(
        content_data=section, response_type=query_param.response_type
    )

    if query_param.only_need_prompt:
        return sys_prompt

    response = await use_model_func(
        query,
        system_prompt=sys_prompt,
    )

直接通过向量相似度检索

返回相关文本块

kg_query 模式:

python 复制代码
async def kg_query(
    query,
    knowledge_graph_inst: BaseGraphStorage,
    entities_vdb: BaseVectorStorage,
    relationships_vdb: BaseVectorStorage,
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
    hashing_kv: BaseKVStorage = None,
) -> str:
    # Handle cache
    use_model_func = global_config["llm_model_func"]
    args_hash = compute_args_hash(query_param.mode, query)
    cached_response, quantized, min_val, max_val = await handle_cache(
        hashing_kv, args_hash, query, query_param.mode
    )
    if cached_response is not None:
        return cached_response

    example_number = global_config["addon_params"].get("example_number", None)
    if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
        examples = "\n".join(
            PROMPTS["keywords_extraction_examples"][: int(example_number)]
        )
    else:
        examples = "\n".join(PROMPTS["keywords_extraction_examples"])
    language = global_config["addon_params"].get(
        "language", PROMPTS["DEFAULT_LANGUAGE"]
    )

    # Set mode
    if query_param.mode not in ["local", "global", "hybrid"]:
        logger.error(f"Unknown mode {query_param.mode} in kg_query")
        return PROMPTS["fail_response"]

    # LLM generate keywords
    kw_prompt_temp = PROMPTS["keywords_extraction"]
    kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
    result = await use_model_func(kw_prompt, keyword_extraction=True)
    logger.info("kw_prompt result:")
    print(result)

提取查询关键词

检索相关实体和关系

组合上下文信息

mix 模式:

python 复制代码
async def mix_kg_vector_query(
    query,
    knowledge_graph_inst: BaseGraphStorage,
    entities_vdb: BaseVectorStorage,
    relationships_vdb: BaseVectorStorage,
    chunks_vdb: BaseVectorStorage,
    text_chunks_db: BaseKVStorage[TextChunkSchema],
    query_param: QueryParam,
    global_config: dict,
    hashing_kv: BaseKVStorage = None,
) -> str:
    """
    Hybrid retrieval implementation combining knowledge graph and vector search.

    This function performs a hybrid search by:
    1. Extracting semantic information from knowledge graph
    2. Retrieving relevant text chunks through vector similarity
    3. Combining both results for comprehensive answer generation
    """
    # 1. Cache handling
    use_model_func = global_config["llm_model_func"]
    args_hash = compute_args_hash("mix", query)
    cached_response, quantized, min_val, max_val = await handle_cache(
        hashing_kv, args_hash, query, "mix"
    )
    if cached_response is not None:
        return cached_response

结合向量检索和图检索

融合多种相关信息

实践建议

开始使用时,可以先尝试最简单的配置:

python 复制代码
1. 开始使用时,可以先尝试最简单的配置:
```python
from lightrag import LightRAG, QueryParam

# 初始化
rag = LightRAG(
    working_dir="./test_dir",
    chunk_token_size=512,  # 较小的块大小,便于测试
    embedding_func=your_embedding_func,
    llm_model_func=your_llm_func
)

# 插入文档
await rag.ainsert("你的文档内容")

# 测试不同查询模式
modes = ["naive", "local", "global", "hybrid"]
for mode in modes:
    result = await rag.aquery(
        "你的问题?",
        param=QueryParam(mode=mode)
    )
    print(f"Mode {mode}:", result)

注意事项:

  • 合理设置分块大小
  • 选择适合的 embedding 模型
  • 根据数据量选择合适的存储后端
  • 监控向量化和实体抽取的质量

你想深入了解哪个具体环节?我可以为你详细解释。

Q:详细讲解文本分块

好的,让我详细讲解 LightRAG 中的文本分块(Chunking)机制。

  1. 分块的目的
    文本分块是 RAG 系统的重要预处理步骤,主要有以下目的:
    将长文档切分成适合 LLM 处理的小块
    保持语义完整性
    便于向量化和检索
  2. 分块实现
    LightRAG 使用 chunking_by_token_size 函数来实现分块:
python 复制代码
def chunking_by_token_size(
    content: str, 
    overlap_token_size=128,  # 重叠token数
    max_token_size=1024,     # 最大token数
    tiktoken_model="gpt-4o"  # 使用的分词模型
):
    # 1. 将文本转换为 tokens
    tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
    
    results = []
    # 2. 按照滑动窗口进行分块
    for index, start in enumerate(
        range(0, len(tokens), max_token_size - overlap_token_size)
    ):
        # 3. 将 tokens 转回文本
        chunk_content = decode_tokens_by_tiktoken(
            tokens[start : start + max_token_size], 
            model_name=tiktoken_model
        )
        
        # 4. 构建分块信息
        results.append({
            "tokens": min(max_token_size, len(tokens) - start),
            "content": chunk_content.strip(),
            "chunk_order_index": index,
        })
    return results
  1. 关键参数说明
    overlap_token_size (重叠大小)
    默认值: 128 tokens
    作用: 确保分块之间有重叠,避免切断语义
    建议: 根据文档特点调整,通常设置为 max_token_size 的 10%-20%
    max_token_size (最大分块大小)
    默认值: 1024 tokens
    作用: 控制单个分块的最大长度
    建议: 根据 LLM 的上下文窗口大小设置
    tiktoken_model (分词模型)
    默认值: "gpt-4o"
    作用: 指定使用哪个模型的分词器
    建议: 与最终使用的 LLM 模型保持一致
  2. 分块处理流程
    在 LightRAG 中,分块是插入文档时的第一步处理:
python 复制代码
    async def ainsert(self, string_or_strings):
        """Insert documents with checkpoint support

        ...

        for i in range(0, len(new_docs), batch_size):
            batch_docs = dict(list(new_docs.items())[i : i + batch_size])

            for doc_id, doc in tqdm_async(
                batch_docs.items(), desc=f"Processing batch {i//batch_size + 1}"
            ):
                try:
                    # Update status to processing
                    doc_status = {
                        "content_summary": doc["content_summary"],
                        "content_length": doc["content_length"],
                        "status": DocStatus.PROCESSING,
                        "created_at": doc["created_at"],
                        "updated_at": datetime.now().isoformat(),
                    }
                    await self.doc_status.upsert({doc_id: doc_status})

                    # Generate chunks from document
                    chunks = {
                        compute_mdhash_id(dp["content"], prefix="chunk-"): {
                            **dp,
                            "full_doc_id": doc_id,
                        }
                        for dp in chunking_by_token_size(
                            doc["content"],
                            overlap_token_size=self.chunk_overlap_token_size,
                            max_token_size=self.chunk_token_size,
                            tiktoken_model=self.tiktoken_model_name,
                        )
                    }

                    # Update status with chunks information
                    doc_status.update(
                        {
                            "chunks_count": len(chunks),
                            "updated_at": datetime.now().isoformat(),
                        }
                    )


                    try:
                        # Store chunks in vector database
                        await self.chunks_vdb.upsert(chunks)

                        # Extract and store entities and relationships
                        maybe_new_kg = await extract_entities(
                            chunks,
                            knowledge_graph_inst=self.chunk_entity_relation_graph,
                            entity_vdb=self.entities_vdb,
                            relationships_vdb=self.relationships_vdb,
                            global_config=asdict(self),
                        )

                        if maybe_new_kg is None:
                            raise Exception(
                                "Failed to extract entities and relationships"
                            )

                        self.chunk_entity_relation_graph = maybe_new_kg

                        # Store original document and chunks
                        await self.full_docs.upsert(
                            {doc_id: {"content": doc["content"]}}
                        )
                        await self.text_chunks.upsert(chunks)

主要步骤:

文档分块

向量化存储

实体抽取

关系构建

存储原始文档和分块

  1. 存储设计

分块后的内容会被存储到多个位置:

向量数据库 (chunks_vdb):

python 复制代码
1. **向量数据库** (chunks_vdb):

```42:75:lightrag/kg/milvus_impl.py
    async def upsert(self, data: dict[str, dict]):
        logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
        if not len(data):
            logger.warning("You insert an empty data to vector DB")
            return []
        list_data = [
            {
                "id": k,
                **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
            }
            for k, v in data.items()
        ]
        contents = [v["content"] for v in data.values()]
        batches = [
            contents[i : i + self._max_batch_size]
            for i in range(0, len(contents), self._max_batch_size)
        ]

        async def wrapped_task(batch):
            result = await self.embedding_func(batch)
            pbar.update(1)
            return result

        embedding_tasks = [wrapped_task(batch) for batch in batches]
        pbar = tqdm_async(
            total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
        )
        embeddings_list = await asyncio.gather(*embedding_tasks)

        embeddings = np.concatenate(embeddings_list)
        for i, d in enumerate(list_data):
            d["vector"] = embeddings[i]
        results = self._client.upsert(collection_name=self.namespace, data=list_data)
        return results
KV存储 (text_chunks):
存储原始文本内容
保存分块元信息
图存储 (chunk_entity_relation_graph):
存储分块之间的关系
构建知识图谱
6. 使用建议
分块大小选择:

```javascript
rag = LightRAG(
    chunk_token_size=512,      # 较小文档用小块
    chunk_overlap_size=64,     # 设置适当重叠
    embedding_func=your_func,
    llm_model_func=your_llm
)

自定义分块策略:

对于特殊格式文档,可以先进行预处理

可以按段落或章节等自然边界分块

性能优化:

使用批量处理减少 API 调用

启用异步处理提高效率

合理设置缓存策略

监控建议:

跟踪分块质量

监控向量化效果

评估检索准确性

  1. 实践示例

一个完整的文档处理示例:

javascript 复制代码
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc

# 1. 初始化 RAG
rag = LightRAG(
    working_dir="./test",
    chunk_token_size=512,
    chunk_overlap_size=64,
    embedding_func=EmbeddingFunc(
        embedding_dim=1536,
        max_token_size=512,
        func=your_embedding_func
    )
)

# 2. 插入文档
async def process_doc():
    with open("your_doc.txt", "r") as f:
        content = f.read()
        await rag.ainsert(content)
        
    # 3. 测试检索效果
    result = await rag.aquery(
        "你的问题?",
        param=QueryParam(mode="hybrid")
    )
    print(result)

通过合理的分块策略,可以显著提升 RAG 系统的检索效果。

相关推荐
leegong231112 小时前
学习PostgreSQL专家认证
数据库·学习·postgresql
Moonnnn.3 小时前
51单片机学习——动态数码管显示
笔记·嵌入式硬件·学习·51单片机
南宫生3 小时前
力扣每日一题【算法学习day.132】
java·学习·算法·leetcode
山河已无恙3 小时前
基于 DeepSeek LLM 本地知识库搭建开源方案(AnythingLLM、Cherry、Ragflow、Dify)认知
开源·知识库·deepseek
技术小齐4 小时前
网络运维学习笔记 016网工初级(HCIA-Datacom与CCNA-EI)PPP点对点协议和PPPoE以太网上的点对点协议(此处只讲华为)
运维·网络·学习
竹言笙熙4 小时前
代码审计初探
学习·web安全
日记成书4 小时前
物联网智能项目
物联网·学习
虾球xz4 小时前
游戏引擎学习第118天
学习·游戏引擎
AI服务老曹5 小时前
运用先进的智能算法和优化模型,进行科学合理调度的智慧园区开源了
运维·人工智能·安全·开源·音视频