参考代码
python
import time
from api.middleware import Middleware
from api.middleware.base import BaseHTTPMiddleware
from api.exceptions import AuthException
import jwt
has_logger = False
try:
from log import logger
has_logger = True
except:
pass
def roleapi(has_auth_func, jwtkey=None):
"""
判断用户角色是否有某个接口的权限
:param has_auth_func 考虑到可能使用异步的MySQL库,此函数必须是一个异步函数
接受Token中解析出来的role和request中的method和path作为参数,用于判断用户对该路径是否有访问权限
"""
return Middleware(RoleAPIAuthMiddleware, has_auth_func=has_auth_func, jwtkey=jwtkey)
class RoleAPIAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, has_auth_func, jwtkey, app):
super().__init__(app)
self.has_auth_func = has_auth_func
self.jwtkey = jwtkey
async def dispatch(self, request, call_next):
# 判断是否需要跳过校验
if "/login" in str(request.url) or "/register" in str(request.url):
if has_logger:
logger.debug("识别到正在访问不需要校验的路径", method=request.method, url=request.url)
return await call_next(request)
if has_logger:
msg = f"开始进行基于角色和接口的高级别权限校验:method={request.method} url={request.url}"
logger.debug(msg)
# 获取token
token = request.headers.get("Authorization")
if token is None:
if has_logger:
logger.error("权限不足:没有携带Token", token=token)
raise AuthException("权限不足:没有携带Token")
# 解析token
data = None
try:
if isinstance(self.jwtkey, str):
data = jwt.parse_token(token, key=self.jwtkey)
else:
data = jwt.parse_token(token)
except Exception as e:
if has_logger:
logger.error("获取用户Token失败", headers=request.headers, token=token, error=e)
raise AuthException("无效的Token")
# token的结果必须是字典类型
if not isinstance(data, dict):
if has_logger:
logger.error("Token的解析结果不是字典类型", data=data)
raise AuthException("无效的Token")
# 必须要有过期时间
expired = data.get("expired")
if expired is None:
if has_logger:
logger.error("该Token没有设置过期时间", data=data, expired=expired)
raise AuthException("无效的Token")
# 校验过期时间的类型
if not (isinstance(expired, int) or isinstance(expired, float)):
if has_logger:
logger.error("该Token的过期时间不是数字类型", expired=expired)
raise AuthException("无效的Token")
# 校验是否过期
now = time.time()
if now > expired:
if has_logger:
logger.warning("Token已过期", expired=expired)
raise AuthException("Token已过期")
# 必须要有用户名和用户ID
userid = data.get("id")
username = data.get("username")
userrole = data.get("role")
if not (userid or username or userrole):
if has_logger:
logger.error(
"Token中应该包含用户ID,用户名和用户角色",
userid=userid,
username=username,
userrole=userrole,
token=token,
data=data,
)
raise AuthException("无效的Token")
# 获取请求方法和请求路径
base_url = request.base_url
url = request.url
method = request.method
path = str(url).replace(str(base_url), "")
if has_logger:
logger.debug(
"开始查询用户是否具备接口级别的权限",
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
has_auth = False
try:
has_auth = await self.has_auth_func(userrole, method, path)
except Exception as e:
if has_logger:
logger.error(
"查询用户接口级别权限失败",
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
error=e,
)
raise AuthException("无效的Token")
if has_logger:
logger.debug(
"成功查询用户是否具备接口级别的权限",
has_auth=has_auth,
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
if not has_auth:
if has_logger:
logger.warning(
"权限不足",
has_auth=has_auth,
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
raise AuthException("权限不足")
# 权限充足,发送请求,获取响应
response = await call_next(request)
return response
调用栈分析
使用中间件
python
app1 = api.Api(
routes=[
api.resp.get("/", index),
api.resp.get("/1", index),
api.resp.post("/login", login),
],
middleware=[
# 默认是:zhangdapeng zhangdapng520
# 可以传入账号和密码进行覆盖
apimidauth.roleapi(has_auth_func)
]
)
apimidauth.roleapi(has_auth_func) 实例化了一个中间件对象。
has_auth_func 是一个函数
这个函数,会在中间件里面,被调用到,完整代码如下。
python
async def has_auth_func(role, method, path):
"""校验角色对method和path是否有访问权限"""
print(role, method, path)
if not str(path).startswith("/"):
path = f"/{path}"
# GET:/1
auth = str(method).upper() + ":" + path
# 判断是否有权限
role_auth_dict = auth_dict.get(role)
if not isinstance(role_auth_dict, dict):
return False
if not role_auth_dict.get(auth):
return False
return True
roleapi 方法
完整代码如下:
python
def roleapi(has_auth_func, jwtkey=None):
"""
判断用户角色是否有某个接口的权限
:param has_auth_func 考虑到可能使用异步的MySQL库,此函数必须是一个异步函数
接受Token中解析出来的role和request中的method和path作为参数,用于判断用户对该路径是否有访问权限
"""
return Middleware(RoleAPIAuthMiddleware, has_auth_func=has_auth_func, jwtkey=jwtkey)
这个其实就是一个普通的函数。
不过比较特殊的地方在于,这个函数的返回值是一个Middleware类的实例。这个类接收如下参数:
- RoleAPIAuthMiddleware:自定义的中间件
- has_auth_func:实际上是自定义中间件需要的一个参数
- jwtkey:实际上也是自定义中间件需要的一个参数
注意:has_auth_func 和 jwtkey 这两个参数,不是 Middleware 类必须的,而是我们自定义的 RoleAPIAuthMiddleware 需要的。
RoleAPIAuthMiddleware 的代码分析
完整代码:
python
class RoleAPIAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, has_auth_func, jwtkey, app):
super().__init__(app)
self.has_auth_func = has_auth_func
self.jwtkey = jwtkey
async def dispatch(self, request, call_next):
# 判断是否需要跳过校验
if "/login" in str(request.url) or "/register" in str(request.url):
if has_logger:
logger.debug("识别到正在访问不需要校验的路径", method=request.method, url=request.url)
return await call_next(request)
if has_logger:
msg = f"开始进行基于角色和接口的高级别权限校验:method={request.method} url={request.url}"
logger.debug(msg)
# 获取token
token = request.headers.get("Authorization")
if token is None:
if has_logger:
logger.error("权限不足:没有携带Token", token=token)
raise AuthException("权限不足:没有携带Token")
# 解析token
data = None
try:
if isinstance(self.jwtkey, str):
data = jwt.parse_token(token, key=self.jwtkey)
else:
data = jwt.parse_token(token)
except Exception as e:
if has_logger:
logger.error("获取用户Token失败", headers=request.headers, token=token, error=e)
raise AuthException("无效的Token")
# token的结果必须是字典类型
if not isinstance(data, dict):
if has_logger:
logger.error("Token的解析结果不是字典类型", data=data)
raise AuthException("无效的Token")
# 必须要有过期时间
expired = data.get("expired")
if expired is None:
if has_logger:
logger.error("该Token没有设置过期时间", data=data, expired=expired)
raise AuthException("无效的Token")
# 校验过期时间的类型
if not (isinstance(expired, int) or isinstance(expired, float)):
if has_logger:
logger.error("该Token的过期时间不是数字类型", expired=expired)
raise AuthException("无效的Token")
# 校验是否过期
now = time.time()
if now > expired:
if has_logger:
logger.warning("Token已过期", expired=expired)
raise AuthException("Token已过期")
# 必须要有用户名和用户ID
userid = data.get("id")
username = data.get("username")
userrole = data.get("role")
if not (userid or username or userrole):
if has_logger:
logger.error(
"Token中应该包含用户ID,用户名和用户角色",
userid=userid,
username=username,
userrole=userrole,
token=token,
data=data,
)
raise AuthException("无效的Token")
# 获取请求方法和请求路径
base_url = request.base_url
url = request.url
method = request.method
path = str(url).replace(str(base_url), "")
if has_logger:
logger.debug(
"开始查询用户是否具备接口级别的权限",
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
has_auth = False
try:
has_auth = await self.has_auth_func(userrole, method, path)
except Exception as e:
if has_logger:
logger.error(
"查询用户接口级别权限失败",
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
error=e,
)
raise AuthException("无效的Token")
if has_logger:
logger.debug(
"成功查询用户是否具备接口级别的权限",
has_auth=has_auth,
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
if not has_auth:
if has_logger:
logger.warning(
"权限不足",
has_auth=has_auth,
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
raise AuthException("权限不足")
# 权限充足,发送请求,获取响应
response = await call_next(request)
return response
基本结构分析
python
class RoleAPIAuthMiddleware(BaseHTTPMiddleware):
def __init__(self, has_auth_func, jwtkey, app):
super().__init__(app)
self.has_auth_func = has_auth_func
self.jwtkey = jwtkey
自定义的中间件类,需要继承:BaseHTTPMiddleware,这个类来自于 from api.middleware.base import BaseHTTPMiddleware 。
初始化方法:
python
def __init__(self, has_auth_func, jwtkey, app):
super().__init__(app)
self.has_auth_func = has_auth_func
self.jwtkey = jwtkey
在这个初始化方法中,我们定义了此中间件需要的参数。
我们在这里定义的是类的参数,但是实际上最后传递参数的方式是:
python
return Middleware(RoleAPIAuthMiddleware, has_auth_func=has_auth_func, jwtkey=jwtkey)
这里的 app 是没有传参的。
核心是 dispatch 方法
python
async def dispatch(self, request, call_next):
# 判断是否需要跳过校验
if "/login" in str(request.url) or "/register" in str(request.url):
if has_logger:
logger.debug("识别到正在访问不需要校验的路径", method=request.method, url=request.url)
return await call_next(request)
if has_logger:
msg = f"开始进行基于角色和接口的高级别权限校验:method={request.method} url={request.url}"
logger.debug(msg)
# 获取token
token = request.headers.get("Authorization")
if token is None:
if has_logger:
logger.error("权限不足:没有携带Token", token=token)
raise AuthException("权限不足:没有携带Token")
# 解析token
data = None
try:
if isinstance(self.jwtkey, str):
data = jwt.parse_token(token, key=self.jwtkey)
else:
data = jwt.parse_token(token)
except Exception as e:
if has_logger:
logger.error("获取用户Token失败", headers=request.headers, token=token, error=e)
raise AuthException("无效的Token")
# token的结果必须是字典类型
if not isinstance(data, dict):
if has_logger:
logger.error("Token的解析结果不是字典类型", data=data)
raise AuthException("无效的Token")
# 必须要有过期时间
expired = data.get("expired")
if expired is None:
if has_logger:
logger.error("该Token没有设置过期时间", data=data, expired=expired)
raise AuthException("无效的Token")
# 校验过期时间的类型
if not (isinstance(expired, int) or isinstance(expired, float)):
if has_logger:
logger.error("该Token的过期时间不是数字类型", expired=expired)
raise AuthException("无效的Token")
# 校验是否过期
now = time.time()
if now > expired:
if has_logger:
logger.warning("Token已过期", expired=expired)
raise AuthException("Token已过期")
# 必须要有用户名和用户ID
userid = data.get("id")
username = data.get("username")
userrole = data.get("role")
if not (userid or username or userrole):
if has_logger:
logger.error(
"Token中应该包含用户ID,用户名和用户角色",
userid=userid,
username=username,
userrole=userrole,
token=token,
data=data,
)
raise AuthException("无效的Token")
# 获取请求方法和请求路径
base_url = request.base_url
url = request.url
method = request.method
path = str(url).replace(str(base_url), "")
if has_logger:
logger.debug(
"开始查询用户是否具备接口级别的权限",
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
has_auth = False
try:
has_auth = await self.has_auth_func(userrole, method, path)
except Exception as e:
if has_logger:
logger.error(
"查询用户接口级别权限失败",
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
error=e,
)
raise AuthException("无效的Token")
if has_logger:
logger.debug(
"成功查询用户是否具备接口级别的权限",
has_auth=has_auth,
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
if not has_auth:
if has_logger:
logger.warning(
"权限不足",
has_auth=has_auth,
method=method,
path=path,
userid=userid,
username=username,
userrole=userrole,
)
raise AuthException("权限不足")
# 权限充足,发送请求,获取响应
response = await call_next(request)
return response
中间件核心方法分析
参数是什么
python
async def dispatch(self, request, call_next):
首先,这个方法是一个异步方法。
第一个参数是请求对象,存储了客户端的所有请求信息。
第二个参数是调用下一个中间件的对象,如果成功了,则调用此方法得到一个response对象,返回response对象即可。
如果成功了返回什么?
如果成功了,则调用此 call_next 方法得到一个response对象,返回response对象即可。
python
response = await call_next(request)
return response
如果失败了,该返回什么?
示例代码如下:
python
if has_logger:
logger.error("Token的解析结果不是字典类型", data=data)
raise AuthException("无效的Token")
首先,我们是记录错误日志。
然后,抛出一个异常。
因为 zdppy_api 已经内部封装了全局异常处理,所以这个异常,最终会被全局异常错误处理器自动捕获,并返回给客户端一个比较通用且友好的错误信息。
支持哪些异常类
首先是 AuthException,这个来源于:
python
from api.exceptions import AuthException
通过查看源码,我们可以知道, zdppy_api 框架,目前内置了如下异常处理器:
python
default_exception_handlers = {
404: not_found,
500: server_error,
HTTPException: handle_http_exception,
AuthException: handle_auth_exception,
Exception: handle_exception,
}
最终总结
如果,我们要实现 db 请求上下文中间件:
- 请求开始时,自动建立连接
- 请求结束是,自动断开连接
那么,我们的实现思路如下:
- 1、实现一个 OrmRequestMiddleware 中间件类。这个类继承 api.middleware.base.BaseHTTPMiddleware,接收一个 db 作为参数。
- 2、实现一个 apimidorm.request(db) 方法,这个方法的返回值是 Middleware(OrmRequestMiddleware, db=db)
- 3、在 OrmRequestMiddleware 自定义中间件类中,重写 async def dispatch(self, request, call_next) 方法
- 4、方法体中实现具体的逻辑。请求开始之前,调用 db.connect(),调用 response = call_next(),之后调用 db.colse(),最后返回 response。
以上是一个具体的实现思路,仅供参考。