缓存工具类封装:内存与Redis无缝切换

python 复制代码
"""
缓存工具模块

提供:
1. MemoryCache - 内存缓存类(支持TTL过期、线程安全)
2. RedisCache - Redis缓存类(与MemoryCache接口完全一致)
3. CacheManager - 缓存管理器(支持动态切换内存/Redis后端)

使用示例:
    from utils.cache import cache

    # 设置缓存(带TTL,单位:秒)
    cache.set('sms:login:13800138000', {'code': '123456'}, ttl=300)

    # 获取缓存
    data = cache.get('sms:login:13800138000')

    # 删除缓存
    cache.delete('sms:login:13800138000')

    # 检查是否存在
    exists = cache.exists('sms:login:13800138000')

Redis切换(在app.py中初始化时调用):
    from utils.cache import cache
    import redis
    # 根据配置选择缓存后端
    if app.config.get('CACHE_BACKEND') == 'redis':
        try:
            import redis as redis_module
            redis_client = redis_module.Redis(
                host=app.config['REDIS_HOST'],
                port=app.config['REDIS_PORT'],
                password=app.config['REDIS_PASSWORD'] or None,
                db=app.config['REDIS_DB'],
                decode_responses=True,
                socket_connect_timeout=3,
                socket_timeout=3,
                retry_on_timeout=True
            )
            redis_client.ping()
            cache.init_redis(redis_client)
            app.logger.info('缓存已切换到 Redis 后端')
        except ImportError:
            app.logger.warning('CACHE_BACKEND=redis 但 redis 模块未安装,使用内存缓存')
        except Exception as e:
            app.logger.warning(f'Redis 连接失败,使用内存缓存: {e}')
    else:
        app.logger.info('使用内存缓存(如需切换请设置 CACHE_BACKEND=redis)')
"""
import json
import threading
import time
from typing import Any, Optional


