Python 依赖注入详解

Python 依赖注入详解

目录

  1. 依赖注入基础概念
  2. 手动依赖注入
  3. 简单依赖注入容器
  4. 高级依赖注入框架
  5. 装饰器实现依赖注入
  6. 类型注解与依赖注入
  7. 生命周期管理
  8. 实际应用场景

1. 依赖注入基础概念

什么是依赖注入?

依赖注入(Dependency Injection,DI)是一种设计模式,用于实现控制反转(IoC)。它允许对象的依赖关系在运行时由外部容器注入,而不是在对象内部创建。

python 复制代码
# 传统方式 - 硬编码依赖
class EmailService:
    def send(self, message):
        print(f"发送邮件: {message}")

class UserService:
    def __init__(self):
        # 硬编码依赖,难以测试和扩展
        self.email_service = EmailService()
    
    def register_user(self, user):
        # 注册逻辑
        print(f"注册用户: {user}")
        self.email_service.send(f"欢迎 {user}")

# 依赖注入方式
class UserServiceWithDI:
    def __init__(self, email_service):
        # 依赖通过构造函数注入
        self.email_service = email_service
    
    def register_user(self, user):
        print(f"注册用户: {user}")
        self.email_service.send(f"欢迎 {user}")

# 使用示例
email_service = EmailService()
user_service = UserServiceWithDI(email_service)
user_service.register_user("张三")

依赖注入的优势

python 复制代码
"""
依赖注入的优势:

1. 降低耦合度:组件之间松耦合
2. 提高可测试性:容易进行单元测试
3. 增强可扩展性:容易替换实现
4. 符合开闭原则:对扩展开放,对修改封闭
5. 单一职责:每个类专注于自己的职责
"""

# 优势演示
from abc import ABC, abstractmethod

# 定义接口
class NotificationService(ABC):
    @abstractmethod
    def send(self, message: str) -> None:
        pass

# 多种实现
class EmailNotificationService(NotificationService):
    def send(self, message: str) -> None:
        print(f"📧 邮件通知: {message}")

class SMSNotificationService(NotificationService):
    def send(self, message: str) -> None:
        print(f"📱 短信通知: {message}")

class PushNotificationService(NotificationService):
    def send(self, message: str) -> None:
        print(f"🔔 推送通知: {message}")

# 业务服务
class OrderService:
    def __init__(self, notification_service: NotificationService):
        self.notification_service = notification_service
    
    def create_order(self, order_id: str):
        print(f"创建订单: {order_id}")
        self.notification_service.send(f"订单 {order_id} 创建成功")

# 使用不同的通知服务
email_service = EmailNotificationService()
sms_service = SMSNotificationService()
push_service = PushNotificationService()

# 可以轻松切换实现
order_service_email = OrderService(email_service)
order_service_sms = OrderService(sms_service)
order_service_push = OrderService(push_service)

order_service_email.create_order("ORD001")
order_service_sms.create_order("ORD002")
order_service_push.create_order("ORD003")

2. 手动依赖注入

构造函数注入

python 复制代码
# 构造函数注入示例
class DatabaseConnection:
    def __init__(self, host: str, port: int):
        self.host = host
        self.port = port
        print(f"连接数据库: {host}:{port}")
    
    def execute(self, query: str):
        print(f"执行查询: {query}")
        return f"查询结果 for {query}"

class CacheService:
    def __init__(self, host: str, port: int):
        self.host = host
        self.port = port
        print(f"连接缓存: {host}:{port}")
    
    def get(self, key: str):
        return f"缓存值 for {key}"
    
    def set(self, key: str, value: str):
        print(f"设置缓存: {key} = {value}")

class UserRepository:
    def __init__(self, db_connection: DatabaseConnection):
        self.db = db_connection
    
    def find_by_id(self, user_id: int):
        result = self.db.execute(f"SELECT * FROM users WHERE id = {user_id}")
        return {"id": user_id, "name": f"用户{user_id}"}
    
    def save(self, user: dict):
        self.db.execute(f"INSERT INTO users VALUES {user}")
        print(f"保存用户: {user}")

class UserService:
    def __init__(self, user_repository: UserRepository, cache_service: CacheService):
        self.user_repository = user_repository
        self.cache = cache_service
    
    def get_user(self, user_id: int):
        # 先查缓存
        cache_key = f"user:{user_id}"
        cached_user = self.cache.get(cache_key)
        
        if cached_user:
            print(f"从缓存获取用户: {cached_user}")
            return cached_user
        
        # 查数据库
        user = self.user_repository.find_by_id(user_id)
        
        # 设置缓存
        self.cache.set(cache_key, str(user))
        
        return user

# 手动组装依赖
def create_user_service():
    """工厂函数创建用户服务"""
    # 创建底层依赖
    db_connection = DatabaseConnection("localhost", 5432)
    cache_service = CacheService("localhost", 6379)
    
    # 创建中间层依赖
    user_repository = UserRepository(db_connection)
    
    # 创建顶层服务
    user_service = UserService(user_repository, cache_service)
    
    return user_service

# 使用示例
user_service = create_user_service()
user = user_service.get_user(1)
print(f"获取到用户: {user}")

属性注入

python 复制代码
# 属性注入示例
class Logger:
    def log(self, message: str):
        print(f"[LOG] {message}")

class ConfigService:
    def get(self, key: str):
        config = {
            "api_key": "secret_key_123",
            "timeout": 30,
            "retries": 3
        }
        return config.get(key)

