Python Flask 上下文管理源码分析

Python Flask 上下文管理源码分析

前言

Flask 上下文管理可以说是 Flask 非常具有特色的设计,它总共可分为 2 个大的方向:

  • 应用上下文管理:通过 current_app 即可拿到当前 Flask 实例
  • 请求上下文管理:通过导入的 request 对象即可拿到当前的请求对象

特别是请求上下文管理我们会经常用到。

在两年前左右我曾经拜读过 Flask 源码,当时给我的感受是比较震撼的。

现如今 Flask 增添了对异步的支持,因此上下文管理做了很大程度的变更,它不再依赖于 threading-local 而是依赖于 contextvars 模块。

本着 「路漫漫其修远兮、吾将上下而求索」 的精神,我决定重新再读一遍 Flask 上下文管理器的源码,并在本文做个记录,方便感兴趣的小伙伴查阅相关资料。

基本介绍

WSGI

Flask 是一个标准的基于 WSGI 协议 werkzeug 模块的 WEB 应用框架。

首先我们看一下 werkzeug 的基本使用:

python 复制代码
from werkzeug.wrappers import Request, Response

# 创建一个应用程序函数,它将处理 HTTP 请求
def application(environ, start_response):
    request = Request(environ)
    response = Response('Hello, World!', content_type='text/plain')
    return response(environ, start_response)


if __name__ == '__main__':
    from werkzeug.serving import run_simple
    run_simple('localhost', 5000, application)

werkzeug 模块让构建 WEB 应用十分简单、它提供了很多常用的基本功能。

如:

  • environ : 一个 dict、用于存放请求相关的数据
  • start_response : 一个 function、用于发送响应头部

在上面的示例中,可以看到 werkzeug 服务是通过 run_simple() 函数进行启动的,先记住这一点,下面会进行详细介绍。

globals

进入 Flask 源码后,需要注意一段导入语句:

python 复制代码
from .globals import current_app as current_app
from .globals import g as g
from .globals import request as request
from .globals import session as session

这 4 个导入的资源是 Flask 实现上下文的核心:

python 复制代码
current_app: Flask = LocalProxy(  # type: ignore[assignment]
    _cv_app, "app", unbound_message=_no_app_msg
)
g: _AppCtxGlobals = LocalProxy(  # type: ignore[assignment]
    _cv_app, "g", unbound_message=_no_app_msg
)
request: Request = LocalProxy(  # type: ignore[assignment]
    _cv_request, "request", unbound_message=_no_req_msg
)
session: SessionMixin = LocalProxy(  # type: ignore[assignment]
    _cv_request, "session", unbound_message=_no_req_msg
)

除此之外还有几个需要留意的变量:

python 复制代码
_cv_app: ContextVar[AppContext] = ContextVar("flask.app_ctx")
_cv_request: ContextVar[RequestContext] = ContextVar("flask.request_ctx")

先牢记这些点,接下来我们开始正式阅读源码。

Flask 与 werkzeug

Flask 的启动

Flask 是通过 run() 方法启动的:

python 复制代码
import flask

app = flask.Flask(__name__)

@app.route("/index/", methods=["GET"])
def index():
    return "hello world"

if __name__ == "__main__":
    app.run(host="localhost", port=5000)

所以直接点进源码查看 Flask 类下的 run() 方法即可,可以看见在该方法中调用了 werkzeug 模块里的 run_simple() 函数:

python 复制代码
class Flask(Scaffold):
    def run(self, host, port, debug, load_dotenv, **options):
        # ... 省略代码
        from werkzeug.serving import run_simple
        try:
            run_simple(t.cast(str, host), port, self, **options)
        finally:
            # reset the first request information if the development server
            # reset normally.  This makes it possible to restart the server
            # without reloader and that stuff from an interactive shell.
            self._got_first_request = False

注意!这里 run_simple() 第 3 个参数传入的是 self、也就是 Flask 实例对象本身。

所以 run_simple() 运行起来会去调用 Flask.call() 方法。

python 复制代码
class Flask(Scaffold):
    def __call__(self, environ: dict, start_response: t.Callable) -> t.Any:
        return self.wsgi_app(environ, start_response)

