Pydantic 数据校验 & 限流中间件(限制每个 IP 的请求频率,防止接口被刷爆)

一、Pydantic 介绍

Pydantic 主要使用场景:

  • API 接口请求校验。
  • LLM 结构化输出:对模型返回的 JSON 进行带验证的反序列化,得到 Pydantic 模型实例(避免出现的幻觉数据影响业务)。

Pydantic 类比 JAVA

  • @RequestBody 接收JSON数据绑定到对象
  • Lombok 生成方法
  • 数据类型验证(对请求的数据类型做校验)

Pydantic作用

• 数据验证与约束:用BaseModel定义数据模型

• 数据转换与解析:支持从dict、JSON字符串等多种输入解析为强类型对象

• 模型序列化:用model_dump()将模型安全地序列化为字典

• 标准化与清洗

• 生成Schema

二、Pydantic 完整使用案例

2.1. FastAPI + Pydantic 自动校验 API 请求体的案例

展示如何用 Pydantic 模型定义请求格式,FastAPI 自动完成类型检查、约束校验和异常返回。

python 复制代码
# 导入 FastAPI 核心类和 HTTP 异常
from fastapi import FastAPI, HTTPException
# 导入 Pydantic 数据校验组件
from pydantic import BaseModel, Field, field_validator

# 创建 FastAPI 应用实例
app = FastAPI()


# ── 定义请求体数据模型(Pydantic) ──
class ProductItem(BaseModel):
    """产品创建请求体模型,FastAPI 会自动用它校验请求 JSON"""

    # id:必须 > 0 的正整数,通过 Field 约束范围
    id: int = Field(gt=0)

    # name:字符串,长度限制 1~100
    name: str = Field(min_length=1, max_length=100)

    # price:必须 > 0 的浮点数
    price: float = Field(gt=0)

    # quantity:必须 >= 0 的库存数量
    quantity: int = Field(ge=0)

    # 自定义字段验证器:在 name 赋值时自动去除首尾空格
    @field_validator("name")
    @classmethod
    def strip_name(cls, v: str) -> str:
        return v.strip()


# ── 路由:接收 POST 请求,自动用 Pydantic 模型校验 ──
@app.post("/products")
def create_product(product: ProductItem):
    """
    FastAPI 看到 product: ProductItem 这个参数类型时,会自动:
    1. 读取请求体的 JSON
    2. 调用 ProductItem.model_validate() 做类型校验 + 约束检查 + 自定义验证器
    3. 校验通过 → 注入 ProductItem 实例到 product 参数
    4. 校验失败 → 自动返回 422 错误,不会进入这个函数
    """
    # 能走到这里说明数据完全合法,直接操作实例即可
    return {
        "msg": "创建成功",
        "product": product.model_dump(),  # 序列化为字典,返回给客户端
    }


# ── 自定义校验异常处理器(可选,覆盖 FastAPI 默认的 422 格式) ──
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
    """自定义 422 校验失败的返回格式,方便前端统一解析"""
    return JSONResponse(
        status_code=422,
        content={
            "code": 422,
            "msg": "请求参数错误",
            "errors": exc.errors(),  # 校验失败的详细列表(字段名、错误原因)
        },
    )

2.2. LLM 结构化输出为 Pydantic 数据模型的案例

把 Pydantic 模型导出为 JSON Schema 传给 LLM,让 LLM 按要求输出 JSON,再用 model_validate_json() 把 LLM 返回的 JSON 反序列化成带校验的 Pydantic 实例 。


流程:

定义模型 ProductItem

导出 JSON Schema 传给 LLM(告诉 LLM 要输出什么结构) 【ProductItem.model_json_schema()】

LLM 按要求返回 JSON 【call_qwen_model】

model_validate_json() 校验 LLM 返回的JSON,并转成 ProductItem 模型实例【ProductItem.model_validate_json(cleaned_json)】

model_dump() 导出干净的结构化数据【product.model_dump()】

