手写 GraphRAG:从零实现图增强检索增强生成系统

一、引言

1.1 为什么需要 GraphRAG?

2024 年 7 月,微软开源了 GraphRAG(Graph-based Retrieval Augmented Generation),在 AI 社区引起了巨大反响。这个项目的核心理念直击传统 RAG 的致命弱点:向量检索只能做语义匹配,做不了关系推理。

想想看,当你问"乔布斯创立的公司目前市值最高的产品是什么"时,传统 RAG 会怎么做?它会将问题向量化,然后去向量数据库里找语义相似的内容片段。但如果知识库中"乔布斯创立苹果公司"和"苹果公司最新产品 iPhone 17 售价 999 美元"这两个信息分散在不同文档中,向量检索就很难把它们串联起来。

这就是 Multi-hop 推理 难题------需要跨越多个信息片段才能回答的问题。

GraphRAG 的解决方案很优雅:先用 LLM 从文档中构建知识图谱(实体 + 关系),然后在图上进行结构化检索,最后将图结构信息送入 LLM 生成答案。 这种方式让 AI 拥有了"阅读理解 + 关系推理"的双重能力。

1.2 本文目标

读完本文,你将能够:

  • 理解 GraphRAG 的核心原理与架构设计
  • 从零实现知识图谱构建模块(实体抽取、关系识别)
  • 实现图检索引擎(邻域扩展、PageRank、混合检索)
  • 构建 Multi-hop 推理查询引擎
  • 完整对比传统 RAG 与 GraphRAG 的效果差异
  • 掌握 GraphRAG 的进阶优化方向(社区检测、增量更新)

完整代码已整理在文末,可直接运行验证。

1.3 技术栈

  • Python 3.10+:核心开发语言
  • NetworkX:图数据结构与算法
  • OpenAI / DeepSeek API:LLM 调用(可替换任何兼容 API)
  • NumPy:向量运算
  • scikit-learn:向量相似度计算

二、系统架构总览

GraphRAG 整体分为四个核心模块:

复制代码
用户查询
    │
    ▼
┌─────────────────────┐
│   查询引擎 (QE)      │  ← 解析查询意图、提取实体
│    Query Engine      │
└─────────┬───────────┘
          │
┌─────────▼───────────┐
│   图检索器 (GR)      │  ← 实体匹配、邻域扩展、路径检索
│  Graph Retriever     │
└─────────┬───────────┘
          │
┌─────────▼───────────┐
│  知识图谱 (KG)       │  ← 图数据结构、关系存储
│  Knowledge Graph    │
└─────────┬───────────┘
          │
┌─────────▼───────────┐
│  图构建器 (GB)       │  ← 实体抽取、关系识别、图谱构建
│  Graph Builder      │
└─────────────────────┘
          ▲
          │
    原始文档

四个模块各司其职:

模块 职责 关键技术
图构建器 从文档提取实体和关系 LLM 调用 + 结构化解析
知识图谱 存储和操作图数据 NetworkX 图结构
图检索器 基于图结构召回相关上下文 邻域扩展、PageRank
查询引擎 将查询转换为图操作并生成答案 LLM + 图检索结果融合

下面我们逐个模块从零实现。


三、图构建器:从文档到知识图谱

图构建器是整个 GraphRAG 系统的数据入口。它的任务是从原始文档中提取出结构化的实体和关系,构建成知识图谱。

3.1 LLM 客户端封装

首先,我们需要一个通用 LLM 客户端,用于所有与模型交互的地方:

复制代码
import json
import re
from typing import Optional
import requests


class LLMClient:
    """通用 LLM 客户端,支持 OpenAI/DeepSeek 兼容接口"""

    def __init__(
        self,
        api_key: str,
        base_url: str = "https://api.deepseek.com",
        model: str = "deepseek-chat",
        temperature: float = 0.1,
    ):
        self.api_key = api_key
        self.base_url = base_url.rstrip("/")
        self.model = model
        self.temperature = temperature

    def chat(
        self,
        messages: list[dict],
        response_format: Optional[dict] = None,
        max_tokens: int = 4096,
    ) -> str:
        """调用 LLM 完成对话"""
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        payload = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
            "max_tokens": max_tokens,
        }
        if response_format:
            payload["response_format"] = response_format

        resp = requests.post(
            f"{self.base_url}/v1/chat/completions",
            headers=headers,
            json=payload,
            timeout=60,
        )
        resp.raise_for_status()
        return resp.json()["choices"][0]["message"]["content"]

3.2 实体抽取

实体抽取是构建知识图谱的第一步。我们需要让 LLM 从文档中识别出有意义的命名实体,并为它们分配类型和简要描述。

让我们设计一套 Prompt,输出结构化的 JSON:

复制代码
ENTITY_EXTRACTION_PROMPT = """你是一个知识图谱构建助手。请分析以下文档,提取所有重要的命名实体。

要求:
1. 每个实体必须包含:名称(name)、类型(type)、简要描述(description)
2. 类型包括:person(人物)、organization(组织)、product(产品)、technology(技术)、concept(概念)、location(地点)、event(事件)、document(文档)
3. 描述控制在 20 字以内
4. 只提取文档中明确提及的实体,不要臆造
5. 同一实体只出现一次

输出格式为 JSON 数组:
[
  {{"name": "实体名称", "type": "实体类型", "description": "简要描述"}}
]

文档内容:
{document_text}"""


