【设计模式】Python仓储模式:从入门到实战

Python仓储模式:让数据访问不再混乱

前言

这篇是设计模式系列的学习笔记,这次来聊聊仓储模式(Repository Pattern)。

上一篇讲了工厂模式和依赖注入,最后提到了分层架构里的 Repository 层。这篇就专门把数据访问这块的设计讲透。

说实话,很多人写 FastAPI 项目的时候,习惯在路由函数里直接写 SQLAlchemy 查询,一个接口几十行,查询、过滤、分页全揉在一起。项目小的时候还好,一旦业务复杂起来,代码就变成一锅粥了------同样的查询逻辑写了好几遍,改个字段名得改十几个地方,测试更是没法写。

仓储模式就是解决这个问题的。它把数据访问逻辑封装起来,让业务层不用关心数据是怎么存的、怎么取的。听起来是不是有点像工厂模式?没错,设计模式之间经常是互相配合的,仓储模式和工厂模式、依赖注入一起用,能让代码架构非常清晰。

这篇文章会从"为什么需要"讲起,然后一步步实现一个完整的仓储层,包括泛型仓储、工作单元模式这些进阶内容。内容比较多,但都是实战中会用到的,耐心看完肯定有收获。

🏠个人主页:山沐与山


文章目录


一、仓储模式是什么

1.1 从一个比喻说起

想象一下图书馆。你想借一本书,需要知道这本书具体放在哪个架子的第几层吗?不需要,你只需要告诉图书管理员书名,管理员会帮你找到并拿给你。还书的时候也一样,你不用操心书该放回哪里,交给管理员就行。

仓储模式里的 Repository 就是这个"图书管理员"。你的业务代码只需要说"我要 ID 为 1 的用户"或者"帮我保存这个订单",至于数据是存在 MySQL 还是 PostgreSQL,用的是什么 ORM,业务代码完全不需要知道。

1.2 正式定义

仓储模式(Repository Pattern)是一种将数据访问逻辑与业务逻辑分离的设计模式。它提供一个类似"集合"的接口来访问领域对象,让业务层可以像操作内存中的集合一样操作数据,而不需要关心底层的数据存储细节。

用更直白的话说:Repository 是数据层的抽象,它把"怎么存数据"这件事藏起来,只暴露"存什么、取什么"的接口

1.3 仓储模式的核心思想

仓储模式有几个核心理念:

第一,隔离数据访问细节 。业务层不直接和数据库打交道,而是通过 Repository 这个中间层。这样数据库换了(比如从 MySQL 换到 PostgreSQL),或者 ORM 换了(比如从 SQLAlchemy 换到 Tortoise),业务层的代码不用改。

第二,统一数据访问接口 。不管底层是关系型数据库、NoSQL、文件系统还是远程 API,Repository 对外提供的接口是一致的。这让业务代码变得简洁,也让测试变得容易(可以用内存实现替换真实数据库)。

第三,集中管理查询逻辑 。所有和某个实体相关的查询都放在对应的 Repository 里。想找"获取活跃用户"的逻辑?去 UserRepository 里找就行,不用在几十个文件里翻。

1.4 仓储模式在分层架构中的位置

在典型的分层架构中,Repository 处于数据访问层,向上对接服务层(Service),向下对接具体的数据存储:

复制代码
┌─────────────────────────────────────────────────────┐
│                   Presentation Layer                 │
│              (Controllers / Routers)                 │
│         处理HTTP请求,调用Service,返回响应            │
└───────────────────────┬─────────────────────────────┘
                        │
                        ▼
┌─────────────────────────────────────────────────────┐
│                   Service Layer                      │
│                  (Business Logic)                    │
│           业务逻辑,调用Repository获取数据            │
└───────────────────────┬─────────────────────────────┘
                        │
                        ▼
┌─────────────────────────────────────────────────────┐
│                 Repository Layer                     │
│                (Data Access Logic)                   │
│        封装数据访问,提供类集合的操作接口              │
└───────────────────────┬─────────────────────────────┘
                        │
                        ▼
┌─────────────────────────────────────────────────────┐
│               Data Storage Layer                     │
│        (Database / Cache / External API)             │
│              实际的数据存储                           │
└─────────────────────────────────────────────────────┘

每一层只和相邻的层打交道,这样修改任何一层都不会影响其他层(只要接口不变)。


二、不用仓储模式会怎样

在讲怎么实现之前,先看看不用仓储模式会遇到什么问题。这样你才能理解为什么要引入这个模式。

2.1 典型的"面条代码"

很多 FastAPI 项目一开始是这样写的,所有逻辑都塞在路由函数里:

python 复制代码
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from database import get_db
from models import User, Order, Product

app = FastAPI()


@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
    # 直接在路由里写查询逻辑
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return user


@app.get("/users")
def list_users(
    skip: int = 0,
    limit: int = 10,
    is_active: bool = None,
    keyword: str = None,
    db: Session = Depends(get_db)
):
    # 查询逻辑越来越复杂
    query = db.query(User)

    if is_active is not None:
        query = query.filter(User.is_active == is_active)

    if keyword:
        query = query.filter(
            or_(
                User.username.ilike(f"%{keyword}%"),
                User.email.ilike(f"%{keyword}%")
            )
        )

    total = query.count()
    users = query.offset(skip).limit(limit).all()

    return {"total": total, "items": users}


@app.get("/users/{user_id}/orders")
def get_user_orders(
    user_id: int,
    status: str = None,
    start_date: str = None,
    end_date: str = None,
    db: Session = Depends(get_db)
):
    # 又是一大段查询逻辑
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")

    query = db.query(Order).filter(Order.user_id == user_id)

    if status:
        query = query.filter(Order.status == status)

    if start_date:
        query = query.filter(Order.created_at >= start_date)

    if end_date:
        query = query.filter(Order.created_at <= end_date)

    return query.all()