class ApiClient:
    def __init__(self):
        self.logger = None
        self.config = None
    
    def set_logger(self, logger: Logger):
        """属性注入 - Setter注入"""
        self.logger = logger
        return self
    
    def set_config(self, config: ConfigService):
        """属性注入 - Setter注入"""
        self.config = config
        return self
    
    def call_api(self, endpoint: str):
        if self.logger:
            self.logger.log(f"调用API: {endpoint}")
        
        api_key = self.config.get("api_key") if self.config else "default_key"
        timeout = self.config.get("timeout") if self.config else 10
        
        print(f"API调用: {endpoint}, key: {api_key}, timeout: {timeout}")
        return {"status": "success", "data": "api_response"}

# 使用链式调用进行属性注入
logger = Logger()
config = ConfigService()

api_client = (ApiClient()
              .set_logger(logger)
              .set_config(config))

result = api_client.call_api("/api/users")
print(f"API结果: {result}")

方法注入

python 复制代码
# 方法注入示例
class EmailTemplate:
    def render(self, template_name: str, context: dict):
        return f"模板 {template_name}: {context}"

class EmailSender:
    def send(self, to: str, subject: str, body: str):
        print(f"发送邮件到 {to}: {subject}")
        print(f"内容: {body}")

class NotificationManager:
    def __init__(self):
        self.default_template = None
        self.default_sender = None
    
    def send_notification(self, to: str, template_name: str, context: dict, 
                         template_service: EmailTemplate = None, 
                         sender_service: EmailSender = None):
        """方法注入 - 通过方法参数注入依赖"""
        
        # 使用注入的依赖或默认依赖
        template_service = template_service or self.default_template
        sender_service = sender_service or self.default_sender
        
        if not template_service or not sender_service:
            raise ValueError("缺少必要的依赖服务")
        
        # 渲染模板
        body = template_service.render(template_name, context)
        
        # 发送邮件
        sender_service.send(to, f"通知: {template_name}", body)

# 使用示例
template_service = EmailTemplate()
sender_service = EmailSender()
notification_manager = NotificationManager()

# 方法级别注入依赖
notification_manager.send_notification(
    to="user@example.com",
    template_name="welcome",
    context={"username": "张三"},
    template_service=template_service,
    sender_service=sender_service
)

3. 简单依赖注入容器

基础容器实现

python 复制代码
# 简单依赖注入容器
from typing import Dict, Any, Callable, TypeVar, Type
import inspect

T = TypeVar('T')

class DIContainer:
    """简单的依赖注入容器"""
    
    def __init__(self):
        self._services: Dict[str, Any] = {}
        self._factories: Dict[str, Callable] = {}
        self._singletons: Dict[str, Any] = {}
    
    def register(self, interface: Type[T], implementation: Type[T] = None, singleton: bool = True):
        """注册服务"""
        if implementation is None:
            implementation = interface
        
        service_name = interface.__name__
        
        if singleton:
            self._factories[service_name] = lambda: self._create_instance(implementation)
        else:
            self._services[service_name] = implementation
        
        return self
    
    def register_instance(self, interface: Type[T], instance: T):
        """注册实例"""
        service_name = interface.__name__
        self._singletons[service_name] = instance
        return self
    
    def register_factory(self, interface: Type[T], factory: Callable[[], T]):
        """注册工厂函数"""
        service_name = interface.__name__
        self._factories[service_name] = factory
        return self
    
    def resolve(self, interface: Type[T]) -> T:
        """解析依赖"""
        service_name = interface.__name__
        
        # 检查单例
        if service_name in self._singletons:
            return self._singletons[service_name]
        
        # 检查工厂
        if service_name in self._factories:
            instance = self._factories[service_name]()
            self._singletons[service_name] = instance  # 缓存单例
            return instance
        
        # 检查注册的服务
        if service_name in self._services:
            return self._create_instance(self._services[service_name])
        
        # 尝试自动解析
        return self._create_instance(interface)
    
    def _create_instance(self, cls: Type[T]) -> T:
        """创建实例,自动解析构造函数依赖"""
        # 获取构造函数签名
        signature = inspect.signature(cls.__init__)
        
        # 准备构造函数参数
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            # 获取参数类型
            param_type = param.annotation
            if param_type != inspect.Parameter.empty:
                # 递归解析依赖
                kwargs[param_name] = self.resolve(param_type)
        
        return cls(**kwargs)

# 使用示例
class ILogger:
    def log(self, message: str):
        pass

class ConsoleLogger(ILogger):
    def log(self, message: str):
        print(f"[Console] {message}")

class IDatabase:
    def query(self, sql: str):
        pass

class MySQLDatabase(IDatabase):
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.logger.log("MySQL数据库已连接")
    
    def query(self, sql: str):
        self.logger.log(f"执行SQL: {sql}")
        return f"结果: {sql}"

class UserService:
    def __init__(self, database: IDatabase, logger: ILogger):
        self.database = database
        self.logger = logger
    
    def create_user(self, username: str):
        self.logger.log(f"创建用户: {username}")
        result = self.database.query(f"INSERT INTO users (name) VALUES ('{username}')")
        return result

# 配置容器
container = DIContainer()

# 注册服务
container.register(ILogger, ConsoleLogger)
container.register(IDatabase, MySQLDatabase)
container.register(UserService)

# 解析服务
user_service = container.resolve(UserService)
result = user_service.create_user("张三")
print(f"创建结果: {result}")

高级容器功能

python 复制代码
# 高级依赖注入容器
from enum import Enum
from typing import Dict, Any, Callable, TypeVar, Type, Optional
import inspect
import threading

class Lifetime(Enum):
    """服务生命周期"""
    TRANSIENT = "transient"    # 每次创建新实例
    SINGLETON = "singleton"    # 单例
    SCOPED = "scoped"         # 作用域内单例

