FastAPI 系列·(七):Redis 集成——缓存、分布式锁与 Session 管理

FastAPI 系列 · 第 7 篇:Redis 集成------缓存、分布式锁与 Session 管理

适合人群 :熟悉 Java Spring Boot,已完成第 01--06 篇的后端工程师
阅读时间 :约 60 分钟
一句话定位 :本篇为 shop-api 引入 Redis 层------用 Cache Aside 模式加速商品详情查询,用分布式锁保障库存扣减的安全性,并将第 05 篇的内存 Refresh Token 升级为 Redis 存储,同时剖析缓存穿透、击穿、雪崩三大经典问题的工程解法。


一、redis.asyncio 连接配置

Redis 对后端系统的价值,类似 Java 世界里 Caffeine + Redisson 的组合:本地缓存解决热点读,分布式锁解决并发写。在 FastAPI 的异步生态里,官方 redis 库自 4.2 版本起内置了 redis.asyncio 模块,无需额外安装异步驱动,一个包搞定所有场景。

1.1 安装依赖

bash 复制代码
pip install redis>=4.2.0

与第 04 篇引入 SQLAlchemy 异步驱动类似,Redis 的异步客户端也基于 asyncio,所有 IO 操作都是协程,不会阻塞 Event Loop。

1.2 连接池配置

在 Java Spring Data Redis 里,配置 Redis 需要声明 LettuceConnectionFactoryJedisConnectionFactory,然后通过 RedisTemplate 操作。FastAPI 里对应的概念是 ConnectionPool + Redis 客户端实例。

Spring Data Redis 概念 FastAPI + redis.asyncio 对应 说明
RedisConnectionFactory ConnectionPool.from_url() 连接工厂,管理连接的创建和回收
RedisTemplate<String, String> aioredis.Redis(connection_pool=pool) 操作客户端
StringRedisValueOperations redis.get() / redis.set() 字符串操作
@Bean 单例 模块级全局变量 _redis 整个应用共享同一个实例
@PreDestroy lifespan 关闭阶段 应用退出时释放连接

连接池的关键参数说明:

python 复制代码
import redis.asyncio as aioredis

pool = aioredis.ConnectionPool.from_url(
    "redis://:password@localhost:6379/0",
    max_connections=20,        # 最大并发连接数,类比 HikariCP 的 maximumPoolSize
    decode_responses=True,     # bytes → str,省去手动 decode,推荐开启
    socket_connect_timeout=5,  # 建连超时(秒)
    socket_timeout=5,          # 读写超时(秒)
    retry_on_timeout=True,     # 超时自动重试
)

💡 max_connections=20 是经验值起点。Redis 是单线程处理命令,连接数太多并不会提升吞吐,反而增加上下文切换开销。对于中小型服务,20 个连接足够支撑数千 QPS。

1.3 在 lifespan 中初始化和关闭连接池

第 01 篇介绍了 FastAPI 的 lifespan 上下文管理器,用来替代已废弃的 @app.on_event("startup") / @app.on_event("shutdown")。Redis 连接池的生命周期管理同样放在这里。

app/config.py --- 新增 Redis 配置项:

python 复制代码
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    DATABASE_URL: str = "mysql+aiomysql://root:password@localhost/shop"
    SECRET_KEY: str = "your-super-secret-key-change-this-in-production"
    ALGORITHM: str = "HS256"
    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
    REFRESH_TOKEN_EXPIRE_DAYS: int = 7

    # 新增:Redis 配置
    REDIS_URL: str = "redis://localhost:6379/0"
    REDIS_MAX_CONNECTIONS: int = 20

    class Config:
        env_file = ".env"

settings = Settings()

app/redis_client.py --- Redis 客户端单例(类比 Spring 的 RedisTemplate Bean):

python 复制代码
from typing import Optional
import redis.asyncio as aioredis
from app.config import settings

# 全局连接池(类比 Spring 的 RedisConnectionFactory Bean)
_redis_pool: Optional[aioredis.ConnectionPool] = None
_redis: Optional[aioredis.Redis] = None


async def init_redis() -> None:
    """初始化 Redis 连接池(在 lifespan 中调用)"""
    global _redis_pool, _redis
    _redis_pool = aioredis.ConnectionPool.from_url(
        settings.REDIS_URL,
        max_connections=settings.REDIS_MAX_CONNECTIONS,
        decode_responses=True,  # 自动将 bytes 解码为 str,避免手动 b"xxx".decode()
        socket_connect_timeout=5,
        socket_timeout=5,
        retry_on_timeout=True,
    )
    _redis = aioredis.Redis(connection_pool=_redis_pool)


async def close_redis() -> None:
    """关闭 Redis 连接(在 lifespan 中调用)"""
    if _redis:
        await _redis.aclose()  # 关闭客户端连接
    if _redis_pool:
        await _redis_pool.aclose()  # 关闭连接池中所有连接


def get_redis() -> aioredis.Redis:
    """
    依赖注入函数,获取 Redis 客户端。
    类比 Spring 中直接注入 RedisTemplate:
        @Autowired
        private RedisTemplate<String, String> redisTemplate;
    """
    if _redis is None:
        raise RuntimeError("Redis 未初始化,请确保 lifespan 已运行")
    return _redis

app/main.py --- 在 lifespan 中挂载 Redis 初始化:

python 复制代码
from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.database import init_db
from app.redis_client import init_redis, close_redis

@asynccontextmanager
async def lifespan(app: FastAPI):
    # ---- 启动阶段 ----
    await init_db()        # 第 04 篇:初始化数据库连接池
    await init_redis()     # 本篇新增:初始化 Redis 连接池
    print("✅ 数据库和 Redis 连接池初始化完成")

    yield  # 应用运行期间

    # ---- 关闭阶段 ----
    await close_redis()    # 先关 Redis
    print("Redis 连接池已关闭")

app = FastAPI(title="Shop API", lifespan=lifespan)

1.4 Redis 依赖注入

有了 get_redis() 函数,在任何路由中都可以通过 Depends() 注入 Redis 客户端:

python 复制代码
from fastapi import Depends
from redis.asyncio import Redis
from app.redis_client import get_redis

