Python工厂模式与依赖注入:FastAPI的Depends到底在干嘛
前言
这篇是设计模式系列的学习笔记,这次来聊聊工厂模式和依赖注入。
为什么把这俩放一起讲?因为它们在 FastAPI 里是"黄金搭档"------FastAPI 的 Depends 机制本质上就是依赖注入,而工厂函数是创建依赖的常用方式。
说实话,很多人写 FastAPI 的时候,Depends 用得很溜,但不太清楚它背后的原理。什么是依赖注入?为什么这样设计?Depends 内部到底干了啥?这篇文章就把这些事儿讲透。
理解了工厂模式和依赖注入,你写出来的代码会更容易测试、更容易维护,而且能更好地利用 FastAPI 的特性。这可能是这个系列里最"值钱"的一篇,因为它直接关系到你的代码架构。
🏠个人主页:山沐与山
文章目录
- 一、先搞懂工厂模式
- 二、依赖注入是什么鬼
- 三、FastAPI的Depends深度剖析
- 四、实战:分层架构中的依赖注入
- 五、进阶:自己实现一个依赖注入容器
- 六、测试中的巨大优势
- 七、常见问题与最佳实践
- 八、总结
一、先搞懂工厂模式
1.1 为什么需要工厂
假设你在写一个通知系统,需要支持多种通知方式:
python
class EmailNotifier:
def send(self, message: str):
print(f"发送邮件: {message}")
class SMSNotifier:
def send(self, message: str):
print(f"发送短信: {message}")
class PushNotifier:
def send(self, message: str):
print(f"发送推送: {message}")
# 使用的时候
def notify_user(user, message, notify_type):
if notify_type == "email":
notifier = EmailNotifier()
elif notify_type == "sms":
notifier = SMSNotifier()
elif notify_type == "push":
notifier = PushNotifier()
else:
raise ValueError(f"不支持的通知类型: {notify_type}")
notifier.send(message)
问题来了:这段 if-elif 逻辑可能在很多地方都要写。新增一种通知方式?得改好几个地方。累不累?
工厂模式就是把"创建对象"这件事封装起来,让调用方不用关心具体怎么创建。你想想,如果创建逻辑复杂(比如需要读配置、连接数据库),散落在各处岂不是噩梦?
1.2 简单工厂
最直接的封装方式:
python
from typing import Protocol
class Notifier(Protocol):
"""通知器协议"""
def send(self, message: str) -> None: ...
class NotifierFactory:
"""通知器工厂"""
@staticmethod
def create(notify_type: str) -> Notifier:
factories = {
"email": EmailNotifier,
"sms": SMSNotifier,
"push": PushNotifier,
}
if notify_type not in factories:
raise ValueError(f"不支持的通知类型: {notify_type}")
return factories[notify_type]()
# 使用
def notify_user(user, message, notify_type):
notifier = NotifierFactory.create(notify_type)
notifier.send(message)
看到没?现在创建逻辑集中在一个地方了,新增通知类型只需要改工厂。这里用了 Protocol 来定义接口,这是 Python 3.8+ 的结构化子类型(鸭子类型的正式版),不需要显式继承,只要实现了 send 方法就行。
1.3 工厂方法
简单工厂有个问题:所有创建逻辑都在一个类里,违反开闭原则(对扩展开放,对修改关闭)。啥意思?就是每次新增类型,都得改 NotifierFactory 的代码。
工厂方法把创建逻辑分散到各个子类:
python
from abc import ABC, abstractmethod
class NotifierFactory(ABC):
"""抽象工厂"""
@abstractmethod
def create_notifier(self) -> Notifier:
pass
def notify(self, message: str):
# 模板方法:先创建通知器,再发送
notifier = self.create_notifier()
notifier.send(message)
class EmailNotifierFactory(NotifierFactory):
def __init__(self, smtp_server: str, port: int):
self.smtp_server = smtp_server
self.port = port
def create_notifier(self) -> Notifier:
# 每种工厂管理自己的配置
return EmailNotifier(self.smtp_server, self.port)
class SMSNotifierFactory(NotifierFactory):
def __init__(self, api_key: str):
self.api_key = api_key
def create_notifier(self) -> Notifier:
return SMSNotifier(self.api_key)
# 使用
email_factory = EmailNotifierFactory("smtp.example.com", 587)
email_factory.notify("Hello!")
sms_factory = SMSNotifierFactory("your-api-key")
sms_factory.notify("Hello!")
每种通知器有自己的工厂,各自管理各自的配置和创建逻辑。新增类型?只需要新增工厂类,不用改已有代码。这就是开闭原则的体现。
1.4 抽象工厂
当你需要创建一系列相关的对象时,用抽象工厂。
举个例子:你在开发一个跨平台的 UI 框架,需要同时创建按钮、输入框、弹窗,而且这些组件得是同一风格(Windows 风格或 Mac 风格)。
python
from abc import ABC, abstractmethod
# 产品族:定义各类 UI 组件的接口
class Button(ABC):
@abstractmethod
def render(self) -> str: pass
class Input(ABC):
@abstractmethod
def render(self) -> str: pass
class Dialog(ABC):
@abstractmethod
def show(self) -> str: pass
# Windows 风格的产品族
class WindowsButton(Button):
def render(self):
return "<button class='windows-btn'>Click</button>"
class WindowsInput(Input):
def render(self):
return "<input class='windows-input'/>"
class WindowsDialog(Dialog):
def show(self):
return "Windows 风格弹窗"
# Mac 风格的产品族
class MacButton(Button):
def render(self):
return "<button class='mac-btn'>Click</button>"
class MacInput(Input):
def render(self):
return "<input class='mac-input'/>"
class MacDialog(Dialog):
def show(self):
return "Mac 风格弹窗"
# 抽象工厂:定义创建产品族的接口
class UIFactory(ABC):
@abstractmethod
def create_button(self) -> Button: pass
@abstractmethod
def create_input(self) -> Input: pass
@abstractmethod
def create_dialog(self) -> Dialog: pass
class WindowsUIFactory(UIFactory):
def create_button(self) -> Button:
return WindowsButton()
def create_input(self) -> Input:
return WindowsInput()
def create_dialog(self) -> Dialog:
return WindowsDialog()
class MacUIFactory(UIFactory):
def create_button(self) -> Button:
return MacButton()
def create_input(self) -> Input:
return MacInput()
def create_dialog(self) -> Dialog:
return MacDialog()
# 使用:客户端代码不关心具体是什么风格
def render_form(factory: UIFactory):
button = factory.create_button()
input_field = factory.create_input()
return f"{input_field.render()} {button.render()}"
# 根据系统自动选择工厂
import platform
factory = WindowsUIFactory() if platform.system() == "Windows" else MacUIFactory()
print(render_form(factory))
抽象工厂确保你创建的一系列对象是"配套"的------Windows 风格的按钮配 Windows 风格的输入框,不会混搭出四不像。
1.5 Python 风格的工厂:函数就够了
Java 程序员看完上面可能觉得很亲切,但 Python 程序员会说:搞这么复杂干嘛?
在 Python 里,很多时候不需要搞这么复杂的类结构,一个函数就能当工厂:
python
def create_notifier(notify_type: str, **config) -> Notifier:
"""工厂函数:根据类型创建通知器"""
if notify_type == "email":
return EmailNotifier(
smtp_server=config.get("smtp_server", "localhost"),
port=config.get("port", 587)
)
elif notify_type == "sms":
return SMSNotifier(api_key=config["api_key"])
elif notify_type == "push":
return PushNotifier(app_id=config["app_id"])
else:
raise ValueError(f"Unknown type: {notify_type}")
# 使用
notifier = create_notifier("email", smtp_server="smtp.gmail.com")
notifier.send("Hello!")
或者用字典映射 + 装饰器注册,这个模式在插件系统里特别常用:
python
from typing import Dict, Type
class NotifierRegistry:
"""通知器注册表"""
_registry: Dict[str, Type[Notifier]] = {}
@classmethod
def register(cls, name: str):
"""装饰器:注册通知器类"""
def decorator(notifier_class: Type[Notifier]):
cls._registry[name] = notifier_class
return notifier_class
return decorator
@classmethod
def create(cls, name: str, **kwargs) -> Notifier:
"""创建通知器实例"""
if name not in cls._registry:
raise ValueError(f"未注册的通知器: {name}")
return cls._registry[name](**kwargs)
@classmethod
def list_available(cls) -> list:
"""列出所有可用的通知器"""
return list(cls._registry.keys())
# 用装饰器注册,代码更优雅
@NotifierRegistry.register("email")
class EmailNotifier:
def __init__(self, smtp_server: str = "localhost", **kwargs):
self.smtp_server = smtp_server
def send(self, message: str):
print(f"[Email via {self.smtp_server}] {message}")
@NotifierRegistry.register("sms")
class SMSNotifier:
def __init__(self, api_key: str, **kwargs):
self.api_key = api_key
def send(self, message: str):
print(f"[SMS] {message}")
# 使用
print(NotifierRegistry.list_available()) # ['email', 'sms']
notifier = NotifierRegistry.create("email", smtp_server="smtp.gmail.com")
notifier.send("Hello!")
看到没?装饰器一加,类就自动注册了。新增通知类型?写个类加上装饰器就完事,连工厂代码都不用改。这才是 Python 风格。
1.6 工厂模式小结
| 工厂类型 | 适用场景 | Python 实现 |
|---|---|---|
| 简单工厂 | 类型固定,创建逻辑简单 | 工厂函数 + 字典映射 |
| 工厂方法 | 需要子类化,各产品配置不同 | 抽象基类 + 子类实现 |
| 抽象工厂 | 创建一系列相关产品 | 工厂接口 + 产品族 |
| 注册式工厂 | 插件系统,动态扩展 | 装饰器 + 注册表 |
二、依赖注入是什么鬼
2.1 先看一个"紧耦合"的例子
python
class UserRepository:
"""用户仓储"""
def __init__(self):
# 直接在内部创建数据库连接
self.db = DatabaseConnection("postgresql://localhost/mydb")
def get_user(self, user_id: int):
return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
class UserService:
"""用户服务"""
def __init__(self):
# 直接在内部创建仓储
self.repo = UserRepository()
# 直接在内部创建缓存
self.cache = RedisClient("redis://localhost")
def get_user(self, user_id: int):
# 先查缓存
cached = self.cache.get(f"user:{user_id}")
if cached:
return cached
# 缓存没有,查数据库
user = self.repo.get_user(user_id)
self.cache.set(f"user:{user_id}", user)
return user
# 使用
service = UserService()
user = service.get_user(1)
看起来能跑,但问题大了:
测试困难 :想测 UserService,必须有真实的数据库和 Redis。想 mock?改不了,依赖都写死在构造函数里了。
改配置要改代码 :数据库地址变了?得改 UserRepository 的代码。Redis 换成 Memcached?大改。
复用困难 :想在另一个项目用 UserService,但那个项目用 MySQL 不用 PostgreSQL?不好意思,改代码吧。
职责不清 :UserService 不光要处理业务逻辑,还要负责创建依赖。它管得太多了。
2.2 依赖注入:把依赖"注入"进来
核心思想超简单:不要在内部创建依赖,而是从外部传入。
python
from typing import Protocol
class IUserRepository(Protocol):
"""仓储接口"""
def get_user(self, user_id: int) -> dict: ...
class ICache(Protocol):
"""缓存接口"""
def get(self, key: str) -> any: ...
def set(self, key: str, value: any) -> None: ...
class UserService:
"""用户服务 - 依赖注入版"""
def __init__(self, repo: IUserRepository, cache: ICache):
# 依赖从外部注入,不在内部创建
self.repo = repo
self.cache = cache
def get_user(self, user_id: int):
cached = self.cache.get(f"user:{user_id}")
if cached:
return cached
user = self.repo.get_user(user_id)
self.cache.set(f"user:{user_id}", user)
return user
# 使用时,由调用方决定用什么实现
repo = PostgresUserRepository(db_url="postgresql://localhost/mydb")
cache = RedisCache(url="redis://localhost")
service = UserService(repo=repo, cache=cache)
user = service.get_user(1)
对比一下变化:
| 方面 | 紧耦合 | 依赖注入 |
|---|---|---|
| 依赖创建 | 内部 self.x = X() |
外部传入 __init__(self, x) |
| 依赖类型 | 具体类 | 接口/协议 |
| 可测试性 | 差,需要真实依赖 | 好,可以传入 mock |
| 灵活性 | 差,改依赖要改代码 | 好,换实现不改代码 |
| 职责 | 混杂 | 单一 |
这就是依赖注入的精髓:依赖接口,不依赖实现。
2.3 依赖注入的三种方式
1. 构造函数注入(最常用)
python
class UserService:
def __init__(self, repo: IUserRepository, cache: ICache):
self.repo = repo
self.cache = cache
优点:对象创建后就是"完整"的,所有依赖都准备好了。
2. 属性注入
python
class UserService:
repo: IUserRepository = None
cache: ICache = None
def get_user(self, user_id: int):
# 使用前需要先设置属性
pass
service = UserService()
service.repo = PostgresUserRepository()
service.cache = RedisCache()
缺点:对象可能处于"不完整"状态(忘了设置某个属性)。
3. 方法注入
python
class UserService:
def get_user(self, user_id: int, repo: IUserRepository, cache: ICache):
# 每次调用时传入依赖
pass
缺点:每次调用都要传,麻烦。但有时候确实需要(比如不同请求用不同的数据库连接)。
实际项目中,构造函数注入最常用,因为它能保证对象创建后就是可用的。
2.4 控制反转(IoC)
依赖注入是实现"控制反转"的一种方式。这俩概念经常一起出现,到底啥关系?
传统方式 :高层模块(UserService)主动创建低层模块(UserRepository),高层控制低层。
UserService 说:我需要 UserRepository,我自己创建一个
↓
控制权在 UserService
控制反转:高层模块不创建低层模块,而是由外部(IoC 容器或调用方)创建并注入。控制权"反转"了。
外部容器说:UserService 需要 UserRepository?我来创建,注入给你
↓
控制权在外部
用代码说明:
python
# 传统:UserService 自己创建依赖(控制权在 UserService)
class UserService:
def __init__(self):
self.repo = UserRepository() # 我自己 new
# IoC:外部创建依赖并注入(控制权反转到外部)
class UserService:
def __init__(self, repo: IUserRepository):
self.repo = repo # 外部给我
# 外部(IoC 容器)负责组装
repo = UserRepository()
service = UserService(repo) # 控制权在这里
依赖注入就是实现 IoC 的具体手段。IoC 是原则,DI 是实现。
2.5 为什么依赖接口
你可能注意到了,依赖注入通常配合接口使用:
python
class UserService:
def __init__(self, repo: IUserRepository): # 依赖接口,不是具体类
self.repo = repo
为什么?因为这样 UserService 就不关心你传的是 PostgresUserRepository 还是 MySQLUserRepository 还是 MockUserRepository,只要它实现了 IUserRepository 接口就行。
这就是依赖倒置原则(SOLID 中的 D):
- 高层模块不应该依赖低层模块,两者都应该依赖抽象
- 抽象不应该依赖细节,细节应该依赖抽象
听起来很绕?其实就是:依赖接口,不依赖实现。
三、FastAPI的Depends深度剖析
好了,理论讲完了,来看看 FastAPI 是怎么玩的。
3.1 Depends 的基本用法
python
from fastapi import FastAPI, Depends, Header, HTTPException
app = FastAPI()
# 依赖函数:获取数据库会话
def get_db():
db = DatabaseSession()
try:
yield db
finally:
db.close()
# 依赖函数:获取当前用户
def get_current_user(token: str = Header(...)):
user = decode_token(token)
if not user:
raise HTTPException(status_code=401)
return user
# 在路由中使用
@app.get("/users/me")
def read_current_user(user = Depends(get_current_user)):
return user
@app.get("/items")
def read_items(
db = Depends(get_db),
user = Depends(get_current_user)
):
return db.query(Item).filter(Item.owner_id == user.id).all()
Depends 告诉 FastAPI:"调用这个路由之前,先执行这个函数,把返回值传给我"。这就是依赖注入!
3.2 Depends 的执行流程
当请求到达时,FastAPI 会做什么?让我用伪代码解释:
python
# 伪代码,展示 Depends 的执行逻辑
def handle_request(request):
# 1. 解析路由函数的参数签名
params = inspect_function_params(route_function)
# 2. 对于每个 Depends 参数,执行依赖函数
resolved_deps = {}
for param_name, dependency in params.items():
if isinstance(dependency.default, Depends):
# 递归解析(依赖可能还有依赖)
dep_result = resolve_dependency(dependency.default.dependency)
resolved_deps[param_name] = dep_result
# 3. 用解析后的参数调用路由函数
result = route_function(**resolved_deps)
return result
关键点:
- 惰性执行:只有请求到达时才执行依赖函数,不是应用启动时
- 递归解析 :依赖可以有自己的依赖,
FastAPI会自动解析整个依赖树 - 生命周期管理 :用
yield的依赖会在请求结束后执行清理代码
3.3 依赖的各种形态
FastAPI 的 Depends 非常灵活,可以接受多种形态的"依赖":
1. 普通函数
python
def get_query_params(skip: int = 0, limit: int = 10):
return {"skip": skip, "limit": limit}
@app.get("/items")
def read_items(params: dict = Depends(get_query_params)):
return {"params": params}
2. 生成器函数(带清理逻辑)
这个是 FastAPI 的杀手锏!用 yield 可以实现"请求前-请求中-请求后"的生命周期管理:
python
def get_db():
db = SessionLocal()
try:
yield db # 请求处理期间使用这个 db
finally:
db.close() # 请求结束后自动清理
@app.get("/items")
def read_items(db: Session = Depends(get_db)):
return db.query(Item).all()
请求进来
↓
执行 get_db(),到 yield 暂停,返回 db
↓
路由函数用 db 处理请求
↓
请求处理完
↓
继续执行 get_db() 的 finally 块,关闭 db
这比你在路由函数里写 try-finally 优雅多了。
3. 类(可调用对象)
python
class Pagination:
def __init__(self, skip: int = 0, limit: int = 10):
self.skip = skip
self.limit = limit
@app.get("/items")
def read_items(pagination: Pagination = Depends()):
# 当 Depends() 不传参数时,会实例化类型注解的类
return {"skip": pagination.skip, "limit": pagination.limit}
或者用 __call__ 让实例可调用,这个技巧特别适合做参数化的依赖:
python
class PermissionChecker:
def __init__(self, required_permissions: list[str]):
self.required_permissions = required_permissions
def __call__(self, user: User = Depends(get_current_user)):
for perm in self.required_permissions:
if perm not in user.permissions:
raise HTTPException(status_code=403, detail=f"缺少权限: {perm}")
return user
# 不同路由要求不同权限
@app.get("/admin/users")
def list_users(user = Depends(PermissionChecker(["admin:read"]))):
return {"users": []}
@app.delete("/admin/users/{user_id}")
def delete_user(user = Depends(PermissionChecker(["admin:write"]))):
return {"message": "deleted"}
看到没?PermissionChecker(["admin:read"]) 创建一个实例,FastAPI 调用这个实例(触发 __call__),就能实现参数化的权限检查。
4. 异步函数
python
async def get_async_db():
async with AsyncSession() as session:
yield session
@app.get("/items")
async def read_items(db = Depends(get_async_db)):
result = await db.execute(select(Item))
return result.scalars().all()
3.4 依赖的依赖(嵌套依赖)
依赖可以有自己的依赖,FastAPI 会自动解析整个依赖树:
python
# 第一层依赖:获取配置
def get_settings():
return Settings()
# 第二层依赖:获取数据库,依赖配置
def get_db(settings: Settings = Depends(get_settings)):
engine = create_engine(settings.database_url)
Session = sessionmaker(bind=engine)
db = Session()
try:
yield db
finally:
db.close()
# 第三层依赖:获取仓储,依赖数据库
def get_user_repository(db: Session = Depends(get_db)):
return UserRepository(db)
# 第四层依赖:获取服务,依赖仓储
def get_user_service(repo: UserRepository = Depends(get_user_repository)):
return UserService(repo)
# 路由:只需要声明最顶层的依赖
@app.get("/users/{user_id}")
def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
return service.get_user(user_id)
依赖链:
get_user (路由)
└── get_user_service
└── get_user_repository
└── get_db
└── get_settings
FastAPI 会从最底层开始解析:先执行 get_settings,拿到配置;再执行 get_db,传入配置;一层层往上,最后把 UserService 传给路由函数。
这就是依赖注入 + 工厂函数的完美结合。
3.5 依赖缓存
默认情况下,同一个请求中,同一个依赖只会执行一次:
python
def get_db():
print("创建数据库连接") # 只打印一次!
return DatabaseSession()
@app.get("/test")
def test(
db1 = Depends(get_db),
db2 = Depends(get_db) # 不会再次执行 get_db
):
print(db1 is db2) # True,是同一个对象
return {"status": "ok"}
这是个很重要的优化。想象一下,如果一个请求的多个依赖都需要数据库连接,没有缓存的话得创建好几个连接。
如果你需要每次都执行(比如获取不同的随机数),设置 use_cache=False:
python
import random
def get_random():
return random.randint(1, 100)
@app.get("/test")
def test(
r1 = Depends(get_random),
r2 = Depends(get_random, use_cache=False) # 强制重新执行
):
return {"r1": r1, "r2": r2} # r1 和 r2 可能不同
3.6 全局依赖
可以给整个应用或路由组添加依赖:
python
# 全局依赖:所有路由都会检查 API Key
app = FastAPI(dependencies=[Depends(verify_api_key)])
# 路由组依赖:只有这个路由组会检查 admin 权限
from fastapi import APIRouter
router = APIRouter(
prefix="/admin",
dependencies=[Depends(require_admin)]
)
@router.get("/users")
def list_users(): # 自动检查 admin 权限
return []
@router.get("/settings")
def get_settings(): # 也会自动检查
return {}
app.include_router(router)
四、实战:分层架构中的依赖注入
来个完整的例子,展示如何在 FastAPI 项目中用依赖注入实现分层架构。
4.1 项目结构
myapp/
├── main.py # 应用入口
├── config.py # 配置管理
├── database.py # 数据库连接
├── dependencies.py # 依赖注入配置
├── models/
│ └── user.py # 数据库模型
├── repositories/
│ └── user_repository.py # 数据访问层
├── services/
│ └── user_service.py # 业务逻辑层
└── routers/
└── user_router.py # 路由层
4.2 配置层
python
# config.py
from functools import lru_cache
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
app_name: str = "MyApp"
database_url: str = "postgresql://localhost/mydb"
redis_url: str = "redis://localhost:6379"
secret_key: str = "your-secret-key"
class Config:
env_file = ".env"
@lru_cache()
def get_settings() -> Settings:
"""配置单例:整个应用只创建一次"""
return Settings()
@lru_cache() 确保配置只加载一次,这是单例模式的 Python 风格实现。
4.3 数据库层
python
# database.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session, declarative_base
from typing import Generator
from config import get_settings
settings = get_settings()
# 创建引擎(连接池)
engine = create_engine(
settings.database_url,
pool_size=5, # 连接池大小
max_overflow=10, # 超出 pool_size 后最多再创建 10 个
pool_pre_ping=True, # 使用前 ping 一下,检查连接是否有效
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
"""数据库会话工厂"""
db = SessionLocal()
try:
yield db
finally:
db.close()
4.4 模型层
python
# models/user.py
from sqlalchemy import Column, Integer, String, Boolean, DateTime
from sqlalchemy.sql import func
from database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True, nullable=False)
email = Column(String(100), unique=True, index=True, nullable=False)
hashed_password = Column(String(100), nullable=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
4.5 仓储层
python
# repositories/user_repository.py
from typing import Optional, List, Protocol
from sqlalchemy.orm import Session
from models.user import User
class IUserRepository(Protocol):
"""用户仓储接口"""
def get_by_id(self, user_id: int) -> Optional[User]: ...
def get_by_username(self, username: str) -> Optional[User]: ...
def get_by_email(self, email: str) -> Optional[User]: ...
def get_all(self, skip: int, limit: int) -> List[User]: ...
def create(self, user: User) -> User: ...
def update(self, user: User) -> User: ...
def delete(self, user_id: int) -> bool: ...
class UserRepository:
"""用户仓储实现"""
def __init__(self, db: Session):
self.db = db
def get_by_id(self, user_id: int) -> Optional[User]:
return self.db.query(User).filter(User.id == user_id).first()
def get_by_username(self, username: str) -> Optional[User]:
return self.db.query(User).filter(User.username == username).first()
def get_by_email(self, email: str) -> Optional[User]:
return self.db.query(User).filter(User.email == email).first()
def get_all(self, skip: int = 0, limit: int = 100) -> List[User]:
return self.db.query(User).offset(skip).limit(limit).all()
def create(self, user: User) -> User:
self.db.add(user)
self.db.commit()
self.db.refresh(user)
return user
def update(self, user: User) -> User:
self.db.commit()
self.db.refresh(user)
return user
def delete(self, user_id: int) -> bool:
user = self.get_by_id(user_id)
if user:
self.db.delete(user)
self.db.commit()
return True
return False
4.6 服务层
python
# services/user_service.py
from typing import Optional, List
from fastapi import HTTPException, status
from passlib.context import CryptContext
from pydantic import BaseModel, EmailStr
from repositories.user_repository import IUserRepository
from models.user import User
# 请求/响应模型
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
class UserUpdate(BaseModel):
email: Optional[EmailStr] = None
password: Optional[str] = None
class UserResponse(BaseModel):
id: int
username: str
email: str
is_active: bool
class Config:
from_attributes = True
# 服务实现
class UserService:
"""用户服务:处理业务逻辑"""
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def __init__(self, user_repo: IUserRepository):
self.user_repo = user_repo
def _hash_password(self, password: str) -> str:
return self.pwd_context.hash(password)
def _verify_password(self, plain: str, hashed: str) -> bool:
return self.pwd_context.verify(plain, hashed)
def get_user(self, user_id: int) -> UserResponse:
user = self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return UserResponse.model_validate(user)
def get_users(self, skip: int = 0, limit: int = 100) -> List[UserResponse]:
users = self.user_repo.get_all(skip, limit)
return [UserResponse.model_validate(u) for u in users]
def create_user(self, user_data: UserCreate) -> UserResponse:
# 检查用户名是否已存在
if self.user_repo.get_by_username(user_data.username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 检查邮箱是否已存在
if self.user_repo.get_by_email(user_data.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已被注册"
)
# 创建用户
user = User(
username=user_data.username,
email=user_data.email,
hashed_password=self._hash_password(user_data.password)
)
user = self.user_repo.create(user)
return UserResponse.model_validate(user)
def update_user(self, user_id: int, user_data: UserUpdate) -> UserResponse:
user = self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
if user_data.email is not None:
user.email = user_data.email
if user_data.password is not None:
user.hashed_password = self._hash_password(user_data.password)
user = self.user_repo.update(user)
return UserResponse.model_validate(user)
def delete_user(self, user_id: int) -> dict:
if not self.user_repo.delete(user_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return {"message": "用户已删除"}
4.7 依赖注入配置
python
# dependencies.py
from fastapi import Depends
from sqlalchemy.orm import Session
from database import get_db
from repositories.user_repository import UserRepository
from services.user_service import UserService
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
"""用户仓储工厂"""
return UserRepository(db)
def get_user_service(
user_repo: UserRepository = Depends(get_user_repository)
) -> UserService:
"""用户服务工厂"""
return UserService(user_repo)
这就是依赖注入的配置中心。所有的"组装"逻辑都在这里,路由层只需要声明需要什么,不用管怎么创建。
4.8 路由层
python
# routers/user_router.py
from typing import List
from fastapi import APIRouter, Depends, status
from services.user_service import UserService, UserCreate, UserUpdate, UserResponse
from dependencies import get_user_service
router = APIRouter(prefix="/users", tags=["users"])
@router.get("/", response_model=List[UserResponse])
def list_users(
skip: int = 0,
limit: int = 100,
service: UserService = Depends(get_user_service)
):
"""获取用户列表"""
return service.get_users(skip, limit)
@router.get("/{user_id}", response_model=UserResponse)
def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""获取单个用户"""
return service.get_user(user_id)
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def create_user(
user_data: UserCreate,
service: UserService = Depends(get_user_service)
):
"""创建用户"""
return service.create_user(user_data)
@router.put("/{user_id}", response_model=UserResponse)
def update_user(
user_id: int,
user_data: UserUpdate,
service: UserService = Depends(get_user_service)
):
"""更新用户"""
return service.update_user(user_id, user_data)
@router.delete("/{user_id}")
def delete_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""删除用户"""
return service.delete_user(user_id)
4.9 主应用
python
# main.py
from fastapi import FastAPI
from routers import user_router
app = FastAPI(title="用户管理系统")
app.include_router(user_router.router)
@app.get("/")
def root():
return {"message": "Welcome to User Management API"}
4.10 架构图
┌─────────────────────────────────────────────────────┐
│ Router │
│ (user_router.py - 处理HTTP请求,调用Service) │
└───────────────────────┬─────────────────────────────┘
│ Depends(get_user_service)
▼
┌─────────────────────────────────────────────────────┐
│ Service │
│ (user_service.py - 业务逻辑,调用Repository) │
└───────────────────────┬─────────────────────────────┘
│ Depends(get_user_repository)
▼
┌─────────────────────────────────────────────────────┐
│ Repository │
│ (user_repository.py - 数据访问,操作数据库) │
└───────────────────────┬─────────────────────────────┘
│ Depends(get_db)
▼
┌─────────────────────────────────────────────────────┐
│ Database │
│ (database.py - 数据库会话管理) │
└─────────────────────────────────────────────────────┘
每一层只依赖它下面一层的接口,不关心具体实现。想换数据库?改 UserRepository。想换缓存?改 dependencies.py 里的工厂函数。业务逻辑不受影响。
五、进阶:自己实现一个依赖注入容器
理解原理最好的方式是自己实现一个简化版本。
python
import inspect
from typing import Callable, Dict, Any, Type, get_type_hints
from functools import wraps
class DIContainer:
"""简单的依赖注入容器"""
def __init__(self):
self._singletons: Dict[Type, Any] = {} # 单例缓存
self._factories: Dict[Type, Callable] = {} # 工厂函数
def register_singleton(self, interface: Type, implementation: Callable = None):
"""
注册单例:整个应用生命周期只创建一次
"""
if implementation is None:
implementation = interface
def decorator(cls):
self._factories[interface] = cls
return cls
if callable(implementation) and implementation is not interface:
self._factories[interface] = implementation
return implementation
return decorator
def register_transient(self, interface: Type, implementation: Callable = None):
"""
注册瞬态:每次请求都创建新实例
"""
if implementation is None:
implementation = interface
self._factories[interface] = implementation
return implementation
def resolve(self, interface: Type) -> Any:
"""
解析依赖:根据类型获取实例
"""
# 检查是否有单例缓存
if interface in self._singletons:
return self._singletons[interface]
# 检查是否注册了工厂
if interface not in self._factories:
raise ValueError(f"未注册的依赖: {interface}")
factory = self._factories[interface]
# 创建实例,自动注入依赖
instance = self._create_instance(factory)
return instance
def _create_instance(self, factory: Callable) -> Any:
"""
创建实例,自动解析并注入构造函数的依赖
"""
# 获取构造函数的参数签名
sig = inspect.signature(factory)
hints = get_type_hints(factory) if hasattr(factory, '__annotations__') else {}
kwargs = {}
for param_name, param in sig.parameters.items():
if param_name == 'self':
continue
# 获取参数类型
param_type = hints.get(param_name)
if param_type and param_type in self._factories:
# 递归解析依赖
kwargs[param_name] = self.resolve(param_type)
elif param.default is not inspect.Parameter.empty:
# 使用默认值
kwargs[param_name] = param.default
return factory(**kwargs)
def inject(self, func: Callable) -> Callable:
"""
装饰器:自动注入函数参数
"""
@wraps(func)
def wrapper(*args, **kwargs):
sig = inspect.signature(func)
hints = get_type_hints(func)
for param_name, param in sig.parameters.items():
if param_name in kwargs:
continue
param_type = hints.get(param_name)
if param_type and param_type in self._factories:
kwargs[param_name] = self.resolve(param_type)
return func(*args, **kwargs)
return wrapper
# ========== 使用示例 ==========
# 创建容器
container = DIContainer()
# 定义接口
class ILogger:
def log(self, message: str): ...
class IUserRepository:
def get_user(self, user_id: int): ...
class IUserService:
def get_user(self, user_id: int): ...
# 实现类
class ConsoleLogger:
def log(self, message: str):
print(f"[LOG] {message}")
class UserRepository:
def __init__(self, logger: ILogger): # 依赖 ILogger
self.logger = logger
def get_user(self, user_id: int):
self.logger.log(f"查询用户: {user_id}")
return {"id": user_id, "name": f"User_{user_id}"}
class UserService:
def __init__(self, repo: IUserRepository, logger: ILogger): # 依赖多个
self.repo = repo
self.logger = logger
def get_user(self, user_id: int):
self.logger.log(f"获取用户: {user_id}")
return self.repo.get_user(user_id)
# 注册依赖
container.register_singleton(ILogger, ConsoleLogger)
container.register_transient(IUserRepository, UserRepository)
container.register_transient(IUserService, UserService)
# 解析并使用
service = container.resolve(IUserService)
user = service.get_user(1)
print(user)
# 输出:
# [LOG] 获取用户: 1
# [LOG] 查询用户: 1
# {'id': 1, 'name': 'User_1'}
# 使用装饰器注入
@container.inject
def some_function(service: IUserService, user_id: int = 1):
return service.get_user(user_id)
result = some_function(user_id=2)
print(result)
这只是个简化版,真正的 DI 容器(如 dependency-injector 库)会更复杂,支持作用域(请求级、会话级)、异步、配置等特性。但核心原理就是这样:根据类型注解,递归地创建和注入依赖。
六、测试中的巨大优势
依赖注入最大的好处之一就是方便测试。为什么?因为你可以轻松地用 mock 替换真实依赖。
6.1 不用依赖注入的窘境
python
class UserService:
def __init__(self):
self.db = RealDatabase() # 硬编码,测试时必须有真实数据库
self.cache = RealRedis() # 测试时必须有真实 Redis
def get_user(self, user_id: int):
cached = self.cache.get(f"user:{user_id}")
if cached:
return cached
user = self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
self.cache.set(f"user:{user_id}", user)
return user
# 测试
def test_get_user():
service = UserService() # 需要真实的数据库和 Redis!
user = service.get_user(1)
assert user is not None
# 问题:
# 1. 测试需要真实环境
# 2. 测试速度慢(网络IO)
# 3. 测试数据难以准备和清理
# 4. 测试不稳定(网络抖动、数据库状态)
6.2 用依赖注入轻松 mock
python
# 用依赖注入的版本
class UserService:
def __init__(self, db: IDatabase, cache: ICache):
self.db = db
self.cache = cache
def get_user(self, user_id: int):
cached = self.cache.get(f"user:{user_id}")
if cached:
return cached
user = self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
self.cache.set(f"user:{user_id}", user)
return user
# 测试变得超简单
import pytest
from unittest.mock import Mock
class TestUserService:
def test_get_user_from_cache(self):
"""测试从缓存获取用户"""
# 准备 mock
mock_db = Mock()
mock_cache = Mock()
mock_cache.get.return_value = {"id": 1, "name": "Test"}
# 注入 mock
service = UserService(db=mock_db, cache=mock_cache)
# 执行
user = service.get_user(1)
# 断言
assert user == {"id": 1, "name": "Test"}
mock_cache.get.assert_called_once_with("user:1")
mock_db.query.assert_not_called() # 有缓存时不应该查数据库
def test_get_user_from_db(self):
"""测试从数据库获取用户(缓存未命中)"""
mock_db = Mock()
mock_db.query.return_value = {"id": 1, "name": "Test"}
mock_cache = Mock()
mock_cache.get.return_value = None # 缓存未命中
service = UserService(db=mock_db, cache=mock_cache)
user = service.get_user(1)
assert user == {"id": 1, "name": "Test"}
mock_db.query.assert_called_once()
mock_cache.set.assert_called_once() # 应该设置缓存
看到没?不需要真实数据库,不需要真实 Redis,测试跑得飞快,而且每个测试都是独立的。
6.3 FastAPI 测试中的依赖覆盖
FastAPI 提供了 app.dependency_overrides 机制来覆盖依赖:
python
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
import pytest
# 测试用的 mock 依赖
def get_mock_db():
"""返回一个 mock 数据库会话"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = User(
id=1, username="testuser", email="test@example.com"
)
return mock_db
def get_mock_user_service():
"""返回一个 mock 服务"""
mock_service = MagicMock(spec=UserService)
mock_service.get_user.return_value = UserResponse(
id=1, username="testuser", email="test@example.com", is_active=True
)
return mock_service
class TestUserRouter:
@pytest.fixture
def client(self):
"""创建测试客户端,覆盖依赖"""
# 覆盖依赖
app.dependency_overrides[get_db] = get_mock_db
app.dependency_overrides[get_user_service] = get_mock_user_service
client = TestClient(app)
yield client
# 清理
app.dependency_overrides.clear()
def test_get_user(self, client):
response = client.get("/users/1")
assert response.status_code == 200
data = response.json()
assert data["id"] == 1
assert data["username"] == "testuser"
def test_create_user(self, client):
response = client.post("/users/", json={
"username": "newuser",
"email": "new@example.com",
"password": "password123"
})
assert response.status_code == 201
6.4 更精细的测试控制
python
class TestUserRouterEdgeCases:
def test_user_not_found(self):
"""测试用户不存在的情况"""
def get_service_that_raises():
service = MagicMock(spec=UserService)
service.get_user.side_effect = HTTPException(
status_code=404, detail="用户不存在"
)
return service
app.dependency_overrides[get_user_service] = get_service_that_raises
client = TestClient(app)
response = client.get("/users/999")
assert response.status_code == 404
assert response.json()["detail"] == "用户不存在"
app.dependency_overrides.clear()
6.5 使用 pytest fixtures 组织测试
python
# conftest.py
import pytest
from unittest.mock import MagicMock
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from main import app
from database import Base, get_db
# 测试数据库(用内存 SQLite)
SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///./test.db"
test_engine = create_engine(
SQLALCHEMY_TEST_DATABASE_URL,
connect_args={"check_same_thread": False}
)
TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
@pytest.fixture(scope="function")
def test_db():
"""为每个测试创建新的数据库"""
Base.metadata.create_all(bind=test_engine)
db = TestSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=test_engine)
@pytest.fixture(scope="function")
def client(test_db):
"""创建测试客户端"""
def override_get_db():
yield test_db
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
# tests/test_users.py
class TestUserAPI:
def test_create_and_get_user(self, client):
"""集成测试:创建并获取用户"""
# 创建用户
response = client.post("/users/", json={
"username": "testuser",
"email": "test@example.com",
"password": "password123"
})
assert response.status_code == 201
user_id = response.json()["id"]
# 获取用户
response = client.get(f"/users/{user_id}")
assert response.status_code == 200
assert response.json()["username"] == "testuser"
七、常见问题与最佳实践
7.1 循环依赖怎么办
python
# 错误示例:A 依赖 B,B 依赖 A
class ServiceA:
def __init__(self, service_b: "ServiceB"):
self.service_b = service_b
class ServiceB:
def __init__(self, service_a: ServiceA): # 循环依赖!
self.service_a = service_a
解决方法 1:重新设计,打破循环
python
# 提取公共依赖
class CommonService:
def shared_logic(self):
pass
class ServiceA:
def __init__(self, common: CommonService):
self.common = common
class ServiceB:
def __init__(self, common: CommonService):
self.common = common
解决方法 2:使用中介者模式
python
class Mediator:
service_a: "ServiceA" = None
service_b: "ServiceB" = None
def notify(self, sender, event):
if sender == "A":
self.service_b.handle(event)
elif sender == "B":
self.service_a.handle(event)
解决方法 3:延迟注入
python
class ServiceA:
def __init__(self):
self._service_b = None
@property
def service_b(self):
if self._service_b is None:
from dependencies import get_service_b
self._service_b = get_service_b()
return self._service_b
7.2 依赖太多怎么办
如果一个类的构造函数参数超过 5 个,可能说明职责太多了:
python
# 不好:依赖太多,职责不清
class OrderService:
def __init__(
self,
user_repo: UserRepository,
order_repo: OrderRepository,
product_repo: ProductRepository,
inventory_service: InventoryService,
payment_service: PaymentService,
notification_service: NotificationService,
email_service: EmailService,
sms_service: SMSService,
logger: Logger,
):
...
解决方法 1:拆分服务
python
class OrderCreationService:
def __init__(self, user_repo, order_repo, product_repo):
...
class OrderPaymentService:
def __init__(self, payment_service, inventory_service):
...
class OrderNotificationService:
def __init__(self, notification_service, email_service, sms_service):
...
解决方法 2:引入门面(Facade)
python
class NotificationFacade:
"""通知门面,封装多种通知方式"""
def __init__(self, email, sms, push):
self.email = email
self.sms = sms
self.push = push
def notify_user(self, user, message, channels=None):
...
class OrderService:
def __init__(
self,
order_repo: OrderRepository,
notification: NotificationFacade, # 用门面代替多个服务
):
...
7.3 生命周期管理
不同的依赖需要不同的生命周期:
| 生命周期 | 说明 | 适用场景 | 实现方式 |
|---|---|---|---|
| 单例 | 整个应用只有一个实例 | 配置、连接池 | @lru_cache() |
| 请求级 | 每个请求一个实例 | 数据库会话 | yield + Depends |
| 瞬态 | 每次注入新实例 | 无状态工具 | 普通函数 |
python
# 单例:整个应用共享
@lru_cache()
def get_settings():
return Settings()
# 请求级:每个请求一个实例
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 瞬态:每次注入都创建新实例
def get_uuid():
return str(uuid.uuid4())
7.4 依赖注入 vs 单例
什么时候用单例,什么时候用依赖注入?
| 场景 | 推荐方式 | 原因 |
|---|---|---|
| 配置管理 | 单例 + DI | 配置是只读的,全局共享没问题 |
| 数据库连接池 | 单例 | 连接池本身应该是唯一的 |
| 数据库会话 | DI(请求级) | 每个请求独立的事务 |
| 业务服务 | DI | 方便测试,避免状态共享 |
| 无状态工具类 | DI | 虽然可以单例,但 DI 更灵活 |
实际项目中往往是单例做底层资源管理,DI 做上层业务组装:
python
# 底层:单例
@lru_cache()
def get_settings():
return Settings()
# 底层:单例(连接池)
engine = create_engine(...) # 模块级单例
# 上层:DI
def get_db():
return SessionLocal()
def get_user_repo(db = Depends(get_db)):
return UserRepository(db)
def get_user_service(repo = Depends(get_user_repo)):
return UserService(repo)
7.5 异步依赖的注意事项
FastAPI 完全支持异步依赖:
python
async def get_async_db():
async with AsyncSession() as session:
yield session
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db)
):
result = await db.execute(
select(User).where(User.token == token)
)
return result.scalar_one_or_none()
@app.get("/users/me")
async def read_current_user(
user: User = Depends(get_current_user)
):
return user
注意:同步路由可以使用异步依赖,但异步路由里的阻塞代码要小心(会阻塞事件循环)。
八、总结
这篇内容有点多,来梳理一下关键点:
工厂模式
把"创建对象"的逻辑封装起来。Python 里不用搞 Java 那套复杂的类结构,一个函数就能当工厂。装饰器 + 注册表的模式特别适合插件系统。
依赖注入
不在类内部创建依赖,而是从外部传入。好处是解耦、易测试、易替换。核心原则:依赖接口,不依赖实现。
FastAPI 的 Depends
本质就是依赖注入的实现。它会:
- 自动解析依赖树
- 管理生命周期(
yield支持清理) - 支持依赖缓存
- 支持全局依赖
用好了能让代码架构很清晰。
实践建议
| 建议 | 说明 |
|---|---|
用 Depends 管理依赖 |
数据库会话、服务实例等 |
| 配合分层架构 | Router → Service → Repository |
| 依赖接口不依赖实现 | 用 Protocol 定义接口 |
| 测试时覆盖依赖 | app.dependency_overrides |
| 避免循环依赖 | 重新设计或用中介者 |
| 依赖太多就拆分 | 单一职责原则 |
关键要点总结
| 概念 | 说明 | 适用场景 |
|---|---|---|
| 简单工厂 | 封装创建逻辑 | 类型固定,逻辑简单 |
| 工厂方法 | 子类决定创建什么 | 需要扩展,配置不同 |
| 抽象工厂 | 创建产品族 | 一系列相关对象 |
| 依赖注入 | 外部传入依赖 | 解耦、测试、灵活 |
Depends |
FastAPI 的 DI 实现 | Web 应用的依赖管理 |
工厂模式 + 依赖注入 + 单例 ,这三个配合起来用,就是 FastAPI 项目的依赖管理"三板斧"。理解了这些,你写出来的代码会更好维护、更容易测试。
下一篇打算讲仓储模式(Repository Pattern),把数据访问层的抽象再深入讲讲,和这篇的分层架构正好呼应。
热门专栏推荐
- Agent小册
- 服务器部署
- Java基础合集
- Python基础合集
- Go基础合集
- 大数据合集
- 前端小册
- 数据库合集
- Redis 合集
- Spring 全家桶
- 微服务全家桶
- 数据结构与算法合集
- 设计模式小册
- 消息队列合集
等等等还有许多优秀的合集在主页等着大家的光顾,感谢大家的支持
文章到这里就结束了,如果有什么疑问的地方请指出,诸佬们一起来评论区一起讨论😊
希望能和诸佬们一起努力,今后我们一起观看感谢您的阅读🙏
如果帮助到您不妨3连支持一下,创造不易您们的支持是我的动力🌟