FastAPI WebSocket 由浅入深的开发范例

引言

在AI的开发中,WebSocket的开发尤其重要,比如与大模型的对接,一般都是使用WebSocket通讯,达到全双工与实时响应的效果。FastAPI 作为现代化的 Python Web 框架,提供了强大而简洁的 WebSocket 支持。本文将由浅入深,通过几个范例讲解,逐步掌握FastAPI的WebSocket的开发技巧。

1 什么是 WebSocket?

WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议。与传统的 HTTP 请求-响应模式不同,WebSocket 允许服务器和客户端之间建立持久连接,实现实时数据交换。

2 FastAPI的WebSocket基础用法。

python 复制代码
from fastapi import FastAPI, WebSocket

app = FastAPI()

@app.websocket("/ws")
async def simple_websocket(websocket: WebSocket):
    # 接受连接
    await websocket.accept()
    
    try:
        while True:
            # 接收消息
            data = await websocket.receive_text()
            # 发送回应
            await websocket.send_text(f"Echo: {data}")
    except Exception as e:
        print(f"Connection closed: {e}")

这个基础的范例展示了 WebSocket 的基本工作流程:建立连接、持续通信、处理断开。

3 多消息处理用法

python 复制代码
from fastapi import WebSocket, WebSocketDisconnect

@app.websocket("/ws/advanced")
async def advanced_websocket(websocket: WebSocket):
    await websocket.accept()
    
    try:
        while True:
            # 接收多种类型的消息
            message = await websocket.receive()
            
            if "text" in message:
                await websocket.send_text(f"Text received: {message['text']}")
            elif "bytes" in message:
                await websocket.send_bytes(message["bytes"])
            elif "json" in message:
                await websocket.send_json({"echo": message["json"]})
                
    except WebSocketDisconnect:
        print("Client disconnected gracefully")

4 使用连接管理器实现多连接管理

4.1 连接管理器的实现

在实际开发中,需要管理多个客户端的连接,因此有必要通过一个连接管理器来管理这些连接。

python 复制代码
from typing import Dict, List
import json

class ConnectionManager:
    def __init__(self):
        self.active_connections: Dict[str, WebSocket] = {}
        self.connection_groups: Dict[str, List[str]] = {}

    async def connect(self, websocket: WebSocket, client_id: str):
        await websocket.accept()
        self.active_connections[client_id] = websocket

    def disconnect(self, client_id: str):
        self.active_connections.pop(client_id, None)
        # 从所有群组中移除
        for group in self.connection_groups.values():
            if client_id in group:
                group.remove(client_id)

    async def send_personal_message(self, message: dict, client_id: str):
        if client_id in self.active_connections:
            await self.active_connections[client_id].send_json(message)

    async def broadcast(self, message: dict):
        disconnected_clients = []
        for client_id, websocket in self.active_connections.items():
            try:
                await websocket.send_json(message)
            except Exception:
                disconnected_clients.append(client_id)
        
        # 清理断开的连接
        for client_id in disconnected_clients:
            self.disconnect(client_id)

    async def add_to_group(self, group_name: str, client_id: str):
        if group_name not in self.connection_groups:
            self.connection_groups[group_name] = []
        if client_id not in self.connection_groups[group_name]:
            self.connection_groups[group_name].append(client_id)

    async def send_to_group(self, group_name: str, message: dict):
        if group_name in self.connection_groups:
            disconnected_clients = []
            for client_id in self.connection_groups[group_name]:
                if client_id in self.active_connections:
                    try:
                        await self.active_connections[client_id].send_json(message)
                    except Exception:
                        disconnected_clients.append(client_id)
            
            # 清理断开的连接
            for client_id in disconnected_clients:
                self.disconnect(client_id)

manager = ConnectionManager()

4.2 连接管理器的使用

python 复制代码
@app.websocket("/ws/chat/{client_id}")
async def chat_websocket(websocket: WebSocket, client_id: str):
    await manager.connect(websocket, client_id)
    
    try:
        # 通知所有用户新用户加入
        await manager.broadcast({
            "type": "user_joined",
            "client_id": client_id,
            "message": f"User {client_id} joined the chat"
        })
        
        while True:
            data = await websocket.receive_text()
            message_data = json.loads(data)
            
            if message_data.get("type") == "join_room":
                # 加入聊天室
                room_name = message_data["room_name"]
                await manager.add_to_group(room_name, client_id)
                await manager.send_to_group(room_name, {
                    "type": "room_join",
                    "client_id": client_id,
                    "room_name": room_name
                })
            else:
                # 广播消息
                await manager.broadcast({
                    "type": "message",
                    "client_id": client_id,
                    "content": message_data.get("content", ""),
                    "timestamp": message_data.get("timestamp")
                })
                
    except WebSocketDisconnect:
        manager.disconnect(client_id)
        await manager.broadcast({
            "type": "user_left",
            "client_id": client_id,
            "message": f"User {client_id} left the chat"
        })

