python
# A2A × FastAPI × JSON‑RPC 智能问答 Agent · 项目脚手架
> 可直接运行的工程骨架,采用 **FastAPI + Pydantic v2 + JSON‑RPC 2.0**,实现 A2A 规范核心方法:
> - `message/send`(同步短问答)
> - `message/stream`(SSE 多次成功响应)
> - `tasks/get`、`tasks/cancel`、`tasks/resubscribe`(断线续传,回放历史事件)
> - `agent/getAuthenticatedExtendedCard`(Agent 自描述)
>
> 存储后端可切换 **memory / redis / pg**(默认 memory 可直接跑)。
---
## 目录结构
a2a-fastapi/
├─ README.md
├─ requirements.txt
├─ .env.example
├─ docker-compose.yml # 可选:Redis + Postgres(开发)
└─ app/
├─ main.py # 路由与 JSON-RPC 分发
├─ models.py # A2A & JSON-RPC Pydantic 模型
├─ jsonrpc.py # 成功/错误响应助手 + 错误码
├─ llm.py # 模型调用(含流式模拟,可替换为真实 SDK)
├─ tasks.py # TaskManager:状态机 & 流式执行
├─ security/
│ └─ auth.py # API Key/Bearer 校验钩子
└─ storage/
├─ base.py # 抽象接口(Storage)
├─ memory.py # 内存实现(默认)
├─ redis.py # Redis 实现(可选)
└─ pg.py # Postgres 实现(示例/占位)
---
## requirements.txt
```txt
fastapi==0.112.2
uvicorn[standard]==0.30.6
pydantic==2.8.2
redis==5.0.7
SQLAlchemy==2.0.32
asyncpg==0.29.0
python-dotenv==1.0.1
fastmcp>=2.0.0
.env.example
env
# 选择存储:memory | redis | pg
STORAGE_BACKEND=memory
# Redis(当 STORAGE_BACKEND=redis 时)
REDIS_URL=redis://localhost:6379/0
# Postgres(当 STORAGE_BACKEND=pg 时)
PG_DSN=postgresql+asyncpg://postgres:postgres@localhost:5432/a2a
# 安全:设置后才启用校验(可选)
API_KEY=
BEARER_TOKEN=
# MCP 短期记忆服务(fastmcp server.py 以 HTTP/SSE 暴露时)
MCP_SERVER_URL=http://127.0.0.1:8010/mcp
# 服务
HOST=0.0.0.0
PORT=8000
docker-compose.yml(开发可选)
yaml
version: "3.9"
services:
redis:
image: redis:7-alpine
ports: ["6379:6379"]
postgres:
image: postgres:16-alpine
environment:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
POSTGRES_DB: a2a
ports: ["5432:5432"]
README.md(节选)
markdown
### 运行
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
cp .env.example .env # 默认 memory 后端
uvicorn app.main:app --reload
### cURL 示例
# (1) 同步问答:message/send
curl -s http://localhost:8000/rpc -H 'Content-Type: application/json' -d '{
"jsonrpc":"2.0","id":"1","method":"message/send",
"params":{"message":{"kind":"message","messageId":"m1","role":"user","parts":[{"kind":"text","text":"什么是白露?"}]}}
}' | jq
# (2) 流式问答(SSE):message/stream(同一 id,多条成功响应)
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
"jsonrpc":"2.0","id":"42","method":"message/stream",
"params":{"message":{"kind":"message","messageId":"m2","role":"user","parts":[{"kind":"text","text":"介绍下二十四节气"}]}}
}'
# (3) 取消任务
aid=$(uuidgen); echo $aid
curl -s http://localhost:8000/rpc -H 'Content-Type: application/json' -d '{
"jsonrpc":"2.0","id":"3","method":"tasks/cancel","params":{"id":"TASK_ID"}
}' | jq
# (4) 查询任务
aid="TASK_ID"
curl -s http://localhost:8000/rpc -H 'Content-Type: application/json' -d "{\n \"jsonrpc\":\"2.0\",\"id\":\"4\",\"method\":\"tasks/get\",\"params\":{\"id\":\"$aid\"}\n}" | jq
# (5) 续订流(回放事件)
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
"jsonrpc":"2.0","id":"5","method":"tasks/resubscribe",
"params":{"id":"TASK_ID","fromSeq":0}
}'
app/models.py
python
from __future__ import annotations
from typing import List, Optional, Literal, Dict, Any, Annotated, Union
from pydantic import BaseModel, Field
# ===== Parts(判别式联合 by kind) =====
class TextPart(BaseModel):
kind: Literal["text"]
text: str
metadata: Optional[Dict[str, Any]] = None
class DataPart(BaseModel):
kind: Literal["data"]
data: Dict[str, Any]
metadata: Optional[Dict[str, Any]] = None
Part = Annotated[Union[TextPart, DataPart], Field(discriminator="kind")]
# ===== Message =====
class Message(BaseModel):
kind: Literal["message"] = "message"
messageId: str
role: Literal["user", "agent"]
parts: List[Part]
contextId: Optional[str] = None
taskId: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
extensions: Optional[List[str]] = None
referenceTaskIds: Optional[List[str]] = None
# ===== Task / Status =====
class TaskStatus(BaseModel):
state: Literal[
"submitted","working","input-required","completed","canceled",
"failed","rejected","auth-required","unknown"
]
timestamp: Optional[str] = None
message: Optional[Message] = None
class Task(BaseModel):
kind: Literal["task"] = "task"
id: str
contextId: str
status: TaskStatus
history: List[Message] = Field(default_factory=list)
artifacts: List[Dict[str, Any]] = Field(default_factory=list)
metadata: Optional[Dict[str, Any]] = None
# ===== Events =====
class TaskStatusUpdateEvent(BaseModel):
kind: Literal["status-update"] = "status-update"
taskId: str
contextId: str
status: TaskStatus
final: bool = False
metadata: Optional[Dict[str, Any]] = None
class TaskArtifactUpdateEvent(BaseModel):
kind: Literal["artifact-update"] = "artifact-update"
taskId: str
contextId: str
artifact: Dict[str, Any]
append: Optional[bool] = None
lastChunk: Optional[bool] = None
metadata: Optional[Dict[str, Any]] = None
Event = Annotated[
Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent, Message],
Field(discriminator="kind")
]
# ===== JSON-RPC =====
class JSONRPCRequest(BaseModel):
jsonrpc: Literal["2.0"]
method: str
id: str | int | None = None
params: Optional[Dict[str, Any]] = None
class JSONRPCSuccessResponse(BaseModel):
jsonrpc: Literal["2.0"] = "2.0"
id: str | int | None
result: Any
class JSONRPCError(BaseModel):
code: int
message: str
data: Optional[Any] = None
class JSONRPCErrorResponse(BaseModel):
jsonrpc: Literal["2.0"] = "2.0"
id: str | int | None
error: JSONRPCError
# ===== Send params =====
class MessageSendConfiguration(BaseModel):
acceptedOutputModes: Optional[List[str]] = None
blocking: Optional[bool] = None
historyLength: Optional[int] = None
class MessageSendParams(BaseModel):
message: Message
configuration: Optional[MessageSendConfiguration] = None
metadata: Optional[Dict[str, Any]] = None
# ===== AgentCard(简化) =====
class AgentCapabilities(BaseModel):
streaming: Optional[bool] = True
pushNotifications: Optional[bool] = False
class AgentInterface(BaseModel):
transport: Literal["JSONRPC"] = "JSONRPC"
url: str
class AgentCard(BaseModel):
name: str
description: str
version: str
protocolVersion: str = "0.3.0"
url: str
preferredTransport: Literal["JSONRPC"] = "JSONRPC"
capabilities: AgentCapabilities
defaultInputModes: list[str] = ["text/plain"]
defaultOutputModes: list[str] = ["text/plain"]
skills: list[Dict[str, Any]] = Field(default_factory=list)
additionalInterfaces: list[AgentInterface] = Field(default_factory=list)
app/jsonrpc.py
python
from __future__ import annotations
from typing import Any, Dict
from datetime import datetime, timezone
from app.models import JSONRPCSuccessResponse, JSONRPCErrorResponse, JSONRPCError
# 错误码(A2A/JSON-RPC 常用子集)
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603
TASK_NOT_FOUND = -32001
TASK_NOT_CANCELABLE = -32002
UNSUPPORTED_OPERATION = -32004
iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def ok(id_: str|int|None, result: Any) -> Dict[str, Any]:
return JSONRPCSuccessResponse(id=id_, result=result).model_dump()
def err(id_: str|int|None, code: int, message: str, data: Any|None=None) -> Dict[str, Any]:
return JSONRPCErrorResponse(id=id_, error=JSONRPCError(code=code, message=message, data=data)).model_dump()
app/llm.py
python
from __future__ import annotations
import asyncio
from typing import AsyncGenerator, Optional
# 你可以把这里替换为真实 LLM SDK(OpenAI/阿里/Qwen/DeepSeek 等)的流式接口
# 例如:yield 每个 delta/choice 的内容片段
async def stream_answer(prompt: str, *, delay: float = 0.25) -> AsyncGenerator[str, None]:
# 演示:固定分片 + 延迟,便于观察 SSE 效果
pieces = [
"白露是二十四节气之一,",
"表示天气转凉、露水凝结,",
"通常出现在公历九月上旬,",
"民谚有'白露身不露'之说。",
]
for p in pieces:
await asyncio.sleep(delay)
yield p
app/storage/base.py
python
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from app.models import Task, TaskStatus, Message, Event
@dataclass
class EventRecord:
seq: int
payload: Dict[str, Any] # 序列化后的 Event/Message
class Storage(ABC):
@abstractmethod
async def create_task(self, first_message: Message) -> Task: ...
@abstractmethod
async def get_task(self, task_id: str) -> Optional[Task]: ...
@abstractmethod
async def set_status(self, task_id: str, status: TaskStatus) -> None: ...
@abstractmethod
async def append_history(self, task_id: str, msg: Message) -> None: ...
@abstractmethod
async def cancel(self, task_id: str) -> Optional[Task]: ...
@abstractmethod
async def is_canceled(self, task_id: str) -> bool: ...
@abstractmethod
async def append_event(self, task_id: str, payload: Dict[str, Any]) -> int: ...
@abstractmethod
async def iter_events(self, task_id: str, from_seq: int = 0) -> AsyncGenerator[EventRecord, None]: ...
app/storage/memory.py
python
from __future__ import annotations
import uuid
from typing import Dict, Any, AsyncGenerator, Optional, List
from datetime import datetime, timezone
from app.storage.base import Storage, EventRecord
from app.models import Task, TaskStatus, Message
_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
class MemoryStorage(Storage):
def __init__(self) -> None:
self.tasks: Dict[str, Task] = {}
self.cancel_flags: Dict[str, bool] = {}
self.events: Dict[str, List[Dict[str, Any]]] = {}
async def create_task(self, first_message: Message) -> Task:
task_id = str(uuid.uuid4())
ctx_id = str(uuid.uuid4())
t = Task(id=task_id, contextId=ctx_id,
status=TaskStatus(state="submitted", timestamp=_iso()),
history=[first_message])
self.tasks[task_id] = t
self.cancel_flags[task_id] = False
self.events[task_id] = []
return t
async def get_task(self, task_id: str) -> Optional[Task]:
return self.tasks.get(task_id)
async def set_status(self, task_id: str, status: TaskStatus) -> None:
t = self.tasks[task_id]
t.status = status
async def append_history(self, task_id: str, msg: Message) -> None:
t = self.tasks[task_id]
t.history.append(msg)
async def cancel(self, task_id: str) -> Optional[Task]:
t = self.tasks.get(task_id)
if not t: return None
self.cancel_flags[task_id] = True
t.status = TaskStatus(state="canceled", timestamp=_iso())
return t
async def is_canceled(self, task_id: str) -> bool:
return self.cancel_flags.get(task_id, False)
async def append_event(self, task_id: str, payload: Dict[str, Any]) -> int:
lst = self.events[task_id]
lst.append(payload)
return len(lst) # seq = 1-based
async def iter_events(self, task_id: str, from_seq: int = 0):
lst = self.events.get(task_id, [])
for i, p in enumerate(lst[from_seq:], start=from_seq+1):
yield EventRecord(seq=i, payload=p)
app/storage/redis.py
python
from __future__ import annotations
import uuid, json, asyncio
from typing import Any, AsyncGenerator, Optional
from datetime import datetime, timezone
import redis.asyncio as redis
from app.storage.base import Storage, EventRecord
from app.models import Task, TaskStatus, Message
_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
class RedisStorage(Storage):
def __init__(self, url: str):
self.r = redis.from_url(url, decode_responses=True)
def _k_task(self, task_id: str) -> str: return f"a2a:task:{task_id}"
def _k_events(self, task_id: str) -> str: return f"a2a:events:{task_id}"
def _k_cancel(self, task_id: str) -> str: return f"a2a:cancel:{task_id}"
async def create_task(self, first_message: Message) -> Task:
task_id = str(uuid.uuid4())
ctx_id = str(uuid.uuid4())
t = Task(id=task_id, contextId=ctx_id,
status=TaskStatus(state="submitted", timestamp=_iso()),
history=[first_message])
await self.r.hset(self._k_task(task_id), mapping={
"task": t.model_dump_json()
})
await self.r.set(self._k_cancel(task_id), "0")
return t
async def get_task(self, task_id: str) -> Optional[Task]:
d = await self.r.hget(self._k_task(task_id), "task")
if not d: return None
return Task.model_validate_json(d)
async def set_status(self, task_id: str, status: TaskStatus) -> None:
t = await self.get_task(task_id)
if not t: return
t.status = status
await self.r.hset(self._k_task(task_id), mapping={"task": t.model_dump_json()})
async def append_history(self, task_id: str, msg: Message) -> None:
t = await self.get_task(task_id)
if not t: return
t.history.append(msg)
await self.r.hset(self._k_task(task_id), mapping={"task": t.model_dump_json()})
async def cancel(self, task_id: str) -> Optional[Task]:
t = await self.get_task(task_id)
if not t: return None
await self.r.set(self._k_cancel(task_id), "1")
t.status = TaskStatus(state="canceled", timestamp=_iso())
await self.r.hset(self._k_task(task_id), mapping={"task": t.model_dump_json()})
return t
async def is_canceled(self, task_id: str) -> bool:
v = await self.r.get(self._k_cancel(task_id))
return v == "1"
async def append_event(self, task_id: str, payload: dict) -> int:
# 使用 Redis List 存事件,返回新长度作为 seq
return await self.r.rpush(self._k_events(task_id), json.dumps(payload))
async def iter_events(self, task_id: str, from_seq: int = 0) -> AsyncGenerator[EventRecord, None]:
# 读取整个列表(开发足够),生产可换成 XREAD 流
items = await self.r.lrange(self._k_events(task_id), from_seq, -1)
seq = from_seq
for raw in items:
seq += 1
yield EventRecord(seq=seq, payload=json.loads(raw))
app/storage/pg.py
python
# 说明:提供最小的 SQLAlchemy 异步连接与占位,便于后续扩展到持久层。
from __future__ import annotations
from typing import Optional, AsyncGenerator, Dict
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Text
from app.storage.base import Storage, EventRecord
from app.models import Task, TaskStatus, Message
class Base(DeclarativeBase):
pass
class TaskRow(Base):
__tablename__ = "tasks"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
blob: Mapped[str] = mapped_column(Text) # 存整个 Task JSON(演示)
# 为简洁,pg 实现留作示例,接口留空或抛 NotImplementedError,可按需补齐
class PostgresStorage(Storage):
def __init__(self, dsn: str):
self.engine = create_async_engine(dsn, echo=False)
self.Session = async_sessionmaker(self.engine, expire_on_commit=False)
async def create_task(self, first_message: Message) -> Task:
raise NotImplementedError("Fill with real implementation")
async def get_task(self, task_id: str) -> Optional[Task]:
raise NotImplementedError()
async def set_status(self, task_id: str, status: TaskStatus) -> None:
raise NotImplementedError()
async def append_history(self, task_id: str, msg: Message) -> None:
raise NotImplementedError()
async def cancel(self, task_id: str) -> Optional[Task]:
raise NotImplementedError()
async def is_canceled(self, task_id: str) -> bool:
raise NotImplementedError()
async def append_event(self, task_id: str, payload: Dict) -> int:
raise NotImplementedError()
async def iter_events(self, task_id: str, from_seq: int = 0):
raise NotImplementedError()
app/tasks.py
python
from __future__ import annotations
import uuid
from typing import AsyncGenerator, Dict, Optional, List
from datetime import datetime, timezone
from app.models import (
Message, TextPart, Task, TaskStatus,
TaskStatusUpdateEvent
)
from app.llm import stream_answer
from app.storage.base import Storage
# 新增:MCP 记忆客户端接口(在 main.py 中传入)
class MemoryClientProto:
async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600): ...
async def recent(self, session_id: str, n: int = 10) -> List[dict]: ...
_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _extract_text(msg: Message) -> str:
parts = [p for p in msg.parts if isinstance(p, TextPart)]
return parts[0].text if parts else ""
class TaskManager:
def __init__(self, store: Storage, memory: Optional[MemoryClientProto] = None) -> None:
self.store = store
self.memory = memory
async def create_from_message(self, msg: Message) -> Task:
task = await self.store.create_task(msg)
# 把用户输入写入短期记忆(按 contextId 做会话隔离)
if self.memory:
try:
await self.memory.append(task.contextId, role="user", text=_extract_text(msg))
except Exception:
pass
return task
async def run_stream(self, task: Task) -> AsyncGenerator[Dict, None]:
# 进入 working
await self.store.set_status(task.id, TaskStatus(state="working", timestamp=_iso()))
evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
status=TaskStatus(state="working", timestamp=_iso()))
await self.store.append_event(task.id, evt.model_dump())
yield evt.model_dump()
# 取用户问题
user_msg = task.history[-1]
user_text = _extract_text(user_msg)
# 从短期记忆取最近若干条,拼装到提示里
memory_snippets: List[dict] = []
if self.memory:
try:
memory_snippets = await self.memory.recent(task.contextId, n=10)
except Exception:
memory_snippets = []
if memory_snippets:
mem_ctx = "\n".join(f"{it.get('role')}: {it.get('text')}" for it in memory_snippets)
prompt = f"[会话短期记忆]\n{mem_ctx}\n\n[当前提问]\n{user_text}"
else:
prompt = user_text
# 流式生成
accum = []
async for ch in stream_answer(prompt):
if await self.store.is_canceled(task.id):
# 推送最终 canceled
await self.store.set_status(task.id, TaskStatus(state="canceled", timestamp=_iso()))
final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
status=TaskStatus(state="canceled", timestamp=_iso()), final=True)
await self.store.append_event(task.id, final_evt.model_dump())
yield final_evt.model_dump()
return
accum.append(ch)
# 最终答案
answer = "".join(accum)
agent_msg = Message(kind="message", messageId=str(uuid.uuid4()), role="agent",
parts=[TextPart(kind="text", text=answer)],
contextId=task.contextId, taskId=task.id)
await self.store.append_history(task.id, agent_msg)
await self.store.append_event(task.id, agent_msg.model_dump())
yield agent_msg.model_dump()
# 写入短期记忆(agent 回复)
if self.memory:
try:
await self.memory.append(task.contextId, role="agent", text=answer)
except Exception:
pass
# 完成状态
await self.store.set_status(task.id, TaskStatus(state="completed", timestamp=_iso()))
final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
status=TaskStatus(state="completed", timestamp=_iso()), final=True)
await self.store.append_event(task.id, final_evt.model_dump())
yield final_evt.model_dump()
app/security/auth.py
python
from __future__ import annotations
import os
from fastapi import Header, HTTPException
async def require_auth(x_api_key: str | None = Header(default=None), authorization: str | None = Header(default=None)):
api_key = os.getenv("API_KEY")
bearer = os.getenv("BEARER_TOKEN")
if api_key:
if x_api_key != api_key:
raise HTTPException(status_code=401, detail="invalid api key")
if bearer:
token = None
if authorization and authorization.lower().startswith("bearer "):
token = authorization[7:].strip()
if token != bearer:
raise HTTPException(status_code=401, detail="invalid bearer token")
app/main.py
python
from __future__ import annotations
import os, json
from typing import AsyncGenerator
from fastapi import FastAPI, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
from app.models import JSONRPCRequest, MessageSendParams, AgentCard, AgentCapabilities
from app.jsonrpc import ok, err, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, TASK_NOT_FOUND
from app.security.auth import require_auth
from app.storage.base import Storage
from app.storage.memory import MemoryStorage
from app.storage.redis import RedisStorage
# from app.storage.pg import PostgresStorage
from app.tasks import TaskManager
app = FastAPI(title="A2A JSON-RPC Q&A Agent")
# 选择存储后端
backend = os.getenv("STORAGE_BACKEND", "memory")
if backend == "redis":
store: Storage = RedisStorage(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
elif backend == "pg":
# store = PostgresStorage(os.getenv("PG_DSN", "postgresql+asyncpg://postgres:postgres@localhost:5432/a2a"))
store = MemoryStorage() # 占位:还未实现,暂退回内存
else:
store = MemoryStorage()
tm = TaskManager(store)
@app.post("/rpc")
async def rpc(request: Request, _=Depends(require_auth)):
try:
payload = await request.json()
except Exception:
return JSONResponse(err(None, PARSE_ERROR, "Invalid JSON payload"))
try:
rpc_req = JSONRPCRequest.model_validate(payload)
except Exception as e:
return JSONResponse(err(None, INVALID_REQUEST, "Invalid Request", str(e)))
rid = rpc_req.id
method = rpc_req.method
params = rpc_req.params or {}
accept = request.headers.get("accept", "")
if method == "message/send":
try:
p = MessageSendParams.model_validate(params)
except Exception as e:
return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))
task = await tm.create_from_message(p.message)
# 同步:跑完拿最终 message
last_msg = None
async for item in tm.run_stream(task):
if item.get("kind") == "message":
last_msg = item
return JSONResponse(ok(rid, last_msg or task.model_dump()))
if method == "message/stream" or ("text/event-stream" in accept):
try:
p = MessageSendParams.model_validate(params)
except Exception as e:
return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))
task = await tm.create_from_message(p.message)
async def gen() -> AsyncGenerator[bytes, None]:
# 先发 Task 快照
yield f"data: {{\"jsonrpc\":\"2.0\",\"id\":{json.dumps(rid)},\"result\":{task.model_dump_json()}}}\n\n".encode()
async for item in tm.run_stream(task):
out = {"jsonrpc": "2.0", "id": rid, "result": item}
yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
return StreamingResponse(gen(), media_type="text/event-stream")
if method == "tasks/get":
tid = (params or {}).get("id")
if not tid:
return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
t = await store.get_task(tid)
if not t:
return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
return JSONResponse(ok(rid, t.model_dump()))
if method == "tasks/cancel":
tid = (params or {}).get("id")
if not tid:
return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
t = await store.cancel(tid)
if not t:
return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
return JSONResponse(ok(rid, t.model_dump()))
if method == "tasks/resubscribe":
tid = (params or {}).get("id")
from_seq = int((params or {}).get("fromSeq", 0))
if not tid:
return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
# SSE 回放历史事件
async def gen() -> AsyncGenerator[bytes, None]:
async for rec in store.iter_events(tid, from_seq=from_seq):
out = {"jsonrpc": "2.0", "id": rid, "result": rec.payload}
yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
return StreamingResponse(gen(), media_type="text/event-stream")
if method == "agent/getAuthenticatedExtendedCard":
card = AgentCard(
name="Q&A Agent",
description="A2A JSON-RPC demo agent for Q&A",
version="1.0.0",
url="http://localhost:8000/rpc",
capabilities=AgentCapabilities(streaming=True, pushNotifications=False),
skills=[{"id":"qa","name":"Question Answering","tags":["qa","chat"],"description":"answer general questions"}],
additionalInterfaces=[],
)
return JSONResponse(ok(rid, card.model_dump()))
return JSONResponse(err(rid, METHOD_NOT_FOUND, "Method not found"))
---
## app/mcp.py
```python
from __future__ import annotations
import json
from typing import Optional, List
try:
from fastmcp import Client
except Exception: # fastmcp 未安装时的占位
Client = None
class MemoryClient:
"""基于 fastmcp 2.0 的短期记忆客户端封装。
通过 MCP Server 暴露的工具:mem_put / mem_get / mem_append / mem_recent / mem_clear
"""
def __init__(self, url: Optional[str]) -> None:
self.url = url
async def _call(self, name: str, params: dict):
if not self.url or Client is None:
raise RuntimeError("MCP client unavailable")
async with Client(self.url) as c:
return await c.call_tool(name, params)
async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600):
await self._call("mem_append", {"session_id": session_id, "role": role, "text": text,
"max_items": max_items, "ttl_s": ttl_s})
async def recent(self, session_id: str, n: int = 10) -> List[dict]:
res = await self._call("mem_recent", {"session_id": session_id, "n": n})
try:
return json.loads(res.text or "[]")
except Exception:
return []
async def put(self, session_id: str, key: str, value: str, ttl_s: int = 3600):
await self._call("mem_put", {"session_id": session_id, "key": key, "value": value, "ttl_s": ttl_s})
async def get(self, session_id: str, key: str, default: Optional[str] = None, ttl_s: int = 3600) -> Optional[str]:
res = await self._call("mem_get", {"session_id": session_id, "key": key, "default": default, "ttl_s": ttl_s})
return res.text
async def clear(self, session_id: str):
await self._call("mem_clear", {"session_id": session_id})
class NullMemory(MemoryClient):
"""空实现:当未配置 MCP_SERVER_URL 或 fastmcp 未安装时使用。"""
def __init__(self):
super().__init__(url=None)
async def _call(self, *a, **kw):
return None
async def append(self, *a, **kw):
return None
async def recent(self, *a, **kw):
return []
async def put(self, *a, **kw):
return None
async def get(self, *a, **kw):
return None
async def clear(self, *a, **kw):
return None
UPDATED requirements.txt (add fastmcp)
txt
fastapi==0.112.2
uvicorn[standard]==0.30.6
pydantic==2.8.2
redis==5.0.7
SQLAlchemy==2.0.32
asyncpg==0.29.0
python-dotenv==1.0.1
fastmcp>=2.0.0
UPDATED .env.example (add MCP_SERVER_URL)
env
# 选择存储:memory | redis | pg
STORAGE_BACKEND=memory
# Redis(当 STORAGE_BACKEND=redis 时)
REDIS_URL=redis://localhost:6379/0
# Postgres(当 STORAGE_BACKEND=pg 时)
PG_DSN=postgresql+asyncpg://postgres:postgres@localhost:5432/a2a
# 安全:设置后才启用校验(可选)
API_KEY=
BEARER_TOKEN=
# MCP 短期记忆服务(fastmcp server.py 以 HTTP/SSE 暴露时)
MCP_SERVER_URL=http://127.0.0.1:8010/mcp
# 服务
HOST=0.0.0.0
PORT=8000
NEW app/mcp.py
python
from __future__ import annotations
import json
from typing import Optional, List
try:
from fastmcp import Client
except Exception: # fastmcp 未安装时的占位
Client = None
class MemoryClient:
"""基于 fastmcp 2.0 的短期记忆客户端封装。
通过 MCP Server 暴露的工具:mem_put / mem_get / mem_append / mem_recent / mem_clear
"""
def __init__(self, url: Optional[str]) -> None:
self.url = url
async def _call(self, name: str, params: dict):
if not self.url or Client is None:
raise RuntimeError("MCP client unavailable")
async with Client(self.url) as c:
return await c.call_tool(name, params)
async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600):
await self._call("mem_append", {"session_id": session_id, "role": role, "text": text,
"max_items": max_items, "ttl_s": ttl_s})
async def recent(self, session_id: str, n: int = 10) -> List[dict]:
res = await self._call("mem_recent", {"session_id": session_id, "n": n})
try:
return json.loads(res.text or "[]")
except Exception:
return []
async def put(self, session_id: str, key: str, value: str, ttl_s: int = 3600):
await self._call("mem_put", {"session_id": session_id, "key": key, "value": value, "ttl_s": ttl_s})
async def get(self, session_id: str, key: str, default: Optional[str] = None, ttl_s: int = 3600) -> Optional[str]:
res = await self._call("mem_get", {"session_id": session_id, "key": key, "default": default, "ttl_s": ttl_s})
return res.text
async def clear(self, session_id: str):
await self._call("mem_clear", {"session_id": session_id})
class NullMemory(MemoryClient):
"""空实现:当未配置 MCP_SERVER_URL 或 fastmcp 未安装时使用。"""
def __init__(self):
super().__init__(url=None)
async def _call(self, *a, **kw):
return None
async def append(self, *a, **kw):
return None
async def recent(self, *a, **kw):
return []
async def put(self, *a, **kw):
return None
async def get(self, *a, **kw):
return None
async def clear(self, *a, **kw):
return None
UPDATED app/tasks.py(集成短期记忆:写入/读取)
python
from __future__ import annotations
import uuid
from typing import AsyncGenerator, Dict, Optional, List
from datetime import datetime, timezone
from app.models import (
Message, TextPart, Task, TaskStatus,
TaskStatusUpdateEvent
)
from app.llm import stream_answer
from app.storage.base import Storage
# 新增:MCP 记忆客户端接口(在 main.py 中传入)
class MemoryClientProto:
async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600): ...
async def recent(self, session_id: str, n: int = 10) -> List[dict]: ...
_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _extract_text(msg: Message) -> str:
parts = [p for p in msg.parts if isinstance(p, TextPart)]
return parts[0].text if parts else ""
class TaskManager:
def __init__(self, store: Storage, memory: Optional[MemoryClientProto] = None) -> None:
self.store = store
self.memory = memory
async def create_from_message(self, msg: Message) -> Task:
task = await self.store.create_task(msg)
# 把用户输入写入短期记忆(按 contextId 做会话隔离)
if self.memory:
try:
await self.memory.append(task.contextId, role="user", text=_extract_text(msg))
except Exception:
pass
return task
async def run_stream(self, task: Task) -> AsyncGenerator[Dict, None]:
# 进入 working
await self.store.set_status(task.id, TaskStatus(state="working", timestamp=_iso()))
evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
status=TaskStatus(state="working", timestamp=_iso()))
await self.store.append_event(task.id, evt.model_dump())
yield evt.model_dump()
# 取用户问题
user_msg = task.history[-1]
user_text = _extract_text(user_msg)
# 从短期记忆取最近若干条,拼装到提示里
memory_snippets: List[dict] = []
if self.memory:
try:
memory_snippets = await self.memory.recent(task.contextId, n=10)
except Exception:
memory_snippets = []
if memory_snippets:
mem_ctx = "\n".join(f"{it.get('role')}: {it.get('text')}" for it in memory_snippets)
prompt = f"[会话短期记忆]\n{mem_ctx}\n\n[当前提问]\n{user_text}"
else:
prompt = user_text
# 流式生成
accum = []
async for ch in stream_answer(prompt):
if await self.store.is_canceled(task.id):
# 推送最终 canceled
await self.store.set_status(task.id, TaskStatus(state="canceled", timestamp=_iso()))
final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
status=TaskStatus(state="canceled", timestamp=_iso()), final=True)
await self.store.append_event(task.id, final_evt.model_dump())
yield final_evt.model_dump()
return
accum.append(ch)
# 最终答案
answer = "".join(accum)
agent_msg = Message(kind="message", messageId=str(uuid.uuid4()), role="agent",
parts=[TextPart(kind="text", text=answer)],
contextId=task.contextId, taskId=task.id)
await self.store.append_history(task.id, agent_msg)
await self.store.append_event(task.id, agent_msg.model_dump())
yield agent_msg.model_dump()
# 写入短期记忆(agent 回复)
if self.memory:
try:
await self.memory.append(task.contextId, role="agent", text=answer)
except Exception:
pass
# 完成状态
await self.store.set_status(task.id, TaskStatus(state="completed", timestamp=_iso()))
final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
status=TaskStatus(state="completed", timestamp=_iso()), final=True)
await self.store.append_event(task.id, final_evt.model_dump())
yield final_evt.model_dump()
UPDATED app/main.py(初始化 MCP 记忆客户端并注入 TaskManager)
python
from __future__ import annotations
import os, json
from typing import AsyncGenerator
from fastapi import FastAPI, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
from app.models import JSONRPCRequest, MessageSendParams, AgentCard, AgentCapabilities
from app.jsonrpc import ok, err, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, TASK_NOT_FOUND
from app.security.auth import require_auth
from app.storage.base import Storage
from app.storage.memory import MemoryStorage
from app.storage.redis import RedisStorage
# from app.storage.pg import PostgresStorage
from app.tasks import TaskManager
# === MCP 记忆客户端 ===
from app.mcp import MemoryClient, NullMemory
app = FastAPI(title="A2A JSON-RPC Q&A Agent")
# 选择存储后端
backend = os.getenv("STORAGE_BACKEND", "memory")
if backend == "redis":
store: Storage = RedisStorage(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
elif backend == "pg":
# store = PostgresStorage(os.getenv("PG_DSN", "postgresql+asyncpg://postgres:postgres@localhost:5432/a2a"))
store = MemoryStorage() # 占位:还未实现,暂退回内存
else:
store = MemoryStorage()
# MCP 短期记忆(可选)
mcp_url = os.getenv("MCP_SERVER_URL")
memory = MemoryClient(mcp_url) if mcp_url else NullMemory()
tm = TaskManager(store, memory)
@app.post("/rpc")
async def rpc(request: Request, _=Depends(require_auth)):
try:
payload = await request.json()
except Exception:
return JSONResponse(err(None, PARSE_ERROR, "Invalid JSON payload"))
try:
rpc_req = JSONRPCRequest.model_validate(payload)
except Exception as e:
return JSONResponse(err(None, INVALID_REQUEST, "Invalid Request", str(e)))
rid = rpc_req.id
method = rpc_req.method
params = rpc_req.params or {}
accept = request.headers.get("accept", "")
if method == "message/send":
try:
p = MessageSendParams.model_validate(params)
except Exception as e:
return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))
task = await tm.create_from_message(p.message)
# 同步:跑完拿最终 message
last_msg = None
async for item in tm.run_stream(task):
if item.get("kind") == "message":
last_msg = item
return JSONResponse(ok(rid, last_msg or task.model_dump()))
if method == "message/stream" or ("text/event-stream" in accept):
try:
p = MessageSendParams.model_validate(params)
except Exception as e:
return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))
task = await tm.create_from_message(p.message)
async def gen() -> AsyncGenerator[bytes, None]:
# 先发 Task 快照
yield f"data: {{\"jsonrpc\":\"2.0\",\"id\":{json.dumps(rid)},\"result\":{task.model_dump_json()}}}\n\n".encode()
async for item in tm.run_stream(task):
out = {"jsonrpc": "2.0", "id": rid, "result": item}
yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
return StreamingResponse(gen(), media_type="text/event-stream")
if method == "tasks/get":
tid = (params or {}).get("id")
if not tid:
return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
t = await store.get_task(tid)
if not t:
return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
return JSONResponse(ok(rid, t.model_dump()))
if method == "tasks/cancel":
tid = (params or {}).get("id")
if not tid:
return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
t = await store.cancel(tid)
if not t:
return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
return JSONResponse(ok(rid, t.model_dump()))
if method == "tasks/resubscribe":
tid = (params or {}).get("id")
from_seq = int((params or {}).get("fromSeq", 0))
if not tid:
return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
# SSE 回放历史事件
async def gen() -> AsyncGenerator[bytes, None]:
async for rec in store.iter_events(tid, from_seq=from_seq):
out = {"jsonrpc": "2.0", "id": rid, "result": rec.payload}
yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
return StreamingResponse(gen(), media_type="text/event-stream")
if method == "agent/getAuthenticatedExtendedCard":
card = AgentCard(
name="Q&A Agent",
description="A2A JSON-RPC demo agent for Q&A (with MCP memory)",
version="1.1.0",
url="http://localhost:8000/rpc",
capabilities=AgentCapabilities(streaming=True, pushNotifications=False),
skills=[{"id":"qa","name":"Question Answering","tags":["qa","chat"],"description":"answer general questions"}],
additionalInterfaces=[],
)
return JSONResponse(ok(rid, card.model_dump()))
return JSONResponse(err(rid, METHOD_NOT_FOUND, "Method not found"))
**fastmcp 2.0 的"短期记忆(Redis)"**集成到你现有的 A2A × FastAPI × JSON-RPC 脚手架里了:包括新增 app/mcp.py、改造 app/tasks.py(读写会话记忆)和 app/main.py(注入 MemoryClient),并更新了 requirements.txt、.env.example。完整代码已经放到右侧画布里(新增/UPDATED 小节)。
快速跑通(本地):
- 起 Redis 与 MCP 短期记忆服务
bash
docker run -d --name redis -p 6379:6379 redis:7-alpine
另开一个终端,启动你之前拿到的 fastmcp Redis 内存 server(server.py)
确保它运行在 http://127.0.0.1:8010/mcp
bash
python server.py
2. 配置并启动 Agent
项目根目录
bash
cp .env.example .env
打开 .env,确认:
bash
# MCP_SERVER_URL=http://127.0.0.1:8010/mcp
bash
pip install -r requirements.txt
uvicorn app.main:app --reload
3. 体验短期记忆生效的流式对话(A2A)
第一句会被写入记忆
bash
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
"jsonrpc":"2.0","id":"s1","method":"message/stream",
"params":{"message":{"kind":"message","messageId":"m1","role":"user",
"parts":[{"kind":"text","text":"我叫小王,后面记住我"}]}}
}'
第二句生成时会把最近记忆拼进提示(可观察回答更"知道上下文")
bash
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
"jsonrpc":"2.0","id":"s2","method":"message/stream",
"params":{"message":{"kind":"message","messageId":"m2","role":"user",
"parts":[{"kind":"text","text":"提醒下我叫什么?"}]}}
}'
要点:
• 我把 A2A 的 contextId 当作 session_id 传给 fastmcp 的记忆工具,自动会话隔离;
• 每次用户输入与 Agent 输出都会 mem_append,生成前会拉取 mem_recent 拼进提示;
• 未配置 MCP_SERVER_URL 时自动降级为空实现(不影响运行)。