Docker-compose部署GraphRAG-2.7.0

一、背景

GraphRAG(Microsoft GraphRAG)是一套基于图结构检索增强的大模型知识系统。

官方提供了:

  • 索引 构建工具链(index pipeline)

  • 查询工具 CLI graphrag query

但:

官方没有提供 Server 端 API 服务

官方的 indexer 依赖大量 NLP 库(spaCy、umap、lanceDB),部署体积庞大

本地 Milvus Lite 默认路径不适配 Docker

因此,本项目需要在服务器中:

目标

  • 以 Docker 容器启动一个 GraphRAG 查询服务(/ query API)

  • 使用 已经生成的 output 数据(entities、communities、text_units...)。

  • 基于 Milvus Lite 本地文件进行 embedding 检索。

二、实现方案

本项目通过 自定义 server.py 实现:

  • 加载 GraphRAG 的配置(settings.yaml)

  • 自动读取输出数据(entities, relationships, text_units...)

  • 自动构建 local/global/drift/basic 请求结构

  • 支持 SSE 流式输出(适合前端对话系统)

  • 使用 fastapi/uvicorn 作为 Web Server

1.在graphrag根目录新建graphrag_server包

2.创建server.py模块(在graphrag_server下)

python 复制代码
import os
import signal
from pathlib import Path
from typing import Literal, Any, Dict

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse

import graphrag.api as api
from graphrag.config.load_config import load_config
from graphrag.config.enums import SearchMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.utils.api import create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, storage_has_table
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks

# 加载 .env 中的环境变量(可选)
load_dotenv()

# 默认项目根目录(容器内),可以通过环境变量覆盖
DEFAULT_PROJECT_ROOT = Path(
    os.getenv("GRAPHRAG_PROJECT_ROOT", "/workspace/rag_milvus")
)

app = FastAPI(title="GraphRAG Query Server")

# 默认查询参数
DEFAULT_COMMUNITY_LEVEL = 2
DEFAULT_RESPONSE_TYPE = "Multiple Paragraphs"

# 简单的缓存:避免每次请求都重新从磁盘 / Milvus 加载 DataFrame
_LOCAL_CACHE: Dict[str, Any] = {}
_GLOBAL_CACHE: Dict[str, Any] = {}
_DRIFT_CACHE: Dict[str, Any] = {}
_BASIC_CACHE: Dict[str, Any] = {}


class GraphRAGQueryRequest(BaseModel):
    """
    请求体:
    - root: GraphRAG 项目根目录(包含 settings.yaml)
    - method: local / global / drift / basic
    - query: 问题
    - streaming: 是否使用 SSE 流式返回
    """
    root: str | None = None
    method: Literal["local", "global", "drift", "basic"]
    query: str
    streaming: bool = False


@app.get("/healthz")
async def healthz():
    """简单健康检查."""
    return {"status": "ok", "message": "GraphRAG query server is running"}


@app.post("/query")
async def query(req: GraphRAGQueryRequest):
    # 1. 解析 root
    root_str = (req.root or "").strip()
    if not root_str:
        root_path = DEFAULT_PROJECT_ROOT
    else:
        # 注意:容器内的路径用容器自己的绝对路径
        root_path = Path(root_str).resolve()

    if not root_path.exists():
        raise HTTPException(status_code=400, detail=f"root 路径不存在: {root_path}")

    # 2. 校验 query
    if not req.query.strip():
        raise HTTPException(status_code=400, detail="query 不能为空")

    # 3. 解析 method 枚举
    try:
        method_enum = SearchMethod(req.method)
    except ValueError:
        raise HTTPException(status_code=400, detail=f"无效的 method: {req.method}")

    print(f"[QUERY] method={req.method}, root={root_path}, query={req.query[:100]}...")

    # 4. 分发到不同 search 实现
    return await run_query(
        root=root_path,
        query=req.query,
        method=method_enum,
        streaming=req.streaming,
    )


