FastAPI 进阶:ORM(SQLAlchemy 异步)

一、ORM 简介

ORM(Object-Relational Mapping,对象关系映射) 是一种编程技术,它在面向对象编程语言和关系型数据库之间建立映射。开发者通过操作对象的方式与数据库交互,无需直接编写复杂的 SQL 语句。

核心优势

  • 减少重复的 SQL 代码
  • 代码更简洁、易读,面向对象风格
  • 自动处理数据库连接和事务
  • 内置防止 SQL 注入攻击(通过参数化查询)
  • 数据库无关性:可轻松更换数据库(如 MySQL → PostgreSQL)

二、主流 ORM 工具对比

排名 ORM 工具 特点 适应场景
1 SQLAlchemy ORM 功能最强、最灵活、企业级 各类 API、微服务、数据应用
2 Django ORM 封装好、上手快 Django 项目、管理后台
3 Tortoise ORM 全异步 异步 Web 服务、高并发 API

在 FastAPI 项目中,通常推荐 SQLAlchemy(异步模式)。

三、SQLAlchemy 异步 ORM 完整使用流程

Step 1:安装依赖

bash 复制代码
pip install sqlalchemy[asyncio] aiomysql
  • sqlalchemy[asyncio]:SQLAlchemy 异步支持
  • aiomysql:MySQL 的异步驱动(也可用 asyncpg 连接 PostgreSQL)

Step 2:建库与建表

2.1 创建异步引擎
python 复制代码
# database.py
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, DateTime, func
from datetime import datetime

# 数据库连接 URL(MySQL 异步)
ASYNC_DATABASE_URL = "mysql+aiomysql://root:123456@localhost:3306/fastapi_test?charset=utf8"

# 创建异步引擎
async_engine = create_async_engine(
    ASYNC_DATABASE_URL,
    echo=True,          # 输出 SQL 日志(开发环境可开启)
    pool_size=10,       # 连接池中保持的持久连接数
    max_overflow=20     # 连接池允许创建的额外连接数
)
2.2 定义模型基类和模型类
python 复制代码
# models.py
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String, DateTime, func
from datetime import datetime

class Base(DeclarativeBase):
    """所有模型类的基类,包含公共字段(创建时间、更新时间)"""
    create_time: Mapped[datetime] = mapped_column(
        DateTime,
        insert_default=func.now(),
        default=datetime.now,
        comment="创建时间"
    )
    update_time: Mapped[datetime] = mapped_column(
        DateTime,
        insert_default=func.now(),
        onupdate=func.now(),
        default=datetime.now,
        comment="修改时间"
    )

class Book(Base):
    __tablename__ = "book"   # 数据库表名

    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    bookname: Mapped[str] = mapped_column(String(255), nullable=False)
    author: Mapped[str] = mapped_column(String(255), nullable=False)
    price: Mapped[float] = mapped_column(nullable=True, default=0.0)
  • Mapped[type] 是 SQLAlchemy 2.0 风格的类型注解,更加简洁。
  • mapped_column() 定义列属性,可指定长度、默认值、注释等。
  • 基类中的 create_timeupdate_time 会自动添加到所有子表。
2.3 创建数据库表(在应用启动时执行)
python 复制代码
# main.py 或 lifespan
from contextlib import asynccontextmanager
from fastapi import FastAPI
from database import async_engine
from models import Base

async def create_tables():
    """异步创建所有表(如果不存在)"""
    async with async_engine.begin() as conn:
        # run_sync 用于在异步环境中同步执行建表操作
        await conn.run_sync(Base.metadata.create_all)

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动时执行
    await create_tables()
    yield
    # 关闭时释放引擎(可选)
    await async_engine.dispose()

app = FastAPI(lifespan=lifespan)
  • 使用 @asynccontextmanager 实现 lifespan,替代已弃用的 @app.on_event("startup")
  • conn.run_sync(Base.metadata.create_all) 是因为 create_all 是同步方法,需要在异步环境中包装执行。

Step 3:数据库会话依赖注入(关键)

使用依赖注入为每个请求提供独立的数据库会话,并确保正确关闭。

python 复制代码
# dependencies.py
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import sessionmaker
from database import async_engine

