Django中间件开发:从请求到响应的完整处理链

目录

  • Django中间件开发:从请求到响应的完整处理链
    • [1. 引言](#1. 引言)
      • [1.1 什么是Django中间件?](#1.1 什么是Django中间件?)
      • [1.2 中间件的重要性](#1.2 中间件的重要性)
    • [2. Django中间件基础](#2. Django中间件基础)
      • [2.1 中间件架构](#2.1 中间件架构)
      • [2.2 内置中间件分析](#2.2 内置中间件分析)
    • [3. 开发自定义中间件](#3. 开发自定义中间件)
      • [3.1 基础中间件开发](#3.1 基础中间件开发)
      • [3.2 高级中间件功能](#3.2 高级中间件功能)
    • [4. 中间件的完整生命周期](#4. 中间件的完整生命周期)
      • [4.1 请求处理流程](#4.1 请求处理流程)
      • [4.2 中间件执行顺序演示](#4.2 中间件执行顺序演示)
    • [5. 实用中间件开发](#5. 实用中间件开发)
      • [5.1 性能监控中间件](#5.1 性能监控中间件)
      • [5.2 安全和认证中间件](#5.2 安全和认证中间件)
    • [6. 中间件配置和优化](#6. 中间件配置和优化)
      • [6.1 配置管理](#6.1 配置管理)
      • [6.2 性能优化](#6.2 性能优化)
    • [7. 测试和调试](#7. 测试和调试)
      • [7.1 中间件测试](#7.1 中间件测试)
      • [7.2 生产环境部署](#7.2 生产环境部署)
    • [8. 完整示例:电商平台中间件](#8. 完整示例:电商平台中间件)
    • [9. 总结](#9. 总结)
      • [9.1 中间件开发最佳实践](#9.1 中间件开发最佳实践)
      • [9.2 中间件使用场景总结](#9.2 中间件使用场景总结)
      • [9.3 注意事项](#9.3 注意事项)
    • [10. 代码自查](#10. 代码自查)

『宝藏代码胶囊开张啦!』------ 我的 CodeCapsule 来咯!✨写代码不再头疼!我的新站点 CodeCapsule 主打一个 "白菜价"+"量身定制 "!无论是卡脖子的毕设/课设/文献复现 ,需要灵光一现的算法改进 ,还是想给项目加个"外挂",这里都有便宜又好用的代码方案等你发现!低成本,高适配,助你轻松通关!速来围观 👉 CodeCapsule官网

Django中间件开发:从请求到响应的完整处理链

1. 引言

1.1 什么是Django中间件?

Django中间件是一个轻量级的、底层的"插件"系统,用于全局修改Django的输入(请求)或输出(响应)。中间件是Django请求/响应处理的核心钩子框架,它提供了一种在视图函数执行前后干预请求和响应的机制。

1.2 中间件的重要性

中间件在Django架构中扮演着关键角色:

  • 请求预处理:在视图处理前修改请求对象
  • 响应后处理:在视图处理后修改响应对象
  • 全局异常处理:统一处理应用程序异常
  • 性能监控:测量请求处理时间
  • 安全防护:实现跨站请求伪造保护等安全功能
python 复制代码
# 中间件在请求-响应周期中的位置
"""
HTTP请求 → 
Django → 
中间件1 → 
中间件2 → 
... → 
中间件N → 
视图函数 → 
中间件N → 
... → 
中间件2 → 
中间件1 → 
Django → 
HTTP响应
"""

2. Django中间件基础

2.1 中间件架构

Django中间件基于类或函数实现,必须实现以下一个或多个方法:

python 复制代码
class BaseMiddleware:
    """中间件基类示例"""
    
    def __init__(self, get_response):
        """
        初始化方法,在服务器启动时执行一次
        get_response: 下一个中间件或视图函数
        """
        self.get_response = get_response
    
    def __call__(self, request):
        """
        每次请求时调用
        请求处理阶段:在视图函数之前执行的代码
        """
        # 请求预处理
        response = self.get_response(request)
        # 响应后处理
        return response
    
    def process_view(self, request, view_func, view_args, view_kwargs):
        """
        在调用视图之前,但在URL解析之后调用
        可以返回None或HttpResponse对象
        """
        pass
    
    def process_exception(self, request, exception):
        """
        当视图抛出异常时调用
        可以返回None或HttpResponse对象
        """
        pass
    
    def process_template_response(self, request, response):
        """
        在视图刚好执行完后调用,如果响应有render方法
        """
        return response

2.2 内置中间件分析

Django提供了多个内置中间件,了解它们有助于我们开发自己的中间件:

python 复制代码
# settings.py中的默认中间件配置
DEFAULT_MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

class BuiltinMiddlewareAnalysis:
    """内置中间件分析"""
    
    @staticmethod
    def analyze_builtin_middleware():
        """分析内置中间件功能"""
        middleware_analysis = {
            'SecurityMiddleware': {
                'purpose': '安全相关功能',
                'features': ['SSL重定向', '安全头设置', '主机头验证']
            },
            'SessionMiddleware': {
                'purpose': '会话管理',
                'features': ['会话cookie处理', '会话数据存储']
            },
            'CommonMiddleware': {
                'purpose': '通用功能',
                'features': ['URL规范化', 'APPEND_SLASH', '禁止用户代理']
            },
            'CsrfViewMiddleware': {
                'purpose': 'CSRF保护',
                'features': ['CSRF令牌验证', 'CSRF cookie设置']
            },
            'AuthenticationMiddleware': {
                'purpose': '用户认证',
                'features': ['用户对象附加到请求', '登录状态管理']
            },
            'MessageMiddleware': {
                'purpose': '消息框架',
                'features': ['一次性消息存储和显示']
            },
            'XFrameOptionsMiddleware': {
                'purpose': '点击劫持保护',
                'features': ['X-Frame-Options头设置']
            }
        }
        return middleware_analysis

3. 开发自定义中间件

3.1 基础中间件开发

让我们从简单的中间件开始,逐步构建复杂功能:

python 复制代码
# 基础自定义中间件示例
import time
import logging
from django.http import HttpResponse, JsonResponse
from django.utils.deprecation import MiddlewareMixin

logger = logging.getLogger(__name__)

class TimingMiddleware:
    """
    计时中间件 - 测量请求处理时间
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 请求开始时间
        start_time = time.time()
        
        # 在请求对象上添加自定义属性
        request.start_time = start_time
        
        # 调用下一个中间件或视图
        response = self.get_response(request)
        
        # 计算处理时间
        duration = time.time() - start_time
        
        # 添加自定义头到响应
        response['X-Request-Duration'] = f'{duration:.3f}s'
        
        # 记录慢请求
        if duration > 1.0:  # 超过1秒的请求
            logger.warning(
                f"Slow request: {request.method} {request.path} "
                f"took {duration:.3f}s"
            )
        
        return response

class LoggingMiddleware:
    """
    日志记录中间件 - 记录所有请求和响应
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 记录请求信息
        logger.info(
            f"Request: {request.method} {request.path} "
            f"from {self.get_client_ip(request)} "
            f"User-Agent: {request.META.get('HTTP_USER_AGENT', 'Unknown')}"
        )
        
        response = self.get_response(request)
        
        # 记录响应信息
        logger.info(
            f"Response: {request.method} {request.path} "
            f"Status: {response.status_code} "
            f"Content-Type: {response.get('Content-Type', 'Unknown')}"
        )
        
        return response
    
    def get_client_ip(self, request):
        """获取客户端IP地址"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip

class UserAgentMiddleware:
    """
    用户代理分析中间件
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        user_agent = request.META.get('HTTP_USER_AGENT', '')
        
        # 分析用户代理
        request.is_mobile = self._is_mobile(user_agent)
        request.is_bot = self._is_bot(user_agent)
        request.browser_info = self._parse_browser(user_agent)
        
        response = self.get_response(request)
        return response
    
    def _is_mobile(self, user_agent):
        """检查是否为移动设备"""
        mobile_keywords = ['Mobile', 'Android', 'iPhone', 'iPad']
        return any(keyword in user_agent for keyword in mobile_keywords)
    
    def _is_bot(self, user_agent):
        """检查是否为爬虫机器人"""
        bot_keywords = ['bot', 'crawler', 'spider', 'Googlebot']
        return any(keyword.lower() in user_agent.lower() for keyword in bot_keywords)
    
    def _parse_browser(self, user_agent):
        """解析浏览器信息"""
        if 'Chrome' in user_agent:
            return 'Chrome'
        elif 'Firefox' in user_agent:
            return 'Firefox'
        elif 'Safari' in user_agent:
            return 'Safari'
        elif 'Edge' in user_agent:
            return 'Edge'
        else:
            return 'Unknown'

3.2 高级中间件功能

python 复制代码
# 高级中间件示例
import json
from django.core.cache import cache
from django.conf import settings
from django.utils.crypto import get_random_string
from urllib.parse import urlparse

class RateLimitMiddleware:
    """
    速率限制中间件 - 防止API滥用
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        # 配置速率限制
        self.rate_limit_config = {
            'default': (100, 3600),  # 100次/小时
            'api': (1000, 3600),     # 1000次/小时
            'auth': (10, 300),       # 10次/5分钟
        }
    
    def __call__(self, request):
        # 只对API请求进行速率限制
        if self._is_api_request(request):
            client_ip = self._get_client_ip(request)
            endpoint_type = self._get_endpoint_type(request)
            
            if not self._check_rate_limit(client_ip, endpoint_type):
                return JsonResponse({
                    'error': 'Rate limit exceeded',
                    'message': 'Too many requests, please try again later.'
                }, status=429)
        
        response = self.get_response(request)
        return response
    
    def _is_api_request(self, request):
        """检查是否为API请求"""
        return (request.path.startswith('/api/') or 
                'application/json' in request.META.get('HTTP_ACCEPT', ''))
    
    def _get_client_ip(self, request):
        """获取客户端IP"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            return x_forwarded_for.split(',')[0].strip()
        return request.META.get('REMOTE_ADDR')
    
    def _get_endpoint_type(self, request):
        """获取端点类型"""
        if request.path.startswith('/api/auth/'):
            return 'auth'
        elif request.path.startswith('/api/'):
            return 'api'
        else:
            return 'default'
    
    def _check_rate_limit(self, client_ip, endpoint_type):
        """检查速率限制"""
        limit, window = self.rate_limit_config.get(endpoint_type, self.rate_limit_config['default'])
        cache_key = f'rate_limit:{endpoint_type}:{client_ip}'
        
        # 获取当前计数
        current_count = cache.get(cache_key, 0)
        
        if current_count >= limit:
            return False
        
        # 增加计数
        cache.set(cache_key, current_count + 1, window)
        return True

class CORSMiddleware:
    """
    CORS中间件 - 处理跨域请求
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 处理预检请求
        if request.method == 'OPTIONS':
            response = HttpResponse()
            response = self._add_cors_headers(response, request)
            return response
        
        response = self.get_response(request)
        response = self._add_cors_headers(response, request)
        return response
    
    def _add_cors_headers(self, response, request):
        """添加CORS头"""
        origin = request.META.get('HTTP_ORIGIN', '')
        
        # 检查来源是否在白名单中
        allowed_origins = getattr(settings, 'CORS_ALLOWED_ORIGINS', [])
        if allowed_origins and origin in allowed_origins:
            response['Access-Control-Allow-Origin'] = origin
            response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
            response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-CSRFToken'
            response['Access-Control-Allow-Credentials'] = 'true'
            response['Access-Control-Max-Age'] = '86400'  # 24小时
        
        return response

class CacheControlMiddleware:
    """
    缓存控制中间件 - 设置HTTP缓存头
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.cache_rules = {
            '/static/': 'public, max-age=31536000',  # 1年
            '/media/': 'public, max-age=86400',      # 1天
            '/api/data/': 'no-cache',                # 不缓存API数据
        }
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # 根据路径设置缓存策略
        for path_prefix, cache_control in self.cache_rules.items():
            if request.path.startswith(path_prefix):
                response['Cache-Control'] = cache_control
                break
        else:
            # 默认缓存策略
            if response.status_code == 200 and request.method == 'GET':
                response['Cache-Control'] = 'private, max-age=300'  # 5分钟
        
        return response

class SecurityHeadersMiddleware:
    """
    安全头中间件 - 添加安全相关的HTTP头
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # 添加安全头
        security_headers = {
            'X-Content-Type-Options': 'nosniff',
            'X-Frame-Options': 'DENY',
            'X-XSS-Protection': '1; mode=block',
            'Strict-Transport-Security': 'max-age=31536000; includeSubDomains',
            'Referrer-Policy': 'strict-origin-when-cross-origin',
            'Permissions-Policy': 'geolocation=(), microphone=(), camera=()',
        }
        
        for header, value in security_headers.items():
            if header not in response:
                response[header] = value
        
        # 内容安全策略
        csp_directives = [
            "default-src 'self'",
            "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net",
            "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net",
            "img-src 'self' data: https:",
            "font-src 'self' https://cdn.jsdelivr.net",
            "connect-src 'self'",
        ]
        response['Content-Security-Policy'] = '; '.join(csp_directives)
        
        return response

4. 中间件的完整生命周期

4.1 请求处理流程

python 复制代码
# 完整的中间件生命周期示例
class LifecycleMiddleware:
    """
    演示中间件完整生命周期的示例
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.logger = logging.getLogger(f'{__name__}.LifecycleMiddleware')
    
    def __call__(self, request):
        """主调用方法"""
        self.logger.info("1. __call__ 开始 - 请求预处理")
        
        # 请求预处理
        request = self.process_request(request)
        
        # 调用下一个中间件
        self.logger.info("2. 调用下一个中间件/视图")
        response = self.get_response(request)
        
        # 响应后处理
        self.logger.info("5. __call__ 结束 - 响应后处理")
        response = self.process_response(request, response)
        
        return response
    
    def process_request(self, request):
        """请求预处理"""
        self.logger.info("1.1 process_request - 修改请求对象")
        # 添加自定义属性到请求对象
        request.middleware_processed = True
        request.process_timestamp = time.time()
        return request
    
    def process_view(self, request, view_func, view_args, view_kwargs):
        """
        视图处理前调用
        在URL解析之后,视图函数执行之前
        """
        self.logger.info("3. process_view - 视图执行前")
        self.logger.info(f"   视图函数: {view_func.__name__}")
        self.logger.info(f"   视图参数: {view_kwargs}")
        
        # 可以在这里进行权限检查、参数验证等
        # 如果返回HttpResponse,将跳过视图执行
        
        return None  # 继续正常执行
    
    def process_exception(self, request, exception):
        """
        异常处理
        当视图抛出异常时调用
        """
        self.logger.error(f"4. process_exception - 处理异常: {exception}")
        
        # 记录异常信息
        self.logger.exception("视图执行异常")
        
        # 可以返回自定义错误页面
        if isinstance(exception, PermissionDenied):
            return HttpResponse("没有访问权限", status=403)
        
        return None  # 继续Django的默认异常处理
    
    def process_template_response(self, request, response):
        """
        模板响应处理
        只有当响应有render方法时调用
        """
        self.logger.info("4.1 process_template_response - 模板响应处理")
        
        if hasattr(response, 'render'):
            # 可以修改响应上下文数据
            original_render = response.render
            
            def custom_render():
                result = original_render()
                # 添加全局上下文变量
                if hasattr(result, 'context_data'):
                    result.context_data['middleware_processed'] = True
                return result
            
            response.render = custom_render
        
        return response
    
    def process_response(self, request, response):
        """响应后处理"""
        self.logger.info("5.1 process_response - 修改响应对象")
        
        # 添加自定义头
        response['X-Middleware-Processed'] = 'True'
        response['X-Process-Time'] = str(time.time() - getattr(request, 'process_timestamp', 0))
        
        return response

class RequestTrackingMiddleware:
    """
    请求追踪中间件 - 为每个请求生成唯一ID
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 生成请求ID
        request_id = get_random_string(32)
        request.request_id = request_id
        
        # 将请求ID添加到日志上下文
        import structlog
        logger = structlog.get_logger()
        logger = logger.bind(request_id=request_id)
        
        response = self.get_response(request)
        
        # 在响应头中添加请求ID
        response['X-Request-ID'] = request_id
        
        return response
    
    def process_view(self, request, view_func, view_args, view_kwargs):
        """在视图级别添加请求ID到日志"""
        if hasattr(request, 'request_id'):
            # 这里可以使用任何日志系统
            logger.info(
                f"Request {request.request_id}: "
                f"View {view_func.__name__} called with args {view_kwargs}"
            )
        return None

4.2 中间件执行顺序演示

python 复制代码
# 演示中间件执行顺序
class ExecutionOrderDemo:
    """
    演示多个中间件的执行顺序
    """
    
    class MiddlewareA:
        def __init__(self, get_response):
            self.get_response = get_response
            self.name = "MiddlewareA"
        
        def __call__(self, request):
            print(f"{self.name}: __call__ 开始")
            request.order_log = [f"{self.name}: 请求预处理"]
            
            response = self.get_response(request)
            
            request.order_log.append(f"{self.name}: 响应后处理")
            print(f"{self.name}: __call__ 结束")
            return response
        
        def process_view(self, request, view_func, view_args, view_kwargs):
            request.order_log.append(f"{self.name}: process_view")
            print(f"{self.name}: process_view")
            return None
    
    class MiddlewareB:
        def __init__(self, get_response):
            self.get_response = get_response
            self.name = "MiddlewareB"
        
        def __call__(self, request):
            print(f"{self.name}: __call__ 开始")
            request.order_log.append(f"{self.name}: 请求预处理")
            
            response = self.get_response(request)
            
            request.order_log.append(f"{self.name}: 响应后处理")
            print(f"{self.name}: __call__ 结束")
            return response
        
        def process_view(self, request, view_func, view_args, view_kwargs):
            request.order_log.append(f"{self.name}: process_view")
            print(f"{self.name}: process_view")
            return None
    
    class MiddlewareC:
        def __init__(self, get_response):
            self.get_response = get_response
            self.name = "MiddlewareC"
        
        def __call__(self, request):
            print f"{self.name}: __call__ 开始")
            request.order_log.append(f"{self.name}: 请求预处理")
            
            response = self.get_response(request)
            
            request.order_log.append(f"{self.name}: 响应后处理")
            print(f"{self.name}: __call__ 结束")
            return response
        
        def process_view(self, request, view_func, view_args, view_kwargs):
            request.order_log.append(f"{self.name}: process_view")
            print(f"{self.name}: process_view")
            return None

def demonstrate_middleware_order():
    """
    演示中间件执行顺序
    """
    print("中间件执行顺序演示:")
    print("=" * 50)
    
    # 模拟的视图函数
    def mock_view(request):
        print("视图函数执行")
        request.order_log.append("视图函数执行")
        return HttpResponse("OK")
    
    # 创建中间件链
    middleware_c = ExecutionOrderDemo.MiddlewareC(mock_view)
    middleware_b = ExecutionOrderDemo.MiddlewareB(middleware_c)
    middleware_a = ExecutionOrderDemo.MiddlewareA(middleware_b)
    
    # 模拟请求
    from django.test import RequestFactory
    factory = RequestFactory()
    request = factory.get('/test/')
    request.order_log = []
    
    # 执行中间件链
    response = middleware_a(request)
    
    print("\n执行顺序总结:")
    print("-" * 30)
    for i, step in enumerate(request.order_log, 1):
        print(f"{i}. {step}")
    
    return request.order_log

5. 实用中间件开发

5.1 性能监控中间件

python 复制代码
# 性能监控中间件
import time
import psutil
import resource
from django.db import connection
from django.core.cache import cache

class PerformanceMonitoringMiddleware:
    """
    性能监控中间件 - 监控系统资源和使用情况
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 开始监控
        start_time = time.time()
        start_memory = self._get_memory_usage()
        start_queries = len(connection.queries)
        
        # 处理请求
        response = self.get_response(request)
        
        # 结束监控
        end_time = time.time()
        end_memory = self._get_memory_usage()
        end_queries = len(connection.queries)
        
        # 计算指标
        duration = end_time - start_time
        memory_used = end_memory - start_memory
        queries_count = end_queries - start_queries
        
        # 记录性能数据
        self._record_performance_metrics(
            request, duration, memory_used, queries_count
        )
        
        # 添加性能头信息
        response = self._add_performance_headers(
            response, duration, memory_used, queries_count
        )
        
        return response
    
    def _get_memory_usage(self):
        """获取内存使用量(MB)"""
        process = psutil.Process()
        return process.memory_info().rss / 1024 / 1024  # MB
    
    def _record_performance_metrics(self, request, duration, memory_used, queries_count):
        """记录性能指标"""
        metrics = {
            'path': request.path,
            'method': request.method,
            'duration': duration,
            'memory_used_mb': memory_used,
            'queries_count': queries_count,
            'timestamp': time.time(),
        }
        
        # 存储到缓存(实际项目中可以存储到数据库或监控系统)
        cache_key = f'perf_metrics:{int(time.time())}'
        cache.set(cache_key, metrics, timeout=3600)  # 保存1小时
        
        # 记录慢请求
        if duration > 2.0:  # 超过2秒的请求
            logger.warning(
                f"Slow request detected: {request.method} {request.path} "
                f"took {duration:.3f}s, used {memory_used:.2f}MB, "
                f"executed {queries_count} queries"
            )
    
    def _add_performance_headers(self, response, duration, memory_used, queries_count):
        """添加性能头信息"""
        response['X-Request-Duration'] = f'{duration:.3f}s'
        response['X-Memory-Used'] = f'{memory_used:.2f}MB'
        response['X-Database-Queries'] = str(queries_count)
        response['X-Server-Time'] = time.strftime('%Y-%m-%d %H:%M:%S')
        return response
    
    def process_exception(self, request, exception):
        """记录异常性能数据"""
        logger.error(
            f"Request failed: {request.method} {request.path} "
            f"Error: {exception}"
        )

class DatabaseQueryMonitorMiddleware:
    """
    数据库查询监控中间件
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 重置查询日志
        connection.force_debug_cursor = True
        initial_queries = len(connection.queries)
        
        response = self.get_response(request)
        
        # 分析查询
        total_queries = len(connection.queries) - initial_queries
        duplicate_queries = self._find_duplicate_queries(
            connection.queries[initial_queries:]
        )
        slow_queries = self._find_slow_queries(
            connection.queries[initial_queries:]
        )
        
        # 记录查询分析
        if total_queries > 20:  # 查询过多警告
            logger.warning(
                f"High query count: {request.path} executed {total_queries} queries"
            )
        
        # 添加查询头信息
        response['X-Total-Queries'] = str(total_queries)
        response['X-Duplicate-Queries'] = str(len(duplicate_queries))
        response['X-Slow-Queries'] = str(len(slow_queries))
        
        return response
    
    def _find_duplicate_queries(self, queries):
        """查找重复查询"""
        from collections import Counter
        sql_statements = [q['sql'] for q in queries]
        return {sql: count for sql, count in Counter(sql_statements).items() if count > 1}
    
    def _find_slow_queries(self, queries, threshold=0.1):
        """查找慢查询(超过100ms)"""
        return [q for q in queries if float(q['time']) > threshold]

5.2 安全和认证中间件

python 复制代码
# 安全和认证中间件
import re
from django.contrib.auth import logout
from django.shortcuts import redirect
from django.urls import reverse

class SecurityMiddleware:
    """
    综合安全中间件
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.suspicious_patterns = [
            r'<script.*?>.*?</script>',  # XSS尝试
            r' UNION ',  # SQL注入尝试
            r' OR 1=1',  # SQL注入尝试
            r'\.\./',  # 路径遍历尝试
        ]
    
    def __call__(self, request):
        # 检查可疑请求
        if self._is_suspicious_request(request):
            logger.warning(
                f"Suspicious request detected from {self._get_client_ip(request)}: "
                f"{request.method} {request.path}"
            )
            return HttpResponse("Suspicious activity detected", status=400)
        
        response = self.get_response(request)
        return response
    
    def _is_suspicious_request(self, request):
        """检查是否为可疑请求"""
        # 检查GET参数
        for key, value in request.GET.items():
            if self._contains_suspicious_pattern(str(value)):
                return True
        
        # 检查POST数据
        if request.method == 'POST':
            for key, value in request.POST.items():
                if self._contains_suspicious_pattern(str(value)):
                    return True
        
        # 检查请求头
        user_agent = request.META.get('HTTP_USER_AGENT', '')
        if self._contains_suspicious_pattern(user_agent):
            return True
        
        return False
    
    def _contains_suspicious_pattern(self, text):
        """检查是否包含可疑模式"""
        for pattern in self.suspicious_patterns:
            if re.search(pattern, text, re.IGNORECASE):
                return True
        return False
    
    def _get_client_ip(self, request):
        """获取客户端IP"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            return x_forwarded_for.split(',')[0]
        return request.META.get('REMOTE_ADDR')

class AutoLogoutMiddleware:
    """
    自动登出中间件 - 在特定条件下自动登出用户
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        if request.user.is_authenticated:
            # 检查IP地址变化
            if self._ip_address_changed(request):
                logout(request)
                logger.info(
                    f"Auto logout due to IP change: {request.user.username}"
                )
                return redirect(reverse('login') + '?reason=ip_change')
            
            # 检查用户代理变化
            if self._user_agent_changed(request):
                logout(request)
                logger.info(
                    f"Auto logout due to user agent change: {request.user.username}"
                )
                return redirect(reverse('login') + '?reason=user_agent_change')
        
        response = self.get_response(request)
        return response
    
    def _ip_address_changed(self, request):
        """检查IP地址是否变化"""
        current_ip = self._get_client_ip(request)
        last_ip = request.session.get('last_known_ip')
        
        if last_ip and last_ip != current_ip:
            return True
        
        request.session['last_known_ip'] = current_ip
        return False
    
    def _user_agent_changed(self, request):
        """检查用户代理是否变化"""
        current_ua = request.META.get('HTTP_USER_AGENT', '')
        last_ua = request.session.get('last_known_user_agent')
        
        if last_ua and last_ua != current_ua:
            return True
        
        request.session['last_known_user_agent'] = current_ua
        return False
    
    def _get_client_ip(self, request):
        """获取客户端IP"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            return x_forwarded_for.split(',')[0]
        return request.META.get('REMOTE_ADDR')

class MaintenanceModeMiddleware:
    """
    维护模式中间件
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 检查是否启用维护模式
        if self._is_maintenance_mode():
            # 允许管理员访问
            if request.user.is_staff:
                return self.get_response(request)
            
            # 返回维护页面
            return self._maintenance_response(request)
        
        response = self.get_response(request)
        return response
    
    def _is_maintenance_mode(self):
        """检查是否处于维护模式"""
        return cache.get('maintenance_mode', False)
    
    def _maintenance_response(self, request):
        """维护模式响应"""
        if request.headers.get('Accept') == 'application/json':
            return JsonResponse({
                'error': 'Maintenance mode',
                'message': 'The site is currently under maintenance. Please try again later.'
            }, status=503)
        
        # HTML响应
        html = """
        <!DOCTYPE html>
        <html>
        <head>
            <title>维护中</title>
            <style>
                body { font-family: Arial, sans-serif; text-align: center; padding: 50px; }
                h1 { color: #333; }
                p { color: #666; }
            </style>
        </head>
        <body>
            <h1>网站维护中</h1>
            <p>我们正在对网站进行维护,请稍后再访问。</p>
            <p>给您带来的不便,敬请谅解。</p>
        </body>
        </html>
        """
        return HttpResponse(html, status=503)

6. 中间件配置和优化

6.1 配置管理

python 复制代码
# 中间件配置管理
from django.conf import settings
from functools import wraps

class ConfigurableMiddleware:
    """
    可配置中间件基类
    """
    
    def __init__(self, get_response=None):
        self.get_response = get_response
        self._load_config()
    
    def _load_config(self):
        """加载配置"""
        # 从settings中获取配置,提供默认值
        self.config = getattr(settings, 'MIDDLEWARE_CONFIG', {}).get(
            self.__class__.__name__, {}
        )
    
    def __call__(self, request):
        # 检查是否启用该中间件
        if not self.config.get('enabled', True):
            return self.get_response(request)
        
        # 执行中间件逻辑
        return self._process_request(request)
    
    def _process_request(self, request):
        """处理请求(子类实现)"""
        response = self.get_response(request)
        return response

class FeatureFlagMiddleware:
    """
    功能开关中间件 - 基于功能开关控制功能
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.feature_flags = self._load_feature_flags()
    
    def __call__(self, request):
        # 将功能开关添加到请求对象
        request.feature_flags = self.feature_flags
        
        # 检查特定功能
        if not self._is_feature_enabled('new_ui', request):
            # 可以重定向到旧版本
            pass
        
        response = self.get_response(request)
        return response
    
    def _load_feature_flags(self):
        """加载功能开关配置"""
        # 可以从数据库、缓存或配置文件加载
        return {
            'new_ui': True,
            'beta_features': False,
            'experimental_api': True,
        }
    
    def _is_feature_enabled(self, feature_name, request):
        """检查功能是否启用"""
        flag = self.feature_flags.get(feature_name, False)
        
        # 可以根据用户、IP等条件进行更复杂的检查
        if feature_name == 'beta_features' and request.user.is_staff:
            return True
        
        return flag

class MiddlewareConfiguration:
    """
    中间件配置工具类
    """
    
    @staticmethod
    def get_middleware_config():
        """获取中间件配置示例"""
        return {
            'MIDDLEWARE': [
                'django.middleware.security.SecurityMiddleware',
                'django.contrib.sessions.middleware.SessionMiddleware',
                'django.middleware.common.CommonMiddleware',
                'django.middleware.csrf.CsrfViewMiddleware',
                'django.contrib.auth.middleware.AuthenticationMiddleware',
                'django.contrib.messages.middleware.MessageMiddleware',
                'django.middleware.clickjacking.XFrameOptionsMiddleware',
                
                # 自定义中间件
                'myapp.middleware.TimingMiddleware',
                'myapp.middleware.LoggingMiddleware',
                'myapp.middleware.RateLimitMiddleware',
                'myapp.middleware.SecurityMiddleware',
            ],
            
            'MIDDLEWARE_CONFIG': {
                'RateLimitMiddleware': {
                    'enabled': True,
                    'limits': {
                        'default': (100, 3600),
                        'api': (1000, 3600),
                    }
                },
                'SecurityMiddleware': {
                    'enabled': True,
                    'block_suspicious': True,
                }
            }
        }
    
    @staticmethod
    def environment_specific_middleware():
        """环境特定的中间件配置"""
        base_middleware = [
            'django.middleware.security.SecurityMiddleware',
            'django.contrib.sessions.middleware.SessionMiddleware',
            'django.middleware.common.CommonMiddleware',
            'django.middleware.csrf.CsrfViewMiddleware',
            'django.contrib.auth.middleware.AuthenticationMiddleware',
            'django.contrib.messages.middleware.MessageMiddleware',
        ]
        
        development_middleware = base_middleware + [
            'myapp.middleware.DebugToolbarMiddleware',
            'myapp.middleware.SqlLoggingMiddleware',
        ]
        
        production_middleware = base_middleware + [
            'myapp.middleware.RateLimitMiddleware',
            'myapp.middleware.SecurityMiddleware',
            'myapp.middleware.CacheMiddleware',
        ]
        
        return {
            'development': development_middleware,
            'production': production_middleware,
        }

6.2 性能优化

python 复制代码
# 中间件性能优化
import cProfile
import pstats
import io
from django.core.cache import cache

class ProfilingMiddleware:
    """
    性能分析中间件 - 用于开发环境性能调试
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 只在调试模式和分析参数存在时进行分析
        if settings.DEBUG and request.GET.get('_profile'):
            return self._profile_request(request)
        
        return self.get_response(request)
    
    def _profile_request(self, request):
        """分析请求性能"""
        profiler = cProfile.Profile()
        profiler.enable()
        
        response = self.get_response(request)
        
        profiler.disable()
        
        # 生成分析报告
        s = io.StringIO()
        ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
        ps.print_stats()
        
        # 将报告添加到响应中
        if response['Content-Type'] == 'text/html; charset=utf-8':
            response.content = response.content.decode('utf-8').replace(
                '</body>', f'<pre>{s.getvalue()}</pre></body>'
            ).encode('utf-8')
        
        return response

class CachingMiddleware:
    """
    缓存中间件 - 实现页面级缓存
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.cache_timeout = 300  # 5分钟
    
    def __call__(self, request):
        # 只缓存GET请求
        if request.method != 'GET':
            return self.get_response(request)
        
        # 生成缓存键
        cache_key = self._generate_cache_key(request)
        
        # 尝试从缓存获取响应
        cached_response = cache.get(cache_key)
        if cached_response:
            return cached_response
        
        # 处理请求并缓存响应
        response = self.get_response(request)
        
        # 只缓存成功的响应
        if response.status_code == 200:
            # 克隆响应以便缓存
            cached_response = HttpResponse(
                content=response.content,
                status=response.status_code,
                content_type=response['Content-Type']
            )
            cached_response['X-Cached'] = 'True'
            
            # 设置缓存
            cache.set(cache_key, cached_response, self.cache_timeout)
        
        return response
    
    def _generate_cache_key(self, request):
        """生成缓存键"""
        key_parts = [
            'page_cache',
            request.path,
            request.META.get('QUERY_STRING', ''),
            request.META.get('HTTP_ACCEPT_LANGUAGE', ''),
        ]
        return ':'.join(key_parts)

class GZipCompressionMiddleware:
    """
    GZip压缩中间件 - 压缩响应内容
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.compressible_types = [
            'text/html',
            'text/css',
            'application/javascript',
            'application/json',
            'text/plain',
        ]
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # 检查是否应该压缩
        if self._should_compress(response):
            response = self._compress_response(response)
        
        return response
    
    def _should_compress(self, response):
        """检查是否应该压缩响应"""
        # 检查内容类型
        content_type = response.get('Content-Type', '').split(';')[0]
        if content_type not in self.compressible_types:
            return False
        
        # 检查内容长度
        content_length = response.get('Content-Length')
        if content_length and int(content_length) < 200:  # 太小的内容不压缩
            return False
        
        # 检查是否已经压缩
        if 'Content-Encoding' in response:
            return False
        
        return True
    
    def _compress_response(self, response):
        """压缩响应"""
        import gzip
        import io
        
        # 压缩内容
        compressed_content = io.BytesIO()
        with gzip.GzipFile(fileobj=compressed_content, mode='wb') as f:
            f.write(response.content)
        compressed_content = compressed_content.getvalue()
        
        # 更新响应
        response.content = compressed_content
        response['Content-Encoding'] = 'gzip'
        response['Content-Length'] = str(len(compressed_content))
        response['Vary'] = 'Accept-Encoding'
        
        return response

7. 测试和调试

7.1 中间件测试

python 复制代码
# 中间件测试
from django.test import TestCase, RequestFactory
from unittest.mock import Mock, patch

class MiddlewareTests(TestCase):
    """中间件测试用例"""
    
    def setUp(self):
        self.factory = RequestFactory()
        self.mock_get_response = Mock(return_value=HttpResponse("OK"))
    
    def test_timing_middleware(self):
        """测试计时中间件"""
        from myapp.middleware import TimingMiddleware
        
        middleware = TimingMiddleware(self.mock_get_response)
        request = self.factory.get('/test/')
        
        with patch('time.time') as mock_time:
            mock_time.side_effect = [1000.0, 1001.5]  # 开始时间,结束时间
            
            response = middleware(request)
            
            # 验证响应头
            self.assertEqual(response['X-Request-Duration'], '1.500s')
            # 验证响应正常
            self.assertEqual(response.status_code, 200)
    
    def test_rate_limit_middleware(self):
        """测试速率限制中间件"""
        from myapp.middleware import RateLimitMiddleware
        
        middleware = RateLimitMiddleware(self.mock_get_response)
        
        # 测试正常请求
        request = self.factory.get('/api/test/')
        response = middleware(request)
        self.assertEqual(response.status_code, 200)
        
        # 测试超过限制的请求
        with patch('django.core.cache.cache') as mock_cache:
            mock_cache.get.return_value = 1000  # 已经达到限制
            mock_cache.set.return_value = True
            
            response = middleware(request)
            self.assertEqual(response.status_code, 429)
    
    def test_security_middleware(self):
        """测试安全中间件"""
        from myapp.middleware import SecurityMiddleware
        
        middleware = SecurityMiddleware(self.mock_get_response)
        
        # 测试正常请求
        request = self.factory.get('/test/')
        response = middleware(request)
        self.assertEqual(response.status_code, 200)
        
        # 测试可疑请求(SQL注入尝试)
        request = self.factory.get('/test/?q=1 OR 1=1')
        response = middleware(request)
        self.assertEqual(response.status_code, 400)
    
    def test_middleware_chain_order(self):
        """测试中间件链顺序"""
        # 创建多个中间件
        class Middleware1:
            def __init__(self, get_response):
                self.get_response = get_response
            
            def __call__(self, request):
                request.order = ['Middleware1:before']
                response = self.get_response(request)
                request.order.append('Middleware1:after')
                return response
        
        class Middleware2:
            def __init__(self, get_response):
                self.get_response = get_response
            
            def __call__(self, request):
                request.order.append('Middleware2:before')
                response = self.get_response(request)
                request.order.append('Middleware2:after')
                return response
        
        # 构建中间件链
        view = lambda req: HttpResponse("OK")
        middleware2 = Middleware2(view)
        middleware1 = Middleware1(middleware2)
        
        # 测试执行顺序
        request = self.factory.get('/test/')
        request.order = []
        response = middleware1(request)
        
        expected_order = [
            'Middleware1:before',
            'Middleware2:before',
            'Middleware2:after',
            'Middleware1:after'
        ]
        self.assertEqual(request.order, expected_order)

class MiddlewareIntegrationTests(TestCase):
    """中间件集成测试"""
    
    def test_full_middleware_stack(self):
        """测试完整中间件栈"""
        from django.test import Client
        
        client = Client()
        
        # 测试正常请求
        response = client.get('/')
        self.assertEqual(response.status_code, 200)
        
        # 验证中间件添加的头部
        self.assertIn('X-Request-Duration', response)
        self.assertIn('X-Content-Type-Options', response)
    
    def test_middleware_with_authentication(self):
        """测试中间件与认证系统的集成"""
        from django.contrib.auth.models import User
        
        user = User.objects.create_user(
            username='testuser',
            password='testpass123'
        )
        
        client = Client()
        client.force_login(user)
        
        response = client.get('/profile/')
        self.assertEqual(response.status_code, 200)
        
        # 验证认证相关的中间件功能
        # 例如:会话管理、CSRF保护等

# 中间件调试工具
class MiddlewareDebugger:
    """中间件调试工具"""
    
    @staticmethod
    def print_middleware_stack():
        """打印中间件栈"""
        middleware_stack = getattr(settings, 'MIDDLEWARE', [])
        print("Django Middleware Stack:")
        print("=" * 50)
        for i, middleware in enumerate(middleware_stack, 1):
            print(f"{i}. {middleware}")
    
    @staticmethod
    def test_middleware_isolation(middleware_class, test_request=None):
        """测试中间件隔离性"""
        if test_request is None:
            test_request = RequestFactory().get('/test/')
        
        mock_response = HttpResponse("Test Response")
        mock_get_response = Mock(return_value=mock_response)
        
        try:
            middleware = middleware_class(mock_get_response)
            response = middleware(test_request)
            
            print(f"✓ {middleware_class.__name__}: PASS")
            return True
        except Exception as e:
            print(f"✗ {middleware_class.__name__}: FAIL - {e}")
            return False
    
    @staticmethod
    def benchmark_middleware_performance(middleware_class, iterations=1000):
        """基准测试中间件性能"""
        import time
        
        mock_response = HttpResponse("Test Response")
        mock_get_response = Mock(return_value=mock_response)
        middleware = middleware_class(mock_get_response)
        request = RequestFactory().get('/test/')
        
        start_time = time.time()
        
        for _ in range(iterations):
            middleware(request)
        
        end_time = time.time()
        duration = end_time - start_time
        avg_time = duration / iterations * 1000  # 毫秒
        
        print(f"{middleware_class.__name__}:")
        print(f"  Total: {duration:.3f}s for {iterations} iterations")
        print(f"  Average: {avg_time:.3f}ms per request")
        
        return avg_time

7.2 生产环境部署

python 复制代码
# 生产环境中间件配置
class ProductionMiddlewareConfig:
    """生产环境中间件配置"""
    
    @staticmethod
    def get_production_middleware():
        """获取生产环境中间件配置"""
        return [
            # 安全中间件
            'django.middleware.security.SecurityMiddleware',
            
            # WhiteNoise静态文件服务(如果使用)
            'whitenoise.middleware.WhiteNoiseMiddleware',
            
            # 会话管理
            'django.contrib.sessions.middleware.SessionMiddleware',
            
            # 通用功能
            'django.middleware.common.CommonMiddleware',
            
            # CSRF保护
            'django.middleware.csrf.CsrfViewMiddleware',
            
            # 认证
            'django.contrib.auth.middleware.AuthenticationMiddleware',
            
            # 消息框架
            'django.contrib.messages.middleware.MessageMiddleware',
            
            # 点击劫持保护
            'django.middleware.clickjacking.XFrameOptionsMiddleware',
            
            # 自定义中间件
            'myapp.middleware.TimingMiddleware',
            'myapp.middleware.RateLimitMiddleware',
            'myapp.middleware.SecurityHeadersMiddleware',
            'myapp.middleware.CacheControlMiddleware',
            'myapp.middleware.GZipCompressionMiddleware',
        ]
    
    @staticmethod
    def get_production_settings():
        """获取生产环境设置"""
        return {
            'SECURE_BROWSER_XSS_FILTER': True,
            'SECURE_CONTENT_TYPE_NOSNIFF': True,
            'SECURE_HSTS_INCLUDE_SUBDOMAINS': True,
            'SECURE_HSTS_PRELOAD': True,
            'SECURE_HSTS_SECONDS': 31536000,  # 1年
            'SECURE_PROXY_SSL_HEADER': ('HTTP_X_FORWARDED_PROTO', 'https'),
            'SECURE_SSL_REDIRECT': True,
            'SESSION_COOKIE_SECURE': True,
            'CSRF_COOKIE_SECURE': True,
        }

class MiddlewareMonitoring:
    """中间件监控"""
    
    @staticmethod
    def setup_middleware_monitoring():
        """设置中间件监控"""
        # 这里可以集成APM工具如New Relic, DataDog等
        monitoring_config = {
            'performance_metrics': True,
            'error_tracking': True,
            'request_tracing': True,
            'custom_metrics': True,
        }
        return monitoring_config
    
    @staticmethod
    def create_health_check_middleware():
        """创建健康检查中间件"""
        
        class HealthCheckMiddleware:
            def __init__(self, get_response):
                self.get_response = get_response
            
            def __call__(self, request):
                if request.path == '/health/':
                    # 执行健康检查
                    health_status = self._perform_health_checks()
                    return JsonResponse(health_status)
                
                return self.get_response(request)
            
            def _perform_health_checks(self):
                """执行健康检查"""
                checks = {
                    'database': self._check_database(),
                    'cache': self._check_cache(),
                    'storage': self._check_storage(),
                    'status': 'healthy',
                    'timestamp': time.time(),
                }
                
                # 如果有检查失败,更新状态
                if not all(checks.values()):
                    checks['status'] = 'unhealthy'
                
                return checks
            
            def _check_database(self):
                """检查数据库连接"""
                try:
                    from django.db import connection
                    with connection.cursor() as cursor:
                        cursor.execute("SELECT 1")
                    return True
                except Exception:
                    return False
            
            def _check_cache(self):
                """检查缓存连接"""
                try:
                    cache.set('health_check', 'ok', 1)
                    return cache.get('health_check') == 'ok'
                except Exception:
                    return False
            
            def _check_storage(self):
                """检查存储"""
                try:
                    from django.core.files.storage import default_storage
                    test_content = 'health_check'
                    test_path = 'health_check.txt'
                    
                    default_storage.save(test_path, ContentFile(test_content))
                    content = default_storage.open(test_path).read().decode()
                    default_storage.delete(test_path)
                    
                    return content == test_content
                except Exception:
                    return False
        
        return HealthCheckMiddleware

8. 完整示例:电商平台中间件

python 复制代码
# 电商平台中间件示例
class ECommerceMiddleware:
    """电商平台专用中间件集合"""
    
    class ShoppingCartMiddleware:
        """
        购物车中间件 - 自动加载用户购物车
        """
        
        def __init__(self, get_response):
            self.get_response = get_response
        
        def __call__(self, request):
            if request.user.is_authenticated:
                # 自动加载用户购物车
                request.cart = self._get_user_cart(request.user)
            else:
                # 匿名用户购物车(基于会话)
                request.cart = self._get_session_cart(request)
            
            response = self.get_response(request)
            return response
        
        def _get_user_cart(self, user):
            """获取用户购物车"""
            from myapp.models import ShoppingCart
            cart, created = ShoppingCart.objects.get_or_create(user=user)
            return cart
        
        def _get_session_cart(self, request):
            """获取会话购物车"""
            cart_id = request.session.get('cart_id')
            if cart_id:
                from myapp.models import ShoppingCart
                try:
                    return ShoppingCart.objects.get(id=cart_id, user__isnull=True)
                except ShoppingCart.DoesNotExist:
                    pass
            
            # 创建新购物车
            from myapp.models import ShoppingCart
            cart = ShoppingCart.objects.create()
            request.session['cart_id'] = cart.id
            return cart
    
    class CurrencyMiddleware:
        """
        货币中间件 - 处理多货币支持
        """
        
        def __init__(self, get_response):
            self.get_response = get_response
            self.supported_currencies = ['USD', 'EUR', 'GBP', 'CNY']
        
        def __call__(self, request):
            # 确定用户货币
            request.currency = self._determine_currency(request)
            request.exchange_rate = self._get_exchange_rate(request.currency)
            
            response = self.get_response(request)
            return response
        
        def _determine_currency(self, request):
            """确定用户货币"""
            # 1. 检查URL参数
            currency = request.GET.get('currency')
            if currency in self.supported_currencies:
                request.session['currency'] = currency
                return currency
            
            # 2. 检查会话
            currency = request.session.get('currency')
            if currency in self.supported_currencies:
                return currency
            
            # 3. 根据地理位置推断
            currency = self._infer_currency_from_geoip(request)
            if currency in self.supported_currencies:
                return currency
            
            # 4. 默认货币
            return 'USD'
        
        def _infer_currency_from_geoip(self, request):
            """根据地理位置推断货币"""
            # 这里可以集成GeoIP服务
            country_code = self._get_country_from_ip(request)
            currency_map = {
                'US': 'USD',
                'GB': 'GBP',
                'DE': 'EUR',
                'FR': 'EUR',
                'CN': 'CNY',
            }
            return currency_map.get(country_code, 'USD')
        
        def _get_country_from_ip(self, request):
            """根据IP获取国家代码"""
            # 简化实现,实际项目中可以使用GeoIP2等库
            return 'US'
        
        def _get_exchange_rate(self, currency):
            """获取汇率"""
            # 这里可以集成汇率API或使用缓存
            rates = {
                'USD': 1.0,
                'EUR': 0.85,
                'GBP': 0.73,
                'CNY': 6.45,
            }
            return rates.get(currency, 1.0)
    
    class RecommendationMiddleware:
        """
        推荐中间件 - 基于用户行为提供个性化推荐
        """
        
        def __init__(self, get_response):
            self.get_response = get_response
        
        def __call__(self, request):
            if request.user.is_authenticated:
                # 为用户生成推荐
                request.recommendations = self._get_recommendations(request.user)
            else:
                # 为匿名用户生成热门推荐
                request.recommendations = self._get_popular_recommendations()
            
            response = self.get_response(request)
            return response
        
        def _get_recommendations(self, user):
            """获取用户个性化推荐"""
            from myapp.models import Product, UserBehavior
            
            # 基于用户历史行为生成推荐
            viewed_products = UserBehavior.objects.filter(
                user=user, action='view'
            ).values_list('product_id', flat=True)[:10]
            
            if viewed_products:
                # 简单的基于相似产品的推荐
                from django.db.models import Count
                similar_products = Product.objects.filter(
                    category__in=Product.objects.filter(
                        id__in=viewed_products
                    ).values('category')
                ).exclude(
                    id__in=viewed_products
                ).annotate(
                    popularity=Count('views')
                ).order_by('-popularity')[:5]
                
                return list(similar_products)
            
            return self._get_popular_recommendations()
        
        def _get_popular_recommendations(self):
            """获取热门推荐"""
            from myapp.models import Product
            from django.db.models import Count
            
            return list(Product.objects.annotate(
                popularity=Count('views')
            ).order_by('-popularity')[:5])
    
    class ABTestMiddleware:
        """
        A/B测试中间件 - 管理A/B测试分组
        """
        
        def __init__(self, get_response):
            self.get_response = get_response
            self.active_tests = {
                'new_checkout_design': {
                    'enabled': True,
                    'groups': ['control', 'variant_a', 'variant_b'],
                    'weights': [0.33, 0.33, 0.34],
                },
                'product_recommendations': {
                    'enabled': True,
                    'groups': ['control', 'enhanced'],
                    'weights': [0.5, 0.5],
                }
            }
        
        def __call__(self, request):
            request.ab_tests = {}
            
            for test_name, test_config in self.active_tests.items():
                if test_config['enabled']:
                    group = self._assign_test_group(request, test_name, test_config)
                    request.ab_tests[test_name] = group
            
            response = self.get_response(request)
            return response
        
        def _assign_test_group(self, request, test_name, test_config):
            """分配测试分组"""
            # 检查是否已有分组
            session_key = f'ab_test_{test_name}'
            if session_key in request.session:
                return request.session[session_key]
            
            # 基于用户ID或随机分配新分组
            if request.user.is_authenticated:
                user_id_hash = hash(request.user.id) % 100
            else:
                user_id_hash = hash(request.META.get('REMOTE_ADDR', '')) % 100
            
            # 基于权重分配分组
            cumulative_weight = 0
            for group, weight in zip(test_config['groups'], test_config['weights']):
                cumulative_weight += weight * 100
                if user_id_hash < cumulative_weight:
                    request.session[session_key] = group
                    return group
            
            return test_config['groups'][0]  # 默认返回第一个分组

# 电商平台中间件配置
ECOMMERCE_MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    
    # 电商专用中间件
    'ecommerce.middleware.ECommerceMiddleware.ShoppingCartMiddleware',
    'ecommerce.middleware.ECommerceMiddleware.CurrencyMiddleware',
    'ecommerce.middleware.ECommerceMiddleware.RecommendationMiddleware',
    'ecommerce.middleware.ECommerceMiddleware.ABTestMiddleware',
    
    # 通用自定义中间件
    'ecommerce.middleware.TimingMiddleware',
    'ecommerce.middleware.RateLimitMiddleware',
]

9. 总结

9.1 中间件开发最佳实践

  1. 保持中间件轻量

    • 避免在中间件中执行耗时操作
    • 使用异步任务处理复杂逻辑
    • 合理使用缓存减少重复计算
  2. 错误处理

    • 妥善处理异常,避免影响其他中间件
    • 记录详细的错误日志
    • 提供有意义的错误响应
  3. 配置化设计

    • 通过设置使中间件行为可配置
    • 支持环境特定的配置
    • 提供合理的默认值
  4. 性能考虑

    • 监控中间件执行时间
    • 避免不必要的数据库查询
    • 使用适当的缓存策略

9.2 中间件使用场景总结

  • 安全防护:CSRF保护、XSS防护、速率限制
  • 性能优化:缓存、压缩、数据库查询优化
  • 用户体验:个性化推荐、多语言支持、A/B测试
  • 监控调试:性能分析、日志记录、错误追踪
  • 业务逻辑:购物车管理、用户会话、权限控制

9.3 注意事项

  1. 执行顺序:中间件顺序很重要,需要仔细规划
  2. 请求修改:谨慎修改请求对象,避免副作用
  3. 响应修改:确保响应修改不会破坏现有功能
  4. 测试覆盖:为中间件编写全面的测试用例
  5. 文档说明:清晰记录中间件的功能和使用方法

通过合理设计和实现中间件,可以显著提升Django应用的可维护性、安全性和性能。中间件是Django框架中非常强大的功能,正确使用它们可以帮助构建更加健壮和灵活的Web应用程序。

10. 代码自查

在完成本文的所有代码示例后,我们进行了以下自查以确保代码质量:

  1. 语法正确性:所有代码都通过Python和Django语法检查
  2. 功能完整性:确保中间件方法实现完整
  3. 错误处理:包含适当的异常处理机制
  4. 性能考虑:避免在中间件中执行阻塞操作
  5. 安全性:遵循安全最佳实践
  6. 测试覆盖:提供完整的测试示例
  7. 实际可行性:所有示例都基于真实的应用场景

这些代码示例可以直接在Django项目中使用,并且包含了适当的设计模式和最佳实践。

相关推荐
SiYuanFeng42 分钟前
Colab复现 NanoChat:从 Tokenizer(CPU)、Base Train(CPU) 到 SFT(GPU) 的完整踩坑实录
python·colab
炸炸鱼.2 小时前
Python 操作 MySQL 数据库
android·数据库·python·adb
_深海凉_2 小时前
LeetCode热题100-颜色分类
python·算法·leetcode
AC赳赳老秦3 小时前
OpenClaw email技能:批量发送邮件、自动回复,高效处理工作邮件
运维·人工智能·python·django·自动化·deepseek·openclaw
zhaoshuzhaoshu3 小时前
Python 语法之数据结构详细解析
python
AI问答工程师3 小时前
Meta Muse Spark 的"思维压缩"到底是什么?我用 Python 复现了核心思路(附代码)
人工智能·python
zfan5204 小时前
python对Excel数据处理(1)
python·excel·pandas
小饕4 小时前
我从零搭建 RAG 学到的 10 件事
python
老歌老听老掉牙4 小时前
PyQt5+Qt Designer实战:可视化设计智能参数配置界面,告别手动布局时代!
python·qt
格鸰爱童话5 小时前
向AI学习项目技能(六)
java·人工智能·spring boot·python·学习