@app.post("/orders")
def create_order(
    user_id: int,
    product_id: int,
    quantity: int,
    db: Session = Depends(get_db)
):
    # 更多的数据库操作...
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")

    product = db.query(Product).filter(Product.id == product_id).first()
    if not product:
        raise HTTPException(status_code=404, detail="商品不存在")

    if product.stock < quantity:
        raise HTTPException(status_code=400, detail="库存不足")

    # 创建订单
    order = Order(
        user_id=user_id,
        product_id=product_id,
        quantity=quantity,
        total_price=product.price * quantity
    )
    db.add(order)

    # 扣减库存
    product.stock -= quantity

    db.commit()
    db.refresh(order)

    return order

看起来能跑,但问题其实很多。

2.2 问题分析

问题一:代码重复严重

"根据 ID 查用户"这个逻辑在 get_userget_user_orderscreate_order 里都写了一遍。如果以后查询条件变了(比如要加上软删除过滤),得改好几个地方。累不累?

python 复制代码
# 这段代码重复了三次
user = db.query(User).filter(User.id == user_id).first()
if not user:
    raise HTTPException(status_code=404, detail="用户不存在")

问题二:业务逻辑和数据访问逻辑混在一起

路由函数本来应该只负责"接收请求、返回响应",现在还要处理复杂的数据库查询。一个函数做了太多事情,违反单一职责原则。

更糟糕的是,业务规则(比如"库存不足不能下单")和数据操作(查询、更新)交织在一起,代码的意图变得不清晰,后来维护的人很难理解这段代码到底在干什么。

问题三:难以测试

想测试 create_order 的业务逻辑,必须准备一个真实的数据库,还得提前插入用户和商品数据。测试又慢又脆弱,稍微改点东西测试就挂了。

如果能把数据访问抽象出来,测试时用个假的实现(返回固定数据),就能快速验证业务逻辑是否正确。

问题四:难以复用

假设有另一个地方也需要"获取用户的订单列表",怎么办?把那一大段代码复制过去?还是提取成函数?提取成函数放在哪里?

没有统一的规范,代码组织会越来越乱,每个人都有自己的写法,最后就是一团糟。

问题五:切换数据源困难

现在用的是 SQLAlchemyPostgreSQL,哪天要加个 Redis 缓存,或者某些数据要从外部 API 获取,改动会非常大。因为数据库操作散落在各处,没有统一的抽象层。

2.3 我们需要什么

总结一下,我们需要:

需求 说明
统一的地方 放所有和某个实体相关的数据访问逻辑
抽象的接口 让业务层不依赖具体的数据库实现
易于测试 可以用假实现替换真实数据库
可复用 查询逻辑写一次到处用

这就是仓储模式要解决的问题。


三、仓储模式的基本实现

理解了问题,现在来看解决方案。

3.1 定义仓储接口

首先,定义一个抽象的接口,规定 Repository 应该提供哪些操作。Python 里可以用 ProtocolABC 来定义接口。

Protocol 是 Python 3.8+ 引入的,它实现的是"结构化子类型"(structural subtyping),也叫鸭子类型的静态版本。只要一个类实现了 Protocol 定义的方法,就被认为是该 Protocol 的实现,不需要显式继承。这比 ABC 更灵活。

python 复制代码
from typing import Protocol, TypeVar, Generic, Optional, List

# 定义泛型类型变量
T = TypeVar('T')   # 实体类型
ID = TypeVar('ID') # ID类型


class IRepository(Protocol[T, ID]):
    """
    仓储接口

    这是一个泛型协议,T 是实体类型,ID 是主键类型。
    所有具体的 Repository 都应该实现这个接口定义的方法。
    """

    def get_by_id(self, id: ID) -> Optional[T]:
        """根据ID获取单个实体,不存在返回 None"""
        ...

    def get_all(self) -> List[T]:
        """获取所有实体"""
        ...

    def add(self, entity: T) -> T:
        """添加一个实体"""
        ...

    def update(self, entity: T) -> T:
        """更新一个实体"""
        ...

    def delete(self, id: ID) -> bool:
        """删除一个实体,返回是否成功"""
        ...

这个接口定义了最基本的 CRUD 操作(Create、Read、Update、Delete)。注意,这里没有任何和具体数据库相关的代码,它只是一个"契约",规定实现类必须提供这些方法。

3.2 实现具体的仓储

有了接口,接下来实现具体的 Repository。这里以用户仓储为例,使用 SQLAlchemy 作为 ORM:

python 复制代码
from typing import Optional, List
from sqlalchemy.orm import Session
from models import User


