一、背景
GraphRAG(Microsoft GraphRAG)是一套基于图结构检索增强的大模型知识系统。
官方提供了:
-
索引 构建工具链(index pipeline)
-
查询工具 CLI (
graphrag query)
但:
❌ 官方没有提供 Server 端 API 服务
❌ 官方的 indexer 依赖大量 NLP 库(spaCy、umap、lanceDB),部署体积庞大
❌ 本地 Milvus Lite 默认路径不适配 Docker
因此,本项目需要在服务器中:
目标
-
以 Docker 容器启动一个 GraphRAG 查询服务(/ query API)。
-
使用 已经生成的 output 数据(entities、communities、text_units...)。
-
基于 Milvus Lite 本地文件进行 embedding 检索。
二、实现方案
本项目通过 自定义 server.py 实现:
-
加载 GraphRAG 的配置(settings.yaml)
-
自动读取输出数据(entities, relationships, text_units...)
-
自动构建 local/global/drift/basic 请求结构
-
支持 SSE 流式输出(适合前端对话系统)
-
使用 fastapi/uvicorn 作为 Web Server
1.在graphrag根目录新建graphrag_server包

2.创建server.py模块(在graphrag_server下)
python
import os
import signal
from pathlib import Path
from typing import Literal, Any, Dict
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
import graphrag.api as api
from graphrag.config.load_config import load_config
from graphrag.config.enums import SearchMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.utils.api import create_storage_from_config
from graphrag.utils.storage import load_table_from_storage, storage_has_table
from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks
# 加载 .env 中的环境变量(可选)
load_dotenv()
# 默认项目根目录(容器内),可以通过环境变量覆盖
DEFAULT_PROJECT_ROOT = Path(
os.getenv("GRAPHRAG_PROJECT_ROOT", "/workspace/rag_milvus")
)
app = FastAPI(title="GraphRAG Query Server")
# 默认查询参数
DEFAULT_COMMUNITY_LEVEL = 2
DEFAULT_RESPONSE_TYPE = "Multiple Paragraphs"
# 简单的缓存:避免每次请求都重新从磁盘 / Milvus 加载 DataFrame
_LOCAL_CACHE: Dict[str, Any] = {}
_GLOBAL_CACHE: Dict[str, Any] = {}
_DRIFT_CACHE: Dict[str, Any] = {}
_BASIC_CACHE: Dict[str, Any] = {}
class GraphRAGQueryRequest(BaseModel):
"""
请求体:
- root: GraphRAG 项目根目录(包含 settings.yaml)
- method: local / global / drift / basic
- query: 问题
- streaming: 是否使用 SSE 流式返回
"""
root: str | None = None
method: Literal["local", "global", "drift", "basic"]
query: str
streaming: bool = False
@app.get("/healthz")
async def healthz():
"""简单健康检查."""
return {"status": "ok", "message": "GraphRAG query server is running"}
@app.post("/query")
async def query(req: GraphRAGQueryRequest):
# 1. 解析 root
root_str = (req.root or "").strip()
if not root_str:
root_path = DEFAULT_PROJECT_ROOT
else:
# 注意:容器内的路径用容器自己的绝对路径
root_path = Path(root_str).resolve()
if not root_path.exists():
raise HTTPException(status_code=400, detail=f"root 路径不存在: {root_path}")
# 2. 校验 query
if not req.query.strip():
raise HTTPException(status_code=400, detail="query 不能为空")
# 3. 解析 method 枚举
try:
method_enum = SearchMethod(req.method)
except ValueError:
raise HTTPException(status_code=400, detail=f"无效的 method: {req.method}")
print(f"[QUERY] method={req.method}, root={root_path}, query={req.query[:100]}...")
# 4. 分发到不同 search 实现
return await run_query(
root=root_path,
query=req.query,
method=method_enum,
streaming=req.streaming,
)
async def run_query(
root: Path,
query: str,
method: SearchMethod,
streaming: bool,
):
if method == SearchMethod.LOCAL:
return await run_local_search(root, query, streaming)
elif method == SearchMethod.GLOBAL:
return await run_global_search(root, query, streaming)
elif method == SearchMethod.DRIFT:
return await run_drift_search(root, query, streaming)
elif method == SearchMethod.BASIC:
return await run_basic_search(root, query, streaming)
else:
raise HTTPException(status_code=400, detail=f"不支持的查询方法: {method}")
# =========================
# Local Search
# =========================
async def build_local_request(root_dir: Path):
"""
构建 local_search 所需的 config + 各类 DataFrame。
为了减小每次请求的初始化开销,这里做了简单缓存:
- key = root_dir.resolve(),一个项目只加载一次。
"""
cache_key = str(root_dir.resolve())
if cache_key in _LOCAL_CACHE:
return _LOCAL_CACHE[cache_key]
try:
config = load_config(root_dir)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"加载 GraphRAG 配置失败: {e}",
)
dataframe_dict = await _resolve_output_files(
config=config,
output_list=[
"communities",
"community_reports",
"text_units",
"relationships",
"entities",
],
optional_list=["covariates"],
)
result = (
config,
dataframe_dict["communities"],
dataframe_dict["community_reports"],
dataframe_dict["text_units"],
dataframe_dict["relationships"],
dataframe_dict["entities"],
dataframe_dict.get("covariates"),
)
_LOCAL_CACHE[cache_key] = result
return result
async def run_local_search(root_dir: Path, query: str, streaming: bool):
(
config,
final_communities,
final_community_reports,
final_text_units,
final_relationships,
final_entities,
final_covariates,
) = await build_local_request(root_dir)
if streaming:
async def streaming_search():
context_data: Any = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
callbacks = NoopQueryCallbacks()
callbacks.on_context = on_context
async for chunk in api.local_search_streaming(
config=config,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=DEFAULT_COMMUNITY_LEVEL,
response_type=DEFAULT_RESPONSE_TYPE,
query=query,
callbacks=[callbacks],
):
# 直接把 chunk 往前推
yield chunk
# 显式告诉客户端流结束(前端自己约定)
yield "[DONE]"
return EventSourceResponse(streaming_search(), media_type="text/event-stream")
# 非流式
response, context_data = await api.local_search(
config=config,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=DEFAULT_COMMUNITY_LEVEL,
response_type=DEFAULT_RESPONSE_TYPE,
query=query,
)
return Response(content=response, media_type="text/plain")
# =========================
# Global Search
# =========================
async def build_global_request(root_dir: Path):
cache_key = str(root_dir.resolve())
if cache_key in _GLOBAL_CACHE:
return _GLOBAL_CACHE[cache_key]
try:
config = load_config(root_dir)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"加载 GraphRAG 配置失败: {e}",
)
dataframe_dict = await _resolve_output_files(
config=config,
output_list=[
"entities",
"communities",
"community_reports",
],
optional_list=[],
)
result = (
config,
dataframe_dict["entities"],
dataframe_dict["communities"],
dataframe_dict["community_reports"],
)
_GLOBAL_CACHE[cache_key] = result
return result
async def run_global_search(root_dir: Path, query: str, streaming: bool):
(
config,
final_entities,
final_communities,
final_community_reports,
) = await build_global_request(root_dir)
if streaming:
async def streaming_search():
context_data: Any = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
callbacks = NoopQueryCallbacks()
callbacks.on_context = on_context
async for chunk in api.global_search_streaming(
config=config,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
community_level=DEFAULT_COMMUNITY_LEVEL,
dynamic_community_selection=False,
response_type=DEFAULT_RESPONSE_TYPE,
query=query,
callbacks=[callbacks],
):
yield chunk
yield "[DONE]"
return EventSourceResponse(streaming_search(), media_type="text/event-stream")
response, context_data = await api.global_search(
config=config,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
community_level=DEFAULT_COMMUNITY_LEVEL,
dynamic_community_selection=False,
response_type=DEFAULT_RESPONSE_TYPE,
query=query,
)
return Response(content=response, media_type="text/plain")
# =========================
# DRIFT Search
# =========================
async def build_drift_request(root_dir: Path):
cache_key = str(root_dir.resolve())
if cache_key in _DRIFT_CACHE:
return _DRIFT_CACHE[cache_key]
try:
config = load_config(root_dir)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"加载 GraphRAG 配置失败: {e}",
)
dataframe_dict = await _resolve_output_files(
config=config,
output_list=[
"communities",
"community_reports",
"text_units",
"relationships",
"entities",
],
optional_list=[],
)
result = (
config,
dataframe_dict["communities"],
dataframe_dict["entities"],
dataframe_dict["community_reports"],
dataframe_dict["text_units"],
dataframe_dict["relationships"],
)
_DRIFT_CACHE[cache_key] = result
return result
async def run_drift_search(root_dir: Path, query: str, streaming: bool):
(
config,
final_communities,
final_entities,
final_community_reports,
final_text_units,
final_relationships,
) = await build_drift_request(root_dir)
if streaming:
async def streaming_search():
context_data: Any = {}
def on_context(context: Any) -> None:
nonlocal context_data
context_data = context
callbacks = NoopQueryCallbacks()
callbacks.on_context = on_context
async for chunk in api.drift_search_streaming(
config=config,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=DEFAULT_COMMUNITY_LEVEL,
response_type=DEFAULT_RESPONSE_TYPE,
query=query,
callbacks=[callbacks],
):
yield chunk
yield "[DONE]"
return EventSourceResponse(streaming_search(), media_type="text/event-stream")
response, context_data = await api.drift_search(
config=config,
entities=final_entities,
communities=final_communities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=DEFAULT_COMMUNITY_LEVEL,
response_type=DEFAULT_RESPONSE_TYPE,
query=query,
)
return Response(content=response, media_type="text/plain")
# =========================
# BASIC Search(只用 text_units)
# =========================
async def build_basic_request(root_dir: Path):
cache_key = str(root_dir.resolve())
if cache_key in _BASIC_CACHE:
return _BASIC_CACHE[cache_key]
try:
config = load_config(root_dir)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"加载 GraphRAG 配置失败: {e}",
)
dataframe_dict = await _resolve_output_files(
config=config,
output_list=["text_units"],
optional_list=[],
)
result = (config, dataframe_dict["text_units"])
_BASIC_CACHE[cache_key] = result
return result
async def run_basic_search(root_dir: Path, query: str, streaming: bool):
config, final_text_units = await build_basic_request(root_dir)
if streaming:
async def streaming_search():
async for chunk in api.basic_search_streaming(
config=config,
text_units=final_text_units,
query=query,
):
yield chunk
yield "[DONE]"
return EventSourceResponse(streaming_search(), media_type="text/event-stream")
response, context_data = await api.basic_search(
config=config,
text_units=final_text_units,
query=query,
)
return Response(content=response, media_type="text/plain")
# =========================
# 统一读取 output 的工具(支持单 index / multi-index)
# =========================
async def _resolve_output_files(
config: GraphRagConfig,
output_list: list[str],
optional_list: list[str] | None = None,
) -> dict[str, Any]:
"""
从 GraphRAG 的 output 存储中读取需要的表,支持:
- 单 index: config.output
- 多 index: config.outputs
注意:
- 这里只处理「结构化表」(parquet/csv 等),
- 向量检索走的是 config 里的 vector_store / embedding_store,
不在这个函数里处理,也不会动你的 Milvus 配置。
"""
dataframe_dict: dict[str, Any] = {}
# Multi-index 模式
if config.outputs:
dataframe_dict["multi-index"] = True
dataframe_dict["num_indexes"] = len(config.outputs)
dataframe_dict["index_names"] = list(config.outputs.keys())
for output in config.outputs.values():
storage_obj = create_storage_from_config(output)
# 必选表
for name in output_list:
if name not in dataframe_dict:
dataframe_dict[name] = []
df_value = await load_table_from_storage(name=name, storage=storage_obj)
dataframe_dict[name].append(df_value)
# 可选表
if optional_list:
for optional_name in optional_list:
if optional_name not in dataframe_dict:
dataframe_dict[optional_name] = []
exists = await storage_has_table(optional_name, storage_obj)
if exists:
df_value = await load_table_from_storage(
name=optional_name,
storage=storage_obj,
)
dataframe_dict[optional_name].append(df_value)
return dataframe_dict
# 单 index 模式
dataframe_dict["multi-index"] = False
storage_obj = create_storage_from_config(config.output)
for name in output_list:
df_value = await load_table_from_storage(name=name, storage=storage_obj)
dataframe_dict[name] = df_value
if optional_list:
for optional_name in optional_list:
exists = await storage_has_table(optional_name, storage_obj)
if exists:
df_value = await load_table_from_storage(
name=optional_name,
storage=storage_obj,
)
dataframe_dict[optional_name] = df_value
else:
dataframe_dict[optional_name] = None
return dataframe_dict
# =========================
# 服务入口(uvicorn)
# =========================
if __name__ == "__main__":
def handle_signal(sig, frame):
print("[Main] shutting down...")
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
host = os.getenv("SERVER_HOST", "0.0.0.0")
port = int(os.getenv("SERVER_PORT", "8000"))
print(f"[GraphRAG] starting API server on {host}:{port} ...")
uvicorn.run(app, host=host, port=port)
3.创建Dockerfile文件(在graphrag_server下)
bash
# /root/code/graphrag_milvus_2.7.0/graphrag_server/Dockerfile
FROM python:3.11-slim
WORKDIR /app
# 安装一些基础依赖(可以根据需要增减)
RUN apt-get update && apt-get install -y \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# 先拷贝服务端 requirements
COPY graphrag_server/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt
# 把你修改过的 graphrag 源码拷进镜像
# 注意:因为 build context 是 /root,这里路径是相对 /root 的
COPY . /opt/graphrag
# 用源码安装 graphrag(editable 模式)
RUN pip install --no-cache-dir -e /opt/graphrag
# 拷贝 server.py 进来
COPY graphrag_server/server.py /app/server.py
# 默认监听 8000
EXPOSE 8000
# 这些可以通过 docker-compose 或环境变量覆盖
ENV SERVER_HOST=0.0.0.0
ENV SERVER_PORT=8000
# 默认 GraphRAG 项目根目录(容器内)
ENV GRAPHRAG_PROJECT_ROOT=/workspace/rag_milvus
CMD ["python", "server.py"]
4.创建requirement.txt文件(在graphrag_server下)
bash
fastapi
uvicorn[standard]
sse-starlette
python-dotenv
pandas
pymilvus
pymilvus[milvus_lite]
5.创建docker-compose.yaml文件(在graphrag根目录下)
bash
version: "3.9"
services:
graphrag-server:
build:
context: /root/code/graphrag_milvus_2.7.0 # 这里是包含 graphrag / rag_milvus / graphrag_server 的根目录
dockerfile: graphrag_server/Dockerfile
container_name: graphrag_server
environment:
# LLM / Embedding 的相关 ENV,settings.yaml 里可以用 ${XXX} 读取
GRAPHRAG_API_KEY: "api_key"
GRAPHRAG_API_BASE: "api_base"
GRAPHRAG_API_BASE_EMBEDDING: "api_base"
SERVER_HOST: "0.0.0.0"
SERVER_PORT: "8000"
# 默认项目根目录(和 Dockerfile 里的保持一致)
GRAPHRAG_PROJECT_ROOT: "/workspace/rag_milvus"
volumes:
# 把你的 rag_milvus 项目挂进容器,保证 settings.yaml、output 等能读到
- /root/code/graphrag_milvus_2.7.0/rag_milvus:/workspace/rag_milvus
ports:
- "8000:8000"
restart: unless-stopped
三、整体目录结构
bash
graphrag_milvus_2.7.0/
│
├── docker-compose.yml
├── graphrag_server/
│ ├── Dockerfile
│ └── server.py
│
├── rag_milvus/ <-- 构建好的 GraphRAG 项目
│ ├── settings.yaml
│ ├── output/
│ │ ├── communities.parquet
│ │ ├── entities.parquet
│ │ ├── relationships.parquet
│ │ ├── milvus_lite.db
│ │ ...
│
└── graphrag/ <-- 修改过的 GraphRAG 源码
四、启动及验证
1.启动服务
bash
docker-compose build graphrag_server
docker-compose up -d
2.检查运行日志
bash
docker logs -f graphrag_server
3.查询测试
POST /query
bash
curl -X POST http://127.0.0.1:8000/query \
-H "Content-Type: application/json" \
-d '{
"root": "/workspace/rag_milvus",
"method": "local",
"query": "失效模式为F043-扭转变形所对应的征兆有哪些?",
"streaming": false
}'
GET /healthz
bash
curl http://127.0.0.1:8000/healthz