class MemoryCache:
    """
    内存缓存实现(支持TTL过期机制)

    特性:
    - 线程安全:使用 threading.Lock 保证并发访问安全
    - 自动过期:支持 TTL(Time To Live)自动清理过期数据
    - 命名空间隔离:通过前缀区分不同业务场景的缓存
    """

    def __init__(self):
        self._store = {}
        self._lock = threading.Lock()

    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
        """
        设置缓存值

        Args:
            key: 缓存键
            value: 缓存值
            ttl: 过期时间(秒),None 表示永不过期

        Returns:
            bool: 是否设置成功
        """
        with self._lock:
            expire_at = time.time() + ttl if ttl else None
            self._store[key] = {
                'value': value,
                'expire_at': expire_at
            }
            return True

    def get(self, key: str) -> Optional[Any]:
        """
        获取缓存值

        Args:
            key: 缓存键

        Returns:
            缓存值,不存在或已过期返回 None
        """
        with self._lock:
            if key not in self._store:
                return None

            item = self._store[key]

            if item['expire_at'] is not None and time.time() > item['expire_at']:
                del self._store[key]
                return None

            return item['value']

    def delete(self, key: str) -> bool:
        """
        删除缓存

        Args:
            key: 缓存键

        Returns:
            bool: 是否删除成功
        """
        with self._lock:
            if key in self._store:
                del self._store[key]
                return True
            return False

    def exists(self, key: str) -> bool:
        """
        检查键是否存在且未过期

        Args:
            key: 缓存键

        Returns:
            bool: 是否存在
        """
        with self._lock:
            if key not in self._store:
                return False

            item = self._store[key]

            if item['expire_at'] is not None and time.time() > item['expire_at']:
                del self._store[key]
                return False

            return True

    def keys(self, prefix: str = '') -> list:
        """
        获取所有匹配前缀的键列表(自动过滤过期项)

        Args:
            prefix: 键前缀筛选条件

        Returns:
            list: 匹配的键列表
        """
        with self._lock:
            now = time.time()
            valid_keys = []

            expired_keys = []
            for key, item in self._store.items():
                if item['expire_at'] is not None and now > item['expire_at']:
                    expired_keys.append(key)
                elif prefix == '' or key.startswith(prefix):
                    valid_keys.append(key)

            for k in expired_keys:
                del self._store[k]

            return valid_keys

    def clear(self, prefix: str = '') -> int:
        """
        清除缓存

        Args:
            prefix: 如果指定,只清除匹配前缀的缓存;否则清除全部

        Returns:
            int: 清除的数量
        """
        with self._lock:
            if prefix == '':
                count = len(self._store)
                self._store.clear()
                return count

            keys_to_delete = [k for k in self._store if k.startswith(prefix)]
            for k in keys_to_delete:
                del self._store[k]
            return len(keys_to_delete)

    def cleanup_expired(self) -> int:
        """
        手动清理所有过期的缓存项

        Returns:
            int: 清理的数量
        """
        with self._lock:
            now = time.time()
            expired_keys = [
                k for k, v in self._store.items()
                if v['expire_at'] is not None and now > v['expire_at']
            ]
            for k in expired_keys:
                del self._store[k]
            return len(expired_keys)

    def size(self) -> int:
        """获取当前缓存大小(包含未清理的过期项)"""
        with self._lock:
            return len(self._store)

    def get_ttl(self, key: str) -> Optional[float]:
        """
        获取剩余过期时间

        Args:
            key: 缓存键

        Returns:
            float: 剩余秒数,不存在返回 None
        """
        with self._lock:
            if key not in self._store:
                return None

            item = self._store[key]
            if item['expire_at'] is None:
                return float('inf')

            remaining = item['expire_at'] - time.time()
            return max(0, remaining) if remaining > 0 else None

    def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
        """
        原子递增计数器

        适用于限流、统计等场景

        Args:
            key: 缓存键
            amount: 递增步长
            ttl: 首次设置时的过期时间

        Returns:
            int: 递增后的值
        """
        with self._lock:
            current = self.get(key)
            if current is None:
                new_value = amount
                self.set(key, new_value, ttl)
            else:
                new_value = current + amount
                self._store[key]['value'] = new_value
            return new_value

    def check_rate_limit(self, key: str, interval: int, max_count: int = 1) -> tuple:
        """
        检查频率限制

        Args:
            key: 限制标识键
            interval: 时间窗口(秒)
            max_count: 允许的最大次数

        Returns:
            tuple: (allowed: bool, remaining_time: int)
                   allowed=True 表示允许执行
                   remaining_time=0 表示无等待时间
        """
        last_time = self.get(f"rate:{key}")
        current_time = time.time()

        if last_time is None:
            self.set(f"rate:{key}", current_time, interval)
            return (True, 0)

        elapsed = current_time - last_time
        if elapsed < interval:
            remaining = int(interval - elapsed)
            return (False, remaining)

        self.set(f"rate:{key}", current_time, interval)
        return (True, 0)