5 多线线程安全与异步处理

在真实的生成环境中,仅仅处理多连接还不行,还需要考虑多线程与异步的处理。以下范例的连接管理器增加了安全锁与发送队列,其他线程发送消息先放到发送队列即可,由单独的消息分发任务进行发送。

5.1 线程安全与异步处理的连接管理器

python 复制代码
import asyncio
import threading
from typing import Dict
from queue import Queue, Empty

class ThreadSafeWebSocketManager:
    def __init__(self):
        self.connections: Dict[str, WebSocket] = {}
        self.message_queues: Dict[str, Queue] = {}
        self.lock = threading.RLock()
        self.dispatcher_tasks: Dict[str, asyncio.Task] = {}
        
    async def add_connection(self, client_id: str, websocket: WebSocket):
        """添加连接并启动消息分发器"""
        with self.lock:
            self.connections[client_id] = websocket
            self.message_queues[client_id] = Queue()
            
        # 启动消息分发任务
        task = asyncio.create_task(self._message_dispatcher(client_id))
        self.dispatcher_tasks[client_id] = task
        
    def remove_connection(self, client_id: str):
        """移除连接"""
        with self.lock:
            websocket = self.connections.pop(client_id, None)
            queue = self.message_queues.pop(client_id, None)
            task = self.dispatcher_tasks.pop(client_id, None)
            
        # 取消任务
        if task:
            task.cancel()
            
        # 关闭 WebSocket
        if websocket:
            try:
                asyncio.create_task(websocket.close())
            except:
                pass
                
    def send_message(self, client_id: str, message: dict):
        """从任何线程安全发送消息"""
        with self.lock:
            if client_id in self.message_queues:
                self.message_queues[client_id].put(message)
                
    def broadcast(self, message: dict, exclude_clients: set = None):
        """广播消息到所有连接"""
        exclude_clients = exclude_clients or set()
        with self.lock:
            for client_id in self.connections:
                if client_id not in exclude_clients:
                    self.send_message(client_id, message)
                    
    async def _message_dispatcher(self, client_id: str):
        """异步消息分发器"""
        while client_id in self.connections:
            try:
                # 使用异步方式等待消息
                message = await asyncio.get_event_loop().run_in_executor(
                    None, 
                    self._get_message_safe, 
                    client_id
                )
                
                if message and client_id in self.connections:
                    websocket = self.connections[client_id]
                    await websocket.send_json(message)
                    
            except Exception as e:
                print(f"Error in dispatcher for {client_id}: {e}")
                break
                
    def _get_message_safe(self, client_id: str):
        """安全地从队列获取消息"""
        try:
            return self.message_queues[client_id].get(timeout=0.1)
        except Empty:
            return None

thread_safe_manager = ThreadSafeWebSocketManager()

此连接管理器中,每个WebSocket连接成功后都启动了一个消息分发任务,专门发送该连接的消息。

5.2 连接管理器的使用

python 复制代码
import time
import threading
from datetime import datetime

def start_background_notifications(manager: ThreadSafeWebSocketManager):
    """启动后台通知任务"""
    
    def notification_generator():
        """生成系统通知"""
        count = 0
        while True:
            try:
                notification = {
                    "type": "system_notification",
                    "message": f"System update #{count}",
                    "timestamp": datetime.now().isoformat(),
                    "priority": "info"
                }
                
                # 安全地广播通知
                manager.broadcast(notification)
                count += 1
                time.sleep(30)  # 每30秒发送一次
                
            except Exception as e:
                print(f"Notification generator error: {e}")
                time.sleep(5)  # 错误后等待5秒重试
                
    # 启动后台线程
    thread = threading.Thread(target=notification_generator, daemon=True)
    thread.start()

@app.on_event("startup")
async def startup_event():
    start_background_notifications(thread_safe_manager)