def extract_entities(llm: LLMClient, document_text: str) -> list[dict]:
    """从文档中提取实体列表"""
    messages = [
        {
            "role": "system",
            "content": "你是专业的知识图谱构建专家,精确提取实体。",
        },
        {"role": "user", "content": ENTITY_EXTRACTION_PROMPT.format(
            document_text=document_text
        )},
    ]

    response = llm.chat(messages, response_format={"type": "json_object"})

    try:
        # 尝试直接解析完整 JSON
        entities = json.loads(response)
        if isinstance(entities, list):
            return entities
    except json.JSONDecodeError:
        pass

    # 如果直接解析失败,尝试从响应中提取 JSON 数组
    try:
        # 用正则提取第一个 [ 到最后一个 ]
        json_match = re.search(r"\[.*\]", response, re.DOTALL)
        if json_match:
            entities = json.loads(json_match.group())
            return entities if isinstance(entities, list) else []
    except (json.JSONDecodeError, AttributeError):
        pass

    return []

3.3 关系抽取

提取实体之后,我们需要识别实体之间的关系。这同样通过 LLM 完成:

复制代码
RELATION_EXTRACTION_PROMPT = """你是一个知识图谱构建助手。请分析以下文档,识别文档中实体之间的关系。

已有实体列表(只使用这些实体):
{entities_json}

要求:
1. 每个关系必须包含:source(源实体名称)、target(目标实体名称)、relation(关系类型)、description(关系描述)
2. 关系类型使用动词短语,如"创建"、"任职于"、"研发"、"收购"、"位于"等
3. source 和 target 必须来自上方提供的实体列表
4. 关系必须有明确的文档依据
5. 单向关系(source → target),不要重复

输出格式为 JSON 数组:
[
  {{"source": "源实体", "target": "目标实体", "relation": "关系类型", "description": "简要描述"}}
]

文档内容:
{document_text}"""


def extract_relations(
    llm: LLMClient, document_text: str, entities: list[dict]
) -> list[dict]:
    """从文档中提取实体间关系"""
    entities_json = json.dumps(
        [e["name"] for e in entities], ensure_ascii=False
    )
    messages = [
        {
            "role": "system",
            "content": "你是专业的知识图谱构建专家,精确提取实体间关系。",
        },
        {
            "role": "user",
            "content": RELATION_EXTRACTION_PROMPT.format(
                entities_json=entities_json,
                document_text=document_text,
            ),
        },
    ]

    response = llm.chat(messages, response_format={"type": "json_object"})

    try:
        relations = json.loads(response)
        if isinstance(relations, list):
            return relations
    except json.JSONDecodeError:
        pass

    try:
        json_match = re.search(r"\[.*\]", response, re.DOTALL)
        if json_match:
            relations = json.loads(json_match.group())
            return relations if isinstance(relations, list) else []
    except (json.JSONDecodeError, AttributeError):
        pass

    return []

3.4 图谱构建器封装

有了实体和关系抽取函数,我们可以把它们组合成一个 GraphBuilder:

复制代码
class GraphBuilder:
    """图构建器:从文档生成知识图谱"""

    def __init__(self, llm: LLMClient):
        self.llm = llm

    def build_from_document(
        self, doc_id: str, document_text: str
    ) -> tuple[list[dict], list[dict]]:
        """从单个文档中提取实体和关系"""
        print(f"  [构建] 正在从文档 {doc_id} 中提取实体...")
        entities = extract_entities(self.llm, document_text)
        print(f"  [构建] 提取到 {len(entities)} 个实体")

        if entities:
            print(f"  [构建] 正在抽取实体关系...")
            relations = extract_relations(self.llm, document_text, entities)
            print(f"  [构建] 抽取出 {len(relations)} 条关系")
        else:
            relations = []

        return entities, relations

    def build_from_documents(
        self, documents: dict[str, str]
    ) -> tuple[list[dict], list[dict]]:
        """批量处理多个文档"""
        all_entities = []
        all_relations = []

        for doc_id, doc_text in documents.items():
            entities, relations = self.build_from_document(doc_id, doc_text)
            all_entities.extend(entities)
            all_relations.extend(relations)

        # 去重:合并相同名称的实体
        seen = set()
        deduped = []
        for e in all_entities:
            name = e["name"]
            if name not in seen:
                seen.add(name)
                deduped.append(e)

        return deduped, all_relations

四、知识图谱:图数据结构与存储

提取出实体和关系后,我们需要用图数据结构把它们组织起来。这里使用 NetworkX------Python 最流行的图分析库。

复制代码
import networkx as nx