async def run_query(
    root: Path,
    query: str,
    method: SearchMethod,
    streaming: bool,
):
    if method == SearchMethod.LOCAL:
        return await run_local_search(root, query, streaming)
    elif method == SearchMethod.GLOBAL:
        return await run_global_search(root, query, streaming)
    elif method == SearchMethod.DRIFT:
        return await run_drift_search(root, query, streaming)
    elif method == SearchMethod.BASIC:
        return await run_basic_search(root, query, streaming)
    else:
        raise HTTPException(status_code=400, detail=f"不支持的查询方法: {method}")


# =========================
# Local Search
# =========================

async def build_local_request(root_dir: Path):
    """
    构建 local_search 所需的 config + 各类 DataFrame。
    为了减小每次请求的初始化开销,这里做了简单缓存:
    - key = root_dir.resolve(),一个项目只加载一次。
    """
    cache_key = str(root_dir.resolve())
    if cache_key in _LOCAL_CACHE:
        return _LOCAL_CACHE[cache_key]

    try:
        config = load_config(root_dir)
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"加载 GraphRAG 配置失败: {e}",
        )

    dataframe_dict = await _resolve_output_files(
        config=config,
        output_list=[
            "communities",
            "community_reports",
            "text_units",
            "relationships",
            "entities",
        ],
        optional_list=["covariates"],
    )

    result = (
        config,
        dataframe_dict["communities"],
        dataframe_dict["community_reports"],
        dataframe_dict["text_units"],
        dataframe_dict["relationships"],
        dataframe_dict["entities"],
        dataframe_dict.get("covariates"),
    )
    _LOCAL_CACHE[cache_key] = result
    return result


async def run_local_search(root_dir: Path, query: str, streaming: bool):
    (
        config,
        final_communities,
        final_community_reports,
        final_text_units,
        final_relationships,
        final_entities,
        final_covariates,
    ) = await build_local_request(root_dir)

    if streaming:
        async def streaming_search():
            context_data: Any = {}

            def on_context(context: Any) -> None:
                nonlocal context_data
                context_data = context

            callbacks = NoopQueryCallbacks()
            callbacks.on_context = on_context

            async for chunk in api.local_search_streaming(
                config=config,
                entities=final_entities,
                communities=final_communities,
                community_reports=final_community_reports,
                text_units=final_text_units,
                relationships=final_relationships,
                covariates=final_covariates,
                community_level=DEFAULT_COMMUNITY_LEVEL,
                response_type=DEFAULT_RESPONSE_TYPE,
                query=query,
                callbacks=[callbacks],
            ):
                # 直接把 chunk 往前推
                yield chunk

            # 显式告诉客户端流结束(前端自己约定)
            yield "[DONE]"

        return EventSourceResponse(streaming_search(), media_type="text/event-stream")

    # 非流式
    response, context_data = await api.local_search(
        config=config,
        entities=final_entities,
        communities=final_communities,
        community_reports=final_community_reports,
        text_units=final_text_units,
        relationships=final_relationships,
        covariates=final_covariates,
        community_level=DEFAULT_COMMUNITY_LEVEL,
        response_type=DEFAULT_RESPONSE_TYPE,
        query=query,
    )
    return Response(content=response, media_type="text/plain")


# =========================
# Global Search
# =========================

async def build_global_request(root_dir: Path):
    cache_key = str(root_dir.resolve())
    if cache_key in _GLOBAL_CACHE:
        return _GLOBAL_CACHE[cache_key]

    try:
        config = load_config(root_dir)
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"加载 GraphRAG 配置失败: {e}",
        )

    dataframe_dict = await _resolve_output_files(
        config=config,
        output_list=[
            "entities",
            "communities",
            "community_reports",
        ],
        optional_list=[],
    )

    result = (
        config,
        dataframe_dict["entities"],
        dataframe_dict["communities"],
        dataframe_dict["community_reports"],
    )
    _GLOBAL_CACHE[cache_key] = result
    return result


