Python 依赖注入详解

Python 依赖注入详解

目录

  1. 依赖注入基础概念
  2. 手动依赖注入
  3. 简单依赖注入容器
  4. 高级依赖注入框架
  5. 装饰器实现依赖注入
  6. 类型注解与依赖注入
  7. 生命周期管理
  8. 实际应用场景

1. 依赖注入基础概念

什么是依赖注入?

依赖注入(Dependency Injection,DI)是一种设计模式,用于实现控制反转(IoC)。它允许对象的依赖关系在运行时由外部容器注入,而不是在对象内部创建。

python 复制代码
# 传统方式 - 硬编码依赖
class EmailService:
    def send(self, message):
        print(f"发送邮件: {message}")

class UserService:
    def __init__(self):
        # 硬编码依赖,难以测试和扩展
        self.email_service = EmailService()
    
    def register_user(self, user):
        # 注册逻辑
        print(f"注册用户: {user}")
        self.email_service.send(f"欢迎 {user}")

# 依赖注入方式
class UserServiceWithDI:
    def __init__(self, email_service):
        # 依赖通过构造函数注入
        self.email_service = email_service
    
    def register_user(self, user):
        print(f"注册用户: {user}")
        self.email_service.send(f"欢迎 {user}")

# 使用示例
email_service = EmailService()
user_service = UserServiceWithDI(email_service)
user_service.register_user("张三")

依赖注入的优势

python 复制代码
"""
依赖注入的优势:

1. 降低耦合度:组件之间松耦合
2. 提高可测试性:容易进行单元测试
3. 增强可扩展性:容易替换实现
4. 符合开闭原则:对扩展开放,对修改封闭
5. 单一职责:每个类专注于自己的职责
"""

# 优势演示
from abc import ABC, abstractmethod

# 定义接口
class NotificationService(ABC):
    @abstractmethod
    def send(self, message: str) -> None:
        pass

# 多种实现
class EmailNotificationService(NotificationService):
    def send(self, message: str) -> None:
        print(f"📧 邮件通知: {message}")

class SMSNotificationService(NotificationService):
    def send(self, message: str) -> None:
        print(f"📱 短信通知: {message}")

class PushNotificationService(NotificationService):
    def send(self, message: str) -> None:
        print(f"🔔 推送通知: {message}")

# 业务服务
class OrderService:
    def __init__(self, notification_service: NotificationService):
        self.notification_service = notification_service
    
    def create_order(self, order_id: str):
        print(f"创建订单: {order_id}")
        self.notification_service.send(f"订单 {order_id} 创建成功")

# 使用不同的通知服务
email_service = EmailNotificationService()
sms_service = SMSNotificationService()
push_service = PushNotificationService()

# 可以轻松切换实现
order_service_email = OrderService(email_service)
order_service_sms = OrderService(sms_service)
order_service_push = OrderService(push_service)

order_service_email.create_order("ORD001")
order_service_sms.create_order("ORD002")
order_service_push.create_order("ORD003")

2. 手动依赖注入

构造函数注入

python 复制代码
# 构造函数注入示例
class DatabaseConnection:
    def __init__(self, host: str, port: int):
        self.host = host
        self.port = port
        print(f"连接数据库: {host}:{port}")
    
    def execute(self, query: str):
        print(f"执行查询: {query}")
        return f"查询结果 for {query}"

class CacheService:
    def __init__(self, host: str, port: int):
        self.host = host
        self.port = port
        print(f"连接缓存: {host}:{port}")
    
    def get(self, key: str):
        return f"缓存值 for {key}"
    
    def set(self, key: str, value: str):
        print(f"设置缓存: {key} = {value}")

class UserRepository:
    def __init__(self, db_connection: DatabaseConnection):
        self.db = db_connection
    
    def find_by_id(self, user_id: int):
        result = self.db.execute(f"SELECT * FROM users WHERE id = {user_id}")
        return {"id": user_id, "name": f"用户{user_id}"}
    
    def save(self, user: dict):
        self.db.execute(f"INSERT INTO users VALUES {user}")
        print(f"保存用户: {user}")

class UserService:
    def __init__(self, user_repository: UserRepository, cache_service: CacheService):
        self.user_repository = user_repository
        self.cache = cache_service
    
    def get_user(self, user_id: int):
        # 先查缓存
        cache_key = f"user:{user_id}"
        cached_user = self.cache.get(cache_key)
        
        if cached_user:
            print(f"从缓存获取用户: {cached_user}")
            return cached_user
        
        # 查数据库
        user = self.user_repository.find_by_id(user_id)
        
        # 设置缓存
        self.cache.set(cache_key, str(user))
        
        return user

# 手动组装依赖
def create_user_service():
    """工厂函数创建用户服务"""
    # 创建底层依赖
    db_connection = DatabaseConnection("localhost", 5432)
    cache_service = CacheService("localhost", 6379)
    
    # 创建中间层依赖
    user_repository = UserRepository(db_connection)
    
    # 创建顶层服务
    user_service = UserService(user_repository, cache_service)
    
    return user_service

# 使用示例
user_service = create_user_service()
user = user_service.get_user(1)
print(f"获取到用户: {user}")

属性注入

python 复制代码
# 属性注入示例
class Logger:
    def log(self, message: str):
        print(f"[LOG] {message}")

class ConfigService:
    def get(self, key: str):
        config = {
            "api_key": "secret_key_123",
            "timeout": 30,
            "retries": 3
        }
        return config.get(key)