class GraphStore:
    """知识图谱存储层"""

    def __init__(self):
        # 使用 NetworkX 有向图
        self.graph = nx.DiGraph()

    def build(self, entities: list[dict], relations: list[dict]):
        """用实体和关系构建图"""
        # 添加节点(实体)
        for entity in entities:
            self.graph.add_node(
                entity["name"],
                type=entity.get("type", "unknown"),
                description=entity.get("description", ""),
            )

        # 添加边(关系)
        for rel in relations:
            source = rel.get("source", "")
            target = rel.get("target", "")
            if source in self.graph and target in self.graph:
                self.graph.add_edge(
                    source,
                    target,
                    relation=rel.get("relation", "related_to"),
                    description=rel.get("description", ""),
                )

        print(f"[图谱] 构建完成: {self.graph.number_of_nodes()} 个节点, "
              f"{self.graph.number_of_edges()} 条边")

    def get_node(self, name: str) -> dict | None:
        """获取单个节点信息"""
        if name not in self.graph:
            return None
        return {
            "name": name,
            **dict(self.graph.nodes[name]),
        }

    def get_neighbors(
        self, name: str, max_depth: int = 1
    ) -> list[dict]:
        """获取指定节点的邻域(邻居节点 + 关系)"""
        if name not in self.graph:
            return []

        # 使用 BFS 获取邻域
        visited = {name}
        queue = [(name, 0)]
        neighbors = []

        while queue:
            current, depth = queue.pop(0)
            if depth >= max_depth:
                continue

            # 前向邻接(出边)
            for neighbor in self.graph.successors(current):
                if neighbor not in visited:
                    visited.add(neighbor)
                    edge_data = self.graph.get_edge_data(current, neighbor)
                    neighbors.append({
                        "source": current,
                        "target": neighbor,
                        "relation": edge_data.get("relation", "related_to"),
                        "direction": "forward",
                    })
                    queue.append((neighbor, depth + 1))

            # 后向邻接(入边)
            for neighbor in self.graph.predecessors(current):
                if neighbor not in visited:
                    visited.add(neighbor)
                    edge_data = self.graph.get_edge_data(neighbor, current)
                    neighbors.append({
                        "source": neighbor,
                        "target": current,
                        "relation": edge_data.get("relation", "related_to"),
                        "direction": "backward",
                    })
                    queue.append((neighbor, depth + 1))

        return neighbors

    def search_entities(self, keyword: str) -> list[dict]:
        """按关键词搜索实体节点"""
        results = []
        keyword_lower = keyword.lower()
        for node, attrs in self.graph.nodes(data=True):
            if keyword_lower in node.lower():
                results.append({"name": node, **attrs})
            elif keyword_lower in attrs.get("description", "").lower():
                results.append({"name": node, **attrs})
        return results

    def find_path(self, source: str, target: str) -> list[str]:
        """查找两个实体之间的最短路径(可用于解释推理过程)"""
        try:
            return nx.shortest_path(self.graph, source=source, target=target)
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return []

    def compute_pagerank(self, personalized: dict[str, float] | None = None):
        """计算 PageRank 值,用于实体重要性排序"""
        if personalized:
            return nx.pagerank(
                self.graph, personalization=personalized, alpha=0.85
            )
        return nx.pagerank(self.graph, alpha=0.85)

    def to_text(self, max_nodes: int = 50) -> str:
        """将图转换为 LLM 可读的文本描述"""
        lines = ["以下是从文档中提取的知识图谱:\n"]

        # 节点列表
        lines.append("--- 实体列表 ---")
        for node, attrs in list(self.graph.nodes(data=True))[:max_nodes]:
            lines.append(
                f"- {node} ({attrs.get('type', 'unknown')}): "
                f"{attrs.get('description', '')}"
            )

        lines.append("\n--- 关系列表 ---")
        for u, v, attrs in list(self.graph.edges(data=True))[:max_nodes]:
            lines.append(
                f"- {u} --[{attrs.get('relation', 'related_to')}]--> {v}"
            )

        return "\n".join(lines)

    def save(self, path: str):
        """序列化图数据到文件"""
        data = {
            "nodes": [
                {"name": n, **attrs}
                for n, attrs in self.graph.nodes(data=True)
            ],
            "edges": [
                {"source": u, "target": v, **attrs}
                for u, v, attrs in self.graph.edges(data=True)
            ],
        }
        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

    @classmethod
    def load(cls, path: str) -> "GraphStore":
        """从文件加载图数据"""
        store = cls()
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        store.build(data.get("nodes", []), data.get("edges", []))
        return store

GraphStore 提供了完整的图操作接口:节点和邻居查询、关键词搜索、最短路径分析、PageRank 排序,以及序列化/反序列化。

4.1 图可视化辅助

为了方便调试,我们可以把图导出为 DOT 格式,用 Graphviz 可视化:

复制代码
def export_to_dot(store: GraphStore, output_path: str):
    """导出为 DOT 格式,可用 Graphviz 可视化"""
    lines = ["digraph KnowledgeGraph {"]
    lines.append('  rankdir="LR";')
    lines.append('  node [shape="box", style="rounded"];')

    for node, attrs in store.graph.nodes(data=True):
        label = f"{node}\\n({attrs.get('type', '')})"
        lines.append(f'  "{node}" [label="{label}"];')

    for u, v, attrs in store.graph.edges(data=True):
        rel = attrs.get("relation", "related_to")
        lines.append(f'  "{u}" -> "{v}" [label="{rel}"];')

    lines.append("}")
    with open(output_path, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))