Flask.call() 又调用了 Flask.wsgi_app() 方法。

Flask.wsgi_app()

Flask.wsgi_app() 是整个 Flask 代码的核心:

python 复制代码
class Flask(Scaffold):
    def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any:

        # 封装一个上下文对象
        ctx = self.request_context(environ)
        error: BaseException | None = None
        try:
            try:
                # 进行上下文管理
                ctx.push()
                # 执行视图函数
                response = self.full_dispatch_request()
            except Exception as e:
                error = e
                response = self.handle_exception(e)
            except:  # noqa: B001
                error = sys.exc_info()[1]
                raise
            # 返回响应
            return response(environ, start_response)
        finally:
            if "werkzeug.debug.preserve_context" in environ:
                environ["werkzeug.debug.preserve_context"](_cv_app.get())
                environ["werkzeug.debug.preserve_context"](_cv_request.get())

            if error is not None and self.should_ignore_error(error):
                error = None
            # 清除上下文管理内容
            ctx.pop(error)

上下文的存入

RequestContext 的实例化

Flask.request_context() 方法的源码非常简单,用于返回一个请求上下文对象:

python 复制代码
class Flask(Scaffold)
    def request_context(self, environ: dict) -> RequestContext:
        return RequestContext(self, environ)

RequestContext 主要包含以下内容:

  • 一个 request 对象
  • 一个 session 对象

源码如下:

python 复制代码
class RequestContext:

    request_class = Request

    def __init__(
        self,
        app: Flask,
        environ: dict,
        request: Request | None = None,
        session: SessionMixin | None = None,
    ) -> None:
        # 这里是 Flask 类的实例对象
        self.app = app

        if request is None:
            # request_class 是一个类属性,指向了 Request 类
            # 这里是实例化了 Request 类,具体源码不做细看
            # 常见的 request.method、request.args、request.json 等都在这里封装
            request = app.request_class(environ)
            request.json_module = app.json

        self.request: Request = request
        self.url_adapter = None
        try:
            # 创建 url 和请求的适配器
            # 适配器的作用有:
            # - 路由匹配
            # - URL 生成
            # - URL 变量提取
            # 具体源码不做细看
            self.url_adapter = app.create_url_adapter(self.request)
        except HTTPException as e:
            self.request.routing_exception = e
        self.flashes: list[tuple[str, str]] | None = None

        # session 是 None,这里不做细看
        self.session: SessionMixin | None = session
        self._after_request_functions: list[ft.AfterRequestCallable] = []

        # 用于存储上下文 Token、以及应用上下文的列表
        self._cv_tokens: list[tuple[contextvars.Token, AppContext | None]] = []

RequestContext.push()

在 Flask.wsgi_app() 方法中、得到 RequestContext 的实例化对象 ctx 后随后会立即调用其下的 push() 方法:

python 复制代码
class RequestContext:

    def push(self) -> None:
        # _cv_app 是空的、所以这里得到的是 None
        app_ctx = _cv_app.get(None)

        if app_ctx is None or app_ctx.app is not self.app:
            # 创建一个应用上下文并调用其 push() 方法
            app_ctx = self.app.app_context()
            app_ctx.push()
        else:
            app_ctx = None

        # 将当前的请求上下文的 Token 和应用上下文绑定到 self._cv_tokens 中
        # 注意!这里将当前请求上下文设置在了 _cv_request 中
        # set() 方法会返回一个 Token、通过这个 Token 后续可以做 reset() 操作
        self._cv_tokens.append((_cv_request.set(self), app_ctx))

        # session 相关,暂时略过
        if self.session is None:
            session_interface = self.app.session_interface
            self.session = session_interface.open_session(self.app, self.request)

            if self.session is None:
                self.session = session_interface.make_null_session(self.app)

        # 匹配请求,暂时略过
        if self.url_adapter is not None:
            self.match_request()

AppContext 的实例化

在 RequestContext.push() 方法中,app_ctx 是通过 Flask.app_context() 方法得到的。

该方法也非常简单,用于得到一个应用上下文对象:

python 复制代码
class Flask(Scaffold):
    def app_context(self) -> AppContext:
        return AppContext(self)

