引言
在AI的开发中,WebSocket的开发尤其重要,比如与大模型的对接,一般都是使用WebSocket通讯,达到全双工与实时响应的效果。FastAPI 作为现代化的 Python Web 框架,提供了强大而简洁的 WebSocket 支持。本文将由浅入深,通过几个范例讲解,逐步掌握FastAPI的WebSocket的开发技巧。
1 什么是 WebSocket?
WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议。与传统的 HTTP 请求-响应模式不同,WebSocket 允许服务器和客户端之间建立持久连接,实现实时数据交换。
2 FastAPI的WebSocket基础用法。
python
from fastapi import FastAPI, WebSocket
app = FastAPI()
@app.websocket("/ws")
async def simple_websocket(websocket: WebSocket):
# 接受连接
await websocket.accept()
try:
while True:
# 接收消息
data = await websocket.receive_text()
# 发送回应
await websocket.send_text(f"Echo: {data}")
except Exception as e:
print(f"Connection closed: {e}")
这个基础的范例展示了 WebSocket 的基本工作流程:建立连接、持续通信、处理断开。
3 多消息处理用法
python
from fastapi import WebSocket, WebSocketDisconnect
@app.websocket("/ws/advanced")
async def advanced_websocket(websocket: WebSocket):
await websocket.accept()
try:
while True:
# 接收多种类型的消息
message = await websocket.receive()
if "text" in message:
await websocket.send_text(f"Text received: {message['text']}")
elif "bytes" in message:
await websocket.send_bytes(message["bytes"])
elif "json" in message:
await websocket.send_json({"echo": message["json"]})
except WebSocketDisconnect:
print("Client disconnected gracefully")
4 使用连接管理器实现多连接管理
4.1 连接管理器的实现
在实际开发中,需要管理多个客户端的连接,因此有必要通过一个连接管理器来管理这些连接。
python
from typing import Dict, List
import json
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_groups: Dict[str, List[str]] = {}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
self.active_connections.pop(client_id, None)
# 从所有群组中移除
for group in self.connection_groups.values():
if client_id in group:
group.remove(client_id)
async def send_personal_message(self, message: dict, client_id: str):
if client_id in self.active_connections:
await self.active_connections[client_id].send_json(message)
async def broadcast(self, message: dict):
disconnected_clients = []
for client_id, websocket in self.active_connections.items():
try:
await websocket.send_json(message)
except Exception:
disconnected_clients.append(client_id)
# 清理断开的连接
for client_id in disconnected_clients:
self.disconnect(client_id)
async def add_to_group(self, group_name: str, client_id: str):
if group_name not in self.connection_groups:
self.connection_groups[group_name] = []
if client_id not in self.connection_groups[group_name]:
self.connection_groups[group_name].append(client_id)
async def send_to_group(self, group_name: str, message: dict):
if group_name in self.connection_groups:
disconnected_clients = []
for client_id in self.connection_groups[group_name]:
if client_id in self.active_connections:
try:
await self.active_connections[client_id].send_json(message)
except Exception:
disconnected_clients.append(client_id)
# 清理断开的连接
for client_id in disconnected_clients:
self.disconnect(client_id)
manager = ConnectionManager()
4.2 连接管理器的使用
python
@app.websocket("/ws/chat/{client_id}")
async def chat_websocket(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
try:
# 通知所有用户新用户加入
await manager.broadcast({
"type": "user_joined",
"client_id": client_id,
"message": f"User {client_id} joined the chat"
})
while True:
data = await websocket.receive_text()
message_data = json.loads(data)
if message_data.get("type") == "join_room":
# 加入聊天室
room_name = message_data["room_name"]
await manager.add_to_group(room_name, client_id)
await manager.send_to_group(room_name, {
"type": "room_join",
"client_id": client_id,
"room_name": room_name
})
else:
# 广播消息
await manager.broadcast({
"type": "message",
"client_id": client_id,
"content": message_data.get("content", ""),
"timestamp": message_data.get("timestamp")
})
except WebSocketDisconnect:
manager.disconnect(client_id)
await manager.broadcast({
"type": "user_left",
"client_id": client_id,
"message": f"User {client_id} left the chat"
})
5 多线线程安全与异步处理
在真实的生成环境中,仅仅处理多连接还不行,还需要考虑多线程与异步的处理。以下范例的连接管理器增加了安全锁与发送队列,其他线程发送消息先放到发送队列即可,由单独的消息分发任务进行发送。
5.1 线程安全与异步处理的连接管理器
python
import asyncio
import threading
from typing import Dict
from queue import Queue, Empty
class ThreadSafeWebSocketManager:
def __init__(self):
self.connections: Dict[str, WebSocket] = {}
self.message_queues: Dict[str, Queue] = {}
self.lock = threading.RLock()
self.dispatcher_tasks: Dict[str, asyncio.Task] = {}
async def add_connection(self, client_id: str, websocket: WebSocket):
"""添加连接并启动消息分发器"""
with self.lock:
self.connections[client_id] = websocket
self.message_queues[client_id] = Queue()
# 启动消息分发任务
task = asyncio.create_task(self._message_dispatcher(client_id))
self.dispatcher_tasks[client_id] = task
def remove_connection(self, client_id: str):
"""移除连接"""
with self.lock:
websocket = self.connections.pop(client_id, None)
queue = self.message_queues.pop(client_id, None)
task = self.dispatcher_tasks.pop(client_id, None)
# 取消任务
if task:
task.cancel()
# 关闭 WebSocket
if websocket:
try:
asyncio.create_task(websocket.close())
except:
pass
def send_message(self, client_id: str, message: dict):
"""从任何线程安全发送消息"""
with self.lock:
if client_id in self.message_queues:
self.message_queues[client_id].put(message)
def broadcast(self, message: dict, exclude_clients: set = None):
"""广播消息到所有连接"""
exclude_clients = exclude_clients or set()
with self.lock:
for client_id in self.connections:
if client_id not in exclude_clients:
self.send_message(client_id, message)
async def _message_dispatcher(self, client_id: str):
"""异步消息分发器"""
while client_id in self.connections:
try:
# 使用异步方式等待消息
message = await asyncio.get_event_loop().run_in_executor(
None,
self._get_message_safe,
client_id
)
if message and client_id in self.connections:
websocket = self.connections[client_id]
await websocket.send_json(message)
except Exception as e:
print(f"Error in dispatcher for {client_id}: {e}")
break
def _get_message_safe(self, client_id: str):
"""安全地从队列获取消息"""
try:
return self.message_queues[client_id].get(timeout=0.1)
except Empty:
return None
thread_safe_manager = ThreadSafeWebSocketManager()
此连接管理器中,每个WebSocket连接成功后都启动了一个消息分发任务,专门发送该连接的消息。
5.2 连接管理器的使用
python
import time
import threading
from datetime import datetime
def start_background_notifications(manager: ThreadSafeWebSocketManager):
"""启动后台通知任务"""
def notification_generator():
"""生成系统通知"""
count = 0
while True:
try:
notification = {
"type": "system_notification",
"message": f"System update #{count}",
"timestamp": datetime.now().isoformat(),
"priority": "info"
}
# 安全地广播通知
manager.broadcast(notification)
count += 1
time.sleep(30) # 每30秒发送一次
except Exception as e:
print(f"Notification generator error: {e}")
time.sleep(5) # 错误后等待5秒重试
# 启动后台线程
thread = threading.Thread(target=notification_generator, daemon=True)
thread.start()
@app.on_event("startup")
async def startup_event():
start_background_notifications(thread_safe_manager)
@app.websocket("/ws/thread-safe/{client_id}")
async def thread_safe_websocket(websocket: WebSocket, client_id: str):
await thread_safe_manager.add_connection(client_id, websocket)
try:
# 发送欢迎消息
thread_safe_manager.send_message(client_id, {
"type": "welcome",
"message": "Connected to thread-safe WebSocket",
"timestamp": datetime.now().isoformat()
})
# 处理客户端消息
while True:
data = await websocket.receive_text()
print(f"Received from {client_id}: {data}")
except Exception as e:
print(f"WebSocket error for {client_id}: {e}")
finally:
thread_safe_manager.remove_connection(client_id)
此范例中,启动一个定时广播的任务,每隔一段时间发送广播消息给每个WebSocket连接,最终也是调用send_message把消息放到发送队列里,由消息分发任务来发送消息。
6 双队列异步处理
当然我们也可以增加接收队列,由单独的接收任务来接收与处理消息,WebSocket的主线程仅仅是建立连接与定时发送心跳消息。
6.1 双队列异步连接管理器
python
import asyncio
import json
from typing import Dict
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
app = FastAPI()
class SimpleDualQueueManager:
"""简单的双队列 WebSocket 管理器"""
def __init__(self):
# 连接存储
self.connections: Dict[str, WebSocket] = {}
# 双队列系统:接收队列和发送队列
self.receive_queues: Dict[str, asyncio.Queue] = {}
self.send_queues: Dict[str, asyncio.Queue] = {}
# 任务存储
self.receive_tasks: Dict[str, asyncio.Task] = {}
self.process_tasks: Dict[str, asyncio.Task] = {}
self.send_tasks: Dict[str, asyncio.Task] = {}
async def add_connection(self, client_id: str, websocket: WebSocket):
"""添加连接并启动三个核心任务"""
self.connections[client_id] = websocket
self.receive_queues[client_id] = asyncio.Queue()
self.send_queues[client_id] = asyncio.Queue()
# 启动三个异步任务
self.receive_tasks[client_id] = asyncio.create_task(
self._receive_messages(client_id, websocket)
)
self.process_tasks[client_id] = asyncio.create_task(
self._process_messages(client_id)
)
self.send_tasks[client_id] = asyncio.create_task(
self._send_messages(client_id, websocket)
)
# 发送欢迎消息(通过发送队列)
await self.send_queues[client_id].put({
"type": "welcome",
"message": f"Client {client_id} connected"
})
async def remove_connection(self, client_id: str):
"""移除连接并清理资源"""
# 取消所有任务
for task in [self.receive_tasks.get(client_id),
self.process_tasks.get(client_id),
self.send_tasks.get(client_id)]:
if task:
task.cancel()
# 清理资源
self.connections.pop(client_id, None)
self.receive_queues.pop(client_id, None)
self.send_queues.pop(client_id, None)
self.receive_tasks.pop(client_id, None)
self.process_tasks.pop(client_id, None)
self.send_tasks.pop(client_id, None)
async def _receive_messages(self, client_id: str, websocket: WebSocket):
"""任务1: 接收消息并放入接收队列"""
try:
while True:
# 从WebSocket接收消息
data = await websocket.receive_text()
message = json.loads(data)
# 放入接收队列
if client_id in self.receive_queues:
await self.receive_queues[client_id].put(message)
except WebSocketDisconnect:
print(f"Client {client_id} disconnected")
except Exception as e:
print(f"Receive error for {client_id}: {e}")
finally:
await self.remove_connection(client_id)
async def _process_messages(self, client_id: str):
"""任务2: 从接收队列处理消息,结果放入发送队列"""
try:
while client_id in self.receive_queues:
# 从接收队列获取消息
message = await self.receive_queues[client_id].get()
# 处理消息(这里简单回声)
response = {
"type": "echo",
"original": message,
"timestamp": "now"
}
# 将响应放入发送队列
if client_id in self.send_queues:
await self.send_queues[client_id].put(response)
except Exception as e:
print(f"Process error for {client_id}: {e}")
async def _send_messages(self, client_id: str, websocket: WebSocket):
"""任务3: 从发送队列取出消息并发送"""
try:
while client_id in self.send_queues:
# 从发送队列获取消息
message = await self.send_queues[client_id].get()
# 通过WebSocket发送
await websocket.send_json(message)
except Exception as e:
print(f"Send error for {client_id}: {e}")
def send_message(self, client_id: str, message: dict):
"""从外部线程安全发送消息"""
if client_id in self.send_queues:
# 使用线程安全的方式将消息放入队列
asyncio.run_coroutine_threadsafe(
self.send_queues[client_id].put(message),
asyncio.get_event_loop()`在这里插入代码片`
)
# 创建管理器实例
manager = SimpleDualQueueManager()
6.2 连接管理器的使用
python
@app.websocket("/ws/simple/{client_id}")
async def simple_websocket(websocket: WebSocket, client_id: str):
await websocket.accept()
# 注册连接到管理器
await manager.add_connection(client_id, websocket)
try:
# 主循环只负责保持连接
# 实际的消息处理已经在后台任务中运行
while True:
# 简单的心跳检查
await asyncio.sleep(30)
# 检查连接是否仍然有效
if client_id not in manager.connections:
break
except Exception as e:
print(f"WebSocket error for {client_id}: {e}")
finally:
await manager.remove_connection(client_id)
# 后台任务示例
import threading
import time
def background_task():
"""模拟后台任务发送消息"""
count = 0
while True:
try:
count += 1
message = {
"type": "background",
"count": count,
"timestamp": time.time()
}
# 向所有连接的客户端广播消息
for client_id in list(manager.connections.keys()):
manager.send_message(client_id, message)
time.sleep(5) # 每5秒发送一次
except Exception as e:
print(f"Background task error: {e}")
time.sleep(1)
# 启动后台任务
@app.on_event("startup")
async def startup():
thread = threading.Thread(target=background_task, daemon=True)
thread.start()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
说明:
三个独立任务:
· 接收任务:从 WebSocket 接收消息 → 放入接收队列
· 处理任务:从接收队列取出消息 → 处理 → 放入发送队列
· 发送任务:从发送队列取出消息 → 通过 WebSocket 发送
队列优势:
· 解耦:接收、处理、发送相互独立,互不阻塞
· 缓冲:处理速度不一致时,队列起到缓冲作用
· 线程安全:外部线程可以通过队列安全地发送消息
异步处理:
· 所有操作都是异步的,不会阻塞事件循环
· 使用 asyncio.Queue 实现异步队列
· 每个客户端有自己独立的队列和任务
这个简单范例包含了双队列异步处理的核心思想,可以根据需要扩展更复杂的功能如错误处理、优先级队列、批处理等。