zdppy_api如何实现带参数的中间件

参考代码

python 复制代码
import time

from api.middleware import Middleware
from api.middleware.base import BaseHTTPMiddleware
from api.exceptions import AuthException
import jwt

has_logger = False
try:
    from log import logger

    has_logger = True
except:
    pass


def roleapi(has_auth_func, jwtkey=None):
    """
    判断用户角色是否有某个接口的权限
    :param has_auth_func 考虑到可能使用异步的MySQL库,此函数必须是一个异步函数
    接受Token中解析出来的role和request中的method和path作为参数,用于判断用户对该路径是否有访问权限
    """
    return Middleware(RoleAPIAuthMiddleware, has_auth_func=has_auth_func, jwtkey=jwtkey)


class RoleAPIAuthMiddleware(BaseHTTPMiddleware):
    def __init__(self, has_auth_func, jwtkey, app):
        super().__init__(app)
        self.has_auth_func = has_auth_func
        self.jwtkey = jwtkey

    async def dispatch(self, request, call_next):
        # 判断是否需要跳过校验
        if "/login" in str(request.url) or "/register" in str(request.url):
            if has_logger:
                logger.debug("识别到正在访问不需要校验的路径", method=request.method, url=request.url)
            return await call_next(request)

        if has_logger:
            msg = f"开始进行基于角色和接口的高级别权限校验:method={request.method} url={request.url}"
            logger.debug(msg)

        # 获取token
        token = request.headers.get("Authorization")
        if token is None:
            if has_logger:
                logger.error("权限不足:没有携带Token", token=token)
            raise AuthException("权限不足:没有携带Token")

        # 解析token
        data = None
        try:
            if isinstance(self.jwtkey, str):
                data = jwt.parse_token(token, key=self.jwtkey)
            else:
                data = jwt.parse_token(token)
        except Exception as e:
            if has_logger:
                logger.error("获取用户Token失败", headers=request.headers, token=token, error=e)
            raise AuthException("无效的Token")

        # token的结果必须是字典类型
        if not isinstance(data, dict):
            if has_logger:
                logger.error("Token的解析结果不是字典类型", data=data)
            raise AuthException("无效的Token")

        # 必须要有过期时间
        expired = data.get("expired")
        if expired is None:
            if has_logger:
                logger.error("该Token没有设置过期时间", data=data, expired=expired)
            raise AuthException("无效的Token")

        # 校验过期时间的类型
        if not (isinstance(expired, int) or isinstance(expired, float)):
            if has_logger:
                logger.error("该Token的过期时间不是数字类型", expired=expired)
            raise AuthException("无效的Token")

        # 校验是否过期
        now = time.time()
        if now > expired:
            if has_logger:
                logger.warning("Token已过期", expired=expired)
            raise AuthException("Token已过期")

        # 必须要有用户名和用户ID
        userid = data.get("id")
        username = data.get("username")
        userrole = data.get("role")
        if not (userid or username or userrole):
            if has_logger:
                logger.error(
                    "Token中应该包含用户ID,用户名和用户角色",
                    userid=userid,
                    username=username,
                    userrole=userrole,
                    token=token,
                    data=data,
                )
            raise AuthException("无效的Token")

        # 获取请求方法和请求路径
        base_url = request.base_url
        url = request.url
        method = request.method
        path = str(url).replace(str(base_url), "")
        if has_logger:
            logger.debug(
                "开始查询用户是否具备接口级别的权限",
                method=method,
                path=path,
                userid=userid,
                username=username,
                userrole=userrole,
            )
        has_auth = False
        try:
            has_auth = await self.has_auth_func(userrole, method, path)
        except Exception as e:
            if has_logger:
                logger.error(
                    "查询用户接口级别权限失败",
                    method=method,
                    path=path,
                    userid=userid,
                    username=username,
                    userrole=userrole,
                    error=e,
                )
            raise AuthException("无效的Token")
        if has_logger:
            logger.debug(
                "成功查询用户是否具备接口级别的权限",
                has_auth=has_auth,
                method=method,
                path=path,
                userid=userid,
                username=username,
                userrole=userrole,
            )
        if not has_auth:
            if has_logger:
                logger.warning(
                    "权限不足",
                    has_auth=has_auth,
                    method=method,
                    path=path,
                    userid=userid,
                    username=username,
                    userrole=userrole,
                )
            raise AuthException("权限不足")

        # 权限充足,发送请求,获取响应
        response = await call_next(request)
        return response

调用栈分析