python 复制代码
import os
import json
import re
import logging
from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings
from openai import OpenAI

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EnvSettings(BaseSettings):
    """环境配置"""
    DASHSCOPE_API_KEY: str = ""
    DASHSCOPE_API_BASE: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"


# 初始化API客户端
env_settings = EnvSettings()
client = OpenAI(
    api_key=env_settings.DASHSCOPE_API_KEY or os.getenv("DASHSCOPE_API_KEY"),
    base_url=env_settings.DASHSCOPE_API_BASE,
)


class ProductItem(BaseModel):
    """ 
        定义一个 Pydantic 数据模型(产品项模型 - 主要验证模型) 
        
        - 整个 ProductItem 就是一个 带类型检查 + 取值范围约束 + 字段白名单 + 自定义清洗逻辑 的"数据契约"------定义好之后, 
        - 任何要进入系统的产品数据,都必须严格符合这个结构,否则就报错 ,而不是让脏数据悄悄流过去。
    """

    # ── 字段定义(schema) ──
    id: int = Field(gt=0) # 定义参数类型,通过 Field 规定参数范围(大于 0 的整数)
    name: str # 定义参数类型
    description: str  
    price: float = Field(gt=0)
    quantity: int = Field(ge=0)
    category: str

    # ── 模型配置 ──
    model_config = {
        "extra": "forbid",  # 禁止多余字段,保证结构严格
    }

    # ── 自定义验证器 ── 移除字符串首尾空格,当这两个字段被赋值或反序列化时,会自动触发此方法对参数进行增强处理。
    @field_validator("name", "category") # 声明该验证器作用于 name 和 category 这两个字段
    @classmethod # 类方法装饰器
    def strip_text(cls, v: str) -> str:
        return v.strip()


def build_prompt(product_description: str, schema: dict) -> str:
    """构造严格的提示词,要求输出纯 JSON。"""
    return (
        "请根据以下产品描述,严格输出符合 JSON Schema 的纯 JSON 数据:\n\n"
        f"产品描述:{product_description}\n\n"
        "JSON Schema:\n"
        f"{json.dumps(schema, indent=2, ensure_ascii=False)}\n\n"
        "输出要求:\n"
        "- 仅输出可以被直接解析的纯 JSON 字符串。\n"
        "- 禁止输出 Markdown 代码块(例如 ```json )。\n"
        "- 不要包含任何额外文字、解释或前后缀。\n"
        "- 对描述中未明确给出的字段合理补全:\n"
        "  - id 使用任意正整数。\n"
        "  - quantity 为库存数量(整数)。若描述为'库存充足',请设为 100。\n"
    )


def call_qwen_model(prompt: str) -> str:
    """调用模型并返回原始文本结果。"""
    try:
        logger.info("正在调用模型生成结构化 JSON...")

        # 确保API密钥已配置
        if not env_settings.DASHSCOPE_API_KEY and not os.getenv("DASHSCOPE_API_KEY"):
            raise ValueError("未配置 DASHSCOPE_API_KEY 环境变量")

        response = client.chat.completions.create(
            model="qwen-turbo",
            messages=[
                {
                    "role": "system",
                    "content": (
                        "你是一个结构化数据的处理器,精通 JSON。"
                        "请严格按给定 JSON Schema 输出纯 JSON。"
                        "输出将被直接解析,禁止代码块与解释文本。"
                    ),
                },
                {"role": "user", "content": prompt},
            ],
            temperature=0.2,
        )

        result = response.choices[0].message.content
        logger.info("模型调用成功")
        return result

    except Exception as e:
        logger.error(f"调用模型时发生错误: {e}")
        raise


def extract_pure_json(text: str) -> str:
    """去除可能的代码块或前后缀,仅保留 JSON 字符串。"""
    if not text:
        raise ValueError("空响应,无法解析 JSON")

    # 捕获 ```json ... ```或 ```... ```中的内容
    fence_match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
    if fence_match:
        text = fence_match.group(1).strip()

    # 若包含多余文字,尝试截取第一个 '{' 到最后一个 '}'
    if text.strip()[0] != "{" or text.strip()[-1] != "}":
        start = text.find("{")
        end = text.rfind("}")
        if start != -1 and end != -1 and end > start:
            text = text[start : end + 1]

    return text.strip()