class ServiceDescriptor:
    """服务描述符"""
    def __init__(self, 
                 interface: Type,
                 implementation: Type = None,
                 factory: Callable = None,
                 instance: Any = None,
                 lifetime: Lifetime = Lifetime.SINGLETON):
        self.interface = interface
        self.implementation = implementation or interface
        self.factory = factory
        self.instance = instance
        self.lifetime = lifetime

class AdvancedDIContainer:
    """高级依赖注入容器"""
    
    def __init__(self):
        self._descriptors: Dict[Type, ServiceDescriptor] = {}
        self._singletons: Dict[Type, Any] = {}
        self._scoped_instances: Dict[str, Dict[Type, Any]] = {}
        self._lock = threading.Lock()
        self._current_scope: Optional[str] = None
    
    def register_transient(self, interface: Type, implementation: Type = None):
        """注册瞬态服务"""
        descriptor = ServiceDescriptor(interface, implementation, lifetime=Lifetime.TRANSIENT)
        self._descriptors[interface] = descriptor
        return self
    
    def register_singleton(self, interface: Type, implementation: Type = None):
        """注册单例服务"""
        descriptor = ServiceDescriptor(interface, implementation, lifetime=Lifetime.SINGLETON)
        self._descriptors[interface] = descriptor
        return self
    
    def register_scoped(self, interface: Type, implementation: Type = None):
        """注册作用域服务"""
        descriptor = ServiceDescriptor(interface, implementation, lifetime=Lifetime.SCOPED)
        self._descriptors[interface] = descriptor
        return self
    
    def register_instance(self, interface: Type, instance: Any):
        """注册实例"""
        descriptor = ServiceDescriptor(interface, instance=instance, lifetime=Lifetime.SINGLETON)
        self._descriptors[interface] = descriptor
        self._singletons[interface] = instance
        return self
    
    def register_factory(self, interface: Type, factory: Callable, lifetime: Lifetime = Lifetime.SINGLETON):
        """注册工厂"""
        descriptor = ServiceDescriptor(interface, factory=factory, lifetime=lifetime)
        self._descriptors[interface] = descriptor
        return self
    
    def create_scope(self, scope_id: str = None):
        """创建作用域"""
        if scope_id is None:
            import uuid
            scope_id = str(uuid.uuid4())
        
        return DIScope(self, scope_id)
    
    def resolve(self, interface: Type, scope_id: str = None):
        """解析服务"""
        if interface not in self._descriptors:
            raise ValueError(f"服务 {interface.__name__} 未注册")
        
        descriptor = self._descriptors[interface]
        
        if descriptor.lifetime == Lifetime.SINGLETON:
            return self._get_singleton(interface, descriptor)
        elif descriptor.lifetime == Lifetime.SCOPED:
            return self._get_scoped(interface, descriptor, scope_id)
        else:  # TRANSIENT
            return self._create_instance(descriptor)
    
    def _get_singleton(self, interface: Type, descriptor: ServiceDescriptor):
        """获取单例"""
        if interface in self._singletons:
            return self._singletons[interface]
        
        with self._lock:
            if interface not in self._singletons:
                instance = self._create_instance(descriptor)
                self._singletons[interface] = instance
        
        return self._singletons[interface]
    
    def _get_scoped(self, interface: Type, descriptor: ServiceDescriptor, scope_id: str):
        """获取作用域实例"""
        if not scope_id:
            scope_id = self._current_scope
            if not scope_id:
                raise ValueError("作用域服务需要在作用域内解析")
        
        if scope_id not in self._scoped_instances:
            self._scoped_instances[scope_id] = {}
        
        scoped_dict = self._scoped_instances[scope_id]
        
        if interface not in scoped_dict:
            scoped_dict[interface] = self._create_instance(descriptor)
        
        return scoped_dict[interface]
    
    def _create_instance(self, descriptor: ServiceDescriptor):
        """创建实例"""
        if descriptor.instance is not None:
            return descriptor.instance
        
        if descriptor.factory is not None:
            return descriptor.factory()
        
        return self._auto_resolve(descriptor.implementation)
    
    def _auto_resolve(self, cls: Type):
        """自动解析构造函数依赖"""
        signature = inspect.signature(cls.__init__)
        kwargs = {}
        
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            param_type = param.annotation
            if param_type != inspect.Parameter.empty:
                kwargs[param_name] = self.resolve(param_type)
        
        return cls(**kwargs)
    
    def dispose_scope(self, scope_id: str):
        """释放作用域"""
        if scope_id in self._scoped_instances:
            scoped_dict = self._scoped_instances[scope_id]
            
            # 调用实例的dispose方法(如果存在)
            for instance in scoped_dict.values():
                if hasattr(instance, 'dispose'):
                    instance.dispose()
            
            del self._scoped_instances[scope_id]

class DIScope:
    """依赖注入作用域"""
    
    def __init__(self, container: AdvancedDIContainer, scope_id: str):
        self.container = container
        self.scope_id = scope_id
    
    def __enter__(self):
        self.container._current_scope = self.scope_id
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.container._current_scope = None
        self.container.dispose_scope(self.scope_id)
    
    def resolve(self, interface: Type):
        """在作用域内解析服务"""
        return self.container.resolve(interface, self.scope_id)

# 使用示例
class IRepository:
    def save(self, data):
        pass

class DatabaseRepository(IRepository):
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.logger.log("DatabaseRepository 已创建")
    
    def save(self, data):
        self.logger.log(f"保存数据: {data}")
    
    def dispose(self):
        print("DatabaseRepository 资源已释放")

class TransientService:
    def __init__(self):
        print(f"TransientService 实例已创建: {id(self)}")