AppContext 主要包含以下内容:

  • 一个 g 对象
  • 当前的 Flask 类的实例对象

源码如下:

python 复制代码
class AppContext:

    def __init__(self, app: Flask) -> None:
        self.app = app
        self.url_adapter = app.create_url_adapter(None)
        self.g: _AppCtxGlobals = app.app_ctx_globals_class()
        self._cv_tokens: list[contextvars.Token] = []

AppContext.push()

在 RequestContext.push() 方法中,app_ctx 实例化完成后会调用其 push() 方法。

python 复制代码
class AppContext:
    def push(self) -> None:
        # 当前应用上下文的 Token 将被存储到 self._cv_tokens 中
        # 注意!这里将当前应用上下文设置在了 _cv_app 中
        self._cv_tokens.append(_cv_app.set(self))
        # 发送 1 个信号
        appcontext_pushed.send(self.app, _async_wrapper=self.app.ensure_sync)

Flask.wsgi_app() 的回顾

到目前为止、我们得到了请求上下文对象、应用上下文对象。

并将其分别设置在了 _cv_request 全局变量和 _cv_app 全局变量中。

python 复制代码
class Flask(Scaffold):
    def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any:

        # 封装一个上下文对象
        ctx = self.request_context(environ)
        error: BaseException | None = None
        try:
            try:
                # 进行上下文管理
                ctx.push()
                # 执行视图函数
                response = self.full_dispatch_request()
            except Exception as e:
                error = e
                response = self.handle_exception(e)
            except:  # noqa: B001
                error = sys.exc_info()[1]
                raise
            # 返回响应
            return response(environ, start_response)
        finally:
            if "werkzeug.debug.preserve_context" in environ:
                environ["werkzeug.debug.preserve_context"](_cv_app.get())
                environ["werkzeug.debug.preserve_context"](_cv_request.get())

            if error is not None and self.should_ignore_error(error):
                error = None
            # 清除上下文管理内容
            ctx.pop(error)

其实到这里的时候、ctx.push() 方法已经执行完毕了。下面执行视图函数、返回响应的具体实现就先不看了。

我们切换一个视角,看看 import 语句是如何拿到当前的 request 以及 app 对象的。

获取当前上下文

请求上下文管理器的初始化

当执行以下语句时,会发生什么事?

python 复制代码
from flask import request

点进源码,会发现我们来到了 globals.py 文件中:

python 复制代码
_no_req_msg = """\
Working outside of request context.

This typically means that you attempted to use functionality that needed
an active HTTP request. Consult the documentation on testing for
information about how to avoid this problem.\
"""

_cv_request: ContextVar[RequestContext] = ContextVar("flask.request_ctx")

request: Request = LocalProxy(  # type: ignore[assignment]
    _cv_request, "request", unbound_message=_no_req_msg
)

所以此时时间线这里要拉回之前 LocalProxy 实例化的时候,先看看它在 Flask 启动时会如何进行实例化:

python 复制代码
class LocalProxy(t.Generic[T]):
    __slots__ = ("__wrapped", "_get_current_object")
    _get_current_object: t.Callable[[], T]

    def __init__(
        self,
        local: ContextVar[T] | Local | LocalStack[T] | t.Callable[[], T],
        name: str | None = None,
        *,
        unbound_message: str | None = None,
    ) -> None:

        # local 是 _cv_request
        # name 是 request
        # unbound_message 是 _no_req_msg

        # 不执行
        if name is None:
            get_name = _identity
        else:
            get_name = attrgetter(name)

        # 不执行
        if unbound_message is None:
            unbound_message = "object is not bound"

        # 不执行
        if isinstance(local, Local):
            if name is None:
                raise TypeError("'name' is required when proxying a 'Local' object.")

            def _get_current_object() -> T:
                try:
                    return get_name(local)  # type: ignore[return-value]
                except AttributeError:
                    raise RuntimeError(unbound_message) from None

        # 不执行
        elif isinstance(local, LocalStack):

            def _get_current_object() -> T:
                obj = local.top

                if obj is None:
                    raise RuntimeError(unbound_message)

                return get_name(obj)

        # 执行
        elif isinstance(local, ContextVar):

            # 标记、这个闭包函数后续会有其他用途
            def _get_current_object() -> T:
                try:
                    obj = local.get()
                except LookupError:
                    raise RuntimeError(unbound_message) from None

                return get_name(obj)

        # 不执行
        elif callable(local):

            def _get_current_object() -> T:
                return get_name(local())

        else:
            raise TypeError(f"Don't know how to proxy '{type(local)}'.")

        # 注意、这里是调用的 object 的 __setattr__ 为实例对象 self 设置属性
        object.__setattr__(self, "_LocalProxy__wrapped", local)
        object.__setattr__(self, "_get_current_object", _get_current_object)