class UserRepository:
    """
    用户仓储的 SQLAlchemy 实现

    这个类负责所有和 User 实体相关的数据库操作。
    业务层通过这个类来访问用户数据,不需要知道底层用的是什么数据库。
    """

    def __init__(self, db: Session):
        """
        构造函数

        Args:
            db: SQLAlchemy 的数据库会话,通过依赖注入传入
        """
        self.db = db

    def get_by_id(self, user_id: int) -> Optional[User]:
        """
        根据ID获取用户

        返回 Optional 类型意味着可能返回 None,
        调用方需要处理用户不存在的情况。
        """
        return self.db.query(User).filter(User.id == user_id).first()

    def get_by_username(self, username: str) -> Optional[User]:
        """根据用户名获取用户"""
        return self.db.query(User).filter(User.username == username).first()

    def get_by_email(self, email: str) -> Optional[User]:
        """根据邮箱获取用户"""
        return self.db.query(User).filter(User.email == email).first()

    def get_all(
        self,
        skip: int = 0,
        limit: int = 100,
        is_active: Optional[bool] = None
    ) -> List[User]:
        """
        获取用户列表

        支持分页和过滤。在实际项目中,这个方法的参数可能会更多,
        比如排序字段、排序方向、多种过滤条件等。
        """
        query = self.db.query(User)

        # 条件过滤
        if is_active is not None:
            query = query.filter(User.is_active == is_active)

        # 分页
        return query.offset(skip).limit(limit).all()

    def count(self, is_active: Optional[bool] = None) -> int:
        """统计用户数量"""
        query = self.db.query(User)

        if is_active is not None:
            query = query.filter(User.is_active == is_active)

        return query.count()

    def add(self, user: User) -> User:
        """
        添加用户

        注意这里只是把对象加到 session 里,
        真正写入数据库是在 commit 时。
        """
        self.db.add(user)
        self.db.flush()   # flush 会生成 ID,但不提交事务
        self.db.refresh(user)  # 刷新对象,获取数据库生成的值
        return user

    def update(self, user: User) -> User:
        """更新用户"""
        self.db.flush()
        self.db.refresh(user)
        return user

    def delete(self, user_id: int) -> bool:
        """删除用户"""
        user = self.get_by_id(user_id)
        if user:
            self.db.delete(user)
            return True
        return False

    def exists_by_username(self, username: str) -> bool:
        """检查用户名是否存在"""
        return self.db.query(
            self.db.query(User).filter(User.username == username).exists()
        ).scalar()

    def exists_by_email(self, email: str) -> bool:
        """检查邮箱是否存在"""
        return self.db.query(
            self.db.query(User).filter(User.email == email).exists()
        ).scalar()

这个实现有几点值得注意:

构造函数接收 SessionRepository 不自己创建数据库连接,而是从外部传入。这是依赖注入的体现,让 Repository 可以在不同的上下文中复用(比如测试时传入一个内存数据库的 Session)。

方法命名清晰get_by_idget_by_username 这样的命名一眼就能看出方法的作用。Repository 的方法名应该反映业务语义,而不是技术细节。

分离了 flushcommitRepository 只负责数据操作,不负责事务控制(commit)。事务的边界应该由更上层(Service 层或工作单元)来决定,这样可以把多个操作放在一个事务里。

3.3 在服务层使用仓储

有了 Repository,服务层的代码就变得简洁多了:

python 复制代码
from typing import Optional, List
from fastapi import HTTPException, status

from repositories.user_repository import UserRepository
from models import User
from schemas import UserCreate, UserUpdate, UserResponse


class UserService:
    """
    用户服务

    服务层负责业务逻辑,它通过 Repository 来访问数据,
    而不直接操作数据库。
    """

    def __init__(self, user_repo: UserRepository):
        """通过依赖注入接收 Repository"""
        self.user_repo = user_repo

    def get_user(self, user_id: int) -> UserResponse:
        """获取用户信息"""
        user = self.user_repo.get_by_id(user_id)

        if not user:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="用户不存在"
            )

        return UserResponse.model_validate(user)

    def get_users(
        self,
        skip: int = 0,
        limit: int = 100,
        is_active: Optional[bool] = None
    ) -> dict:
        """获取用户列表(带分页)"""
        users = self.user_repo.get_all(skip=skip, limit=limit, is_active=is_active)
        total = self.user_repo.count(is_active=is_active)

        return {
            "total": total,
            "items": [UserResponse.model_validate(u) for u in users]
        }

    def create_user(self, user_data: UserCreate) -> UserResponse:
        """创建用户"""
        # 业务规则:检查用户名是否已存在
        if self.user_repo.exists_by_username(user_data.username):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="用户名已存在"
            )

        # 业务规则:检查邮箱是否已存在
        if self.user_repo.exists_by_email(user_data.email):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="邮箱已被注册"
            )

        # 创建用户实体
        user = User(
            username=user_data.username,
            email=user_data.email,
            hashed_password=self._hash_password(user_data.password)
        )

        # 通过 Repository 保存
        user = self.user_repo.add(user)

        return UserResponse.model_validate(user)

    def _hash_password(self, password: str) -> str:
        """密码哈希"""
        from passlib.context import CryptContext
        pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
        return pwd_context.hash(password)

对比一下之前的"面条代码",现在的服务层:

改进点 说明
职责清晰 只处理业务逻辑,不写 SQL 查询
代码简洁 数据访问都委托给 Repository
易于测试 可以 mock Repository 来测试业务逻辑
可读性强 方法名反映业务意图,读代码像读文档

3.4 配置依赖注入

最后,用 FastAPI 的依赖注入把这些串起来:

python 复制代码
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session

from database import get_db
from repositories.user_repository import UserRepository
from services.user_service import UserService

app = FastAPI()


def get_user_repository(db: Session = Depends(get_db)) -> UserRepository:
    """创建用户仓储实例"""
    return UserRepository(db)


def get_user_service(
    user_repo: UserRepository = Depends(get_user_repository)
) -> UserService:
    """创建用户服务实例"""
    return UserService(user_repo)


@app.get("/users/{user_id}")
def get_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    """路由函数变得非常简洁"""
    return service.get_user(user_id)


@app.get("/users")
def list_users(
    skip: int = 0,
    limit: int = 100,
    is_active: bool = None,
    service: UserService = Depends(get_user_service)
):
    return service.get_users(skip=skip, limit=limit, is_active=is_active)

依赖链是这样的:Router → Service → Repository → Database Session。每一层只依赖下一层的抽象,不依赖具体实现,这就是依赖倒置原则的体现。


