自定义上下文管理器实战:数据库连接池、文件锁与超时控制

文章目录

一、从"会用"到"会造"

#09 讲透了 with 语句的字节码执行流程,#10 展示了 @contextmanager 如何把一个生成器变成上下文管理器。但在真实工程中,最考验功底的不是"会用 open()",而是设计出稳健的自定义上下文管理器------它需要处理并发访问、超时中断、异常回滚、资源复用等一系列棘手问题。

本篇聚焦三个高频工程场景,每个场景都包含一个朴素的初始实现,然后逐步揭示生产环境中的隐患并给出改进方案。


二、场景一:数据库连接池

2.1 为什么需要连接池

直接在每次请求中 open 一个新连接是最常见的入门写法,代价是:

  • 每次建连需要 TCP 握手 + 认证,大约消耗 10~100 ms
  • 并发量稍高时,数据库连接数很快触达上限(PostgreSQL 默认 100)
  • 连接资源持有者崩溃后,服务端连接不会立即释放

连接池的核心思路是复用:维护一组长连接,每次请求从池中取出一个,用完放回去,池满时排队等待。

2.2 基础版本

python 复制代码
import threading
import time
from typing import Optional
from contextlib import contextmanager
from queue import Queue, Empty


class DBConnection:
    """模拟一个数据库连接对象"""
    _id_counter = 0
    _lock = threading.Lock()

    def __init__(self, dsn: str):
        with DBConnection._lock:
            DBConnection._id_counter += 1
            self.id = DBConnection._id_counter
        self.dsn = dsn
        self._alive = True
        print(f"[连接 #{self.id}] 已建立")

    def execute(self, sql: str) -> list:
        if not self._alive:
            raise RuntimeError(f"连接 #{self.id} 已失效")
        time.sleep(0.01)  # 模拟查询耗时
        return [{"result": f"{sql} ok"}]

    def close(self):
        self._alive = False
        print(f"[连接 #{self.id}] 已关闭")

    def ping(self) -> bool:
        return self._alive


class ConnectionPool:
    """
    线程安全的数据库连接池
    - 支持最大连接数限制
    - 支持等待超时
    - 支持连接健康检查(ping)
    """

    def __init__(self, dsn: str, max_size: int = 5, timeout: float = 5.0):
        self.dsn = dsn
        self.max_size = max_size
        self.timeout = timeout
        self._pool: Queue = Queue(maxsize=max_size)
        self._created = 0
        self._lock = threading.Lock()

        # 预热:创建初始连接
        for _ in range(2):
            self._pool.put(self._create_connection())

    def _create_connection(self) -> DBConnection:
        with self._lock:
            self._created += 1
        return DBConnection(self.dsn)

    def _get(self) -> DBConnection:
        try:
            # 尝试从池中取一个已有连接
            conn = self._pool.get(block=False)
            if not conn.ping():
                print(f"[连接 #{conn.id}] 健康检查失败,重建连接")
                conn.close()
                conn = self._create_connection()
            return conn
        except Empty:
            pass

        # 池为空:如果还没达到上限,新建连接
        with self._lock:
            if self._created < self.max_size:
                return self._create_connection()

        # 池满且连接数已达上限,等待
        try:
            conn = self._pool.get(timeout=self.timeout)
            if not conn.ping():
                conn.close()
                conn = self._create_connection()
            return conn
        except Empty:
            raise TimeoutError(
                f"连接池已满({self.max_size} 个连接全部占用),"
                f"等待 {self.timeout}s 超时"
            )

    def _put(self, conn: DBConnection):
        if conn.ping():
            self._pool.put(conn)
        else:
            with self._lock:
                self._created -= 1
            print(f"[连接 #{conn.id}] 失效,从池中移除")

    @contextmanager
    def acquire(self):
        """从连接池借出一个连接,with 块结束后自动归还"""
        conn = self._get()
        try:
            yield conn
        except Exception:
            # 异常时连接可能处于脏状态,需要 ping 确认是否可用
            if conn.ping():
                self._put(conn)
            else:
                with self._lock:
                    self._created -= 1
            raise
        else:
            self._put(conn)

    def close_all(self):
        """关闭所有连接(应用退出时调用)"""
        while not self._pool.empty():
            try:
                conn = self._pool.get_nowait()
                conn.close()
            except Empty:
                break

