在FastAPI 中正确使用 async SQLAlchemy、celery、websockets
从版本1.4开始,SQLAlchemy支持asyncio。在本文章中,我们将尝试使用async SQLAlchemy
功能、encryption
、celery
和websocket
来实现简单的项目。我们从数据库连接开始。
使用异步SQLAlchemy设置数据库
首先让我们创建异步session
:
python
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, echo=True)
SessionLocal = sessionmaker(
expire_on_commit=False,
class_=AsyncSession,
bind=engine,
)
我们使用 FastAPI的注入 dependencies 能力来依赖注入 db session
python
async def get_db() -> AsyncSession:
async with SessionLocal() as session:
yield session
所有都准备好之后我们使用 DB. 在项目中我们是用 token authentication 来控制用户的登录行为, 因此我们需要两个数据表: users
和 user_tokens
python
# 导入相关的依赖包
from sqlalchemy.orm import declarative_base
from sqlalchemy_utils import EmailType, force_auto_coercion, PasswordType
# Base 为数据库的基类,我们的 table 需要继承 Base 才能实现 orm 的相关能力
Base = declarative_base()
force_auto_coercion()
class User(Base):
"""定义 User class, 目的是实现 orm 能力"""
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50))
email = Column(EmailType(50), unique=True, nullable=False)
password = Column(PasswordType(schemes=["pbkdf2_sha512"]), nullable=False)
tokens = relationship(
"UserToken",
back_populates="user",
lazy='dynamic',
cascade="all, delete-orphan",
)
class UserToken(Base):
"""用户 token 的表"""
__tablename__ = "user_tokens"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
token = Column(
UUID(as_uuid=True), unique=True, nullable=False, default=uuid.uuid4
)
expires = Column(DateTime)
user = relationship("User", back_populates="tokens", lazy='joined')Note `force_auto_coercion()`
请注意,我们在模型之前使用了 force_auto_coercion()
. 在记录保存到数据库前,确保密码经过哈希处理。
现在我们构建的 web 项目都几乎不在自己手动去添加数据库,而是使用相关的数据库迁移工具。我们将使用alembic来实现这一目的。(如果你的项目中还没有使用 magration的工具,那么建议你赶快用起来)
安装 alembic
:
shell
pip install alembic
初始化 alembic
shell
alembic init migrations
以上命令将创建带有 env.py
、README
和 script.py.mako
文件的 migrations
目录。
要使 alembic
与我们的数据库配合工作,我们需要更新 env.py
文件。
python
import asyncio
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Here we importing and specifying our DB metadata
from app.db.base import Base
target_metadata = Base.metadata
# This method returns url of our DB
def get_url():
return os.getenv("SQLALCHEMY_DATABASE_URL", "")
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
# Specify which database we use with alembic
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_url()
connectable = AsyncEngine(
engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
我们使用下面的命令来创建 migration
文件:
shell
alembic revision --autogenerate -m "Added required tables"
运行以下命令来应用迁移并更新数据库:
shell
alembic upgrade head
执行完成后我们的数据库里面的表已经创建好了。
现在我们的数据库已经准备好了,我们可以尝试创建新的用户和令牌:
shell
from sqlalchemy import select
from app.db.base import User
async def get_user_by_email(db: AsyncSession, email: str) -> User:
"""根据邮箱查找用户,所以数据库中我们的 email 字段需要保证唯一"""
statement = select(User).where(User.email == email)
result = await db.execute(statement)
return result.scalars().first()
async def create_user(db: AsyncSession, user: UserCreate) -> User:
"""创建用户"""
db_user = User(
email=user.email,
name=user.name,
password=user.password,
)
db.add(db_user)
await db.commit()
await db.refresh(db_user)
return db_user
async def create_user_token(db: AsyncSession, user: User) -> UserToken:
"""用户登陆成功后,创建用户 token """
db_token = UserToken(
user=user, expires=datetime.now() + timedelta(weeks=2)
)
db.add(db_token)
await db.commit()
return db_token
编写注册新用户的代码:
shell
from fastapi import APIRouter, FastAPI
from pydantic import BaseModel
from app.crud import crud_user
app = FastAPI()
router = APIRouter()
class UserBase(BaseModel):
email: EmailStr
name: str
class UserCreate(UserBase):
password: constr(strip_whitespace=True, min_length=8)
class User(UserBase):
id: Optional[int] = None
token: TokenBase | None = None
class Config:
orm_mode = True
添加注册用户的路由:
python
@router.post("/sign-up/", response_model=User)
async def create_user(user: UserCreate, db: DBSession):
user_db = await crud_user.get_user_by_email(db, email=user.email)
if user_db:
raise HTTPException(status_code=400, detail="User already registered")
user = await crud_user.create_user(db, user=user)
user.token = await crud_user.create_user_token(db, user=user)
return user
app.include_router(user_routes)
测试代码
耶!我们已经实现了注册逻辑,最好添加一些测试来检查一切是否如预期那样工作。由于我们使用异步数据库连接,因此需要使用异步测试。因此,我们需要添加一些特殊能力代码:
python
import asyncio
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.api.deps import get_db
from app.core.config import settings
from app.db.base import Base
from app.main import app
@pytest.fixture(scope="session")
def event_loop() -> asyncio.AbstractEventLoop:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
def engine():
engine = create_async_engine(settings.TEST_SQLALCHEMY_DATABASE_URL)
yield engine
engine.sync_engine.dispose()
@pytest_asyncio.fixture(scope="session")
async def prepare_db():
create_db_engine = create_async_engine(
settings.POSTGRES_DATABASE_URL,
isolation_level="AUTOCOMMIT",
)
async with create_db_engine.begin() as connection:
await connection.execute(
text(
"drop database if exists {name};".format(
name=settings.TEST_DB_NAME
)
),
)
await connection.execute(
text("create database {name};".format(name=settings.TEST_DB_NAME)),
)
@pytest_asyncio.fixture(scope="session")
async def db_session(engine) -> AsyncSession:
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.run_sync(Base.metadata.create_all)
TestingSessionLocal = sessionmaker(
expire_on_commit=False,
class_=AsyncSession,
bind=engine,
)
async with TestingSessionLocal(bind=connection) as session:
yield session
await session.flush()
await session.rollback()
@pytest.fixture(scope="session")
def override_get_db(prepare_db, db_session: AsyncSession):
async def _override_get_db():
yield db_session
return _override_get_db
@pytest_asyncio.fixture(scope="session")
async def async_client(override_get_db):
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
首先,我们需要为 event_loop
fixture 更改作用域。默认情况下它是函数级别的 fixture
,但在这种情况下,我们必须将我们的 DB
设为函数级别,这会导致性能问题,而使用会话级别(session scope
)可以解决这个问题。
另外,我们添加了 engine
fixture 以使用测试数据库而非实际数据库。在 prepare_db
中,我们确保数据库已创建。在 db_session
中,我们创建表格并返回数据库连接。然后在 override_get_db
中更新项目依赖项,以确保测试期间的视图不会使用实际数据库。最后,我们创建了 async_client
来执行对我们API的异步请求。
所有准备工作已完成,现在我们进行测试:
python
import pytest
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.crud_user import create_user
from app.db.base import User
from app.models.users import UserToken
from app.schemas.user import UserCreate
@pytest_asyncio.fixture
async def user(db_session: AsyncSession) -> User:
user = UserCreate(
email="demo007x@juejin.com",
name="demo007",
password="12345678"
)
user_db = await create_user(db_session, user)
yield user_db
await db_session.delete(user_db)
await db_session.commit()
@pytest.mark.asyncio
async def test_sign_up(async_client, db_session):
request_data = {
"email": "demo008x@juejin.com",
"name": "demo 008",
"password": "12345678",
}
response = await async_client.post("/sign-up/", json=request_data)
token_counts = await db_session.execute(select(func.count(UserToken.id)))
assert token_counts.scalar_one() == 1
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["email"] == "demo007x@juejin.com"
assert response.json()["name"] == "demo007"
assert response.json()["token"]["access_token"] is not None
assert response.json()["token"]["expires"] is not None
assert response.json()["token"]["token_type"] == "bearer"
@pytest.mark.asyncio
async def test_sign_up_existing_user(async_client, user):
request_data = {
"email": user.email,
"name": "Weatherwax",
"password": "12345678",
}
response = await async_client.post("/sign-up/", json=request_data)
assert response.status_code == 400
assert response.json()["detail"] == "你已经注册"
@pytest.mark.asyncio
async def test_sign_up_weak_password(async_client):
request_data = {
"email": "xx@ww.com",
"name": "Vimes",
"password": "123",
}
response = await async_client.post("/sign-up/", json=request_data)
assert response.status_code == 422
assert (
response.json()["detail"][0]["msg"]
== "确保该值至少包含8个字符"
)
assert (
response.json()["detail"][0]["type"]
== "数值错误。任何字符串最小长度"
)
使用 Celery tasks
AsyncIO
适合IO密集型任务。这就是为什么我们使用它来从数据库中读取数据。但如果我们需要执行一些需要大量CPU的任务呢?在这种情况下,我们应该考虑将此任务发送到单独的进程。我们可以查看文档Celery来帮助我们完成这个任务。
在我们的系统中,用户将能够创建帖子。但是,在存储到数据库之前,帖子的内容将被加密。加密是一个CPU密集型任务,因此我们需要使用celery
。让我们创建所需的 Model
:
python
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from app.db.base_class import Base
class UserKeys(Base):
__tablename__ = "user_keys"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
public_key = Column(String(2000), nullable=False)
is_revoked = Column(Boolean, default=False)
user = relationship("User", back_populates="keys")
class UserGroup(Base):
__tablename__ = "user_groups"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50))
users = relationship(
"User",
secondary="user_group_association",
back_populates="groups",
)
posts = relationship(
"Post",
back_populates="user_group",
cascade="all, delete-orphan",
)
class UserGroupAssociation(Base):
__tablename__ = "user_group_association"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
group_id = Column(Integer, ForeignKey("user_groups.id"))
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(100))
content = Column(Text)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
group_id = Column(
Integer,
ForeignKey("user_groups.id", ondelete='CASCADE'),
nullable=False,
)
author = relationship("User", back_populates="posts")
user_group = relationship("UserGroup", back_populates="posts")
keys = relationship(
"PostKeys",
back_populates="post",
cascade="all, delete-orphan",
)
class PostKeys(Base):
__tablename__ = "post_keys"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(
Integer,
ForeignKey("posts.id", ondelete='CASCADE'),
nullable=False,
)
public_key_id = Column(
Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
)
encrypted_key = Column(Text)
post = relationship("Post", back_populates="keys")
public_key = relationship("UserKeys")
每个用户都有自己的公钥/私钥对。他将公钥上传到服务器,并保持私钥的机密性。还有用户组 . 每个用户可以参与不同的组,但每篇帖子只能附加到一个特定的组中。因此,只有该组成员才能阅读帖子内容。
当添加新帖子时,系统会生成临时密钥,并使用该密钥加密帖子的内容,然后为每个群组成员将临时密钥用用户的公钥加密。当用户从服务器获取帖子时,他会收到使用他的公钥加密的加密内容和临时密钥。他可以使用私有秘钥解密临时秘钥,并用它来解密帖子的内容。让我们看一下代码。
python
from pydantic import BaseModel
class PostBase(BaseModel):
title: str
content: str
group_id: int
class PostInDBBase(PostBase):
id: Optional[int] = None
class Config:
orm_mode = True
async def create_post(db: AsyncSession, post: PostBase, author: User) -> Post:
db_post = Post(
title=post.title,
content=post.content,
group_id=post.group_id,
author=author,
)
db.add(db_post)
await db.commit()
await db.refresh(db_post)
return db_post
添加发帖路由代码:
python
@router.post("/posts/", response_model=PostInDBBase, status_code=201)
async def create_post(
post: PostBase,
db: DBSession,
current_user: CurrentUser,
):
plain_content = post.content
post.content = ""
post = await create_post(
db=db,
post=post,
author=current_user,
)
encrypt_post_content.delay(post_id=post.id, content=plain_content)
return post
这是一个视图,它接收帖子并将其保存到数据库。这里最有趣的部分是encrypt_post_content.delay()
方法。实际上,这是一个Celery
任务,将在单独的进程中执行。就是这样:
python
import os
from celery import Celery
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from sqlalchemy import create_engine, select, update
from sqlalchemy.orm import sessionmaker
from app.core.crypto_tools import (
asymmetric_encryption,
generate_symmetric_key,
symmetric_encryption,
)
from app.db.base import Post, PostKeys, User, UserGroup, UserKeys
from app.core.config import settings
celery = Celery("secureblogs")
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL")
sync_engine = create_engine(settings.SYNC_SQLALCHEMY_DATABASE_URL, echo=True)
SyncSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=sync_engine,
)
@celery.task(name="encrypt_post_content")
def encrypt_post_content(post_id: int, content: str):
# generate temp key and encrypt content
key = generate_symmetric_key()
encrypted_content = symmetric_encryption(content, key)
with SyncSessionLocal() as db:
# update post instance
post_statement = (
update(Post)
.returning(Post.group_id)
.where(Post.id == post_id)
.values(content=encrypted_content)
)
post = db.execute(post_statement).fetchone()
# fetch user's public keys from DB
users_subquery = (
select(User.id)
.where(User.groups.any(UserGroup.id.in_([post.group_id])))
.subquery()
)
statement = select(UserKeys).where(
(UserKeys.user_id.in_(users_subquery))
& (UserKeys.is_revoked == False)
)
public_keys = db.execute(statement).scalars().all()
db_post_keys = []
for public_key in public_keys:
# Save generated keys in DB
public_pem_data = public_key.public_key
public_key_object = load_pem_public_key(public_pem_data.encode())
encrypted_key = asymmetric_encryption(key, public_key_object)
db_post_keys.append(
PostKeys(
post_id=post_id,
public_key_id=public_key.id,
encrypted_key=encrypted_key,
)
)
db.bulk_save_objects(db_post_keys)
db.commit()
Celery
任务只是一个同步的Python
函数,因此为了执行数据库查询,我们在其中使用同步数据库会话。
Websockets
我们的API
允许创建加密的帖子。但是等等,它只适用于现有用户。如果新用户加入群并想阅读一些帖子,该怎么办?他不能这样做,因为他不能解密临时密钥。他需要那个创建帖子的人给他寄临时钥匙。这里的websocket
可能非常方便。当用户要求帖子访问时,我们会向帖子的作者发送实时通知。Post
的作者收到通知并决定是批准还是拒绝请求。我们来实现它。首先添加新的 model
。它包含关于请求帖子访问的用户、帖子本身和用户公钥的信息:
python
class ReadPostRequest(Base):
__tablename__ = "read_post_request"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(
Integer,
ForeignKey("posts.id", ondelete='CASCADE'),
nullable=False,
)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
public_key_id = Column(
Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
)
post = relationship("Post")
requester = relationship("User")
public_key = relationship("UserKeys")
现在我们需要允许创建新请求的API:
python
async def get_post(db: AsyncSession, post_id: int) -> Post:
statement = select(Post).where(Post.id == post_id)
result = await db.execute(statement)
return result.scalars().first()
async def get_user_key(
db: AsyncSession,
user: User,
) -> UserKeys:
statement = select(UserKeys).where(
(UserKeys.user == user) & (UserKeys.is_revoked == False)
)
result = await db.execute(statement)
return result.scalars().first()
async def add_read_post_request(
db: AsyncSession, user: User, post_id: int
) -> ReadPostRequest:
exists_statement = select(ReadPostRequest.id).where(
(ReadPostRequest.user_id == user.id)
& (ReadPostRequest.post_id == post_id)
)
result = await db.execute(exists_statement)
if result.scalars().first():
return None
public_key_statement = select(UserKeys).where(
(UserKeys.is_revoked == False) & (UserKeys.user_id == user.id)
)
result = await db.execute(public_key_statement)
if not (public_key := result.scalars().first()):
return None
db_read_post_request = ReadPostRequest(
user_id=user.id,
post_id=post_id,
public_key=public_key,
)
db.add(db_read_post_request)
await db.commit()
await db.refresh(db_read_post_request)
return db_read_post_request
添加请求路由:
python
@router.post("/posts/{post_id}/request_read/", status_code=204)
async def add_read_post_request(
post_id: int,
db: DBSession,
current_user: CurrentUser,
):
post = await crud_post.get_post(db, post_id)
if not post:
raise HTTPException(status_code=404)
user_key = await crud_user.get_user_key(db, current_user)
request = await crud_post.add_read_post_request(db, current_user, post_id)
if request:
await ws_manager.send_personal_message(
{
'request_id': request.id,
'post_id': post_id,
'requested_user_id': current_user.id,
'user_public_key': user_key.public_key,
},
post.user_id,
)
除了最后一行之外,这里没有什么新内容。我们只是检查帖子是否真的存在于数据库中。然后我们获取用户的公钥,创建新请求,最后发送 WebSocket
通知。 现在让我们更仔细地看看我们是如何实现这个功能的。
python
from fastapi import WebSocket
class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, WebSocket] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
self.active_connections[user_id] = websocket
def disconnect(self, user_id: int):
self.active_connections.pop(user_id)
async def send_personal_message(self, message: dict, user_id: int):
if websocket := self.active_connections.get(user_id):
await websocket.send_json(message)
ws_manager = ConnectionManager()
这是我们的 WebSocket
管理器。在这里,我们有一个字典,用于保存用户ID
及与每个用户ID
相关联的WebSocket
连接。当有人想要发送个人通知时,会使用 send_personal_message
。
最后让我们看一下如何创建新的WebSocket连接。
python
from typing import Annotated
from fastapi import (
APIRouter,
Depends,
Query,
status,
WebSocket,
WebSocketDisconnect,
WebSocketException,
)
from app.api.deps import DBSession
from app.api.websockets.managers import ws_manager
from app.crud.crud_user import get_user_by_token
router = APIRouter()
async def get_token(
websocket: WebSocket,
token: Annotated[str | None, Query()] = None,
):
if token is None:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return token
添加路由请求:
python
@router.websocket("/ws/post_request")
async def websocket_endpoint(
websocket: WebSocket,
db: DBSession,
token: Annotated[str, Depends(get_token)],
):
user = await get_user_by_token(db, token)
if not user:
raise WebSocketException(code=status.HTTP_401_UNAUTHORIZED)
try:
await ws_manager.connect(user.id, websocket)
await ws_manager.send_personal_message(
{"message": "connection accepted"},
user.id,
)
while True:
await websocket.receive_text()
except WebSocketDisconnect:
ws_manager.disconnect(user.id)
我们创建了新的 websocket· API
/ws/post_request
。此 API
会检查用户的令牌,如果令牌有效,则会创建新连接并发送给用户确认信息 connection accepted
。
总结
以上内容我们以用户发帖和浏览帖子为程序主要内容分别介绍了下面的内容:
- 使用 FastAPI 程序
- FastAPI 中使用
async SQLAlchemy
数据库能力。 - FastAPI 中使用
celery
做任务队列的异步。 - FastAPI 中使用
websockets
做信息通知。