python
复制代码
"""
缓存工具模块
提供:
1. MemoryCache - 内存缓存类(支持TTL过期、线程安全)
2. RedisCache - Redis缓存类(与MemoryCache接口完全一致)
3. CacheManager - 缓存管理器(支持动态切换内存/Redis后端)
使用示例:
from utils.cache import cache
# 设置缓存(带TTL,单位:秒)
cache.set('sms:login:13800138000', {'code': '123456'}, ttl=300)
# 获取缓存
data = cache.get('sms:login:13800138000')
# 删除缓存
cache.delete('sms:login:13800138000')
# 检查是否存在
exists = cache.exists('sms:login:13800138000')
Redis切换(在app.py中初始化时调用):
from utils.cache import cache
import redis
# 根据配置选择缓存后端
if app.config.get('CACHE_BACKEND') == 'redis':
try:
import redis as redis_module
redis_client = redis_module.Redis(
host=app.config['REDIS_HOST'],
port=app.config['REDIS_PORT'],
password=app.config['REDIS_PASSWORD'] or None,
db=app.config['REDIS_DB'],
decode_responses=True,
socket_connect_timeout=3,
socket_timeout=3,
retry_on_timeout=True
)
redis_client.ping()
cache.init_redis(redis_client)
app.logger.info('缓存已切换到 Redis 后端')
except ImportError:
app.logger.warning('CACHE_BACKEND=redis 但 redis 模块未安装,使用内存缓存')
except Exception as e:
app.logger.warning(f'Redis 连接失败,使用内存缓存: {e}')
else:
app.logger.info('使用内存缓存(如需切换请设置 CACHE_BACKEND=redis)')
"""
import json
import threading
import time
from typing import Any, Optional
class MemoryCache:
"""
内存缓存实现(支持TTL过期机制)
特性:
- 线程安全:使用 threading.Lock 保证并发访问安全
- 自动过期:支持 TTL(Time To Live)自动清理过期数据
- 命名空间隔离:通过前缀区分不同业务场景的缓存
"""
def __init__(self):
self._store = {}
self._lock = threading.Lock()
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
设置缓存值
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒),None 表示永不过期
Returns:
bool: 是否设置成功
"""
with self._lock:
expire_at = time.time() + ttl if ttl else None
self._store[key] = {
'value': value,
'expire_at': expire_at
}
return True
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值
Args:
key: 缓存键
Returns:
缓存值,不存在或已过期返回 None
"""
with self._lock:
if key not in self._store:
return None
item = self._store[key]
if item['expire_at'] is not None and time.time() > item['expire_at']:
del self._store[key]
return None
return item['value']
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存键
Returns:
bool: 是否删除成功
"""
with self._lock:
if key in self._store:
del self._store[key]
return True
return False
def exists(self, key: str) -> bool:
"""
检查键是否存在且未过期
Args:
key: 缓存键
Returns:
bool: 是否存在
"""
with self._lock:
if key not in self._store:
return False
item = self._store[key]
if item['expire_at'] is not None and time.time() > item['expire_at']:
del self._store[key]
return False
return True
def keys(self, prefix: str = '') -> list:
"""
获取所有匹配前缀的键列表(自动过滤过期项)
Args:
prefix: 键前缀筛选条件
Returns:
list: 匹配的键列表
"""
with self._lock:
now = time.time()
valid_keys = []
expired_keys = []
for key, item in self._store.items():
if item['expire_at'] is not None and now > item['expire_at']:
expired_keys.append(key)
elif prefix == '' or key.startswith(prefix):
valid_keys.append(key)
for k in expired_keys:
del self._store[k]
return valid_keys
def clear(self, prefix: str = '') -> int:
"""
清除缓存
Args:
prefix: 如果指定,只清除匹配前缀的缓存;否则清除全部
Returns:
int: 清除的数量
"""
with self._lock:
if prefix == '':
count = len(self._store)
self._store.clear()
return count
keys_to_delete = [k for k in self._store if k.startswith(prefix)]
for k in keys_to_delete:
del self._store[k]
return len(keys_to_delete)
def cleanup_expired(self) -> int:
"""
手动清理所有过期的缓存项
Returns:
int: 清理的数量
"""
with self._lock:
now = time.time()
expired_keys = [
k for k, v in self._store.items()
if v['expire_at'] is not None and now > v['expire_at']
]
for k in expired_keys:
del self._store[k]
return len(expired_keys)
def size(self) -> int:
"""获取当前缓存大小(包含未清理的过期项)"""
with self._lock:
return len(self._store)
def get_ttl(self, key: str) -> Optional[float]:
"""
获取剩余过期时间
Args:
key: 缓存键
Returns:
float: 剩余秒数,不存在返回 None
"""
with self._lock:
if key not in self._store:
return None
item = self._store[key]
if item['expire_at'] is None:
return float('inf')
remaining = item['expire_at'] - time.time()
return max(0, remaining) if remaining > 0 else None
def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
原子递增计数器
适用于限流、统计等场景
Args:
key: 缓存键
amount: 递增步长
ttl: 首次设置时的过期时间
Returns:
int: 递增后的值
"""
with self._lock:
current = self.get(key)
if current is None:
new_value = amount
self.set(key, new_value, ttl)
else:
new_value = current + amount
self._store[key]['value'] = new_value
return new_value
def check_rate_limit(self, key: str, interval: int, max_count: int = 1) -> tuple:
"""
检查频率限制
Args:
key: 限制标识键
interval: 时间窗口(秒)
max_count: 允许的最大次数
Returns:
tuple: (allowed: bool, remaining_time: int)
allowed=True 表示允许执行
remaining_time=0 表示无等待时间
"""
last_time = self.get(f"rate:{key}")
current_time = time.time()
if last_time is None:
self.set(f"rate:{key}", current_time, interval)
return (True, 0)
elapsed = current_time - last_time
if elapsed < interval:
remaining = int(interval - elapsed)
return (False, remaining)
self.set(f"rate:{key}", current_time, interval)
return (True, 0)
class RedisCache:
"""
Redis 缓存实现(支持 TTL 过期机制)
与 MemoryCache 接口完全一致,业务代码无需任何修改。
使用 Redis 的 EXPIRE / TTL 原生过期机制,支持多进程/多服务器共享缓存。
"""
def __init__(self, redis_client):
"""
Args:
redis_client: redis.Redis 连接实例
"""
self._redis = redis_client
self._prefix = 'cache:'
def _make_key(self, key: str) -> str:
"""拼接完整键名"""
return f"{self._prefix}{key}"
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
设置缓存值(自动序列化为 JSON)
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒),None 表示永不过期
Returns:
bool: 是否设置成功
"""
try:
serialized = json.dumps(value, ensure_ascii=False)
self._redis.set(self._make_key(key), serialized, ex=ttl)
return True
except Exception as e:
return False
def get(self, key: str) -> Optional[Any]:
"""
获取缓存值(自动反序列化)
Args:
key: 缓存键
Returns:
缓存值,不存在或已过期返回 None
"""
try:
data = self._redis.get(self._make_key(key))
if data is None:
return None
return json.loads(data)
except Exception:
return None
def delete(self, key: str) -> bool:
"""
删除缓存
Args:
key: 缓存键
Returns:
bool: 是否删除成功
"""
try:
return bool(self._redis.delete(self._make_key(key)))
except Exception:
return False
def exists(self, key: str) -> bool:
"""
检查键是否存在
Args:
key: 缓存键
Returns:
bool: 是否存在
"""
try:
return bool(self._redis.exists(self._make_key(key)))
except Exception:
return False
def keys(self, prefix: str = '') -> list:
"""
获取匹配前缀的所有键
Args:
prefix: 键前缀筛选条件
Returns:
list: 匹配的键列表
"""
try:
pattern = f"{self._make_key(prefix)}*"
keys = self._redis.keys(pattern)
result = []
for k in keys:
key_str = k.decode() if isinstance(k, bytes) else k
if key_str.startswith(self._prefix):
result.append(key_str[len(self._prefix):])
return result
except Exception:
return []
def clear(self, prefix: str = '') -> int:
"""
清除缓存
Args:
prefix: 如果指定,只清除匹配前缀的缓存;否则清除全部
Returns:
int: 清除的数量
"""
try:
keys = self.keys(prefix)
if keys:
full_keys = [self._make_key(k) for k in keys]
return self._redis.delete(*full_keys)
return 0
except Exception:
return 0
def get_ttl(self, key: str) -> Optional[float]:
"""
获取剩余过期时间
Args:
key: 缓存键
Returns:
float: 剩余秒数,不存在返回 None
"""
try:
ttl = self._redis.ttl(self._make_key(key))
return ttl if ttl >= 0 else None
except Exception:
return None
def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
原子递增计数器(使用 Redis INCRBY)
Args:
key: 缓存键
amount: 递增步长
ttl: 首次设置时的过期时间
Returns:
int: 递增后的值
"""
try:
full_key = self._make_key(key)
value = self._redis.incr(full_key, amount)
if ttl:
self._redis.expire(full_key, ttl)
return value
except Exception:
return 0
def check_rate_limit(self, key: str, interval: int, max_count: int = 1) -> tuple:
"""
检查频率限制(使用 Redis 原子操作)
Args:
key: 限制标识键
interval: 时间窗口(秒)
max_count: 允许的最大次数
Returns:
tuple: (allowed: bool, remaining_time: int)
allowed=True 表示允许执行
remaining_time=0 表示无等待时间
"""
try:
rate_key = self._make_key(f"rate:{key}")
last_time = self._redis.get(rate_key)
if last_time is None:
self._redis.set(rate_key, time.time(), ex=interval)
return (True, 0)
elapsed = time.time() - float(last_time)
if elapsed < interval:
remaining = int(interval - elapsed)
return (False, remaining)
self._redis.set(rate_key, time.time(), ex=interval)
return (True, 0)
except Exception:
return (True, 0)
class CacheManager:
"""
缓存管理器(单例模式)
提供统一的缓存操作接口,支持按业务场景分类管理缓存。
默认使用内存缓存,可通过 init_redis() 切换到 Redis 后端。
命名空间规范:
- sms:{scene}:{phone} - 短信验证码
- email:{type}:{email} - 邮箱验证码
- rate_limit:{key} - 频率限制
- session:{session_id} - 会话数据
"""
_instance = None
_initialized = False
def __new__(cls):
"""单例模式实现"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if CacheManager._initialized:
return
CacheManager._initialized = True
self._cache = MemoryCache()
self._backend = 'memory'
def init_redis(self, redis_client):
"""
切换到 Redis 后端
在 app.py 中初始化时调用:
from utils.cache import cache
import redis
redis_client = redis.Redis(host='localhost', port=6379, db=0)
cache.init_redis(redis_client)
Args:
redis_client: redis.Redis 连接实例
"""
self._cache = RedisCache(redis_client)
self._backend = 'redis'
@property
def backend(self) -> str:
"""获取当前缓存后端类型:'memory' 或 'redis'"""
return self._backend
@classmethod
def reset_instance(cls):
"""重置单例(用于测试)"""
cls._instance = None
cls._initialized = False
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""设置缓存"""
return self._cache.set(key, value, ttl)
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
return self._cache.get(key)
def delete(self, key: str) -> bool:
"""删除缓存"""
return self._cache.delete(key)
def exists(self, key: str) -> bool:
"""检查键是否存在"""
return self._cache.exists(key)
def keys(self, prefix: str = '') -> list:
"""获取匹配前缀的所有键"""
return self._cache.keys(prefix)
def clear(self, prefix: str = '') -> int:
"""清除缓存"""
return self._cache.clear(prefix)
def set_with_timestamp(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
设置缓存并自动记录创建时间戳
适用于需要记录发送/创建时间的场景(如验证码频率限制)
Args:
key: 缓存键
value: 缓存值(字典类型会自动添加时间戳)
ttl: 过期时间(秒)
Returns:
bool: 是否设置成功
"""
if isinstance(value, dict):
value['_created_at'] = time.time()
return self.set(key, value, ttl)
def get_ttl(self, key: str) -> Optional[float]:
"""
获取剩余过期时间
Args:
key: 缓存键
Returns:
float: 剩余秒数,不存在返回 None
"""
return self._cache.get_ttl(key)
def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""
原子递增计数器
适用于限流、统计等场景
Args:
key: 缓存键
amount: 递增步长
ttl: 首次设置时的过期时间
Returns:
int: 递增后的值
"""
return self._cache.increment(key, amount, ttl)
def check_rate_limit(self, key: str, interval: int, max_count: int = 1) -> tuple:
"""
检查频率限制
Args:
key: 限制标识键
interval: 时间窗口(秒)
max_count: 允许的最大次数
Returns:
tuple: (allowed: bool, remaining_time: int)
allowed=True 表示允许执行
remaining_time=0 表示无等待时间
"""
return self._cache.check_rate_limit(key, interval, max_count)
# 全局缓存实例(单例)
cache = CacheManager()