@router.get("/products/{product_id}")
async def get_product(
    product_id: int,
    redis: Redis = Depends(get_redis),        # 注入 Redis 客户端
    db: AsyncSession = Depends(get_db),       # 注入数据库会话
    current_user: User = Depends(get_current_user),  # 注入当前用户
):
    ...

💡 get_redis() 返回的是模块级单例,不像 get_db() 每次请求创建新会话。Redis 连接池内部维护了连接的分配和回收,多个协程可以安全地并发使用同一个 Redis 实例。


二、Cache Aside 模式

2.1 什么是 Cache Aside

Cache Aside (旁路缓存)是分布式系统中最经典、最安全的缓存使用模式。它的核心思路是:缓存永远是数据库的快照,应用代码负责维护两者的一致性

与之对比的是 Read Through / Write Through 模式------那些模式把缓存更新的职责委托给缓存层本身(如 ORM 插件、缓存代理),对应用透明。Cache Aside 则更直白:应用自己决定何时读缓存、何时写缓存、何时删缓存。

2.2 完整流程图

命中
未命中


客户端请求
缓存命中?
从 Redis 返回数据
查询数据库
查到数据?
将数据写入 Redis

设置 TTL
写入空值标记

NULL

防止穿透
返回数据给客户端
返回 404 给客户端

写操作流程(更新/删除数据时):
写请求
更新数据库
删除 Redis 缓存
返回成功

2.3 为什么写操作要"删缓存"而不是"更新缓存"

这是 Cache Aside 最容易被误解的地方。直觉上,更新数据库后同步更新缓存似乎更合理,但实际上存在竞态条件:

复制代码
线程 A:更新 DB(product.stock = 100)
线程 B:更新 DB(product.stock = 99)
线程 B:更新缓存(stock = 99)   ← B 先写缓存
线程 A:更新缓存(stock = 100)  ← A 后写缓存,覆盖了 B 的结果
结果:DB 里是 99,缓存里是 100,数据不一致!

先删后写同样有问题:

复制代码
线程 A:删缓存
线程 B:读缓存(未命中)→ 查 DB(旧值 100)→ 写缓存(100)
线程 A:更新 DB(99)
结果:DB 里是 99,缓存里是 100,数据不一致!

正确做法是先写 DB,再删缓存。即使在极端情况下,最坏结果是下次读请求回源 DB 拿到最新值,缓存会自然更新,不一致窗口极短。

2.4 通用 cache_get_or_set 工具函数

将 Cache Aside 的读取逻辑封装成通用工具,类比 Spring 的 @Cacheable 注解:

app/cache/utils.py

python 复制代码
import json
from typing import Any, Callable, Optional, TypeVar
import redis.asyncio as aioredis

T = TypeVar("T")

# 空值标记,防止缓存穿透(缓存穿透详见第七章)
NULL_SENTINEL = "__NULL__"


async def cache_get_or_set(
    redis: aioredis.Redis,
    key: str,
    fetch_func: Callable,
    ttl: int = 300,
    null_ttl: int = 60,          # 空值的 TTL 比正常值短(避免长期缓存不存在的数据)
    serializer=json.dumps,
    deserializer=json.loads,
) -> Optional[Any]:
    """
    Cache Aside 通用封装。

    类比 Spring @Cacheable:
        @Cacheable(value = "products", key = "#id", unless = "#result == null")
        public Product getProduct(Long id) { ... }

    用法示例:
        result = await cache_get_or_set(
            redis, "shop:v1:product:123",
            fetch_func=lambda: product_repo.find_by_id(123),
            ttl=300
        )
    """
    # 1. 先查缓存
    cached = await redis.get(key)
    if cached is not None:
        if cached == NULL_SENTINEL:
            return None  # 空值缓存命中,直接返回 None,不查 DB
        return deserializer(cached)

    # 2. 缓存未命中,查数据源(DB / 外部接口 / 计算结果)
    result = await fetch_func()

    # 3. 写入缓存(包括空值缓存)
    if result is None:
        # 空值也缓存,防止同一个不存在的 key 反复穿透到 DB
        await redis.setex(key, null_ttl, NULL_SENTINEL)
    else:
        await redis.setex(key, ttl, serializer(result))

    return result


async def cache_delete(redis: aioredis.Redis, key: str) -> None:
    """删除缓存(写操作后调用,类比 Spring @CacheEvict)"""
    await redis.delete(key)


async def cache_delete_pattern(redis: aioredis.Redis, pattern: str) -> int:
    """
    按模式批量删除缓存(如 'shop:v1:product:*')。
    ⚠️ 生产环境慎用 KEYS 命令,它会阻塞 Redis,推荐改用 SCAN。
    """
    cursor = 0
    deleted = 0
    while True:
        cursor, keys = await redis.scan(cursor, match=pattern, count=100)
        if keys:
            await redis.delete(*keys)
            deleted += len(keys)
        if cursor == 0:
            break
    return deleted

三、商品详情缓存实战

3.1 缓存 Key 设计规范

好的 Key 设计应该清晰、可追溯、易于批量操作。推荐格式:

复制代码
{namespace}:{version}:{resource}:{id}
字段 示例 说明
namespace shop 项目/服务标识,防止多服务共用 Redis 时 Key 冲突
version v1 数据格式版本,升级 Schema 时只需改版本号即可全量失效
resource product 资源类型
id 123 资源 ID

实际示例:

python 复制代码
# 商品详情
KEY_PRODUCT = "shop:v1:product:{product_id}"

# 商品列表(按分页参数区分)
KEY_PRODUCT_LIST = "shop:v1:product:list:{page}:{page_size}"

# 用户购物车
KEY_CART = "shop:v1:cart:{user_id}"

# Refresh Token(第五章升级到此)
KEY_REFRESH_TOKEN = "shop:v1:refresh:{token_hash}"

3.2 TTL 设计策略

TTL(Time To Live,生存时间)决定了缓存数据的新鲜度和缓存命中率之间的平衡:

python 复制代码
# app/cache/keys.py