async def run_global_search(root_dir: Path, query: str, streaming: bool):
    (
        config,
        final_entities,
        final_communities,
        final_community_reports,
    ) = await build_global_request(root_dir)

    if streaming:
        async def streaming_search():
            context_data: Any = {}

            def on_context(context: Any) -> None:
                nonlocal context_data
                context_data = context

            callbacks = NoopQueryCallbacks()
            callbacks.on_context = on_context

            async for chunk in api.global_search_streaming(
                config=config,
                entities=final_entities,
                communities=final_communities,
                community_reports=final_community_reports,
                community_level=DEFAULT_COMMUNITY_LEVEL,
                dynamic_community_selection=False,
                response_type=DEFAULT_RESPONSE_TYPE,
                query=query,
                callbacks=[callbacks],
            ):
                yield chunk

            yield "[DONE]"

        return EventSourceResponse(streaming_search(), media_type="text/event-stream")

    response, context_data = await api.global_search(
        config=config,
        entities=final_entities,
        communities=final_communities,
        community_reports=final_community_reports,
        community_level=DEFAULT_COMMUNITY_LEVEL,
        dynamic_community_selection=False,
        response_type=DEFAULT_RESPONSE_TYPE,
        query=query,
    )
    return Response(content=response, media_type="text/plain")


# =========================
# DRIFT Search
# =========================

async def build_drift_request(root_dir: Path):
    cache_key = str(root_dir.resolve())
    if cache_key in _DRIFT_CACHE:
        return _DRIFT_CACHE[cache_key]

    try:
        config = load_config(root_dir)
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"加载 GraphRAG 配置失败: {e}",
        )

    dataframe_dict = await _resolve_output_files(
        config=config,
        output_list=[
            "communities",
            "community_reports",
            "text_units",
            "relationships",
            "entities",
        ],
        optional_list=[],
    )

    result = (
        config,
        dataframe_dict["communities"],
        dataframe_dict["entities"],
        dataframe_dict["community_reports"],
        dataframe_dict["text_units"],
        dataframe_dict["relationships"],
    )
    _DRIFT_CACHE[cache_key] = result
    return result


async def run_drift_search(root_dir: Path, query: str, streaming: bool):
    (
        config,
        final_communities,
        final_entities,
        final_community_reports,
        final_text_units,
        final_relationships,
    ) = await build_drift_request(root_dir)

    if streaming:
        async def streaming_search():
            context_data: Any = {}

            def on_context(context: Any) -> None:
                nonlocal context_data
                context_data = context

            callbacks = NoopQueryCallbacks()
            callbacks.on_context = on_context

            async for chunk in api.drift_search_streaming(
                config=config,
                entities=final_entities,
                communities=final_communities,
                community_reports=final_community_reports,
                text_units=final_text_units,
                relationships=final_relationships,
                community_level=DEFAULT_COMMUNITY_LEVEL,
                response_type=DEFAULT_RESPONSE_TYPE,
                query=query,
                callbacks=[callbacks],
            ):
                yield chunk

            yield "[DONE]"

        return EventSourceResponse(streaming_search(), media_type="text/event-stream")

    response, context_data = await api.drift_search(
        config=config,
        entities=final_entities,
        communities=final_communities,
        community_reports=final_community_reports,
        text_units=final_text_units,
        relationships=final_relationships,
        community_level=DEFAULT_COMMUNITY_LEVEL,
        response_type=DEFAULT_RESPONSE_TYPE,
        query=query,
    )
    return Response(content=response, media_type="text/plain")


# =========================
# BASIC Search(只用 text_units)
# =========================

async def build_basic_request(root_dir: Path):
    cache_key = str(root_dir.resolve())
    if cache_key in _BASIC_CACHE:
        return _BASIC_CACHE[cache_key]

    try:
        config = load_config(root_dir)
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"加载 GraphRAG 配置失败: {e}",
        )

    dataframe_dict = await _resolve_output_files(
        config=config,
        output_list=["text_units"],
        optional_list=[],
    )

    result = (config, dataframe_dict["text_units"])
    _BASIC_CACHE[cache_key] = result
    return result