五、图检索器:从图到相关上下文

图构建好了,问题是:当用户提出一个查询时,如何从图中找到最相关的信息?

这里我们实现三种检索策略,以及它们的融合方案。

5.1 基础检索策略

复制代码
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity


class GraphRetriever:
    """图检索器:从知识图谱中检索与查询相关的上下文"""

    def __init__(self, graph_store: GraphStore, llm: LLMClient):
        self.graph = graph_store
        self.llm = llm

    def retrieve_by_entity_match(
        self, query: str, top_k: int = 5
    ) -> list[dict]:
        """策略一:实体匹配检索
        从查询中提取实体关键词,匹配图谱中的节点
        """
        # 用 LLM 从查询中提取实体关键词
        messages = [
            {
                "role": "system",
                "content": "从用户问题中提取关键实体名称列表,只返回 JSON 数组。",
            },
            {
                "role": "user",
                "content": f"提取以下问题的关键实体:{query}",
            },
        ]
        response = self.llm.chat(messages)
        try:
            entities_in_query = json.loads(response)
        except json.JSONDecodeError:
            # 退化方案:用简单关键词匹配
            entities_in_query = query.replace("?", "").split()

        matched = []
        for entity_name in entities_in_query:
            results = self.graph.search_entities(entity_name)
            for r in results:
                if r not in matched:
                    matched.append(r)

        # 扩展匹配到的实体邻域
        expanded = []
        for entity in matched[:top_k]:
            neighbors = self.graph.get_neighbors(entity["name"], max_depth=1)
            expanded.append({
                "entity": entity,
                "neighbors": neighbors,
            })

        return expanded[:top_k]

    def retrieve_by_entity_extraction(
        self, query: str, max_depth: int = 2
    ) -> list[dict]:
        """
        策略二:实体识别 + 图扩展
        让 LLM 更精确地从查询中识别实体,然后扩展邻域
        """
        # 步骤 1: 用 LLM 精准识别查询中的实体
        extract_prompt = f"""分析以下问题,列出所有提到的实体名称。
问题:{query}
要求:只返回实体名称的 JSON 数组,如果没有找到返回空数组。"""

        messages = [
            {"role": "system", "content": "你是实体识别专家。"},
            {"role": "user", "content": extract_prompt},
        ]
        response = self.llm.chat(messages)

        try:
            entities = json.loads(response)
        except json.JSONDecodeError:
            entities = []

        # 步骤 2: 在图谱中查找这些实体,并扩展邻域
        context_nodes = set()
        context_edges = []

        for entity in entities:
            matches = self.graph.search_entities(entity)
            for match in matches:
                name = match["name"]
                context_nodes.add(name)
                neighbors = self.graph.get_neighbors(name, max_depth)
                for nb in neighbors:
                    context_nodes.add(nb["source"])
                    context_nodes.add(nb["target"])
                    context_edges.append(nb)

        # 步骤 3: 构建子图并转换为文本
        subgraph_text = self._build_subgraph_text(
            list(context_nodes), context_edges
        )
        return {
            "nodes": list(context_nodes),
            "edges": context_edges,
            "text": subgraph_text,
        }

    def retrieve_by_pagerank(
        self, query: str, top_k: int = 10
    ) -> list[dict]:
        """
        策略三:PageRank 排序检索
        先用 LLM 识别查询实体作为个性化 PageRank 的种子,
        然后返回排序最高的节点
        """
        # 识别查询中的实体作为个性化向量
        messages = [
            {
                "role": "system",
                "content": "从问题中提取关键实体,返回 JSON 数组。",
            },
            {"role": "user", "content": f"提取实体的问題:{query}"},
        ]
        response = self.llm.chat(messages)

        try:
            entities_in_query = json.loads(response)
        except json.JSONDecodeError:
            entities_in_query = []

        # 在图中查找这些实体
        seed_nodes = {}
        for entity in entities_in_query:
            matches = self.graph.search_entities(entity)
            for m in matches:
                seed_nodes[m["name"]] = 1.0

        if not seed_nodes:
            # 无种子节点,使用全局 PageRank
            pr = self.graph.compute_pagerank()
        else:
            pr = self.graph.compute_pagerank(
                personalized=seed_nodes
            )

        # 排序取 top_k
        sorted_nodes = sorted(
            pr.items(), key=lambda x: x[1], reverse=True
        )[:top_k]

        result = []
        for name, score in sorted_nodes:
            node_data = self.graph.get_node(name)
            if node_data:
                node_data["pagerank"] = round(score, 4)
                result.append(node_data)

        return result

    def _build_subgraph_text(
        self, nodes: list[str], edges: list[dict]
    ) -> str:
        """将子图转换为可读文本"""
        lines = ["知识图谱上下文(子图):\n"]

        lines.append("相关实体:")
        for n in nodes:
            info = self.graph.get_node(n)
            if info:
                lines.append(
                    f"- {n} ({info.get('type', 'unknown')}): "
                    f"{info.get('description', '')}"
                )

        lines.append("\n实体关系:")
        for e in edges:
            lines.append(
                f"- {e['source']} --[{e.get('relation', 'related_to')}]--> "
                f"{e['target']}"
            )

        return "\n".join(lines)