在请求管理器中获取当前的请求

总所周知、当我们在 view func 中使用 request.method 等诸如此类的 attribute 但该 attribute 不存在时会自动调用其类下的 getattr() 方法。

我们 import 的 request 对象,实际上是 LocalProxy 的实例对象,所以当使用 request 中的 attrgetter 时自然会去调用 LocalProxy 其下的 getattr() 方法。

故在这里我们要研究:通过 LocalProxy 是如何找到 request 对象的。

python 复制代码
class LocalProxy(t.Generic[T]):
    __getattr__ = _ProxyLookup(getattr)

可以看到、LocalProxy 的 getattr() 方法实际上是 _ProxyLookup 的实例对象,在这里要明白 2 件事情:

  • 当 Flask 启动时、会自动调用 _ProxyLookup 的 init() 方法完成对象初始化
  • 当使用 request 中的 attribute 时、经过 LocalProxy 的 getattr() 方法后会去自动调用 _ProxyLookup 类下的 call() 方法,因为实例对象加括号调用会自动寻找类下的 call() 方法

对于 LocalProxy 的 getattr() 方法来说,_ProxyLookup 所做的事情无非是添加了一个 bind_f 的属性。

python 复制代码
class _ProxyLookup:
    __slots__ = ("bind_f", "fallback", "is_attr", "class_value", "name")

    def __init__(
        self,
        f: t.Callable | None = None,
        fallback: t.Callable | None = None,
        class_value: t.Any | None = None,
        is_attr: bool = False,
    ) -> None:
        bind_f: t.Callable[[LocalProxy, t.Any], t.Callable] | None

        # 这里的 f 就是 getattr
        if hasattr(f, "__get__"):
            def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable:
                return f.__get__(obj, type(obj))

        # 会执行这里
        elif f is not None:
            def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable:
                # f = getattr
                return partial(f, obj)

        else:
            bind_f = None

        self.bind_f = bind_f
        self.fallback = fallback # None
        self.class_value = class_value # None
        self.is_attr = is_attr # None

第 2 件事、即调用 request 的 attribute 时 _ProxyLookup 中 call() 方法会做什么:

python 复制代码
class LocalProxy(t.Generic[T]):
    def __call__(self, instance: LocalProxy, *args: t.Any, **kwargs: t.Any) -> t.Any:
        # instance: LocalProxy 
        return self.__get__(instance, type(instance))(*args, **kwargs)


    def __get__(self, instance: LocalProxy, owner: type | None = None) -> t.Any:
        # 不执行
        if instance is None:
            if self.class_value is not None:
                return self.class_value

            return self

        try:
            # 这里实际上调用的是 _cv_request 的 get() 方法
            # 看上面代码注释中有一处 「标记、这个闭包函数后续会有其他用途」
            # 这里可以直接获取到当前上下文的 request 实例对象
            obj = instance._get_current_object()
        except RuntimeError:
            if self.fallback is None:
                raise

            fallback = self.fallback.__get__(instance, owner)

            if self.is_attr:
                # __class__ and __doc__ are attributes, not methods.
                # Call the fallback to get the value.
                return fallback()

            return fallback

        # 会执行这个 bind_f 方法,并进行 return
        if self.bind_f is not None:
            # bind_f 方法其实是将 LocalProxy 实例对象和 request 实例对象都传进去
            # 然后返回 request 实例对象的 __getattr__ 方法
            return self.bind_f(instance, obj)

        return getattr(obj, self.name)