# 配置高级容器
container = AdvancedDIContainer()

# 注册不同生命周期的服务
container.register_singleton(ILogger, ConsoleLogger)
container.register_scoped(IRepository, DatabaseRepository)
container.register_transient(TransientService)

# 测试单例
logger1 = container.resolve(ILogger)
logger2 = container.resolve(ILogger)
print(f"单例测试: {id(logger1) == id(logger2)}")

# 测试瞬态
trans1 = container.resolve(TransientService)
trans2 = container.resolve(TransientService)
print(f"瞬态测试: {id(trans1) == id(trans2)}")

# 测试作用域
with container.create_scope() as scope1:
    repo1 = scope1.resolve(IRepository)
    repo2 = scope1.resolve(IRepository)
    print(f"作用域内相同: {id(repo1) == id(repo2)}")

with container.create_scope() as scope2:
    repo3 = scope2.resolve(IRepository)
    print(f"不同作用域: {id(repo1) == id(repo3)}")

4. 高级依赖注入框架

dependency-injector 框架

python 复制代码
# 使用 dependency-injector 框架
# pip install dependency-injector

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
import logging

# 配置和服务定义
class Config:
    def __init__(self):
        self.database_url = "postgresql://localhost/mydb"
        self.redis_url = "redis://localhost:6379"
        self.log_level = logging.INFO

class Database:
    def __init__(self, url: str):
        self.url = url
        print(f"数据库连接: {url}")
    
    def get_connection(self):
        return f"Connection to {self.url}"

class Cache:
    def __init__(self, url: str):
        self.url = url
        print(f"缓存连接: {url}")
    
    def get(self, key: str):
        return f"Cached value for {key}"

class Logger:
    def __init__(self, level: int):
        self.level = level
        print(f"日志级别: {level}")
    
    def info(self, message: str):
        print(f"[INFO] {message}")

class UserRepository:
    def __init__(self, database: Database, logger: Logger):
        self.database = database
        self.logger = logger
    
    def find_user(self, user_id: int):
        self.logger.info(f"查找用户: {user_id}")
        conn = self.database.get_connection()
        return {"id": user_id, "name": f"用户{user_id}", "connection": conn}

class UserService:
    def __init__(self, repository: UserRepository, cache: Cache, logger: Logger):
        self.repository = repository
        self.cache = cache
        self.logger = logger
    
    def get_user(self, user_id: int):
        self.logger.info(f"获取用户服务调用: {user_id}")
        
        # 尝试从缓存获取
        cached = self.cache.get(f"user:{user_id}")
        if cached:
            self.logger.info("从缓存返回")
            return cached
        
        # 从数据库获取
        user = self.repository.find_user(user_id)
        self.logger.info("从数据库返回")
        return user

# 容器配置
class Container(containers.DeclarativeContainer):
    """依赖注入容器"""
    
    # 配置提供者
    config = providers.Singleton(Config)
    
    # 基础服务提供者
    database = providers.Singleton(
        Database,
        url=config.provided.database_url
    )
    
    cache = providers.Singleton(
        Cache,
        url=config.provided.redis_url
    )
    
    logger = providers.Singleton(
        Logger,
        level=config.provided.log_level
    )
    
    # 业务服务提供者
    user_repository = providers.Factory(
        UserRepository,
        database=database,
        logger=logger
    )
    
    user_service = providers.Factory(
        UserService,
        repository=user_repository,
        cache=cache,
        logger=logger
    )

# 使用装饰器注入
class UserController:
    @inject
    def __init__(self, 
                 user_service: UserService = Provide[Container.user_service],
                 logger: Logger = Provide[Container.logger]):
        self.user_service = user_service
        self.logger = logger
    
    def get_user_endpoint(self, user_id: int):
        self.logger.info(f"API调用: 获取用户 {user_id}")
        return self.user_service.get_user(user_id)

# 函数级注入
@inject
def get_user_function(user_id: int, 
                     user_service: UserService = Provide[Container.user_service]):
    return user_service.get_user(user_id)

# 使用示例
def demo_dependency_injector():
    # 创建容器
    container = Container()
    
    # 配置注入
    container.wire(modules=[__name__])
    
    # 直接从容器获取服务
    user_service = container.user_service()
    user = user_service.get_user(1)
    print(f"直接获取: {user}")
    
    # 使用注入的控制器
    controller = UserController()
    user = controller.get_user_endpoint(2)
    print(f"控制器获取: {user}")
    
    # 使用注入的函数
    user = get_user_function(3)
    print(f"函数获取: {user}")

# demo_dependency_injector()

自定义装饰器框架

python 复制代码
# 自定义装饰器依赖注入框架
from typing import Dict, Any, Callable, TypeVar, get_type_hints
from functools import wraps
import inspect

class Injectable:
    """标记类为可注入"""
    pass

class ServiceRegistry:
    """服务注册表"""
    _instance = None
    _services: Dict[type, Any] = {}
    _factories: Dict[type, Callable] = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    @classmethod
    def register(cls, interface: type, implementation: type = None, singleton: bool = True):
        """注册服务"""
        if implementation is None:
            implementation = interface
        
        if singleton:
            cls._factories[interface] = lambda: implementation
        else:
            cls._services[interface] = implementation
    
    @classmethod
    def get(cls, interface: type):
        """获取服务"""
        if interface in cls._services:
            return cls._services[interface]
        
        if interface in cls._factories:
            factory = cls._factories[interface]
            instance = cls._auto_wire(factory())
            cls._services[interface] = instance
            return instance
        
        # 尝试自动注册和解析
        return cls._auto_wire(interface)
    
    @classmethod
    def _auto_wire(cls, target):
        """自动注入"""
        if inspect.isclass(target):
            # 创建实例
            return cls._create_instance(target)
        else:
            # 已经是实例
            return target
    
    @classmethod
    def _create_instance(cls, target_class):
        """创建实例并注入依赖"""
        signature = inspect.signature(target_class.__init__)
        type_hints = get_type_hints(target_class.__init__)
        
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            # 获取参数类型
            param_type = type_hints.get(param_name)
            if param_type:
                kwargs[param_name] = cls.get(param_type)
        
        return target_class(**kwargs)