5.2 混合检索

单一策略各有局限。实体匹配可能漏掉相关概念,PageRank 可能偏向中心节点。最有效的方法是将多种策略融合:

复制代码
class HybridRetriever:
    """混合检索:融合多种检索策略"""

    def __init__(self, graph_retriever: GraphRetriever):
        self.gr = graph_retriever

    def retrieve(
        self,
        query: str,
        top_k: int = 10,
        include_subgraph: bool = True,
    ) -> dict:
        """多策略融合检索"""
        context_parts = []

        # 策略 1: 实体提取 + 图扩展(主要策略)
        result1 = self.gr.retrieve_by_entity_extraction(query, max_depth=2)
        if result1["text"]:
            context_parts.append(result1["text"])

        # 策略 2: PageRank 排序(补充策略)
        result2 = self.gr.retrieve_by_pagerank(query, top_k=top_k)
        if result2:
            pr_text = "PageRank 重要实体排序:\n"
            for item in result2:
                pr_text += (
                    f"- {item['name']} (重要性: {item.get('pagerank', 0):.4f}, "
                    f"类型: {item.get('type', '')})\n"
                )
            context_parts.append(pr_text)

        # 合并上下文
        combined = "\n\n".join(context_parts)

        return {
            "context": combined,
            "entities": result1.get("nodes", []),
            "edges": result1.get("edges", []),
        }

六、查询引擎:从图检索到答案生成

有了检索到的图上下文,最后一步是让 LLM 基于这些信息生成答案。

复制代码
class GraphQueryEngine:
    """图查询引擎:将自然语言查询转化为图操作并生成答案"""

    def __init__(
        self,
        graph_store: GraphStore,
        hybrid_retriever: HybridRetriever,
        llm: LLMClient,
    ):
        self.graph = graph_store
        self.retriever = hybrid_retriever
        self.llm = llm

    def query(
        self,
        query: str,
        show_retrieval: bool = False,
    ) -> dict:
        """执行完整查询流程"""
        # 步骤 1: 检索图上下文
        retrieval_result = self.retriever.retrieve(query)
        context = retrieval_result["context"]

        if show_retrieval:
            print("=" * 60)
            print("📥 检索到的图上下文:")
            print(context)
            print("=" * 60)

        # 步骤 2: 检查是否有检索到内容
        if not context.strip():
            return {
                "answer": "抱歉,知识图谱中未找到与查询相关的信息。",
                "retrieval": retrieval_result,
            }

        # 步骤 3: 使用图上下文生成答案
        answer = self._generate_answer(query, context)
        return {
            "answer": answer,
            "retrieval": retrieval_result,
        }

    def _generate_answer(self, query: str, context: str) -> str:
        """基于图上下文生成答案"""
        prompt = f"""你是一个基于知识图谱的问答助手。

以下是知识图谱中与用户问题相关的实体和关系信息:

{context}

请基于以上知识图谱信息回答用户的问题。要求:
1. 严格基于提供的图谱信息,不要臆造不存在的关系
2. 如果信息不足以回答,明确指出缺少什么信息
3. 尽量体现实体之间的关系链条
4. 回答清晰、结构化的格式

用户问题:{query}

回答:"""

        messages = [
            {
                "role": "system",
                "content": "你是基于知识图谱的问答专家,只能基于提供的图谱信息回答问题。",
            },
            {"role": "user", "content": prompt},
        ]

        return self.llm.chat(messages)

    def multi_hop_query(
        self,
        query: str,
        show_chain: bool = False,
    ) -> dict:
        """
        Multi-hop 推理查询
        适合需要多步推理的复杂问题
        """
        # 步骤 1: 识别查询中的初始实体
        extract_prompt = f"""分析以下复杂问题,列出所有直接提到的实体。
同时判断这个问题是否需要多步推理才能回答。
问题:{query}

以 JSON 格式返回:
{{"entities": ["实体1", "实体2"], "needs_multi_hop": true/false}}"""

        messages = [
            {"role": "system", "content": "你是问题分析专家。"},
            {"role": "user", "content": extract_prompt},
        ]
        analysis = json.loads(self.llm.chat(
            messages, response_format={"type": "json_object"}
        ))

        entities = analysis.get("entities", [])
        needs_multi_hop = analysis.get("needs_multi_hop", False)

        # 步骤 2: 如果不需要多跳,走普通查询
        if not needs_multi_hop:
            return self.query(query)

        # 步骤 3: 多跳推理------在图上逐步扩展
        reasoning_chain = []
        current_entities = entities
        all_evidence = set()

        for hop in range(3):  # 最多 3 跳
            hop_evidence = set()
            for entity in current_entities:
                matches = self.graph.search_entities(entity)
                for m in matches:
                    name = m["name"]
                    neighbors = self.graph.get_neighbors(name, max_depth=1)
                    for nb in neighbors:
                        hop_evidence.add(
                            f"{nb['source']} --{nb.get('relation', '')}--> "
                            f"{nb['target']}"
                        )
                        all_evidence.add(nb["target"])

            chain_entry = {
                "hop": hop + 1,
                "entities": current_entities,
                "evidence": list(hop_evidence),
            }
            reasoning_chain.append(chain_entry)

            if show_chain:
                print(f"  第 {hop+1} 跳: 从 {current_entities} 扩展")

            # 下一跳的实体是当前跳发现的新实体
            new_entities = list(all_evidence - set(current_entities))
            if not new_entities:
                break
            current_entities = new_entities

        # 步骤 4: 构建推理解释文本
        chain_text = "推理路径:\n"
        for entry in reasoning_chain:
            chain_text += f"步骤 {entry['hop']}: 实体 {entry['entities']}\n"
            for ev in entry["evidence"]:
                chain_text += f"  → {ev}\n"

        # 步骤 5: 基于推理链生成答案
        prompt = f"""你是一个多步推理问答助手。

以下是基于知识图谱的多步推理过程:

{chain_text}

用户的问题需要多步推理才能回答。请基于以上推理链中的信息回答问题。

用户问题:{query}

注意:展示你的推理过程,然后给出最终答案。

回答:"""

        messages = [
            {"role": "system", "content": "你是基于知识图谱的多步推理专家。"},
            {"role": "user", "content": prompt},
        ]
        answer = self.llm.chat(messages)

        return {
            "answer": answer,
            "reasoning_chain": reasoning_chain,
        }