使用中间件

python 复制代码
app1 = api.Api(
    routes=[
        api.resp.get("/", index),
        api.resp.get("/1", index),
        api.resp.post("/login", login),
    ],
    middleware=[
        # 默认是:zhangdapeng zhangdapng520
        # 可以传入账号和密码进行覆盖
        apimidauth.roleapi(has_auth_func)
    ]
)

apimidauth.roleapi(has_auth_func) 实例化了一个中间件对象。

has_auth_func 是一个函数

这个函数,会在中间件里面,被调用到,完整代码如下。

python 复制代码
async def has_auth_func(role, method, path):
    """校验角色对method和path是否有访问权限"""
    print(role, method, path)
    if not str(path).startswith("/"):
        path = f"/{path}"

    # GET:/1
    auth = str(method).upper() + ":" + path

    # 判断是否有权限
    role_auth_dict = auth_dict.get(role)
    if not isinstance(role_auth_dict, dict):
        return False
    if not role_auth_dict.get(auth):
        return False

    return True

roleapi 方法

完整代码如下:

python 复制代码
def roleapi(has_auth_func, jwtkey=None):
    """
    判断用户角色是否有某个接口的权限
    :param has_auth_func 考虑到可能使用异步的MySQL库,此函数必须是一个异步函数
    接受Token中解析出来的role和request中的method和path作为参数,用于判断用户对该路径是否有访问权限
    """
    return Middleware(RoleAPIAuthMiddleware, has_auth_func=has_auth_func, jwtkey=jwtkey)

这个其实就是一个普通的函数。

不过比较特殊的地方在于,这个函数的返回值是一个Middleware类的实例。这个类接收如下参数:

  • RoleAPIAuthMiddleware:自定义的中间件
  • has_auth_func:实际上是自定义中间件需要的一个参数
  • jwtkey:实际上也是自定义中间件需要的一个参数

注意:has_auth_func 和 jwtkey 这两个参数,不是 Middleware 类必须的,而是我们自定义的 RoleAPIAuthMiddleware 需要的。

RoleAPIAuthMiddleware 的代码分析

完整代码:

python 复制代码
class RoleAPIAuthMiddleware(BaseHTTPMiddleware):
    def __init__(self, has_auth_func, jwtkey, app):
        super().__init__(app)
        self.has_auth_func = has_auth_func
        self.jwtkey = jwtkey

    async def dispatch(self, request, call_next):
        # 判断是否需要跳过校验
        if "/login" in str(request.url) or "/register" in str(request.url):
            if has_logger:
                logger.debug("识别到正在访问不需要校验的路径", method=request.method, url=request.url)
            return await call_next(request)

        if has_logger:
            msg = f"开始进行基于角色和接口的高级别权限校验:method={request.method} url={request.url}"
            logger.debug(msg)

        # 获取token
        token = request.headers.get("Authorization")
        if token is None:
            if has_logger:
                logger.error("权限不足:没有携带Token", token=token)
            raise AuthException("权限不足:没有携带Token")

        # 解析token
        data = None
        try:
            if isinstance(self.jwtkey, str):
                data = jwt.parse_token(token, key=self.jwtkey)
            else:
                data = jwt.parse_token(token)
        except Exception as e:
            if has_logger:
                logger.error("获取用户Token失败", headers=request.headers, token=token, error=e)
            raise AuthException("无效的Token")

        # token的结果必须是字典类型
        if not isinstance(data, dict):
            if has_logger:
                logger.error("Token的解析结果不是字典类型", data=data)
            raise AuthException("无效的Token")

        # 必须要有过期时间
        expired = data.get("expired")
        if expired is None:
            if has_logger:
                logger.error("该Token没有设置过期时间", data=data, expired=expired)
            raise AuthException("无效的Token")

        # 校验过期时间的类型
        if not (isinstance(expired, int) or isinstance(expired, float)):
            if has_logger:
                logger.error("该Token的过期时间不是数字类型", expired=expired)
            raise AuthException("无效的Token")

        # 校验是否过期
        now = time.time()
        if now > expired:
            if has_logger:
                logger.warning("Token已过期", expired=expired)
            raise AuthException("Token已过期")

        # 必须要有用户名和用户ID
        userid = data.get("id")
        username = data.get("username")
        userrole = data.get("role")
        if not (userid or username or userrole):
            if has_logger:
                logger.error(
                    "Token中应该包含用户ID,用户名和用户角色",
                    userid=userid,
                    username=username,
                    userrole=userrole,
                    token=token,
                    data=data,
                )
            raise AuthException("无效的Token")

        # 获取请求方法和请求路径
        base_url = request.base_url
        url = request.url
        method = request.method
        path = str(url).replace(str(base_url), "")
        if has_logger:
            logger.debug(
                "开始查询用户是否具备接口级别的权限",
                method=method,
                path=path,
                userid=userid,
                username=username,
                userrole=userrole,
            )
        has_auth = False
        try:
            has_auth = await self.has_auth_func(userrole, method, path)
        except Exception as e:
            if has_logger:
                logger.error(
                    "查询用户接口级别权限失败",
                    method=method,
                    path=path,
                    userid=userid,
                    username=username,
                    userrole=userrole,
                    error=e,
                )
            raise AuthException("无效的Token")
        if has_logger:
            logger.debug(
                "成功查询用户是否具备接口级别的权限",
                has_auth=has_auth,
                method=method,
                path=path,
                userid=userid,
                username=username,
                userrole=userrole,
            )
        if not has_auth:
            if has_logger:
                logger.warning(
                    "权限不足",
                    has_auth=has_auth,
                    method=method,
                    path=path,
                    userid=userid,
                    username=username,
                    userrole=userrole,
                )
            raise AuthException("权限不足")

        # 权限充足,发送请求,获取响应
        response = await call_next(request)
        return response