# 创建异步会话工厂
AsyncSessionLocal = sessionmaker(
    async_engine,
    class_=AsyncSession,
    expire_on_commit=False
)

async def get_db() -> AsyncSession:
    """依赖项:提供数据库会话,请求结束后自动关闭"""
    async with AsyncSessionLocal() as session:
        yield session

然后在路由中使用 Depends(get_db) 获取会话。

Step 4:CRUD 操作示例

以下所有代码都放在路由文件中,使用 Depends(get_db) 注入会话。

4.1 查询(多种方式)
python 复制代码
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from models import Book
from dependencies import get_db

router = APIRouter(prefix="/book", tags=["书籍"])

# 查询所有书籍
@router.get("/get_books")
async def get_books(db: AsyncSession = Depends(get_db)):
    result = await db.execute(select(Book))
    books = result.scalars().all()   # 返回列表
    return books

# 根据主键查询单条(get 方法)
@router.get("/get_book/{book_id}")
async def get_book_by_id(book_id: int, db: AsyncSession = Depends(get_db)):
    book = await db.get(Book, book_id)
    if not book:
        raise HTTPException(status_code=404, detail="Book not found")
    return book

# 使用 scalar_one_or_none 避免多条数据时报错
@router.get("/find_by_name")
async def find_book(name: str, db: AsyncSession = Depends(get_db)):
    stmt = select(Book).where(Book.bookname == name)
    result = await db.execute(stmt)
    book = result.scalar_one_or_none()
    return book
4.2 条件查询
python 复制代码
# 比较判断
@router.get("/price_gt")
async def books_price_gt(price: float, db: AsyncSession = Depends(get_db)):
    stmt = select(Book).where(Book.price > price)
    result = await db.execute(stmt)
    return result.scalars().all()

# 模糊查询 like
@router.get("/author_like")
async def books_by_author_pattern(pattern: str, db: AsyncSession = Depends(get_db)):
    stmt = select(Book).where(Book.author.like(f"{pattern}%"))
    result = await db.execute(stmt)
    return result.scalars().all()

# 多条件组合 & (与) , | (或)
@router.get("/filter")
async def complex_filter(db: AsyncSession = Depends(get_db)):
    from sqlalchemy import and_, or_
    stmt = select(Book).where(
        and_(Book.price > 50, Book.author == "曹雪芹")
    )
    result = await db.execute(stmt)
    return result.scalars().all()

# 包含查询 in_
@router.get("/ids")
async def books_by_ids(ids: str, db: AsyncSession = Depends(get_db)):
    id_list = [int(i) for i in ids.split(",")]
    stmt = select(Book).where(Book.id.in_(id_list))
    result = await db.execute(stmt)
    return result.scalars().all()
4.3 聚合查询
python 复制代码
from sqlalchemy import func

@router.get("/count")
async def get_book_count(db: AsyncSession = Depends(get_db)):
    stmt = select(func.count(Book.id))
    result = await db.execute(stmt)
    count = result.scalar()
    return {"total": count}

@router.get("/avg_price")
async def get_avg_price(db: AsyncSession = Depends(get_db)):
    stmt = select(func.avg(Book.price))
    result = await db.execute(stmt)
    avg = result.scalar()
    return {"avg_price": avg}
4.4 分页查询
python 复制代码
@router.get("/pagination")
async def get_books_paginated(
    page: int = 1,
    page_size: int = 10,
    db: AsyncSession = Depends(get_db)
):
    skip = (page - 1) * page_size
    stmt = select(Book).offset(skip).limit(page_size)
    result = await db.execute(stmt)
    books = result.scalars().all()
    return {"page": page, "page_size": page_size, "books": books}
4.5 新增数据
python 复制代码
from pydantic import BaseModel

class BookCreate(BaseModel):
    bookname: str
    author: str
    price: float = 0.0

@router.post("/add_book")
async def add_book(book_data: BookCreate, db: AsyncSession = Depends(get_db)):
    # 创建 ORM 对象
    new_book = Book(
        bookname=book_data.bookname,
        author=book_data.author,
        price=book_data.price
    )
    db.add(new_book)
    await db.commit()          # 提交事务
    await db.refresh(new_book) # 刷新获取自增 id 等
    return new_book