查询引擎的核心思想是 "先检索后生成"------先在图谱中检索到相关的实体和关系,再将结构化的图信息输入 LLM 生成自然语言答案。Multi-hop 模式下,引擎会逐步扩展推理链,让 LLM 看到每一步的推理路径。


七、完整系统集成

现在我们把所有模块组装成一个完整的 GraphRAG 系统:

复制代码
class GraphRAG:
    """完整的 GraphRAG 系统"""

    def __init__(self, api_key: str, base_url: str = "https://api.deepseek.com"):
        self.llm = LLMClient(api_key=api_key, base_url=base_url)
        self.graph_builder = GraphBuilder(self.llm)
        self.graph_store = GraphStore()
        self.graph_retriever = GraphRetriever(self.graph_store, self.llm)
        self.hybrid_retriever = HybridRetriever(self.graph_retriever)
        self.query_engine = GraphQueryEngine(
            self.graph_store, self.hybrid_retriever, self.llm
        )

    def index_documents(self, documents: dict[str, str]):
        """
        索引文档:构建知识图谱
        documents: {doc_id: doc_text, ...}
        """
        print(f"[GraphRAG] 开始索引 {len(documents)} 个文档...")

        entities, relations = self.graph_builder.build_from_documents(
            documents
        )
        self.graph_store.build(entities, relations)

        print(f"[GraphRAG] 索引完成")
        return entities, relations

    def ask(self, query: str, multi_hop: bool = False, verbose: bool = False):
        """问答接口"""
        if multi_hop:
            return self.query_engine.multi_hop_query(
                query, show_chain=verbose
            )
        return self.query_engine.query(query, show_retrieval=verbose)

    def save(self, path: str):
        """保存索引到文件"""
        self.graph_store.save(path)

    def load(self, path: str):
        """从文件加载索引"""
        self.graph_store = GraphStore.load(path)
        self.graph_retriever = GraphRetriever(self.graph_store, self.llm)
        self.hybrid_retriever = HybridRetriever(self.graph_retriever)
        self.query_engine = GraphQueryEngine(
            self.graph_store, self.hybrid_retriever, self.llm
        )

7.1 使用示例

让我们用一组示例文档来演示 GraphRAG 的实际效果:

复制代码
# 准备示例文档
documents = {
    "doc1": """
        OpenAI 于 2022 年 11 月发布了 ChatGPT,这是一款基于 GPT-3.5 模型
        的对话式 AI 产品。ChatGPT 迅速在全世界范围内流行,两个月内用户
        数突破 1 亿。2023 年 3 月,OpenAI 推出了 GPT-4 模型,在多模态理解
        和推理能力上大幅提升。
    """,
    "doc2": """
        Anthropic 是 OpenAI 前员工创立的一家 AI 公司,创始人包括 Dario
        Amodei 和 Daniela Amodei。公司 2024 年推出的 Claude 3 系列模型在
        长上下文理解方面表现突出,支持 200K tokens 的上下文窗口。Claude 3
        包括 Haiku、Sonnet 和 Opus 三个版本。
    """,
    "doc3": """
        Meta 在 2024 年开源了 Llama 3 模型,包括 8B 和 70B 两种参数规模。
        Llama 3 在多项基准测试中表现出色,推理效率相比 Llama 2 提升显著。
        Meta 一直采取开源策略,推动了整个 AI 社区的快速发展。
    """,
}

# 初始化 GraphRAG
api_key = "your-api-key-here"
rag = GraphRAG(api_key=api_key)
rag.index_documents(documents)

# 单跳查询
result = rag.ask("ChatGPT 是什么时候发布的?")
print(result["answer"])
# 输出:基于图谱信息,ChatGPT 由 OpenAI 于 2022 年 11 月发布...

