一文讲透如何二次封装 FastAPI 框架,看完直呼 "Pythonic"

注:本文并不包含项目全部代码,查看全部代码请直接跳转到 Github

鉴于 FastAPIFlask 框架一样保留了足够多的扩展性,无法做到在企业级项目中开箱即用,本文主要讲述了如何对 FastAPI 框架进行二次封装。以下内容主要分为 8 个部分,分别是统一接口返回、全局异常处理、自定义上下文、用户鉴权、单元测试、多环境配置、数据迁移、CLI启动等。看完之后如果您有更好的建议,欢迎到下方留言,或者到 Github 发起 PR。

统一接口返回

一般来讲,返回给前端的接口数据都是统一格式,例如:

json 复制代码
{
  "success": true,
  "msg": "",
  "data": {}
}

我们定义两个方法,请求成功的结构如下:

python 复制代码
def success_response(data=''):
    new_body = {'success': True, 'msg': '', 'data': data}
    return new_body

请求失败的结构如下:

python 复制代码
def failed_response(error_type, error_message, error_data=None):
    """failed response
    """
    new_body = {
        'success': False,
        'error_type': error_type,
        'msg': error_message,
        'data': ''
    }
    if error_data is not None:
        new_body['data'] = error_data
    return JSONResponse(new_body)

如果每次在接口返回时都去调用这些方法是一件非常麻烦的事,本着 DRY 的原则,开始我们初步的封装。 FastAPIresponse_model 参数需要指定模型,所以先定义 APIResponse 类:

python 复制代码
class APIResponse(GenericModel, Generic[T]):
    success: bool
    msg: str
    data: T = None

然后自定义新的路由函数,用来封装我们的接口:

python 复制代码
def route(
    router: APIRouter,
    path: str,
    methods: List[str],
    response_model=None,
    **options
):
    common_response_model = APIResponse[response_model]
​
    def wrapper(func: Callable[..., Any]):
​
        async def decorator(*args, **kwargs):
            response = await func(*args, **kwargs)
            if isinstance(response, Response):
                # The response may have already been wrapped, this situation should be ignored.
                return response
​
            return success_response(response)
​
        signature = inspect.signature(func)
        decorator.__signature__ = signature
        decorator.__name__ = func.__name__
        decorator.__doc__ = func.__doc__
        router.add_api_route(
            path,
            endpoint=decorator,
            response_model=common_response_model,
            methods=methods,
            **options
        )
        return decorator
​
    return wrapper

使用新的路由,一个统一的返回模型就做好了:

python 复制代码
@route(router, '/', ['GET'], response_model=List[PersonOut])
async def list_examples(sa_session: AsyncSession = Depends(get_db_session)):
    persons = await sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

现在每次使用 @route 还需要编写 GET 等参数,再次本着 DRY 原则,拿出大名鼎鼎的 functools 库,我们创建新的方法:

python 复制代码
get = functools.partial(route, methods=['GET'])

上面的代码最终被改写为:

python 复制代码
@get(router, '/', response_model=List[PersonOut])
async def list_examples(sa_session: AsyncSession = Depends(get_db_session)):
    persons = await sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

统一异常处理

FastAPI 已经提供了统一的异常处理,我们只需要简单的封装。

定义最顶层的异常类:

python 复制代码
class APIException(Exception):
    """api 异常
    """
    error_type = 'api_error'
    error_message = 'A server error occurred.'
​
    def __init__(self, error_type=None, error_message=None):
        if error_type is not None:
            self.error_type = error_type
        if error_message is not None:
            self.error_message = error_message
​
    def __repr__(self):
        return '<{} {}: {}>'.format(
            self.__class__, self.error_type, self.error_message
        )

创建统一异常处理类:

python 复制代码
class GlobalExceptionHandler:
​
    def __init__(self, app: FastAPI):
        self.app = app
​
    @staticmethod
    async def handle_api_exception(request: Request, error: APIException):
        return failed_response(
            error_type=error.error_type, error_message=error.error_message
        )
​
    @staticmethod
    async def handle_exception(request: Request, error: Exception):
        logger.error(f'{request.url} {error}')
        return failed_response(error_type='server_error', error_message='Server error')
​
    def init(self):
        self.app.add_exception_handler(
            RequestValidationError, self.handle_request_validation_error
        )
        self.app.add_exception_handler(APIException, self.handle_api_exception)
        self.app.add_exception_handler(Exception, self.handle_exception)

然后创建一个自定义异常:

python 复制代码
class PersonNotFound(APIException):
    error_type = 'person_not_found'
    error_message = 'Person not found'

最后只需要在代码里面抛出异常:

python 复制代码
@get(router, '/{first_name}', response_model=PersonOut)
async def get_person(first_name: str, sa_session: AsyncSession = Depends(get_db_session)):
    person = await sa_session.get(Person, first_name)
    if not person:
        raise PersonNotFound
    return person