def main() -> None:
    product_description = (
        "新款智能手机,6.1英寸OLED屏幕,A15仿生芯片,128GB存储,售价4999元,"
        "库存充足,属于电子产品类别。"
    )
    schema = ProductItem.model_json_schema()
    prompt = build_prompt(product_description, schema)

    # 调用模型生成 JSON 格式的产品信息
    raw_text = call_qwen_model(prompt)
    logger.info("模型返回的原始文本:%s", raw_text)

    # 清洗并校验
    cleaned_json = extract_pure_json(raw_text)
    logger.info("清洗后的 JSON 文本:%s", cleaned_json)

    try:
        product = ProductItem.model_validate_json(cleaned_json)
    except Exception:
        # 回退:若为 Python 字典,先 loads 后再校验
        data = json.loads(cleaned_json)
        product = ProductItem.model_validate(data)

    logger.info(
        "校验后的产品信息:%s",
        json.dumps(product.model_dump(), ensure_ascii=False, indent=2),
    )


if __name__ == "__main__":
    main()

三、限流中间件(限制每个 IP 的请求频率,防止接口被刷爆)

小项目临时可以这么写,企业场景推荐用Nginx、Redis + Lua等方式做限流。

FastAPI 限流中间件,实现案例 :

  • 拦截所有请求,按客户端 IP 限流
  • 令牌不够返回 429(请求太频繁)
  • 响应头带上剩余次数、重试时间
  • 支持配置豁免路径(如文档页面不限流)
  • 自动清理长时间不活跃的 IP 记录
python 复制代码
import asyncio
import time
from typing import Callable, Dict
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

class TokenBucket:
    """令牌桶算法:按速率填充令牌,支持突发容量。"""

    def __init__(self, rate: float, capacity: int):
        self.rate = float(rate)
        self.capacity = int(capacity)
        self.tokens = float(capacity)
        self.updated_at = time.monotonic()
        self._lock = asyncio.Lock()

    async def consume(self, cost: float = 1.0) -> tuple[bool, int, float]:
        """消费令牌。

        Returns:
            (allowed, remaining, retry_after)
        """
        async with self._lock:
            now = time.monotonic()
            elapsed = now - self.updated_at
            if elapsed > 0:
                self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
                self.updated_at = now

            if self.tokens >= cost:
                self.tokens -= cost
                return True, max(0, int(self.tokens)), 0.0

            need = cost - self.tokens
            retry_after = need / self.rate if self.rate > 0 else float("inf")
            return False, 0, retry_after

    def time_to_full(self) -> float:
        return (self.capacity - self.tokens) / self.rate if self.rate > 0 else float("inf")