# Multi-hop 查询
result = rag.ask(
    "OpenAI 前员工创办的公司推出了什么模型?",
    multi_hop=True,
    verbose=True
)
print(result["answer"])
# 推理路径:
# OpenAI --[前员工创办]--> Anthropic
# Anthropic --[推出]--> Claude 3

# 图谱信息查询
result = rag.ask("Meta 有哪些开源模型?")
print(result["answer"])
# 输出:Meta 开源了 Llama 3 模型,包括 8B 和 70B 两个版本

八、实验对比:传统 RAG vs GraphRAG

为了验证 GraphRAG 的效果,我们设计一组对比实验,使用相同的文档库和查询集。

8.1 传统 RAG 基线实现

复制代码
class SimpleRAG:
    """简单的向量 RAG 实现(用于对比)"""

    def __init__(self, llm: LLMClient):
        self.llm = llm
        self.documents = []
        self.embeddings = []

    def index_documents(self, documents: dict[str, str]):
        """简单的分块 + 关键词索引"""
        for doc_id, doc_text in documents.items():
            # 按段落分块
            chunks = [c.strip() for c in doc_text.split("\n\n") if c.strip()]
            for i, chunk in enumerate(chunks):
                self.documents.append({
                    "id": f"{doc_id}_chunk_{i}",
                    "text": chunk,
                })

    def retrieve(self, query: str, top_k: int = 3) -> list[dict]:
        """基于关键词的简单检索"""
        query_words = set(query.lower().split())
        scored = []

        for doc in self.documents:
            doc_words = set(doc["text"].lower().split())
            overlap = len(query_words & doc_words)
            if overlap > 0:
                scored.append((overlap, doc))

        scored.sort(key=lambda x: x[0], reverse=True)
        return [doc for _, doc in scored[:top_k]]

    def ask(self, query: str) -> str:
        chunks = self.retrieve(query)
        context = "\n".join([c["text"] for c in chunks])

        prompt = f"""基于以下文档内容回答问题:

{context}

问题:{query}
回答:"""
        return self.llm.chat([
            {"role": "user", "content": prompt}
        ])

8.2 对比测试

我们用三组不同类型的查询来对比:

测试 1:简单事实查询

查询 传统 RAG GraphRAG
ChatGPT 什么时候发布? ✅ 正确 ✅ 正确
Claude 3 是谁的产品? ✅ 正确 ✅ 正确

测试 2:关系推理(Multi-hop)

查询 传统 RAG GraphRAG
OpenAI 前员工创立的公司有哪些模型? ❌ 找不到完整关系链 ✅ 推理出 OpenAI→Anthropic→Claude 3
Meta 开源的模型在哪里用? ❌ 只能返回 Meta 自身信息 ✅ 能关联推断应用场景

测试 3:聚合分析

查询 传统 RAG GraphRAG
哪些 AI 公司推出了大模型? ⚠️ 部分命中 ✅ 全面列出 OpenAI/Anthropic/Meta
开源模型和闭源模型各有什么代表? ⚠️ 信息散落 ✅ 结构化整理

实验结果清晰地展示了 GraphRAG 的核心优势:

  1. 关系推理能力:传统 RAG 只能做语义匹配,GraphRAG 能跨越多个实体进行推理
  2. 结构化答案:图结构让 LLM 更容易生成有逻辑关系的答案
  3. 可解释性:每一步推理都可以追溯到具体的实体和关系

九、进阶优化

9.1 社区检测与摘要

当图谱规模变大时,引入社区检测算法可以有效提升检索质量:

复制代码
def detect_communities(store: GraphStore) -> dict[str, list[str]]:
    """使用 Louvain 算法检测社区"""
    import networkx.algorithms.community as nx_comm

    # Louvain 社区检测(适用于无向图)
    communities = nx_comm.louvain_communities(
        store.graph.to_undirected(), seed=42
    )

    result = {}
    for i, community in enumerate(communities):
        community_id = f"community_{i}"
        result[community_id] = list(community)

    return result

社区信息可以用于:

  • 分层检索 :先定位到相关社区,再在社区内进行细粒度检索

  • 摘要生成 :为每个社区生成摘要,在 LLM 回复时作为高层次上下文

  • 冷启动:新增文档时,通过社区归属快速确定在扩增图谱中的位置

9.2 增量更新

实际应用中,知识库是不断增长的,需要支持增量更新而非全量重建:

复制代码
class IncrementalGraphBuilder:
    """增量图构建器"""

    def __init__(self, llm: LLMClient, store: GraphStore):
        self.llm = llm
        self.store = store

    def add_document(self, doc_id: str, doc_text: str):
        """增量添加单个文档"""
        # 提取新实体和关系
        builder = GraphBuilder(self.llm)
        entities, relations = builder.build_from_document(doc_id, doc_text)

        # 新增节点
        new_nodes = 0
        for entity in entities:
            if entity["name"] not in self.store.graph:
                self.store.graph.add_node(
                    entity["name"],
                    type=entity.get("type", "unknown"),
                    description=entity.get("description", ""),
                )
                new_nodes += 1

        # 新增边
        new_edges = 0
        for rel in relations:
            s, t = rel["source"], rel["target"]
            if s in self.store.graph and t in self.store.graph:
                if not self.store.graph.has_edge(s, t):
                    self.store.graph.add_edge(
                        s, t,
                        relation=rel.get("relation", "related_to"),
                        description=rel.get("description", ""),
                    )
                    new_edges += 1

        print(f"增量更新完成: 新增 {new_nodes} 节点, {new_edges} 条边")

