注:本文并不包含项目全部代码,查看全部代码请直接跳转到 Github
鉴于 FastAPI
和 Flask
框架一样保留了足够多的扩展性,无法做到在企业级项目中开箱即用,本文主要讲述了如何对 FastAPI
框架进行二次封装。以下内容主要分为 8 个部分,分别是统一接口返回、全局异常处理、自定义上下文、用户鉴权、单元测试、多环境配置、数据迁移、CLI启动等。看完之后如果您有更好的建议,欢迎到下方留言,或者到 Github
发起 PR。
统一接口返回
一般来讲,返回给前端的接口数据都是统一格式,例如:
json
{
"success": true,
"msg": "",
"data": {}
}
我们定义两个方法,请求成功的结构如下:
python
def success_response(data=''):
new_body = {'success': True, 'msg': '', 'data': data}
return new_body
请求失败的结构如下:
python
def failed_response(error_type, error_message, error_data=None):
"""failed response
"""
new_body = {
'success': False,
'error_type': error_type,
'msg': error_message,
'data': ''
}
if error_data is not None:
new_body['data'] = error_data
return JSONResponse(new_body)
如果每次在接口返回时都去调用这些方法是一件非常麻烦的事,本着 DRY
的原则,开始我们初步的封装。 FastAPI
的 response_model
参数需要指定模型,所以先定义 APIResponse
类:
python
class APIResponse(GenericModel, Generic[T]):
success: bool
msg: str
data: T = None
然后自定义新的路由函数,用来封装我们的接口:
python
def route(
router: APIRouter,
path: str,
methods: List[str],
response_model=None,
**options
):
common_response_model = APIResponse[response_model]
def wrapper(func: Callable[..., Any]):
async def decorator(*args, **kwargs):
response = await func(*args, **kwargs)
if isinstance(response, Response):
# The response may have already been wrapped, this situation should be ignored.
return response
return success_response(response)
signature = inspect.signature(func)
decorator.__signature__ = signature
decorator.__name__ = func.__name__
decorator.__doc__ = func.__doc__
router.add_api_route(
path,
endpoint=decorator,
response_model=common_response_model,
methods=methods,
**options
)
return decorator
return wrapper
使用新的路由,一个统一的返回模型就做好了:
python
@route(router, '/', ['GET'], response_model=List[PersonOut])
async def list_examples(sa_session: AsyncSession = Depends(get_db_session)):
persons = await sa_session.scalars(
select(Person).order_by(Person.created_at.desc())
)
return persons.all()
现在每次使用 @route
还需要编写 GET
等参数,再次本着 DRY
原则,拿出大名鼎鼎的 functools
库,我们创建新的方法:
python
get = functools.partial(route, methods=['GET'])
上面的代码最终被改写为:
python
@get(router, '/', response_model=List[PersonOut])
async def list_examples(sa_session: AsyncSession = Depends(get_db_session)):
persons = await sa_session.scalars(
select(Person).order_by(Person.created_at.desc())
)
return persons.all()
统一异常处理
FastAPI
已经提供了统一的异常处理,我们只需要简单的封装。
定义最顶层的异常类:
python
class APIException(Exception):
"""api 异常
"""
error_type = 'api_error'
error_message = 'A server error occurred.'
def __init__(self, error_type=None, error_message=None):
if error_type is not None:
self.error_type = error_type
if error_message is not None:
self.error_message = error_message
def __repr__(self):
return '<{} {}: {}>'.format(
self.__class__, self.error_type, self.error_message
)
创建统一异常处理类:
python
class GlobalExceptionHandler:
def __init__(self, app: FastAPI):
self.app = app
@staticmethod
async def handle_api_exception(request: Request, error: APIException):
return failed_response(
error_type=error.error_type, error_message=error.error_message
)
@staticmethod
async def handle_exception(request: Request, error: Exception):
logger.error(f'{request.url} {error}')
return failed_response(error_type='server_error', error_message='Server error')
def init(self):
self.app.add_exception_handler(
RequestValidationError, self.handle_request_validation_error
)
self.app.add_exception_handler(APIException, self.handle_api_exception)
self.app.add_exception_handler(Exception, self.handle_exception)
然后创建一个自定义异常:
python
class PersonNotFound(APIException):
error_type = 'person_not_found'
error_message = 'Person not found'
最后只需要在代码里面抛出异常:
python
@get(router, '/{first_name}', response_model=PersonOut)
async def get_person(first_name: str, sa_session: AsyncSession = Depends(get_db_session)):
person = await sa_session.get(Person, first_name)
if not person:
raise PersonNotFound
return person
现在已经统一处理了接口成功和失败的响应,这已经为我们省去了大量的重复编码工作,但是这还远远不够,我们还需要让这个项目更加适合日常开发。
自定义上下文
日常开发中,我们需要使用到各种三方库,例如 Redis
, Mysql
, Kafka
等,得益于 FastAPI
的 Depends
功能,使得我们引用这些中间件变得非常简单,但还是本着 DRY
的原则,我们把这些代码都封装起来。
我个人最喜欢 FastAPI
框架的一个点就是可以有效解决 Python
在 Pycharm
里面不提示类型的问题,所以我们先定义一个类:
python
class AppContext:
def __init__(self):
self.request: Request | None = None
self.response: Response | None = None
self.sa_session: AsyncSession | None = None
self.redis: Redis | None = None
创建上下文的获取方法:
python
async def get_app_ctx(
request: Request,
response: Response,
sa_session: AsyncSession = Depends(get_db_session),
redis: Redis = Depends(get_redis)
) -> AppContext:
context = AppContext()
context.request = request
context.response = response
context.sa_session = sa_session
context.redis = redis
return context
然后直接使用:
python
@get(router, '/', response_model=List[PersonOut])
async def list_examples(context: AppContext = Depends(get_app_ctx)):
persons = await context.sa_session.scalars(
select(Person).order_by(Person.created_at.desc())
)
return persons.all()
但是我还是觉得不太方便,所以进一步封装:
python
DependsOnContext: AppContext = cast(AppContext, Depends(get_app_ctx))
最终把上面的参数改写为:
python
@get(router, '/', response_model=List[PersonOut])
async def list_examples(context: AppContext = DependsOnContext):
persons = await context.sa_session.scalars(
select(Person).order_by(Person.created_at.desc())
)
return persons.all()
现在我们引入任何的三方库都可以在这个上下文里面去添加,后续使用起来就很方便啦。
获取当前用户
企业项目往往在网关中处理用户鉴权,传递到后端的已经是被解析出来的用户数据。但是在这里我们还是来模拟下基于 JWT
编码和解码的过程,获取当前用户信息。
登录接口如下:
python
@post(router, '/login', response_model=LoginOut)
async def login(user_in: UserIn, context: AppContext = DependsOnContext):
user = await context.sa_session.scalar(
select(User).where(and_(User.username == user_in.username))
)
if not user:
raise UserNotFound
if not await user.verify_password(user_in.password):
raise AccountOrPasswordWrong
token = await generate_token(UserOut.from_orm(user).dict())
await context.redis.set(
f'{config.SERVICE_NAME}:user:token:{user.id}', token,
config.EXPIRED_SECONDS
)
return {'token': token, 'user': user}
自定义 CurrentUser
模型:
python
class CurrentUser(BaseModel):
id: int
username: str
假设前端拿到 token
后,后面的每次请求都在 header
中携带 token
:
python
async def get_current_user(
token: str = Header('token'), redis: Redis = Depends(get_redis)
):
if not token:
raise ApiSignatureExpired
payload = await parse_token(token)
if 'id' in payload:
try:
current_user = CurrentUser.parse_obj(payload)
except ValidationError:
raise JwtTokenError
cache_token = await redis.get(
f'{config.SERVICE_NAME}:user:token:{current_user.id}'
)
if cache_token != token:
raise ApiSignatureExpired
return current_user
raise ApiSignatureExpired
自定义 Depends
:
python
DependsOnUser: CurrentUser = cast(CurrentUser, Depends(get_current_user))
然后就可以获取当前登录用户信息了:
python
@get(router, '/', response_model=UserOut)
async def get_user_detail(current_user: CurrentUser = DependsOnUser):
return current_user
单元测试
真实世界的项目开发中,为了保证代码的质量,自然少不了编写单元测试,在这里我对 pytest
框架就不再做过多的赘述。我在封装该模块的时候,考虑的最多的是保证单元测试的一致性。
所以在这里我覆盖了原本的数据库 session
,保证每次测试运行之后回滚数据:
python
@pytest.fixture
async def session(app: FastAPI) -> AsyncSession:
async with engine.begin() as connection:
async_session_local = async_sessionmaker(
bind=connection,
autoflush=False,
future=True,
autocommit=False,
expire_on_commit=False
)
async_session = async_session_local()
# Overwrite the current database so that every time the test is run, the transaction is rollback
app.dependency_overrides[get_db_session] = lambda: async_session
yield async_session
await async_session.close()
await connection.rollback()
测试用例的编写:
python
async def test_list_examples(client: AsyncClient):
response = await client.get('/v1/examples/')
assert response.status_code == 200
json_result = response.json()
assert len(json_result['data']) == 0
# 创建
response = await client.post(
'/v1/examples/', json={
'first_name': 'test',
'last_name': 'test'
}
)
assert response.status_code == 200
response = await client.post(
'/v1/examples/', json={
'first_name': 'test2',
'last_name': 'test2'
}
)
assert response.status_code == 200
response = await client.get('/v1/examples/')
assert response.status_code == 200
json_result = response.json()
assert len(json_result['data']) == 2
多环境配置
一般来讲我们需要区分开发、测试、预发布、正式等环境,每个环境的数据库、秘钥等配置都各不相同,在这里我通过环境变量来导入不同的配置文件。
python
@lru_cache(maxsize=1)
def get_config(config_name: str = None) -> BaseConfig:
"""
if config name is none, get active profile from env
"""
if not config_name:
config_name = get_active_env()
configs_module = importlib.import_module('configs')
config_class = getattr(configs_module, config_name.capitalize())
return config_class()
如果环境变量值为 Development
,则导入的就是 Development
类下的配置文件:
python
class Development(BaseConfig):
STAGE: str = 'dev'
class Docker(BaseConfig):
STAGE = 'docker'
DB_HOST = 'mysql'
REDIS_URL = 'redis://redis:6379/0'
class Testing(BaseConfig):
STAGE: str = 'test'
# logger config
LOGGING_LEVEL: str = 'DEBUG'
# db config
DB_DATABASE = 'example_test'
DB_ENABLE_ECHO = False
class Production(BaseConfig):
STAGE: str = 'prod'
DEBUG: bool = False
# logger config
LOGGING_LEVEL: str = 'INFO'
数据迁移
这里数据迁移是基于 alembic
实现的,它能够与 Sqlachelmy
很好的集成,只需要简单的编码:
python
def migrate(commit):
path = Path('models', 'migrations', 'versions')
if path.exists():
has_versions = any(
filter(lambda _dir: _dir.name.endswith('.py'), path.iterdir())
)
else:
path.mkdir()
has_versions = False
revision_args = ['revision', '--autogenerate', '-m']
if has_versions is False:
revision_args.append('"init db"')
else:
if commit:
revision_args.append(f'"{commit}"')
else:
revision_args.append('"update"')
alembic.config.main(argv=revision_args)
migrate_args = ['--raiseerr', 'upgrade', 'head']
alembic.config.main(argv=migrate_args)
CLI
看到这里,我们的项目不单单要运行 FastAPI
应用,还可能要运行测试、数据迁移等,这个时候单一的程序入口已经不能满足最基本的要求,所以我基于 click
框架对项目入口进一步封装。
python
@click.group()
@click.version_option(version='1.0.0')
def cli():
"""CLI management for FastAPI project
"""
添加项目启动方法:
python
@cli.command('start')
def start():
uvicorn.run(
app,
host=config.HOST,
port=config.PORT,
debug=config.DEBUG,
log_config=config.log_config()
)
添加测试运行方法,并支持运行单个测试:
python
@cli.command('test')
@click.argument('test_names', required=False, nargs=-1)
def test(test_names):
import pytest
args = config.PYTEST_ARGS
if test_names:
args.extend(test_names)
pytest.main(args)
添加数据迁移方法,并支持自定义 commit
:
python
@cli.command('migrate')
@click.argument('commit', required=False, nargs=-1)
def migrate(commit):
...
现在我们可以在命令行中执行以下命令:
shell
# 启动项目
python manage.py start
# 运行全部测试
python manage.py test
# 运行单个测试
python manage.py test unittests/test_example.py::test_list_examples2
# 数据迁移
python manage.py migrate
总结
感谢您阅读到这里,目前为止已经涵盖了日常开发最常见的一部分,但还不足以包含企业级项目开发中的所有情况,所以我热烈欢迎您可以来继续完善这个项目,以及在评论区和我积极讨论,谢谢!