Python责任链模式:中间件的灵魂
前言
用过 FastAPI 或 Django 的中间件吗?请求进来先过认证,再过日志,再过限流,最后才到业务逻辑------这就是责任链模式。
责任链的精髓是:把处理逻辑拆成一个个独立的"处理器",串成链条,请求沿着链条传递,每个处理器决定自己处理还是交给下一个。
这篇文章从最简单的 Python 实现开始,然后深入 FastAPI 中间件的源码,看看框架是怎么玩的。最后手撸一个完整的中间件系统。
🏠个人主页:山沐与山
文章目录
一、责任链模式是什么
1.1 生活中的例子
想象一下报销流程:
员工提交报销 → 组长审批(≤500) → 经理审批(≤2000) → 总监审批(≤10000) → 财务总监
每一级只处理自己权限内的金额,超出就传给下一级。员工不需要知道该找谁审批,只管提交,系统自动找到合适的人。
1.2 不用责任链的噩梦
python
def process_request(request):
# 认证
if not request.headers.get("Authorization"):
return {"error": "未认证"}
token = request.headers["Authorization"]
if not verify_token(token):
return {"error": "token无效"}
# 限流
client_ip = request.client.host
if is_rate_limited(client_ip):
return {"error": "请求太频繁"}
# 日志
log_request(request)
# 权限检查
user = get_user_from_token(token)
if not user.has_permission(request.path):
return {"error": "无权限"}
# 参数校验
if not validate_params(request):
return {"error": "参数错误"}
# 终于到业务逻辑了...
return handle_business(request)
问题:所有逻辑揉在一起,难以复用、难以测试、难以维护。想加个新检查?在这一大坨里面找位置插入。
二、Python基础实现
2.1 经典实现
代码来自 basic_chain.py:
python
from abc import ABC, abstractmethod
from typing import Optional, Any
from dataclasses import dataclass
@dataclass
class Request:
"""请求对象"""
user: str
amount: float
description: str
class Handler(ABC):
"""处理器抽象基类"""
def __init__(self):
self._next_handler: Optional[Handler] = None
def set_next(self, handler: "Handler") -> "Handler":
"""设置下一个处理器,返回下一个处理器以支持链式调用"""
self._next_handler = handler
return handler
def handle(self, request: Request) -> str:
"""处理请求"""
# 先尝试自己处理
result = self.process(request)
# 如果自己不处理,传给下一个
if result is None and self._next_handler:
return self._next_handler.handle(request)
return result or "请求无法处理"
@abstractmethod
def process(self, request: Request) -> Optional[str]:
"""具体处理逻辑,返回None表示自己不处理"""
pass
# ========== 具体处理器 ==========
class TeamLeader(Handler):
"""组长:处理500以下"""
def process(self, request: Request) -> Optional[str]:
if request.amount <= 500:
return f"[OK] 组长批准了 {request.user} 的报销 ¥{request.amount}"
print(f" [组长] 金额超出权限,转交上级")
return None
class Manager(Handler):
"""经理:处理2000以下"""
def process(self, request: Request) -> Optional[str]:
if request.amount <= 2000:
return f"[OK] 经理批准了 {request.user} 的报销 ¥{request.amount}"
print(f" [经理] 金额超出权限,转交上级")
return None
class Director(Handler):
"""总监:处理10000以下"""
def process(self, request: Request) -> Optional[str]:
if request.amount <= 10000:
return f"[OK] 总监批准了 {request.user} 的报销 ¥{request.amount}"
print(f" [总监] 金额超出权限,转交上级")
return None
class CFO(Handler):
"""财务总监:处理所有"""
def process(self, request: Request) -> Optional[str]:
return f"[OK] CFO批准了 {request.user} 的报销 ¥{request.amount}"
# ========== 构建责任链 ==========
def build_chain() -> Handler:
"""构建审批链"""
team_leader = TeamLeader()
manager = Manager()
director = Director()
cfo = CFO()
# 链式设置
team_leader.set_next(manager).set_next(director).set_next(cfo)
return team_leader # 返回链头
# ========== 使用 ==========
chain = build_chain()
requests = [
Request("张三", 300, "办公用品"),
Request("李四", 1500, "团建费用"),
Request("王五", 8000, "设备采购"),
Request("赵六", 50000, "项目外包"),
]
for req in requests:
print(f"\n[申请] {req.user} 申请报销 ¥{req.amount} ({req.description})")
result = chain.handle(req)
print(f" [结果] {result}")
输出:
[申请] 张三 申请报销 ¥300 (办公用品)
[结果] [OK] 组长批准了 张三 的报销 ¥300
[申请] 李四 申请报销 ¥1500 (团建费用)
[组长] 金额超出权限,转交上级
[结果] [OK] 经理批准了 李四 的报销 ¥1500
[申请] 王五 申请报销 ¥8000 (设备采购)
[组长] 金额超出权限,转交上级
[经理] 金额超出权限,转交上级
[结果] [OK] 总监批准了 王五 的报销 ¥8000
[申请] 赵六 申请报销 ¥50000 (项目外包)
[组长] 金额超出权限,转交上级
[经理] 金额超出权限,转交上级
[总监] 金额超出权限,转交上级
[结果] [OK] CFO批准了 赵六 的报销 ¥50000
看到没?每个处理器只关心自己能处理的范围,超出就传给下一个。
2.2 函数式实现
Python 不一定要用类,函数也能搞。代码来自 functional_chain.py:
python
from typing import Callable, Optional, Any
# 处理器类型:接收请求和下一个处理器,返回结果
Handler = Callable[[dict, Optional[Callable]], Any]
def auth_handler(request: dict, next_handler: Optional[Callable]) -> Any:
"""认证处理器"""
token = request.get("token")
if not token:
return {"error": "未提供token", "code": 401}
if token != "valid_token":
return {"error": "token无效", "code": 401}
request["user"] = "authenticated_user"
print("[+] 认证通过")
return next_handler(request, None) if next_handler else request
def logging_handler(request: dict, next_handler: Optional[Callable]) -> Any:
"""日志处理器"""
print(f"[日志] 记录请求: {request.get('path', '/')}")
result = next_handler(request, None) if next_handler else request
print(f"[日志] 记录响应: {result}")
return result
def rate_limit_handler(request: dict, next_handler: Optional[Callable]) -> Any:
"""限流处理器"""
client_ip = request.get("client_ip", "unknown")
# 简化:实际应该用Redis等存储计数
if client_ip == "blocked_ip":
return {"error": "请求太频繁", "code": 429}
print("[+] 限流检查通过")
return next_handler(request, None) if next_handler else request
def build_chain(*handlers: Handler) -> Callable[[dict], Any]:
"""构建处理链"""
def chain(request: dict) -> Any:
def call_chain(handlers_list, req):
if not handlers_list:
return req
current = handlers_list[0]
remaining = handlers_list[1:]
# 把剩余的处理器包装成下一个处理器
next_handler = (lambda r, _: call_chain(remaining, r)) if remaining else None
return current(req, next_handler)
return call_chain(list(handlers), request)
return chain
# ========== 使用 ==========
# 构建链
process = build_chain(
logging_handler,
auth_handler,
rate_limit_handler,
)
# 测试
print("=== 正常请求 ===")
result = process({"path": "/api/users", "token": "valid_token", "client_ip": "192.168.1.1"})
print(f"[最终结果] {result}\n")
print("=== 未认证请求 ===")
result = process({"path": "/api/users", "client_ip": "192.168.1.1"})
print(f"[最终结果] {result}")
函数式的写法更简洁,适合简单场景。你想想,不用定义一堆类,直接写函数就完事了。
三、FastAPI中间件源码剖析
FastAPI 的中间件就是责任链模式的典型应用。来看看它是怎么实现的。
3.1 中间件的基本用法
代码来自 fastapi_middleware.py:
python
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
app = FastAPI()
class TimingMiddleware(BaseHTTPMiddleware):
"""计时中间件"""
async def dispatch(self, request: Request, call_next):
start = time.time()
response = await call_next(request) # 调用下一个处理器
duration = time.time() - start
response.headers["X-Process-Time"] = str(duration)
print(f"[计时] {request.url.path} 耗时 {duration:.3f}s")
return response
class AuthMiddleware(BaseHTTPMiddleware):
"""认证中间件"""
async def dispatch(self, request: Request, call_next):
# 白名单路径
if request.url.path in ["/", "/health", "/docs", "/openapi.json"]:
return await call_next(request)
token = request.headers.get("Authorization")
if not token:
from fastapi.responses import JSONResponse
return JSONResponse({"error": "未认证"}, status_code=401)
print(f"[认证] 认证通过: {token[:20]}...")
return await call_next(request)
# 注册中间件(注意:后添加的先执行)
app.add_middleware(AuthMiddleware)
app.add_middleware(TimingMiddleware)
@app.get("/")
async def root():
return {"message": "Hello"}
@app.get("/api/users")
async def get_users():
return {"users": []}
3.2 中间件执行流程
请求进入
│
▼
┌─────────────────────┐
│ TimingMiddleware │ ← 记录开始时间
│ (最外层) │
└─────────┬───────────┘
│ call_next(request)
▼
┌─────────────────────┐
│ AuthMiddleware │ ← 检查认证
│ (中间层) │
└─────────┬───────────┘
│ call_next(request)
▼
┌─────────────────────┐
│ 路由处理函数 │ ← 业务逻辑
│ (最内层) │
└─────────┬───────────┘
│ response
▼
┌─────────────────────┐
│ AuthMiddleware │ ← 后置处理(如果有)
└─────────┬───────────┘
│ response
▼
┌─────────────────────┐
│ TimingMiddleware │ ← 计算耗时,添加header
└─────────┬───────────┘
│
▼
返回响应
这是一个洋葱模型:请求从外到内穿过每一层,响应再从内到外穿回来。
3.3 Starlette 的中间件实现原理
python
# 简化版的 Starlette 中间件机制
class App:
def __init__(self):
self.middleware_stack = []
self.routes = {}
def add_middleware(self, middleware_class, **options):
"""添加中间件"""
self.middleware_stack.append((middleware_class, options))
def route(self, path):
"""路由装饰器"""
def decorator(func):
self.routes[path] = func
return func
return decorator
def build_middleware_stack(self):
"""构建中间件链"""
# 最内层是路由处理
app = self._route_handler
# 从后往前包装中间件
for middleware_class, options in reversed(self.middleware_stack):
app = middleware_class(app, **options)
return app
async def _route_handler(self, request):
"""路由处理器"""
handler = self.routes.get(request.path)
if handler:
return await handler(request)
return {"error": "Not Found"}
async def __call__(self, request):
"""处理请求"""
handler = self.build_middleware_stack()
return await handler(request)
关键点:中间件是层层包装的,类似装饰器模式,但每层都可以决定是否继续传递。
四、实战:完整的中间件系统
来构建一个生产级的中间件系统。
4.1 项目结构
myapp/
├── main.py
├── middleware/
│ ├── __init__.py
│ ├── base.py # 中间件基类
│ ├── timing.py # 计时
│ ├── logging.py # 日志
│ ├── auth.py # 认证
│ ├── rate_limit.py # 限流
│ └── error_handler.py # 异常处理
└── config.py
4.2 中间件实现
4.2.1 基类实现
代码来自 middleware/base.py:
python
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable
import logging
logger = logging.getLogger(__name__)
class BaseMiddleware(BaseHTTPMiddleware):
"""中间件基类"""
# 跳过中间件的路径
skip_paths: list = []
def should_skip(self, request: Request) -> bool:
"""判断是否跳过该中间件"""
return request.url.path in self.skip_paths
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if self.should_skip(request):
return await call_next(request)
return await self.process(request, call_next)
async def process(self, request: Request, call_next: Callable) -> Response:
"""子类实现具体逻辑"""
return await call_next(request)
4.2.2 计时中间件
代码来自 middleware/timing.py:
python
from fastapi import Request, Response
from .base import BaseMiddleware
import time
import logging
logger = logging.getLogger(__name__)
class TimingMiddleware(BaseMiddleware):
"""请求计时中间件"""
skip_paths = ["/health", "/metrics"]
async def process(self, request: Request, call_next) -> Response:
start_time = time.perf_counter()
response = await call_next(request)
duration = time.perf_counter() - start_time
response.headers["X-Process-Time"] = f"{duration:.4f}"
# 慢请求告警
if duration > 1.0:
logger.warning(f"[慢请求] {request.method} {request.url.path} 耗时 {duration:.2f}s")
return response
4.2.3 日志中间件
代码来自 middleware/logging.py:
python
from fastapi import Request, Response
from .base import BaseMiddleware
import logging
import uuid
from contextvars import ContextVar
# 请求ID上下文
request_id_var: ContextVar[str] = ContextVar("request_id", default="")
logger = logging.getLogger(__name__)
class LoggingMiddleware(BaseMiddleware):
"""日志中间件"""
skip_paths = ["/health"]
async def process(self, request: Request, call_next) -> Response:
# 生成请求ID
request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())[:8])
request_id_var.set(request_id)
# 记录请求
logger.info(f"[{request_id}] --> {request.method} {request.url.path}")
response = await call_next(request)
# 记录响应
logger.info(f"[{request_id}] <-- {response.status_code}")
# 响应头加上请求ID
response.headers["X-Request-ID"] = request_id
return response
4.2.4 认证中间件
代码来自 middleware/auth.py:
python
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from .base import BaseMiddleware
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseMiddleware):
"""认证中间件"""
skip_paths = ["/", "/health", "/docs", "/openapi.json", "/redoc", "/login", "/register"]
def __init__(self, app, secret_key: str = "default_secret"):
super().__init__(app)
self.secret_key = secret_key
async def process(self, request: Request, call_next) -> Response:
token = self._extract_token(request)
if not token:
return JSONResponse(
{"error": "未提供认证token", "code": "UNAUTHORIZED"},
status_code=401
)
# 验证token
user = self._verify_token(token)
if not user:
return JSONResponse(
{"error": "token无效或已过期", "code": "INVALID_TOKEN"},
status_code=401
)
# 将用户信息存入request.state
request.state.user = user
logger.debug(f"[认证] 用户认证成功: {user['username']}")
return await call_next(request)
def _extract_token(self, request: Request) -> Optional[str]:
"""从请求中提取token"""
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:]
return None
def _verify_token(self, token: str) -> Optional[dict]:
"""验证token(简化版,实际用JWT)"""
# 模拟验证
if token == "valid_token":
return {"id": 1, "username": "test_user", "role": "admin"}
return None
4.2.5 限流中间件
代码来自 middleware/rate_limit.py:
python
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from .base import BaseMiddleware
from collections import defaultdict
import time
import asyncio
import logging
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseMiddleware):
"""限流中间件(令牌桶算法)"""
skip_paths = ["/health"]
def __init__(
self,
app,
requests_per_second: int = 10,
burst: int = 20
):
super().__init__(app)
self.rate = requests_per_second
self.burst = burst
self.buckets: dict = defaultdict(lambda: {"tokens": burst, "last_update": time.time()})
self._lock = asyncio.Lock()
async def process(self, request: Request, call_next) -> Response:
client_ip = self._get_client_ip(request)
allowed = await self._is_allowed(client_ip)
if not allowed:
logger.warning(f"[限流] {client_ip}")
return JSONResponse(
{"error": "请求太频繁,请稍后再试", "code": "RATE_LIMITED"},
status_code=429,
headers={"Retry-After": "1"}
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""获取客户端IP"""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
async def _is_allowed(self, client_ip: str) -> bool:
"""检查是否允许请求"""
async with self._lock:
now = time.time()
bucket = self.buckets[client_ip]
# 补充令牌
time_passed = now - bucket["last_update"]
bucket["tokens"] = min(
self.burst,
bucket["tokens"] + time_passed * self.rate
)
bucket["last_update"] = now
# 消耗令牌
if bucket["tokens"] >= 1:
bucket["tokens"] -= 1
return True
return False
4.2.6 异常处理中间件
代码来自 middleware/error_handler.py:
python
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from .base import BaseMiddleware
import logging
import traceback
logger = logging.getLogger(__name__)
class ErrorHandlerMiddleware(BaseMiddleware):
"""全局异常处理中间件"""
skip_paths = [] # 不跳过任何路径
def __init__(self, app, debug: bool = False):
super().__init__(app)
self.debug = debug
async def process(self, request: Request, call_next) -> Response:
try:
return await call_next(request)
except ValueError as e:
logger.warning(f"[参数错误] {e}")
return JSONResponse(
{"error": str(e), "code": "INVALID_PARAMETER"},
status_code=400
)
except PermissionError as e:
logger.warning(f"[权限不足] {e}")
return JSONResponse(
{"error": "权限不足", "code": "FORBIDDEN"},
status_code=403
)
except Exception as e:
logger.error(f"[未处理异常] {e}\n{traceback.format_exc()}")
error_detail = {
"error": "服务器内部错误",
"code": "INTERNAL_ERROR"
}
if self.debug:
error_detail["detail"] = str(e)
error_detail["traceback"] = traceback.format_exc()
return JSONResponse(error_detail, status_code=500)
4.3 组装应用
代码来自 main.py:
python
from fastapi import FastAPI, Request
from middleware.timing import TimingMiddleware
from middleware.logging import LoggingMiddleware
from middleware.auth import AuthMiddleware
from middleware.rate_limit import RateLimitMiddleware
from middleware.error_handler import ErrorHandlerMiddleware
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
app = FastAPI(title="责任链模式示例")
# 添加中间件(注意顺序:后添加的先执行)
# 执行顺序:Error -> Timing -> Logging -> RateLimit -> Auth -> 路由
app.add_middleware(AuthMiddleware, secret_key="my_secret")
app.add_middleware(RateLimitMiddleware, requests_per_second=10, burst=20)
app.add_middleware(LoggingMiddleware)
app.add_middleware(TimingMiddleware)
app.add_middleware(ErrorHandlerMiddleware, debug=True)
# ========== 路由 ==========
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/api/users")
async def get_users(request: Request):
user = request.state.user
return {"message": f"Hello {user['username']}", "users": []}
@app.get("/api/error")
async def trigger_error():
raise ValueError("这是一个测试错误")
@app.get("/api/crash")
async def trigger_crash():
raise RuntimeError("服务器崩溃了!")
看到没?通过责任链,我们把认证、限流、日志、计时、异常处理全部解耦了。每个中间件只管自己的事,想加新功能?写个新中间件加进去就完事了。
4.4 中间件执行流程可视化
想看看执行顺序?加个调试中间件:
python
from .base import BaseMiddleware
from fastapi import Request
class DebugMiddleware(BaseMiddleware):
def __init__(self, app, name: str):
super().__init__(app)
self.name = name
async def process(self, request: Request, call_next):
print(f" [进入] {self.name}")
response = await call_next(request)
print(f" [返回] {self.name} (status={response.status_code})")
return response
# 添加调试中间件
app.add_middleware(DebugMiddleware, name="Layer 1")
app.add_middleware(DebugMiddleware, name="Layer 2")
app.add_middleware(DebugMiddleware, name="Layer 3")
请求时输出:
[进入] Layer 3
[进入] Layer 2
[进入] Layer 1
[返回] Layer 1 (status=200)
[返回] Layer 2 (status=200)
[返回] Layer 3 (status=200)
这就是洋葱模型的执行顺序。
五、责任链的变体
5.1 纯链式(处理后停止)
python
class Handler:
def handle(self, request) -> bool:
"""返回True表示已处理,链中断"""
if self._can_handle(request):
self._do_handle(request)
return True
elif self._next:
return self._next.handle(request)
return False
适用场景:审批流程、异常处理(找到第一个能处理的就停止)
5.2 过滤器链(全部执行)
python
class Filter:
def do_filter(self, request, chain):
# 前置处理
self._before(request)
# 继续链(必须调用)
chain.do_filter(request)
# 后置处理
self._after(request)
适用场景:中间件、日志、事务管理(每个都要执行)
5.3 拦截器模式
python
class Interceptor:
def pre_handle(self, request) -> bool:
"""前置处理,返回False则中断"""
return True
def post_handle(self, request, response):
"""后置处理"""
pass
def after_completion(self, request, response, exception):
"""完成后处理(无论成功失败)"""
pass
适用场景:Spring MVC 风格的拦截器
5.4 管道模式
python
from typing import Callable, List, Any
class Pipeline:
"""管道:数据依次流经每个处理器"""
def __init__(self):
self.pipes: List[Callable] = []
def pipe(self, handler: Callable) -> "Pipeline":
self.pipes.append(handler)
return self
def process(self, data: Any) -> Any:
result = data
for pipe in self.pipes:
result = pipe(result)
return result
# 使用
pipeline = Pipeline()
pipeline.pipe(lambda x: x.strip()) # 去空格
pipeline.pipe(lambda x: x.lower()) # 转小写
pipeline.pipe(lambda x: x.replace(" ", "_")) # 空格转下划线
result = pipeline.process(" Hello World ")
print(result) # "hello_world"
适用场景:数据处理、ETL、编译器
六、常见应用场景
| 场景 | 说明 | 示例 |
|---|---|---|
| Web中间件 | HTTP 请求处理链 |
认证、日志、限流、压缩 |
| 审批流程 | 多级审批 | 报销、请假、合同审批 |
| 异常处理 | 分层捕获异常 | 业务异常→系统异常→兜底 |
| 过滤器 | 数据过滤链 | 敏感词过滤、XSS过滤 |
| 命令处理 | 命令分发 | CLI工具、聊天机器人 |
| 日志处理 | 多级日志 | DEBUG→INFO→WARN→ERROR |
| 缓存策略 | 多级缓存 | L1→L2→L3→数据库 |
七、总结
责任链模式的核心是解耦请求的发送者和处理者,让多个对象都有机会处理请求。
两种典型形态:
- 纯责任链:找到能处理的就停止(审批流程)
- 过滤器链:每个都执行,层层包装(中间件)
FastAPI 中间件就是过滤器链的实现,采用洋葱模型,请求从外到内,响应从内到外。
设计要点:
- 每个处理器职责单一
- 处理器之间互不依赖
- 链的组装和处理器实现分离
- 考虑链断开的情况
关键概念对比表:
| 概念 | 说明 | 适用场景 | 注意事项 |
|---|---|---|---|
纯责任链 |
找到就停止 | 审批流程、异常处理 | 要有兜底处理器 |
过滤器链 |
全部执行 | 中间件、日志 | 注意执行顺序 |
洋葱模型 |
双向穿透 | Web框架 |
前后置处理 |
管道模式 |
数据转换 | ETL、编译器 |
数据流式处理 |
和之前讲的模式对比:
- 装饰器模式:增强单个对象的功能
- 责任链模式:多个对象依次处理请求
- 观察者模式:一对多的通知,所有观察者都收到
- 责任链模式:一对多的传递,可能只有一个处理
下一步 :下一篇可以讲讲状态模式,它和责任链有些相似,但关注的是对象内部状态的转换。
热门专栏推荐
- Agent小册
- 服务器部署
- Java基础合集
- Python基础合集
- Go基础合集
- 大数据合集
- 前端小册
- 数据库合集
- Redis 合集
- Spring 全家桶
- 微服务全家桶
- 数据结构与算法合集
- 设计模式小册
- 消息队列合集
等等等还有许多优秀的合集在主页等着大家的光顾,感谢大家的支持
文章到这里就结束了,如果有什么疑问的地方请指出,诸佬们一起来评论区一起讨论😊
希望能和诸佬们一起努力,今后我们一起观看感谢您的阅读🙏
如果帮助到您不妨3连支持一下,创造不易您们的支持是我的动力🌟