四、泛型仓储:减少重复代码

上面的 UserRepository 挺好的,但如果项目里有 UserOrderProductCategory 等十几个实体,每个都写一遍基础的 CRUD 方法,太重复了。

泛型仓储(Generic Repository)可以解决这个问题------把通用的 CRUD 逻辑提取到基类,具体的 Repository 只需要继承基类,添加特有的方法。

4.1 实现泛型基类

python 复制代码
from typing import TypeVar, Generic, Optional, List, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import inspect

ModelType = TypeVar("ModelType")  # ORM 模型类型


class BaseRepository(Generic[ModelType]):
    """
    泛型仓储基类

    这个类实现了通用的 CRUD 操作,具体的 Repository 可以继承它,
    自动获得这些基础功能,只需要添加特有的查询方法。
    """

    def __init__(self, model: Type[ModelType], db: Session):
        """
        Args:
            model: ORM 模型类,比如 User、Order
            db: 数据库会话
        """
        self.model = model
        self.db = db

    def get_by_id(self, id: Any) -> Optional[ModelType]:
        """根据主键获取实体"""
        # 自动检测模型的主键字段
        pk_columns = inspect(self.model).primary_key
        if len(pk_columns) == 1:
            pk_name = pk_columns[0].name
            return self.db.query(self.model).filter(
                getattr(self.model, pk_name) == id
            ).first()
        else:
            # 复合主键的情况
            filters = [
                getattr(self.model, col.name) == id[col.name]
                for col in pk_columns
            ]
            return self.db.query(self.model).filter(*filters).first()

    def get_all(
        self,
        skip: int = 0,
        limit: int = 100,
        **filters
    ) -> List[ModelType]:
        """
        获取实体列表,支持分页和简单过滤

        filters 参数允许传入任意字段的过滤条件,比如:
        get_all(is_active=True, category_id=1)
        """
        query = self.db.query(self.model)

        # 应用过滤条件
        for field, value in filters.items():
            if value is not None and hasattr(self.model, field):
                query = query.filter(getattr(self.model, field) == value)

        return query.offset(skip).limit(limit).all()

    def count(self, **filters) -> int:
        """统计实体数量"""
        query = self.db.query(self.model)

        for field, value in filters.items():
            if value is not None and hasattr(self.model, field):
                query = query.filter(getattr(self.model, field) == value)

        return query.count()

    def exists(self, id: Any) -> bool:
        """检查实体是否存在"""
        return self.get_by_id(id) is not None

    def add(self, entity: ModelType) -> ModelType:
        """添加实体"""
        self.db.add(entity)
        self.db.flush()
        self.db.refresh(entity)
        return entity

    def add_many(self, entities: List[ModelType]) -> List[ModelType]:
        """批量添加实体"""
        self.db.add_all(entities)
        self.db.flush()
        for entity in entities:
            self.db.refresh(entity)
        return entities

    def update(self, entity: ModelType) -> ModelType:
        """更新实体"""
        self.db.flush()
        self.db.refresh(entity)
        return entity

    def delete(self, id: Any) -> bool:
        """删除实体"""
        entity = self.get_by_id(id)
        if entity:
            self.db.delete(entity)
            return True
        return False

    def delete_many(self, ids: List[Any]) -> int:
        """批量删除实体,返回实际删除的数量"""
        pk_columns = inspect(self.model).primary_key
        pk_name = pk_columns[0].name

        count = self.db.query(self.model).filter(
            getattr(self.model, pk_name).in_(ids)
        ).delete(synchronize_session=False)

        return count

4.2 继承基类创建具体仓储

现在,创建具体的 Repository 变得非常简单:

python 复制代码
from typing import Optional, List
from sqlalchemy.orm import Session
from sqlalchemy import or_

from models import User, Order, Product
from repositories.base import BaseRepository


class UserRepository(BaseRepository[User]):
    """
    用户仓储

    继承自 BaseRepository[User],自动获得所有基础 CRUD 方法。
    这里只需要添加 User 特有的查询方法。
    """

    def __init__(self, db: Session):
        super().__init__(User, db)

    # ===== User 特有的方法 =====

    def get_by_username(self, username: str) -> Optional[User]:
        """根据用户名查询"""
        return self.db.query(User).filter(User.username == username).first()

    def get_by_email(self, email: str) -> Optional[User]:
        """根据邮箱查询"""
        return self.db.query(User).filter(User.email == email).first()

    def search(
        self,
        keyword: str,
        skip: int = 0,
        limit: int = 100
    ) -> List[User]:
        """搜索用户(在用户名和邮箱中模糊匹配)"""
        return self.db.query(User).filter(
            or_(
                User.username.ilike(f"%{keyword}%"),
                User.email.ilike(f"%{keyword}%")
            )
        ).offset(skip).limit(limit).all()

    def exists_by_username(self, username: str) -> bool:
        """检查用户名是否存在"""
        return self.db.query(
            self.db.query(User).filter(User.username == username).exists()
        ).scalar()

    def exists_by_email(self, email: str) -> bool:
        """检查邮箱是否存在"""
        return self.db.query(
            self.db.query(User).filter(User.email == email).exists()
        ).scalar()


