**FastAPI + Pydantic v2 + JSON‑RPC 2.0**,实现 A2A 规范核心方法

python 复制代码
# A2A × FastAPI × JSON‑RPC 智能问答 Agent · 项目脚手架

> 可直接运行的工程骨架,采用 **FastAPI + Pydantic v2 + JSON‑RPC 2.0**,实现 A2A 规范核心方法:
> - `message/send`(同步短问答)
> - `message/stream`(SSE 多次成功响应)
> - `tasks/get`、`tasks/cancel`、`tasks/resubscribe`(断线续传,回放历史事件)
> - `agent/getAuthenticatedExtendedCard`(Agent 自描述)
>
> 存储后端可切换 **memory / redis / pg**(默认 memory 可直接跑)。

---

## 目录结构

a2a-fastapi/

├─ README.md

├─ requirements.txt

├─ .env.example

├─ docker-compose.yml # 可选:Redis + Postgres(开发)

└─ app/

├─ main.py # 路由与 JSON-RPC 分发

├─ models.py # A2A & JSON-RPC Pydantic 模型

├─ jsonrpc.py # 成功/错误响应助手 + 错误码

├─ llm.py # 模型调用(含流式模拟,可替换为真实 SDK)

├─ tasks.py # TaskManager:状态机 & 流式执行

├─ security/

│ └─ auth.py # API Key/Bearer 校验钩子

└─ storage/

├─ base.py # 抽象接口(Storage)

├─ memory.py # 内存实现(默认)

├─ redis.py # Redis 实现(可选)

└─ pg.py # Postgres 实现(示例/占位)

复制代码
---