使用方式:

python 复制代码
pool = ConnectionPool("postgresql://localhost/mydb", max_size=5)

# 单次查询
with pool.acquire() as conn:
    results = conn.execute("SELECT * FROM users WHERE id=1")
    print(results)

# 并发查询
def worker(pool, query):
    with pool.acquire() as conn:
        return conn.execute(query)

import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
    futures = [executor.submit(worker, pool, f"SELECT {i}") for i in range(20)]
    for f in concurrent.futures.as_completed(futures):
        print(f.result())

pool.close_all()

2.3 增加事务嵌套支持(Savepoint 语义)

真实业务场景中,事务往往会被嵌套调用------外层开启事务,内层某个函数也想开启一个"局部事务"。直接嵌套 BEGIN 在大多数数据库中会报错,正确做法是使用 Savepoint

python 复制代码
from contextlib import contextmanager


class TransactionManager:
    """
    支持嵌套事务的上下文管理器
    - 第一层:BEGIN TRANSACTION
    - 嵌套层:SAVEPOINT sp_N
    """

    def __init__(self, conn: DBConnection):
        self.conn = conn
        self._depth = 0          # 嵌套深度计数器
        self._savepoints: list[str] = []

    @contextmanager
    def transaction(self):
        self._depth += 1
        if self._depth == 1:
            self.conn.execute("BEGIN TRANSACTION")
            print(f"[事务] 开启主事务")
        else:
            sp = f"sp_{self._depth}"
            self._savepoints.append(sp)
            self.conn.execute(f"SAVEPOINT {sp}")
            print(f"[事务] 设置 Savepoint: {sp}")

        try:
            yield self
            # 正常退出
            if self._depth == 1:
                self.conn.execute("COMMIT")
                print(f"[事务] 主事务提交")
            else:
                sp = self._savepoints.pop()
                self.conn.execute(f"RELEASE SAVEPOINT {sp}")
                print(f"[事务] 释放 Savepoint: {sp}")
        except Exception as e:
            # 异常时回滚
            if self._depth == 1:
                self.conn.execute("ROLLBACK")
                print(f"[事务] 主事务回滚: {e}")
            else:
                sp = self._savepoints.pop()
                self.conn.execute(f"ROLLBACK TO SAVEPOINT {sp}")
                print(f"[事务] 回滚到 Savepoint {sp}: {e}")
            raise
        finally:
            self._depth -= 1


# 用法
with pool.acquire() as conn:
    tm = TransactionManager(conn)
    with tm.transaction():
        conn.execute("INSERT INTO orders VALUES (...)")
        with tm.transaction():           # 嵌套事务
            conn.execute("UPDATE inventory ...")
            # 如果这里失败,只回滚内层 Savepoint,不影响外层事务

三、场景二:跨平台文件锁

3.1 为什么文件锁比线程锁更复杂

threading.Lock 只在进程内有效;multiprocessing.Lock 依赖共享内存,只适合同一台机器上的多进程。文件锁是真正跨进程、跨机器(NFS 场景)的互斥机制,但实现细节因操作系统而异:

机制 系统 API 特点
fcntl.flock Unix fcntl.flock(fd, fcntl.LOCK_EX) 进程级锁,子进程继承,不可重入
fcntl.lockf Unix fcntl.lockf(fd, fcntl.LOCK_EX) 记录锁,线程独立,可按字节范围锁定
LockFileEx Windows msvcrt.locking 只支持字节范围锁,语义与 Unix 差异大

3.2 跨平台文件锁实现

python 复制代码
import os
import sys
import time
from contextlib import contextmanager


def _lock_file_unix(fd, exclusive: bool = True, nonblocking: bool = False):
    import fcntl
    flags = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
    if nonblocking:
        flags |= fcntl.LOCK_NB
    fcntl.flock(fd, flags)