class ApiClient:
    def __init__(self):
        self.logger = None
        self.config = None
    
    def set_logger(self, logger: Logger):
        """属性注入 - Setter注入"""
        self.logger = logger
        return self
    
    def set_config(self, config: ConfigService):
        """属性注入 - Setter注入"""
        self.config = config
        return self
    
    def call_api(self, endpoint: str):
        if self.logger:
            self.logger.log(f"调用API: {endpoint}")
        
        api_key = self.config.get("api_key") if self.config else "default_key"
        timeout = self.config.get("timeout") if self.config else 10
        
        print(f"API调用: {endpoint}, key: {api_key}, timeout: {timeout}")
        return {"status": "success", "data": "api_response"}

# 使用链式调用进行属性注入
logger = Logger()
config = ConfigService()

api_client = (ApiClient()
              .set_logger(logger)
              .set_config(config))

result = api_client.call_api("/api/users")
print(f"API结果: {result}")

方法注入

python 复制代码
# 方法注入示例
class EmailTemplate:
    def render(self, template_name: str, context: dict):
        return f"模板 {template_name}: {context}"

class EmailSender:
    def send(self, to: str, subject: str, body: str):
        print(f"发送邮件到 {to}: {subject}")
        print(f"内容: {body}")

class NotificationManager:
    def __init__(self):
        self.default_template = None
        self.default_sender = None
    
    def send_notification(self, to: str, template_name: str, context: dict, 
                         template_service: EmailTemplate = None, 
                         sender_service: EmailSender = None):
        """方法注入 - 通过方法参数注入依赖"""
        
        # 使用注入的依赖或默认依赖
        template_service = template_service or self.default_template
        sender_service = sender_service or self.default_sender
        
        if not template_service or not sender_service:
            raise ValueError("缺少必要的依赖服务")
        
        # 渲染模板
        body = template_service.render(template_name, context)
        
        # 发送邮件
        sender_service.send(to, f"通知: {template_name}", body)

# 使用示例
template_service = EmailTemplate()
sender_service = EmailSender()
notification_manager = NotificationManager()

# 方法级别注入依赖
notification_manager.send_notification(
    to="user@example.com",
    template_name="welcome",
    context={"username": "张三"},
    template_service=template_service,
    sender_service=sender_service
)

3. 简单依赖注入容器

基础容器实现

python 复制代码
# 简单依赖注入容器
from typing import Dict, Any, Callable, TypeVar, Type
import inspect

T = TypeVar('T')

class DIContainer:
    """简单的依赖注入容器"""
    
    def __init__(self):
        self._services: Dict[str, Any] = {}
        self._factories: Dict[str, Callable] = {}
        self._singletons: Dict[str, Any] = {}
    
    def register(self, interface: Type[T], implementation: Type[T] = None, singleton: bool = True):
        """注册服务"""
        if implementation is None:
            implementation = interface
        
        service_name = interface.__name__
        
        if singleton:
            self._factories[service_name] = lambda: self._create_instance(implementation)
        else:
            self._services[service_name] = implementation
        
        return self
    
    def register_instance(self, interface: Type[T], instance: T):
        """注册实例"""
        service_name = interface.__name__
        self._singletons[service_name] = instance
        return self
    
    def register_factory(self, interface: Type[T], factory: Callable[[], T]):
        """注册工厂函数"""
        service_name = interface.__name__
        self._factories[service_name] = factory
        return self
    
    def resolve(self, interface: Type[T]) -> T:
        """解析依赖"""
        service_name = interface.__name__
        
        # 检查单例
        if service_name in self._singletons:
            return self._singletons[service_name]
        
        # 检查工厂
        if service_name in self._factories:
            instance = self._factories[service_name]()
            self._singletons[service_name] = instance  # 缓存单例
            return instance
        
        # 检查注册的服务
        if service_name in self._services:
            return self._create_instance(self._services[service_name])
        
        # 尝试自动解析
        return self._create_instance(interface)
    
    def _create_instance(self, cls: Type[T]) -> T:
        """创建实例,自动解析构造函数依赖"""
        # 获取构造函数签名
        signature = inspect.signature(cls.__init__)
        
        # 准备构造函数参数
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            # 获取参数类型
            param_type = param.annotation
            if param_type != inspect.Parameter.empty:
                # 递归解析依赖
                kwargs[param_name] = self.resolve(param_type)
        
        return cls(**kwargs)

# 使用示例
class ILogger:
    def log(self, message: str):
        pass

class ConsoleLogger(ILogger):
    def log(self, message: str):
        print(f"[Console] {message}")

class IDatabase:
    def query(self, sql: str):
        pass

class MySQLDatabase(IDatabase):
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.logger.log("MySQL数据库已连接")
    
    def query(self, sql: str):
        self.logger.log(f"执行SQL: {sql}")
        return f"结果: {sql}"

class UserService:
    def __init__(self, database: IDatabase, logger: ILogger):
        self.database = database
        self.logger = logger
    
    def create_user(self, username: str):
        self.logger.log(f"创建用户: {username}")
        result = self.database.query(f"INSERT INTO users (name) VALUES ('{username}')")
        return result

# 配置容器
container = DIContainer()

# 注册服务
container.register(ILogger, ConsoleLogger)
container.register(IDatabase, MySQLDatabase)
container.register(UserService)

# 解析服务
user_service = container.resolve(UserService)
result = user_service.create_user("张三")
print(f"创建结果: {result}")

高级容器功能

python 复制代码
# 高级依赖注入容器
from enum import Enum
from typing import Dict, Any, Callable, TypeVar, Type, Optional
import inspect
import threading

class Lifetime(Enum):
    """服务生命周期"""
    TRANSIENT = "transient"    # 每次创建新实例
    SINGLETON = "singleton"    # 单例
    SCOPED = "scoped"         # 作用域内单例