## requirements.txt
```txt
fastapi==0.112.2
uvicorn[standard]==0.30.6
pydantic==2.8.2
redis==5.0.7
SQLAlchemy==2.0.32
asyncpg==0.29.0
python-dotenv==1.0.1
fastmcp>=2.0.0

.env.example

env 复制代码
# 选择存储:memory | redis | pg
STORAGE_BACKEND=memory

# Redis(当 STORAGE_BACKEND=redis 时)
REDIS_URL=redis://localhost:6379/0

# Postgres(当 STORAGE_BACKEND=pg 时)
PG_DSN=postgresql+asyncpg://postgres:postgres@localhost:5432/a2a

# 安全:设置后才启用校验(可选)
API_KEY=
BEARER_TOKEN=

# MCP 短期记忆服务(fastmcp server.py 以 HTTP/SSE 暴露时)
MCP_SERVER_URL=http://127.0.0.1:8010/mcp

# 服务
HOST=0.0.0.0
PORT=8000

docker-compose.yml(开发可选)

yaml 复制代码
version: "3.9"
services:
  redis:
    image: redis:7-alpine
    ports: ["6379:6379"]
  postgres:
    image: postgres:16-alpine
    environment:
      POSTGRES_PASSWORD: postgres
      POSTGRES_USER: postgres
      POSTGRES_DB: a2a
    ports: ["5432:5432"]

README.md(节选)

markdown 复制代码
### 运行
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
cp .env.example .env  # 默认 memory 后端
uvicorn app.main:app --reload

### cURL 示例
# (1) 同步问答:message/send
curl -s http://localhost:8000/rpc -H 'Content-Type: application/json' -d '{
  "jsonrpc":"2.0","id":"1","method":"message/send",
  "params":{"message":{"kind":"message","messageId":"m1","role":"user","parts":[{"kind":"text","text":"什么是白露?"}]}}
}' | jq

# (2) 流式问答(SSE):message/stream(同一 id,多条成功响应)
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
  "jsonrpc":"2.0","id":"42","method":"message/stream",
  "params":{"message":{"kind":"message","messageId":"m2","role":"user","parts":[{"kind":"text","text":"介绍下二十四节气"}]}}
}'

# (3) 取消任务
aid=$(uuidgen); echo $aid
curl -s http://localhost:8000/rpc -H 'Content-Type: application/json' -d '{
  "jsonrpc":"2.0","id":"3","method":"tasks/cancel","params":{"id":"TASK_ID"}
}' | jq

# (4) 查询任务
aid="TASK_ID"
curl -s http://localhost:8000/rpc -H 'Content-Type: application/json' -d "{\n  \"jsonrpc\":\"2.0\",\"id\":\"4\",\"method\":\"tasks/get\",\"params\":{\"id\":\"$aid\"}\n}" | jq

# (5) 续订流(回放事件)
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
  "jsonrpc":"2.0","id":"5","method":"tasks/resubscribe",
  "params":{"id":"TASK_ID","fromSeq":0}
}'

app/models.py

python 复制代码
from __future__ import annotations
from typing import List, Optional, Literal, Dict, Any, Annotated, Union
from pydantic import BaseModel, Field

# ===== Parts(判别式联合 by kind) =====
class TextPart(BaseModel):
    kind: Literal["text"]
    text: str
    metadata: Optional[Dict[str, Any]] = None

class DataPart(BaseModel):
    kind: Literal["data"]
    data: Dict[str, Any]
    metadata: Optional[Dict[str, Any]] = None

Part = Annotated[Union[TextPart, DataPart], Field(discriminator="kind")]

# ===== Message =====
class Message(BaseModel):
    kind: Literal["message"] = "message"
    messageId: str
    role: Literal["user", "agent"]
    parts: List[Part]
    contextId: Optional[str] = None
    taskId: Optional[str] = None
    metadata: Optional[Dict[str, Any]] = None
    extensions: Optional[List[str]] = None
    referenceTaskIds: Optional[List[str]] = None

# ===== Task / Status =====
class TaskStatus(BaseModel):
    state: Literal[
        "submitted","working","input-required","completed","canceled",
        "failed","rejected","auth-required","unknown"
    ]
    timestamp: Optional[str] = None
    message: Optional[Message] = None

class Task(BaseModel):
    kind: Literal["task"] = "task"
    id: str
    contextId: str
    status: TaskStatus
    history: List[Message] = Field(default_factory=list)
    artifacts: List[Dict[str, Any]] = Field(default_factory=list)
    metadata: Optional[Dict[str, Any]] = None

# ===== Events =====
class TaskStatusUpdateEvent(BaseModel):
    kind: Literal["status-update"] = "status-update"
    taskId: str
    contextId: str
    status: TaskStatus
    final: bool = False
    metadata: Optional[Dict[str, Any]] = None

class TaskArtifactUpdateEvent(BaseModel):
    kind: Literal["artifact-update"] = "artifact-update"
    taskId: str
    contextId: str
    artifact: Dict[str, Any]
    append: Optional[bool] = None
    lastChunk: Optional[bool] = None
    metadata: Optional[Dict[str, Any]] = None

Event = Annotated[
    Union[TaskStatusUpdateEvent, TaskArtifactUpdateEvent, Message],
    Field(discriminator="kind")
]

# ===== JSON-RPC =====
class JSONRPCRequest(BaseModel):
    jsonrpc: Literal["2.0"]
    method: str
    id: str | int | None = None
    params: Optional[Dict[str, Any]] = None

class JSONRPCSuccessResponse(BaseModel):
    jsonrpc: Literal["2.0"] = "2.0"
    id: str | int | None
    result: Any

class JSONRPCError(BaseModel):
    code: int
    message: str
    data: Optional[Any] = None

class JSONRPCErrorResponse(BaseModel):
    jsonrpc: Literal["2.0"] = "2.0"
    id: str | int | None
    error: JSONRPCError

# ===== Send params =====
class MessageSendConfiguration(BaseModel):
    acceptedOutputModes: Optional[List[str]] = None
    blocking: Optional[bool] = None
    historyLength: Optional[int] = None

class MessageSendParams(BaseModel):
    message: Message
    configuration: Optional[MessageSendConfiguration] = None
    metadata: Optional[Dict[str, Any]] = None

# ===== AgentCard(简化) =====
class AgentCapabilities(BaseModel):
    streaming: Optional[bool] = True
    pushNotifications: Optional[bool] = False

class AgentInterface(BaseModel):
    transport: Literal["JSONRPC"] = "JSONRPC"
    url: str

class AgentCard(BaseModel):
    name: str
    description: str
    version: str
    protocolVersion: str = "0.3.0"
    url: str
    preferredTransport: Literal["JSONRPC"] = "JSONRPC"
    capabilities: AgentCapabilities
    defaultInputModes: list[str] = ["text/plain"]
    defaultOutputModes: list[str] = ["text/plain"]
    skills: list[Dict[str, Any]] = Field(default_factory=list)
    additionalInterfaces: list[AgentInterface] = Field(default_factory=list)

app/jsonrpc.py

python 复制代码
from __future__ import annotations
from typing import Any, Dict
from datetime import datetime, timezone
from app.models import JSONRPCSuccessResponse, JSONRPCErrorResponse, JSONRPCError

# 错误码(A2A/JSON-RPC 常用子集)
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603
TASK_NOT_FOUND = -32001
TASK_NOT_CANCELABLE = -32002
UNSUPPORTED_OPERATION = -32004

iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

def ok(id_: str|int|None, result: Any) -> Dict[str, Any]:
    return JSONRPCSuccessResponse(id=id_, result=result).model_dump()

def err(id_: str|int|None, code: int, message: str, data: Any|None=None) -> Dict[str, Any]:
    return JSONRPCErrorResponse(id=id_, error=JSONRPCError(code=code, message=message, data=data)).model_dump()

app/llm.py

python 复制代码
from __future__ import annotations
import asyncio
from typing import AsyncGenerator, Optional

# 你可以把这里替换为真实 LLM SDK(OpenAI/阿里/Qwen/DeepSeek 等)的流式接口
# 例如:yield 每个 delta/choice 的内容片段

async def stream_answer(prompt: str, *, delay: float = 0.25) -> AsyncGenerator[str, None]:
    # 演示:固定分片 + 延迟,便于观察 SSE 效果
    pieces = [
        "白露是二十四节气之一,",
        "表示天气转凉、露水凝结,",
        "通常出现在公历九月上旬,",
        "民谚有'白露身不露'之说。",
    ]
    for p in pieces:
        await asyncio.sleep(delay)
        yield p

app/storage/base.py

python 复制代码
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple

from app.models import Task, TaskStatus, Message, Event

@dataclass
class EventRecord:
    seq: int
    payload: Dict[str, Any]  # 序列化后的 Event/Message

class Storage(ABC):
    @abstractmethod
    async def create_task(self, first_message: Message) -> Task: ...

    @abstractmethod
    async def get_task(self, task_id: str) -> Optional[Task]: ...

    @abstractmethod
    async def set_status(self, task_id: str, status: TaskStatus) -> None: ...

    @abstractmethod
    async def append_history(self, task_id: str, msg: Message) -> None: ...

    @abstractmethod
    async def cancel(self, task_id: str) -> Optional[Task]: ...

    @abstractmethod
    async def is_canceled(self, task_id: str) -> bool: ...

    @abstractmethod
    async def append_event(self, task_id: str, payload: Dict[str, Any]) -> int: ...

    @abstractmethod
    async def iter_events(self, task_id: str, from_seq: int = 0) -> AsyncGenerator[EventRecord, None]: ...

app/storage/memory.py

python 复制代码
from __future__ import annotations
import uuid
from typing import Dict, Any, AsyncGenerator, Optional, List
from datetime import datetime, timezone

from app.storage.base import Storage, EventRecord
from app.models import Task, TaskStatus, Message

_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

class MemoryStorage(Storage):
    def __init__(self) -> None:
        self.tasks: Dict[str, Task] = {}
        self.cancel_flags: Dict[str, bool] = {}
        self.events: Dict[str, List[Dict[str, Any]]] = {}

    async def create_task(self, first_message: Message) -> Task:
        task_id = str(uuid.uuid4())
        ctx_id = str(uuid.uuid4())
        t = Task(id=task_id, contextId=ctx_id,
                 status=TaskStatus(state="submitted", timestamp=_iso()),
                 history=[first_message])
        self.tasks[task_id] = t
        self.cancel_flags[task_id] = False
        self.events[task_id] = []
        return t

    async def get_task(self, task_id: str) -> Optional[Task]:
        return self.tasks.get(task_id)

    async def set_status(self, task_id: str, status: TaskStatus) -> None:
        t = self.tasks[task_id]
        t.status = status

    async def append_history(self, task_id: str, msg: Message) -> None:
        t = self.tasks[task_id]
        t.history.append(msg)

    async def cancel(self, task_id: str) -> Optional[Task]:
        t = self.tasks.get(task_id)
        if not t: return None
        self.cancel_flags[task_id] = True
        t.status = TaskStatus(state="canceled", timestamp=_iso())
        return t

    async def is_canceled(self, task_id: str) -> bool:
        return self.cancel_flags.get(task_id, False)

    async def append_event(self, task_id: str, payload: Dict[str, Any]) -> int:
        lst = self.events[task_id]
        lst.append(payload)
        return len(lst)  # seq = 1-based

    async def iter_events(self, task_id: str, from_seq: int = 0):
        lst = self.events.get(task_id, [])
        for i, p in enumerate(lst[from_seq:], start=from_seq+1):
            yield EventRecord(seq=i, payload=p)

app/storage/redis.py

python 复制代码
from __future__ import annotations
import uuid, json, asyncio
from typing import Any, AsyncGenerator, Optional
from datetime import datetime, timezone
import redis.asyncio as redis

from app.storage.base import Storage, EventRecord
from app.models import Task, TaskStatus, Message

_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

class RedisStorage(Storage):
    def __init__(self, url: str):
        self.r = redis.from_url(url, decode_responses=True)

    def _k_task(self, task_id: str) -> str: return f"a2a:task:{task_id}"
    def _k_events(self, task_id: str) -> str: return f"a2a:events:{task_id}"
    def _k_cancel(self, task_id: str) -> str: return f"a2a:cancel:{task_id}"

    async def create_task(self, first_message: Message) -> Task:
        task_id = str(uuid.uuid4())
        ctx_id = str(uuid.uuid4())
        t = Task(id=task_id, contextId=ctx_id,
                 status=TaskStatus(state="submitted", timestamp=_iso()),
                 history=[first_message])
        await self.r.hset(self._k_task(task_id), mapping={
            "task": t.model_dump_json()
        })
        await self.r.set(self._k_cancel(task_id), "0")
        return t

    async def get_task(self, task_id: str) -> Optional[Task]:
        d = await self.r.hget(self._k_task(task_id), "task")
        if not d: return None
        return Task.model_validate_json(d)

    async def set_status(self, task_id: str, status: TaskStatus) -> None:
        t = await self.get_task(task_id)
        if not t: return
        t.status = status
        await self.r.hset(self._k_task(task_id), mapping={"task": t.model_dump_json()})

    async def append_history(self, task_id: str, msg: Message) -> None:
        t = await self.get_task(task_id)
        if not t: return
        t.history.append(msg)
        await self.r.hset(self._k_task(task_id), mapping={"task": t.model_dump_json()})

    async def cancel(self, task_id: str) -> Optional[Task]:
        t = await self.get_task(task_id)
        if not t: return None
        await self.r.set(self._k_cancel(task_id), "1")
        t.status = TaskStatus(state="canceled", timestamp=_iso())
        await self.r.hset(self._k_task(task_id), mapping={"task": t.model_dump_json()})
        return t

    async def is_canceled(self, task_id: str) -> bool:
        v = await self.r.get(self._k_cancel(task_id))
        return v == "1"

    async def append_event(self, task_id: str, payload: dict) -> int:
        # 使用 Redis List 存事件,返回新长度作为 seq
        return await self.r.rpush(self._k_events(task_id), json.dumps(payload))

    async def iter_events(self, task_id: str, from_seq: int = 0) -> AsyncGenerator[EventRecord, None]:
        # 读取整个列表(开发足够),生产可换成 XREAD 流
        items = await self.r.lrange(self._k_events(task_id), from_seq, -1)
        seq = from_seq
        for raw in items:
            seq += 1
            yield EventRecord(seq=seq, payload=json.loads(raw))

app/storage/pg.py

python 复制代码
# 说明:提供最小的 SQLAlchemy 异步连接与占位,便于后续扩展到持久层。
from __future__ import annotations
from typing import Optional, AsyncGenerator, Dict
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, Text

from app.storage.base import Storage, EventRecord
from app.models import Task, TaskStatus, Message

class Base(DeclarativeBase):
    pass

class TaskRow(Base):
    __tablename__ = "tasks"
    id: Mapped[str] = mapped_column(String(64), primary_key=True)
    blob: Mapped[str] = mapped_column(Text)  # 存整个 Task JSON(演示)

# 为简洁,pg 实现留作示例,接口留空或抛 NotImplementedError,可按需补齐
class PostgresStorage(Storage):
    def __init__(self, dsn: str):
        self.engine = create_async_engine(dsn, echo=False)
        self.Session = async_sessionmaker(self.engine, expire_on_commit=False)

    async def create_task(self, first_message: Message) -> Task:
        raise NotImplementedError("Fill with real implementation")
    async def get_task(self, task_id: str) -> Optional[Task]:
        raise NotImplementedError()
    async def set_status(self, task_id: str, status: TaskStatus) -> None:
        raise NotImplementedError()
    async def append_history(self, task_id: str, msg: Message) -> None:
        raise NotImplementedError()
    async def cancel(self, task_id: str) -> Optional[Task]:
        raise NotImplementedError()
    async def is_canceled(self, task_id: str) -> bool:
        raise NotImplementedError()
    async def append_event(self, task_id: str, payload: Dict) -> int:
        raise NotImplementedError()
    async def iter_events(self, task_id: str, from_seq: int = 0):
        raise NotImplementedError()

app/tasks.py

python 复制代码
from __future__ import annotations
import uuid
from typing import AsyncGenerator, Dict, Optional, List
from datetime import datetime, timezone

from app.models import (
    Message, TextPart, Task, TaskStatus,
    TaskStatusUpdateEvent
)
from app.llm import stream_answer
from app.storage.base import Storage

# 新增:MCP 记忆客户端接口(在 main.py 中传入)
class MemoryClientProto:
    async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600): ...
    async def recent(self, session_id: str, n: int = 10) -> List[dict]: ...

_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

def _extract_text(msg: Message) -> str:
    parts = [p for p in msg.parts if isinstance(p, TextPart)]
    return parts[0].text if parts else ""

class TaskManager:
    def __init__(self, store: Storage, memory: Optional[MemoryClientProto] = None) -> None:
        self.store = store
        self.memory = memory

    async def create_from_message(self, msg: Message) -> Task:
        task = await self.store.create_task(msg)
        # 把用户输入写入短期记忆(按 contextId 做会话隔离)
        if self.memory:
            try:
                await self.memory.append(task.contextId, role="user", text=_extract_text(msg))
            except Exception:
                pass
        return task

    async def run_stream(self, task: Task) -> AsyncGenerator[Dict, None]:
        # 进入 working
        await self.store.set_status(task.id, TaskStatus(state="working", timestamp=_iso()))
        evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
                                    status=TaskStatus(state="working", timestamp=_iso()))
        await self.store.append_event(task.id, evt.model_dump())
        yield evt.model_dump()

        # 取用户问题
        user_msg = task.history[-1]
        user_text = _extract_text(user_msg)

        # 从短期记忆取最近若干条,拼装到提示里
        memory_snippets: List[dict] = []
        if self.memory:
            try:
                memory_snippets = await self.memory.recent(task.contextId, n=10)
            except Exception:
                memory_snippets = []
        if memory_snippets:
            mem_ctx = "\n".join(f"{it.get('role')}: {it.get('text')}" for it in memory_snippets)
            prompt = f"[会话短期记忆]\n{mem_ctx}\n\n[当前提问]\n{user_text}"
        else:
            prompt = user_text

        # 流式生成
        accum = []
        async for ch in stream_answer(prompt):
            if await self.store.is_canceled(task.id):
                # 推送最终 canceled
                await self.store.set_status(task.id, TaskStatus(state="canceled", timestamp=_iso()))
                final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
                                                  status=TaskStatus(state="canceled", timestamp=_iso()), final=True)
                await self.store.append_event(task.id, final_evt.model_dump())
                yield final_evt.model_dump()
                return
            accum.append(ch)

        # 最终答案
        answer = "".join(accum)
        agent_msg = Message(kind="message", messageId=str(uuid.uuid4()), role="agent",
                            parts=[TextPart(kind="text", text=answer)],
                            contextId=task.contextId, taskId=task.id)
        await self.store.append_history(task.id, agent_msg)
        await self.store.append_event(task.id, agent_msg.model_dump())
        yield agent_msg.model_dump()

        # 写入短期记忆(agent 回复)
        if self.memory:
            try:
                await self.memory.append(task.contextId, role="agent", text=answer)
            except Exception:
                pass

        # 完成状态
        await self.store.set_status(task.id, TaskStatus(state="completed", timestamp=_iso()))
        final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
                                          status=TaskStatus(state="completed", timestamp=_iso()), final=True)
        await self.store.append_event(task.id, final_evt.model_dump())
        yield final_evt.model_dump()

app/security/auth.py

python 复制代码
from __future__ import annotations
import os
from fastapi import Header, HTTPException

async def require_auth(x_api_key: str | None = Header(default=None), authorization: str | None = Header(default=None)):
    api_key = os.getenv("API_KEY")
    bearer = os.getenv("BEARER_TOKEN")
    if api_key:
        if x_api_key != api_key:
            raise HTTPException(status_code=401, detail="invalid api key")
    if bearer:
        token = None
        if authorization and authorization.lower().startswith("bearer "):
            token = authorization[7:].strip()
        if token != bearer:
            raise HTTPException(status_code=401, detail="invalid bearer token")

app/main.py

python 复制代码
from __future__ import annotations
import os, json
from typing import AsyncGenerator
from fastapi import FastAPI, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse

from app.models import JSONRPCRequest, MessageSendParams, AgentCard, AgentCapabilities
from app.jsonrpc import ok, err, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, TASK_NOT_FOUND
from app.security.auth import require_auth
from app.storage.base import Storage
from app.storage.memory import MemoryStorage
from app.storage.redis import RedisStorage
# from app.storage.pg import PostgresStorage
from app.tasks import TaskManager

app = FastAPI(title="A2A JSON-RPC Q&A Agent")

# 选择存储后端
backend = os.getenv("STORAGE_BACKEND", "memory")
if backend == "redis":
    store: Storage = RedisStorage(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
elif backend == "pg":
    # store = PostgresStorage(os.getenv("PG_DSN", "postgresql+asyncpg://postgres:postgres@localhost:5432/a2a"))
    store = MemoryStorage()  # 占位:还未实现,暂退回内存
else:
    store = MemoryStorage()

tm = TaskManager(store)

@app.post("/rpc")
async def rpc(request: Request, _=Depends(require_auth)):
    try:
        payload = await request.json()
    except Exception:
        return JSONResponse(err(None, PARSE_ERROR, "Invalid JSON payload"))

    try:
        rpc_req = JSONRPCRequest.model_validate(payload)
    except Exception as e:
        return JSONResponse(err(None, INVALID_REQUEST, "Invalid Request", str(e)))

    rid = rpc_req.id
    method = rpc_req.method
    params = rpc_req.params or {}

    accept = request.headers.get("accept", "")

    if method == "message/send":
        try:
            p = MessageSendParams.model_validate(params)
        except Exception as e:
            return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))
        task = await tm.create_from_message(p.message)
        # 同步:跑完拿最终 message
        last_msg = None
        async for item in tm.run_stream(task):
            if item.get("kind") == "message":
                last_msg = item
        return JSONResponse(ok(rid, last_msg or task.model_dump()))

    if method == "message/stream" or ("text/event-stream" in accept):
        try:
            p = MessageSendParams.model_validate(params)
        except Exception as e:
            return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))

        task = await tm.create_from_message(p.message)

        async def gen() -> AsyncGenerator[bytes, None]:
            # 先发 Task 快照
            yield f"data: {{\"jsonrpc\":\"2.0\",\"id\":{json.dumps(rid)},\"result\":{task.model_dump_json()}}}\n\n".encode()
            async for item in tm.run_stream(task):
                out = {"jsonrpc": "2.0", "id": rid, "result": item}
                yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
        return StreamingResponse(gen(), media_type="text/event-stream")

    if method == "tasks/get":
        tid = (params or {}).get("id")
        if not tid:
            return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
        t = await store.get_task(tid)
        if not t:
            return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
        return JSONResponse(ok(rid, t.model_dump()))

    if method == "tasks/cancel":
        tid = (params or {}).get("id")
        if not tid:
            return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
        t = await store.cancel(tid)
        if not t:
            return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
        return JSONResponse(ok(rid, t.model_dump()))

    if method == "tasks/resubscribe":
        tid = (params or {}).get("id")
        from_seq = int((params or {}).get("fromSeq", 0))
        if not tid:
            return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
        # SSE 回放历史事件
        async def gen() -> AsyncGenerator[bytes, None]:
            async for rec in store.iter_events(tid, from_seq=from_seq):
                out = {"jsonrpc": "2.0", "id": rid, "result": rec.payload}
                yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
        return StreamingResponse(gen(), media_type="text/event-stream")

    if method == "agent/getAuthenticatedExtendedCard":
        card = AgentCard(
            name="Q&A Agent",
            description="A2A JSON-RPC demo agent for Q&A",
            version="1.0.0",
            url="http://localhost:8000/rpc",
            capabilities=AgentCapabilities(streaming=True, pushNotifications=False),
            skills=[{"id":"qa","name":"Question Answering","tags":["qa","chat"],"description":"answer general questions"}],
            additionalInterfaces=[],
        )
        return JSONResponse(ok(rid, card.model_dump()))

    return JSONResponse(err(rid, METHOD_NOT_FOUND, "Method not found"))
---

## app/mcp.py
```python
from __future__ import annotations
import json
from typing import Optional, List