@app.websocket("/ws/thread-safe/{client_id}")
async def thread_safe_websocket(websocket: WebSocket, client_id: str):
    await thread_safe_manager.add_connection(client_id, websocket)
    
    try:
        # 发送欢迎消息
        thread_safe_manager.send_message(client_id, {
            "type": "welcome",
            "message": "Connected to thread-safe WebSocket",
            "timestamp": datetime.now().isoformat()
        })
        
        # 处理客户端消息
        while True:
            data = await websocket.receive_text()
            print(f"Received from {client_id}: {data}")
            
    except Exception as e:
        print(f"WebSocket error for {client_id}: {e}")
    finally:
        thread_safe_manager.remove_connection(client_id)

此范例中,启动一个定时广播的任务,每隔一段时间发送广播消息给每个WebSocket连接,最终也是调用send_message把消息放到发送队列里,由消息分发任务来发送消息。

6 双队列异步处理

当然我们也可以增加接收队列,由单独的接收任务来接收与处理消息,WebSocket的主线程仅仅是建立连接与定时发送心跳消息。

6.1 双队列异步连接管理器

python 复制代码
import asyncio
import json
from typing import Dict
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app = FastAPI()

class SimpleDualQueueManager:
    """简单的双队列 WebSocket 管理器"""
    
    def __init__(self):
        # 连接存储
        self.connections: Dict[str, WebSocket] = {}
        
        # 双队列系统:接收队列和发送队列
        self.receive_queues: Dict[str, asyncio.Queue] = {}
        self.send_queues: Dict[str, asyncio.Queue] = {}
        
        # 任务存储
        self.receive_tasks: Dict[str, asyncio.Task] = {}
        self.process_tasks: Dict[str, asyncio.Task] = {}
        self.send_tasks: Dict[str, asyncio.Task] = {}
    
    async def add_connection(self, client_id: str, websocket: WebSocket):
        """添加连接并启动三个核心任务"""
        self.connections[client_id] = websocket
        self.receive_queues[client_id] = asyncio.Queue()
        self.send_queues[client_id] = asyncio.Queue()
        
        # 启动三个异步任务
        self.receive_tasks[client_id] = asyncio.create_task(
            self._receive_messages(client_id, websocket)
        )
        self.process_tasks[client_id] = asyncio.create_task(
            self._process_messages(client_id)
        )
        self.send_tasks[client_id] = asyncio.create_task(
            self._send_messages(client_id, websocket)
        )
        
        # 发送欢迎消息(通过发送队列)
        await self.send_queues[client_id].put({
            "type": "welcome",
            "message": f"Client {client_id} connected"
        })
    
    async def remove_connection(self, client_id: str):
        """移除连接并清理资源"""
        # 取消所有任务
        for task in [self.receive_tasks.get(client_id), 
                    self.process_tasks.get(client_id), 
                    self.send_tasks.get(client_id)]:
            if task:
                task.cancel()
        
        # 清理资源
        self.connections.pop(client_id, None)
        self.receive_queues.pop(client_id, None)
        self.send_queues.pop(client_id, None)
        self.receive_tasks.pop(client_id, None)
        self.process_tasks.pop(client_id, None)
        self.send_tasks.pop(client_id, None)
    
    async def _receive_messages(self, client_id: str, websocket: WebSocket):
        """任务1: 接收消息并放入接收队列"""
        try:
            while True:
                # 从WebSocket接收消息
                data = await websocket.receive_text()
                message = json.loads(data)
                
                # 放入接收队列
                if client_id in self.receive_queues:
                    await self.receive_queues[client_id].put(message)
                    
        except WebSocketDisconnect:
            print(f"Client {client_id} disconnected")
        except Exception as e:
            print(f"Receive error for {client_id}: {e}")
        finally:
            await self.remove_connection(client_id)
    
    async def _process_messages(self, client_id: str):
        """任务2: 从接收队列处理消息,结果放入发送队列"""
        try:
            while client_id in self.receive_queues:
                # 从接收队列获取消息
                message = await self.receive_queues[client_id].get()
                
                # 处理消息(这里简单回声)
                response = {
                    "type": "echo",
                    "original": message,
                    "timestamp": "now"
                }
                
                # 将响应放入发送队列
                if client_id in self.send_queues:
                    await self.send_queues[client_id].put(response)
                    
        except Exception as e:
            print(f"Process error for {client_id}: {e}")
    
    async def _send_messages(self, client_id: str, websocket: WebSocket):
        """任务3: 从发送队列取出消息并发送"""
        try:
            while client_id in self.send_queues:
                # 从发送队列获取消息
                message = await self.send_queues[client_id].get()
                
                # 通过WebSocket发送
                await websocket.send_json(message)
                
        except Exception as e:
            print(f"Send error for {client_id}: {e}")
    
    def send_message(self, client_id: str, message: dict):
        """从外部线程安全发送消息"""
        if client_id in self.send_queues:
            # 使用线程安全的方式将消息放入队列
            asyncio.run_coroutine_threadsafe(
                self.send_queues[client_id].put(message),
                asyncio.get_event_loop()`在这里插入代码片`
            )