基本结构分析

python 复制代码
class RoleAPIAuthMiddleware(BaseHTTPMiddleware):
    def __init__(self, has_auth_func, jwtkey, app):
        super().__init__(app)
        self.has_auth_func = has_auth_func
        self.jwtkey = jwtkey

自定义的中间件类,需要继承:BaseHTTPMiddleware,这个类来自于 from api.middleware.base import BaseHTTPMiddleware 。

初始化方法:

python 复制代码
def __init__(self, has_auth_func, jwtkey, app):
    super().__init__(app)
    self.has_auth_func = has_auth_func
    self.jwtkey = jwtkey

在这个初始化方法中,我们定义了此中间件需要的参数。

我们在这里定义的是类的参数,但是实际上最后传递参数的方式是:

python 复制代码
return Middleware(RoleAPIAuthMiddleware, has_auth_func=has_auth_func, jwtkey=jwtkey)

这里的 app 是没有传参的。

核心是 dispatch 方法

python 复制代码
async def dispatch(self, request, call_next):
  # 判断是否需要跳过校验
  if "/login" in str(request.url) or "/register" in str(request.url):
      if has_logger:
          logger.debug("识别到正在访问不需要校验的路径", method=request.method, url=request.url)
      return await call_next(request)

  if has_logger:
      msg = f"开始进行基于角色和接口的高级别权限校验:method={request.method} url={request.url}"
      logger.debug(msg)

  # 获取token
  token = request.headers.get("Authorization")
  if token is None:
      if has_logger:
          logger.error("权限不足:没有携带Token", token=token)
      raise AuthException("权限不足:没有携带Token")

  # 解析token
  data = None
  try:
      if isinstance(self.jwtkey, str):
          data = jwt.parse_token(token, key=self.jwtkey)
      else:
          data = jwt.parse_token(token)
  except Exception as e:
      if has_logger:
          logger.error("获取用户Token失败", headers=request.headers, token=token, error=e)
      raise AuthException("无效的Token")

  # token的结果必须是字典类型
  if not isinstance(data, dict):
      if has_logger:
          logger.error("Token的解析结果不是字典类型", data=data)
      raise AuthException("无效的Token")

  # 必须要有过期时间
  expired = data.get("expired")
  if expired is None:
      if has_logger:
          logger.error("该Token没有设置过期时间", data=data, expired=expired)
      raise AuthException("无效的Token")

  # 校验过期时间的类型
  if not (isinstance(expired, int) or isinstance(expired, float)):
      if has_logger:
          logger.error("该Token的过期时间不是数字类型", expired=expired)
      raise AuthException("无效的Token")

  # 校验是否过期
  now = time.time()
  if now > expired:
      if has_logger:
          logger.warning("Token已过期", expired=expired)
      raise AuthException("Token已过期")

  # 必须要有用户名和用户ID
  userid = data.get("id")
  username = data.get("username")
  userrole = data.get("role")
  if not (userid or username or userrole):
      if has_logger:
          logger.error(
              "Token中应该包含用户ID,用户名和用户角色",
              userid=userid,
              username=username,
              userrole=userrole,
              token=token,
              data=data,
          )
      raise AuthException("无效的Token")

  # 获取请求方法和请求路径
  base_url = request.base_url
  url = request.url
  method = request.method
  path = str(url).replace(str(base_url), "")
  if has_logger:
      logger.debug(
          "开始查询用户是否具备接口级别的权限",
          method=method,
          path=path,
          userid=userid,
          username=username,
          userrole=userrole,
      )
  has_auth = False
  try:
      has_auth = await self.has_auth_func(userrole, method, path)
  except Exception as e:
      if has_logger:
          logger.error(
              "查询用户接口级别权限失败",
              method=method,
              path=path,
              userid=userid,
              username=username,
              userrole=userrole,
              error=e,
          )
      raise AuthException("无效的Token")
  if has_logger:
      logger.debug(
          "成功查询用户是否具备接口级别的权限",
          has_auth=has_auth,
          method=method,
          path=path,
          userid=userid,
          username=username,
          userrole=userrole,
      )
  if not has_auth:
      if has_logger:
          logger.warning(
              "权限不足",
              has_auth=has_auth,
              method=method,
              path=path,
              userid=userid,
              username=username,
              userrole=userrole,
          )
      raise AuthException("权限不足")

  # 权限充足,发送请求,获取响应
  response = await call_next(request)
  return response