try:
    from fastmcp import Client
except Exception:  # fastmcp 未安装时的占位
    Client = None

class MemoryClient:
    """基于 fastmcp 2.0 的短期记忆客户端封装。
    通过 MCP Server 暴露的工具:mem_put / mem_get / mem_append / mem_recent / mem_clear
    """
    def __init__(self, url: Optional[str]) -> None:
        self.url = url

    async def _call(self, name: str, params: dict):
        if not self.url or Client is None:
            raise RuntimeError("MCP client unavailable")
        async with Client(self.url) as c:
            return await c.call_tool(name, params)

    async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600):
        await self._call("mem_append", {"session_id": session_id, "role": role, "text": text,
                                         "max_items": max_items, "ttl_s": ttl_s})

    async def recent(self, session_id: str, n: int = 10) -> List[dict]:
        res = await self._call("mem_recent", {"session_id": session_id, "n": n})
        try:
            return json.loads(res.text or "[]")
        except Exception:
            return []

    async def put(self, session_id: str, key: str, value: str, ttl_s: int = 3600):
        await self._call("mem_put", {"session_id": session_id, "key": key, "value": value, "ttl_s": ttl_s})

    async def get(self, session_id: str, key: str, default: Optional[str] = None, ttl_s: int = 3600) -> Optional[str]:
        res = await self._call("mem_get", {"session_id": session_id, "key": key, "default": default, "ttl_s": ttl_s})
        return res.text

    async def clear(self, session_id: str):
        await self._call("mem_clear", {"session_id": session_id})