class OrderRepository(BaseRepository[Order]):
    """订单仓储"""

    def __init__(self, db: Session):
        super().__init__(Order, db)

    def get_by_user(
        self,
        user_id: int,
        status: Optional[str] = None,
        skip: int = 0,
        limit: int = 100
    ) -> List[Order]:
        """获取用户的订单"""
        query = self.db.query(Order).filter(Order.user_id == user_id)

        if status:
            query = query.filter(Order.status == status)

        return query.order_by(Order.created_at.desc()).offset(skip).limit(limit).all()

    def get_recent_orders(self, days: int = 7) -> List[Order]:
        """获取最近几天的订单"""
        from datetime import datetime, timedelta
        since = datetime.now() - timedelta(days=days)
        return self.db.query(Order).filter(Order.created_at >= since).all()

    def calculate_user_total(self, user_id: int) -> float:
        """计算用户的累计消费金额"""
        from sqlalchemy import func
        result = self.db.query(func.sum(Order.total_price)).filter(
            Order.user_id == user_id,
            Order.status == "completed"
        ).scalar()
        return result or 0.0


class ProductRepository(BaseRepository[Product]):
    """商品仓储"""

    def __init__(self, db: Session):
        super().__init__(Product, db)

    def get_by_category(
        self,
        category_id: int,
        in_stock: bool = True
    ) -> List[Product]:
        """获取某分类下的商品"""
        query = self.db.query(Product).filter(Product.category_id == category_id)

        if in_stock:
            query = query.filter(Product.stock > 0)

        return query.all()

    def search_by_name(self, keyword: str) -> List[Product]:
        """搜索商品"""
        return self.db.query(Product).filter(
            Product.name.ilike(f"%{keyword}%")
        ).all()

    def decrease_stock(self, product_id: int, quantity: int) -> bool:
        """
        扣减库存

        使用乐观锁防止超卖:只有 stock >= quantity 时才会更新成功
        """
        result = self.db.query(Product).filter(
            Product.id == product_id,
            Product.stock >= quantity
        ).update(
            {Product.stock: Product.stock - quantity},
            synchronize_session=False
        )
        return result > 0

看到了吗?每个具体的 Repository 只需要几行代码就定义好了,因为基础方法都继承自 BaseRepository。特有的业务查询方法单独定义,代码量大大减少,而且结构清晰。

4.3 泛型仓储的优缺点

优点 缺点
减少重复代码 复杂查询仍需单独实现
统一的 CRUD 接口 增加一层抽象
新增实体 Repository 很快 泛型可能让新手困惑
便于统一添加日志、审计等功能 过度抽象可能适得其反

五、工作单元模式配合仓储

前面的 Repository 实现里,我们只用了 flush() 而没有 commit()。这是故意的------事务控制应该由更上层来负责。工作单元模式(Unit of Work)就是专门管理事务的。

5.1 什么是工作单元

工作单元的核心思想是:把一组相关的数据库操作放在一起,要么全部成功,要么全部失败

比如创建订单这个业务,需要:

  1. 创建订单记录
  2. 更新商品库存
  3. 创建支付记录

这三个操作必须是原子的------要么都成功,要么都回滚。如果订单创建成功但库存更新失败,数据就不一致了。

工作单元负责:

  • 追踪一个业务操作涉及的所有数据变更
  • 在合适的时机统一提交(commit)或回滚(rollback
  • 确保事务的完整性

5.2 实现工作单元

python 复制代码
from typing import Optional
from sqlalchemy.orm import Session

from repositories.user_repository import UserRepository
from repositories.order_repository import OrderRepository
from repositories.product_repository import ProductRepository


class UnitOfWork:
    """
    工作单元

    管理多个 Repository 的操作,确保这些操作在同一个事务中执行。

    典型用法:
        with UnitOfWork(db) as uow:
            user = uow.users.get_by_id(1)
            order = Order(user_id=user.id, ...)
            uow.orders.add(order)
            uow.commit()
    """

    def __init__(self, session: Session):
        self._session = session

        # 懒加载 Repository
        self._users: Optional[UserRepository] = None
        self._orders: Optional[OrderRepository] = None
        self._products: Optional[ProductRepository] = None

    @property
    def users(self) -> UserRepository:
        """用户仓储(懒加载)"""
        if self._users is None:
            self._users = UserRepository(self._session)
        return self._users

    @property
    def orders(self) -> OrderRepository:
        """订单仓储"""
        if self._orders is None:
            self._orders = OrderRepository(self._session)
        return self._orders

    @property
    def products(self) -> ProductRepository:
        """商品仓储"""
        if self._products is None:
            self._products = ProductRepository(self._session)
        return self._products

    def __enter__(self):
        """进入上下文管理器"""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """退出上下文管理器,异常时自动回滚"""
        if exc_type is not None:
            self.rollback()

    def commit(self):
        """提交事务"""
        try:
            self._session.commit()
        except Exception:
            self.rollback()
            raise

    def rollback(self):
        """回滚事务"""
        self._session.rollback()

    def flush(self):
        """刷新会话(不提交)"""
        self._session.flush()

5.3 在服务层使用工作单元

python 复制代码
from fastapi import HTTPException, status

from unit_of_work import UnitOfWork
from models import Order
from schemas import OrderCreate


class OrderService:
    """订单服务"""

    def __init__(self, uow: UnitOfWork):
        self.uow = uow

    def create_order(self, order_data: OrderCreate) -> Order:
        """
        创建订单

        涉及多个表(订单、库存),必须在一个事务中完成
        """
        # 1. 检查用户是否存在
        user = self.uow.users.get_by_id(order_data.user_id)
        if not user:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="用户不存在"
            )

        # 2. 检查商品并计算总价
        total_price = 0
        order_items = []

        for item in order_data.items:
            product = self.uow.products.get_by_id(item.product_id)

            if not product:
                raise HTTPException(
                    status_code=status.HTTP_404_NOT_FOUND,
                    detail=f"商品 {item.product_id} 不存在"
                )

            if product.stock < item.quantity:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail=f"商品 {product.name} 库存不足"
                )

            subtotal = product.price * item.quantity
            total_price += subtotal

            order_items.append({
                'product': product,
                'quantity': item.quantity,
                'price': product.price
            })

        # 3. 创建订单
        order = Order(
            user_id=user.id,
            total_price=total_price,
            status="pending"
        )
        self.uow.orders.add(order)
        self.uow.flush()  # 获取订单 ID

        # 4. 扣减库存
        for item_data in order_items:
            product = item_data['product']
            quantity = item_data['quantity']

            success = self.uow.products.decrease_stock(product.id, quantity)
            if not success:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail=f"商品 {product.name} 库存不足"
                )

        # 5. 提交事务
        self.uow.commit()

        return order