# 创建管理器实例
manager = SimpleDualQueueManager()

6.2 连接管理器的使用

python 复制代码
@app.websocket("/ws/simple/{client_id}")
async def simple_websocket(websocket: WebSocket, client_id: str):
    await websocket.accept()
    
    # 注册连接到管理器
    await manager.add_connection(client_id, websocket)
    
    try:
        # 主循环只负责保持连接
        # 实际的消息处理已经在后台任务中运行
        while True:
            # 简单的心跳检查
            await asyncio.sleep(30)
            
            # 检查连接是否仍然有效
            if client_id not in manager.connections:
                break
                
    except Exception as e:
        print(f"WebSocket error for {client_id}: {e}")
    finally:
        await manager.remove_connection(client_id)

# 后台任务示例
import threading
import time

def background_task():
    """模拟后台任务发送消息"""
    count = 0
    while True:
        try:
            count += 1
            message = {
                "type": "background",
                "count": count,
                "timestamp": time.time()
            }
            
            # 向所有连接的客户端广播消息
            for client_id in list(manager.connections.keys()):
                manager.send_message(client_id, message)
            
            time.sleep(5)  # 每5秒发送一次
            
        except Exception as e:
            print(f"Background task error: {e}")
            time.sleep(1)

# 启动后台任务
@app.on_event("startup")
async def startup():
    thread = threading.Thread(target=background_task, daemon=True)
    thread.start()

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

说明:

三个独立任务:

· 接收任务:从 WebSocket 接收消息 → 放入接收队列

· 处理任务:从接收队列取出消息 → 处理 → 放入发送队列

· 发送任务:从发送队列取出消息 → 通过 WebSocket 发送

队列优势:

· 解耦:接收、处理、发送相互独立,互不阻塞

· 缓冲:处理速度不一致时,队列起到缓冲作用

· 线程安全:外部线程可以通过队列安全地发送消息

异步处理:

· 所有操作都是异步的,不会阻塞事件循环

· 使用 asyncio.Queue 实现异步队列

· 每个客户端有自己独立的队列和任务

这个简单范例包含了双队列异步处理的核心思想,可以根据需要扩展更复杂的功能如错误处理、优先级队列、批处理等。

相关推荐
weixin_456904272 小时前
UDP端口释放和清理时间分析
网络·网络协议·udp
Net_Walke3 小时前
【网络协议】数字签名与证书
网络·网络协议
北京耐用通信8 小时前
一“网”跨协议,万“设”皆可通!耐达讯自动化Modbus TCP转Profibus ,让控制无界,让能源有道。
网络·人工智能·网络协议·自动化·信息与通信
weixin_4365250711 小时前
芋道源码 - RabbitMQ + WebSocket 实现分布式消息推送
分布式·websocket·rabbitmq
小样还想跑15 小时前
UniApp ConnectSocket连接websocket
websocket·elasticsearch·uni-app
FreeBuf_16 小时前
Zloader木马再次升级:通过DNS隧道和WebSocket C2实现更隐蔽的攻击
websocket·网络协议·php
chuxinweihui16 小时前
Socket编程UDP
linux·网络·网络协议·udp·通信
北京耐用通信16 小时前
神秘魔法?耐达讯自动化Modbus TCP 转 Profibus 如何为光伏逆变器编织通信“天网”
网络·人工智能·网络协议·网络安全·自动化·信息与通信
游戏开发爱好者818 小时前
TCP 抓包分析:tcp抓包工具、 iOS/HTTPS 流量解析全流程
网络协议·tcp/ip·ios·小程序·https·uni-app·iphone