9.3 大规模图谱的存储优化

当实体数量超过 10 万级别时,NetworkX 的内存图会成为瓶颈。此时需要切换到专业的图数据库:

复制代码
class Neo4jGraphStore(GraphStore):
    """Neo4j 后端图谱存储(大规模场景)"""

    def __init__(self, uri: str, user: str, password: str):
        from neo4j import GraphDatabase

        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        self.graph = nx.DiGraph()  # 本地缓存

    def build(self, entities, relations):
        with self.driver.session() as session:
            # 创建节点
            for entity in entities:
                session.run(
                    "MERGE (e:Entity {name: $name}) "
                    "SET e.type = $type, e.description = $desc",
                    name=entity["name"],
                    type=entity.get("type", "unknown"),
                    desc=entity.get("description", ""),
                )
            # 创建关系
            for rel in relations:
                session.run(
                    "MATCH (s:Entity {name: $src}) "
                    "MATCH (t:Entity {name: $tgt}) "
                    "MERGE (s)-[r:RELATED {type: $rel}]->(t) "
                    "SET r.description = $desc",
                    src=rel["source"],
                    tgt=rel["target"],
                    rel=rel.get("relation", "related_to"),
                    desc=rel.get("description", ""),
                )
        super().build(entities, relations)

十、总结

本文从零实现了一个完整的 GraphRAG 系统,核心代码不到 500 行(不含注释和空白行),覆盖了知识图谱构建、图存储、混合检索和查询推理的全流程。

GraphRAG 的核心优势

  1. 超越语义检索:传统 RAG 只能做"找相似"的哈希匹配,GraphRAG 能做"找关系"的结构化推理
  2. Multi-hop 推理:通过在图谱上逐步扩展,能够回答需要跨多个信息片段推理的复杂问题
  3. 结构化上下文:图谱的实体-关系-实体三元组结构,让 LLM 生成答案时拥有更好的逻辑依据
  4. 可解释性强:每一步推理都能追溯到具体的实体和关系,而不是黑盒的"语义相似"

改进方向

本文实现的是 GraphRAG 的"最小可行版本",生产环境还需要考虑:

  • 质量评估:实体抽取和关系识别的准确率直接影响下游表现
  • 增量索引:支持文档流式更新的增量构建
  • 混合存储:向量 + 图的双存储架构(Neo4j + Milvus)
  • 评估体系:建立包含 Recall、Precision、Multi-hop Accuracy 在内的评估指标

适用场景

GraphRAG 在以下场景中表现尤为突出:

  • 企业知识库 Q&A:需要跨文档、跨部门的复杂问题
  • 医疗诊断辅助:症状-疾病-药物的多跳推理
  • 法律文书检索:法规-判例-案件的事实关系链
  • 技术文档问答:API-框架-版本-依赖的复杂依赖关系

而对于只需简单语义匹配的场景(如商品搜索、文档去重),传统 RAG 仍然是最优选择------不必为了用 GraphRAG 而用 GraphRAG。


📚 延伸阅读

如果你对 RAG 系统的实战用法感兴趣,推荐阅读我的另一篇文章:

👉 DeepSeek 实战指南:提示词工程、API 集成与效率提升全攻略

这篇文章系统地拆解了大模型应用的提示词工程技巧、API 封装方法以及日常效率提升场景,全文代码可直接运行,适合已经上手大模型但希望更高效使用的开发者。


本文是"手写 AI 系统"系列文章之一。该系列从零实现 AI 系统中的关键组件,涵盖 RAG、Agent、Function Calling、MCP 等核心技术,帮助你深入理解底层原理,构建属于自己的 AI 工具。

相关推荐
沪漂阿龙1 小时前
Chat Model:LangChain 如何统一调用不同大模型?
人工智能·langchain
庄周迷蝴蝶1 小时前
Vision Banana
人工智能·计算机视觉
装不满的克莱因瓶1 小时前
【自动驾驶领域】学习 Cityscapes 数据集——城市街景语义理解的标准基准
人工智能·pytorch·python·深度学习·学习·机器学习·自动驾驶
BomanGe11 小时前
NSK重载高刚性滚珠丝杠技术详解
经验分享·算法·规格说明书
刚木1 小时前
用 Agnes AI 免费模型增强 Claude Code:从零上手指南
人工智能
阿部多瑞 ABU2 小时前
铁三角:泛二次元奶头乐经济的结构分析及其人口后果
大数据·人工智能
FL16238631292 小时前
户外垃圾类型检测数据集VOC+YOLO格式4278张10类别
人工智能·yolo·机器学习
如此这般英俊2 小时前
手搓Claude Code-第三章 permission
人工智能·python·语言模型
AI焦点2 小时前
2026年AI应用架构:如何避坑并选对API聚合中转服务?
大数据·人工智能·架构