python缓存装饰器实现方案

写python的时候突然想着能不能用注解于是就写了个这个


文章目录

原始版

py 复制代码
import os
import pickle
import hashlib
import inspect
import functools


def _generate_cache_filename(func, *args, **kwargs):
    """生成缓存文件名的内部函数"""
    # 获取调用来源文件的绝对路径
    caller_frame = inspect.stack()[2]  # 注意调整为2,跳过当前函数和调用者
    caller_file = os.path.abspath(caller_frame.filename)

    # 生成调用文件路径的短哈希
    file_hash = hashlib.md5(caller_file.encode()).hexdigest()[:8]

    # 生成参数签名
    args_repr = "_".join([repr(arg) for arg in args])
    kwargs_repr = "_".join([f"{k}={repr(v)}" for k, v in kwargs.items()])

    # 处理无参数情况
    param_repr = f"{args_repr}_{kwargs_repr}" if (args or kwargs) else "no_params"

    # 组合最终缓存文件名
    return f"{func.__name__}_{param_repr}_{file_hash}.pkl"


def cache(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # 使用共享函数生成缓存文件名
        cache_file = _generate_cache_filename(func, *args, **kwargs)
        # 缓存逻辑
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as f:
                return pickle.load(f)

        result = func(*args, **kwargs)
        with open(cache_file, 'wb') as f:
            pickle.dump(result, f)
        print(f'缓存已保存:{cache_file}')
        return result

    return wrapper


def clear_cache(func, *args, **kwargs):
    """手动清除缓存文件"""
    # 使用共享函数生成缓存文件名
    cache_file = _generate_cache_filename(func, *args, **kwargs)

    # 删除缓存文件
    if os.path.exists(cache_file):
        os.remove(cache_file)
        print(f"缓存已删除: {cache_file}")
    else:
        print(f"缓存文件不存在: {cache_file}")


# 测试用例
@cache
def get_data(a, b):
    print("计算数据")
    return a + b


if __name__ == "__main__":
    # 第一次调用(创建缓存)
    print(get_data(1, 2))  # 输出: 计算数据 和 3

    # 第二次调用(读取缓存)
    print(get_data(1, 2))  # 无"计算数据"输出

    # 清除缓存
    clear_cache(get_data, 1, 2)  # 成功删除

    # 再次调用(重新计算)
    print(get_data(1, 2))  # 再次输出"计算数据"

1._generate_cache_filename用于生成缓存文件名字,inspect.stack()[2]获取调用栈中的当前使用的文件名字,提取调用文件的绝对路径并转换为8位MD5哈希值。

*args和**kwargs分别转换为字符串表示,用于区分不同参数的同名函数,当函数无参数时,使用"no_params"。

【这里需要todo一下:传入的参数判断是否能做为合法的文件名字】

最终生成"函数名_参数签名_调用文件哈希.pkl"。

【todo:最终的文件名称不能超过系统保存的最大长度】

需要确保_generate_cache_filename函数能生成唯一且合法的文件名

2.def cache(func)

简单的缓存装饰器,将函数的计算结果持久化到文件中

使用装饰器模式,外层函数接受被装饰函数作为参数

functools.wraps保留原函数的元信息

内层wrapper函数处理实际调用逻辑

通过_generate_cache_filename函数生成唯一的缓存文件名

检查缓存文件是否存在,存在则直接读取并返回缓存结果

否则调用原始函数获取计算结果,使用pickle模块序列化结果到文件,打印缓存保存信息,返回计算结果

注意:

被缓存函数的返回值必须可被pickle序列化

在多进程环境中使用时需注意文件锁问题

缓存文件需要定期清理以避免存储空间占用

需要todo改进:

添加缓存过期机制

支持自定义序列化方法 todo

添加缓存命中率统计

支持分布式缓存存储 todo

3.def clear_cache(func, *args, **kwargs)

用于手动清除特定函数的缓存文件。

检查缓存文件是否存在,若存在则删除并打印确认信息;若不存在则提示文件不存在的状态。

文件删除操作不可逆,需谨慎调用。

改进点

1、合法文件名处理:

使用正则表达式移除非法字符:re.sub(r'[<>:"/\|?*\x00-\x1F]', '_', name)

处理特殊字符和不可打印字符

2、文件名长度截断:

限制文件名最大长度(255字符)

对长文件名进行智能截断(保留首尾部分)

3、缓存过期机制:

添加expire_seconds参数控制缓存有效期

基于文件修改时间检查过期状态

默认过期时间为24小时

4、日志系统:

使用Python标准logging模块

不同级别的日志(DEBUG、INFO、WARNING、ERROR)

格式化的日志输出

5、异常处理:

捕获并记录文件操作中的所有异常

提供有意义的错误信息

缓存失败时不影响主程序运行

6、自定义缓存目录:

可配置的缓存目录参数

自动创建不存在的目录

默认目录为./.cache

7、缓存统计:

跟踪命中、未命中和过期次数

计算命中率

线程安全的统计计数器

按函数名查看统计信息

7、多线程/进程安全:

使用filelock库实现跨进程文件锁

为每个缓存文件创建对应的锁文件

设置锁超时时间(10秒)

8、增强的缓存清除:

清除特定参数的缓存

清除函数的所有缓存

批量删除操作

9、附加功能:

添加了清理缓存的方法(clear_cache和clear_all_cache)

统计信息查看函数(get_cache_stats和print_cache_stats)

智能缓存路径管理

py 复制代码
import os
import pickle
import hashlib
import inspect
import functools
import time
import re
import logging
import threading
from collections import defaultdict
from filelock import FileLock

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('cache_decorator')

# 缓存统计
cache_stats = defaultdict(lambda: {'hits': 0, 'misses': 0, 'expired': 0, 'deleted': 0})
stats_lock = threading.Lock()

# 默认配置
DEFAULT_CACHE_DIR = os.path.join(os.getcwd(), '.cache')
DEFAULT_EXPIRE_SECONDS = 24 * 60 * 60  # 默认过期时间: 24小时
MAX_FILENAME_LENGTH = 200  # 最大文件名长度


def _sanitize_filename(name):
    """移除文件名中的非法字符并截断长度"""
    # 替换非法字符
    sanitized = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', name)

    # 截断文件名
    if len(sanitized) > MAX_FILENAME_LENGTH:
        prefix = sanitized[:MAX_FILENAME_LENGTH // 2]
        suffix = sanitized[-MAX_FILENAME_LENGTH // 2:]
        sanitized = prefix + '...' + suffix
        # 确保最终长度不超过限制
        sanitized = sanitized[:MAX_FILENAME_LENGTH]

    return sanitized


def _generate_cache_filename(func, *args, **kwargs):
    """生成缓存文件名的内部函数"""
    # 获取调用来源文件的绝对路径
    caller_frame = inspect.stack()[2]  # 调整堆栈深度
    caller_file = os.path.abspath(caller_frame.filename)

    # 生成调用文件路径的短哈希
    file_hash = hashlib.md5(caller_file.encode()).hexdigest()[:8]

    # 生成参数签名
    args_repr = "_".join([repr(arg) for arg in args])
    kwargs_repr = "_".join([f"{k}={repr(v)}" for k, v in sorted(kwargs.items())])

    # 处理无参数情况
    param_repr = f"{args_repr}_{kwargs_repr}" if (args or kwargs) else "no_params"

    # 组合并清理文件名
    raw_filename = f"{func.__name__}_{param_repr}_{file_hash}"
    return _sanitize_filename(raw_filename) + ".pkl"


def _get_cache_file_path(cache_dir, cache_file):
    """获取缓存文件完整路径,确保目录存在"""
    # 创建缓存目录(如果不存在)
    os.makedirs(cache_dir, exist_ok=True)
    return os.path.join(cache_dir, cache_file)


def cache(expire_seconds=DEFAULT_EXPIRE_SECONDS, cache_dir=DEFAULT_CACHE_DIR):
    """带参数的缓存装饰器

    Args:
        expire_seconds (int): 缓存过期时间(秒)
        cache_dir (str): 缓存文件存储目录
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 生成缓存文件名
            cache_file = _generate_cache_filename(func, *args, **kwargs)
            cache_path = _get_cache_file_path(cache_dir, cache_file)
            lock_path = cache_path + ".lock"

            # 使用文件锁确保线程/进程安全
            with FileLock(lock_path, timeout=10):
                # 检查缓存是否存在且未过期
                if os.path.exists(cache_path):
                    file_age = time.time() - os.path.getmtime(cache_path)

                    if expire_seconds is None or file_age < expire_seconds:
                        # 缓存命中
                        try:
                            with open(cache_path, 'rb') as f:
                                result = pickle.load(f)
                            with stats_lock:
                                cache_stats[func.__name__]['hits'] += 1
                            logger.debug(f'缓存命中: {cache_path}')
                            return result
                        except Exception as e:
                            logger.warning(f'加载缓存失败: {e}')

                    # 缓存过期
                    with stats_lock:
                        cache_stats[func.__name__]['expired'] += 1
                    logger.debug(f'缓存已过期: {cache_path}')

                # 缓存未命中或过期,重新计算
                with stats_lock:
                    cache_stats[func.__name__]['misses'] += 1
                result = func(*args, **kwargs)

                # 保存结果到缓存
                try:
                    with open(cache_path, 'wb') as f:
                        pickle.dump(result, f)
                    logger.debug(f'缓存已保存: {cache_path}')
                except Exception as e:
                    logger.error(f'保存缓存失败: {e}')

                return result

        # 为包装的函数添加清除缓存的方法
        def clear_cache(*args, **kwargs):
            """清除特定参数的缓存"""
            cache_file = _generate_cache_filename(func, *args, **kwargs)
            cache_path = _get_cache_file_path(cache_dir, cache_file)

            if os.path.exists(cache_path):
                try:
                    os.remove(cache_path)
                    logger.info(f'缓存已删除: {cache_path}')
                    with stats_lock:
                        cache_stats[func.__name__]['deleted'] += 1
                    return True
                except Exception as e:
                    logger.error(f'删除缓存失败: {e}')
                    return False
            else:
                logger.warning(f'缓存文件不存在: {cache_path}')
                return False

        def clear_all_cache():
            """清除该函数的所有缓存"""
            pattern = re.compile(f"^{func.__name__}_.*\\.pkl$")
            cleared = 0
            total = 0

            for filename in os.listdir(cache_dir):
                if pattern.match(filename):
                    total += 1
                    file_path = os.path.join(cache_dir, filename)
                    try:
                        os.remove(file_path)
                        cleared += 1
                        with stats_lock:
                            cache_stats[func.__name__]['deleted'] += 1
                    except Exception as e:
                        logger.error(f'删除缓存失败 {filename}: {e}')

            logger.info(f'已清除 {cleared}/{total} 个缓存文件')
            return cleared

        def clear_expired_cache(expire_seconds=expire_seconds):
            """清除该函数的所有过期缓存"""
            pattern = re.compile(f"^{func.__name__}_.*\\.pkl$")
            current_time = time.time()
            removed = 0
            total = 0

            for filename in os.listdir(cache_dir):
                if pattern.match(filename):
                    total += 1
                    file_path = os.path.join(cache_dir, filename)
                    try:
                        # 检查文件是否过期
                        mtime = os.path.getmtime(file_path)
                        if current_time - mtime > expire_seconds:
                            os.remove(file_path)
                            removed += 1
                            with stats_lock:
                                cache_stats[func.__name__]['deleted'] += 1
                            logger.debug(f'已删除过期缓存: {filename}')
                    except Exception as e:
                        logger.error(f'处理缓存文件 {filename} 失败: {e}')

            logger.info(f'已删除 {removed}/{total} 个过期缓存文件')
            return removed

        wrapper.clear_cache = clear_cache
        wrapper.clear_all_cache = clear_all_cache
        wrapper.clear_expired_cache = clear_expired_cache
        return wrapper

    return decorator


def get_cache_stats(func_name=None):
    """获取缓存统计信息

    Args:
        func_name (str): 函数名,None 表示所有函数

    Returns:
        dict: 缓存统计信息
    """
    with stats_lock:
        if func_name:
            return cache_stats.get(func_name, {'hits': 0, 'misses': 0, 'expired': 0, 'deleted': 0})

        # 计算总命中率
        total_stats = {'hits': 0, 'misses': 0, 'expired': 0, 'deleted': 0}
        for stats in cache_stats.values():
            for k in total_stats:
                total_stats[k] += stats[k]

        # 添加命中率百分比
        total = total_stats['hits'] + total_stats['misses'] + total_stats['expired']
        if total > 0:
            total_stats['hit_rate'] = total_stats['hits'] / total * 100
        else:
            total_stats['hit_rate'] = 0.0

        return total_stats


def print_cache_stats(func_name=None):
    """打印缓存统计信息"""
    stats = get_cache_stats(func_name)

    if func_name:
        print(f"\n缓存统计 - {func_name}:")
    else:
        print("\n全局缓存统计:")

    print(f"命中次数: {stats['hits']}")
    print(f"未命中次数: {stats['misses']}")
    print(f"过期次数: {stats['expired']}")
    print(f"删除次数: {stats['deleted']}")

    if 'hit_rate' in stats:
        print(f"命中率: {stats['hit_rate']:.2f}%")
    else:
        total = stats['hits'] + stats['misses'] + stats['expired']
        if total > 0:
            hit_rate = stats['hits'] / total * 100
            print(f"命中率: {hit_rate:.2f}%")


def clear_all_expired_cache(cache_dir=DEFAULT_CACHE_DIR, expire_seconds=DEFAULT_EXPIRE_SECONDS):
    """清除缓存目录中所有过期的缓存文件

    Args:
        cache_dir (str): 缓存目录
        expire_seconds (int): 过期时间(秒)
    """
    current_time = time.time()
    removed = 0
    total = 0

    if not os.path.exists(cache_dir):
        logger.warning(f"缓存目录不存在: {cache_dir}")
        return 0

    for filename in os.listdir(cache_dir):
        if filename.endswith('.pkl'):
            total += 1
            file_path = os.path.join(cache_dir, filename)
            try:
                # 检查文件是否过期
                mtime = os.path.getmtime(file_path)
                if current_time - mtime > expire_seconds:
                    os.remove(file_path)
                    removed += 1
                    with stats_lock:
                        # 尝试找出对应的函数名
                        func_name = filename.split('_')[0]
                        if func_name in cache_stats:
                            cache_stats[func_name]['deleted'] += 1
                    logger.debug(f'已删除过期缓存: {filename}')
            except Exception as e:
                logger.error(f'处理缓存文件 {filename} 失败: {e}')

    logger.info(f'已删除 {removed}/{total} 个过期缓存文件')
    return removed


# 测试用例
@cache(expire_seconds=2, cache_dir="./test_cache")
def get_data(a, b):
    print("计算数据")
    return a + b


if __name__ == "__main__":
    # 确保测试缓存目录存在
    os.makedirs("./test_cache", exist_ok=True)

    # 第一次调用(创建缓存)
    print(get_data(1, 2))  # 输出: 计算数据 和 3

    # 第二次调用(读取缓存)
    print(get_data(1, 2))  # 无"计算数据"输出

    # 等待缓存过期
    time.sleep(3)

    # 第三次调用(缓存过期后重新计算)
    print(get_data(1, 2))  # 再次输出"计算数据"

    # 清除特定参数缓存
    get_data.clear_cache(1, 2)

    # 第四次调用(清除后重新计算)
    print(get_data(1, 2))  # 输出"计算数据"

    # 创建另一个缓存
    print(get_data(3, 4))

    # 等待缓存过期
    time.sleep(3)

    # 清除过期缓存(仅限get_data函数)
    get_data.clear_expired_cache()

    # 清除整个缓存目录中的过期缓存
    clear_all_expired_cache("./test_cache", expire_seconds=1)

    # 清除所有缓存
    get_data.clear_all_cache()

    # 打印缓存统计
    print_cache_stats()
相关推荐
先睡1 分钟前
Redis的缓存击穿和缓存雪崩
redis·spring·缓存
Python×CATIA工业智造1 小时前
Frida RPC高级应用:动态模拟执行Android so文件实战指南
开发语言·python·pycharm
onceco2 小时前
领域LLM九讲——第5讲 为什么选择OpenManus而不是QwenAgent(附LLM免费api邀请码)
人工智能·python·深度学习·语言模型·自然语言处理·自动化
我叫小白菜2 小时前
【Java_EE】单例模式、阻塞队列、线程池、定时器
java·开发语言
狐凄2 小时前
Python实例题:基于 Python 的简单聊天机器人
开发语言·python
weixin_446122463 小时前
JAVA内存区域划分
java·开发语言·redis
悦悦子a啊3 小时前
Python之--基本知识
开发语言·前端·python
QuantumStack4 小时前
【C++ 真题】P1104 生日
开发语言·c++·算法
whoarethenext4 小时前
使用 C++/OpenCV 和 MFCC 构建双重认证智能门禁系统
开发语言·c++·opencv·mfcc
笑稀了的野生俊5 小时前
在服务器中下载 HuggingFace 模型:终极指南
linux·服务器·python·bash·gpu算力