class NullMemory(MemoryClient):
    """空实现:当未配置 MCP_SERVER_URL 或 fastmcp 未安装时使用。"""
    def __init__(self):
        super().__init__(url=None)
    async def _call(self, *a, **kw):
        return None
    async def append(self, *a, **kw):
        return None
    async def recent(self, *a, **kw):
        return []
    async def put(self, *a, **kw):
        return None
    async def get(self, *a, **kw):
        return None
    async def clear(self, *a, **kw):
        return None

UPDATED requirements.txt (add fastmcp)

txt 复制代码
fastapi==0.112.2
uvicorn[standard]==0.30.6
pydantic==2.8.2
redis==5.0.7
SQLAlchemy==2.0.32
asyncpg==0.29.0
python-dotenv==1.0.1
fastmcp>=2.0.0

UPDATED .env.example (add MCP_SERVER_URL)

env 复制代码
# 选择存储:memory | redis | pg
STORAGE_BACKEND=memory

# Redis(当 STORAGE_BACKEND=redis 时)
REDIS_URL=redis://localhost:6379/0

# Postgres(当 STORAGE_BACKEND=pg 时)
PG_DSN=postgresql+asyncpg://postgres:postgres@localhost:5432/a2a

# 安全:设置后才启用校验(可选)
API_KEY=
BEARER_TOKEN=