async def run_basic_search(root_dir: Path, query: str, streaming: bool):
    config, final_text_units = await build_basic_request(root_dir)

    if streaming:
        async def streaming_search():
            async for chunk in api.basic_search_streaming(
                config=config,
                text_units=final_text_units,
                query=query,
            ):
                yield chunk

            yield "[DONE]"

        return EventSourceResponse(streaming_search(), media_type="text/event-stream")

    response, context_data = await api.basic_search(
        config=config,
        text_units=final_text_units,
        query=query,
    )
    return Response(content=response, media_type="text/plain")


# =========================
# 统一读取 output 的工具(支持单 index / multi-index)
# =========================

async def _resolve_output_files(
    config: GraphRagConfig,
    output_list: list[str],
    optional_list: list[str] | None = None,
) -> dict[str, Any]:
    """
    从 GraphRAG 的 output 存储中读取需要的表,支持:
    - 单 index: config.output
    - 多 index: config.outputs
    注意:
    - 这里只处理「结构化表」(parquet/csv 等),
    - 向量检索走的是 config 里的 vector_store / embedding_store,
      不在这个函数里处理,也不会动你的 Milvus 配置。
    """
    dataframe_dict: dict[str, Any] = {}

    # Multi-index 模式
    if config.outputs:
        dataframe_dict["multi-index"] = True
        dataframe_dict["num_indexes"] = len(config.outputs)
        dataframe_dict["index_names"] = list(config.outputs.keys())

        for output in config.outputs.values():
            storage_obj = create_storage_from_config(output)

            # 必选表
            for name in output_list:
                if name not in dataframe_dict:
                    dataframe_dict[name] = []
                df_value = await load_table_from_storage(name=name, storage=storage_obj)
                dataframe_dict[name].append(df_value)

            # 可选表
            if optional_list:
                for optional_name in optional_list:
                    if optional_name not in dataframe_dict:
                        dataframe_dict[optional_name] = []
                    exists = await storage_has_table(optional_name, storage_obj)
                    if exists:
                        df_value = await load_table_from_storage(
                            name=optional_name,
                            storage=storage_obj,
                        )
                        dataframe_dict[optional_name].append(df_value)

        return dataframe_dict

    # 单 index 模式
    dataframe_dict["multi-index"] = False
    storage_obj = create_storage_from_config(config.output)

    for name in output_list:
        df_value = await load_table_from_storage(name=name, storage=storage_obj)
        dataframe_dict[name] = df_value

    if optional_list:
        for optional_name in optional_list:
            exists = await storage_has_table(optional_name, storage_obj)
            if exists:
                df_value = await load_table_from_storage(
                    name=optional_name,
                    storage=storage_obj,
                )
                dataframe_dict[optional_name] = df_value
            else:
                dataframe_dict[optional_name] = None

    return dataframe_dict


# =========================
# 服务入口(uvicorn)
# =========================

if __name__ == "__main__":
    def handle_signal(sig, frame):
        print("[Main] shutting down...")

    signal.signal(signal.SIGINT, handle_signal)
    signal.signal(signal.SIGTERM, handle_signal)

    host = os.getenv("SERVER_HOST", "0.0.0.0")
    port = int(os.getenv("SERVER_PORT", "8000"))

    print(f"[GraphRAG] starting API server on {host}:{port} ...")
    uvicorn.run(app, host=host, port=port)

3.创建Dockerfile文件(在graphrag_server下)

bash 复制代码
# /root/code/graphrag_milvus_2.7.0/graphrag_server/Dockerfile

FROM python:3.11-slim

WORKDIR /app