class RedisCache:
    """
    Redis 缓存实现(支持 TTL 过期机制)

    与 MemoryCache 接口完全一致,业务代码无需任何修改。
    使用 Redis 的 EXPIRE / TTL 原生过期机制,支持多进程/多服务器共享缓存。
    """

    def __init__(self, redis_client):
        """
        Args:
            redis_client: redis.Redis 连接实例
        """
        self._redis = redis_client
        self._prefix = 'cache:'

    def _make_key(self, key: str) -> str:
        """拼接完整键名"""
        return f"{self._prefix}{key}"

    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
        """
        设置缓存值(自动序列化为 JSON)

        Args:
            key: 缓存键
            value: 缓存值
            ttl: 过期时间(秒),None 表示永不过期

        Returns:
            bool: 是否设置成功
        """
        try:
            serialized = json.dumps(value, ensure_ascii=False)
            self._redis.set(self._make_key(key), serialized, ex=ttl)
            return True
        except Exception as e:
            return False

    def get(self, key: str) -> Optional[Any]:
        """
        获取缓存值(自动反序列化)

        Args:
            key: 缓存键

        Returns:
            缓存值,不存在或已过期返回 None
        """
        try:
            data = self._redis.get(self._make_key(key))
            if data is None:
                return None
            return json.loads(data)
        except Exception:
            return None

    def delete(self, key: str) -> bool:
        """
        删除缓存

        Args:
            key: 缓存键

        Returns:
            bool: 是否删除成功
        """
        try:
            return bool(self._redis.delete(self._make_key(key)))
        except Exception:
            return False

    def exists(self, key: str) -> bool:
        """
        检查键是否存在

        Args:
            key: 缓存键

        Returns:
            bool: 是否存在
        """
        try:
            return bool(self._redis.exists(self._make_key(key)))
        except Exception:
            return False

    def keys(self, prefix: str = '') -> list:
        """
        获取匹配前缀的所有键

        Args:
            prefix: 键前缀筛选条件

        Returns:
            list: 匹配的键列表
        """
        try:
            pattern = f"{self._make_key(prefix)}*"
            keys = self._redis.keys(pattern)
            result = []
            for k in keys:
                key_str = k.decode() if isinstance(k, bytes) else k
                if key_str.startswith(self._prefix):
                    result.append(key_str[len(self._prefix):])
            return result
        except Exception:
            return []

    def clear(self, prefix: str = '') -> int:
        """
        清除缓存

        Args:
            prefix: 如果指定,只清除匹配前缀的缓存;否则清除全部

        Returns:
            int: 清除的数量
        """
        try:
            keys = self.keys(prefix)
            if keys:
                full_keys = [self._make_key(k) for k in keys]
                return self._redis.delete(*full_keys)
            return 0
        except Exception:
            return 0

    def get_ttl(self, key: str) -> Optional[float]:
        """
        获取剩余过期时间

        Args:
            key: 缓存键

        Returns:
            float: 剩余秒数,不存在返回 None
        """
        try:
            ttl = self._redis.ttl(self._make_key(key))
            return ttl if ttl >= 0 else None
        except Exception:
            return None

    def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
        """
        原子递增计数器(使用 Redis INCRBY)

        Args:
            key: 缓存键
            amount: 递增步长
            ttl: 首次设置时的过期时间

        Returns:
            int: 递增后的值
        """
        try:
            full_key = self._make_key(key)
            value = self._redis.incr(full_key, amount)
            if ttl:
                self._redis.expire(full_key, ttl)
            return value
        except Exception:
            return 0

    def check_rate_limit(self, key: str, interval: int, max_count: int = 1) -> tuple:
        """
        检查频率限制(使用 Redis 原子操作)

        Args:
            key: 限制标识键
            interval: 时间窗口(秒)
            max_count: 允许的最大次数

        Returns:
            tuple: (allowed: bool, remaining_time: int)
                   allowed=True 表示允许执行
                   remaining_time=0 表示无等待时间
        """
        try:
            rate_key = self._make_key(f"rate:{key}")
            last_time = self._redis.get(rate_key)

            if last_time is None:
                self._redis.set(rate_key, time.time(), ex=interval)
                return (True, 0)

            elapsed = time.time() - float(last_time)
            if elapsed < interval:
                remaining = int(interval - elapsed)
                return (False, remaining)

            self._redis.set(rate_key, time.time(), ex=interval)
            return (True, 0)
        except Exception:
            return (True, 0)