# MCP 短期记忆服务(fastmcp server.py 以 HTTP/SSE 暴露时)
MCP_SERVER_URL=http://127.0.0.1:8010/mcp

# 服务
HOST=0.0.0.0
PORT=8000

NEW app/mcp.py

python 复制代码
from __future__ import annotations
import json
from typing import Optional, List

try:
    from fastmcp import Client
except Exception:  # fastmcp 未安装时的占位
    Client = None

class MemoryClient:
    """基于 fastmcp 2.0 的短期记忆客户端封装。
    通过 MCP Server 暴露的工具:mem_put / mem_get / mem_append / mem_recent / mem_clear
    """
    def __init__(self, url: Optional[str]) -> None:
        self.url = url

    async def _call(self, name: str, params: dict):
        if not self.url or Client is None:
            raise RuntimeError("MCP client unavailable")
        async with Client(self.url) as c:
            return await c.call_tool(name, params)

    async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600):
        await self._call("mem_append", {"session_id": session_id, "role": role, "text": text,
                                         "max_items": max_items, "ttl_s": ttl_s})

    async def recent(self, session_id: str, n: int = 10) -> List[dict]:
        res = await self._call("mem_recent", {"session_id": session_id, "n": n})
        try:
            return json.loads(res.text or "[]")
        except Exception:
            return []

    async def put(self, session_id: str, key: str, value: str, ttl_s: int = 3600):
        await self._call("mem_put", {"session_id": session_id, "key": key, "value": value, "ttl_s": ttl_s})

    async def get(self, session_id: str, key: str, default: Optional[str] = None, ttl_s: int = 3600) -> Optional[str]:
        res = await self._call("mem_get", {"session_id": session_id, "key": key, "default": default, "ttl_s": ttl_s})
        return res.text

    async def clear(self, session_id: str):
        await self._call("mem_clear", {"session_id": session_id})

