前言
周五正在摸鱼刷v站,突然看到一个帖子,里面有几个面试题,其中有一个觉得蛮有意思的,刚好正在学习FastAPI,就扩展下。
面试题:三个窗口同时卖100张票,卖完窗口关闭,分别统计三个窗口的卖票时间和数量,不能超卖。
这个题肯定是多线程,但是作为一个前端开发,其实还是非常懵逼的,一点思路都没有,靠着同事以前给讲解的后端知识,可能觉得要加锁,但是怎么加,毫无思路,问了下gpt,发现有好几个种解决方法,要根据业务量来设计不同的解决方案:小规模可以使用线程锁,中等规模需要用到数据库锁,规模较大的话,就需要用到消息队列或者redis原子操作。
接下来根据上面的解决方案挨个介绍下:
多线程锁
Python 是一个支持多线程的语言,但由于全局解释器锁(GIL)的存在,它的多线程有一些特殊性。Python虽然提供了多线程的API,但是Python标准解释器CPython中,同一时刻只能有一个线程在执行Python代码,但是只是限制CPU密集型任务,比如在一些大量的计算、图像、视频处理和深度学习计算,这种场景下,只能一个线程占着GIL,但是在密集IO操作,比如网络IO、文件IO、数据路IO就非常适合多线程操作了,因为在IO 阻塞期间,GIL会自动释放让其它线程可以继续执行
Python开启多线程也比较简单,有常见的两种方式: threading 模块和 ThreadPoolExecutor 线程池,接下来分别介绍下:
css
import threading as th
import time
def worker(name):
print(f"{name} start")
time.sleep(1)
print(f"{name} end")
t1 = th.Thread(target=worker, args=("Thread-1",))
t2 = th.Thread(target=worker, args=("Thread-2",))
t1.start()
t2.start()
t1.join()
t2.join()
print("All threads done")
#输出
Thread-1 start
Thread-2 start
Thread-2 end
Thread-1 end
All threads done
这个例子就是通过threading 创建了两个线程,还是比较繁琐的,需要手动维护线程数量,而且没有办法去复用线程
ThreadPoolExecutor 是官方为解决线程管理复杂,性能低的问题而提供的高级API。
改写下上面的例子:
python
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
def worker(name):
print(f"{name} start")
time.sleep(1)
print(f"{name} end")
return name
if __name__ == "__main__":
names = ["Thread-1", "Thread-2", "Thread-3"]
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(worker, n) for n in names]
for future in as_completed(futures):
_ = future.result()
print("All threads done")
#输出
Thread-1 start
Thread-2 start
Thread-3 start
Thread-2 end
Thread-1 end
Thread-3 end
All threads done
其中max_workers=3表明最多三个线程并行,executor.submit(worker, n)就是将任务提交到线程池的队列中,等待worker线程来执行。as_completed生成器按照任务实力完成顺序逐个返回已完成的futures对象,
_ = future.result()阻塞调用,获取已完成任务的结果,而-算是一种Python约定,用下划线表示忽略这个变量,只关心任务完成的结果,不存储结果。
前置知识已经了解的差不多了,回到面试题,三个窗口并发卖票一百张,卖完就停止,统计数量和耗时,严格不超卖。然后在并发中,多个线程就需要同时去检查与修改共享库存,将检查并减库存作为一个不可分割的原子操作,用互斥锁保护,确保同一时刻只有一个线程可以执行这段代码,
看下下面的代码:
python
import threading
import time
from typing import Dict, List
TOTAL_TICKETS = 100
SALE_DELAY_SEC = 0.01 # 每次卖票的模拟耗时
class TicketCounter:
def __init__(self, total: int):
self.total = total
self.sold = 0
self.lock = threading.Lock()
def try_sell_one(self) -> bool:
"""尝试卖一张票,保证不超卖。成功返回 True,失败返回 False。"""
with self.lock:
if self.sold < self.total:
self.sold += 1
return True
return False
def window_worker(name: str, counter: TicketCounter, stats: Dict[str, Dict[str, float]]):
start = time.perf_counter()
sold_count = 0
while True:
if not counter.try_sell_one():
break
sold_count += 1
time.sleep(SALE_DELAY_SEC)
end = time.perf_counter()
stats[name] = {
"sold": sold_count,
"seconds": round(end - start, 4),
}
def main():
counter = TicketCounter(TOTAL_TICKETS)
stats: Dict[str, Dict[str, float]] = {}
windows = [f"窗口-{i}" for i in range(1, 4)]
threads: List[threading.Thread] = []
for w in windows:
t = threading.Thread(target=window_worker, args=(w, counter, stats), name=w)
t.start()
threads.append(t)
for t in threads:
t.join()
# 汇总
total_sold = sum(s["sold"] for s in stats.values())
print(f"总票数: {TOTAL_TICKETS}, 实际售出: {total_sold}")
print("是否超卖:", "是" if total_sold > TOTAL_TICKETS else "否")
print("各窗口统计:")
for w in windows:
s = stats[w]
print(f"- {w}: 卖出 {s['sold']} 张, 用时 {s['seconds']} 秒")
if __name__ == "__main__":
main()
来逐行解释下上面的代码:
全局定义了一个常亮TOTAL_TICKETS来当做总票数,SALE_DELAY_SEC当做卖票的模拟时间,TicketCounter这个类就是共享的计数器,类中定义了三个变量:total是总票数量,sold已售数量,lock就是这个问题的核心互斥锁,类中的方法try_sell_one就是模拟卖票,也就是核心代码:with self.lock:,这个with实际上就是一个语法糖,会调用共享对象的互斥锁,执行完成后会自动将锁给释放掉。锁中的代码,就是判断当前已经售出的票和总票数。
window_worker方法就是模拟窗口,start = time.perf_counter()就是开始计时,sold_count就是当前窗口的卖票数量,通过while True:创建一个无限循环,直到if not counter.try_sell_one(): 而try_sell_one方法返回false的条件就是票卖完了,后续也就是一些累计窗口卖票的数量和时间。
在main方法中,首先就创建一个共享对象,然后开启第三个线程:
ini
t = threading.Thread(target=window_worker, args=(w, counter, stats), name=w)
线程中回去执行window_worker,也就是卖票的方法,t.start()启动线程,会在内部创建一个线程,将目标方法也就是window_worker传给他,并且异步调用。threads.append(t)将线程对象放到list中,后续遍历,调用t.jion()时,主线程在这里会阻塞,直到list中所有的线程任务全部结束,后续就是统计各个窗口的卖票数量、时间了,他们都在定义在stats: Dict[str, Dict[str, float]] = {}这个模型中。
mysql行锁
MySQL行锁就是MySQL InnoDB 引擎在执行 SQL 时,只锁住某几行记录,而不是整张表,通过更多的并发SQL同时执行,来提高吞吐量。
行锁主要有以下几个特点:
- 粒度小:只锁定需要操作的记录,其他记录可以被其他事物访问和修改
- 并发性高:多个事务可以同时操作一张表的不同记录,不会互相阻塞
- 需要索引: InnoDB 的行锁是基于索引实现的,如果条件没有走索引,可能会升级成表锁
- 避免脏写:保证同一时间只有一个实务可以修改记录,避免数据冲突
行锁主要有两种:
-
共享锁,允许事务读取,但不允许修改,多个实务可以同时共享锁
sqlSELECT ... LOCK IN SHARE MODE -
排他锁,不允许其他事务读或者写,只有加锁事务可以自己读写
sqlSELECT ... FOR UPDATE
比如下面这几行SQL
sql
BEGIN;
-- 事务1
SELECT * FROM users WHERE id = 1 FOR UPDATE; -- 对id=1的行加排他锁
-- 事务2(此时会被阻塞)
SELECT * FROM users WHERE id = 1 FOR UPDATE;
在事务1提交前,事务2无法修改或加锁同一行。
回到面试题,思路都大差不差,多线程锁是设计一个共享类来存储总票数和已售票数,类上再写一个方法模拟卖票,使用MySQL行锁,共享类这部分的功能就被MySQL代替了,看下代码:
python
import asyncio
import time
from typing import Dict, List
from sqlalchemy.ext.asyncio import create_async_engine, AsyncConnection
from sqlalchemy import text
# 你的数据库地址
DATABASE_URL = "mysql+aiomysql://root:my-mysqlxxxx"
TOTAL_TICKETS = 100 # 与数据库中的 total 一致即可,此处仅用于校验展示
SALE_DELAY_SEC = 0.01 # 每次卖票的模拟耗时
STOCK_ID = 1 # tickets_stock 表的主键 id
async def sell_one(conn: AsyncConnection, window_name: str) -> bool:
"""在一个事务中卖出一张票,使用 SELECT ... FOR UPDATE 防止超卖。"""
try:
async with conn.begin():
result = await conn.execute(
text("SELECT sold, total FROM tickets_stock WHERE id=:id FOR UPDATE"),
{"id": STOCK_ID},
)
row = result.fetchone()
if row is None:
return False
sold, total = row[0], row[1]
if sold < total:
await conn.execute(
text("UPDATE tickets_stock SET sold = sold + 1 WHERE id=:id"),
{"id": STOCK_ID},
)
await conn.execute(
text(
"""
INSERT INTO sale_logs(`window`, `qty`, `sold_after`)
VALUES(:window, :qty, :sold_after)
"""
),
{"window": window_name, "qty": 1, "sold_after": sold + 1},
)
# 事务上下文退出时自动提交
return True
# 不可售时事务回滚(上下文自动处理)
return False
except Exception as e:
print(f"[ERROR] sell_one 异常: {e}")
return False
async def window_worker(conn: AsyncConnection, name: str) -> Dict[str, float]:
start = time.perf_counter()
sold_count = 0
while True:
ok = await sell_one(conn, name)
if not ok:
# 未能售出,可能是售罄、无行、或SQL错误
break
sold_count += 1
await asyncio.sleep(SALE_DELAY_SEC)
seconds = round(time.perf_counter() - start, 4)
return {"window": name, "sold": sold_count, "seconds": seconds}
async def main():
engine = create_async_engine(DATABASE_URL, pool_size=5, max_overflow=10, pool_pre_ping=True)
try:
# 初始化:创建 sale_logs 表;确保 tickets_stock 存在 id=1 记录
async with engine.begin() as init_conn:
await init_conn.execute(text(
"""
CREATE TABLE IF NOT EXISTS sale_logs (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
`window` VARCHAR(20) NOT NULL,
`qty` INT NOT NULL,
`sold_after` INT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
))
# 确保库存行存在
result = await init_conn.execute(text("SELECT COUNT(*) FROM tickets_stock WHERE id=:id"), {"id": STOCK_ID})
count = result.scalar_one()
if count == 0:
await init_conn.execute(text("INSERT INTO tickets_stock (id, total, sold) VALUES (:id, :total, 0)"), {"id": STOCK_ID, "total": TOTAL_TICKETS})
# 为每个窗口分别获取连接,确保在事件循环关闭前正确关闭
async with engine.connect() as conn1:
async with engine.connect() as conn2:
async with engine.connect() as conn3:
windows = ["窗口-1", "窗口-2", "窗口-3"]
tasks = [
window_worker(conn1, windows[0]),
window_worker(conn2, windows[1]),
window_worker(conn3, windows[2]),
]
stats: List[Dict[str, float]] = await asyncio.gather(*tasks)
# 读取最终库存用于校验
result = await conn1.execute(
text("SELECT sold, total FROM tickets_stock WHERE id=:id"),
{"id": STOCK_ID},
)
row = result.fetchone()
sold, total = (row[0], row[1]) if row else (0, TOTAL_TICKETS)
if row is None:
print("[WARN] tickets_stock 找不到 id=1 的行,请确认已插入初始库存记录。")
# 输出统计
total_sold_by_windows = sum(int(s["sold"]) for s in stats)
print(f"数据库记录: 总票数={total}, 已售={sold}")
print(f"窗口统计: 实际售出={total_sold_by_windows}")
print("是否超卖:", "是" if sold > total else "否")
for s in stats:
print(f"- {s['window']}: 卖出 {int(s['sold'])} 张, 用时 {s['seconds']} 秒")
finally:
# 显式释放引擎与连接池,避免事件循环关闭后清理触发异常
await engine.dispose()
if __name__ == "__main__":
asyncio.run(main())
首先看下sell_one这个方法就是单词卖票的实务函数,这个方法里先查询了当前的库存并且给这一行加上排他锁:SELECT sold, total FROM tickets_stock WHERE id=:id FOR UPDATE,如果sold < total就表明库存还有,可以正常卖票,UPDATE tickets_stock SET sold = sold + 1 WHERE id=:id就是简单的将表中的sold加一,表明已经卖出了一张票,同时还有张表sale_logs来存储相关信息。
window_worker就是单次卖票函数,和上面的一样,在main方法中,有点不一样首先创建了日志表sale_logs,然后创建了三个连接:
csharp
async with engine.connect() as conn1:
async with engine.connect() as conn2:
async with engine.connect() as conn3:
每个并发"窗口"都要独立执行事务和等待锁,单个连接无法同时处理多个协程的数据库操作;为每个窗口分配一个独立连接,才能真正并发、正确加锁并避免连接被并发复用导致的错误或串行化。
SELECT ... FOR UPDATE 需要在单个事务中持有行级锁。如果三个窗口共用同一个连接,那么同一会话的事务会互相干扰或被串行化,达不到并发效果;独立连接意味着独立会话,每个窗口各自开事务、各自持锁。
到这里,我有一个误解,我以为是有几个线程,就需要开启几个mysql连接,其实并不是的, 在并发卖票这种场景里,每个"窗口"需要一个独立的事务与会话来持有 SELECT ... FOR UPDATE 的行锁;如果多个线程共享同一个连接(同一会话),数据库操作会被串行化或互相干扰,无法形成真实的锁竞争,还可能触发连接并发复用错误。这就是"为每个窗口分配一个独立连接"的原因。只要保证"同时活跃的并发事务数"各有独立连接即可,通常通过连接池按需借还来满足这一点。也就是说,独立连接是并发单位上的必需,但连接的分配是动态的,不需要静态一一绑定到线程。
redis锁
Redis 锁是利用 Redis 的原子操作(SETNX / SET with NX EX)实现的一种分布式锁,用来保证多个服务竞争同一资源时,同一时刻只有一个人能修改。
python
import asyncio
import time
from typing import Dict, List
from redis.asyncio import from_url, Redis
# Redis 连接地址
REDIS_URL = "redis://:xxx"
# 业务常量
TOTAL_TICKETS = 100
TICKET_KEY = "tickets:remaining" # 剩余票数的键
LOCK_KEY = "lock:tickets" # 分布式锁键
SALE_DELAY_SEC = 0.01 # 每次卖票的模拟耗时
def get_redis() -> Redis:
return from_url(REDIS_URL, decode_responses=True)
async def init_stock(r: Redis) -> None:
# 如果不存在则初始化为 TOTAL_TICKETS;存在则保留当前值(便于多次运行观察)
exists = await r.exists(TICKET_KEY)
if not exists:
await r.set(TICKET_KEY, TOTAL_TICKETS)
async def sell_one_with_lock(r: Redis, window_name: str) -> bool:
"""使用 Redis 分布式锁保护"检查+扣减"关键区,成功卖出返回 True,售罄或失败返回 False。"""
lock = r.lock(LOCK_KEY, timeout=5, blocking_timeout=1) # 超时时间与获取等待时间可调
acquired = await lock.acquire(blocking=True)
if not acquired:
# 未拿到锁,视为本次卖票失败(可重试)
return False
try:
# 关键区:读取剩余、判定、扣减
remaining_str = await r.get(TICKET_KEY)
remaining = int(remaining_str) if remaining_str is not None else 0
if remaining <= 0:
return False
# 扣减一张(原子自减命令)
await r.decr(TICKET_KEY)
return True
finally:
try:
await lock.release()
except Exception:
# 若锁已过期或其他异常,忽略释放错误
pass
async def window_worker(name: str) -> Dict[str, float]:
r = get_redis()
start = time.perf_counter()
sold_count = 0
try:
while True:
ok = await sell_one_with_lock(r, name)
if not ok:
# 售罄或未拿到锁(可以继续尝试),策略:若剩余为0则退出;否则短暂休眠后重试
remaining_str = await r.get(TICKET_KEY)
remaining = int(remaining_str) if remaining_str is not None else 0
if remaining <= 0:
break
await asyncio.sleep(0.002)
continue
sold_count += 1
# 模拟售票耗时(不占用锁,避免长时间持锁)
await asyncio.sleep(SALE_DELAY_SEC)
finally:
await r.aclose()
seconds = round(time.perf_counter() - start, 4)
return {"window": name, "sold": sold_count, "seconds": seconds}
async def main():
r = get_redis()
try:
await init_stock(r)
remaining_str = await r.get(TICKET_KEY)
remaining = int(remaining_str) if remaining_str is not None else 0
print(f"初始剩余票数: {remaining}")
finally:
await r.aclose()
# 三个窗口并发卖票
windows = ["窗口-1", "窗口-2", "窗口-3"]
stats: List[Dict[str, float]] = await asyncio.gather(
*[window_worker(w) for w in windows]
)
# 汇总
r2 = get_redis()
try:
remaining_str = await r2.get(TICKET_KEY)
remaining = int(remaining_str) if remaining_str is not None else 0
finally:
await r2.aclose()
total_sold_by_windows = sum(int(s["sold"]) for s in stats)
print(f"数据库(Redis)记录: 总票数={TOTAL_TICKETS}, 剩余={remaining}, 已售={TOTAL_TICKETS - remaining}")
print(f"窗口统计: 实际售出={total_sold_by_windows}")
print("是否超卖:", "是" if (TOTAL_TICKETS - remaining) > TOTAL_TICKETS else "否")
for s in stats:
print(f"- {s['window']}: 卖出 {int(s['sold'])} 张, 用时 {s['seconds']} 秒")
if __name__ == "__main__":
asyncio.run(main())
分析下代码:get_redis返回一个异步的redis客户端,init_stock方法用来保证tickets:remaining存在,如果不存在就直接初始化; sell_one方法就是加锁的核心代码,lock = r.lock(LOCK_KEY, timeout=5, blocking_timeout=1)创建分布式锁对象, 并用 await lock.acquire(blocking=True) 阻塞尝试获取,然后就是常规的读取库存,判断是否超出,不超过就库存减少。window_worker方法就是并发窗口,功能和前面都大差不差。
消息队列
使用Redis List 充当消息队列,每个队列元素代表"一张票"的消费任务;三个窗口作为并发消费者,通过阻塞弹出 BRPOP 领取票任务,卖完队列为空即停止;以队列驱动消费的方式天然避免超卖:每张票只能被弹出一次
python
import asyncio
import time
from typing import Dict, List
from redis.asyncio import from_url, Redis
# Redis 连接地址(复用现有环境)
REDIS_URL = "redis://:xxx"
# 业务常量与键名
TOTAL_TICKETS = 100
QUEUE_KEY = "queue:tickets" # 队列键,存放每张票的任务
STATS_HASH = "stats:tickets" # 可选:统计哈希(本脚本内部统计即可,不强依赖Redis)
SALE_DELAY_SEC = 0.01 # 模拟售票耗时
def get_redis() -> Redis:
return from_url(REDIS_URL, decode_responses=True)
async def init_queue(r: Redis) -> None:
# 清空旧队列,重新放入 TOTAL_TICKETS 个任务(每个任务代表一张票)
await r.delete(QUEUE_KEY)
if TOTAL_TICKETS > 0:
# 用LPUSH批量插入,提高初始化速度
values = [str(i) for i in range(1, TOTAL_TICKETS + 1)]
# Redis 允许 LPUSH 多值:LPUSH key v1 v2 ...
await r.lpush(QUEUE_KEY, *values)
async def window_worker(name: str) -> Dict[str, float]:
r = get_redis()
start = time.perf_counter()
sold_count = 0
try:
while True:
# BRPOP 从队列右侧阻塞弹出1秒,若超时返回None,视为可能售罄
try:
item = await r.brpop(QUEUE_KEY, timeout=1)
except Exception:
item = None
if not item:
# 再次检查队列长度,0则退出
length = await r.llen(QUEUE_KEY)
if length == 0:
break
await asyncio.sleep(0.002)
continue
# 成功弹出一张票
sold_count += 1
# 模拟售票耗时(不占用队列操作)
await asyncio.sleep(SALE_DELAY_SEC)
finally:
await r.aclose()
seconds = round(time.perf_counter() - start, 4)
return {"window": name, "sold": sold_count, "seconds": seconds}
async def main():
# 初始化队列
r = get_redis()
try:
await init_queue(r)
length = await r.llen(QUEUE_KEY)
print(f"队列初始化完成: 票数={length}")
finally:
await r.aclose()
# 三个窗口并发消费队列
windows = ["窗口-1", "窗口-2", "窗口-3"]
stats: List[Dict[str, float]] = await asyncio.gather(
*[window_worker(w) for w in windows]
)
# 汇总与校验
total_sold_by_windows = sum(int(s["sold"]) for s in stats)
r2 = get_redis()
try:
remaining = await r2.llen(QUEUE_KEY)
finally:
await r2.aclose()
print(f"消息队列卖票: 总票数={TOTAL_TICKETS}, 剩余={remaining}, 已售={TOTAL_TICKETS - remaining}")
print(f"窗口统计: 实际售出={total_sold_by_windows}")
print("是否超卖:", "是" if (TOTAL_TICKETS - remaining) > TOTAL_TICKETS else "否")
for s in stats:
print(f"- {s['window']}: 卖出 {int(s['sold'])} 张, 用时 {s['seconds']} 秒")
if __name__ == "__main__":
asyncio.run(main())
就算有ai的解释,很多东西也都是第一次听说,也都是一知半解的,就那个多线程算是看懂了,后续的数据库锁,消息队列都不太懂,是时候需要强化下数据库的知识了。