def _unlock_file_unix(fd):
    import fcntl
    fcntl.flock(fd, fcntl.LOCK_UN)


def _lock_file_windows(fd, exclusive: bool = True, nonblocking: bool = False):
    import msvcrt
    # Windows 的 locking 只支持排他锁
    if nonblocking:
        mode = msvcrt.LK_NBLCK
    else:
        mode = msvcrt.LK_LOCK
    msvcrt.locking(fd, mode, 1)


def _unlock_file_windows(fd):
    import msvcrt
    msvcrt.locking(fd, msvcrt.LK_UNLCK, 1)


class FileLock:
    """
    跨平台文件锁
    - 支持共享锁(多读)/ 排他锁(单写)
    - 支持非阻塞模式(加锁失败立即抛出 BlockingIOError)
    - 支持超时模式(轮询直到超时)
    - 保证 __exit__ 中无论如何都会解锁
    """

    def __init__(self, lock_file: str, exclusive: bool = True, timeout: float = -1):
        self.lock_file = lock_file
        self.exclusive = exclusive
        self.timeout = timeout       # -1 表示永久阻塞
        self._fd = None

    def __enter__(self):
        self._fd = open(self.lock_file, "w")
        if self.timeout < 0:
            # 阻塞模式:直到获取锁
            self._lock(nonblocking=False)
        else:
            # 超时模式:轮询
            deadline = time.monotonic() + self.timeout
            while True:
                try:
                    self._lock(nonblocking=True)
                    break
                except (BlockingIOError, OSError):
                    if time.monotonic() >= deadline:
                        self._fd.close()
                        raise TimeoutError(
                            f"无法在 {self.timeout}s 内获取文件锁: {self.lock_file}"
                        )
                    time.sleep(0.05)  # 50ms 轮询间隔
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        try:
            self._unlock()
        finally:
            if self._fd:
                self._fd.close()
                self._fd = None
        return False

    def _lock(self, nonblocking: bool):
        if sys.platform == "win32":
            _lock_file_windows(self._fd.fileno(), self.exclusive, nonblocking)
        else:
            _lock_file_unix(self._fd.fileno(), self.exclusive, nonblocking)

    def _unlock(self):
        if sys.platform == "win32":
            _unlock_file_windows(self._fd.fileno())
        else:
            _unlock_file_unix(self._fd.fileno())


# 用法:保护配置文件的并发写入
with FileLock("/tmp/config.lock", exclusive=True, timeout=3.0):
    with open("/etc/myapp/config.json", "w") as f:
        import json
        json.dump({"key": "value"}, f)
# 锁在 with 退出后自动释放

3.3 文件锁的常见陷阱

陷阱一:NFS 上的文件锁不可靠

fcntl.flock 在 NFS 挂载目录上的行为取决于 NFS 版本和挂载选项(nolock 挂载参数会让所有锁调用静默忽略)。分布式场景下,建议改用 Redis SETNX 或 ZooKeeper 节点作为互斥锁,而非文件锁。

陷阱二:锁文件本身的原子创建

open("lockfile", "w") 不是原子操作------两个进程同时执行时,都可能成功创建并写入。正确做法是用 O_CREAT | O_EXCL 标志,这个操作在内核层面是原子的:

python 复制代码
import os

def try_create_lock(lock_path: str) -> bool:
    """尝试创建锁文件,成功返回 True,锁已存在返回 False"""
    try:
        fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
        os.write(fd, str(os.getpid()).encode())
        os.close(fd)
        return True
    except FileExistsError:
        return False

四、场景三:精准超时控制

4.1 超时控制的三种实现路径

方案 原理 精度 局限性
signal.alarm(SIGALRM) 操作系统信号,秒级定时 仅 Unix,只能在主线程使用
threading.Timer + 事件 子线程定时设置 Event 毫秒 不能中断阻塞中的系统调用
concurrent.futures.TimeoutError 线程/进程池的 wait(timeout=) 毫秒 无法中断已提交的 Future

4.2 基于 threading.Event 的通用超时上下文管理器

python 复制代码
import threading
import time
from contextlib import contextmanager