def injectable(cls):
    """可注入装饰器"""
    original_init = cls.__init__
    
    @wraps(original_init)
    def new_init(self, *args, **kwargs):
        # 如果没有提供参数,尝试自动注入
        if not args and not kwargs:
            signature = inspect.signature(original_init)
            type_hints = get_type_hints(original_init)
            
            for param_name, param in signature.parameters.items():
                if param_name == 'self':
                    continue
                
                param_type = type_hints.get(param_name)
                if param_type:
                    kwargs[param_name] = ServiceRegistry.get(param_type)
        
        original_init(self, *args, **kwargs)
    
    cls.__init__ = new_init
    return cls

def inject(func):
    """方法注入装饰器"""
    signature = inspect.signature(func)
    type_hints = get_type_hints(func)
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 为缺失的参数注入依赖
        bound_args = signature.bind_partial(*args, **kwargs)
        
        for param_name, param in signature.parameters.items():
            if param_name not in bound_args.arguments:
                param_type = type_hints.get(param_name)
                if param_type:
                    bound_args.arguments[param_name] = ServiceRegistry.get(param_type)
        
        return func(**bound_args.arguments)
    
    return wrapper

# 使用示例
class IEmailService:
    def send_email(self, to: str, subject: str, body: str):
        pass

@injectable
class EmailService(IEmailService):
    def send_email(self, to: str, subject: str, body: str):
        print(f"发送邮件到 {to}: {subject}")

class IUserRepository:
    def get_user(self, user_id: int):
        pass

@injectable
class UserRepository(IUserRepository):
    def __init__(self, email_service: IEmailService):
        self.email_service = email_service
    
    def get_user(self, user_id: int):
        user = {"id": user_id, "email": f"user{user_id}@example.com"}
        return user

@injectable
class UserService:
    def __init__(self, user_repository: IUserRepository, email_service: IEmailService):
        self.user_repository = user_repository
        self.email_service = email_service
    
    def notify_user(self, user_id: int, message: str):
        user = self.user_repository.get_user(user_id)
        self.email_service.send_email(user["email"], "通知", message)

# 注册服务
ServiceRegistry.register(IEmailService, EmailService)
ServiceRegistry.register(IUserRepository, UserRepository)

# 使用注入的函数
@inject
def send_notification(user_id: int, message: str, user_service: UserService):
    user_service.notify_user(user_id, message)

# 测试
def demo_custom_di():
    # 自动创建并注入依赖
    user_service = UserService()
    user_service.notify_user(1, "欢迎使用系统")
    
    # 使用注入的函数
    send_notification(2, "系统维护通知")

# demo_custom_di()

5. 装饰器实现依赖注入

基于装饰器的高级依赖注入

python 复制代码
# 高级装饰器依赖注入系统
from typing import Dict, Any, Callable, Type, get_type_hints
from functools import wraps
import inspect
from enum import Enum

class Scope(Enum):
    SINGLETON = "singleton"
    TRANSIENT = "transient"
    REQUEST = "request"

class DIMetadata:
    """依赖注入元数据"""
    def __init__(self):
        self.scope = Scope.SINGLETON
        self.lazy = False
        self.factory = None
        self.dependencies = []

class AdvancedDIRegistry:
    """高级依赖注入注册表"""
    _services: Dict[Type, Any] = {}
    _metadata: Dict[Type, DIMetadata] = {}
    _instances: Dict[Type, Any] = {}
    _request_instances: Dict[str, Dict[Type, Any]] = {}
    
    @classmethod
    def register_service(cls, service_type: Type, implementation: Type = None, 
                        scope: Scope = Scope.SINGLETON, lazy: bool = False):
        """注册服务"""
        if implementation is None:
            implementation = service_type
        
        cls._services[service_type] = implementation
        
        metadata = DIMetadata()
        metadata.scope = scope
        metadata.lazy = lazy
        cls._metadata[service_type] = metadata
    
    @classmethod
    def register_factory(cls, service_type: Type, factory: Callable, 
                         scope: Scope = Scope.SINGLETON):
        """注册工厂函数"""
        cls._services[service_type] = None
        
        metadata = DIMetadata()
        metadata.scope = scope
        metadata.factory = factory
        cls._metadata[service_type] = metadata
    
    @classmethod
    def resolve(cls, service_type: Type, request_id: str = None):
        """解析服务"""
        if service_type not in cls._services:
            raise ValueError(f"服务 {service_type.__name__} 未注册")
        
        metadata = cls._metadata[service_type]
        
        if metadata.scope == Scope.SINGLETON:
            return cls._get_singleton(service_type, metadata)
        elif metadata.scope == Scope.REQUEST:
            return cls._get_request_scoped(service_type, metadata, request_id)
        else:  # TRANSIENT
            return cls._create_instance(service_type, metadata)
    
    @classmethod
    def _get_singleton(cls, service_type: Type, metadata: DIMetadata):
        """获取单例"""
        if service_type in cls._instances:
            return cls._instances[service_type]
        
        instance = cls._create_instance(service_type, metadata)
        cls._instances[service_type] = instance
        return instance
    
    @classmethod
    def _get_request_scoped(cls, service_type: Type, metadata: DIMetadata, request_id: str):
        """获取请求作用域实例"""
        if not request_id:
            raise ValueError("请求作用域服务需要提供request_id")
        
        if request_id not in cls._request_instances:
            cls._request_instances[request_id] = {}
        
        request_dict = cls._request_instances[request_id]
        
        if service_type not in request_dict:
            request_dict[service_type] = cls._create_instance(service_type, metadata)
        
        return request_dict[service_type]
    
    @classmethod
    def _create_instance(cls, service_type: Type, metadata: DIMetadata):
        """创建实例"""
        if metadata.factory:
            return metadata.factory()
        
        implementation = cls._services[service_type]
        if implementation is None:
            implementation = service_type
        
        return cls._inject_dependencies(implementation)
    
    @classmethod
    def _inject_dependencies(cls, target_class: Type):
        """注入依赖"""
        signature = inspect.signature(target_class.__init__)
        type_hints = get_type_hints(target_class.__init__)
        
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            param_type = type_hints.get(param_name)
            if param_type and param_type in cls._services:
                kwargs[param_name] = cls.resolve(param_type)
        
        return target_class(**kwargs)
    
    @classmethod
    def clear_request_scope(cls, request_id: str):
        """清理请求作用域"""
        if request_id in cls._request_instances:
            del cls._request_instances[request_id]