class ServiceDescriptor:
    """服务描述符"""
    def __init__(self, 
                 interface: Type,
                 implementation: Type = None,
                 factory: Callable = None,
                 instance: Any = None,
                 lifetime: Lifetime = Lifetime.SINGLETON):
        self.interface = interface
        self.implementation = implementation or interface
        self.factory = factory
        self.instance = instance
        self.lifetime = lifetime

class AdvancedDIContainer:
    """高级依赖注入容器"""
    
    def __init__(self):
        self._descriptors: Dict[Type, ServiceDescriptor] = {}
        self._singletons: Dict[Type, Any] = {}
        self._scoped_instances: Dict[str, Dict[Type, Any]] = {}
        self._lock = threading.Lock()
        self._current_scope: Optional[str] = None
    
    def register_transient(self, interface: Type, implementation: Type = None):
        """注册瞬态服务"""
        descriptor = ServiceDescriptor(interface, implementation, lifetime=Lifetime.TRANSIENT)
        self._descriptors[interface] = descriptor
        return self
    
    def register_singleton(self, interface: Type, implementation: Type = None):
        """注册单例服务"""
        descriptor = ServiceDescriptor(interface, implementation, lifetime=Lifetime.SINGLETON)
        self._descriptors[interface] = descriptor
        return self
    
    def register_scoped(self, interface: Type, implementation: Type = None):
        """注册作用域服务"""
        descriptor = ServiceDescriptor(interface, implementation, lifetime=Lifetime.SCOPED)
        self._descriptors[interface] = descriptor
        return self
    
    def register_instance(self, interface: Type, instance: Any):
        """注册实例"""
        descriptor = ServiceDescriptor(interface, instance=instance, lifetime=Lifetime.SINGLETON)
        self._descriptors[interface] = descriptor
        self._singletons[interface] = instance
        return self
    
    def register_factory(self, interface: Type, factory: Callable, lifetime: Lifetime = Lifetime.SINGLETON):
        """注册工厂"""
        descriptor = ServiceDescriptor(interface, factory=factory, lifetime=lifetime)
        self._descriptors[interface] = descriptor
        return self
    
    def create_scope(self, scope_id: str = None):
        """创建作用域"""
        if scope_id is None:
            import uuid
            scope_id = str(uuid.uuid4())
        
        return DIScope(self, scope_id)
    
    def resolve(self, interface: Type, scope_id: str = None):
        """解析服务"""
        if interface not in self._descriptors:
            raise ValueError(f"服务 {interface.__name__} 未注册")
        
        descriptor = self._descriptors[interface]
        
        if descriptor.lifetime == Lifetime.SINGLETON:
            return self._get_singleton(interface, descriptor)
        elif descriptor.lifetime == Lifetime.SCOPED:
            return self._get_scoped(interface, descriptor, scope_id)
        else:  # TRANSIENT
            return self._create_instance(descriptor)
    
    def _get_singleton(self, interface: Type, descriptor: ServiceDescriptor):
        """获取单例"""
        if interface in self._singletons:
            return self._singletons[interface]
        
        with self._lock:
            if interface not in self._singletons:
                instance = self._create_instance(descriptor)
                self._singletons[interface] = instance
        
        return self._singletons[interface]
    
    def _get_scoped(self, interface: Type, descriptor: ServiceDescriptor, scope_id: str):
        """获取作用域实例"""
        if not scope_id:
            scope_id = self._current_scope
            if not scope_id:
                raise ValueError("作用域服务需要在作用域内解析")
        
        if scope_id not in self._scoped_instances:
            self._scoped_instances[scope_id] = {}
        
        scoped_dict = self._scoped_instances[scope_id]
        
        if interface not in scoped_dict:
            scoped_dict[interface] = self._create_instance(descriptor)
        
        return scoped_dict[interface]
    
    def _create_instance(self, descriptor: ServiceDescriptor):
        """创建实例"""
        if descriptor.instance is not None:
            return descriptor.instance
        
        if descriptor.factory is not None:
            return descriptor.factory()
        
        return self._auto_resolve(descriptor.implementation)
    
    def _auto_resolve(self, cls: Type):
        """自动解析构造函数依赖"""
        signature = inspect.signature(cls.__init__)
        kwargs = {}
        
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            param_type = param.annotation
            if param_type != inspect.Parameter.empty:
                kwargs[param_name] = self.resolve(param_type)
        
        return cls(**kwargs)
    
    def dispose_scope(self, scope_id: str):
        """释放作用域"""
        if scope_id in self._scoped_instances:
            scoped_dict = self._scoped_instances[scope_id]
            
            # 调用实例的dispose方法(如果存在)
            for instance in scoped_dict.values():
                if hasattr(instance, 'dispose'):
                    instance.dispose()
            
            del self._scoped_instances[scope_id]

class DIScope:
    """依赖注入作用域"""
    
    def __init__(self, container: AdvancedDIContainer, scope_id: str):
        self.container = container
        self.scope_id = scope_id
    
    def __enter__(self):
        self.container._current_scope = self.scope_id
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.container._current_scope = None
        self.container.dispose_scope(self.scope_id)
    
    def resolve(self, interface: Type):
        """在作用域内解析服务"""
        return self.container.resolve(interface, self.scope_id)

# 使用示例
class IRepository:
    def save(self, data):
        pass

class DatabaseRepository(IRepository):
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.logger.log("DatabaseRepository 已创建")
    
    def save(self, data):
        self.logger.log(f"保存数据: {data}")
    
    def dispose(self):
        print("DatabaseRepository 资源已释放")

class TransientService:
    def __init__(self):
        print(f"TransientService 实例已创建: {id(self)}")

# 配置高级容器
container = AdvancedDIContainer()