看到没?UnitOfWork 把多个 Repository 组织在一起,服务层通过 uow.usersuow.ordersuow.products 访问不同的仓储,最后统一 commit()。任何一步失败,整个事务都会回滚。

5.4 配置依赖注入

python 复制代码
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session

from database import get_db
from unit_of_work import UnitOfWork
from services.order_service import OrderService

app = FastAPI()


def get_uow(db: Session = Depends(get_db)) -> UnitOfWork:
    """每个请求一个工作单元"""
    return UnitOfWork(db)


def get_order_service(uow: UnitOfWork = Depends(get_uow)) -> OrderService:
    """创建订单服务"""
    return OrderService(uow)


@app.post("/orders")
def create_order(
    order_data: OrderCreate,
    service: OrderService = Depends(get_order_service)
):
    return service.create_order(order_data)

六、FastAPI完整实战

把前面讲的内容整合起来,看看一个完整的项目结构:

6.1 项目结构

复制代码
myapp/
├── main.py                    # 应用入口
├── config.py                  # 配置
├── database.py                # 数据库连接
├── unit_of_work.py            # 工作单元
├── dependencies.py            # 依赖注入配置
│
├── models/                    # ORM 模型
│   ├── __init__.py
│   ├── user.py
│   ├── order.py
│   └── product.py
│
├── schemas/                   # Pydantic 模型
│   ├── __init__.py
│   ├── user.py
│   ├── order.py
│   └── product.py
│
├── repositories/              # 仓储层
│   ├── __init__.py
│   ├── base.py               # 泛型基类
│   ├── user_repository.py
│   ├── order_repository.py
│   └── product_repository.py
│
├── services/                  # 服务层
│   ├── __init__.py
│   ├── user_service.py
│   ├── order_service.py
│   └── product_service.py
│
└── routers/                   # 路由层
    ├── __init__.py
    ├── user_router.py
    ├── order_router.py
    └── product_router.py

6.2 依赖注入配置

python 复制代码
# dependencies.py
from fastapi import Depends
from sqlalchemy.orm import Session

from database import get_db
from unit_of_work import UnitOfWork
from services.user_service import UserService
from services.order_service import OrderService
from services.product_service import ProductService


def get_uow(db: Session = Depends(get_db)) -> UnitOfWork:
    """获取工作单元"""
    return UnitOfWork(db)


def get_user_service(uow: UnitOfWork = Depends(get_uow)) -> UserService:
    """获取用户服务"""
    return UserService(uow)


def get_order_service(uow: UnitOfWork = Depends(get_uow)) -> OrderService:
    """获取订单服务"""
    return OrderService(uow)


def get_product_service(uow: UnitOfWork = Depends(get_uow)) -> ProductService:
    """获取商品服务"""
    return ProductService(uow)

6.3 路由示例

python 复制代码
# routers/user_router.py
from typing import Optional
from fastapi import APIRouter, Depends, status

from dependencies import get_user_service
from services.user_service import UserService
from schemas.user import UserCreate, UserUpdate, UserResponse, UserListResponse

router = APIRouter(prefix="/users", tags=["用户管理"])


@router.get("/", response_model=UserListResponse)
def list_users(
    skip: int = 0,
    limit: int = 100,
    is_active: Optional[bool] = None,
    service: UserService = Depends(get_user_service)
):
    """获取用户列表"""
    return service.get_users(skip=skip, limit=limit, is_active=is_active)


@router.get("/{user_id}", response_model=UserResponse)
def get_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    """获取用户详情"""
    return service.get_user(user_id)


@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def create_user(
    user_data: UserCreate,
    service: UserService = Depends(get_user_service)
):
    """创建用户"""
    return service.create_user(user_data)


@router.put("/{user_id}", response_model=UserResponse)
def update_user(
    user_id: int,
    user_data: UserUpdate,
    service: UserService = Depends(get_user_service)
):
    """更新用户"""
    return service.update_user(user_id, user_data)


@router.delete("/{user_id}")
def delete_user(
    user_id: int,
    service: UserService = Depends(get_user_service)
):
    """删除用户"""
    return service.delete_user(user_id)

6.4 主应用

python 复制代码
# main.py
from fastapi import FastAPI

from routers import user_router, order_router, product_router

app = FastAPI(
    title="电商API",
    description="使用仓储模式的 FastAPI 示例",
    version="1.0.0"
)

app.include_router(user_router.router)
app.include_router(order_router.router)
app.include_router(product_router.router)


@app.get("/")
def root():
    return {"message": "Welcome to the API"}


@app.get("/health")
def health():
    return {"status": "healthy"}

七、仓储模式 vs 直接用ORM

看到这里,你可能会想:这搞得这么复杂,直接用 SQLAlchemy 不是更简单?

这是个好问题。仓储模式不是银弹,它有成本也有收益,需要根据项目情况权衡。

7.1 直接使用 ORM 的场景

适合的情况:

  • 项目较小,CRUD 为主,业务逻辑简单
  • 团队对 ORM 熟悉,开发速度优先
  • 不太可能更换数据库或 ORM
  • 不需要严格的测试覆盖