class NullMemory(MemoryClient):
    """空实现:当未配置 MCP_SERVER_URL 或 fastmcp 未安装时使用。"""
    def __init__(self):
        super().__init__(url=None)
    async def _call(self, *a, **kw):
        return None
    async def append(self, *a, **kw):
        return None
    async def recent(self, *a, **kw):
        return []
    async def put(self, *a, **kw):
        return None
    async def get(self, *a, **kw):
        return None
    async def clear(self, *a, **kw):
        return None

UPDATED app/tasks.py(集成短期记忆:写入/读取)

python 复制代码
from __future__ import annotations
import uuid
from typing import AsyncGenerator, Dict, Optional, List
from datetime import datetime, timezone

from app.models import (
    Message, TextPart, Task, TaskStatus,
    TaskStatusUpdateEvent
)
from app.llm import stream_answer
from app.storage.base import Storage

# 新增:MCP 记忆客户端接口(在 main.py 中传入)
class MemoryClientProto:
    async def append(self, session_id: str, role: str, text: str, max_items: int = 200, ttl_s: int = 3600): ...
    async def recent(self, session_id: str, n: int = 10) -> List[dict]: ...

_iso = lambda: datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

def _extract_text(msg: Message) -> str:
    parts = [p for p in msg.parts if isinstance(p, TextPart)]
    return parts[0].text if parts else ""