# 注册不同生命周期的服务
container.register_singleton(ILogger, ConsoleLogger)
container.register_scoped(IRepository, DatabaseRepository)
container.register_transient(TransientService)

# 测试单例
logger1 = container.resolve(ILogger)
logger2 = container.resolve(ILogger)
print(f"单例测试: {id(logger1) == id(logger2)}")

# 测试瞬态
trans1 = container.resolve(TransientService)
trans2 = container.resolve(TransientService)
print(f"瞬态测试: {id(trans1) == id(trans2)}")

# 测试作用域
with container.create_scope() as scope1:
    repo1 = scope1.resolve(IRepository)
    repo2 = scope1.resolve(IRepository)
    print(f"作用域内相同: {id(repo1) == id(repo2)}")

with container.create_scope() as scope2:
    repo3 = scope2.resolve(IRepository)
    print(f"不同作用域: {id(repo1) == id(repo3)}")

4. 高级依赖注入框架

dependency-injector 框架

python 复制代码
# 使用 dependency-injector 框架
# pip install dependency-injector

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
import logging

# 配置和服务定义
class Config:
    def __init__(self):
        self.database_url = "postgresql://localhost/mydb"
        self.redis_url = "redis://localhost:6379"
        self.log_level = logging.INFO

class Database:
    def __init__(self, url: str):
        self.url = url
        print(f"数据库连接: {url}")
    
    def get_connection(self):
        return f"Connection to {self.url}"

class Cache:
    def __init__(self, url: str):
        self.url = url
        print(f"缓存连接: {url}")
    
    def get(self, key: str):
        return f"Cached value for {key}"

class Logger:
    def __init__(self, level: int):
        self.level = level
        print(f"日志级别: {level}")
    
    def info(self, message: str):
        print(f"[INFO] {message}")

class UserRepository:
    def __init__(self, database: Database, logger: Logger):
        self.database = database
        self.logger = logger
    
    def find_user(self, user_id: int):
        self.logger.info(f"查找用户: {user_id}")
        conn = self.database.get_connection()
        return {"id": user_id, "name": f"用户{user_id}", "connection": conn}

class UserService:
    def __init__(self, repository: UserRepository, cache: Cache, logger: Logger):
        self.repository = repository
        self.cache = cache
        self.logger = logger
    
    def get_user(self, user_id: int):
        self.logger.info(f"获取用户服务调用: {user_id}")
        
        # 尝试从缓存获取
        cached = self.cache.get(f"user:{user_id}")
        if cached:
            self.logger.info("从缓存返回")
            return cached
        
        # 从数据库获取
        user = self.repository.find_user(user_id)
        self.logger.info("从数据库返回")
        return user

# 容器配置
class Container(containers.DeclarativeContainer):
    """依赖注入容器"""
    
    # 配置提供者
    config = providers.Singleton(Config)
    
    # 基础服务提供者
    database = providers.Singleton(
        Database,
        url=config.provided.database_url
    )
    
    cache = providers.Singleton(
        Cache,
        url=config.provided.redis_url
    )
    
    logger = providers.Singleton(
        Logger,
        level=config.provided.log_level
    )
    
    # 业务服务提供者
    user_repository = providers.Factory(
        UserRepository,
        database=database,
        logger=logger
    )
    
    user_service = providers.Factory(
        UserService,
        repository=user_repository,
        cache=cache,
        logger=logger
    )

# 使用装饰器注入
class UserController:
    @inject
    def __init__(self, 
                 user_service: UserService = Provide[Container.user_service],
                 logger: Logger = Provide[Container.logger]):
        self.user_service = user_service
        self.logger = logger
    
    def get_user_endpoint(self, user_id: int):
        self.logger.info(f"API调用: 获取用户 {user_id}")
        return self.user_service.get_user(user_id)

# 函数级注入
@inject
def get_user_function(user_id: int, 
                     user_service: UserService = Provide[Container.user_service]):
    return user_service.get_user(user_id)

# 使用示例
def demo_dependency_injector():
    # 创建容器
    container = Container()
    
    # 配置注入
    container.wire(modules=[__name__])
    
    # 直接从容器获取服务
    user_service = container.user_service()
    user = user_service.get_user(1)
    print(f"直接获取: {user}")
    
    # 使用注入的控制器
    controller = UserController()
    user = controller.get_user_endpoint(2)
    print(f"控制器获取: {user}")
    
    # 使用注入的函数
    user = get_user_function(3)
    print(f"函数获取: {user}")

# demo_dependency_injector()

自定义装饰器框架

python 复制代码
# 自定义装饰器依赖注入框架
from typing import Dict, Any, Callable, TypeVar, get_type_hints
from functools import wraps
import inspect

class Injectable:
    """标记类为可注入"""
    pass

class ServiceRegistry:
    """服务注册表"""
    _instance = None
    _services: Dict[type, Any] = {}
    _factories: Dict[type, Callable] = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    @classmethod
    def register(cls, interface: type, implementation: type = None, singleton: bool = True):
        """注册服务"""
        if implementation is None:
            implementation = interface
        
        if singleton:
            cls._factories[interface] = lambda: implementation
        else:
            cls._services[interface] = implementation
    
    @classmethod
    def get(cls, interface: type):
        """获取服务"""
        if interface in cls._services:
            return cls._services[interface]
        
        if interface in cls._factories:
            factory = cls._factories[interface]
            instance = cls._auto_wire(factory())
            cls._services[interface] = instance
            return instance
        
        # 尝试自动注册和解析
        return cls._auto_wire(interface)
    
    @classmethod
    def _auto_wire(cls, target):
        """自动注入"""
        if inspect.isclass(target):
            # 创建实例
            return cls._create_instance(target)
        else:
            # 已经是实例
            return target
    
    @classmethod
    def _create_instance(cls, target_class):
        """创建实例并注入依赖"""
        signature = inspect.signature(target_class.__init__)
        type_hints = get_type_hints(target_class.__init__)
        
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            # 获取参数类型
            param_type = type_hints.get(param_name)
            if param_type:
                kwargs[param_name] = cls.get(param_type)
        
        return target_class(**kwargs)