python 复制代码
# 直接用 ORM,简单直接
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise HTTPException(status_code=404)
    return user

这样写没什么问题。如果你的项目就是几个简单接口,引入仓储模式反而是过度设计。

7.2 使用仓储模式的场景

适合的情况:

  • 项目较大,业务逻辑复杂
  • 需要良好的测试覆盖
  • 可能更换底层存储
  • 团队协作,需要清晰的代码边界
  • 追求代码的可维护性和可扩展性

7.3 对比总结

方面 直接用ORM 仓储模式
代码量
学习成本 中等
开发速度(初期) 慢一点
维护成本(后期)
可测试性
灵活性
适合项目规模 小型 中大型

7.4 我的建议

项目类型 建议
小项目(几个简单接口) 直接用 ORM
中型项目 至少把数据访问逻辑提取到单独模块
大型项目 完整分层架构 + 工作单元
需要高测试覆盖 仓储模式

八、测试中的巨大优势

仓储模式最大的优势之一就是方便测试。

8.1 不使用仓储模式的测试困境

python 复制代码
# 业务代码直接操作数据库
class UserService:
    def __init__(self, db: Session):
        self.db = db

    def create_user(self, username: str, email: str, password: str):
        if self.db.query(User).filter(User.username == username).first():
            raise ValueError("用户名已存在")
        user = User(username=username, email=email, password=password)
        self.db.add(user)
        self.db.commit()
        return user


# 测试代码
def test_create_user():
    # 必须准备真实数据库
    db = TestSessionLocal()
    service = UserService(db)

    # 清理可能存在的测试数据
    db.query(User).filter(User.username == "testuser").delete()
    db.commit()

    # 执行测试
    user = service.create_user("testuser", "test@example.com", "password")
    assert user.username == "testuser"

    # 清理
    db.delete(user)
    db.commit()
    db.close()

问题很多:需要真实数据库、测试数据要准备和清理、测试速度慢、测试之间可能互相影响。

8.2 使用仓储模式的测试

创建一个内存实现的"假仓储":

python 复制代码
from typing import Optional, List, Dict
from models import User


class FakeUserRepository:
    """假的用户仓储,用于测试"""

    def __init__(self):
        self._users: Dict[int, User] = {}
        self._next_id = 1

    def get_by_id(self, user_id: int) -> Optional[User]:
        return self._users.get(user_id)

    def get_by_username(self, username: str) -> Optional[User]:
        for user in self._users.values():
            if user.username == username:
                return user
        return None

    def add(self, user: User) -> User:
        user.id = self._next_id
        self._next_id += 1
        self._users[user.id] = user
        return user

    def exists_by_username(self, username: str) -> bool:
        return self.get_by_username(username) is not None

    def clear(self):
        """清空所有数据"""
        self._users.clear()
        self._next_id = 1

测试变得简单快速:

python 复制代码
import pytest
from services.user_service import UserService


class FakeUnitOfWork:
    """假的工作单元"""

    def __init__(self):
        self.users = FakeUserRepository()
        self.committed = False

    def commit(self):
        self.committed = True

    def rollback(self):
        pass


class TestUserService:

    @pytest.fixture
    def uow(self):
        return FakeUnitOfWork()

    @pytest.fixture
    def service(self, uow):
        return UserService(uow)

    def test_create_user_success(self, service, uow):
        """测试成功创建用户"""
        user_data = UserCreate(
            username="testuser",
            email="test@example.com",
            password="password123"
        )

        result = service.create_user(user_data)

        assert result.username == "testuser"
        assert uow.committed

    def test_create_user_duplicate_username(self, service, uow):
        """测试用户名重复"""
        # 先添加一个用户
        existing_user = User(
            username="existinguser",
            email="existing@example.com",
            hashed_password="xxx"
        )
        uow.users.add(existing_user)

        # 尝试创建同名用户
        user_data = UserCreate(
            username="existinguser",
            email="new@example.com",
            password="password123"
        )

        with pytest.raises(HTTPException) as exc_info:
            service.create_user(user_data)

        assert exc_info.value.status_code == 400
        assert "用户名已存在" in exc_info.value.detail

8.3 FastAPI 集成测试

python 复制代码
import pytest
from fastapi.testclient import TestClient

from main import app
from dependencies import get_uow


class TestUserAPI:

    @pytest.fixture
    def client(self):
        fake_uow = FakeUnitOfWork()

        # 预设测试数据
        fake_uow.users.add(User(
            username="existinguser",
            email="existing@example.com",
            hashed_password="xxx"
        ))

        # 覆盖依赖
        app.dependency_overrides[get_uow] = lambda: fake_uow

        client = TestClient(app)
        yield client

        app.dependency_overrides.clear()

    def test_create_user(self, client):
        response = client.post("/users/", json={
            "username": "newuser",
            "email": "new@example.com",
            "password": "password123"
        })

        assert response.status_code == 201
        assert response.json()["username"] == "newuser"

8.4 测试的好处总结

好处 说明
快速 内存运行,毫秒级完成
隔离 每个测试有自己的假仓储
稳定 不受网络、数据库状态影响
简单 准备测试数据很方便
专注 单独测试业务逻辑

九、常见问题与最佳实践

9.1 Repository 应该返回什么?

返回 ORM 模型还是 DTO?

建议:Repository 返回 ORM 模型,转换成 DTO 是 Service 层的事。

python 复制代码
# Repository 返回 ORM 对象
class UserRepository:
    def get_by_id(self, user_id: int) -> Optional[User]:
        return self.db.query(User).filter(User.id == user_id).first()