class TokenBucketRateLimiter(BaseHTTPMiddleware):
    """基于令牌桶的限流中间件。默认按客户端IP限流。
        【限制单个 IP 的请求速率】        
    """

    def __init__(
        self,
        app: FastAPI,
        rate_per_sec: float,
        burst_capacity: int,
        key_func: Callable[[Request], str] | None = None,
        tokens_per_request: float = 1.0,
        exempt_paths: set[str] | None = None,
        ttl_seconds: int = 600,
    ) -> None:
        super().__init__(app)
        self.rate = float(rate_per_sec)
        self.capacity = int(burst_capacity)
        self.tokens_per_request = float(tokens_per_request)
        self.key_func = key_func or self._default_key
        self.exempt_paths = exempt_paths or set()
        self.ttl_seconds = int(ttl_seconds)
        self.buckets: Dict[str, TokenBucket] = {}
        self.last_seen: Dict[str, float] = {}
        self._global_lock = asyncio.Lock()

    def _default_key(self, request: Request) -> str:
        xff = request.headers.get("x-forwarded-for")
        if xff:
            ip = xff.split(",")[0].strip()
        else:
            ip = request.headers.get("x-real-ip") or (request.client.host if request.client else "unknown")
        return ip

    async def dispatch(self, request: Request, call_next):
        if request.url.path in self.exempt_paths:
            return await call_next(request)

        key = self.key_func(request)
        now = time.monotonic()

        async with self._global_lock:
            bucket = self.buckets.get(key)
            if bucket is None:
                bucket = TokenBucket(rate=self.rate, capacity=self.capacity)
                self.buckets[key] = bucket
            self.last_seen[key] = now

            if len(self.last_seen) > 1000:
                cutoff = now - self.ttl_seconds
                stale_keys = [k for k, t in self.last_seen.items() if t < cutoff]
                for k in stale_keys:
                    self.buckets.pop(k, None)
                    self.last_seen.pop(k, None)

        allowed, remaining, retry_after = await bucket.consume(self.tokens_per_request)

        policy = f"token_bucket; rate={self.rate}/s; burst={self.capacity}"
        if not allowed:
            reset = max(0, int(retry_after))
            return JSONResponse(
                status_code=429,
                content={"detail": "请求过于频繁,请稍后重试", "retry_after": reset},
                headers={
                    "Retry-After": str(reset),
                    "X-RateLimit-Policy": policy,
                    "X-RateLimit-Limit": str(self.capacity),
                    "X-RateLimit-Remaining": "0",
                    "X-RateLimit-Reset": str(reset),
                },
            )

        response = await call_next(request)
        reset = max(0, int(bucket.time_to_full()))
        response.headers["X-RateLimit-Policy"] = policy
        response.headers["X-RateLimit-Limit"] = str(self.capacity)
        response.headers["X-RateLimit-Remaining"] = str(remaining)
        response.headers["X-RateLimit-Reset"] = str(reset)
        return response

app = FastAPI(title="限流中间件演示(令牌桶)")

"""
配置限流中间件
【限制单个 IP 的请求速率】为5次/秒,每个请求消耗1个令牌,最大令牌数10个,每个请求5秒内只能请求10次
"""
app.add_middleware(
    TokenBucketRateLimiter,
    burst_capacity=10, # 初始状态,最大令牌数10个, 桶的最大容量是 10 个令牌
    rate_per_sec=5.0, # 填充速率 :每秒自动补充 5 个令牌(但不超过 10 个)
    tokens_per_request=1.0, # 每个请求消耗1个令牌
    exempt_paths={"/docs", "/openapi.json"}, # 免许文档路径和OpenAPI路径通过
)


@app.get("/")
async def root():
    return {"message": "OK"}

@app.get("/ping")
async def ping():
    return {"message": "pong"}

@app.get("/work")
async def work():
    import asyncio
    await asyncio.sleep(0.2)
    return {"message": "done"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("限流中间件:app", host="0.0.0.0", port=8000, reload=True)
相关推荐
上海合宙LuatOS1 小时前
Air8000多网通信-NTP
服务器·arm开发·物联网·网络协议·luatos
Betelgeuse763 小时前
Django 中间件 4 大钩子 & CBV vs FBV 对比实战
python·中间件·django
哈里谢顿10 小时前
no_proxy介绍
网络协议
Oflycomm13 小时前
工业以太网四大主流协议(EtherCAT/PROFINET/EtherNet/IP/Modbus)技术参数深度对比
网络·网络协议·tcp/ip·欧飞信·plc模组
wangl_9214 小时前
Modbus RTU 与 Modbus TCP 深入指南-现代替代协议
网络·网络协议·tcp/ip·tcp·modbus·rtu
java资料站16 小时前
常用中间件快速搭建
docker·中间件
霸道流氓气质17 小时前
SpringAIAlibaba整合 Streamable HTTP 调用免费 MCP Server 实战全解
网络·网络协议·http
想唱rap17 小时前
传输层协议TCP
linux·运维·服务器·网络·c++·tcp/ip
一只小白00019 小时前
一篇讲清TCP的三次握手&四次挥手
服务器·网络·tcp/ip