中间件核心方法分析

参数是什么

python 复制代码
async def dispatch(self, request, call_next):

首先,这个方法是一个异步方法。

第一个参数是请求对象,存储了客户端的所有请求信息。

第二个参数是调用下一个中间件的对象,如果成功了,则调用此方法得到一个response对象,返回response对象即可。

如果成功了返回什么?

如果成功了,则调用此 call_next 方法得到一个response对象,返回response对象即可。

python 复制代码
response = await call_next(request)
return response

如果失败了,该返回什么?

示例代码如下:

python 复制代码
if has_logger:
    logger.error("Token的解析结果不是字典类型", data=data)
raise AuthException("无效的Token")

首先,我们是记录错误日志。

然后,抛出一个异常。

因为 zdppy_api 已经内部封装了全局异常处理,所以这个异常,最终会被全局异常错误处理器自动捕获,并返回给客户端一个比较通用且友好的错误信息。

支持哪些异常类

首先是 AuthException,这个来源于:

python 复制代码
from api.exceptions import AuthException

通过查看源码,我们可以知道, zdppy_api 框架,目前内置了如下异常处理器:

python 复制代码
default_exception_handlers = {
    404: not_found,
    500: server_error,
    HTTPException: handle_http_exception,
    AuthException: handle_auth_exception,
    Exception: handle_exception,
}

最终总结

如果,我们要实现 db 请求上下文中间件:

  • 请求开始时,自动建立连接
  • 请求结束是,自动断开连接

那么,我们的实现思路如下:

  • 1、实现一个 OrmRequestMiddleware 中间件类。这个类继承 api.middleware.base.BaseHTTPMiddleware,接收一个 db 作为参数。
  • 2、实现一个 apimidorm.request(db) 方法,这个方法的返回值是 Middleware(OrmRequestMiddleware, db=db)
  • 3、在 OrmRequestMiddleware 自定义中间件类中,重写 async def dispatch(self, request, call_next) 方法
  • 4、方法体中实现具体的逻辑。请求开始之前,调用 db.connect(),调用 response = call_next(),之后调用 db.colse(),最后返回 response。

以上是一个具体的实现思路,仅供参考。

相关推荐
小馒头学python2 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
zmd-zk8 分钟前
kafka+zookeeper的搭建
大数据·分布式·zookeeper·中间件·kafka
神奇夜光杯11 分钟前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
千天夜23 分钟前
使用UDP协议传输视频流!(分片、缓存)
python·网络协议·udp·视频流
测试界的酸菜鱼27 分钟前
Python 大数据展示屏实例
大数据·开发语言·python
羊小猪~~31 分钟前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
放飞自我的Coder1 小时前
【python ROUGE BLEU jiaba.cut NLP常用的指标计算】
python·自然语言处理·bleu·rouge·jieba分词
正义的彬彬侠1 小时前
【scikit-learn 1.2版本后】sklearn.datasets中load_boston报错 使用 fetch_openml 函数来加载波士顿房价
python·机器学习·sklearn
张小生1802 小时前
PyCharm中 argparse 库 的使用方法
python·pycharm
秃头佛爷2 小时前
Python使用PDF相关组件案例详解
python