现在已经统一处理了接口成功和失败的响应,这已经为我们省去了大量的重复编码工作,但是这还远远不够,我们还需要让这个项目更加适合日常开发。

自定义上下文

日常开发中,我们需要使用到各种三方库,例如 Redis, Mysql , Kafka 等,得益于 FastAPIDepends 功能,使得我们引用这些中间件变得非常简单,但还是本着 DRY 的原则,我们把这些代码都封装起来。

我个人最喜欢 FastAPI 框架的一个点就是可以有效解决 PythonPycharm 里面不提示类型的问题,所以我们先定义一个类:

python 复制代码
class AppContext:
​
    def __init__(self):
        self.request: Request | None = None
        self.response: Response | None = None
        self.sa_session: AsyncSession | None = None
        self.redis: Redis | None = None

创建上下文的获取方法:

python 复制代码
async def get_app_ctx(
    request: Request,
    response: Response,
    sa_session: AsyncSession = Depends(get_db_session),
    redis: Redis = Depends(get_redis)
) -> AppContext:
​
    context = AppContext()
    context.request = request
    context.response = response
    context.sa_session = sa_session
    context.redis = redis
    return context

然后直接使用:

python 复制代码
@get(router, '/', response_model=List[PersonOut])
async def list_examples(context: AppContext = Depends(get_app_ctx)):
    persons = await context.sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

但是我还是觉得不太方便,所以进一步封装:

python 复制代码
DependsOnContext: AppContext = cast(AppContext, Depends(get_app_ctx))

最终把上面的参数改写为:

python 复制代码
@get(router, '/', response_model=List[PersonOut])
async def list_examples(context: AppContext = DependsOnContext):
    persons = await context.sa_session.scalars(
        select(Person).order_by(Person.created_at.desc())
    )
    return persons.all()

现在我们引入任何的三方库都可以在这个上下文里面去添加,后续使用起来就很方便啦。

获取当前用户

企业项目往往在网关中处理用户鉴权,传递到后端的已经是被解析出来的用户数据。但是在这里我们还是来模拟下基于 JWT 编码和解码的过程,获取当前用户信息。

登录接口如下:

python 复制代码
@post(router, '/login', response_model=LoginOut)
async def login(user_in: UserIn, context: AppContext = DependsOnContext):
    user = await context.sa_session.scalar(
        select(User).where(and_(User.username == user_in.username))
    )
    if not user:
        raise UserNotFound
​
    if not await user.verify_password(user_in.password):
        raise AccountOrPasswordWrong
​
    token = await generate_token(UserOut.from_orm(user).dict())
    await context.redis.set(
        f'{config.SERVICE_NAME}:user:token:{user.id}', token,
        config.EXPIRED_SECONDS
    )
​
    return {'token': token, 'user': user}

自定义 CurrentUser 模型:

python 复制代码
class CurrentUser(BaseModel):
    id: int
    username: str

假设前端拿到 token 后,后面的每次请求都在 header 中携带 token

python 复制代码
async def get_current_user(
    token: str = Header('token'), redis: Redis = Depends(get_redis)
):
    if not token:
        raise ApiSignatureExpired
​
    payload = await parse_token(token)
​
    if 'id' in payload:
        try:
            current_user = CurrentUser.parse_obj(payload)
        except ValidationError:
            raise JwtTokenError
​
        cache_token = await redis.get(
            f'{config.SERVICE_NAME}:user:token:{current_user.id}'
        )
        if cache_token != token:
            raise ApiSignatureExpired
​
        return current_user
​
    raise ApiSignatureExpired

自定义 Depends

python 复制代码
DependsOnUser: CurrentUser = cast(CurrentUser, Depends(get_current_user))

然后就可以获取当前登录用户信息了:

python 复制代码
@get(router, '/', response_model=UserOut)
async def get_user_detail(current_user: CurrentUser = DependsOnUser):
    return current_user

单元测试

真实世界的项目开发中,为了保证代码的质量,自然少不了编写单元测试,在这里我对 pytest 框架就不再做过多的赘述。我在封装该模块的时候,考虑的最多的是保证单元测试的一致性。

所以在这里我覆盖了原本的数据库 session,保证每次测试运行之后回滚数据:

python 复制代码
@pytest.fixture
async def session(app: FastAPI) -> AsyncSession:
    async with engine.begin() as connection:
        async_session_local = async_sessionmaker(
            bind=connection,
            autoflush=False,
            future=True,
            autocommit=False,
            expire_on_commit=False
        )
        async_session = async_session_local()
        # Overwrite the current database so that every time the test is run, the transaction is rollback
        app.dependency_overrides[get_db_session] = lambda: async_session
        yield async_session
        await async_session.close()
​
        await connection.rollback()

测试用例的编写:

python 复制代码
async def test_list_examples(client: AsyncClient):
    response = await client.get('/v1/examples/')
    assert response.status_code == 200