class TaskManager:
    def __init__(self, store: Storage, memory: Optional[MemoryClientProto] = None) -> None:
        self.store = store
        self.memory = memory

    async def create_from_message(self, msg: Message) -> Task:
        task = await self.store.create_task(msg)
        # 把用户输入写入短期记忆(按 contextId 做会话隔离)
        if self.memory:
            try:
                await self.memory.append(task.contextId, role="user", text=_extract_text(msg))
            except Exception:
                pass
        return task

    async def run_stream(self, task: Task) -> AsyncGenerator[Dict, None]:
        # 进入 working
        await self.store.set_status(task.id, TaskStatus(state="working", timestamp=_iso()))
        evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
                                    status=TaskStatus(state="working", timestamp=_iso()))
        await self.store.append_event(task.id, evt.model_dump())
        yield evt.model_dump()

        # 取用户问题
        user_msg = task.history[-1]
        user_text = _extract_text(user_msg)

        # 从短期记忆取最近若干条,拼装到提示里
        memory_snippets: List[dict] = []
        if self.memory:
            try:
                memory_snippets = await self.memory.recent(task.contextId, n=10)
            except Exception:
                memory_snippets = []
        if memory_snippets:
            mem_ctx = "\n".join(f"{it.get('role')}: {it.get('text')}" for it in memory_snippets)
            prompt = f"[会话短期记忆]\n{mem_ctx}\n\n[当前提问]\n{user_text}"
        else:
            prompt = user_text

        # 流式生成
        accum = []
        async for ch in stream_answer(prompt):
            if await self.store.is_canceled(task.id):
                # 推送最终 canceled
                await self.store.set_status(task.id, TaskStatus(state="canceled", timestamp=_iso()))
                final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
                                                  status=TaskStatus(state="canceled", timestamp=_iso()), final=True)
                await self.store.append_event(task.id, final_evt.model_dump())
                yield final_evt.model_dump()
                return
            accum.append(ch)

        # 最终答案
        answer = "".join(accum)
        agent_msg = Message(kind="message", messageId=str(uuid.uuid4()), role="agent",
                            parts=[TextPart(kind="text", text=answer)],
                            contextId=task.contextId, taskId=task.id)
        await self.store.append_history(task.id, agent_msg)
        await self.store.append_event(task.id, agent_msg.model_dump())
        yield agent_msg.model_dump()

        # 写入短期记忆(agent 回复)
        if self.memory:
            try:
                await self.memory.append(task.contextId, role="agent", text=answer)
            except Exception:
                pass

        # 完成状态
        await self.store.set_status(task.id, TaskStatus(state="completed", timestamp=_iso()))
        final_evt = TaskStatusUpdateEvent(taskId=task.id, contextId=task.contextId,
                                          status=TaskStatus(state="completed", timestamp=_iso()), final=True)
        await self.store.append_event(task.id, final_evt.model_dump())
        yield final_evt.model_dump()

UPDATED app/main.py(初始化 MCP 记忆客户端并注入 TaskManager)

python 复制代码
from __future__ import annotations
import os, json
from typing import AsyncGenerator
from fastapi import FastAPI, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse

from app.models import JSONRPCRequest, MessageSendParams, AgentCard, AgentCapabilities
from app.jsonrpc import ok, err, PARSE_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, INVALID_PARAMS, TASK_NOT_FOUND
from app.security.auth import require_auth
from app.storage.base import Storage
from app.storage.memory import MemoryStorage
from app.storage.redis import RedisStorage
# from app.storage.pg import PostgresStorage
from app.tasks import TaskManager

# === MCP 记忆客户端 ===
from app.mcp import MemoryClient, NullMemory

app = FastAPI(title="A2A JSON-RPC Q&A Agent")

