Python仓储模式:让数据访问不再混乱
前言
这篇是设计模式系列的学习笔记,这次来聊聊仓储模式(Repository Pattern)。
上一篇讲了工厂模式和依赖注入,最后提到了分层架构里的 Repository 层。这篇就专门把数据访问这块的设计讲透。
说实话,很多人写 FastAPI 项目的时候,习惯在路由函数里直接写 SQLAlchemy 查询,一个接口几十行,查询、过滤、分页全揉在一起。项目小的时候还好,一旦业务复杂起来,代码就变成一锅粥了------同样的查询逻辑写了好几遍,改个字段名得改十几个地方,测试更是没法写。
仓储模式就是解决这个问题的。它把数据访问逻辑封装起来,让业务层不用关心数据是怎么存的、怎么取的。听起来是不是有点像工厂模式?没错,设计模式之间经常是互相配合的,仓储模式和工厂模式、依赖注入一起用,能让代码架构非常清晰。
这篇文章会从"为什么需要"讲起,然后一步步实现一个完整的仓储层,包括泛型仓储、工作单元模式这些进阶内容。内容比较多,但都是实战中会用到的,耐心看完肯定有收获。
🏠个人主页:山沐与山
文章目录
- 一、仓储模式是什么
- 二、不用仓储模式会怎样
- 三、仓储模式的基本实现
- 四、泛型仓储:减少重复代码
- 五、工作单元模式配合仓储
- 六、FastAPI完整实战
- [七、仓储模式 vs 直接用ORM](#七、仓储模式 vs 直接用ORM)
- 八、测试中的巨大优势
- 九、常见问题与最佳实践
- 十、总结
一、仓储模式是什么
1.1 从一个比喻说起
想象一下图书馆。你想借一本书,需要知道这本书具体放在哪个架子的第几层吗?不需要,你只需要告诉图书管理员书名,管理员会帮你找到并拿给你。还书的时候也一样,你不用操心书该放回哪里,交给管理员就行。
仓储模式里的 Repository 就是这个"图书管理员"。你的业务代码只需要说"我要 ID 为 1 的用户"或者"帮我保存这个订单",至于数据是存在 MySQL 还是 PostgreSQL,用的是什么 ORM,业务代码完全不需要知道。
1.2 正式定义
仓储模式(Repository Pattern)是一种将数据访问逻辑与业务逻辑分离的设计模式。它提供一个类似"集合"的接口来访问领域对象,让业务层可以像操作内存中的集合一样操作数据,而不需要关心底层的数据存储细节。
用更直白的话说:Repository 是数据层的抽象,它把"怎么存数据"这件事藏起来,只暴露"存什么、取什么"的接口。
1.3 仓储模式的核心思想
仓储模式有几个核心理念:
第一,隔离数据访问细节 。业务层不直接和数据库打交道,而是通过 Repository 这个中间层。这样数据库换了(比如从 MySQL 换到 PostgreSQL),或者 ORM 换了(比如从 SQLAlchemy 换到 Tortoise),业务层的代码不用改。
第二,统一数据访问接口 。不管底层是关系型数据库、NoSQL、文件系统还是远程 API,Repository 对外提供的接口是一致的。这让业务代码变得简洁,也让测试变得容易(可以用内存实现替换真实数据库)。
第三,集中管理查询逻辑 。所有和某个实体相关的查询都放在对应的 Repository 里。想找"获取活跃用户"的逻辑?去 UserRepository 里找就行,不用在几十个文件里翻。
1.4 仓储模式在分层架构中的位置
在典型的分层架构中,Repository 处于数据访问层,向上对接服务层(Service),向下对接具体的数据存储:
┌─────────────────────────────────────────────────────┐
│ Presentation Layer │
│ (Controllers / Routers) │
│ 处理HTTP请求,调用Service,返回响应 │
└───────────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Service Layer │
│ (Business Logic) │
│ 业务逻辑,调用Repository获取数据 │
└───────────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Repository Layer │
│ (Data Access Logic) │
│ 封装数据访问,提供类集合的操作接口 │
└───────────────────────┬─────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────┐
│ Data Storage Layer │
│ (Database / Cache / External API) │
│ 实际的数据存储 │
└─────────────────────────────────────────────────────┘
每一层只和相邻的层打交道,这样修改任何一层都不会影响其他层(只要接口不变)。
二、不用仓储模式会怎样
在讲怎么实现之前,先看看不用仓储模式会遇到什么问题。这样你才能理解为什么要引入这个模式。
2.1 典型的"面条代码"
很多 FastAPI 项目一开始是这样写的,所有逻辑都塞在路由函数里:
python
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from database import get_db
from models import User, Order, Product
app = FastAPI()
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
# 直接在路由里写查询逻辑
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return user
@app.get("/users")
def list_users(
skip: int = 0,
limit: int = 10,
is_active: bool = None,
keyword: str = None,
db: Session = Depends(get_db)
):
# 查询逻辑越来越复杂
query = db.query(User)
if is_active is not None:
query = query.filter(User.is_active == is_active)
if keyword:
query = query.filter(
or_(
User.username.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%")
)
)
total = query.count()
users = query.offset(skip).limit(limit).all()
return {"total": total, "items": users}
@app.get("/users/{user_id}/orders")
def get_user_orders(
user_id: int,
status: str = None,
start_date: str = None,
end_date: str = None,
db: Session = Depends(get_db)
):
# 又是一大段查询逻辑
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
query = db.query(Order).filter(Order.user_id == user_id)
if status:
query = query.filter(Order.status == status)
if start_date:
query = query.filter(Order.created_at >= start_date)
if end_date:
query = query.filter(Order.created_at <= end_date)
return query.all()
@app.post("/orders")
def create_order(
user_id: int,
product_id: int,
quantity: int,
db: Session = Depends(get_db)
):
# 更多的数据库操作...
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise HTTPException(status_code=404, detail="商品不存在")
if product.stock < quantity:
raise HTTPException(status_code=400, detail="库存不足")
# 创建订单
order = Order(
user_id=user_id,
product_id=product_id,
quantity=quantity,
total_price=product.price * quantity
)
db.add(order)
# 扣减库存
product.stock -= quantity
db.commit()
db.refresh(order)
return order
看起来能跑,但问题其实很多。
2.2 问题分析
问题一:代码重复严重
"根据 ID 查用户"这个逻辑在 get_user、get_user_orders、create_order 里都写了一遍。如果以后查询条件变了(比如要加上软删除过滤),得改好几个地方。累不累?
python
# 这段代码重复了三次
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
问题二:业务逻辑和数据访问逻辑混在一起
路由函数本来应该只负责"接收请求、返回响应",现在还要处理复杂的数据库查询。一个函数做了太多事情,违反单一职责原则。
更糟糕的是,业务规则(比如"库存不足不能下单")和数据操作(查询、更新)交织在一起,代码的意图变得不清晰,后来维护的人很难理解这段代码到底在干什么。
问题三:难以测试
想测试 create_order 的业务逻辑,必须准备一个真实的数据库,还得提前插入用户和商品数据。测试又慢又脆弱,稍微改点东西测试就挂了。
如果能把数据访问抽象出来,测试时用个假的实现(返回固定数据),就能快速验证业务逻辑是否正确。
问题四:难以复用
假设有另一个地方也需要"获取用户的订单列表",怎么办?把那一大段代码复制过去?还是提取成函数?提取成函数放在哪里?
没有统一的规范,代码组织会越来越乱,每个人都有自己的写法,最后就是一团糟。
问题五:切换数据源困难
现在用的是 SQLAlchemy 查 PostgreSQL,哪天要加个 Redis 缓存,或者某些数据要从外部 API 获取,改动会非常大。因为数据库操作散落在各处,没有统一的抽象层。
2.3 我们需要什么
总结一下,我们需要:
| 需求 | 说明 |
|---|---|
| 统一的地方 | 放所有和某个实体相关的数据访问逻辑 |
| 抽象的接口 | 让业务层不依赖具体的数据库实现 |
| 易于测试 | 可以用假实现替换真实数据库 |
| 可复用 | 查询逻辑写一次到处用 |
这就是仓储模式要解决的问题。
三、仓储模式的基本实现
理解了问题,现在来看解决方案。
3.1 定义仓储接口
首先,定义一个抽象的接口,规定 Repository 应该提供哪些操作。Python 里可以用 Protocol 或 ABC 来定义接口。
Protocol 是 Python 3.8+ 引入的,它实现的是"结构化子类型"(structural subtyping),也叫鸭子类型的静态版本。只要一个类实现了 Protocol 定义的方法,就被认为是该 Protocol 的实现,不需要显式继承。这比 ABC 更灵活。
python
from typing import Protocol, TypeVar, Generic, Optional, List
# 定义泛型类型变量
T = TypeVar('T') # 实体类型
ID = TypeVar('ID') # ID类型
class IRepository(Protocol[T, ID]):
"""
仓储接口
这是一个泛型协议,T 是实体类型,ID 是主键类型。
所有具体的 Repository 都应该实现这个接口定义的方法。
"""
def get_by_id(self, id: ID) -> Optional[T]:
"""根据ID获取单个实体,不存在返回 None"""
...
def get_all(self) -> List[T]:
"""获取所有实体"""
...
def add(self, entity: T) -> T:
"""添加一个实体"""
...
def update(self, entity: T) -> T:
"""更新一个实体"""
...
def delete(self, id: ID) -> bool:
"""删除一个实体,返回是否成功"""
...
这个接口定义了最基本的 CRUD 操作(Create、Read、Update、Delete)。注意,这里没有任何和具体数据库相关的代码,它只是一个"契约",规定实现类必须提供这些方法。
3.2 实现具体的仓储
有了接口,接下来实现具体的 Repository。这里以用户仓储为例,使用 SQLAlchemy 作为 ORM:
python
from typing import Optional, List
from sqlalchemy.orm import Session
from models import User
class UserRepository:
"""
用户仓储的 SQLAlchemy 实现
这个类负责所有和 User 实体相关的数据库操作。
业务层通过这个类来访问用户数据,不需要知道底层用的是什么数据库。
"""
def __init__(self, db: Session):
"""
构造函数
Args:
db: SQLAlchemy 的数据库会话,通过依赖注入传入
"""
self.db = db
def get_by_id(self, user_id: int) -> Optional[User]:
"""
根据ID获取用户
返回 Optional 类型意味着可能返回 None,
调用方需要处理用户不存在的情况。
"""
return self.db.query(User).filter(User.id == user_id).first()
def get_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
return self.db.query(User).filter(User.username == username).first()
def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
return self.db.query(User).filter(User.email == email).first()
def get_all(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> List[User]:
"""
获取用户列表
支持分页和过滤。在实际项目中,这个方法的参数可能会更多,
比如排序字段、排序方向、多种过滤条件等。
"""
query = self.db.query(User)
# 条件过滤
if is_active is not None:
query = query.filter(User.is_active == is_active)
# 分页
return query.offset(skip).limit(limit).all()
def count(self, is_active: Optional[bool] = None) -> int:
"""统计用户数量"""
query = self.db.query(User)
if is_active is not None:
query = query.filter(User.is_active == is_active)
return query.count()
def add(self, user: User) -> User:
"""
添加用户
注意这里只是把对象加到 session 里,
真正写入数据库是在 commit 时。
"""
self.db.add(user)
self.db.flush() # flush 会生成 ID,但不提交事务
self.db.refresh(user) # 刷新对象,获取数据库生成的值
return user
def update(self, user: User) -> User:
"""更新用户"""
self.db.flush()
self.db.refresh(user)
return user
def delete(self, user_id: int) -> bool:
"""删除用户"""
user = self.get_by_id(user_id)
if user:
self.db.delete(user)
return True
return False
def exists_by_username(self, username: str) -> bool:
"""检查用户名是否存在"""
return self.db.query(
self.db.query(User).filter(User.username == username).exists()
).scalar()
def exists_by_email(self, email: str) -> bool:
"""检查邮箱是否存在"""
return self.db.query(
self.db.query(User).filter(User.email == email).exists()
).scalar()
这个实现有几点值得注意:
构造函数接收 Session :Repository 不自己创建数据库连接,而是从外部传入。这是依赖注入的体现,让 Repository 可以在不同的上下文中复用(比如测试时传入一个内存数据库的 Session)。
方法命名清晰 :get_by_id、get_by_username 这样的命名一眼就能看出方法的作用。Repository 的方法名应该反映业务语义,而不是技术细节。
分离了 flush 和 commit :Repository 只负责数据操作,不负责事务控制(commit)。事务的边界应该由更上层(Service 层或工作单元)来决定,这样可以把多个操作放在一个事务里。
3.3 在服务层使用仓储
有了 Repository,服务层的代码就变得简洁多了:
python
from typing import Optional, List
from fastapi import HTTPException, status
from repositories.user_repository import UserRepository
from models import User
from schemas import UserCreate, UserUpdate, UserResponse
class UserService:
"""
用户服务
服务层负责业务逻辑,它通过 Repository 来访问数据,
而不直接操作数据库。
"""
def __init__(self, user_repo: UserRepository):
"""通过依赖注入接收 Repository"""
self.user_repo = user_repo
def get_user(self, user_id: int) -> UserResponse:
"""获取用户信息"""
user = self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return UserResponse.model_validate(user)
def get_users(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> dict:
"""获取用户列表(带分页)"""
users = self.user_repo.get_all(skip=skip, limit=limit, is_active=is_active)
total = self.user_repo.count(is_active=is_active)
return {
"total": total,
"items": [UserResponse.model_validate(u) for u in users]
}
def create_user(self, user_data: UserCreate) -> UserResponse:
"""创建用户"""
# 业务规则:检查用户名是否已存在
if self.user_repo.exists_by_username(user_data.username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 业务规则:检查邮箱是否已存在
if self.user_repo.exists_by_email(user_data.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已被注册"
)
# 创建用户实体
user = User(
username=user_data.username,
email=user_data.email,
hashed_password=self._hash_password(user_data.password)
)
# 通过 Repository 保存
user = self.user_repo.add(user)
return UserResponse.model_validate(user)
def _hash_password(self, password: str) -> str:
"""密码哈希"""
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.hash(password)
对比一下之前的"面条代码",现在的服务层:
| 改进点 | 说明 |
|---|---|
| 职责清晰 | 只处理业务逻辑,不写 SQL 查询 |
| 代码简洁 | 数据访问都委托给 Repository |
| 易于测试 | 可以 mock Repository 来测试业务逻辑 |
| 可读性强 | 方法名反映业务意图,读代码像读文档 |
3.4 配置依赖注入
最后,用 FastAPI 的依赖注入把这些串起来:
python
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session
from database import get_db
from repositories.user_repository import UserRepository
from services.user_service import UserService
app = FastAPI()
def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
"""创建用户仓储实例"""
return UserRepository(db)
def get_user_service(
user_repo: UserRepository = Depends(get_user_repository)
) -> UserService:
"""创建用户服务实例"""
return UserService(user_repo)
@app.get("/users/{user_id}")
def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""路由函数变得非常简洁"""
return service.get_user(user_id)
@app.get("/users")
def list_users(
skip: int = 0,
limit: int = 100,
is_active: bool = None,
service: UserService = Depends(get_user_service)
):
return service.get_users(skip=skip, limit=limit, is_active=is_active)
依赖链是这样的:Router → Service → Repository → Database Session。每一层只依赖下一层的抽象,不依赖具体实现,这就是依赖倒置原则的体现。
四、泛型仓储:减少重复代码
上面的 UserRepository 挺好的,但如果项目里有 User、Order、Product、Category 等十几个实体,每个都写一遍基础的 CRUD 方法,太重复了。
泛型仓储(Generic Repository)可以解决这个问题------把通用的 CRUD 逻辑提取到基类,具体的 Repository 只需要继承基类,添加特有的方法。
4.1 实现泛型基类
python
from typing import TypeVar, Generic, Optional, List, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import inspect
ModelType = TypeVar("ModelType") # ORM 模型类型
class BaseRepository(Generic[ModelType]):
"""
泛型仓储基类
这个类实现了通用的 CRUD 操作,具体的 Repository 可以继承它,
自动获得这些基础功能,只需要添加特有的查询方法。
"""
def __init__(self, model: Type[ModelType], db: Session):
"""
Args:
model: ORM 模型类,比如 User、Order
db: 数据库会话
"""
self.model = model
self.db = db
def get_by_id(self, id: Any) -> Optional[ModelType]:
"""根据主键获取实体"""
# 自动检测模型的主键字段
pk_columns = inspect(self.model).primary_key
if len(pk_columns) == 1:
pk_name = pk_columns[0].name
return self.db.query(self.model).filter(
getattr(self.model, pk_name) == id
).first()
else:
# 复合主键的情况
filters = [
getattr(self.model, col.name) == id[col.name]
for col in pk_columns
]
return self.db.query(self.model).filter(*filters).first()
def get_all(
self,
skip: int = 0,
limit: int = 100,
**filters
) -> List[ModelType]:
"""
获取实体列表,支持分页和简单过滤
filters 参数允许传入任意字段的过滤条件,比如:
get_all(is_active=True, category_id=1)
"""
query = self.db.query(self.model)
# 应用过滤条件
for field, value in filters.items():
if value is not None and hasattr(self.model, field):
query = query.filter(getattr(self.model, field) == value)
return query.offset(skip).limit(limit).all()
def count(self, **filters) -> int:
"""统计实体数量"""
query = self.db.query(self.model)
for field, value in filters.items():
if value is not None and hasattr(self.model, field):
query = query.filter(getattr(self.model, field) == value)
return query.count()
def exists(self, id: Any) -> bool:
"""检查实体是否存在"""
return self.get_by_id(id) is not None
def add(self, entity: ModelType) -> ModelType:
"""添加实体"""
self.db.add(entity)
self.db.flush()
self.db.refresh(entity)
return entity
def add_many(self, entities: List[ModelType]) -> List[ModelType]:
"""批量添加实体"""
self.db.add_all(entities)
self.db.flush()
for entity in entities:
self.db.refresh(entity)
return entities
def update(self, entity: ModelType) -> ModelType:
"""更新实体"""
self.db.flush()
self.db.refresh(entity)
return entity
def delete(self, id: Any) -> bool:
"""删除实体"""
entity = self.get_by_id(id)
if entity:
self.db.delete(entity)
return True
return False
def delete_many(self, ids: List[Any]) -> int:
"""批量删除实体,返回实际删除的数量"""
pk_columns = inspect(self.model).primary_key
pk_name = pk_columns[0].name
count = self.db.query(self.model).filter(
getattr(self.model, pk_name).in_(ids)
).delete(synchronize_session=False)
return count
4.2 继承基类创建具体仓储
现在,创建具体的 Repository 变得非常简单:
python
from typing import Optional, List
from sqlalchemy.orm import Session
from sqlalchemy import or_
from models import User, Order, Product
from repositories.base import BaseRepository
class UserRepository(BaseRepository[User]):
"""
用户仓储
继承自 BaseRepository[User],自动获得所有基础 CRUD 方法。
这里只需要添加 User 特有的查询方法。
"""
def __init__(self, db: Session):
super().__init__(User, db)
# ===== User 特有的方法 =====
def get_by_username(self, username: str) -> Optional[User]:
"""根据用户名查询"""
return self.db.query(User).filter(User.username == username).first()
def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱查询"""
return self.db.query(User).filter(User.email == email).first()
def search(
self,
keyword: str,
skip: int = 0,
limit: int = 100
) -> List[User]:
"""搜索用户(在用户名和邮箱中模糊匹配)"""
return self.db.query(User).filter(
or_(
User.username.ilike(f"%{keyword}%"),
User.email.ilike(f"%{keyword}%")
)
).offset(skip).limit(limit).all()
def exists_by_username(self, username: str) -> bool:
"""检查用户名是否存在"""
return self.db.query(
self.db.query(User).filter(User.username == username).exists()
).scalar()
def exists_by_email(self, email: str) -> bool:
"""检查邮箱是否存在"""
return self.db.query(
self.db.query(User).filter(User.email == email).exists()
).scalar()
class OrderRepository(BaseRepository[Order]):
"""订单仓储"""
def __init__(self, db: Session):
super().__init__(Order, db)
def get_by_user(
self,
user_id: int,
status: Optional[str] = None,
skip: int = 0,
limit: int = 100
) -> List[Order]:
"""获取用户的订单"""
query = self.db.query(Order).filter(Order.user_id == user_id)
if status:
query = query.filter(Order.status == status)
return query.order_by(Order.created_at.desc()).offset(skip).limit(limit).all()
def get_recent_orders(self, days: int = 7) -> List[Order]:
"""获取最近几天的订单"""
from datetime import datetime, timedelta
since = datetime.now() - timedelta(days=days)
return self.db.query(Order).filter(Order.created_at >= since).all()
def calculate_user_total(self, user_id: int) -> float:
"""计算用户的累计消费金额"""
from sqlalchemy import func
result = self.db.query(func.sum(Order.total_price)).filter(
Order.user_id == user_id,
Order.status == "completed"
).scalar()
return result or 0.0
class ProductRepository(BaseRepository[Product]):
"""商品仓储"""
def __init__(self, db: Session):
super().__init__(Product, db)
def get_by_category(
self,
category_id: int,
in_stock: bool = True
) -> List[Product]:
"""获取某分类下的商品"""
query = self.db.query(Product).filter(Product.category_id == category_id)
if in_stock:
query = query.filter(Product.stock > 0)
return query.all()
def search_by_name(self, keyword: str) -> List[Product]:
"""搜索商品"""
return self.db.query(Product).filter(
Product.name.ilike(f"%{keyword}%")
).all()
def decrease_stock(self, product_id: int, quantity: int) -> bool:
"""
扣减库存
使用乐观锁防止超卖:只有 stock >= quantity 时才会更新成功
"""
result = self.db.query(Product).filter(
Product.id == product_id,
Product.stock >= quantity
).update(
{Product.stock: Product.stock - quantity},
synchronize_session=False
)
return result > 0
看到了吗?每个具体的 Repository 只需要几行代码就定义好了,因为基础方法都继承自 BaseRepository。特有的业务查询方法单独定义,代码量大大减少,而且结构清晰。
4.3 泛型仓储的优缺点
| 优点 | 缺点 |
|---|---|
| 减少重复代码 | 复杂查询仍需单独实现 |
| 统一的 CRUD 接口 | 增加一层抽象 |
新增实体 Repository 很快 |
泛型可能让新手困惑 |
| 便于统一添加日志、审计等功能 | 过度抽象可能适得其反 |
五、工作单元模式配合仓储
前面的 Repository 实现里,我们只用了 flush() 而没有 commit()。这是故意的------事务控制应该由更上层来负责。工作单元模式(Unit of Work)就是专门管理事务的。
5.1 什么是工作单元
工作单元的核心思想是:把一组相关的数据库操作放在一起,要么全部成功,要么全部失败。
比如创建订单这个业务,需要:
- 创建订单记录
- 更新商品库存
- 创建支付记录
这三个操作必须是原子的------要么都成功,要么都回滚。如果订单创建成功但库存更新失败,数据就不一致了。
工作单元负责:
- 追踪一个业务操作涉及的所有数据变更
- 在合适的时机统一提交(
commit)或回滚(rollback) - 确保事务的完整性
5.2 实现工作单元
python
from typing import Optional
from sqlalchemy.orm import Session
from repositories.user_repository import UserRepository
from repositories.order_repository import OrderRepository
from repositories.product_repository import ProductRepository
class UnitOfWork:
"""
工作单元
管理多个 Repository 的操作,确保这些操作在同一个事务中执行。
典型用法:
with UnitOfWork(db) as uow:
user = uow.users.get_by_id(1)
order = Order(user_id=user.id, ...)
uow.orders.add(order)
uow.commit()
"""
def __init__(self, session: Session):
self._session = session
# 懒加载 Repository
self._users: Optional[UserRepository] = None
self._orders: Optional[OrderRepository] = None
self._products: Optional[ProductRepository] = None
@property
def users(self) -> UserRepository:
"""用户仓储(懒加载)"""
if self._users is None:
self._users = UserRepository(self._session)
return self._users
@property
def orders(self) -> OrderRepository:
"""订单仓储"""
if self._orders is None:
self._orders = OrderRepository(self._session)
return self._orders
@property
def products(self) -> ProductRepository:
"""商品仓储"""
if self._products is None:
self._products = ProductRepository(self._session)
return self._products
def __enter__(self):
"""进入上下文管理器"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""退出上下文管理器,异常时自动回滚"""
if exc_type is not None:
self.rollback()
def commit(self):
"""提交事务"""
try:
self._session.commit()
except Exception:
self.rollback()
raise
def rollback(self):
"""回滚事务"""
self._session.rollback()
def flush(self):
"""刷新会话(不提交)"""
self._session.flush()
5.3 在服务层使用工作单元
python
from fastapi import HTTPException, status
from unit_of_work import UnitOfWork
from models import Order
from schemas import OrderCreate
class OrderService:
"""订单服务"""
def __init__(self, uow: UnitOfWork):
self.uow = uow
def create_order(self, order_data: OrderCreate) -> Order:
"""
创建订单
涉及多个表(订单、库存),必须在一个事务中完成
"""
# 1. 检查用户是否存在
user = self.uow.users.get_by_id(order_data.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
# 2. 检查商品并计算总价
total_price = 0
order_items = []
for item in order_data.items:
product = self.uow.products.get_by_id(item.product_id)
if not product:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"商品 {item.product_id} 不存在"
)
if product.stock < item.quantity:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"商品 {product.name} 库存不足"
)
subtotal = product.price * item.quantity
total_price += subtotal
order_items.append({
'product': product,
'quantity': item.quantity,
'price': product.price
})
# 3. 创建订单
order = Order(
user_id=user.id,
total_price=total_price,
status="pending"
)
self.uow.orders.add(order)
self.uow.flush() # 获取订单 ID
# 4. 扣减库存
for item_data in order_items:
product = item_data['product']
quantity = item_data['quantity']
success = self.uow.products.decrease_stock(product.id, quantity)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"商品 {product.name} 库存不足"
)
# 5. 提交事务
self.uow.commit()
return order
看到没?UnitOfWork 把多个 Repository 组织在一起,服务层通过 uow.users、uow.orders、uow.products 访问不同的仓储,最后统一 commit()。任何一步失败,整个事务都会回滚。
5.4 配置依赖注入
python
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session
from database import get_db
from unit_of_work import UnitOfWork
from services.order_service import OrderService
app = FastAPI()
def get_uow(db: Session = Depends(get_db)) -> UnitOfWork:
"""每个请求一个工作单元"""
return UnitOfWork(db)
def get_order_service(uow: UnitOfWork = Depends(get_uow)) -> OrderService:
"""创建订单服务"""
return OrderService(uow)
@app.post("/orders")
def create_order(
order_data: OrderCreate,
service: OrderService = Depends(get_order_service)
):
return service.create_order(order_data)
六、FastAPI完整实战
把前面讲的内容整合起来,看看一个完整的项目结构:
6.1 项目结构
myapp/
├── main.py # 应用入口
├── config.py # 配置
├── database.py # 数据库连接
├── unit_of_work.py # 工作单元
├── dependencies.py # 依赖注入配置
│
├── models/ # ORM 模型
│ ├── __init__.py
│ ├── user.py
│ ├── order.py
│ └── product.py
│
├── schemas/ # Pydantic 模型
│ ├── __init__.py
│ ├── user.py
│ ├── order.py
│ └── product.py
│
├── repositories/ # 仓储层
│ ├── __init__.py
│ ├── base.py # 泛型基类
│ ├── user_repository.py
│ ├── order_repository.py
│ └── product_repository.py
│
├── services/ # 服务层
│ ├── __init__.py
│ ├── user_service.py
│ ├── order_service.py
│ └── product_service.py
│
└── routers/ # 路由层
├── __init__.py
├── user_router.py
├── order_router.py
└── product_router.py
6.2 依赖注入配置
python
# dependencies.py
from fastapi import Depends
from sqlalchemy.orm import Session
from database import get_db
from unit_of_work import UnitOfWork
from services.user_service import UserService
from services.order_service import OrderService
from services.product_service import ProductService
def get_uow(db: Session = Depends(get_db)) -> UnitOfWork:
"""获取工作单元"""
return UnitOfWork(db)
def get_user_service(uow: UnitOfWork = Depends(get_uow)) -> UserService:
"""获取用户服务"""
return UserService(uow)
def get_order_service(uow: UnitOfWork = Depends(get_uow)) -> OrderService:
"""获取订单服务"""
return OrderService(uow)
def get_product_service(uow: UnitOfWork = Depends(get_uow)) -> ProductService:
"""获取商品服务"""
return ProductService(uow)
6.3 路由示例
python
# routers/user_router.py
from typing import Optional
from fastapi import APIRouter, Depends, status
from dependencies import get_user_service
from services.user_service import UserService
from schemas.user import UserCreate, UserUpdate, UserResponse, UserListResponse
router = APIRouter(prefix="/users", tags=["用户管理"])
@router.get("/", response_model=UserListResponse)
def list_users(
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
service: UserService = Depends(get_user_service)
):
"""获取用户列表"""
return service.get_users(skip=skip, limit=limit, is_active=is_active)
@router.get("/{user_id}", response_model=UserResponse)
def get_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""获取用户详情"""
return service.get_user(user_id)
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def create_user(
user_data: UserCreate,
service: UserService = Depends(get_user_service)
):
"""创建用户"""
return service.create_user(user_data)
@router.put("/{user_id}", response_model=UserResponse)
def update_user(
user_id: int,
user_data: UserUpdate,
service: UserService = Depends(get_user_service)
):
"""更新用户"""
return service.update_user(user_id, user_data)
@router.delete("/{user_id}")
def delete_user(
user_id: int,
service: UserService = Depends(get_user_service)
):
"""删除用户"""
return service.delete_user(user_id)
6.4 主应用
python
# main.py
from fastapi import FastAPI
from routers import user_router, order_router, product_router
app = FastAPI(
title="电商API",
description="使用仓储模式的 FastAPI 示例",
version="1.0.0"
)
app.include_router(user_router.router)
app.include_router(order_router.router)
app.include_router(product_router.router)
@app.get("/")
def root():
return {"message": "Welcome to the API"}
@app.get("/health")
def health():
return {"status": "healthy"}
七、仓储模式 vs 直接用ORM
看到这里,你可能会想:这搞得这么复杂,直接用 SQLAlchemy 不是更简单?
这是个好问题。仓储模式不是银弹,它有成本也有收益,需要根据项目情况权衡。
7.1 直接使用 ORM 的场景
适合的情况:
- 项目较小,CRUD 为主,业务逻辑简单
- 团队对 ORM 熟悉,开发速度优先
- 不太可能更换数据库或 ORM
- 不需要严格的测试覆盖
python
# 直接用 ORM,简单直接
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404)
return user
这样写没什么问题。如果你的项目就是几个简单接口,引入仓储模式反而是过度设计。
7.2 使用仓储模式的场景
适合的情况:
- 项目较大,业务逻辑复杂
- 需要良好的测试覆盖
- 可能更换底层存储
- 团队协作,需要清晰的代码边界
- 追求代码的可维护性和可扩展性
7.3 对比总结
| 方面 | 直接用ORM | 仓储模式 |
|---|---|---|
| 代码量 | 少 | 多 |
| 学习成本 | 低 | 中等 |
| 开发速度(初期) | 快 | 慢一点 |
| 维护成本(后期) | 高 | 低 |
| 可测试性 | 差 | 好 |
| 灵活性 | 差 | 好 |
| 适合项目规模 | 小型 | 中大型 |
7.4 我的建议
| 项目类型 | 建议 |
|---|---|
| 小项目(几个简单接口) | 直接用 ORM |
| 中型项目 | 至少把数据访问逻辑提取到单独模块 |
| 大型项目 | 完整分层架构 + 工作单元 |
| 需要高测试覆盖 | 仓储模式 |
八、测试中的巨大优势
仓储模式最大的优势之一就是方便测试。
8.1 不使用仓储模式的测试困境
python
# 业务代码直接操作数据库
class UserService:
def __init__(self, db: Session):
self.db = db
def create_user(self, username: str, email: str, password: str):
if self.db.query(User).filter(User.username == username).first():
raise ValueError("用户名已存在")
user = User(username=username, email=email, password=password)
self.db.add(user)
self.db.commit()
return user
# 测试代码
def test_create_user():
# 必须准备真实数据库
db = TestSessionLocal()
service = UserService(db)
# 清理可能存在的测试数据
db.query(User).filter(User.username == "testuser").delete()
db.commit()
# 执行测试
user = service.create_user("testuser", "test@example.com", "password")
assert user.username == "testuser"
# 清理
db.delete(user)
db.commit()
db.close()
问题很多:需要真实数据库、测试数据要准备和清理、测试速度慢、测试之间可能互相影响。
8.2 使用仓储模式的测试
创建一个内存实现的"假仓储":
python
from typing import Optional, List, Dict
from models import User
class FakeUserRepository:
"""假的用户仓储,用于测试"""
def __init__(self):
self._users: Dict[int, User] = {}
self._next_id = 1
def get_by_id(self, user_id: int) -> Optional[User]:
return self._users.get(user_id)
def get_by_username(self, username: str) -> Optional[User]:
for user in self._users.values():
if user.username == username:
return user
return None
def add(self, user: User) -> User:
user.id = self._next_id
self._next_id += 1
self._users[user.id] = user
return user
def exists_by_username(self, username: str) -> bool:
return self.get_by_username(username) is not None
def clear(self):
"""清空所有数据"""
self._users.clear()
self._next_id = 1
测试变得简单快速:
python
import pytest
from services.user_service import UserService
class FakeUnitOfWork:
"""假的工作单元"""
def __init__(self):
self.users = FakeUserRepository()
self.committed = False
def commit(self):
self.committed = True
def rollback(self):
pass
class TestUserService:
@pytest.fixture
def uow(self):
return FakeUnitOfWork()
@pytest.fixture
def service(self, uow):
return UserService(uow)
def test_create_user_success(self, service, uow):
"""测试成功创建用户"""
user_data = UserCreate(
username="testuser",
email="test@example.com",
password="password123"
)
result = service.create_user(user_data)
assert result.username == "testuser"
assert uow.committed
def test_create_user_duplicate_username(self, service, uow):
"""测试用户名重复"""
# 先添加一个用户
existing_user = User(
username="existinguser",
email="existing@example.com",
hashed_password="xxx"
)
uow.users.add(existing_user)
# 尝试创建同名用户
user_data = UserCreate(
username="existinguser",
email="new@example.com",
password="password123"
)
with pytest.raises(HTTPException) as exc_info:
service.create_user(user_data)
assert exc_info.value.status_code == 400
assert "用户名已存在" in exc_info.value.detail
8.3 FastAPI 集成测试
python
import pytest
from fastapi.testclient import TestClient
from main import app
from dependencies import get_uow
class TestUserAPI:
@pytest.fixture
def client(self):
fake_uow = FakeUnitOfWork()
# 预设测试数据
fake_uow.users.add(User(
username="existinguser",
email="existing@example.com",
hashed_password="xxx"
))
# 覆盖依赖
app.dependency_overrides[get_uow] = lambda: fake_uow
client = TestClient(app)
yield client
app.dependency_overrides.clear()
def test_create_user(self, client):
response = client.post("/users/", json={
"username": "newuser",
"email": "new@example.com",
"password": "password123"
})
assert response.status_code == 201
assert response.json()["username"] == "newuser"
8.4 测试的好处总结
| 好处 | 说明 |
|---|---|
| 快速 | 内存运行,毫秒级完成 |
| 隔离 | 每个测试有自己的假仓储 |
| 稳定 | 不受网络、数据库状态影响 |
| 简单 | 准备测试数据很方便 |
| 专注 | 单独测试业务逻辑 |
九、常见问题与最佳实践
9.1 Repository 应该返回什么?
返回 ORM 模型还是 DTO?
建议:Repository 返回 ORM 模型,转换成 DTO 是 Service 层的事。
python
# Repository 返回 ORM 对象
class UserRepository:
def get_by_id(self, user_id: int) -> Optional[User]:
return self.db.query(User).filter(User.id == user_id).first()
# Service 层做转换
class UserService:
def get_user(self, user_id: int) -> UserResponse:
user = self.repo.get_by_id(user_id) # User(ORM)
return UserResponse.model_validate(user) # UserResponse(DTO)
9.2 复杂查询放在哪里?
原则:和数据获取相关的放 Repository,和业务规则相关的放 Service。
python
# Repository:复杂的数据查询
class OrderRepository:
def get_monthly_stats(self, year: int, month: int) -> dict:
"""获取月度订单统计(数据查询)"""
from sqlalchemy import func
result = self.db.query(
func.count(Order.id).label('count'),
func.sum(Order.total_price).label('total')
).filter(
func.extract('year', Order.created_at) == year,
func.extract('month', Order.created_at) == month,
Order.status == 'completed'
).first()
return {
'count': result.count or 0,
'total': float(result.total or 0)
}
# Service:基于数据的业务计算
class ReportService:
def generate_monthly_report(self, year: int, month: int) -> dict:
"""生成月度报告(业务逻辑)"""
stats = self.uow.orders.get_monthly_stats(year, month)
# 业务计算
avg_order_value = (
stats['total'] / stats['count']
if stats['count'] > 0 else 0
)
return {
'order_count': stats['count'],
'total_revenue': stats['total'],
'avg_order_value': avg_order_value
}
9.3 如何处理 N+1 问题?
使用 joinedload 或 selectinload 预加载关联数据:
python
from sqlalchemy.orm import joinedload, selectinload
class OrderRepository:
def get_with_items(self, order_id: int) -> Optional[Order]:
"""获取订单及其订单项(预加载)"""
return self.db.query(Order).options(
selectinload(Order.items).selectinload(OrderItem.product)
).filter(Order.id == order_id).first()
9.4 仓储应该有多细粒度?
一个实体一个仓储是常见做法。但紧密相关的实体可以共用:
python
# 聚合根模式:订单项通过订单仓储访问
class OrderRepository:
def add_item(self, order_id: int, item: OrderItem): ...
def remove_item(self, order_id: int, item_id: int): ...
def get_items(self, order_id: int) -> List[OrderItem]: ...
9.5 性能优化建议
批量操作:避免循环中一个个操作数据库
python
# 不好:N次数据库操作
for user_id in user_ids:
user = repo.get_by_id(user_id)
# 好:1次数据库操作
users = repo.get_by_ids(user_ids)
分页:始终使用分页
python
# 好
users = repo.get_all(skip=0, limit=100)
# 不好
users = repo.get_all() # 可能有几万条
十、总结
这篇文章把仓储模式的各个方面都讲到了,总结一下关键点:
仓储模式是什么
它是数据访问层的抽象,把"怎么存数据"的细节藏起来,对外提供类似"集合"的接口。Repository 就像图书馆管理员,业务代码只需要说"我要这本书",不用管书放在哪个架子上。
为什么要用仓储模式
主要是为了解耦、可测试、可维护。数据访问逻辑集中管理,业务层不依赖具体的数据库实现,测试时可以用假实现替换真数据库。
怎么实现
从最基础的 Repository 类开始,然后引入泛型基类减少重复代码,再配合工作单元模式管理事务。
什么时候用
小项目直接用 ORM 就够了,中大型项目、需要高测试覆盖的项目建议使用。
关键要点总结
| 概念 | 说明 | 适用场景 |
|---|---|---|
Repository |
数据访问抽象 | 封装查询逻辑 |
| 泛型仓储 | 减少重复的 CRUD 代码 | 多实体项目 |
| 工作单元 | 事务管理 | 多表操作 |
| 假仓储 | 测试用的内存实现 | 单元测试 |
仓储模式 + 工厂模式 + 依赖注入 + 单例,这几个模式组合起来,形成了完整的分层架构。
下一篇打算讲观察者模式,这个在事件驱动的场景(比如发消息通知、触发异步任务)非常有用。在 FastAPI 里配合背景任务或消息队列使用,能让系统架构更加松耦合。
热门专栏推荐
- Agent小册
- 服务器部署
- Java基础合集
- Python基础合集
- Go基础合集
- 大数据合集
- 前端小册
- 数据库合集
- Redis 合集
- Spring 全家桶
- 微服务全家桶
- 数据结构与算法合集
- 设计模式小册
- 消息队列合集
等等等还有许多优秀的合集在主页等着大家的光顾,感谢大家的支持
文章到这里就结束了,如果有什么疑问的地方请指出,诸佬们一起来评论区一起讨论😊
希望能和诸佬们一起努力,今后我们一起观看感谢您的阅读🙏
如果帮助到您不妨3连支持一下,创造不易您们的支持是我的动力🌟