目录
- Flask扩展开发:从零编写自己的Flask扩展
-
- [1. 引言:为什么需要自定义Flask扩展](#1. 引言:为什么需要自定义Flask扩展)
-
- [1.1 Flask扩展的价值](#1.1 Flask扩展的价值)
- [1.2 常见Flask扩展类型](#1.2 常见Flask扩展类型)
- [2. Flask扩展开发基础](#2. Flask扩展开发基础)
-
- [2.1 Flask扩展的结构要求](#2.1 Flask扩展的结构要求)
- [2.2 Flask扩展的命名约定](#2.2 Flask扩展的命名约定)
- [2.3 扩展的生命周期](#2.3 扩展的生命周期)
- [3. 开发第一个Flask扩展:Flask-Heartbeat](#3. 开发第一个Flask扩展:Flask-Heartbeat)
-
- [3.1 扩展的基本结构](#3.1 扩展的基本结构)
- [Quick Start](#Quick Start)
-
- [3.3 扩展使用示例](#3.3 扩展使用示例)
- [4. 进阶扩展:Flask-CacheManager](#4. 进阶扩展:Flask-CacheManager)
-
- [4.1 缓存管理器扩展结构](#4.1 缓存管理器扩展结构)
- [4.2 缓存管理器使用示例](#4.2 缓存管理器使用示例)
- [5. 扩展测试与质量保证](#5. 扩展测试与质量保证)
-
- [5.1 编写扩展测试](#5.1 编写扩展测试)
- [5.2 扩展发布与分发](#5.2 扩展发布与分发)
-
- [5.2.1 打包扩展](#5.2.1 打包扩展)
- [5.2.2 创建文档](#5.2.2 创建文档)
- [6. 高级扩展开发技巧](#6. 高级扩展开发技巧)
-
- [6.1 扩展的配置管理](#6.1 扩展的配置管理)
- [6.2 扩展的依赖注入](#6.2 扩展的依赖注入)
- [6.3 扩展的信号系统](#6.3 扩展的信号系统)
- [7. 扩展开发最佳实践](#7. 扩展开发最佳实践)
-
- [7.1 代码质量保证](#7.1 代码质量保证)
- [7.2 性能优化建议](#7.2 性能优化建议)
- [7.3 安全注意事项](#7.3 安全注意事项)
- [8. 总结与展望](#8. 总结与展望)
-
- [8.1 扩展开发的关键要点](#8.1 扩展开发的关键要点)
- [8.2 扩展生态系统](#8.2 扩展生态系统)
- [8.3 扩展开发的未来趋势](#8.3 扩展开发的未来趋势)
- [8.4 进一步学习资源](#8.4 进一步学习资源)
『宝藏代码胶囊开张啦!』------ 我的 CodeCapsule 来咯!✨写代码不再头疼!我的新站点 CodeCapsule 主打一个 "白菜价"+"量身定制 "!无论是卡脖子的毕设/课设/文献复现 ,需要灵光一现的算法改进 ,还是想给项目加个"外挂",这里都有便宜又好用的代码方案等你发现!低成本,高适配,助你轻松通关!速来围观 👉 CodeCapsule官网
Flask扩展开发:从零编写自己的Flask扩展
1. 引言:为什么需要自定义Flask扩展
Flask是一个极简的Python Web框架,以其"微核"设计理念而闻名。其简洁性允许开发者按需添加功能,而Flask扩展正是实现这一理念的关键机制。那么,为什么我们需要开发自定义的Flask扩展呢?
1.1 Flask扩展的价值
业务需求 分析需求 通用可复用 特定一次性 开发Flask扩展 直接实现 开源共享/团队复用 项目特定代码 提高开发效率
Flask扩展开发的价值在于:
- 代码复用:将通用功能封装成扩展,避免重复造轮子
- 标准化接口:提供一致的API,降低团队学习成本
- 模块化设计:分离关注点,保持应用代码简洁
- 开源贡献:分享解决方案,参与开源社区
1.2 常见Flask扩展类型
| 类型 | 示例扩展 | 功能描述 |
|---|---|---|
| 数据库集成 | Flask-SQLAlchemy | 数据库ORM集成 |
| 表单处理 | Flask-WTF | Web表单处理 |
| 用户认证 | Flask-Login | 用户会话管理 |
| 缓存 | Flask-Caching | 数据缓存功能 |
| 邮件发送 | Flask-Mail | 邮件发送功能 |
| 配置管理 | Flask-Config | 配置管理工具 |
2. Flask扩展开发基础
2.1 Flask扩展的结构要求
一个标准的Flask扩展应该遵循以下结构:
python
"""
flask_myextension/
├── __init__.py # 扩展主文件
├── __about__.py # 元数据信息
├── core.py # 核心功能实现
├── decorators.py # 装饰器
├── exceptions.py # 自定义异常
├── mixins.py # 混入类
├── utils.py # 工具函数
├── static/ # 静态文件
├── templates/ # 模板文件
└── tests/ # 测试文件
2.2 Flask扩展的命名约定
- 包名格式:
flask_extensionname(使用下划线分隔) - 类名格式:
ExtensionName(使用大驼峰命名法) - 版本号:遵循语义化版本控制(SemVer)
2.3 扩展的生命周期
python
from flask import Flask
# 1. 导入扩展
from flask_myextension import MyExtension
# 2. 创建应用
app = Flask(__name__)
# 3. 初始化扩展
ext = MyExtension(app) # 方式一:直接初始化
# 或
ext = MyExtension()
ext.init_app(app) # 方式二:延迟初始化
# 4. 使用扩展功能
@app.route('/')
def index():
return ext.doSomething()
3. 开发第一个Flask扩展:Flask-Heartbeat
让我们从创建一个简单的健康检查扩展开始。这个扩展将为Flask应用添加健康检查端点。
3.1 扩展的基本结构
首先创建扩展目录结构:
python
# setup.py - 扩展的安装配置文件
from setuptools import setup, find_packages
setup(
name="flask-heartbeat",
version="1.0.0",
description="A Flask extension for health checks",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
author="Your Name",
author_email="your.email@example.com",
url="https://github.com/yourusername/flask-heartbeat",
packages=find_packages(),
install_requires=[
"flask>=2.0.0",
],
python_requires=">=3.7",
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Web Environment",
"Framework :: Flask",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Internet :: WWW/HTTP :: Dynamic Content",
"Topic :: Software Development :: Libraries :: Python Modules",
],
keywords="flask health check monitoring",
)
# README.md - 扩展说明文档
"""
# Flask-Heartbeat
A Flask extension for adding health check endpoints to your Flask applications.
## Features
- Simple health check endpoint
- Customizable health check functions
- Database connection checks
- External service checks
- Performance monitoring
## Installation
```bash
pip install flask-heartbeat
Quick Start
python
from flask import Flask
from flask_heartbeat import Heartbeat
app = Flask(__name__)
heartbeat = Heartbeat(app)
@app.route('/')
def index():
return 'Hello World!'
Visit /health to see the health status.
"""
### 3.2 扩展核心实现
```python
# flask_heartbeat/__init__.py
"""
Flask-Heartbeat extension main module.
"""
import time
import json
from typing import Dict, List, Callable, Any, Optional, Union
from dataclasses import dataclass, asdict
from enum import Enum
from datetime import datetime
from flask import Blueprint, current_app, jsonify, request
from werkzeug.exceptions import HTTPException
__version__ = "1.0.0"
class HealthStatus(Enum):
"""Health status enumeration."""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
DEGRADED = "degraded"
@dataclass
class HealthCheckResult:
"""Result of a health check."""
name: str
status: HealthStatus
message: str = ""
timestamp: datetime = None
duration: float = 0.0
data: Dict[str, Any] = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.utcnow()
if self.data is None:
self.data = {}
class HeartbeatError(Exception):
"""Base exception for Heartbeat extension."""
pass
class Heartbeat:
"""
Flask extension for health checks.
Usage:
app = Flask(__name__)
heartbeat = Heartbeat(app)
# or
heartbeat = Heartbeat()
heartbeat.init_app(app)
"""
def __init__(
self,
app=None,
endpoint: str = "/health",
health_check_funcs: Optional[List[Callable]] = None,
auth_required: bool = False,
auth_func: Optional[Callable] = None,
cache_timeout: int = 30
):
"""
Initialize the Heartbeat extension.
Args:
app: Flask application instance
endpoint: Health check endpoint path
health_check_funcs: List of health check functions
auth_required: Whether authentication is required
auth_func: Authentication function
cache_timeout: Cache timeout in seconds
"""
self.app = app
self.endpoint = endpoint
self.health_check_funcs = health_check_funcs or []
self.auth_required = auth_required
self.auth_func = auth_func
self.cache_timeout = cache_timeout
self._last_check = None
self._last_result = None
# Store additional configuration
self._config = {}
if app is not None:
self.init_app(app)
def init_app(self, app: 'Flask') -> None:
"""
Initialize extension with Flask application.
Args:
app: Flask application instance
"""
self.app = app
# Set default configuration
app.config.setdefault('HEARTBEAT_ENDPOINT', self.endpoint)
app.config.setdefault('HEARTBEAT_AUTH_REQUIRED', self.auth_required)
app.config.setdefault('HEARTBEAT_CACHE_TIMEOUT', self.cache_timeout)
# Update from app config
self.endpoint = app.config.get('HEARTBEAT_ENDPOINT', self.endpoint)
self.auth_required = app.config.get('HEARTBEAT_AUTH_REQUIRED', self.auth_required)
self.cache_timeout = app.config.get('HEARTBEAT_CACHE_TIMEOUT', self.cache_timeout)
# Create blueprint for health check endpoint
self.blueprint = self._create_blueprint()
# Register blueprint
app.register_blueprint(self.blueprint)
# Register teardown handler
self._register_teardown_handlers(app)
# Add template filters if needed
self._add_template_filters(app)
# Store extension reference in app
if not hasattr(app, 'extensions'):
app.extensions = {}
app.extensions['heartbeat'] = self
# Add CLI commands
self._add_cli_commands(app)
def _create_blueprint(self) -> Blueprint:
"""Create blueprint for health check endpoint."""
bp = Blueprint('heartbeat', __name__)
@bp.route(self.endpoint, methods=['GET'])
def health_check():
"""Health check endpoint."""
# Check authentication if required
if self.auth_required and self.auth_func:
if not self.auth_func(request):
return jsonify({
'error': 'Authentication required',
'status': 'error'
}), 401
# Use cached result if available and not expired
if (self._last_result and self._last_check and
time.time() - self._last_check < self.cache_timeout):
return self._format_response(self._last_result)
# Run health checks
results = self.run_health_checks()
# Cache results
self._last_check = time.time()
self._last_result = results
return self._format_response(results)
# Add detailed health endpoint
@bp.route(f"{self.endpoint}/detailed", methods=['GET'])
def detailed_health_check():
"""Detailed health check endpoint."""
if self.auth_required and self.auth_func:
if not self.auth_func(request):
return jsonify({
'error': 'Authentication required',
'status': 'error'
}), 401
results = self.run_health_checks(detailed=True)
return self._format_response(results, detailed=True)
# Add ready endpoint (for kubernetes readiness probe)
@bp.route(f"{self.endpoint}/ready", methods=['GET'])
def readiness_check():
"""Readiness check endpoint."""
results = self.run_health_checks()
# Check if all checks passed
all_healthy = all(
r.status == HealthStatus.HEALTHY
for r in results
)
status_code = 200 if all_healthy else 503
return self._format_response(results, detailed=False), status_code
# Add live endpoint (for kubernetes liveness probe)
@bp.route(f"{self.endpoint}/live", methods=['GET'])
def liveness_check():
"""Liveness check endpoint."""
return jsonify({
'status': 'alive',
'timestamp': datetime.utcnow().isoformat()
})
return bp
def add_check(self, func: Callable) -> Callable:
"""
Decorator to add a health check function.
Args:
func: Health check function
Returns:
The original function (for use as decorator)
"""
self.health_check_funcs.append(func)
return func
def add_database_check(self, get_db_func: Callable) -> None:
"""
Add a database connection check.
Args:
get_db_func: Function that returns a database connection
"""
def check_database():
start_time = time.time()
try:
db = get_db_func()
# Try a simple query
if hasattr(db, 'execute'):
db.execute('SELECT 1')
elif hasattr(db, 'cursor'):
with db.cursor() as cursor:
cursor.execute('SELECT 1')
else:
# Assume it's a connection that can be pinged
db.ping()
return HealthCheckResult(
name="database",
status=HealthStatus.HEALTHY,
message="Database connection is healthy",
duration=time.time() - start_time
)
except Exception as e:
return HealthCheckResult(
name="database",
status=HealthStatus.UNHEALTHY,
message=f"Database connection failed: {str(e)}",
duration=time.time() - start_time
)
self.add_check(check_database)
def add_redis_check(self, get_redis_func: Callable) -> None:
"""
Add a Redis connection check.
Args:
get_redis_func: Function that returns a Redis connection
"""
def check_redis():
start_time = time.time()
try:
redis = get_redis_func()
# Try to ping Redis
response = redis.ping()
if response:
return HealthCheckResult(
name="redis",
status=HealthStatus.HEALTHY,
message="Redis connection is healthy",
duration=time.time() - start_time
)
else:
return HealthCheckResult(
name="redis",
status=HealthStatus.UNHEALTHY,
message="Redis ping failed",
duration=time.time() - start_time
)
except Exception as e:
return HealthCheckResult(
name="redis",
status=HealthStatus.UNHEALTHY,
message=f"Redis connection failed: {str(e)}",
duration=time.time() - start_time
)
self.add_check(check_redis)
def run_health_checks(self, detailed: bool = False) -> List[HealthCheckResult]:
"""
Run all registered health checks.
Args:
detailed: Whether to include detailed information
Returns:
List of health check results
"""
results = []
for check_func in self.health_check_funcs:
try:
result = check_func()
if detailed:
# Add more details if available
if hasattr(check_func, '__name__'):
result.data['function_name'] = check_func.__name__
if hasattr(check_func, '__module__'):
result.data['module'] = check_func.__module__
results.append(result)
except Exception as e:
results.append(HealthCheckResult(
name=check_func.__name__ if hasattr(check_func, '__name__') else "unknown",
status=HealthStatus.UNHEALTHY,
message=f"Check failed with exception: {str(e)}",
data={"error": str(e)} if detailed else {}
))
return results
def _format_response(
self,
results: List[HealthCheckResult],
detailed: bool = False
) -> Dict[str, Any]:
"""
Format health check results as JSON response.
Args:
results: List of health check results
detailed: Whether to include detailed information
Returns:
Formatted response dictionary
"""
# Determine overall status
statuses = [r.status for r in results]
if HealthStatus.UNHEALTHY in statuses:
overall_status = HealthStatus.UNHEALTHY.value
http_status = 503
elif HealthStatus.DEGRADED in statuses:
overall_status = HealthStatus.DEGRADED.value
http_status = 200 # Still OK, but degraded
else:
overall_status = HealthStatus.HEALTHY.value
http_status = 200
# Prepare response
response = {
'status': overall_status,
'timestamp': datetime.utcnow().isoformat(),
'checks': {
'total': len(results),
'healthy': len([r for r in results if r.status == HealthStatus.HEALTHY]),
'unhealthy': len([r for r in results if r.status == HealthStatus.UNHEALTHY]),
'degraded': len([r for r in results if r.status == HealthStatus.DEGRADED]),
},
'results': []
}
# Add individual check results
for result in results:
result_dict = {
'name': result.name,
'status': result.status.value,
'message': result.message,
'duration': round(result.duration, 3),
'timestamp': result.timestamp.isoformat() if result.timestamp else None
}
if detailed and result.data:
result_dict['data'] = result.data
response['results'].append(result_dict)
# Return as JSON response
return jsonify(response)
def _register_teardown_handlers(self, app: 'Flask') -> None:
"""Register teardown handlers for cleanup."""
@app.teardown_appcontext
def teardown_heartbeat(exception):
"""Cleanup resources."""
# Add any cleanup logic here
pass
def _add_template_filters(self, app: 'Flask') -> None:
"""Add template filters if needed."""
@app.template_filter('health_status_class')
def health_status_class(status):
"""Convert health status to CSS class."""
status_map = {
HealthStatus.HEALTHY.value: 'success',
HealthStatus.UNHEALTHY.value: 'danger',
HealthStatus.DEGRADED.value: 'warning',
}
return status_map.get(status, 'secondary')
def _add_cli_commands(self, app: 'Flask') -> None:
"""Add CLI commands for health checks."""
import click
@app.cli.command('health-check')
@click.option('--detailed', is_flag=True, help='Show detailed information')
def health_check_cli(detailed):
"""Run health checks from command line."""
from flask import current_app
heartbeat = current_app.extensions['heartbeat']
results = heartbeat.run_health_checks(detailed=detailed)
# Print results
click.echo(f"Health Check Results ({len(results)} checks)")
click.echo("=" * 50)
for result in results:
status_icon = "✓" if result.status == HealthStatus.HEALTHY else "✗"
status_color = "green" if result.status == HealthStatus.HEALTHY else "red"
click.echo(
f"{status_icon} {result.name}: "
f"{click.style(result.status.value, fg=status_color)} "
f"({result.duration:.3f}s)"
)
if result.message:
click.echo(f" {result.message}")
# Summary
healthy_count = len([r for r in results if r.status == HealthStatus.HEALTHY])
click.echo("\n" + "=" * 50)
click.echo(f"Summary: {healthy_count}/{len(results)} checks passed")
if healthy_count < len(results):
click.echo(click.style("Some checks failed!", fg="red", bold=True))
raise click.Abort()
def set_config(self, key: str, value: Any) -> None:
"""
Set extension configuration.
Args:
key: Configuration key
value: Configuration value
"""
self._config[key] = value
def get_config(self, key: str, default: Any = None) -> Any:
"""
Get extension configuration.
Args:
key: Configuration key
default: Default value if key not found
Returns:
Configuration value
"""
return self._config.get(key, default)
def clear_cache(self) -> None:
"""Clear cached health check results."""
self._last_check = None
self._last_result = None
3.3 扩展使用示例
python
# example_app.py - 示例应用展示扩展用法
from flask import Flask, jsonify
import sqlite3
import redis
import time
from flask_heartbeat import Heartbeat, HealthStatus
# 创建Flask应用
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-here'
app.config['HEARTBEAT_ENDPOINT'] = '/health'
app.config['HEARTBEAT_CACHE_TIMEOUT'] = 10 # 缓存10秒
# 初始化Heartbeat扩展
heartbeat = Heartbeat(app)
# 模拟数据库连接
def get_database_connection():
"""获取数据库连接(模拟)"""
# 在实际应用中,这里应该返回真实的数据库连接
return sqlite3.connect(':memory:')
# 模拟Redis连接
def get_redis_connection():
"""获取Redis连接(模拟)"""
# 在实际应用中,这里应该返回真实的Redis连接
class MockRedis:
def ping(self):
return True
return MockRedis()
# 添加内置检查
heartbeat.add_database_check(get_database_connection)
heartbeat.add_redis_check(get_redis_connection)
# 使用装饰器添加自定义检查
@heartbeat.add_check
def check_disk_space():
"""检查磁盘空间"""
import shutil
start_time = time.time()
try:
# 模拟磁盘空间检查
total, used, free = shutil.disk_usage("/")
free_percent = (free / total) * 100
if free_percent > 20:
status = HealthStatus.HEALTHY
message = f"Disk space充足: {free_percent:.1f}% 空闲"
elif free_percent > 10:
status = HealthStatus.DEGRADED
message = f"Disk空间不足: {free_percent:.1f}% 空闲"
else:
status = HealthStatus.UNHEALTHY
message = f"Disk空间严重不足: {free_percent:.1f}% 空闲"
return heartbeat.HealthCheckResult(
name="disk_space",
status=status,
message=message,
duration=time.time() - start_time,
data={
"total_gb": total / (1024**3),
"used_gb": used / (1024**3),
"free_gb": free / (1024**3),
"free_percent": free_percent
}
)
except Exception as e:
return heartbeat.HealthCheckResult(
name="disk_space",
status=HealthStatus.UNHEALTHY,
message=f"Disk空间检查失败: {str(e)}",
duration=time.time() - start_time
)
@heartbeat.add_check
def check_external_api():
"""检查外部API可用性"""
import requests
start_time = time.time()
try:
# 模拟检查外部API
response = requests.get("https://httpbin.org/delay/1", timeout=2)
if response.status_code == 200:
status = HealthStatus.HEALTHY
message = "外部API响应正常"
else:
status = HealthStatus.DEGRADED
message = f"外部API返回异常状态码: {response.status_code}"
return heartbeat.HealthCheckResult(
name="external_api",
status=status,
message=message,
duration=time.time() - start_time,
data={
"status_code": response.status_code,
"response_time": response.elapsed.total_seconds()
}
)
except requests.exceptions.Timeout:
return heartbeat.HealthCheckResult(
name="external_api",
status=HealthStatus.UNHEALTHY,
message="外部API请求超时",
duration=time.time() - start_time
)
except Exception as e:
return heartbeat.HealthCheckResult(
name="external_api",
status=HealthStatus.UNHEALTHY,
message=f"外部API检查失败: {str(e)}",
duration=time.time() - start_time
)
@heartbeat.add_check
def check_application_specific():
"""应用程序特定检查"""
start_time = time.time()
# 模拟应用程序特定检查
try:
# 检查应用程序配置
required_configs = ['SECRET_KEY', 'DEBUG']
missing_configs = []
for config_key in required_configs:
if config_key not in app.config:
missing_configs.append(config_key)
if missing_configs:
status = HealthStatus.UNHEALTHY
message = f"缺失必需配置: {', '.join(missing_configs)}"
else:
status = HealthStatus.HEALTHY
message = "应用程序配置完整"
return heartbeat.HealthCheckResult(
name="app_config",
status=status,
message=message,
duration=time.time() - start_time,
data={
"missing_configs": missing_configs,
"total_configs": len(app.config)
}
)
except Exception as e:
return heartbeat.HealthCheckResult(
name="app_config",
status=HealthStatus.UNHEALTHY,
message=f"应用程序检查失败: {str(e)}",
duration=time.time() - start_time
)
# 添加其他应用路由
@app.route('/')
def index():
return jsonify({
'message': '欢迎使用Flask应用',
'version': '1.0.0',
'health_endpoint': '/health',
'detailed_health': '/health/detailed',
'readiness': '/health/ready',
'liveness': '/health/live'
})
@app.route('/api/data')
def get_data():
return jsonify({
'data': [1, 2, 3, 4, 5],
'timestamp': time.time()
})
if __name__ == '__main__':
app.run(debug=True, port=5000)
4. 进阶扩展:Flask-CacheManager
现在让我们创建一个更复杂的扩展:一个缓存管理器扩展,支持多种缓存后端。
4.1 缓存管理器扩展结构
python
# flask_cachemanager/__init__.py
"""
Flask-CacheManager extension.
A flexible caching extension for Flask applications.
"""
import hashlib
import json
import pickle
import time
import functools
from typing import Any, Optional, Callable, Dict, List, Union, Tuple
from datetime import datetime, timedelta
from enum import Enum
from dataclasses import dataclass, asdict
from flask import Flask, current_app, request, g
from werkzeug.local import LocalProxy
__version__ = "1.0.0"
class CacheBackend(Enum):
"""Supported cache backends."""
MEMORY = "memory"
REDIS = "redis"
FILESYSTEM = "filesystem"
DATABASE = "database"
class CacheError(Exception):
"""Base exception for CacheManager."""
pass
class InvalidCacheKeyError(CacheError):
"""Invalid cache key error."""
pass
class CacheUnavailableError(CacheError):
"""Cache backend unavailable error."""
pass
@dataclass
class CacheStats:
"""Cache statistics."""
hits: int = 0
misses: int = 0
sets: int = 0
deletes: int = 0
clears: int = 0
def hit_rate(self) -> float:
"""Calculate hit rate."""
total = self.hits + self.misses
return self.hits / total if total > 0 else 0.0
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
class CacheBackendInterface:
"""Interface for cache backends."""
def get(self, key: str) -> Any:
"""Get value from cache."""
raise NotImplementedError
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
"""Set value in cache."""
raise NotImplementedError
def delete(self, key: str) -> bool:
"""Delete value from cache."""
raise NotImplementedError
def clear(self) -> bool:
"""Clear all cache."""
raise NotImplementedError
def has(self, key: str) -> bool:
"""Check if key exists in cache."""
raise NotImplementedError
def ttl(self, key: str) -> Optional[int]:
"""Get time to live for key."""
raise NotImplementedError
class MemoryCacheBackend(CacheBackendInterface):
"""In-memory cache backend."""
def __init__(self):
self._cache = {}
self._expiry = {}
def get(self, key: str) -> Any:
"""Get value from memory cache."""
if key not in self._cache:
return None
# Check if expired
if key in self._expiry and self._expiry[key] < time.time():
self.delete(key)
return None
return self._cache[key]
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
"""Set value in memory cache."""
self._cache[key] = value
if timeout:
self._expiry[key] = time.time() + timeout
elif key in self._expiry:
del self._expiry[key]
return True
def delete(self, key: str) -> bool:
"""Delete value from memory cache."""
if key in self._cache:
del self._cache[key]
if key in self._expiry:
del self._expiry[key]
return True
def clear(self) -> bool:
"""Clear all memory cache."""
self._cache.clear()
self._expiry.clear()
return True
def has(self, key: str) -> bool:
"""Check if key exists in memory cache."""
if key not in self._cache:
return False
# Check if expired
if key in self._expiry and self._expiry[key] < time.time():
self.delete(key)
return False
return True
def ttl(self, key: str) -> Optional[int]:
"""Get time to live for key."""
if key not in self._expiry:
return None
ttl = self._expiry[key] - time.time()
return max(0, int(ttl))
class CacheManager:
"""
Flask cache manager extension.
Usage:
app = Flask(__name__)
cache = CacheManager(app)
# Or using factory pattern:
cache = CacheManager()
cache.init_app(app)
"""
def __init__(
self,
app: Optional[Flask] = None,
backend: Union[str, CacheBackend] = CacheBackend.MEMORY,
default_timeout: int = 300,
key_prefix: str = "flask_cache",
serializer: str = "pickle",
**backend_kwargs
):
"""
Initialize CacheManager.
Args:
app: Flask application instance
backend: Cache backend (memory, redis, filesystem, database)
default_timeout: Default cache timeout in seconds
key_prefix: Prefix for cache keys
serializer: Serializer to use (pickle or json)
**backend_kwargs: Backend-specific arguments
"""
self.app = app
self.backend_type = CacheBackend(backend) if isinstance(backend, str) else backend
self.default_timeout = default_timeout
self.key_prefix = key_prefix
self.serializer = serializer
self.backend_kwargs = backend_kwargs
self._backend = None
self._stats = CacheStats()
self._callbacks = {
'on_set': [],
'on_get': [],
'on_delete': [],
'on_clear': [],
}
# Store named caches
self._caches = {}
if app is not None:
self.init_app(app)
def init_app(self, app: Flask) -> None:
"""
Initialize extension with Flask application.
Args:
app: Flask application instance
"""
self.app = app
# Set default configuration
app.config.setdefault('CACHE_BACKEND', self.backend_type.value)
app.config.setdefault('CACHE_DEFAULT_TIMEOUT', self.default_timeout)
app.config.setdefault('CACHE_KEY_PREFIX', self.key_prefix)
app.config.setdefault('CACHE_SERIALIZER', self.serializer)
# Update from app config
backend_str = app.config.get('CACHE_BACKEND', self.backend_type.value)
self.backend_type = CacheBackend(backend_str)
self.default_timeout = app.config.get('CACHE_DEFAULT_TIMEOUT', self.default_timeout)
self.key_prefix = app.config.get('CACHE_KEY_PREFIX', self.key_prefix)
self.serializer = app.config.get('CACHE_SERIALIZER', self.serializer)
# Initialize backend
self._init_backend()
# Store extension reference in app
if not hasattr(app, 'extensions'):
app.extensions = {}
app.extensions['cache_manager'] = self
# Register context processor
self._register_context_processor(app)
# Register template filters
self._register_template_filters(app)
# Add CLI commands
self._add_cli_commands(app)
# Create default cache
self.default_cache = self.get_cache('default')
def _init_backend(self) -> None:
"""Initialize cache backend."""
if self.backend_type == CacheBackend.MEMORY:
self._backend = MemoryCacheBackend()
elif self.backend_type == CacheBackend.REDIS:
try:
import redis
pool = redis.ConnectionPool(**self.backend_kwargs)
client = redis.Redis(connection_pool=pool)
class RedisCacheBackend(CacheBackendInterface):
def __init__(self, client):
self.client = client
def get(self, key):
value = self.client.get(key)
if value is None:
return None
try:
return pickle.loads(value)
except:
return value
def set(self, key, value, timeout=None):
try:
if isinstance(value, (bytes, str, int, float)):
serialized = value
else:
serialized = pickle.dumps(value)
if timeout:
return self.client.setex(key, timeout, serialized)
else:
return self.client.set(key, serialized)
except Exception:
return False
def delete(self, key):
return self.client.delete(key) > 0
def clear(self):
return self.client.flushdb()
def has(self, key):
return self.client.exists(key) > 0
def ttl(self, key):
ttl = self.client.ttl(key)
return ttl if ttl >= 0 else None
self._backend = RedisCacheBackend(client)
except ImportError:
raise ImportError(
"Redis backend requires 'redis' package. "
"Install with: pip install redis"
)
elif self.backend_type == CacheBackend.FILESYSTEM:
import os
import tempfile
cache_dir = self.backend_kwargs.get('cache_dir',
os.path.join(tempfile.gettempdir(), 'flask_cache'))
os.makedirs(cache_dir, exist_ok=True)
class FilesystemCacheBackend(CacheBackendInterface):
def __init__(self, cache_dir):
self.cache_dir = cache_dir
def _get_path(self, key):
# Create safe filename from key
key_hash = hashlib.md5(key.encode()).hexdigest()
return os.path.join(self.cache_dir, f"{key_hash}.cache")
def get(self, key):
path = self._get_path(key)
if not os.path.exists(path):
return None
try:
with open(path, 'rb') as f:
data = pickle.load(f)
# Check expiry
if 'expiry' in data and data['expiry'] < time.time():
self.delete(key)
return None
return data.get('value')
except Exception:
return None
def set(self, key, value, timeout=None):
path = self._get_path(key)
data = {
'value': value,
'created': time.time()
}
if timeout:
data['expiry'] = time.time() + timeout
try:
with open(path, 'wb') as f:
pickle.dump(data, f)
return True
except Exception:
return False
def delete(self, key):
path = self._get_path(key)
if os.path.exists(path):
try:
os.remove(path)
return True
except Exception:
return False
return False
def clear(self):
try:
for filename in os.listdir(self.cache_dir):
if filename.endswith('.cache'):
os.remove(os.path.join(self.cache_dir, filename))
return True
except Exception:
return False
def has(self, key):
path = self._get_path(key)
if not os.path.exists(path):
return False
try:
with open(path, 'rb') as f:
data = pickle.load(f)
if 'expiry' in data and data['expiry'] < time.time():
self.delete(key)
return False
return True
except Exception:
return False
def ttl(self, key):
path = self._get_path(key)
if not os.path.exists(path):
return None
try:
with open(path, 'rb') as f:
data = pickle.load(f)
if 'expiry' not in data:
return None
ttl = data['expiry'] - time.time()
return max(0, int(ttl))
except Exception:
return None
self._backend = FilesystemCacheBackend(cache_dir)
else:
raise ValueError(f"Unsupported cache backend: {self.backend_type}")
def get_cache(self, name: str = 'default', **kwargs) -> 'Cache':
"""
Get or create a named cache.
Args:
name: Cache name
**kwargs: Cache-specific arguments
Returns:
Cache instance
"""
if name not in self._caches:
self._caches[name] = Cache(
manager=self,
name=name,
**kwargs
)
return self._caches[name]
def make_key(self, *parts) -> str:
"""
Create a cache key from parts.
Args:
*parts: Key parts
Returns:
Cache key string
"""
key_parts = [self.key_prefix] + [str(part) for part in parts]
return ':'.join(key_parts)
def _serialize(self, value: Any) -> bytes:
"""Serialize value for caching."""
if self.serializer == 'json':
return json.dumps(value).encode()
else: # pickle
return pickle.dumps(value)
def _deserialize(self, data: bytes) -> Any:
"""Deserialize value from cache."""
if self.serializer == 'json':
return json.loads(data.decode())
else: # pickle
return pickle.loads(data)
def get(self, key: str, default: Any = None) -> Any:
"""
Get value from cache.
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default
"""
if self._backend is None:
raise CacheUnavailableError("Cache backend not initialized")
# Execute before-get callbacks
for callback in self._callbacks['on_get']:
callback('get', key)
value = self._backend.get(key)
if value is None:
self._stats.misses += 1
return default
self._stats.hits += 1
return value
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
"""
Set value in cache.
Args:
key: Cache key
value: Value to cache
timeout: Cache timeout in seconds
Returns:
True if successful, False otherwise
"""
if self._backend is None:
raise CacheUnavailableError("Cache backend not initialized")
if timeout is None:
timeout = self.default_timeout
# Execute before-set callbacks
for callback in self._callbacks['on_set']:
callback('set', key, value, timeout)
success = self._backend.set(key, value, timeout)
if success:
self._stats.sets += 1
return success
def delete(self, key: str) -> bool:
"""
Delete value from cache.
Args:
key: Cache key
Returns:
True if successful, False otherwise
"""
if self._backend is None:
raise CacheUnavailableError("Cache backend not initialized")
# Execute before-delete callbacks
for callback in self._callbacks['on_delete']:
callback('delete', key)
success = self._backend.delete(key)
if success:
self._stats.deletes += 1
return success
def clear(self) -> bool:
"""
Clear all cache.
Returns:
True if successful, False otherwise
"""
if self._backend is None:
raise CacheUnavailableError("Cache backend not initialized")
# Execute before-clear callbacks
for callback in self._callbacks['on_clear']:
callback('clear')
success = self._backend.clear()
if success:
self._stats.clears += 1
return success
def has(self, key: str) -> bool:
"""
Check if key exists in cache.
Args:
key: Cache key
Returns:
True if key exists, False otherwise
"""
if self._backend is None:
raise CacheUnavailableError("Cache backend not initialized")
return self._backend.has(key)
def ttl(self, key: str) -> Optional[int]:
"""
Get time to live for key.
Args:
key: Cache key
Returns:
Time to live in seconds, or None if no expiry
"""
if self._backend is None:
raise CacheUnavailableError("Cache backend not initialized")
return self._backend.ttl(key)
def incr(self, key: str, delta: int = 1) -> Optional[int]:
"""
Increment integer value in cache.
Args:
key: Cache key
delta: Increment amount
Returns:
New value, or None if key doesn't exist
"""
value = self.get(key)
if value is None:
return None
if not isinstance(value, int):
raise TypeError(f"Value for key '{key}' is not an integer")
new_value = value + delta
self.set(key, new_value)
return new_value
def decr(self, key: str, delta: int = 1) -> Optional[int]:
"""
Decrement integer value in cache.
Args:
key: Cache key
delta: Decrement amount
Returns:
New value, or None if key doesn't exist
"""
return self.incr(key, -delta)
def get_many(self, *keys: str) -> Dict[str, Any]:
"""
Get multiple values from cache.
Args:
*keys: Cache keys
Returns:
Dictionary of key-value pairs
"""
return {key: self.get(key) for key in keys}
def set_many(self, mapping: Dict[str, Any], timeout: Optional[int] = None) -> bool:
"""
Set multiple values in cache.
Args:
mapping: Dictionary of key-value pairs
timeout: Cache timeout in seconds
Returns:
True if all successful, False otherwise
"""
all_success = True
for key, value in mapping.items():
success = self.set(key, value, timeout)
if not success:
all_success = False
return all_success
def delete_many(self, *keys: str) -> bool:
"""
Delete multiple values from cache.
Args:
*keys: Cache keys
Returns:
True if all successful, False otherwise
"""
all_success = True
for key in keys:
success = self.delete(key)
if not success:
all_success = False
return all_success
def cache(
self,
timeout: Optional[int] = None,
key_prefix: str = "",
unless: Optional[Callable] = None,
forced_update: Optional[Callable] = None
):
"""
Decorator for caching function results.
Args:
timeout: Cache timeout in seconds
key_prefix: Prefix for cache key
unless: Callable that returns True to skip caching
forced_update: Callable that returns True to force update
Returns:
Decorator function
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Skip caching if unless condition is met
if unless is not None and unless():
return func(*args, **kwargs)
# Force update if condition is met
if forced_update is not None and forced_update():
result = func(*args, **kwargs)
self.delete(cache_key)
return result
# Generate cache key
cache_key = self._make_cache_key(func, key_prefix, args, kwargs)
# Try to get from cache
cached_result = self.get(cache_key)
if cached_result is not None:
return cached_result
# Call function and cache result
result = func(*args, **kwargs)
self.set(cache_key, result, timeout)
return result
return wrapper
return decorator
def _make_cache_key(
self,
func: Callable,
key_prefix: str,
args: tuple,
kwargs: dict
) -> str:
"""Generate cache key for function call."""
# Create key components
components = [
key_prefix or "",
func.__module__,
func.__name__,
str(args),
str(sorted(kwargs.items()))
]
# Create hash
key_string = ':'.join(components)
key_hash = hashlib.md5(key_string.encode()).hexdigest()
return self.make_key('func_cache', func.__name__, key_hash)
def memoize(self, timeout: Optional[int] = None):
"""
Memoization decorator (simpler version of cache).
Args:
timeout: Cache timeout in seconds
Returns:
Decorator function
"""
return self.cache(timeout=timeout)
def register_callback(self, event: str, callback: Callable) -> None:
"""
Register callback for cache events.
Args:
event: Event name (on_set, on_get, on_delete, on_clear)
callback: Callback function
"""
if event not in self._callbacks:
raise ValueError(f"Invalid event: {event}")
self._callbacks[event].append(callback)
def get_stats(self) -> CacheStats:
"""
Get cache statistics.
Returns:
CacheStats object
"""
return self._stats
def reset_stats(self) -> None:
"""Reset cache statistics."""
self._stats = CacheStats()
def _register_context_processor(self, app: Flask) -> None:
"""Register context processor."""
@app.context_processor
def inject_cache():
"""Inject cache into template context."""
return {
'cache': self,
'cache_stats': self.get_stats()
}
def _register_template_filters(self, app: Flask) -> None:
"""Register template filters."""
@app.template_filter('cache_key')
def cache_key_filter(*parts):
"""Create cache key from template."""
return self.make_key(*parts)
def _add_cli_commands(self, app: Flask) -> None:
"""Add CLI commands."""
import click
@app.cli.command('cache-clear')
@click.option('--all', is_flag=True, help='Clear all caches')
@click.option('--pattern', help='Clear keys matching pattern')
def cache_clear(all, pattern):
"""Clear cache."""
if all:
self.clear()
click.echo("All cache cleared")
elif pattern:
# This would need pattern matching support in backend
click.echo("Pattern clearing not implemented for all backends")
else:
self.clear()
click.echo("Cache cleared")
@app.cli.command('cache-stats')
def cache_stats():
"""Show cache statistics."""
stats = self.get_stats()
click.echo("Cache Statistics")
click.echo("=" * 50)
click.echo(f"Hits: {stats.hits}")
click.echo(f"Misses: {stats.misses}")
click.echo(f"Hit Rate: {stats.hit_rate():.1%}")
click.echo(f"Sets: {stats.sets}")
click.echo(f"Deletes: {stats.deletes}")
click.echo(f"Clears: {stats.clears}")
@app.cli.command('cache-info')
def cache_info():
"""Show cache configuration info."""
click.echo("Cache Configuration")
click.echo("=" * 50)
click.echo(f"Backend: {self.backend_type.value}")
click.echo(f"Default Timeout: {self.default_timeout}s")
click.echo(f"Key Prefix: {self.key_prefix}")
click.echo(f"Serializer: {self.serializer}")
click.echo(f"Named Caches: {len(self._caches)}")
class Cache:
"""
Named cache instance.
Provides a scoped interface to the cache manager.
"""
def __init__(self, manager: CacheManager, name: str = 'default', **kwargs):
"""
Initialize cache instance.
Args:
manager: CacheManager instance
name: Cache name
**kwargs: Cache-specific arguments
"""
self.manager = manager
self.name = name
self.config = kwargs
def make_key(self, *parts) -> str:
"""
Create a cache key with cache name prefix.
Args:
*parts: Key parts
Returns:
Cache key string
"""
return self.manager.make_key(self.name, *parts)
def get(self, key: str, default: Any = None) -> Any:
"""
Get value from cache.
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default
"""
full_key = self.make_key(key)
return self.manager.get(full_key, default)
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
"""
Set value in cache.
Args:
key: Cache key
value: Value to cache
timeout: Cache timeout in seconds
Returns:
True if successful, False otherwise
"""
full_key = self.make_key(key)
return self.manager.set(full_key, value, timeout)
def delete(self, key: str) -> bool:
"""
Delete value from cache.
Args:
key: Cache key
Returns:
True if successful, False otherwise
"""
full_key = self.make_key(key)
return self.manager.delete(full_key)
def clear(self) -> bool:
"""
Clear all cache entries for this cache.
Returns:
True if successful, False otherwise
"""
# Note: This would need pattern deletion support
# For simplicity, we'll just clear everything
return self.manager.clear()
def has(self, key: str) -> bool:
"""
Check if key exists in cache.
Args:
key: Cache key
Returns:
True if key exists, False otherwise
"""
full_key = self.make_key(key)
return self.manager.has(full_key)
def ttl(self, key: str) -> Optional[int]:
"""
Get time to live for key.
Args:
key: Cache key
Returns:
Time to live in seconds, or None if no expiry
"""
full_key = self.make_key(key)
return self.manager.ttl(full_key)
def incr(self, key: str, delta: int = 1) -> Optional[int]:
"""
Increment integer value in cache.
Args:
key: Cache key
delta: Increment amount
Returns:
New value, or None if key doesn't exist
"""
full_key = self.make_key(key)
return self.manager.incr(full_key, delta)
def decr(self, key: str, delta: int = 1) -> Optional[int]:
"""
Decrement integer value in cache.
Args:
key: Cache key
delta: Decrement amount
Returns:
New value, or None if key doesn't exist
"""
return self.incr(key, -delta)
# Create a default cache instance for easy access
cache = LocalProxy(lambda: current_app.extensions['cache_manager'].default_cache)
4.2 缓存管理器使用示例
python
# cache_example.py - 缓存管理器使用示例
from flask import Flask, render_template, jsonify, request
from flask_cachemanager import CacheManager, cache
import time
import random
app = Flask(__name__)
# 配置缓存管理器
app.config.update({
'CACHE_BACKEND': 'memory', # 使用内存缓存
'CACHE_DEFAULT_TIMEOUT': 60, # 默认60秒
'CACHE_KEY_PREFIX': 'myapp',
'SECRET_KEY': 'your-secret-key'
})
# 初始化缓存管理器
cache_manager = CacheManager(app)
# 创建命名缓存
user_cache = cache_manager.get_cache('users')
product_cache = cache_manager.get_cache('products')
session_cache = cache_manager.get_cache('sessions')
# 示例1:基本的缓存使用
@app.route('/expensive-operation')
def expensive_operation():
"""模拟耗时操作,使用缓存加速"""
cache_key = 'expensive_result'
# 尝试从缓存获取
result = cache.get(cache_key)
if result is None:
# 模拟耗时计算(2秒)
time.sleep(2)
result = {
'data': [random.randint(1, 100) for _ in range(10)],
'computed_at': time.time(),
'message': 'This was expensive to compute!'
}
# 缓存结果(60秒)
cache.set(cache_key, result, timeout=60)
result['from_cache'] = False
else:
result['from_cache'] = True
return jsonify(result)
# 示例2:使用装饰器缓存函数结果
@app.route('/fibonacci/<int:n>')
@cache_manager.cache(timeout=300) # 缓存5分钟
def fibonacci(n):
"""计算斐波那契数列(使用缓存避免重复计算)"""
def fib(x):
if x <= 1:
return x
return fib(x-1) + fib(x-2)
result = fib(n)
return jsonify({
'n': n,
'fibonacci': result,
'computed_at': time.time()
})
# 示例3:用户资料缓存
@app.route('/user/<int:user_id>')
def get_user(user_id):
"""获取用户信息(带缓存)"""
cache_key = f'user_{user_id}'
# 尝试从用户缓存获取
user_data = user_cache.get(cache_key)
if user_data is None:
# 模拟数据库查询
time.sleep(0.5)
user_data = {
'id': user_id,
'name': f'User {user_id}',
'email': f'user{user_id}@example.com',
'joined_at': time.time() - random.randint(86400, 31536000),
'profile_views': random.randint(100, 10000)
}
# 缓存用户数据(5分钟)
user_cache.set(cache_key, user_data, timeout=300)
user_data['from_cache'] = False
else:
user_data['from_cache'] = True
# 增加浏览次数(使用原子操作)
user_cache.incr(f'user_{user_id}_views')
views = user_cache.get(f'user_{user_id}_views') or 1
user_data['cache_views'] = views
return jsonify(user_data)
# 示例4:API响应缓存(基于查询参数)
@app.route('/products')
def get_products():
"""获取产品列表(带查询参数缓存)"""
# 基于查询参数生成缓存键
category = request.args.get('category', 'all')
page = request.args.get('page', '1')
sort = request.args.get('sort', 'name')
cache_key = f'products_{category}_{page}_{sort}'
# 尝试从产品缓存获取
products = product_cache.get(cache_key)
if products is None:
# 模拟数据库查询
time.sleep(1)
# 生成模拟产品数据
products = []
for i in range(10):
products.append({
'id': i + 1,
'name': f'Product {i + 1}',
'category': random.choice(['electronics', 'clothing', 'books', 'home']),
'price': round(random.uniform(10.0, 1000.0), 2),
'rating': round(random.uniform(1.0, 5.0), 1),
'in_stock': random.choice([True, False])
})
# 过滤和排序(模拟)
if category != 'all':
products = [p for p in products if p['category'] == category]
if sort == 'price':
products.sort(key=lambda x: x['price'])
elif sort == 'rating':
products.sort(key=lambda x: x['rating'], reverse=True)
# 分页(模拟)
page_size = 5
page_int = int(page)
start_idx = (page_int - 1) * page_size
end_idx = start_idx + page_size
paginated_products = products[start_idx:end_idx]
result = {
'products': paginated_products,
'total': len(products),
'page': page_int,
'page_size': page_size,
'pages': (len(products) + page_size - 1) // page_size,
'filters': {
'category': category,
'sort': sort
}
}
# 缓存结果(30秒)
product_cache.set(cache_key, result, timeout=30)
result['from_cache'] = False
else:
result = products
result['from_cache'] = True
return jsonify(result)
# 示例5:模板片段缓存
@app.route('/dashboard')
def dashboard():
"""仪表板页面(演示模板缓存)"""
# 缓存一些仪表板数据
stats = cache.get('dashboard_stats')
if stats is None:
# 模拟统计计算
time.sleep(2)
stats = {
'total_users': random.randint(1000, 10000),
'active_today': random.randint(100, 1000),
'total_orders': random.randint(5000, 50000),
'revenue_today': round(random.uniform(1000.0, 10000.0), 2),
'top_products': [
{'name': 'Product A', 'sales': random.randint(100, 1000)},
{'name': 'Product B', 'sales': random.randint(100, 1000)},
{'name': 'Product C', 'sales': random.randint(100, 1000)}
]
}
# 缓存统计(60秒)
cache.set('dashboard_stats', stats, timeout=60)
return render_template('dashboard.html', stats=stats)
# 示例6:缓存事件回调
def log_cache_event(event, *args):
"""缓存事件日志回调"""
print(f"[CACHE EVENT] {event}: {args}")
# 注册缓存事件回调
cache_manager.register_callback('on_set', lambda *args: log_cache_event('SET', *args))
cache_manager.register_callback('on_get', lambda *args: log_cache_event('GET', *args))
cache_manager.register_callback('on_delete', lambda *args: log_cache_event('DELETE', *args))
# 示例7:缓存管理API
@app.route('/cache/stats')
def cache_stats():
"""获取缓存统计信息"""
stats = cache_manager.get_stats()
return jsonify({
'hits': stats.hits,
'misses': stats.misses,
'hit_rate': stats.hit_rate(),
'sets': stats.sets,
'deletes': stats.deletes,
'clears': stats.clears
})
@app.route('/cache/clear', methods=['POST'])
def clear_cache():
"""清除缓存"""
cache_manager.clear()
return jsonify({
'message': 'Cache cleared',
'timestamp': time.time()
})
@app.route('/cache/info')
def cache_info():
"""获取缓存配置信息"""
return jsonify({
'backend': cache_manager.backend_type.value,
'default_timeout': cache_manager.default_timeout,
'key_prefix': cache_manager.key_prefix,
'serializer': cache_manager.serializer,
'named_caches': list(cache_manager._caches.keys())
})
if __name__ == '__main__':
app.run(debug=True, port=5001)
5. 扩展测试与质量保证
5.1 编写扩展测试
python
# tests/test_heartbeat.py
import pytest
import json
import time
from flask import Flask
from flask_heartbeat import Heartbeat, HealthStatus
@pytest.fixture
def app():
"""创建测试应用"""
app = Flask(__name__)
app.config['TESTING'] = True
app.config['HEARTBEAT_ENDPOINT'] = '/health'
return app
@pytest.fixture
def heartbeat(app):
"""创建Heartbeat扩展"""
return Heartbeat(app)
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
def test_heartbeat_initialization():
"""测试扩展初始化"""
app = Flask(__name__)
heartbeat = Heartbeat(app)
assert heartbeat.app == app
assert heartbeat.endpoint == '/health'
assert heartbeat.health_check_funcs == []
assert heartbeat.auth_required == False
def test_heartbeat_factory_pattern():
"""测试工厂模式初始化"""
app = Flask(__name__)
heartbeat = Heartbeat()
heartbeat.init_app(app)
assert heartbeat.app == app
assert 'heartbeat' in app.extensions
def test_health_endpoint(client, heartbeat):
"""测试健康检查端点"""
# 添加一个测试检查
@heartbeat.add_check
def test_check():
return HealthCheckResult(
name="test",
status=HealthStatus.HEALTHY,
message="Test check passed"
)
# 测试健康端点
response = client.get('/health')
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 'healthy'
assert len(data['results']) == 1
assert data['results'][0]['name'] == 'test'
def test_health_check_with_custom_endpoint():
"""测试自定义端点"""
app = Flask(__name__)
app.config['HEARTBEAT_ENDPOINT'] = '/custom-health'
heartbeat = Heartbeat(app)
with app.test_client() as client:
response = client.get('/custom-health')
assert response.status_code == 200
def test_multiple_health_checks(client, heartbeat):
"""测试多个健康检查"""
check_count = 3
for i in range(check_count):
@heartbeat.add_check
def check():
return HealthCheckResult(
name=f"check_{i}",
status=HealthStatus.HEALTHY,
message=f"Check {i} passed"
)
response = client.get('/health')
data = json.loads(response.data)
assert len(data['results']) == check_count
assert data['checks']['total'] == check_count
assert data['checks']['healthy'] == check_count
def test_health_check_statuses(client, heartbeat):
"""测试不同健康状态"""
@heartbeat.add_check
def healthy_check():
return HealthCheckResult(
name="healthy",
status=HealthStatus.HEALTHY,
message="Healthy check"
)
@heartbeat.add_check
def unhealthy_check():
return HealthCheckResult(
name="unhealthy",
status=HealthStatus.UNHEALTHY,
message="Unhealthy check"
)
@heartbeat.add_check
def degraded_check():
return HealthCheckResult(
name="degraded",
status=HealthStatus.DEGRADED,
message="Degraded check"
)
response = client.get('/health')
data = json.loads(response.data)
assert data['status'] == 'unhealthy' # 有unhealthy检查
assert data['checks']['healthy'] == 1
assert data['checks']['unhealthy'] == 1
assert data['checks']['degraded'] == 1
def test_cache_functionality(client, heartbeat):
"""测试缓存功能"""
execution_count = 0
@heartbeat.add_check
def counting_check():
nonlocal execution_count
execution_count += 1
return HealthCheckResult(
name="counting",
status=HealthStatus.HEALTHY,
message=f"Executed {execution_count} times"
)
# 第一次请求应该执行检查
response1 = client.get('/health')
data1 = json.loads(response1.data)
# 第二次请求应该使用缓存(如果缓存时间>0)
response2 = client.get('/health')
data2 = json.loads(response2.data)
# 检查缓存是否工作(取决于缓存配置)
# 这里我们主要测试接口,不验证具体缓存行为
def test_authentication_required():
"""测试认证要求"""
app = Flask(__name__)
app.config['HEARTBEAT_AUTH_REQUIRED'] = True
def auth_func(request):
return request.headers.get('X-API-Key') == 'secret'
heartbeat = Heartbeat(app, auth_required=True, auth_func=auth_func)
with app.test_client() as client:
# 没有认证的请求应该返回401
response = client.get('/health')
assert response.status_code == 401
# 有认证的请求应该成功
response = client.get('/health', headers={'X-API-Key': 'secret'})
assert response.status_code == 200
def test_detailed_endpoint(client, heartbeat):
"""测试详细端点"""
@heartbeat.add_check
def detailed_check():
return HealthCheckResult(
name="detailed",
status=HealthStatus.HEALTHY,
message="Detailed check",
data={"extra": "information"}
)
response = client.get('/health/detailed')
data = json.loads(response.data)
assert 'results' in data
assert len(data['results']) == 1
# 详细端点应该包含数据字段
assert 'data' in data['results'][0]
def test_readiness_endpoint(client, heartbeat):
"""测试就绪端点"""
@heartbeat.add_check
def healthy_check():
return HealthCheckResult(
name="healthy",
status=HealthStatus.HEALTHY,
message="Healthy"
)
response = client.get('/health/ready')
assert response.status_code == 200 # 所有检查都健康
@heartbeat.add_check
def unhealthy_check():
return HealthCheckResult(
name="unhealthy",
status=HealthStatus.UNHEALTHY,
message="Unhealthy"
)
response = client.get('/health/ready')
assert response.status_code == 503 # 有不健康检查
def test_liveness_endpoint(client, heartbeat):
"""测试存活端点"""
response = client.get('/health/live')
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 'alive'
assert 'timestamp' in data
def test_cli_command(app, heartbeat, runner):
"""测试CLI命令"""
# 添加一个测试检查
@heartbeat.add_check
def test_check():
return HealthCheckResult(
name="test",
status=HealthStatus.HEALTHY,
message="Test check"
)
# 测试health-check命令
result = runner.invoke(args=['health-check'])
assert result.exit_code == 0
assert 'Health Check Results' in result.output
assert 'test' in result.output
5.2 扩展发布与分发
5.2.1 打包扩展
python
# setup.py - 完整的扩展打包配置
from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
with open("requirements.txt", "r", encoding="utf-8") as fh:
requirements = fh.read().splitlines()
setup(
name="flask-heartbeat",
version="1.0.0",
author="Your Name",
author_email="your.email@example.com",
description="A Flask extension for health checks and monitoring",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/yourusername/flask-heartbeat",
project_urls={
"Bug Tracker": "https://github.com/yourusername/flask-heartbeat/issues",
"Documentation": "https://flask-heartbeat.readthedocs.io/",
"Source Code": "https://github.com/yourusername/flask-heartbeat",
},
classifiers=[
"Development Status :: 4 - Beta",
"Environment :: Web Environment",
"Framework :: Flask",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Internet :: WWW/HTTP :: Dynamic Content",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: System :: Monitoring",
],
packages=find_packages(exclude=["tests", "tests.*", "examples", "examples.*"]),
python_requires=">=3.7",
install_requires=requirements,
extras_require={
"dev": [
"pytest>=6.0",
"pytest-flask>=1.2.0",
"black>=21.0",
"flake8>=4.0",
"mypy>=0.910",
"sphinx>=4.0",
"sphinx-rtd-theme>=1.0",
],
"redis": [
"redis>=4.0",
],
},
include_package_data=True,
zip_safe=False,
entry_points={
"console_scripts": [
"flask-heartbeat-check=flask_heartbeat.cli:main",
],
"flask.commands": [
"heartbeat=flask_heartbeat.cli:heartbeat_cli",
],
},
)
5.2.2 创建文档
python
# docs/conf.py - Sphinx文档配置
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
project = 'Flask-Heartbeat'
copyright = '2023, Your Name'
author = 'Your Name'
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
]
templates_path = ['_templates']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'flask': ('https://flask.palletsprojects.com/', None),
}
# API文档配置
autodoc_default_options = {
'members': True,
'member-order': 'bysource',
'special-members': '__init__',
'undoc-members': True,
'exclude-members': '__weakref__'
}
# 扩展元数据
release = '1.0.0'
version = '1.0'
6. 高级扩展开发技巧
6.1 扩展的配置管理
python
class ConfigurableExtension:
"""
可配置扩展的基类。
演示如何优雅地处理扩展配置。
"""
# 默认配置
DEFAULT_CONFIG = {
'ENABLED': True,
'DEBUG': False,
'TIMEOUT': 30,
'MAX_RETRIES': 3,
}
# 配置键映射(应用配置 -> 扩展配置)
CONFIG_MAPPING = {
'EXTENSION_ENABLED': 'ENABLED',
'EXTENSION_DEBUG': 'DEBUG',
'EXTENSION_TIMEOUT': 'TIMEOUT',
'EXTENSION_MAX_RETRIES': 'MAX_RETRIES',
}
def __init__(self, app=None, **kwargs):
self.app = app
self._config = self.DEFAULT_CONFIG.copy()
# 更新配置
self._config.update(kwargs)
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化应用,加载配置"""
self.app = app
# 从应用配置加载
self._load_app_config(app)
# 验证配置
self._validate_config()
# 注册到应用扩展
self._register_with_app(app)
def _load_app_config(self, app):
"""从应用配置加载扩展配置"""
for app_key, ext_key in self.CONFIG_MAPPING.items():
if app_key in app.config:
self._config[ext_key] = app.config[app_key]
def _validate_config(self):
"""验证配置有效性"""
if self._config['TIMEOUT'] <= 0:
raise ValueError("TIMEOUT must be positive")
if self._config['MAX_RETRIES'] < 0:
raise ValueError("MAX_RETRIES cannot be negative")
def _register_with_app(self, app):
"""注册扩展到应用"""
if not hasattr(app, 'extensions'):
app.extensions = {}
# 使用类名作为扩展键
ext_name = self.__class__.__name__.lower()
app.extensions[ext_name] = self
def get_config(self, key, default=None):
"""获取配置值"""
return self._config.get(key, default)
def set_config(self, key, value):
"""设置配置值"""
self._config[key] = value
@property
def is_enabled(self):
"""检查扩展是否启用"""
return self._config['ENABLED']
@property
def is_debug(self):
"""检查是否调试模式"""
return self._config['DEBUG']
6.2 扩展的依赖注入
python
from typing import Dict, Any, Optional
from flask import current_app
class DependencyManager:
"""
扩展依赖管理器。
演示如何在扩展中管理依赖关系。
"""
def __init__(self):
self._dependencies = {}
self._factories = {}
self._singletons = {}
def register(self, name: str, dependency: Any) -> None:
"""
注册依赖实例。
Args:
name: 依赖名称
dependency: 依赖实例
"""
self._dependencies[name] = dependency
def register_factory(self, name: str, factory: callable) -> None:
"""
注册依赖工厂函数。
Args:
name: 依赖名称
factory: 工厂函数
"""
self._factories[name] = factory
def register_singleton(self, name: str, factory: callable) -> None:
"""
注册单例工厂。
Args:
name: 依赖名称
factory: 工厂函数
"""
self._factories[name] = factory
# 标记为单例
if name not in self._singletons:
self._singletons[name] = None
def get(self, name: str) -> Any:
"""
获取依赖。
Args:
name: 依赖名称
Returns:
依赖实例
Raises:
KeyError: 如果依赖不存在
"""
# 检查直接注册的实例
if name in self._dependencies:
return self._dependencies[name]
# 检查工厂函数
if name in self._factories:
factory = self._factories[name]
# 如果是单例且已创建,返回缓存的实例
if name in self._singletons:
if self._singletons[name] is None:
self._singletons[name] = factory()
return self._singletons[name]
# 否则创建新实例
return factory()
raise KeyError(f"Dependency '{name}' not found")
def has(self, name: str) -> bool:
"""
检查依赖是否存在。
Args:
name: 依赖名称
Returns:
bool: 是否存在
"""
return name in self._dependencies or name in self._factories
class InjectableExtension:
"""
支持依赖注入的扩展基类。
"""
def __init__(self, app=None):
self.app = app
self.dependencies = DependencyManager()
# 注册自身
self.dependencies.register('extension', self)
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化应用"""
self.app = app
# 注册应用上下文
self.dependencies.register('app', app)
self.dependencies.register('config', app.config)
# 注册常用Flask组件
self.dependencies.register_factory('request', lambda: self._get_request())
self.dependencies.register_factory('session', lambda: self._get_session())
self.dependencies.register_factory('g', lambda: self._get_g())
# 注册扩展到应用
self._register_extension(app)
def _get_request(self):
"""获取当前请求(仅在请求上下文中可用)"""
try:
from flask import request as flask_request
return flask_request
except RuntimeError:
return None
def _get_session(self):
"""获取当前会话"""
try:
from flask import session as flask_session
return flask_session
except RuntimeError:
return None
def _get_g(self):
"""获取g对象"""
try:
from flask import g as flask_g
return flask_g
except RuntimeError:
return None
def _register_extension(self, app):
"""注册扩展到应用"""
if not hasattr(app, 'extensions'):
app.extensions = {}
ext_name = self.__class__.__name__.lower()
app.extensions[ext_name] = self
# 添加依赖访问器
app.extensions[f'{ext_name}_deps'] = self.dependencies
def inject(self, func):
"""
依赖注入装饰器。
使用示例:
@extension.inject
def my_function(request, session, extension):
# request, session, extension 会自动注入
pass
"""
import inspect
def wrapper(*args, **kwargs):
# 获取函数参数
sig = inspect.signature(func)
params = sig.parameters
# 准备注入参数
injected_kwargs = {}
for param_name in params:
if param_name not in kwargs: # 如果参数未提供
try:
# 尝试从依赖中获取
injected_kwargs[param_name] = self.dependencies.get(param_name)
except KeyError:
# 依赖不存在,使用默认值
pass
# 合并参数
all_kwargs = {**injected_kwargs, **kwargs}
# 调用函数
return func(*args, **all_kwargs)
return wrapper
6.3 扩展的信号系统
python
from blinker import Namespace
# 创建扩展特定的信号命名空间
extension_signals = Namespace()
class SignalExtension:
"""
支持信号系统的扩展。
演示如何在扩展中使用信号。
"""
def __init__(self, app=None):
self.app = app
# 定义扩展信号
self.signals = {
'before_init': extension_signals.signal('extension.before_init'),
'after_init': extension_signals.signal('extension.after_init'),
'before_request': extension_signals.signal('extension.before_request'),
'after_request': extension_signals.signal('extension.after_request'),
'error_occurred': extension_signals.signal('extension.error'),
}
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化应用"""
self.app = app
# 发送初始化前信号
self.signals['before_init'].send(self, app=app)
# 执行初始化
self._setup_app(app)
# 发送初始化后信号
self.signals['after_init'].send(self, app=app)
# 注册请求钩子
self._register_hooks(app)
# 注册扩展
self._register_extension(app)
def _setup_app(self, app):
"""设置应用(子类重写)"""
pass
def _register_hooks(self, app):
"""注册请求钩子"""
@app.before_request
def before_request_hook():
self.signals['before_request'].send(self)
@app.after_request
def after_request_hook(response):
self.signals['after_request'].send(self, response=response)
return response
@app.errorhandler(Exception)
def error_handler(error):
self.signals['error_occurred'].send(self, error=error)
# 返回默认错误处理
raise error
def _register_extension(self, app):
"""注册扩展到应用"""
if not hasattr(app, 'extensions'):
app.extensions = {}
ext_name = self.__class__.__name__.lower()
app.extensions[ext_name] = self
def connect(self, signal_name, callback):
"""
连接信号回调。
Args:
signal_name: 信号名称
callback: 回调函数
"""
if signal_name not in self.signals:
raise ValueError(f"Unknown signal: {signal_name}")
self.signals[signal_name].connect(callback)
def disconnect(self, signal_name, callback):
"""
断开信号连接。
Args:
signal_name: 信号名称
callback: 回调函数
"""
if signal_name in self.signals:
self.signals[signal_name].disconnect(callback)
7. 扩展开发最佳实践
7.1 代码质量保证
python
# .pre-commit-config.yaml - 预提交钩子配置
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-ast
- id: check-json
- id: check-toml
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 22.6.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
hooks:
- id: flake8
additional_dependencies: [flake8-docstrings]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.971
hooks:
- id: mypy
additional_dependencies: [types-all]
exclude: ^tests/
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort
# Makefile - 自动化构建和测试
.PHONY: install test lint format docs clean
install:
pip install -e .[dev]
test:
pytest tests/ -v --cov=flask_heartbeat --cov-report=html --cov-report=term-missing
lint:
flake8 flask_heartbeat/ tests/ examples/
mypy flask_heartbeat/
format:
black flask_heartbeat/ tests/ examples/
isort flask_heartbeat/ tests/ examples/
docs:
cd docs && make html
clean:
rm -rf build/
rm -rf dist/
rm -rf *.egg-info
rm -rf .coverage
rm -rf htmlcov/
rm -rf docs/_build/
find . -type f -name "*.pyc" -delete
find . -type d -name "__pycache__" -delete
build:
python setup.py sdist bdist_wheel
publish-test:
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
publish:
twine upload dist/*
7.2 性能优化建议
python
class OptimizedExtension:
"""
性能优化的扩展示例。
演示扩展开发中的性能优化技巧。
"""
def __init__(self, app=None):
self.app = app
# 使用惰性加载
self._lazy_loaded_data = None
self._cached_result = None
self._cache_timestamp = 0
# 使用局部变量缓存
self._local_cache = {}
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化应用"""
self.app = app
# 使用属性缓存
self._config_cache = {}
# 注册扩展
self._register_extension(app)
@property
def lazy_data(self):
"""惰性加载属性"""
if self._lazy_loaded_data is None:
# 模拟耗时操作
self._lazy_loaded_data = self._expensive_operation()
return self._lazy_loaded_data
def _expensive_operation(self):
"""模拟耗时操作"""
import time
time.sleep(0.1) # 模拟100ms操作
return {"data": "expensive result"}
def get_cached_result(self, force_refresh=False):
"""
带缓存的结果获取。
Args:
force_refresh: 是否强制刷新缓存
Returns:
缓存的结果
"""
import time
cache_ttl = 60 # 缓存60秒
# 检查缓存是否有效
current_time = time.time()
cache_valid = (
not force_refresh and
self._cached_result is not None and
current_time - self._cache_timestamp < cache_ttl
)
if cache_valid:
return self._cached_result
# 重新计算
result = self._compute_result()
# 更新缓存
self._cached_result = result
self._cache_timestamp = current_time
return result
def _compute_result(self):
"""计算结果"""
import time
time.sleep(0.05) # 模拟50ms计算
return {"computed": time.time()}
def get_config_value(self, key, default=None):
"""
获取配置值(带缓存)。
Args:
key: 配置键
default: 默认值
Returns:
配置值
"""
# 检查本地缓存
if key in self._local_cache:
return self._local_cache[key]
# 从应用配置获取
if self.app and key in self.app.config:
value = self.app.config[key]
else:
value = default
# 更新本地缓存
self._local_cache[key] = value
return value
def clear_cache(self):
"""清空缓存"""
self._lazy_loaded_data = None
self._cached_result = None
self._cache_timestamp = 0
self._local_cache.clear()
self._config_cache.clear()
7.3 安全注意事项
python
class SecureExtension:
"""
安全扩展示例。
演示扩展开发中的安全注意事项。
"""
def __init__(self, app=None):
self.app = app
# 安全相关配置
self.security_config = {
'allow_unsafe_operations': False,
'max_input_size': 1024 * 1024, # 1MB
'allowed_schemes': ['http', 'https'],
'sanitize_input': True,
}
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化应用"""
self.app = app
# 从应用配置加载安全配置
self._load_security_config(app)
# 验证安全配置
self._validate_security_config()
# 注册扩展
self._register_extension(app)
def _load_security_config(self, app):
"""从应用配置加载安全配置"""
security_keys = [
('EXTENSION_ALLOW_UNSAFE', 'allow_unsafe_operations'),
('EXTENSION_MAX_INPUT_SIZE', 'max_input_size'),
('EXTENSION_ALLOWED_SCHEMES', 'allowed_schemes'),
('EXTENSION_SANITIZE_INPUT', 'sanitize_input'),
]
for app_key, ext_key in security_keys:
if app_key in app.config:
self.security_config[ext_key] = app.config[app_key]
def _validate_security_config(self):
"""验证安全配置"""
# 验证最大输入大小
max_size = self.security_config['max_input_size']
if max_size <= 0:
raise ValueError("max_input_size must be positive")
# 验证允许的方案
allowed_schemes = self.security_config['allowed_schemes']
if not isinstance(allowed_schemes, list):
raise TypeError("allowed_schemes must be a list")
def sanitize_input(self, input_data):
"""
清理输入数据。
Args:
input_data: 输入数据
Returns:
清理后的数据
"""
if not self.security_config['sanitize_input']:
return input_data
if isinstance(input_data, str):
# 清理字符串
import html
return html.escape(input_data)
elif isinstance(input_data, dict):
# 递归清理字典
return {k: self.sanitize_input(v) for k, v in input_data.items()}
elif isinstance(input_data, list):
# 递归清理列表
return [self.sanitize_input(item) for item in input_data]
else:
# 其他类型直接返回
return input_data
def validate_url(self, url):
"""
验证URL安全性。
Args:
url: 要验证的URL
Returns:
bool: 是否安全
Raises:
ValueError: 如果URL不安全
"""
from urllib.parse import urlparse
parsed = urlparse(url)
# 检查方案
if parsed.scheme not in self.security_config['allowed_schemes']:
raise ValueError(f"URL scheme '{parsed.scheme}' is not allowed")
# 检查特殊字符(简单示例)
dangerous_chars = ['<', '>', '"', "'", '`']
for char in dangerous_chars:
if char in url:
raise ValueError(f"URL contains dangerous character: {char}")
return True
def safe_execute(self, operation, *args, **kwargs):
"""
安全执行操作。
Args:
operation: 要执行的操作
*args: 位置参数
**kwargs: 关键字参数
Returns:
操作结果
Raises:
SecurityError: 如果操作被拒绝
"""
# 检查是否允许不安全操作
if not self.security_config['allow_unsafe_operations']:
# 检查操作是否安全
if not self._is_operation_safe(operation, *args, **kwargs):
raise SecurityError(f"Operation '{operation}' is not allowed")
# 执行操作
return operation(*args, **kwargs)
def _is_operation_safe(self, operation, *args, **kwargs):
"""检查操作是否安全"""
# 这里可以实现更复杂的安全检查
operation_name = operation.__name__ if hasattr(operation, '__name__') else str(operation)
# 定义不安全操作列表
unsafe_operations = [
'eval',
'exec',
'__import__',
'open', # 在某些上下文中
]
return operation_name not in unsafe_operations
class SecurityError(Exception):
"""安全异常"""
pass
8. 总结与展望
8.1 扩展开发的关键要点
- 遵循Flask扩展约定:使用标准命名、初始化模式和API设计
- 支持多种初始化方式:同时支持直接初始化和工厂模式
- 良好的错误处理:提供清晰的错误消息和异常类型
- 完整的文档:包括代码文档、使用示例和API文档
- 全面的测试:覆盖主要功能和边缘情况
- 性能优化:合理使用缓存和惰性加载
- 安全性考虑:验证输入、清理输出、限制危险操作
8.2 扩展生态系统
Flask核心 基础扩展 数据库扩展 认证扩展 缓存扩展 表单扩展 Flask-SQLAlchemy Flask-MongoEngine Flask-Login Flask-Principal Flask-Caching 自定义缓存扩展 Flask-WTF 自定义表单扩展 Flask-Heartbeat
本文示例 Flask-CacheManager
本文示例
8.3 扩展开发的未来趋势
- 异步支持:随着Python异步生态的成熟,支持异步操作的扩展将越来越重要
- 类型提示:更完善的类型提示和静态类型检查
- 容器化友好:更好的容器化和云原生支持
- 微服务集成:与其他微服务组件的深度集成
- AI/ML集成:机器学习模型部署和推理支持
8.4 进一步学习资源
-
官方文档:
-
优秀扩展源码学习:
-
社区资源:
通过本文的学习,您应该已经掌握了Flask扩展开发的核心概念和技能。记住,最好的学习方式是通过实践------选择一个实际问题,尝试用扩展的方式解决它,然后不断迭代和完善。祝您在Flask扩展开发的道路上取得成功!