# 安装一些基础依赖(可以根据需要增减)
RUN apt-get update && apt-get install -y \
    build-essential \
    && rm -rf /var/lib/apt/lists/*

# 先拷贝服务端 requirements
COPY graphrag_server/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt

# 把你修改过的 graphrag 源码拷进镜像
# 注意:因为 build context 是 /root,这里路径是相对 /root 的
COPY . /opt/graphrag

# 用源码安装 graphrag(editable 模式)
RUN pip install --no-cache-dir -e /opt/graphrag

# 拷贝 server.py 进来
COPY graphrag_server/server.py /app/server.py

# 默认监听 8000
EXPOSE 8000

# 这些可以通过 docker-compose 或环境变量覆盖
ENV SERVER_HOST=0.0.0.0
ENV SERVER_PORT=8000
# 默认 GraphRAG 项目根目录(容器内)
ENV GRAPHRAG_PROJECT_ROOT=/workspace/rag_milvus

CMD ["python", "server.py"]

4.创建requirement.txt文件(在graphrag_server下)

bash 复制代码
fastapi
uvicorn[standard]
sse-starlette
python-dotenv
pandas
pymilvus
pymilvus[milvus_lite]

5.创建docker-compose.yaml文件(在graphrag根目录下)

bash 复制代码
version: "3.9"

services:
  graphrag-server:
    build:
      context: /root/code/graphrag_milvus_2.7.0          # 这里是包含 graphrag / rag_milvus / graphrag_server 的根目录
      dockerfile: graphrag_server/Dockerfile
    container_name: graphrag_server
    environment:
      # LLM / Embedding 的相关 ENV,settings.yaml 里可以用 ${XXX} 读取
      GRAPHRAG_API_KEY: "api_key"
      GRAPHRAG_API_BASE: "api_base"
      GRAPHRAG_API_BASE_EMBEDDING: "api_base"
      SERVER_HOST: "0.0.0.0"
      SERVER_PORT: "8000"
      # 默认项目根目录(和 Dockerfile 里的保持一致)
      GRAPHRAG_PROJECT_ROOT: "/workspace/rag_milvus"
    volumes:
      # 把你的 rag_milvus 项目挂进容器,保证 settings.yaml、output 等能读到
      - /root/code/graphrag_milvus_2.7.0/rag_milvus:/workspace/rag_milvus
    ports:
      - "8000:8000"
    restart: unless-stopped

三、整体目录结构

bash 复制代码
graphrag_milvus_2.7.0/
│
├── docker-compose.yml
├── graphrag_server/
│   ├── Dockerfile
│   └── server.py
│
├── rag_milvus/   <-- 构建好的 GraphRAG 项目
│    ├── settings.yaml
│    ├── output/
│    │    ├── communities.parquet
│    │    ├── entities.parquet
│    │    ├── relationships.parquet
│    │    ├── milvus_lite.db
│    │    ...
│
└── graphrag/            <-- 修改过的 GraphRAG 源码

四、启动及验证

1.启动服务

bash 复制代码
docker-compose build graphrag_server
docker-compose up -d

2.检查运行日志

bash 复制代码
docker logs -f graphrag_server

3.查询测试

POST /query

bash 复制代码
curl -X POST http://127.0.0.1:8000/query \
  -H "Content-Type: application/json" \
  -d '{
        "root": "/workspace/rag_milvus",
        "method": "local",
        "query": "失效模式为F043-扭转变形所对应的征兆有哪些?",
        "streaming": false
      }'

GET /healthz

bash 复制代码
curl http://127.0.0.1:8000/healthz
相关推荐
泡沫·1 小时前
2.grep使用手册
运维
init_23611 小时前
【BGP入门专题-5】bgp路由反射器RR
运维·网络
遇见火星1 小时前
Linux下挂载磁盘相关命令
linux·运维·服务器·磁盘·lsblk·fdisk
可爱又迷人的反派角色“yang”1 小时前
Mysql数据库(二)
运维·服务器·前端·数据库·mysql·nginx·云计算
小玉不愚鸭1 小时前
nginx中的https的搭建
运维·nginx·https
weixin_307779132 小时前
Jenkins ASM API 插件:详解与应用指南
java·运维·开发语言·后端·jenkins
是垚不是土2 小时前
轻量化CICD落地:基于Jenkins与Supervisor的中小企业服务发布实践
运维·servlet·ci/cd·微服务·jenkins
温启志c#2 小时前
【无标题极简版的 TCP 服务端和客户端实现,保留核心功能,去掉复杂封装,适合快速测试:】
运维·服务器·网络
北京耐用通信2 小时前
三步打通数据壁垒:耐达讯自动化Ethernet/IP转CC-Link方案全解析。建议点赞收藏
运维·tcp/ip·自动化