class TimeoutExpired(Exception):
    """超时异常"""
    def __init__(self, seconds: float):
        super().__init__(f"操作超时({seconds}s)")
        self.seconds = seconds


class Timeout:
    """
    线程安全的超时上下文管理器
    - 在后台线程计时,到期后设置 cancel_event
    - with 块内可检查 cancel_event.is_set() 实现协作式中断
    - 对于不支持事件检查的阻塞调用,提供 strict 模式强制抛出异常
    """

    def __init__(self, seconds: float, strict: bool = False):
        self.seconds = seconds
        self.strict = strict
        self._cancel_event = threading.Event()
        self._timer: threading.Timer | None = None
        self._expired = False

    @property
    def cancel_event(self) -> threading.Event:
        """调用方可以在循环中检查这个事件来实现协作式中断"""
        return self._cancel_event

    def _on_timeout(self):
        self._expired = True
        self._cancel_event.set()

    def __enter__(self):
        self._timer = threading.Timer(self.seconds, self._on_timeout)
        self._timer.daemon = True
        self._timer.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._timer:
            self._timer.cancel()
            self._timer = None
        if self._expired and self.strict and exc_type is None:
            # strict 模式:超时后如果没有其他异常,主动抛出
            raise TimeoutExpired(self.seconds)
        return False

    def check(self):
        """主动检查点:超时则抛出异常"""
        if self._cancel_event.is_set():
            raise TimeoutExpired(self.seconds)


# 用法一:协作式中断(循环中主动检查)
def long_running_task(data: list, timeout_ctx: Timeout):
    results = []
    for item in data:
        timeout_ctx.check()       # 主动检查点:超时立即中断
        results.append(process(item))
    return results


with Timeout(5.0) as t:
    result = long_running_task(large_dataset, t)


# 用法二:用于可以拆分的 I/O 循环
def read_stream_with_timeout(stream, timeout: float = 3.0) -> list:
    lines = []
    with Timeout(timeout) as t:
        while True:
            t.check()
            line = stream.readline()
            if not line:
                break
            lines.append(line)
    return lines

4.3 SIGALRM 方案(Unix 精确中断)

在 Unix 系统上,signal.alarm 可以真正中断阻塞的系统调用(如 socket.recvtime.sleep),精度优于 threading.Timer

python 复制代码
import signal
import contextlib


@contextlib.contextmanager
def unix_timeout(seconds: int):
    """
    基于 SIGALRM 的秒级超时上下文管理器(仅 Unix 主线程可用)
    - 能中断 socket.recv、subprocess.wait 等阻塞调用
    - 不能用于子线程(信号只能在主线程处理)
    """

    class _Timeout(Exception):
        pass

    def _handler(signum, frame):
        raise _Timeout(f"操作超时({seconds}s)")

    old_handler = signal.signal(signal.SIGALRM, _handler)
    old_alarm = signal.alarm(seconds)
    try:
        yield
    except _Timeout:
        raise TimeoutExpired(seconds)
    finally:
        signal.alarm(old_alarm or 0)  # 恢复原有的 alarm(避免嵌套使用时清除外层 alarm)
        signal.signal(signal.SIGALRM, old_handler)


# 使用:中断卡住的网络请求
import socket

with unix_timeout(3):
    sock = socket.socket()
    sock.connect(("10.0.0.1", 8080))  # 如果 3 秒内没有响应,抛出 TimeoutExpired
    data = sock.recv(4096)

五、三大场景的架构全览

超时控制流程
正常完成
调用 t.check()
True
False
timer 触发
Timeout.enter()
启动 threading.Timer
yield self(执行代码)
exit: timer.cancel()
cancel_event.is_set()?
抛出 TimeoutExpired
_cancel_event.set()

_expired = True
strict 模式且无其他异常

→ 抛出 TimeoutExpired
文件锁流程
< 0(永久阻塞)
>= 0(超时模式)
超时
加锁成功
FileLock.enter()
open(lock_file, 'w')
timeout 参数?
flock LOCK_EX(阻塞直到成功)
循环 flock LOCK_NB