def injectable(cls):
    """可注入装饰器"""
    original_init = cls.__init__
    
    @wraps(original_init)
    def new_init(self, *args, **kwargs):
        # 如果没有提供参数,尝试自动注入
        if not args and not kwargs:
            signature = inspect.signature(original_init)
            type_hints = get_type_hints(original_init)
            
            for param_name, param in signature.parameters.items():
                if param_name == 'self':
                    continue
                
                param_type = type_hints.get(param_name)
                if param_type:
                    kwargs[param_name] = ServiceRegistry.get(param_type)
        
        original_init(self, *args, **kwargs)
    
    cls.__init__ = new_init
    return cls

def inject(func):
    """方法注入装饰器"""
    signature = inspect.signature(func)
    type_hints = get_type_hints(func)
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 为缺失的参数注入依赖
        bound_args = signature.bind_partial(*args, **kwargs)
        
        for param_name, param in signature.parameters.items():
            if param_name not in bound_args.arguments:
                param_type = type_hints.get(param_name)
                if param_type:
                    bound_args.arguments[param_name] = ServiceRegistry.get(param_type)
        
        return func(**bound_args.arguments)
    
    return wrapper

# 使用示例
class IEmailService:
    def send_email(self, to: str, subject: str, body: str):
        pass

@injectable
class EmailService(IEmailService):
    def send_email(self, to: str, subject: str, body: str):
        print(f"发送邮件到 {to}: {subject}")

class IUserRepository:
    def get_user(self, user_id: int):
        pass

@injectable
class UserRepository(IUserRepository):
    def __init__(self, email_service: IEmailService):
        self.email_service = email_service
    
    def get_user(self, user_id: int):
        user = {"id": user_id, "email": f"user{user_id}@example.com"}
        return user

@injectable
class UserService:
    def __init__(self, user_repository: IUserRepository, email_service: IEmailService):
        self.user_repository = user_repository
        self.email_service = email_service
    
    def notify_user(self, user_id: int, message: str):
        user = self.user_repository.get_user(user_id)
        self.email_service.send_email(user["email"], "通知", message)

# 注册服务
ServiceRegistry.register(IEmailService, EmailService)
ServiceRegistry.register(IUserRepository, UserRepository)

# 使用注入的函数
@inject
def send_notification(user_id: int, message: str, user_service: UserService):
    user_service.notify_user(user_id, message)

# 测试
def demo_custom_di():
    # 自动创建并注入依赖
    user_service = UserService()
    user_service.notify_user(1, "欢迎使用系统")
    
    # 使用注入的函数
    send_notification(2, "系统维护通知")

# demo_custom_di()

5. 装饰器实现依赖注入

基于装饰器的高级依赖注入

python 复制代码
# 高级装饰器依赖注入系统
from typing import Dict, Any, Callable, Type, get_type_hints
from functools import wraps
import inspect
from enum import Enum

class Scope(Enum):
    SINGLETON = "singleton"
    TRANSIENT = "transient"
    REQUEST = "request"

class DIMetadata:
    """依赖注入元数据"""
    def __init__(self):
        self.scope = Scope.SINGLETON
        self.lazy = False
        self.factory = None
        self.dependencies = []

class AdvancedDIRegistry:
    """高级依赖注入注册表"""
    _services: Dict[Type, Any] = {}
    _metadata: Dict[Type, DIMetadata] = {}
    _instances: Dict[Type, Any] = {}
    _request_instances: Dict[str, Dict[Type, Any]] = {}
    
    @classmethod
    def register_service(cls, service_type: Type, implementation: Type = None, 
                        scope: Scope = Scope.SINGLETON, lazy: bool = False):
        """注册服务"""
        if implementation is None:
            implementation = service_type
        
        cls._services[service_type] = implementation
        
        metadata = DIMetadata()
        metadata.scope = scope
        metadata.lazy = lazy
        cls._metadata[service_type] = metadata
    
    @classmethod
    def register_factory(cls, service_type: Type, factory: Callable, 
                         scope: Scope = Scope.SINGLETON):
        """注册工厂函数"""
        cls._services[service_type] = None
        
        metadata = DIMetadata()
        metadata.scope = scope
        metadata.factory = factory
        cls._metadata[service_type] = metadata
    
    @classmethod
    def resolve(cls, service_type: Type, request_id: str = None):
        """解析服务"""
        if service_type not in cls._services:
            raise ValueError(f"服务 {service_type.__name__} 未注册")
        
        metadata = cls._metadata[service_type]
        
        if metadata.scope == Scope.SINGLETON:
            return cls._get_singleton(service_type, metadata)
        elif metadata.scope == Scope.REQUEST:
            return cls._get_request_scoped(service_type, metadata, request_id)
        else:  # TRANSIENT
            return cls._create_instance(service_type, metadata)
    
    @classmethod
    def _get_singleton(cls, service_type: Type, metadata: DIMetadata):
        """获取单例"""
        if service_type in cls._instances:
            return cls._instances[service_type]
        
        instance = cls._create_instance(service_type, metadata)
        cls._instances[service_type] = instance
        return instance
    
    @classmethod
    def _get_request_scoped(cls, service_type: Type, metadata: DIMetadata, request_id: str):
        """获取请求作用域实例"""
        if not request_id:
            raise ValueError("请求作用域服务需要提供request_id")
        
        if request_id not in cls._request_instances:
            cls._request_instances[request_id] = {}
        
        request_dict = cls._request_instances[request_id]
        
        if service_type not in request_dict:
            request_dict[service_type] = cls._create_instance(service_type, metadata)
        
        return request_dict[service_type]
    
    @classmethod
    def _create_instance(cls, service_type: Type, metadata: DIMetadata):
        """创建实例"""
        if metadata.factory:
            return metadata.factory()
        
        implementation = cls._services[service_type]
        if implementation is None:
            implementation = service_type
        
        return cls._inject_dependencies(implementation)
    
    @classmethod
    def _inject_dependencies(cls, target_class: Type):
        """注入依赖"""
        signature = inspect.signature(target_class.__init__)
        type_hints = get_type_hints(target_class.__init__)
        
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            param_type = type_hints.get(param_name)
            if param_type and param_type in cls._services:
                kwargs[param_name] = cls.resolve(param_type)
        
        return target_class(**kwargs)
    
    @classmethod
    def clear_request_scope(cls, request_id: str):
        """清理请求作用域"""
        if request_id in cls._request_instances:
            del cls._request_instances[request_id]