4.6 更新数据
python 复制代码
class BookUpdate(BaseModel):
    bookname: str = None
    author: str = None
    price: float = None

@router.put("/update_book/{book_id}")
async def update_book(
    book_id: int,
    book_data: BookUpdate,
    db: AsyncSession = Depends(get_db)
):
    book = await db.get(Book, book_id)
    if not book:
        raise HTTPException(status_code=404, detail="Book not found")
    
    # 只更新传入的非空字段
    if book_data.bookname is not None:
        book.bookname = book_data.bookname
    if book_data.author is not None:
        book.author = book_data.author
    if book_data.price is not None:
        book.price = book_data.price
    
    await db.commit()
    await db.refresh(book)
    return book
4.7 删除数据
python 复制代码
@router.delete("/delete_book/{book_id}")
async def delete_book(book_id: int, db: AsyncSession = Depends(get_db)):
    book = await db.get(Book, book_id)
    if not book:
        raise HTTPException(status_code=404, detail="Book not found")
    await db.delete(book)
    await db.commit()
    return {"message": "Book deleted successfully"}

四、ORM 使用流程总结

  1. 安装依赖

sqlalchemy[asyncio] + aiomysql
2. 创建异步引擎

create_async_engine

配置URL、连接池、日志
3. 定义模型基类和模型类

DeclarativeBase + Mapped/mapped_column

映射表名、字段类型、约束
4. 应用启动时建表

run_sync(Base.metadata.create_all)
5. 创建依赖项 get_db

sessionmaker + AsyncSession

async with 确保会话自动关闭

通过 Depends 注入到路由函数
6. 在路由中执行 CRUD

查询:select → execute → scalars

新增:add → commit

更新:修改属性 → commit

删除:delete → commit

五、ORM 核心要点速记

操作 核心代码
查询所有 result = await db.execute(select(Book)); books = result.scalars().all()
主键查询 book = await db.get(Book, id)
条件查询 select(Book).where(Book.price > 50)
分页 select(Book).offset(skip).limit(limit)
聚合 select(func.count(Book.id))
新增 db.add(obj); await db.commit()
更新 obj.attr = value; await db.commit()
删除 await db.delete(obj); await db.commit()

六、常见问题与注意事项

  1. 异步环境必须全程异步

    • 使用 async def 定义路由和依赖项。
    • 所有数据库操作都需要 await
  2. 不要在异步中使用同步 SQLAlchemy

    • 必须使用 create_async_engine + AsyncSession,而不是普通引擎。
  3. scalars() 的作用

    • result.scalars() 将行结果转换为 ORM 对象,.all() 返回列表,.first() 返回第一个。
  4. 事务提交

    • 增、删、改后必须 await db.commit(),否则不会持久化。
    • 查询不需要 commit。
  5. 依赖注入会话的正确写法

    python 复制代码
    async def get_db() -> AsyncSession:
        async with AsyncSessionLocal() as session:
            yield session   # 请求处理完后自动关闭
相关推荐
dinl_vin4 小时前
FastAPI 系列 ·(十一):ClickHouse 集成——大数据查询实战
大数据·clickhouse·fastapi
li星野19 小时前
FastAPI 入门:异步与同步端点的性能差异与并发测试解析
fastapi
dinl_vin1 天前
FastAPI 系列 · (十):测试——从单元到集成
fastapi
dinl_vin1 天前
FastAPI 系列 ·(九):中间件与错误处理:让服务更健壮
中间件·状态模式·fastapi
圣殿骑士-Khtangc1 天前
Python后端开发实战:FastAPI构建高性能RESTful API完整指南
python·restful·fastapi
展示猪肝1 天前
FastAPI 全局异常处理最佳实践:自定义异常、统一响应、兜底处理
python·异常处理·fastapi·后端开发
青衫客362 天前
从零实现多智能体 Runtime(一):系统架构、状态机与任务编排设计
agent·fastapi
曲幽2 天前
FastApiAdmin 后端接口开发好了,前端管理界面怎么调用与显示?
python·vue3·api·fastapi·web·ant design·view·menu·frontend
dinl_vin2 天前
FastAPI 系列·(七):Redis 集成——缓存、分布式锁与 Session 管理
redis·缓存·fastapi