这里的调用顺序可能有些绕,需要静心仔细研读。

经过上面的一系列操作,我们最终通过 request 这个 LocalProxy 实例对象成功拿到了 request 实例对象。

应用上下问管理器的初始化

同分析请求上下文管理器类似,当执行以下语句时,会发生什么事?

python 复制代码
from flask import current_app

基本过程和上面的过程非常相似,进入源码后会发现以下语句:

python 复制代码
_cv_app: ContextVar[AppContext] = ContextVar("flask.app_ctx")

current_app: Flask = LocalProxy(
    _cv_app, "app", unbound_message=_no_app_msg
)

还是一样的实例化过程,这里不再详细举例。

在应用管理器中获取当前的应用

整体是和请求管理器中获取当前请求的步骤相同。不再详细举例。

其他的导入语句与上下文管理器

Flask 中包括 g 对象、session 对象的 import 都会使用到上下文管理器。

整个获取过程也和上面介绍的类似,具体这里就不再进行分析了。

上下文的清理

再回到 Flask.wsgi_app()

Flask.wsgi_app() 中最后会对上下文进行清理:

python 复制代码
class Flask(Scaffold):
    def wsgi_app(self, environ: dict, start_response: t.Callable) -> t.Any:

        # 封装一个上下文对象
        ctx = self.request_context(environ)
        error: BaseException | None = None
        try:
            try:
                # 进行上下文管理
                ctx.push()
                # 执行视图函数
                response = self.full_dispatch_request()
            except Exception as e:
                error = e
                response = self.handle_exception(e)
            except:  # noqa: B001
                error = sys.exc_info()[1]
                raise
            # 返回响应
            return response(environ, start_response)
        finally:
            if "werkzeug.debug.preserve_context" in environ:
                environ["werkzeug.debug.preserve_context"](_cv_app.get())
                environ["werkzeug.debug.preserve_context"](_cv_request.get())

            if error is not None and self.should_ignore_error(error):
                error = None

            # 清除上下文管理内容
            ctx.pop(error)

RequestContext.pop()

清除请求上下文会通过 pop() 方法进行:

python 复制代码
class RequestContext:
    def pop(self, exc: BaseException | None = _sentinel) -> None:  # type: ignore

        # 一个布尔值、判断是否要清理请求
        clear_request = len(self._cv_tokens) == 1

        try:
            if clear_request:
                if exc is _sentinel:
                    exc = sys.exc_info()[1]
                self.app.do_teardown_request(exc)

                request_close = getattr(self.request, "close", None)
                if request_close is not None:
                    request_close()
        finally:
            # 获取请求上下文
            ctx = _cv_request.get()
            # 弹出请求上下文的 token 和应用上下文
            token, app_ctx = self._cv_tokens.pop()
            # 清理请求上下文
            _cv_request.reset(token)

            if clear_request:
                ctx.request.environ["werkzeug.request"] = None

            if app_ctx is not None:

                # 这里是清理应用上下文
                app_ctx.pop(exc)

            if ctx is not self:
                raise AssertionError(
                    f"Popped wrong request context. ({ctx!r} instead of {self!r})"
                )

AppContext.pop()

在清理请求上下文的过程中,如果包含应用上下文、则应用上下文也会清理:

python 复制代码
class AppContext:

    def __init__(self, app: Flask) -> None:
        # ...
        self._cv_tokens: list[contextvars.Token] = []

    def pop(self, exc: BaseException | None = _sentinel) -> None:  # type: ignore
        try:

            # 一个布尔值、判断是否要清理请求
            if len(self._cv_tokens) == 1:
                if exc is _sentinel:
                    exc = sys.exc_info()[1]
                self.app.do_teardown_appcontext(exc)

        finally:
            # 获取应用上下文
            ctx = _cv_app.get()

            # 弹出应用上下文的 token 并进行 reset
            _cv_app.reset(self._cv_tokens.pop())

        if ctx is not self:
            raise AssertionError(
                f"Popped wrong app context. ({ctx!r} instead of {self!r})"
            )

        # 发送信号
        appcontext_popped.send(self.app, _async_wrapper=self.app.ensure_sync)