# 装饰器定义
def service(scope: Scope = Scope.SINGLETON, lazy: bool = False):
    """服务装饰器"""
    def decorator(cls):
        AdvancedDIRegistry.register_service(cls, cls, scope, lazy)
        return cls
    return decorator

def singleton(cls):
    """单例装饰器"""
    return service(Scope.SINGLETON)(cls)

def transient(cls):
    """瞬态装饰器"""
    return service(Scope.TRANSIENT)(cls)

def request_scoped(cls):
    """请求作用域装饰器"""
    return service(Scope.REQUEST)(cls)

def inject_dependencies(cls):
    """依赖注入装饰器"""
    original_init = cls.__init__
    
    @wraps(original_init)
    def new_init(self, *args, **kwargs):
        # 获取构造函数的类型提示
        type_hints = get_type_hints(original_init)
        signature = inspect.signature(original_init)
        
        # 自动注入缺失的依赖
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            if param_name not in kwargs:
                param_type = type_hints.get(param_name)
                if param_type and param_type in AdvancedDIRegistry._services:
                    kwargs[param_name] = AdvancedDIRegistry.resolve(param_type)
        
        original_init(self, *args, **kwargs)
    
    cls.__init__ = new_init
    return cls

def inject_method(*dependencies):
    """方法依赖注入装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 获取函数的类型提示
            type_hints = get_type_hints(func)
            signature = inspect.signature(func)
            
            # 注入指定的依赖
            for param_name, param in signature.parameters.items():
                if param_name not in kwargs:
                    param_type = type_hints.get(param_name)
                    if param_type and param_type in dependencies:
                        kwargs[param_name] = AdvancedDIRegistry.resolve(param_type)
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 使用示例
from abc import ABC, abstractmethod

class ILogger(ABC):
    @abstractmethod
    def log(self, message: str):
        pass

@singleton
class ConsoleLogger(ILogger):
    def __init__(self):
        print("ConsoleLogger 单例已创建")
    
    def log(self, message: str):
        print(f"[LOG] {message}")

class IEmailService(ABC):
    @abstractmethod
    def send(self, to: str, subject: str, body: str):
        pass

@transient
class EmailService(IEmailService):
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.logger.log("EmailService 实例已创建")
    
    def send(self, to: str, subject: str, body: str):
        self.logger.log(f"发送邮件到 {to}: {subject}")

@request_scoped
class RequestContext:
    def __init__(self, logger: ILogger):
        self.logger = logger
        self.request_id = id(self)
        self.logger.log(f"RequestContext 已创建: {self.request_id}")
    
    def get_request_id(self):
        return self.request_id

@inject_dependencies
class OrderService:
    def __init__(self, email_service: IEmailService, logger: ILogger, 
                 request_context: RequestContext):
        self.email_service = email_service
        self.logger = logger
        self.request_context = request_context
    
    def create_order(self, customer_email: str):
        request_id = self.request_context.get_request_id()
        self.logger.log(f"创建订单 - 请求ID: {request_id}")
        self.email_service.send(customer_email, "订单确认", "您的订单已创建")

# 注册接口实现
AdvancedDIRegistry.register_service(ILogger, ConsoleLogger)
AdvancedDIRegistry.register_service(IEmailService, EmailService)

# 工厂函数示例
def create_special_logger():
    return ConsoleLogger()

AdvancedDIRegistry.register_factory(ILogger, create_special_logger)

# 使用方法注入
@inject_method(ILogger, IEmailService)
def process_order(order_id: int, logger: ILogger, email_service: IEmailService):
    logger.log(f"处理订单: {order_id}")
    email_service.send("customer@example.com", "处理中", f"订单 {order_id} 正在处理")

# 测试
def demo_advanced_decorators():
    import uuid
    
    # 测试不同作用域
    request_id = str(uuid.uuid4())
    
    # 创建订单服务
    order_service1 = OrderService()
    order_service2 = OrderService()
    
    # 验证单例
    print(f"Logger是否为单例: {id(order_service1.logger) == id(order_service2.logger)}")
    
    # 验证瞬态
    print(f"EmailService是否为瞬态: {id(order_service1.email_service) == id(order_service2.email_service)}")
    
    # 创建订单
    order_service1.create_order("customer1@example.com")
    order_service2.create_order("customer2@example.com")
    
    # 使用方法注入
    process_order(12345)
    
    # 清理请求作用域
    AdvancedDIRegistry.clear_request_scope(request_id)

# demo_advanced_decorators()

6. 类型注解与依赖注入

基于类型注解的依赖注入

python 复制代码
# 基于类型注解的依赖注入系统
from typing import Dict, Any, Callable, Type, TypeVar, get_type_hints, get_origin, get_args
from functools import wraps
import inspect
from abc import ABC, abstractmethod

T = TypeVar('T')

class TypedDIContainer:
    """基于类型注解的依赖注入容器"""
    
    def __init__(self):
        self._services: Dict[Type, Type] = {}
        self._instances: Dict[Type, Any] = {}
        self._factories: Dict[Type, Callable] = {}
        self._configurations: Dict[str, Any] = {}
    
    def register(self, interface: Type[T], implementation: Type[T] = None) -> 'TypedDIContainer':
        """注册服务类型"""
        if implementation is None:
            implementation = interface
        
        self._services[interface] = implementation
        return self
    
    def register_instance(self, interface: Type[T], instance: T) -> 'TypedDIContainer':
        """注册实例"""
        self._instances[interface] = instance
        return self
    
    def register_factory(self, interface: Type[T], factory: Callable[[], T]) -> 'TypedDIContainer':
        """注册工厂函数"""
        self._factories[interface] = factory
        return self
    
    def configure(self, key: str, value: Any) -> 'TypedDIContainer':
        """配置值"""
        self._configurations[key] = value
        return self
    
    def resolve(self, interface: Type[T]) -> T:
        """解析类型"""
        # 检查实例缓存
        if interface in self._instances:
            return self._instances[interface]
        
        # 检查工厂函数
        if interface in self._factories:
            instance = self._factories[interface]()
            self._instances[interface] = instance
            return instance
        
        # 检查注册的服务
        if interface in self._services:
            implementation = self._services[interface]
            instance = self._create_typed_instance(implementation)
            self._instances[interface] = instance
            return instance
        
        # 尝试直接创建
        return self._create_typed_instance(interface)
    
    def _create_typed_instance(self, cls: Type[T]) -> T:
        """创建类型化实例"""
        # 获取构造函数类型提示
        type_hints = get_type_hints(cls.__init__)
        signature = inspect.signature(cls.__init__)
        
        kwargs = {}
        for param_name, param in signature.parameters.items():
            if param_name == 'self':
                continue
            
            param_type = type_hints.get(param_name)
            if param_type:
                # 处理特殊类型
                if self._is_config_type(param_type):
                    kwargs[param_name] = self._resolve_config(param_name, param_type)
                else:
                    kwargs[param_name] = self.resolve(param_type)
        
        return cls(**kwargs)
    
    def _is_config_type(self, param_type: Type) -> bool:
        """检查是否为配置类型"""
        return param_type in [str, int, float, bool] or hasattr(param_type, '__origin__')
    
    def _resolve_config(self, param_name: str, param_type: Type):
        """解析配置值"""
        if param_name in self._configurations:
            return self._configurations[param_name]
        
        # 默认值处理
        defaults = {
            str: "",
            int: 0,
            float: 0.0,
            bool: False
        }
        return defaults.get(param_type, None)

# 类型安全的装饰器
def typed_injectable(container: TypedDIContainer):
    """类型安全的可注入装饰器"""
    def decorator(cls: Type[T]) -> Type[T]:
        original_init = cls.__init__
        
        @wraps(original_init)
        def new_init(self, *args, **kwargs):
            if not args and not kwargs:
                # 使用类型提示自动注入
                type_hints = get_type_hints(original_init)
                signature = inspect.signature(original_init)
                
                for param_name, param in signature.parameters.items():
                    if param_name == 'self':
                        continue
                    
                    param_type = type_hints.get(param_name)
                    if param_type:
                        if container._is_config_type(param_type):
                            kwargs[param_name] = container._resolve_config(param_name, param_type)
                        else:
                            kwargs[param_name] = container.resolve(param_type)
            
            original_init(self, *args, **kwargs)
        
        cls.__init__ = new_init
        return cls
    
    return decorator

def typed_inject(container: TypedDIContainer):
    """类型安全的方法注入装饰器"""
    def decorator(func: Callable) -> Callable:
        type_hints = get_type_hints(func)
        signature = inspect.signature(func)
        
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 自动填充缺失的类型化参数
            bound_args = signature.bind_partial(*args, **kwargs)
            
            for param_name, param in signature.parameters.items():
                if param_name not in bound_args.arguments:
                    param_type = type_hints.get(param_name)
                    if param_type:
                        if container._is_config_type(param_type):
                            bound_args.arguments[param_name] = container._resolve_config(param_name, param_type)
                        else:
                            bound_args.arguments[param_name] = container.resolve(param_type)
            
            return func(**bound_args.arguments)
        
        return wrapper
    
    return decorator

# 泛型支持
from typing import Generic, List, Optional

class Repository(Generic[T], ABC):
    """通用仓储接口"""
    
    @abstractmethod
    def find_by_id(self, id: int) -> Optional[T]:
        pass
    
    @abstractmethod
    def find_all(self) -> List[T]:
        pass
    
    @abstractmethod
    def save(self, entity: T) -> T:
        pass

class User:
    def __init__(self, id: int, name: str, email: str):
        self.id = id
        self.name = name
        self.email = email
    
    def __repr__(self):
        return f"User(id={self.id}, name='{self.name}', email='{self.email}')"

class Product:
    def __init__(self, id: int, name: str, price: float):
        self.id = id
        self.name = name
        self.price = price
    
    def __repr__(self):
        return f"Product(id={self.id}, name='{self.name}', price={self.price})"

class UserRepository(Repository[User]):
    """用户仓储实现"""
    
    def __init__(self, database_url: str):
        self.database_url = database_url
        self._users = [
            User(1, "张三", "zhangsan@example.com"),
            User(2, "李四", "lisi@example.com")
        ]
    
    def find_by_id(self, id: int) -> Optional[User]:
        return next((user for user in self._users if user.id == id), None)
    
    def find_all(self) -> List[User]:
        return self._users.copy()
    
    def save(self, entity: User) -> User:
        self._users.append(entity)
        return entity

class ProductRepository(Repository[Product]):
    """产品仓储实现"""
    
    def __init__(self, database_url: str):
        self.database_url = database_url
        self._products = [
            Product(1, "笔记本电脑", 5999.99),
            Product(2, "智能手机", 2999.99)
        ]
    
    def find_by_id(self, id: int) -> Optional[Product]:
        return next((product for product in self._products if product.id == id), None)
    
    def find_all(self) -> List[Product]:
        return self._products.copy()
    
    def save(self, entity: Product) -> Product:
        self._products.append(entity)
        return entity

class EmailService:
    """邮件服务"""
    
    def __init__(self, smtp_host: str, smtp_port: int):
        self.smtp_host = smtp_host
        self.smtp_port = smtp_port
    
    def send_email(self, to: str, subject: str, body: str):
        print(f"发送邮件到 {to} (通过 {self.smtp_host}:{self.smtp_port})")
        print(f"主题: {subject}")
        print(f"内容: {body}")

class UserService:
    """用户服务"""
    
    def __init__(self, 
                 user_repository: Repository[User],
                 email_service: EmailService):
        self.user_repository = user_repository
        self.email_service = email_service
    
    def get_user(self, user_id: int) -> Optional[User]:
        return self.user_repository.find_by_id(user_id)
    
    def create_user(self, name: str, email: str) -> User:
        user_id = len(self.user_repository.find_all()) + 1
        user = User(user_id, name, email)
        saved_user = self.user_repository.save(user)
        
        # 发送欢迎邮件
        self.email_service.send_email(
            user.email, 
            "欢迎注册", 
            f"欢迎 {user.name} 注册我们的服务!"
        )
        
        return saved_user

# 配置容器
def setup_typed_container():
    """设置类型化容器"""
    container = TypedDIContainer()
    
    # 配置基础设置
    container.configure("database_url", "postgresql://localhost/mydb")
    container.configure("smtp_host", "smtp.example.com")
    container.configure("smtp_port", 587)
    
    # 注册泛型仓储
    container.register(Repository[User], UserRepository)
    container.register(Repository[Product], ProductRepository)
    
    # 注册服务
    container.register(EmailService)
    container.register(UserService)
    
    return container

# 使用类型安全的装饰器
container = setup_typed_container()

@typed_injectable(container)
class OrderService:
    """订单服务"""
    
    def __init__(self, 
                 user_repository: Repository[User],
                 product_repository: Repository[Product],
                 email_service: EmailService):
        self.user_repository = user_repository
        self.product_repository = product_repository
        self.email_service = email_service
    
    def create_order(self, user_id: int, product_id: int):
        user = self.user_repository.find_by_id(user_id)
        product = self.product_repository.find_by_id(product_id)
        
        if not user or not product:
            raise ValueError("用户或产品不存在")
        
        order_id = f"ORD-{user_id}-{product_id}"
        
        self.email_service.send_email(
            user.email,
            "订单确认",
            f"您的订单 {order_id} 已创建,产品: {product.name}"
        )
        
        return {
            "order_id": order_id,
            "user": user,
            "product": product
        }

@typed_inject(container)
def get_user_info(user_id: int, user_service: UserService) -> Optional[User]:
    """获取用户信息的函数"""
    return user_service.get_user(user_id)

# 测试类型化依赖注入
def demo_typed_di():
    """演示类型化依赖注入"""
    
    # 直接解析服务
    user_service = container.resolve(UserService)
    
    # 创建用户
    new_user = user_service.create_user("王五", "wangwu@example.com")
    print(f"创建用户: {new_user}")
    
    # 获取用户
    user = user_service.get_user(1)
    print(f"获取用户: {user}")
    
    # 使用装饰器注入的类
    order_service = OrderService()
    order = order_service.create_order(1, 1)
    print(f"创建订单: {order}")
    
    # 使用装饰器注入的函数
    user_info = get_user_info(2)
    print(f"函数获取用户: {user_info}")

# demo_typed_di()
相关推荐
R-G-B1 分钟前
OpenCV Python——Numpy基本操作(Numpy 矩阵操作、Numpy 矩阵的检索与赋值、Numpy 操作ROI)
python·opencv·numpy·numpy基本操作·numpy 矩阵操作·numpy 矩阵的检索与赋值·numpy 操作roi
细节处有神明8 分钟前
Jupyter 中实现交互式图表:ipywidgets 从入门到部署
ide·python·jupyter
小小码农一只8 分钟前
Python 爬虫实战:玩转 Playwright 跨浏览器自动化(Chromium/Firefox/WebKit 全支持)
爬虫·python·自动化
深盾安全1 小时前
Python脚本安全防护策略全解析(上)
python
杜子不疼.1 小时前
《Python学习之使用标准库:从入门到实战》
开发语言·python·学习
胡耀超2 小时前
从哲学(业务)视角看待数据挖掘:从认知到实践的螺旋上升
人工智能·python·数据挖掘·大模型·特征工程·crisp-dm螺旋认知·批判性思维
tomelrg2 小时前
多台服务器批量发布arcgisserver服务并缓存切片
服务器·python·arcgis
A尘埃2 小时前
Java+Python混合微服务OCR系统设计
java·python·微服务·混合
冬天vs不冷3 小时前
Java基础(九):Object核心类深度剖析
java·开发语言·python