# 装饰器定义
def service(scope: Scope = Scope.SINGLETON, lazy: bool = False):
    """服务装饰器"""
    def decorator(cls):
        AdvancedDIRegistry.register_service(cls, cls, scope, lazy)
        return cls
    return decorator

def singleton(cls):
    """单例装饰器"""
    return service(Scope.SINGLETON)(cls)

def transient(cls):
    """瞬态装饰器"""
    return service(Scope.TRANSIENT)(cls)

def request_scoped(cls):
    """请求作用域装饰器"""
    return service(Scope.REQUEST)(cls)

def inject_dependencies(cls):
    """依赖注入装饰器"""
    original_init = cls.__init__
    
    @wraps(original_init)
    def new_init(self, *args, **kwargs):
        # 获取构造函数的类型提示
        type_hints = get_type_hints(original_init)
        signature = inspect.signature(original_init)
        
        # 自动注入缺失的依赖
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            if param_name not in kwargs:
                param_type = type_hints.get(param_name)
                if param_type and param_type in AdvancedDIRegistry._services:
                    kwargs[param_name] = AdvancedDIRegistry.resolve(param_type)
        
        original_init(self, *args, **kwargs)
    
    cls.__init__ = new_init
    return cls

def inject_method(*dependencies):
    """方法依赖注入装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 获取函数的类型提示
            type_hints = get_type_hints(func)
            signature = inspect.signature(func)
            
            # 注入指定的依赖
            for param_name, param in signature.parameters.items():
                if param_name not in kwargs:
                    param_type = type_hints.get(param_name)
                    if param_type and param_type in dependencies:
                        kwargs[param_name] = AdvancedDIRegistry.resolve(param_type)
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 使用示例
from abc import ABC, abstractmethod

class ILogger(ABC):
    @abstractmethod
    def log(self, message: str):
        pass

@singleton
class ConsoleLogger(ILogger):
    def __init__(self):
        print("ConsoleLogger 单例已创建")
    
    def log(self, message: str):
        print(f"[LOG] {message}")

class IEmailService(ABC):
    @abstractmethod
    def send(self, to: str, subject: str, body: str):
        pass

@transient
class EmailService(IEmailService):
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.logger.log("EmailService 实例已创建")
    
    def send(self, to: str, subject: str, body: str):
        self.logger.log(f"发送邮件到 {to}: {subject}")

@request_scoped
class RequestContext:
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.request_id = id(self)
        self.logger.log(f"RequestContext 已创建: {self.request_id}")
    
    def get_request_id(self):
        return self.request_id

@inject_dependencies
class OrderService:
    def __init__(self, email_service: IEmailService, logger: ILogger, 
                 request_context: RequestContext):
        self.email_service = email_service
        self.logger = logger
        self.request_context = request_context
    
    def create_order(self, customer_email: str):
        request_id = self.request_context.get_request_id()
        self.logger.log(f"创建订单 - 请求ID: {request_id}")
        self.email_service.send(customer_email, "订单确认", "您的订单已创建")

# 注册接口实现
AdvancedDIRegistry.register_service(ILogger, ConsoleLogger)
AdvancedDIRegistry.register_service(IEmailService, EmailService)

# 工厂函数示例
def create_special_logger():
    return ConsoleLogger()

AdvancedDIRegistry.register_factory(ILogger, create_special_logger)

# 使用方法注入
@inject_method(ILogger, IEmailService)
def process_order(order_id: int, logger: ILogger, email_service: IEmailService):
    logger.log(f"处理订单: {order_id}")
    email_service.send("customer@example.com", "处理中", f"订单 {order_id} 正在处理")

# 测试
def demo_advanced_decorators():
    import uuid
    
    # 测试不同作用域
    request_id = str(uuid.uuid4())
    
    # 创建订单服务
    order_service1 = OrderService()
    order_service2 = OrderService()
    
    # 验证单例
    print(f"Logger是否为单例: {id(order_service1.logger) == id(order_service2.logger)}")
    
    # 验证瞬态
    print(f"EmailService是否为瞬态: {id(order_service1.email_service) == id(order_service2.email_service)}")
    
    # 创建订单
    order_service1.create_order("customer1@example.com")
    order_service2.create_order("customer2@example.com")
    
    # 使用方法注入
    process_order(12345)
    
    # 清理请求作用域
    AdvancedDIRegistry.clear_request_scope(request_id)

# demo_advanced_decorators()

6. 类型注解与依赖注入

基于类型注解的依赖注入

python 复制代码
# 基于类型注解的依赖注入系统
from typing import Dict, Any, Callable, Type, TypeVar, get_type_hints, get_origin, get_args
from functools import wraps
import inspect
from abc import ABC, abstractmethod

T = TypeVar('T')

class TypedDIContainer:
    """基于类型注解的依赖注入容器"""
    
    def __init__(self):
        self._services: Dict[Type, Type] = {}
        self._instances: Dict[Type, Any] = {}
        self._factories: Dict[Type, Callable] = {}
        self._configurations: Dict[str, Any] = {}
    
    def register(self, interface: Type[T], implementation: Type[T] = None) -> 'TypedDIContainer':
        """注册服务类型"""
        if implementation is None:
            implementation = interface
        
        self._services[interface] = implementation
        return self
    
    def register_instance(self, interface: Type[T], instance: T) -> 'TypedDIContainer':
        """注册实例"""
        self._instances[interface] = instance
        return self
    
    def register_factory(self, interface: Type[T], factory: Callable[[], T]) -> 'TypedDIContainer':
        """注册工厂函数"""
        self._factories[interface] = factory
        return self
    
    def configure(self, key: str, value: Any) -> 'TypedDIContainer':
        """配置值"""
        self._configurations[key] = value
        return self
    
    def resolve(self, interface: Type[T]) -> T:
        """解析类型"""
        # 检查实例缓存
        if interface in self._instances:
            return self._instances[interface]
        
        # 检查工厂函数
        if interface in self._factories:
            instance = self._factories[interface]()
            self._instances[interface] = instance
            return instance
        
        # 检查注册的服务
        if interface in self._services:
            implementation = self._services[interface]
            instance = self._create_typed_instance(implementation)
            self._instances[interface] = instance
            return instance
        
        # 尝试直接创建
        return self._create_typed_instance(interface)
    
    def _create_typed_instance(self, cls: Type[T]) -> T:
        """创建类型化实例"""
        # 获取构造函数类型提示
        type_hints = get_type_hints(cls.__init__)
        signature = inspect.signature(cls.__init__)
        
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            param_type = type_hints.get(param_name)
            if param_type:
                # 处理特殊类型
                if self._is_config_type(param_type):
                    kwargs[param_name] = self._resolve_config(param_name, param_type)
                else:
                    kwargs[param_name] = self.resolve(param_type)
        
        return cls(**kwargs)
    
    def _is_config_type(self, param_type: Type) -> bool:
        """检查是否为配置类型"""
        return param_type in [str, int, float, bool] or hasattr(param_type, '__origin__')
    
    def _resolve_config(self, param_name: str, param_type: Type):
        """解析配置值"""
        if param_name in self._configurations:
            return self._configurations[param_name]
        
        # 默认值处理
        defaults = {
            str: "",
            int: 0,
            float: 0.0,
            bool: False
        }
        return defaults.get(param_type, None)

# 类型安全的装饰器
def typed_injectable(container: TypedDIContainer):
    """类型安全的可注入装饰器"""
    def decorator(cls: Type[T]) -> Type[T]:
        original_init = cls.__init__
        
        @wraps(original_init)
        def new_init(self, *args, **kwargs):
            if not args and not kwargs:
                # 使用类型提示自动注入
                type_hints = get_type_hints(original_init)
                signature = inspect.signature(original_init)
                
                for param_name, param in signature.parameters.items():
                    if param_name == 'self':
                        continue
                    
                    param_type = type_hints.get(param_name)
                    if param_type:
                        if container._is_config_type(param_type):
                            kwargs[param_name] = container._resolve_config(param_name, param_type)
                        else:
                            kwargs[param_name] = container.resolve(param_type)
            
            original_init(self, *args, **kwargs)
        
        cls.__init__ = new_init
        return cls
    
    return decorator

def typed_inject(container: TypedDIContainer):
    """类型安全的方法注入装饰器"""
    def decorator(func: Callable) -> Callable:
        type_hints = get_type_hints(func)
        signature = inspect.signature(func)
        
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 自动填充缺失的类型化参数
            bound_args = signature.bind_partial(*args, **kwargs)
            
            for param_name, param in signature.parameters.items():
                if param_name not in bound_args.arguments:
                    param_type = type_hints.get(param_name)
                    if param_type:
                        if container._is_config_type(param_type):
                            bound_args.arguments[param_name] = container._resolve_config(param_name, param_type)
                        else:
                            bound_args.arguments[param_name] = container.resolve(param_type)
            
            return func(**bound_args.arguments)
        
        return wrapper
    
    return decorator

# 泛型支持
from typing import Generic, List, Optional

class Repository(Generic[T], ABC):
    """通用仓储接口"""
    
    @abstractmethod
    def find_by_id(self, id: int) -> Optional[T]:
        pass
    
    @abstractmethod
    def find_all(self) -> List[T]:
        pass
    
    @abstractmethod
    def save(self, entity: T) -> T:
        pass

class User:
    def __init__(self, id: int, name: str, email: str):
        self.id = id
        self.name = name
        self.email = email
    
    def __repr__(self):
        return f"User(id={self.id}, name='{self.name}', email='{self.email}')"

class Product:
    def __init__(self, id: int, name: str, price: float):
        self.id = id
        self.name = name
        self.price = price
    
    def __repr__(self):
        return f"Product(id={self.id}, name='{self.name}', price={self.price})"

class UserRepository(Repository[User]):
    """用户仓储实现"""
    
    def __init__(self, database_url: str):
        self.database_url = database_url
        self._users = [
            User(1, "张三", "zhangsan@example.com"),
            User(2, "李四", "lisi@example.com")
        ]
    
    def find_by_id(self, id: int) -> Optional[User]:
        return next((user for user in self._users if user.id == id), None)
    
    def find_all(self) -> List[User]:
        return self._users.copy()
    
    def save(self, entity: User) -> User:
        self._users.append(entity)
        return entity

class ProductRepository(Repository[Product]):
    """产品仓储实现"""
    
    def __init__(self, database_url: str):
        self.database_url = database_url
        self._products = [
            Product(1, "笔记本电脑", 5999.99),
            Product(2, "智能手机", 2999.99)
        ]
    
    def find_by_id(self, id: int) -> Optional[Product]:
        return next((product for product in self._products if product.id == id), None)
    
    def find_all(self) -> List[Product]:
        return self._products.copy()
    
    def save(self, entity: Product) -> Product:
        self._products.append(entity)
        return entity

class EmailService:
    """邮件服务"""
    
    def __init__(self, smtp_host: str, smtp_port: int):
        self.smtp_host = smtp_host
        self.smtp_port = smtp_port
    
    def send_email(self, to: str, subject: str, body: str):
        print(f"发送邮件到 {to} (通过 {self.smtp_host}:{self.smtp_port})")
        print(f"主题: {subject}")
        print(f"内容: {body}")

class UserService:
    """用户服务"""
    
    def __init__(self, 
                 user_repository: Repository[User],
                 email_service: EmailService):
        self.user_repository = user_repository
        self.email_service = email_service
    
    def get_user(self, user_id: int) -> Optional[User]:
        return self.user_repository.find_by_id(user_id)
    
    def create_user(self, name: str, email: str) -> User:
        user_id = len(self.user_repository.find_all()) + 1
        user = User(user_id, name, email)
        saved_user = self.user_repository.save(user)
        
        # 发送欢迎邮件
        self.email_service.send_email(
            user.email, 
            "欢迎注册", 
            f"欢迎 {user.name} 注册我们的服务!"
        )
        
        return saved_user

# 配置容器
def setup_typed_container():
    """设置类型化容器"""
    container = TypedDIContainer()
    
    # 配置基础设置
    container.configure("database_url", "postgresql://localhost/mydb")
    container.configure("smtp_host", "smtp.example.com")
    container.configure("smtp_port", 587)
    
    # 注册泛型仓储
    container.register(Repository[User], UserRepository)
    container.register(Repository[Product], ProductRepository)
    
    # 注册服务
    container.register(EmailService)
    container.register(UserService)
    
    return container

# 使用类型安全的装饰器
container = setup_typed_container()

@typed_injectable(container)
class OrderService:
    """订单服务"""
    
    def __init__(self, 
                 user_repository: Repository[User],
                 product_repository: Repository[Product],
                 email_service: EmailService):
        self.user_repository = user_repository
        self.product_repository = product_repository
        self.email_service = email_service
    
    def create_order(self, user_id: int, product_id: int):
        user = self.user_repository.find_by_id(user_id)
        product = self.product_repository.find_by_id(product_id)
        
        if not user or not product:
            raise ValueError("用户或产品不存在")
        
        order_id = f"ORD-{user_id}-{product_id}"
        
        self.email_service.send_email(
            user.email,
            "订单确认",
            f"您的订单 {order_id} 已创建,产品: {product.name}"
        )
        
        return {
            "order_id": order_id,
            "user": user,
            "product": product
        }

@typed_inject(container)
def get_user_info(user_id: int, user_service: UserService) -> Optional[User]:
    """获取用户信息的函数"""
    return user_service.get_user(user_id)

# 测试类型化依赖注入
def demo_typed_di():
    """演示类型化依赖注入"""
    
    # 直接解析服务
    user_service = container.resolve(UserService)
    
    # 创建用户
    new_user = user_service.create_user("王五", "wangwu@example.com")
    print(f"创建用户: {new_user}")
    
    # 获取用户
    user = user_service.get_user(1)
    print(f"获取用户: {user}")
    
    # 使用装饰器注入的类
    order_service = OrderService()
    order = order_service.create_order(1, 1)
    print(f"创建订单: {order}")
    
    # 使用装饰器注入的函数
    user_info = get_user_info(2)
    print(f"函数获取用户: {user_info}")

# demo_typed_di()
相关推荐
Swizard16 分钟前
拒绝“狗熊掰棒子”!用 EWC (Elastic Weight Consolidation) 彻底终结 AI 的灾难性遗忘
python·算法·ai·训练
Spider赵毅18 分钟前
python实战 | 如何使用海外代理IP抓取Amazon黑五数据
python·tcp/ip·php
月光技术杂谈23 分钟前
基于Python的网络性能分析实践:从Ping原理到自动化监控
网络·python·性能分析·ping·时延·自动化监控
龘龍龙27 分钟前
Python基础学习(四)
开发语言·python·学习
洵有兮1 小时前
python第四次作业
开发语言·python
kkoral1 小时前
单机docker部署的redis sentinel,使用python调用redis,报错
redis·python·docker·sentinel
BoBoZz191 小时前
IterativeClosestPoints icp配准矩阵
python·vtk·图形渲染·图形处理
test管家2 小时前
PyTorch动态图编程与自定义网络层实战教程
python
laocooon5238578862 小时前
python 收发信的功能。
开发语言·python
清水白石0082 小时前
《Python 责任链模式实战指南:从设计思想到工程落地》
开发语言·python·责任链模式