class CacheTTL:
    """
    TTL 常量定义。
    设计原则:
    - 数据变更频率越高,TTL 越短
    - 数据越重要(一致性要求高),TTL 越短
    - 热门数据适当缩短 TTL,冷数据延长 TTL
    """
    HOT_PRODUCT = 60        # 热门商品:60 秒,频繁变动(秒杀、库存)
    NORMAL_PRODUCT = 300    # 普通商品:5 分钟
    COLD_PRODUCT = 3600     # 冷门商品:1 小时
    PRODUCT_LIST = 120      # 商品列表:2 分钟(列表变动频率高于详情)
    USER_INFO = 600         # 用户信息:10 分钟
    NULL_VALUE = 60         # 空值缓存:1 分钟(不要太长)


def product_cache_key(product_id: int) -> str:
    return f"shop:v1:product:{product_id}"


def product_list_cache_key(page: int, page_size: int) -> str:
    return f"shop:v1:product:list:{page}:{page_size}"

3.3 序列化方案选择

缓存需要将 Python 对象序列化为字符串存入 Redis,取出时再反序列化。有两种常见方案:

方案 优点 缺点 推荐场景
json.dumps / json.loads 跨语言、可读、安全 不支持复杂类型(datetime、自定义类) 首选,99% 场景够用
pickle.dumps / pickle.loads 支持任意 Python 对象 ⚠️ 安全风险(反序列化攻击)、语言绑定、版本耦合 内部工具,不对外暴露
python 复制代码
import json
from datetime import datetime
from decimal import Decimal
from typing import Any

# ✅ 自定义 JSON 编码器处理特殊类型
class ShopJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.isoformat()       # datetime → "2024-01-01T12:00:00"
        if isinstance(obj, Decimal):
            return str(obj)              # Decimal → "99.99"
        return super().default(obj)


def serialize(obj) -> str:
    return json.dumps(obj, cls=ShopJSONEncoder, ensure_ascii=False)


def deserialize(s: str) -> Any:
    return json.loads(s)

3.4 商品详情接口完整实现

app/routers/products.py --- 带缓存的商品详情接口:

python 复制代码
from fastapi import APIRouter, Depends, HTTPException, status
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from redis.asyncio import Redis

from app.database import get_db
from app.redis_client import get_redis
from app.models import Product
from app.schemas import ProductResponse
from app.cache.keys import product_cache_key, CacheTTL
from app.cache.utils import cache_get_or_set, cache_delete

router = APIRouter(prefix="/products", tags=["商品"])


async def _fetch_product_from_db(
    product_id: int,
    db: AsyncSession,
) -> Optional[dict]:
    """从数据库查询商品,返回可序列化的 dict"""
    product = await db.get(Product, product_id)
    if product is None:
        return None
    return {
        "id": product.id,
        "name": product.name,
        "price": str(product.price),  # Decimal → str 便于 JSON 序列化
        "stock": product.stock,
        "description": product.description,
        "created_at": product.created_at.isoformat(),
    }


@router.get("/{product_id}", response_model=ProductResponse)
async def get_product(
    product_id: int,
    redis: Redis = Depends(get_redis),
    db: AsyncSession = Depends(get_db),
):
    """
    获取商品详情(Cache Aside 模式)。
    1. 查 Redis 缓存
    2. 未命中 → 查 DB → 写缓存
    3. 缓存穿透防护(空值缓存)
    """
    cache_key = product_cache_key(product_id)

    result = await cache_get_or_set(
        redis=redis,
        key=cache_key,
        fetch_func=lambda: _fetch_product_from_db(product_id, db),
        ttl=CacheTTL.NORMAL_PRODUCT,
        null_ttl=CacheTTL.NULL_VALUE,
    )

    if result is None:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"商品 {product_id} 不存在",
        )

    return result


@router.put("/{product_id}", response_model=ProductResponse)
async def update_product(
    product_id: int,
    product_data: ProductUpdate,
    redis: Redis = Depends(get_redis),
    db: AsyncSession = Depends(get_db),
    current_user: User = Depends(require_role(Role.ADMIN)),  # 需要管理员权限
):
    """
    更新商品信息。
    写操作流程:先写 DB,再删缓存(Cache Aside 写操作标准流程)。
    类比 Spring @CacheEvict(value = "products", key = "#id")
    """
    product = await db.get(Product, product_id)
    if product is None:
        raise HTTPException(status_code=404, detail="商品不存在")

    # 1. 写入数据库
    for field, value in product_data.model_dump(exclude_unset=True).items():
        setattr(product, field, value)
    await db.commit()
    await db.refresh(product)

    # 2. 删除缓存(不更新缓存,防止竞态条件)
    await cache_delete(redis, product_cache_key(product_id))

    return product

🤔 为什么 fetch_funclambda: _fetch_product_from_db(product_id, db) 而不是直接调用?因为 cache_get_or_set 需要在缓存未命中时才调用 fetch_func,使用 lambda 实现惰性求值------只有真正需要查 DB 时才执行,这是函数式编程里"延迟计算"的经典模式。


四、分布式锁

4.1 场景:库存扣减的超卖问题

假设某商品库存为 1,两个用户同时下单。在没有锁的情况下:

复制代码
时间线:
T1: 用户 A 查库存 → 库存=1,可以购买
T1: 用户 B 查库存 → 库存=1,可以购买
T2: 用户 A 扣减库存 → stock = stock - 1 = 0
T2: 用户 B 扣减库存 → stock = stock - 1 = 0(基于 T1 时刻的旧值)
结果:两个用户都成功下单,但库存只有 1 个!

这就是超卖 问题。在单体应用里可以用数据库的 SELECT FOR UPDATE 行锁解决。但在多实例部署场景下,多个 FastAPI 进程并发运行,数据库行锁只能保证单个事务内的原子性,无法跨越不同进程的"查库存→判断→扣减"这段业务逻辑。

分布式锁的核心思路:在 Redis 里设置一个"令牌",谁持有令牌谁才能执行关键逻辑,执行完毕释放令牌。

4.2 SET NX EX 原子命令原理