class CacheManager:
    """
    缓存管理器(单例模式)

    提供统一的缓存操作接口,支持按业务场景分类管理缓存。
    默认使用内存缓存,可通过 init_redis() 切换到 Redis 后端。

    命名空间规范:
    - sms:{scene}:{phone} - 短信验证码
    - email:{type}:{email} - 邮箱验证码
    - rate_limit:{key} - 频率限制
    - session:{session_id} - 会话数据
    """

    _instance = None
    _initialized = False

    def __new__(cls):
        """单例模式实现"""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self):
        if CacheManager._initialized:
            return
        CacheManager._initialized = True
        self._cache = MemoryCache()
        self._backend = 'memory'

    def init_redis(self, redis_client):
        """
        切换到 Redis 后端

        在 app.py 中初始化时调用:
            from utils.cache import cache
            import redis
            redis_client = redis.Redis(host='localhost', port=6379, db=0)
            cache.init_redis(redis_client)

        Args:
            redis_client: redis.Redis 连接实例
        """
        self._cache = RedisCache(redis_client)
        self._backend = 'redis'

    @property
    def backend(self) -> str:
        """获取当前缓存后端类型:'memory' 或 'redis'"""
        return self._backend

    @classmethod
    def reset_instance(cls):
        """重置单例(用于测试)"""
        cls._instance = None
        cls._initialized = False

    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
        """设置缓存"""
        return self._cache.set(key, value, ttl)

    def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        return self._cache.get(key)

    def delete(self, key: str) -> bool:
        """删除缓存"""
        return self._cache.delete(key)

    def exists(self, key: str) -> bool:
        """检查键是否存在"""
        return self._cache.exists(key)

    def keys(self, prefix: str = '') -> list:
        """获取匹配前缀的所有键"""
        return self._cache.keys(prefix)

    def clear(self, prefix: str = '') -> int:
        """清除缓存"""
        return self._cache.clear(prefix)

    def set_with_timestamp(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
        """
        设置缓存并自动记录创建时间戳

        适用于需要记录发送/创建时间的场景(如验证码频率限制)

        Args:
            key: 缓存键
            value: 缓存值(字典类型会自动添加时间戳)
            ttl: 过期时间(秒)

        Returns:
            bool: 是否设置成功
        """
        if isinstance(value, dict):
            value['_created_at'] = time.time()
        return self.set(key, value, ttl)

    def get_ttl(self, key: str) -> Optional[float]:
        """
        获取剩余过期时间

        Args:
            key: 缓存键

        Returns:
            float: 剩余秒数,不存在返回 None
        """
        return self._cache.get_ttl(key)

    def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
        """
        原子递增计数器

        适用于限流、统计等场景

        Args:
            key: 缓存键
            amount: 递增步长
            ttl: 首次设置时的过期时间

        Returns:
            int: 递增后的值
        """
        return self._cache.increment(key, amount, ttl)

    def check_rate_limit(self, key: str, interval: int, max_count: int = 1) -> tuple:
        """
        检查频率限制

        Args:
            key: 限制标识键
            interval: 时间窗口(秒)
            max_count: 允许的最大次数

        Returns:
            tuple: (allowed: bool, remaining_time: int)
                   allowed=True 表示允许执行
                   remaining_time=0 表示无等待时间
        """
        return self._cache.check_rate_limit(key, interval, max_count)


# 全局缓存实例(单例)
cache = CacheManager()
相关推荐
m0_609160491 小时前
Go语言Beego框架如何用_Go语言Beego框架入门教程【高效】
jvm·数据库·python
闵孚龙1 小时前
Claude Code 缓存优化模式全解析:AI Agent 上下文工程、Prompt Cache、工具 Schema 缓存、Token 成本优化
人工智能·缓存·prompt
阿正的梦工坊7 小时前
深入理解 PyTorch 中的 unsqueeze 操作
人工智能·pytorch·python
FreakStudio8 小时前
硬件版【Cursor】?aily blockly IDE尝鲜封神,实战硬伤尽显
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机
测试员周周10 小时前
【Appium 系列】第06节-页面对象实现 — LoginPage 实战
开发语言·前端·人工智能·python·功能测试·appium·测试用例
2301_7838486510 小时前
优化文本分类中堆叠模型的网格搜索性能:避免训练卡顿的实战指南
jvm·数据库·python
CLX050511 小时前
如何安装Oracle 12c Cloud Control_OMS服务端组件与Agent部署
jvm·数据库·python
老纪12 小时前
SQL中如何查找特定的空值行:WHERE IS NULL深度解析
jvm·数据库·python
噜噜噜阿鲁~12 小时前
python学习笔记 | 10.0、面向对象编程
笔记·python·学习