直到超时
抛出 TimeoutError
yield(执行受保护的代码)
exit: flock LOCK_UN

close(fd)
连接池架构

无,未满
无,已满
超时
正常退出
异常退出
请求线程

pool.acquire()
ConnectionPool._get()
池中有空闲连接?
取出连接,ping 健康检查
新建 DBConnection
Queue.get(timeout=T)

等待归还
抛出 TimeoutError
yield conn(执行业务逻辑)
pool._put(conn) 归还
ping 确认可用性

可用则归还,否则销毁


六、设计上下文管理器的七条原则

回顾三个场景,可以总结出一套可操作的设计原则:

# 原则 反例 正例
1 __exit__ 必须无条件执行清理 if not error: release() try...finally: release()
2 不要在 __exit__ 中引发新异常 raise ValueError("cleanup failed") 记录日志,返回 False
3 明确 __exit__ 是否压制异常 默认 return True 只对已知、可安全忽略的异常 return True
4 __enter__ 失败时不调用 __exit__ 不做任何处理 __enter__ 内用 try/except 清理已初始化的资源
5 支持 as 返回有用的对象 return None return conn/return self 让调用方能访问状态
6 多资源按逆序释放 手动按顺序 close() ExitStack 自动管理逆序
7 处理嵌套进入(可重入性) 重入时抛出异常或死锁 用深度计数器或 Savepoint 支持嵌套

总结

三个场景揭示了上下文管理器在工程实践中的三个层次:

  • 连接池 :上下文管理器作为资源分配器acquire() 封装了借出和归还的完整生命周期,包括健康检查和异常时的连接销毁
  • 文件锁 :上下文管理器作为互斥原语,处理跨平台差异、非阻塞模式和超时轮询
  • 超时控制 :上下文管理器作为执行约束,通过后台定时器或 SIGALRM 信号为任意代码块添加时间上限

这三种角色------资源分配器、互斥原语、执行约束------几乎覆盖了生产代码中上下文管理器的全部典型用途。


相关好文推荐

  • with 语句底层原理:深入理解 __enter__/__exit__ 的字节码执行机制,是读懂本篇连接池实现的前提
  • contextlib 工具箱@contextmanagerExitStacksuppress 等工具在连接池和文件锁场景中大量使用

如果觉得这篇文章有帮助,欢迎点赞 👍、收藏 ⭐、转发!

三个场景里哪一个踩过坑?欢迎在评论区交流。关注作者,持续更新中。

相关推荐
吃着火锅x唱着歌1 小时前
LeetCode 503.下一个更大元素II
算法·leetcode·职场和发展
小短腿的代码世界1 小时前
从KB到字节:Qt行情数据压缩与传输优化的全链路透视——LZ4、Snappy与自定义二进制协议的极限压榨
开发语言·qt
_深海凉_1 小时前
LeetCode热题100-将有序数组转换为二叉搜索树
数据结构·算法·leetcode
AI技术控1 小时前
Transformer 的 Encoder 和 Decoder 模块介绍:从结构原理到大模型应用实践
人工智能·python·深度学习·自然语言处理·transformer
晚风_END1 小时前
Linux|操作系统|最新版zfs编译后的适用于centos7的rpm安装包完全离线安装介绍
linux·运维·服务器·c++·python·缓存·github
KaMeidebaby1 小时前
卡梅德生物技术快报|单克隆抗体人源化 PEG 修饰质控方法体系构建与验证
服务器·前端·数据库·人工智能·算法·百度·新浪微博
wuxinyan1231 小时前
工业级大模型学习之路015:RAG零基础入门教程(第十一篇):系统重构与代码规范化
人工智能·python·学习·重构·rag
humors2211 小时前
检查网址连通性的python脚本
网络·python·网站·检测网址·查询网址·网址连通性·网址可访问性
灵机一物1 小时前
灵机一物AI原生电商小程序、PC端(已上线)-【技术深度解析】Bun 6 天 AI 重写 96 万行代码:从 Zig 迁移 Rust 全流程与行业影响
开发语言·人工智能·rust