Redis 的 SET key value NX EX seconds 是实现分布式锁的原语:

  • NX(Not eXists):只有 key 不存在时才设置,存在则返回 nil(失败)
  • EX seconds:同时设置过期时间,防止持锁进程崩溃导致死锁
  • 原子性NXEX 在同一条命令内完成,不会出现"设置成功但过期时间没设上"的情况

⚠️ 这是一个常见的历史错误写法:

python 复制代码
# ❌ 错误:SETNX 和 EXPIRE 不是原子的
# 如果两行之间进程崩溃,锁永远不会过期 → 死锁
await redis.setnx("lock:stock:123", "1")
await redis.expire("lock:stock:123", 30)

# ✅ 正确:SET NX EX 是一条原子命令
acquired = await redis.set("lock:stock:123", "token", nx=True, ex=30)

4.3 Lua 脚本原子解锁

解锁时必须保证"查 token + 删 key"的原子性,否则会有这样的问题:

复制代码
T1: 进程 A 的锁即将过期(但还没过期)
T2: 进程 A 查到 token 匹配,准备删锁
T3: 锁过期,进程 B 成功加锁(B 的 token 写入)
T4: 进程 A 删锁(误删了 B 的锁!)

解决方案是用 Lua 脚本,Redis 保证 Lua 脚本的原子执行:

lua 复制代码
-- 解锁脚本:只有持有锁的客户端(token 匹配)才能解锁
if redis.call("get", KEYS[1]) == ARGV[1] then
    return redis.call("del", KEYS[1])
else
    return 0
end

这个脚本对应 Spring 生态里 Redisson 的 RLock.unlock() 内部实现。

4.4 RedisLock 上下文管理器完整实现

app/cache/redis_lock.py

python 复制代码
import asyncio
import secrets
import redis.asyncio as aioredis
from typing import Optional

# Lua 脚本:原子解锁(只有持有锁的客户端才能解锁)
UNLOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
    return redis.call("del", KEYS[1])
else
    return 0