# 选择存储后端
backend = os.getenv("STORAGE_BACKEND", "memory")
if backend == "redis":
    store: Storage = RedisStorage(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
elif backend == "pg":
    # store = PostgresStorage(os.getenv("PG_DSN", "postgresql+asyncpg://postgres:postgres@localhost:5432/a2a"))
    store = MemoryStorage()  # 占位:还未实现,暂退回内存
else:
    store = MemoryStorage()

# MCP 短期记忆(可选)
mcp_url = os.getenv("MCP_SERVER_URL")
memory = MemoryClient(mcp_url) if mcp_url else NullMemory()

tm = TaskManager(store, memory)

@app.post("/rpc")
async def rpc(request: Request, _=Depends(require_auth)):
    try:
        payload = await request.json()
    except Exception:
        return JSONResponse(err(None, PARSE_ERROR, "Invalid JSON payload"))

    try:
        rpc_req = JSONRPCRequest.model_validate(payload)
    except Exception as e:
        return JSONResponse(err(None, INVALID_REQUEST, "Invalid Request", str(e)))

    rid = rpc_req.id
    method = rpc_req.method
    params = rpc_req.params or {}

    accept = request.headers.get("accept", "")

    if method == "message/send":
        try:
            p = MessageSendParams.model_validate(params)
        except Exception as e:
            return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))
        task = await tm.create_from_message(p.message)
        # 同步:跑完拿最终 message
        last_msg = None
        async for item in tm.run_stream(task):
            if item.get("kind") == "message":
                last_msg = item
        return JSONResponse(ok(rid, last_msg or task.model_dump()))

    if method == "message/stream" or ("text/event-stream" in accept):
        try:
            p = MessageSendParams.model_validate(params)
        except Exception as e:
            return JSONResponse(err(rid, INVALID_PARAMS, "Invalid parameters", str(e)))

        task = await tm.create_from_message(p.message)

        async def gen() -> AsyncGenerator[bytes, None]:
            # 先发 Task 快照
            yield f"data: {{\"jsonrpc\":\"2.0\",\"id\":{json.dumps(rid)},\"result\":{task.model_dump_json()}}}\n\n".encode()
            async for item in tm.run_stream(task):
                out = {"jsonrpc": "2.0", "id": rid, "result": item}
                yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
        return StreamingResponse(gen(), media_type="text/event-stream")

    if method == "tasks/get":
        tid = (params or {}).get("id")
        if not tid:
            return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
        t = await store.get_task(tid)
        if not t:
            return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
        return JSONResponse(ok(rid, t.model_dump()))

    if method == "tasks/cancel":
        tid = (params or {}).get("id")
        if not tid:
            return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
        t = await store.cancel(tid)
        if not t:
            return JSONResponse(err(rid, TASK_NOT_FOUND, "Task not found"))
        return JSONResponse(ok(rid, t.model_dump()))

    if method == "tasks/resubscribe":
        tid = (params or {}).get("id")
        from_seq = int((params or {}).get("fromSeq", 0))
        if not tid:
            return JSONResponse(err(rid, INVALID_PARAMS, "missing 'id'"))
        # SSE 回放历史事件
        async def gen() -> AsyncGenerator[bytes, None]:
            async for rec in store.iter_events(tid, from_seq=from_seq):
                out = {"jsonrpc": "2.0", "id": rid, "result": rec.payload}
                yield ("data: " + json.dumps(out, ensure_ascii=False) + "\n\n").encode()
        return StreamingResponse(gen(), media_type="text/event-stream")

    if method == "agent/getAuthenticatedExtendedCard":
        card = AgentCard(
            name="Q&A Agent",
            description="A2A JSON-RPC demo agent for Q&A (with MCP memory)",
            version="1.1.0",
            url="http://localhost:8000/rpc",
            capabilities=AgentCapabilities(streaming=True, pushNotifications=False),
            skills=[{"id":"qa","name":"Question Answering","tags":["qa","chat"],"description":"answer general questions"}],
            additionalInterfaces=[],
        )
        return JSONResponse(ok(rid, card.model_dump()))

    return JSONResponse(err(rid, METHOD_NOT_FOUND, "Method not found"))

**fastmcp 2.0 的"短期记忆(Redis)"**集成到你现有的 A2A × FastAPI × JSON-RPC 脚手架里了:包括新增 app/mcp.py、改造 app/tasks.py(读写会话记忆)和 app/main.py(注入 MemoryClient),并更新了 requirements.txt、.env.example。完整代码已经放到右侧画布里(新增/UPDATED 小节)。

快速跑通(本地):

  1. 起 Redis 与 MCP 短期记忆服务
bash 复制代码
docker run -d --name redis -p 6379:6379 redis:7-alpine

另开一个终端,启动你之前拿到的 fastmcp Redis 内存 server(server.py

确保它运行在 http://127.0.0.1:8010/mcp

bash 复制代码
python server.py
复制代码
2.	配置并启动 Agent

项目根目录

bash 复制代码
cp .env.example .env

打开 .env,确认:

bash 复制代码
# MCP_SERVER_URL=http://127.0.0.1:8010/mcp
bash 复制代码
pip install -r requirements.txt
uvicorn app.main:app --reload
复制代码
3.	体验短期记忆生效的流式对话(A2A)

第一句会被写入记忆

bash 复制代码
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
  "jsonrpc":"2.0","id":"s1","method":"message/stream",
  "params":{"message":{"kind":"message","messageId":"m1","role":"user",
            "parts":[{"kind":"text","text":"我叫小王,后面记住我"}]}}
}'

第二句生成时会把最近记忆拼进提示(可观察回答更"知道上下文")

bash 复制代码
curl -N http://localhost:8000/rpc -H 'Content-Type: application/json' -H 'Accept: text/event-stream' -d '{
  "jsonrpc":"2.0","id":"s2","method":"message/stream",
  "params":{"message":{"kind":"message","messageId":"m2","role":"user",
            "parts":[{"kind":"text","text":"提醒下我叫什么?"}]}}
}'

要点:

• 我把 A2A 的 contextId 当作 session_id 传给 fastmcp 的记忆工具,自动会话隔离;

• 每次用户输入与 Agent 输出都会 mem_append,生成前会拉取 mem_recent 拼进提示;

• 未配置 MCP_SERVER_URL 时自动降级为空实现(不影响运行)。