​
    json_result = response.json()
    assert len(json_result['data']) == 0
​
    # 创建
    response = await client.post(
        '/v1/examples/', json={
            'first_name': 'test',
            'last_name': 'test'
        }
    )
    assert response.status_code == 200
​
    response = await client.post(
        '/v1/examples/', json={
            'first_name': 'test2',
            'last_name': 'test2'
        }
    )
    assert response.status_code == 200
​
    response = await client.get('/v1/examples/')
    assert response.status_code == 200
​
    json_result = response.json()
    assert len(json_result['data']) == 2

多环境配置

一般来讲我们需要区分开发、测试、预发布、正式等环境,每个环境的数据库、秘钥等配置都各不相同,在这里我通过环境变量来导入不同的配置文件。

python 复制代码
@lru_cache(maxsize=1)
def get_config(config_name: str = None) -> BaseConfig:
    """
    if config name is none, get active profile from env
    """
    if not config_name:
        config_name = get_active_env()
​
    configs_module = importlib.import_module('configs')
    config_class = getattr(configs_module, config_name.capitalize())
    return config_class()

如果环境变量值为 Development,则导入的就是 Development 类下的配置文件:

python 复制代码
class Development(BaseConfig):
    STAGE: str = 'dev'
​
class Docker(BaseConfig):
    STAGE = 'docker'
​
    DB_HOST = 'mysql'
    REDIS_URL = 'redis://redis:6379/0'
​
class Testing(BaseConfig):
    STAGE: str = 'test'
​
    # logger config
    LOGGING_LEVEL: str = 'DEBUG'
​
    # db config
    DB_DATABASE = 'example_test'
    DB_ENABLE_ECHO = False
​
class Production(BaseConfig):
    STAGE: str = 'prod'
    DEBUG: bool = False
​
    # logger config
    LOGGING_LEVEL: str = 'INFO'

数据迁移

这里数据迁移是基于 alembic 实现的,它能够与 Sqlachelmy 很好的集成,只需要简单的编码:

python 复制代码
def migrate(commit):
    path = Path('models', 'migrations', 'versions')
    if path.exists():
        has_versions = any(
            filter(lambda _dir: _dir.name.endswith('.py'), path.iterdir())
        )
    else:
        path.mkdir()
        has_versions = False
​
    revision_args = ['revision', '--autogenerate', '-m']
    if has_versions is False:
        revision_args.append('"init db"')
    else:
        if commit:
            revision_args.append(f'"{commit}"')
        else:
            revision_args.append('"update"')
​
    alembic.config.main(argv=revision_args)
​
    migrate_args = ['--raiseerr', 'upgrade', 'head']
    alembic.config.main(argv=migrate_args)

CLI

看到这里,我们的项目不单单要运行 FastAPI 应用,还可能要运行测试、数据迁移等,这个时候单一的程序入口已经不能满足最基本的要求,所以我基于 click 框架对项目入口进一步封装。

python 复制代码
@click.group()
@click.version_option(version='1.0.0')
def cli():
    """CLI management for FastAPI project
    """

添加项目启动方法:

python 复制代码
@cli.command('start')
def start():
    uvicorn.run(
        app,
        host=config.HOST,
        port=config.PORT,
        debug=config.DEBUG,
        log_config=config.log_config()
    )

添加测试运行方法,并支持运行单个测试:

python 复制代码
@cli.command('test')
@click.argument('test_names', required=False, nargs=-1)
def test(test_names):
    import pytest
​
    args = config.PYTEST_ARGS
    if test_names:
        args.extend(test_names)
    pytest.main(args)

添加数据迁移方法,并支持自定义 commit

python 复制代码
@cli.command('migrate')
@click.argument('commit', required=False, nargs=-1)
def migrate(commit):
    ...

现在我们可以在命令行中执行以下命令:

shell 复制代码
# 启动项目
python manage.py start
​
# 运行全部测试
python manage.py test
​
# 运行单个测试
python manage.py test unittests/test_example.py::test_list_examples2
​
# 数据迁移
python manage.py migrate

总结

感谢您阅读到这里,目前为止已经涵盖了日常开发最常见的一部分,但还不足以包含企业级项目开发中的所有情况,所以我热烈欢迎您可以来继续完善这个项目,以及在评论区和我积极讨论,谢谢!

相关推荐
databook8 小时前
Manim实现闪光轨迹特效
后端·python·动效
Juchecar9 小时前
解惑:NumPy 中 ndarray.ndim 到底是什么?
python
用户8356290780519 小时前
Python 删除 Excel 工作表中的空白行列
后端·python
Json_9 小时前
使用python-fastApi框架开发一个学校宿舍管理系统-前后端分离项目
后端·python·fastapi
数据智能老司机16 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机17 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机17 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机17 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i17 小时前
drf初步梳理
python·django
每日AI新事件17 小时前
python的异步函数
python