end
"""


class RedisLock:
    """
    Redis 分布式锁,实现为异步上下文管理器。

    类比 Spring/Redisson:
        RLock lock = redissonClient.getLock("product:stock:123");
        lock.lock(30, TimeUnit.SECONDS);
        try { ... } finally { lock.unlock(); }

    FastAPI 用法:
        async with RedisLock(redis, "product:stock:123", expire=30) as acquired:
            if not acquired:
                raise HTTPException(429, "服务繁忙,请稍后重试")
            # 执行关键业务逻辑
    """

    def __init__(
        self,
        redis: aioredis.Redis,
        key: str,
        expire: int = 30,
        retry_times: int = 3,
        retry_delay: float = 0.1,
    ):
        self.redis = redis
        self.key = f"lock:{key}"   # 统一加 lock: 前缀,与业务 Key 区分
        self.expire = expire
        self.retry_times = retry_times
        self.retry_delay = retry_delay
        self.token = secrets.token_hex(16)  # 随机 token,确保只有自己能解锁
        self._acquired = False

    async def __aenter__(self) -> bool:
        """尝试获取锁,支持重试。返回 True 表示成功获取锁。"""
        for attempt in range(self.retry_times):
            acquired = await self.redis.set(
                self.key,
                self.token,
                nx=True,          # NX: 只有 key 不存在时才设置
                ex=self.expire,   # EX: 过期时间(秒),防止死锁
            )
            if acquired:
                self._acquired = True
                return True  # 成功获取锁

            # 未获取到锁,等待后重试
            if attempt < self.retry_times - 1:
                await asyncio.sleep(self.retry_delay)

        return False  # 重试耗尽,获取锁失败

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """退出上下文时释放锁(Lua 脚本原子解锁)"""
        if self._acquired:
            await self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.token)
        return False  # 不吞掉异常,让调用方感知到异常


class RedisLockError(Exception):
    """获取分布式锁失败"""
    pass

4.5 库存扣减接口实战

app/routers/orders.py --- 带分布式锁的库存扣减:

python 复制代码
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from redis.asyncio import Redis

from app.database import get_db
from app.redis_client import get_redis
from app.models import Product, Order
from app.cache.redis_lock import RedisLock
from app.cache.utils import cache_delete
from app.cache.keys import product_cache_key

router = APIRouter(prefix="/orders", tags=["订单"])


@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_order(
    order_data: OrderCreate,
    redis: Redis = Depends(get_redis),
    db: AsyncSession = Depends(get_db),
    current_user: User = Depends(get_current_user),
):
    """
    创建订单(含库存扣减)。
    分布式锁保证同一商品的并发下单不会超卖。
    """
    product_id = order_data.product_id
    quantity = order_data.quantity
    lock_key = f"product:stock:{product_id}"

    async with RedisLock(redis, lock_key, expire=10) as acquired:
        if not acquired:
            # 获取锁失败,说明同一商品正有其他请求在处理
            raise HTTPException(
                status_code=status.HTTP_429_TOO_MANY_REQUESTS,
                detail="请求过于频繁,请稍后重试",
            )

        # ---- 以下代码在分布式锁保护下串行执行 ----

        # 1. 查库存(在锁保护下查,确保最新值)
        product = await db.get(Product, product_id)
        if product is None:
            raise HTTPException(status_code=404, detail="商品不存在")

        if product.stock < quantity:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"库存不足,当前库存:{product.stock}",
            )

        # 2. 扣减库存
        product.stock -= quantity
        await db.flush()   # flush 到当前事务,但还未 commit

        # 3. 创建订单记录
        order = Order(
            user_id=current_user.id,
            product_id=product_id,
            quantity=quantity,
            total_price=product.price * quantity,
        )
        db.add(order)
        await db.commit()  # 原子提交库存扣减 + 订单创建

        # 4. 删除商品缓存(库存变了,缓存失效)
        await cache_delete(redis, product_cache_key(product_id))

    return {"order_id": order.id, "message": "下单成功"}

📝 lock_key = f"product:stock:{product_id}" 是针对每个商品的独立锁,而不是一个全局锁。这样不同商品的下单请求可以并发执行,只有同一商品的下单请求会互斥,保证了系统吞吐量。


五、Refresh Token 存 Redis(接第 05 篇)

5.1 第 05 篇的遗留问题

第 05 篇实现 Refresh Token 时,使用了一个简单的内存字典存储:

python 复制代码
# app/routers/auth.py(第 05 篇的实现)
_refresh_token_store: dict[str, int] = {}  # token_hash → user_id

# 问题 1:重启后所有用户需重新登录(内存数据丢失)
# 问题 2:多实例部署时,实例 A 颁发的 Refresh Token,实例 B 无法识别
# 问题 3:无法实现服务端主动吊销 Token(踢人下线)

本篇将其升级为 Redis 存储,解决以上三个问题。

5.2 Redis 存储方案设计

复制代码
Key:shop:v1:refresh:{token_hash}
Value:用户 ID(字符串)
TTL:7天(对应 REFRESH_TOKEN_EXPIRE_DAYS)

使用 SETEX(SET + EXpire)一条命令完成写入和设置过期时间,与 Refresh Token 的有效期绑定------Token 过期时,Redis 里的记录自动消失,无需定期清理。

5.3 修改 auth.py

app/routers/auth.py --- 完整改造:

python 复制代码
import hashlib
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from redis.asyncio import Redis

from app.database import get_db
from app.redis_client import get_redis
from app.models import User
from app.auth import (
    verify_password,
    create_access_token,
    create_refresh_token,
    decode_token,
)
from app.config import settings
# 以下函数来自第 05 篇定义的 auth 模块
from app.core.security import verify_password, create_access_token, create_refresh_token
from app.dependencies.auth import get_current_user  # 第 05 篇定义

router = APIRouter(prefix="/auth", tags=["认证"])


def _refresh_token_key(token: str) -> str:
    """
    不直接用 token 原文作 Key,而是用 SHA256 哈希。
    原因:Redis Key 会记录在日志里,哈希后即使日志泄露也不会暴露 token 原文。
    """
    token_hash = hashlib.sha256(token.encode()).hexdigest()
    return f"shop:v1:refresh:{token_hash}"


@router.post("/token")
async def login(
    form_data: OAuth2PasswordRequestForm = Depends(),
    db: AsyncSession = Depends(get_db),
    redis: Redis = Depends(get_redis),
):
    """登录,颁发 Access Token + Refresh Token"""
    # 1. 校验用户名密码(同第 05 篇,内联实现)
    from sqlalchemy import select
    result = await db.execute(select(User).where(User.username == form_data.username))
    user = result.scalar_one_or_none()
    if not user or not verify_password(form_data.password, user.hashed_password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="用户名或密码错误",
            headers={"WWW-Authenticate": "Bearer"},
        )

    # 2. 生成 Token
    access_token = create_access_token({"sub": str(user.id), "role": user.role})
    refresh_token = create_refresh_token({"sub": str(user.id)})

    # 3. 将 Refresh Token 存入 Redis(替代第 05 篇的内存字典)
    expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 3600
    await redis.setex(
        _refresh_token_key(refresh_token),
        expire_seconds,
        str(user.id),  # value 为 user_id(Redis 存字符串)
    )

    return {
        "access_token": access_token,
        "refresh_token": refresh_token,
        "token_type": "bearer",
        "expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
    }


@router.post("/token/refresh")
async def refresh_access_token(
    refresh_token: str,
    redis: Redis = Depends(get_redis),
    db: AsyncSession = Depends(get_db),
):
    """
    用 Refresh Token 换新的 Access Token。
    从 Redis 而不是内存里校验,多实例友好。
    """
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Refresh Token 无效或已过期",
    )

    # 1. 从 Redis 查 token 是否存在(同时校验了有效期,因为 Key 有 TTL)
    redis_key = _refresh_token_key(refresh_token)
    user_id_str = await redis.get(redis_key)
    if user_id_str is None:
        raise credentials_exception

    # 2. 验证 JWT 签名(防止伪造)
    payload = decode_token(refresh_token)
    if payload is None or str(payload.get("sub")) != user_id_str:
        raise credentials_exception

    # 3. 查询用户状态(确保用户未被禁用)
    user = await db.get(User, int(user_id_str))
    if user is None or not user.is_active:
        raise credentials_exception

    # 4. 颁发新 Access Token(Refresh Token 不轮换,避免客户端复杂性)
    new_access_token = create_access_token({"sub": str(user.id), "role": user.role})

    return {
        "access_token": new_access_token,
        "token_type": "bearer",
        "expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
    }


@router.post("/logout")
async def logout(
    refresh_token: str,
    redis: Redis = Depends(get_redis),
    current_user: User = Depends(get_current_user),
):
    """
    登出:主动删除 Redis 里的 Refresh Token。
    Access Token 无法吊销(JWT 无状态特性),依靠短有效期自然过期。
    如需立即吊销 Access Token,可用 Redis 维护黑名单(详见第六章)。
    """
    await redis.delete(_refresh_token_key(refresh_token))
    return {"message": "已成功登出"}

六、Session 管理方案对比

不同的 Session 管理方案各有适用场景,选型时需要权衡安全性、可扩展性和复杂度:

维度 JWT(无状态) Redis Session Cookie Session
状态存储 客户端(Token 内) 服务端(Redis) 服务端(Cookie + 签名)
服务端存储 不需要 需要 Redis 需要加密密钥
实时吊销 ❌ 需要黑名单 ✅ 直接删 Key ✅ 直接删 Session
水平扩展 ✅ 天然支持 ✅ 共享 Redis 即可 ⚠️ 需粘性会话或共享存储
Token 大小 较大(含 payload) 小(只有 session_id) 中等
适用场景 API 服务、移动端、微服务间调用 需要实时踢人、有状态会话 传统 Web 应用、表单提交
Spring 对应 Spring Security + JWT Spring Session + Redis HttpSession + Spring Security

6.1 何时选 Redis Session

当以下需求出现时,应该考虑 Redis Session:

  1. 需要实时踢人下线(后台封号、多设备互踢)
  2. 会话包含动态状态(如购物车、多步骤表单、临时权限)
  3. Token 太大导致传输开销(JWT payload 嵌入大量数据)

6.2 简单 Redis Session 实现

python 复制代码
# app/session.py --- 简单 Redis Session 管理

import secrets
import json
from typing import Optional
from redis.asyncio import Redis


class RedisSession:
    """
    基于 Redis 的简单 Session 实现。
    Spring 对应:@EnableRedisHttpSession + HttpSession
    """
    SESSION_KEY_PREFIX = "shop:v1:session:"
    DEFAULT_TTL = 3600  # 1 小时

    def __init__(self, redis: Redis, session_id: str):
        self.redis = redis
        self.session_id = session_id
        self._key = f"{self.SESSION_KEY_PREFIX}{session_id}"
        self._data: dict = {}

    @classmethod
    async def create(cls, redis: Redis, ttl: int = 3600) -> "RedisSession":  # 默认 TTL 1 小时(3600 秒)
        """创建新 Session"""
        session_id = secrets.token_urlsafe(32)
        session = cls(redis, session_id)
        await redis.setex(session._key, ttl, json.dumps({}))
        return session

    @classmethod
    async def load(cls, redis: Redis, session_id: str) -> Optional["RedisSession"]:
        """加载已有 Session(不存在返回 None)"""
        session = cls(redis, session_id)
        data = await redis.get(session._key)
        if data is None:
            return None  # Session 不存在或已过期
        session._data = json.loads(data)
        return session

    async def set(self, key: str, value, ttl: int = 3600) -> None:  # 默认 TTL 1 小时(3600 秒)
        """设置 Session 字段并刷新 TTL(滑动过期)"""
        self._data[key] = value
        await self.redis.setex(self._key, ttl, json.dumps(self._data))

    def get(self, key: str, default=None):
        """读取 Session 字段"""
        return self._data.get(key, default)

    async def destroy(self) -> None:
        """销毁 Session(登出时调用)"""
        await self.redis.delete(self._key)


# 在路由中使用
@router.get("/me/cart")
async def get_cart(
    session_id: str = Cookie(default=None),
    redis: Redis = Depends(get_redis),
):
    if not session_id:
        raise HTTPException(status_code=401, detail="未登录")

    session = await RedisSession.load(redis, session_id)
    if session is None:
        raise HTTPException(status_code=401, detail="Session 已过期")

    cart = session.get("cart", [])
    return {"cart": cart}

七、缓存三大问题及解法

缓存系统的三大经典问题:穿透(Penetration)击穿(Breakdown)雪崩(Avalanche)。每个问题都源于真实的生产故障,值得深入理解。

7.1 缓存穿透 --- 大量请求访问不存在的数据

现象 :攻击者或异常客户端反复请求一个不存在的商品 ID(如 product_id=-1),每次请求都会穿透缓存打到 DB,DB 被压垮。

复制代码
请求 → Redis 未命中 → DB 也没有 → 什么都没缓存 → 下次请求重复打 DB

解法一:空值缓存 (已在 cache_get_or_set 中实现)

python 复制代码
# 查 DB 为空,仍然写入缓存(值为特殊标记)
if result is None:
    await redis.setex(key, null_ttl, "__NULL__")

# 读缓存时识别空值标记
if cached == "__NULL__":
    return None  # 直接返回,不查 DB

解法二:布隆过滤器(进阶方案)

布隆过滤器(Bloom Filter)是一种空间效率极高的概率型数据结构:用极少内存判断"某个 ID 是否一定不存在于 DB 中"。

python 复制代码
# 需要安装:pip install pybloom-live 或使用 Redis 的 RedisBloom 模块

from redis.asyncio import Redis

# 使用 Redis BitMap 模拟简单的布隆过滤器
class SimpleBloomFilter:
    """
    简单布隆过滤器,基于 Redis BitMap。
    生产环境推荐使用 RedisBloom 模块(更精确,有专用命令)。
    """
    def __init__(self, redis: Redis, key: str, capacity: int = 1_000_000):
        self.redis = redis
        self.key = key
        self.capacity = capacity
        # 3 个哈希函数(位置不同,降低误判率)
        self.hash_seeds = [31, 37, 97]

    def _hash(self, value: str, seed: int) -> int:
        """简单哈希函数"""
        result = 0
        for char in value:
            result = seed * result + ord(char)
        return result % self.capacity

    async def add(self, value: str) -> None:
        """向布隆过滤器添加值(商品入库时调用)"""
        pipe = self.redis.pipeline()
        for seed in self.hash_seeds:
            bit_pos = self._hash(value, seed)
            pipe.setbit(self.key, bit_pos, 1)
        await pipe.execute()

    async def exists(self, value: str) -> bool:
        """
        判断值是否可能存在。
        返回 False:该值一定不存在(0% 误判)
        返回 True:该值可能存在(有一定概率误判)
        """
        pipe = self.redis.pipeline()
        for seed in self.hash_seeds:
            bit_pos = self._hash(value, seed)
            pipe.getbit(self.key, bit_pos)
        bits = await pipe.execute()
        return all(bits)  # 所有位都为 1 才认为可能存在


# 在商品详情接口中使用布隆过滤器
@router.get("/{product_id}")
async def get_product_with_bloom(
    product_id: int,
    redis: Redis = Depends(get_redis),
    db: AsyncSession = Depends(get_db),
):
    bloom = SimpleBloomFilter(redis, "shop:bloom:products")

    # 快速拦截不存在的商品 ID(布隆过滤器先行拦截)
    if not await bloom.exists(str(product_id)):
        raise HTTPException(status_code=404, detail="商品不存在")

    # 布隆过滤器说"可能存在",再走 Cache Aside 流程
    result = await cache_get_or_set(
        redis, product_cache_key(product_id),
        fetch_func=lambda: _fetch_product_from_db(product_id, db),
        ttl=CacheTTL.NORMAL_PRODUCT,
    )
    ...

7.2 缓存击穿 --- 热点 Key 过期瞬间的流量洪峰

现象:某个超热商品的缓存恰好在秒杀开始前过期,瞬间几千个请求同时穿透到 DB。

复制代码
Key 过期 → 1000 个请求同时未命中缓存 → 1000 个 DB 查询 → DB 压力暴增

解法一:分布式锁(互斥重建)

python 复制代码
async def cache_get_or_set_with_lock(
    redis: aioredis.Redis,
    key: str,
    fetch_func: Callable,
    ttl: int = 300,
) -> Optional[Any]:
    """
    带分布式锁的 Cache Aside,防止缓存击穿。
    只有一个请求能重建缓存,其他请求等待。
    """
    # 1. 先查缓存(大多数请求在这里返回)
    cached = await redis.get(key)
    if cached is not None:
        return json.loads(cached) if cached != "__NULL__" else None

    # 2. 缓存未命中,用分布式锁保护重建过程
    lock_key = f"rebuild:{key}"
    async with RedisLock(redis, lock_key, expire=5, retry_times=10, retry_delay=0.05) as acquired:
        if not acquired:
            # 获取锁失败,再读一次缓存(可能其他请求已经重建完了)
            cached = await redis.get(key)
            if cached is not None:
                return json.loads(cached) if cached != "__NULL__" else None
            return None  # 实在拿不到,返回 None(降级处理)

        # 3. 拿到锁,再次确认缓存(Double Check)
        cached = await redis.get(key)
        if cached is not None:
            return json.loads(cached) if cached != "__NULL__" else None

        # 4. 真正查 DB 并重建缓存
        result = await fetch_func()
        if result is None:
            await redis.setex(key, 60, "__NULL__")
        else:
            await redis.setex(key, ttl, json.dumps(result))
        return result

解法二:永不过期策略(逻辑过期)

python 复制代码
import time

async def get_hot_product(redis: aioredis.Redis, product_id: int, db: AsyncSession):
    """
    热点数据永不物理过期,逻辑过期时异步刷新。
    牺牲强一致性换取极低延迟(秒杀场景可接受短暂旧数据)。
    """
    key = f"shop:v1:hot:product:{product_id}"
    raw = await redis.get(key)

    if raw:
        data = json.loads(raw)
        # 检查逻辑过期时间
        if data["_expire_at"] > time.time():
            return data["value"]  # 未逻辑过期,直接返回
        else:
            # 逻辑已过期,异步触发刷新(当前请求先返回旧数据)
            asyncio.create_task(_async_refresh_product(redis, product_id, db))
            return data["value"]  # 返回旧数据,可接受短暂不一致

    # 冷启动:Key 不存在,同步查 DB
    result = await _fetch_product_from_db(product_id, db)
    if result:
        payload = json.dumps({
            "value": result,
            "_expire_at": time.time() + 300,  # 逻辑过期时间:5 分钟后
        })
        await redis.set(key, payload)  # 不设物理过期时间(永不过期)
    return result


async def _async_refresh_product(redis, product_id, db):
    """异步刷新热点商品缓存"""
    result = await _fetch_product_from_db(product_id, db)
    if result:
        payload = json.dumps({
            "value": result,
            "_expire_at": time.time() + 300,
        })
        await redis.set(f"shop:v1:hot:product:{product_id}", payload)

7.3 缓存雪崩 --- 大量 Key 同时过期

现象:系统重启或大规模更新后,大量 Key 的 TTL 相同,同时到期,导致瞬间大量请求打到 DB。

解法:TTL 随机化 + 多级缓存

python 复制代码
import random


def jittered_ttl(base_ttl: int, jitter_ratio: float = 0.2) -> int:
    """
    为 TTL 增加随机抖动,避免大量 Key 同时过期。
    base_ttl=300, jitter_ratio=0.2 → TTL 在 [240, 360] 随机分布
    """
    jitter = int(base_ttl * jitter_ratio)
    return base_ttl + random.randint(-jitter, jitter)


# 使用随机 TTL 写缓存
await redis.setex(key, jittered_ttl(300), value)


# 多级缓存:本地内存(L1)+ Redis(L2)+ DB(L3)
from functools import lru_cache
from datetime import datetime, timedelta

# 本地缓存(进程内,超低延迟,但实例间不共享)
_local_cache: dict[str, tuple[Any, datetime]] = {}
LOCAL_CACHE_TTL = 10  # 本地缓存 10 秒,容忍轻微不一致

async def get_with_multilevel_cache(
    redis: aioredis.Redis,
    key: str,
    fetch_func: Callable,
    ttl: int = 300,
):
    """三级缓存:L1 本地内存 → L2 Redis → L3 数据库"""
    now = datetime.now()

    # L1:本地内存缓存
    if key in _local_cache:
        value, expire_at = _local_cache[key]
        if expire_at > now:
            return value
        del _local_cache[key]

    # L2 + L3:Redis + DB(带随机 TTL 防雪崩)
    result = await cache_get_or_set(
        redis, key, fetch_func,
        ttl=jittered_ttl(ttl),
    )

    # 写入 L1 本地缓存
    if result is not None:
        _local_cache[key] = (result, now + timedelta(seconds=LOCAL_CACHE_TTL))

    return result

八、常见坑与最佳实践

❌ 坑 1:KEYS * 命令拖垮生产 Redis

python 复制代码
# ❌ 糟糕:KEYS 命令遍历所有 Key,Redis 单线程,大数据量时会阻塞数秒
keys = await redis.keys("shop:v1:product:*")
await redis.delete(*keys)

# ✅ 正确:用 SCAN 分批遍历,不阻塞 Redis
cursor = 0
while True:
    cursor, keys = await redis.scan(cursor, match="shop:v1:product:*", count=100)
    if keys:
        await redis.delete(*keys)
    if cursor == 0:
        break

❌ 坑 2:decode_responses=Truepickle 混用

python 复制代码
# ❌ 糟糕:decode_responses=True 会将所有响应解码为 str
# 但 pickle 序列化的是 bytes,解码后变成乱码,反序列化失败
pool = aioredis.ConnectionPool.from_url(url, decode_responses=True)
await redis.set("key", pickle.dumps(obj))   # bytes 被当作 str 存储
pickle.loads(await redis.get("key"))        # 报错:字符串无法 loads

# ✅ 正确方案一:不开 decode_responses,手动 decode
pool = aioredis.ConnectionPool.from_url(url, decode_responses=False)
data = await redis.get("key")
obj = pickle.loads(data)  # data 是 bytes,正常工作

# ✅ 正确方案二(推荐):开 decode_responses,用 json 序列化
pool = aioredis.ConnectionPool.from_url(url, decode_responses=True)
await redis.set("key", json.dumps(obj_dict))
obj_dict = json.loads(await redis.get("key"))

❌ 坑 3:分布式锁未设过期时间导致死锁

python 复制代码
# ❌ 糟糕:SETNX 不设过期时间,进程崩溃后锁永远不释放
await redis.setnx("lock:stock:123", "1")
# ... 进程崩溃了 ...
# 锁永远存在,所有后续请求都被阻塞 → 死锁

# ❌ 也糟糕:SETNX 和 EXPIRE 分两步,非原子
await redis.setnx("lock:stock:123", "1")
await redis.expire("lock:stock:123", 30)  # 如果此处崩溃,仍死锁

# ✅ 正确:SET NX EX 原子命令,一次性设置 Key + 过期时间
await redis.set("lock:stock:123", token, nx=True, ex=30)

❌ 坑 4:缓存更新用"更新"而不是"删除"

python 复制代码
# ❌ 糟糕:更新数据库后同步更新缓存(存在竞态)
product.stock = 99
await db.commit()
await redis.set(product_cache_key(product.id), json.dumps({...}))  # 危险!

# ✅ 正确:先写 DB,再删缓存(Cache Aside 标准写法)
product.stock = 99
await db.commit()
await redis.delete(product_cache_key(product.id))  # 删缓存,不更新
# 下次读请求自然回源 DB 并重建缓存

❌ 坑 5:Refresh Token 直接用原文作 Redis Key

python 复制代码
# ❌ 糟糕:直接用 token 原文作 Key
# Redis 的 Key 可能被打印到日志、监控系统、Redis 管理工具,导致 token 泄露
await redis.setex(f"refresh:{refresh_token}", expire, user_id)

# ✅ 正确:用哈希值作 Key,原文不出现在 Redis 里
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
await redis.setex(f"shop:v1:refresh:{token_hash}", expire, user_id)

❌ 坑 6:忘记处理 Redis 连接异常(缓存降级)

python 复制代码
# ❌ 糟糕:Redis 故障时整个接口 500
@router.get("/{product_id}")
async def get_product(product_id: int, redis: Redis = Depends(get_redis)):
    cached = await redis.get(product_cache_key(product_id))  # Redis 挂了就抛异常
    ...

# ✅ 正确:Redis 故障时降级为直接查 DB(缓存雪崩应急方案)
@router.get("/{product_id}")
async def get_product(
    product_id: int,
    redis: Redis = Depends(get_redis),
    db: AsyncSession = Depends(get_db),
):
    try:
        cached = await redis.get(product_cache_key(product_id))
        if cached:
            return json.loads(cached)
    except Exception as e:
        # 记录日志,降级为查 DB
        logger.warning(f"Redis 查询失败,降级查 DB: {e}")

    # 降级:直接查 DB
    return await _fetch_product_from_db(product_id, db)

九、总结

本篇为 shop-api 构建了完整的 Redis 集成层,覆盖了从基础连接配置到生产级缓存模式的完整链路:

知识点 FastAPI 实现 Spring 对应
Redis 连接池初始化 ConnectionPool.from_url() + lifespan LettuceConnectionFactory + @Bean
Redis 依赖注入 Depends(get_redis) @Autowired RedisTemplate
Cache Aside 读 cache_get_or_set() 工具函数 @Cacheable
Cache Aside 写(删缓存) cache_delete() @CacheEvict
分布式锁 RedisLock 上下文管理器 Redisson RLock
Refresh Token 存储 redis.setex() + SHA256 Key Spring Session + Redis
缓存穿透防护 空值缓存 __NULL__ + 布隆过滤器 Guava BloomFilter + 自定义实现
缓存击穿防护 分布式锁互斥重建 + 逻辑过期 Redisson 热点 Key 方案
缓存雪崩防护 随机 TTL + 多级缓存 Caffeine(L1)+ Redis(L2)

💡 金句:缓存不是银弹,它用最终一致性换取了性能。理解你的业务对"一致性时间窗口"的容忍度,才是缓存设计的核心。


参考资料


下期预告

第 8 篇:后台任务------从 BackgroundTasks 到 Celery

下一篇介绍 FastAPI 中处理后台任务的两种方案:

  • BackgroundTasks 内置轻量级任务(适合发送通知等低优先级操作)
  • Celery + Redis 任务队列(类比 Spring Batch + Quartz,支持重试、定时、结果存储)
  • 延迟任务:下单 30 分钟未付款自动取消
  • Celery Beat 定时任务:每日凌晨同步库存
  • AsyncResult 任务状态查询接口
相关推荐
姚不倒10 小时前
从「LeetCode LRU 缓存」到「生产级 Go Web 服务」:我如何迈出工程化第一步
leetcode·缓存·云原生·golang
努力努力再努力wz10 小时前
【Redis入门系列】:从 hashtable到 listpack:深入理解 Hash 底层编码、字段级过期、核心命令与缓存应用
开发语言·数据结构·数据库·c++·redis·算法·缓存
还是鼠鼠10 小时前
AI掘金头条新闻系统 (Toutiao News)-封装通用成功响应格式
数据库·后端·python·fastapi·web
止语Lab10 小时前
从 sync.Map 到 Redis:Go 缓存升级的三个拐点
redis·缓存·golang
yh弓长10 小时前
Redis的基础指令
redis
洛水水11 小时前
redis缓存:雪崩、穿透、击穿详解
数据库·redis·缓存
洛水水11 小时前
Redis 内存淘汰策略详解
数据库·redis·缓存
Mr. zhihao11 小时前
Redis 集群分区思想的演进:从哈希取余到虚拟槽
redis·哈希算法
yh弓长11 小时前
Redis的string类及基础指令
数据库·redis·缓存