# Service 层做转换
class UserService:
    def get_user(self, user_id: int) -> UserResponse:
        user = self.repo.get_by_id(user_id)  # User(ORM)
        return UserResponse.model_validate(user)  # UserResponse(DTO)

9.2 复杂查询放在哪里?

原则:和数据获取相关的放 Repository,和业务规则相关的放 Service

python 复制代码
# Repository:复杂的数据查询
class OrderRepository:
    def get_monthly_stats(self, year: int, month: int) -> dict:
        """获取月度订单统计(数据查询)"""
        from sqlalchemy import func

        result = self.db.query(
            func.count(Order.id).label('count'),
            func.sum(Order.total_price).label('total')
        ).filter(
            func.extract('year', Order.created_at) == year,
            func.extract('month', Order.created_at) == month,
            Order.status == 'completed'
        ).first()

        return {
            'count': result.count or 0,
            'total': float(result.total or 0)
        }


# Service:基于数据的业务计算
class ReportService:
    def generate_monthly_report(self, year: int, month: int) -> dict:
        """生成月度报告(业务逻辑)"""
        stats = self.uow.orders.get_monthly_stats(year, month)

        # 业务计算
        avg_order_value = (
            stats['total'] / stats['count']
            if stats['count'] > 0 else 0
        )

        return {
            'order_count': stats['count'],
            'total_revenue': stats['total'],
            'avg_order_value': avg_order_value
        }

9.3 如何处理 N+1 问题?

使用 joinedloadselectinload 预加载关联数据:

python 复制代码
from sqlalchemy.orm import joinedload, selectinload


class OrderRepository:
    def get_with_items(self, order_id: int) -> Optional[Order]:
        """获取订单及其订单项(预加载)"""
        return self.db.query(Order).options(
            selectinload(Order.items).selectinload(OrderItem.product)
        ).filter(Order.id == order_id).first()

9.4 仓储应该有多细粒度?

一个实体一个仓储是常见做法。但紧密相关的实体可以共用:

python 复制代码
# 聚合根模式:订单项通过订单仓储访问
class OrderRepository:
    def add_item(self, order_id: int, item: OrderItem): ...
    def remove_item(self, order_id: int, item_id: int): ...
    def get_items(self, order_id: int) -> List[OrderItem]: ...

9.5 性能优化建议

批量操作:避免循环中一个个操作数据库

python 复制代码
# 不好:N次数据库操作
for user_id in user_ids:
    user = repo.get_by_id(user_id)

# 好:1次数据库操作
users = repo.get_by_ids(user_ids)

分页:始终使用分页

python 复制代码
# 好
users = repo.get_all(skip=0, limit=100)

# 不好
users = repo.get_all()  # 可能有几万条

十、总结

这篇文章把仓储模式的各个方面都讲到了,总结一下关键点:

仓储模式是什么

它是数据访问层的抽象,把"怎么存数据"的细节藏起来,对外提供类似"集合"的接口。Repository 就像图书馆管理员,业务代码只需要说"我要这本书",不用管书放在哪个架子上。

为什么要用仓储模式

主要是为了解耦、可测试、可维护。数据访问逻辑集中管理,业务层不依赖具体的数据库实现,测试时可以用假实现替换真数据库。

怎么实现

从最基础的 Repository 类开始,然后引入泛型基类减少重复代码,再配合工作单元模式管理事务。

什么时候用

小项目直接用 ORM 就够了,中大型项目、需要高测试覆盖的项目建议使用。

关键要点总结

概念 说明 适用场景
Repository 数据访问抽象 封装查询逻辑
泛型仓储 减少重复的 CRUD 代码 多实体项目
工作单元 事务管理 多表操作
假仓储 测试用的内存实现 单元测试

仓储模式 + 工厂模式 + 依赖注入 + 单例,这几个模式组合起来,形成了完整的分层架构。

下一篇打算讲观察者模式,这个在事件驱动的场景(比如发消息通知、触发异步任务)非常有用。在 FastAPI 里配合背景任务或消息队列使用,能让系统架构更加松耦合。


热门专栏推荐

等等等还有许多优秀的合集在主页等着大家的光顾,感谢大家的支持

文章到这里就结束了,如果有什么疑问的地方请指出,诸佬们一起来评论区一起讨论😊

希望能和诸佬们一起努力,今后我们一起观看感谢您的阅读🙏

如果帮助到您不妨3连支持一下,创造不易您们的支持是我的动力🌟

相关推荐
MarkHD2 小时前
智能体在车联网中的应用:第25天 深度Q网络(DQN)实战:在CartPole环境中用PyTorch从零实现
人工智能·pytorch·python
kobe_OKOK_2 小时前
windows 部署 django 的 方案
后端·python·django
智算菩萨2 小时前
【Python进阶】数据结构的精巧与算法的智慧:AI提速的关键
开发语言·人工智能·python
BoBoZz192 小时前
GenerateCubesFromLabels 提取和可视化特定标签所代表的 3D 结构
python·vtk·图形渲染·图形处理
liwulin05062 小时前
【PYTHON】视频转图片
开发语言·python·音视频
惆怅客1232 小时前
libuvc初探
python·c·libuvc
渡我白衣2 小时前
Python 与数据科学工具链入门:NumPy、Pandas、Matplotlib 快速上手
人工智能·python·机器学习·自然语言处理·numpy·pandas·matplotlib
love530love2 小时前
【笔记】把已有的 ComfyUI 插件发布到 Comfy Registry(官方节点商店)全流程实录
人工智能·windows·笔记·python·aigc·comfyui·torchmonitor
星火飞码iFlyCode2 小时前
iFlyCode实践规范驱动开发(SDD):招考平台报名相片质量抽检功能开发实战
java·前端·python·算法·ai编程·科大讯飞