目录
- TCP/UDP网络编程实战深度解析
-
- [1. 网络编程基础与协议对比](#1. 网络编程基础与协议对比)
-
- [1.1 计算机网络模型](#1.1 计算机网络模型)
- [1.2 TCP与UDP协议核心差异](#1.2 TCP与UDP协议核心差异)
- [1.3 协议头部结构对比](#1.3 协议头部结构对比)
-
- [1.3.1 TCP头部结构(20字节+选项)](#1.3.1 TCP头部结构(20字节+选项))
- [1.3.2 UDP头部结构(8字节)](#1.3.2 UDP头部结构(8字节))
- [1.4 端口与Socket概念](#1.4 端口与Socket概念)
- [2. TCP编程深度解析](#2. TCP编程深度解析)
-
- [2.1 TCP三次握手与四次挥手](#2.1 TCP三次握手与四次挥手)
-
- [2.1.1 三次握手建立连接](#2.1.1 三次握手建立连接)
- [2.1.2 四次挥手关闭连接](#2.1.2 四次挥手关闭连接)
- [2.2 TCP状态转换详解](#2.2 TCP状态转换详解)
- [2.3 TCP基础服务器实现](#2.3 TCP基础服务器实现)
- [2.4 TCP高级特性实现](#2.4 TCP高级特性实现)
- [2.5 TCP客户端实现](#2.5 TCP客户端实现)
- [3. UDP编程深度解析](#3. UDP编程深度解析)
-
- [3.1 UDP协议特性与应用场景](#3.1 UDP协议特性与应用场景)
- [3.2 UDP基础服务器实现](#3.2 UDP基础服务器实现)
- [3.3 UDP客户端实现](#3.3 UDP客户端实现)
- [4. 高级网络编程技术](#4. 高级网络编程技术)
-
- [4.1 异步I/O与并发模型](#4.1 异步I/O与并发模型)
- [4.2 网络安全与加密](#4.2 网络安全与加密)
- [4.3 网络诊断与调试工具](#4.3 网络诊断与调试工具)
- [5. 实战项目:多协议聊天系统](#5. 实战项目:多协议聊天系统)
- [6. 性能优化与最佳实践](#6. 性能优化与最佳实践)
-
- [6.1 网络编程性能优化](#6.1 网络编程性能优化)
- [6.2 错误处理与容错](#6.2 错误处理与容错)
- [7. 总结](#7. 总结)
-
- [7.1 关键知识点总结](#7.1 关键知识点总结)
-
- [7.1.1 TCP编程核心](#7.1.1 TCP编程核心)
- [7.1.2 UDP编程核心](#7.1.2 UDP编程核心)
- [7.1.3 高级网络技术](#7.1.3 高级网络技术)
- [7.2 协议选择决策树](#7.2 协议选择决策树)
- [7.3 性能对比矩阵](#7.3 性能对比矩阵)
- [7.4 最佳实践清单](#7.4 最佳实践清单)
-
- [7.4.1 TCP编程最佳实践](#7.4.1 TCP编程最佳实践)
- [7.4.2 UDP编程最佳实践](#7.4.2 UDP编程最佳实践)
- [7.4.3 通用最佳实践](#7.4.3 通用最佳实践)
- [7.5 未来趋势](#7.5 未来趋势)
『宝藏代码胶囊开张啦!』------ 我的 CodeCapsule 来咯!✨写代码不再头疼!我的新站点 CodeCapsule 主打一个 "白菜价"+"量身定制 "!无论是卡脖子的毕设/课设/文献复现 ,需要灵光一现的算法改进 ,还是想给项目加个"外挂",这里都有便宜又好用的代码方案等你发现!低成本,高适配,助你轻松通关!速来围观 👉 CodeCapsule官网
TCP/UDP网络编程实战深度解析
1. 网络编程基础与协议对比
1.1 计算机网络模型
计算机网络通信遵循分层的体系结构,最常用的是OSI七层模型 和TCP/IP四层模型:
TCP/IP四层模型 OSI七层模型 传输层 应用层 网络层 网络接口层 表示层 应用层 会话层 传输层 网络层 数据链路层 物理层
1.2 TCP与UDP协议核心差异
| 特性 | TCP (传输控制协议) | UDP (用户数据报协议) |
|---|---|---|
| 连接性 | 面向连接,需要三次握手 | 无连接 |
| 可靠性 | 可靠传输,有确认机制 | 不可靠传输,尽力而为 |
| 有序性 | 数据包有序到达 | 数据包可能乱序 |
| 流量控制 | 有滑动窗口机制 | 无流量控制 |
| 拥塞控制 | 有多重拥塞控制算法 | 无拥塞控制 |
| 头部开销 | 20-60字节 | 8字节 |
| 传输速度 | 相对较慢 | 非常快 |
| 应用场景 | Web、Email、文件传输 | 视频流、DNS、游戏 |
1.3 协议头部结构对比
1.3.1 TCP头部结构(20字节+选项)
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgment Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data | |U|A|P|R|S|F| |
| Offset| Reserved |R|C|S|S|Y|I| Window |
| | |G|K|H|T|N|N| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Checksum | Urgent Pointer |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
1.3.2 UDP头部结构(8字节)
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
1.4 端口与Socket概念
端口号范围:
- 0-1023:知名端口(Well-known ports)
- 1024-49151:注册端口(Registered ports)
- 49152-65535:动态/私有端口(Dynamic/Private ports)
Socket(套接字) 是网络通信的端点,由IP地址和端口号唯一标识:
Socket = (IP地址, 端口号)
2. TCP编程深度解析
2.1 TCP三次握手与四次挥手
2.1.1 三次握手建立连接
Client Server 三次握手过程 SYN (seq=x) SYN_SENT状态 SYN-ACK (seq=y, ack=x+1) SYN_RECEIVED状态 ACK (seq=x+1, ack=y+1) 连接建立 数据传输开始 Client Server
2.1.2 四次挥手关闭连接
Client Server 四次挥手过程 FIN (seq=u) FIN_WAIT_1状态 ACK (ack=u+1) CLOSE_WAIT状态 FIN (seq=v, ack=u+1) LAST_ACK状态 ACK (seq=u+1, ack=v+1) TIME_WAIT状态(2MSL) Client Server
2.2 TCP状态转换详解
python
class TCPStateMachine:
"""TCP状态机实现"""
STATES = {
'CLOSED': '关闭状态',
'LISTEN': '监听状态',
'SYN_SENT': '已发送SYN',
'SYN_RECEIVED': '已收到SYN',
'ESTABLISHED': '已建立连接',
'FIN_WAIT_1': '等待FIN确认',
'FIN_WAIT_2': '等待对方FIN',
'CLOSE_WAIT': '等待关闭',
'CLOSING': '同时关闭',
'LAST_ACK': '最后确认',
'TIME_WAIT': '等待时间',
}
def __init__(self):
self.current_state = 'CLOSED'
self.state_history = []
def transition(self, event):
"""状态转换"""
transitions = {
('CLOSED', 'open'): 'LISTEN',
('LISTEN', 'syn_received'): 'SYN_RECEIVED',
('SYN_RECEIVED', 'ack_received'): 'ESTABLISHED',
('ESTABLISHED', 'fin_received'): 'CLOSE_WAIT',
('CLOSE_WAIT', 'close'): 'LAST_ACK',
('LAST_ACK', 'ack_received'): 'CLOSED',
}
key = (self.current_state, event)
if key in transitions:
self.state_history.append((self.current_state, event))
self.current_state = transitions[key]
return True
return False
def get_state_diagram(self):
"""生成状态图描述"""
diagram = """
digraph TCPStateMachine {
rankdir=LR;
size="8,5";
node [shape = circle];
CLOSED -> LISTEN [label = "被动打开"];
LISTEN -> SYN_RECEIVED [label = "收到SYN"];
SYN_RECEIVED -> ESTABLISHED [label = "收到ACK"];
ESTABLISHED -> FIN_WAIT_1 [label = "主动关闭"];
FIN_WAIT_1 -> FIN_WAIT_2 [label = "收到ACK"];
FIN_WAIT_2 -> TIME_WAIT [label = "收到FIN"];
TIME_WAIT -> CLOSED [label = "超时"];
ESTABLISHED -> CLOSE_WAIT [label = "收到FIN"];
CLOSE_WAIT -> LAST_ACK [label = "被动关闭"];
LAST_ACK -> CLOSED [label = "收到ACK"];
}
"""
return diagram
2.3 TCP基础服务器实现
python
# tcp_basic_server.py
import socket
import threading
import time
import struct
from typing import Tuple, Optional
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('TCPServer')
class TCPBasicServer:
"""TCP基础服务器实现"""
def __init__(self, host: str = '0.0.0.0', port: int = 8888):
"""
初始化TCP服务器
Args:
host: 监听地址
port: 监听端口
"""
self.host = host
self.port = port
self.server_socket: Optional[socket.socket] = None
self.running = False
self.client_threads = []
self.client_count = 0
# 连接统计
self.stats = {
'connections_total': 0,
'connections_active': 0,
'bytes_received': 0,
'bytes_sent': 0,
'errors_total': 0
}
def start(self):
"""启动服务器"""
try:
# 创建TCP socket
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 设置socket选项
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 绑定地址和端口
self.server_socket.bind((self.host, self.port))
# 开始监听,设置最大连接队列为5
self.server_socket.listen(5)
# 设置socket为非阻塞模式(可选)
# self.server_socket.setblocking(False)
self.running = True
logger.info(f"TCP服务器启动在 {self.host}:{self.port}")
# 启动统计线程
stats_thread = threading.Thread(target=self._stats_monitor, daemon=True)
stats_thread.start()
# 主循环接受连接
self._accept_connections()
except Exception as e:
logger.error(f"服务器启动失败: {e}")
self.stop()
def _accept_connections(self):
"""接受客户端连接"""
while self.running:
try:
# 接受客户端连接
client_socket, client_address = self.server_socket.accept()
# 更新统计
self.stats['connections_total'] += 1
self.stats['connections_active'] += 1
logger.info(f"新客户端连接: {client_address} "
f"(活跃连接: {self.stats['connections_active']})")
# 为每个客户端创建线程
client_thread = threading.Thread(
target=self._handle_client,
args=(client_socket, client_address),
daemon=True
)
client_thread.start()
self.client_threads.append(client_thread)
except socket.timeout:
continue
except OSError as e:
if self.running:
logger.error(f"接受连接错误: {e}")
break
except Exception as e:
logger.error(f"未知错误: {e}")
break
def _handle_client(self, client_socket: socket.socket, client_address: Tuple[str, int]):
"""处理客户端连接"""
client_id = f"{client_address[0]}:{client_address[1]}"
try:
# 设置socket超时(可选)
client_socket.settimeout(30.0)
# 发送欢迎消息
welcome_msg = f"欢迎连接到TCP服务器 {self.host}:{self.port}\r\n"
client_socket.send(welcome_msg.encode('utf-8'))
self.stats['bytes_sent'] += len(welcome_msg)
# 客户端处理循环
while self.running:
try:
# 接收数据
data = client_socket.recv(1024)
if not data:
# 客户端关闭连接
logger.info(f"客户端 {client_id} 断开连接")
break
# 更新接收统计
self.stats['bytes_received'] += len(data)
# 处理接收到的数据
message = data.decode('utf-8', errors='ignore').strip()
logger.info(f"来自 {client_id} 的消息: {message}")
# 回显消息
response = f"服务器收到: {message}\r\n"
client_socket.send(response.encode('utf-8'))
self.stats['bytes_sent'] += len(response)
# 特殊命令处理
if message.lower() == 'time':
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
time_msg = f"服务器时间: {current_time}\r\n"
client_socket.send(time_msg.encode('utf-8'))
self.stats['bytes_sent'] += len(time_msg)
elif message.lower() == 'stats':
stats_msg = self._get_stats_message()
client_socket.send(stats_msg.encode('utf-8'))
self.stats['bytes_sent'] += len(stats_msg)
elif message.lower() == 'quit':
goodbye_msg = "再见!\r\n"
client_socket.send(goodbye_msg.encode('utf-8'))
break
except socket.timeout:
# 发送心跳包保持连接
try:
client_socket.send(b"<HEARTBEAT>\r\n")
except:
break
except ConnectionResetError:
logger.warning(f"客户端 {client_id} 强制断开连接")
break
except Exception as e:
logger.error(f"处理客户端 {client_id} 时出错: {e}")
self.stats['errors_total'] += 1
break
except Exception as e:
logger.error(f"客户端 {client_id} 处理异常: {e}")
self.stats['errors_total'] += 1
finally:
# 清理资源
try:
client_socket.close()
except:
pass
self.stats['connections_active'] -= 1
logger.info(f"客户端 {client_id} 连接关闭")
def _get_stats_message(self) -> str:
"""获取统计信息消息"""
return (
f"服务器统计信息:\r\n"
f"- 总连接数: {self.stats['connections_total']}\r\n"
f"- 活跃连接: {self.stats['connections_active']}\r\n"
f"- 接收字节: {self.stats['bytes_received']}\r\n"
f"- 发送字节: {self.stats['bytes_sent']}\r\n"
f"- 错误总数: {self.stats['errors_total']}\r\n"
)
def _stats_monitor(self):
"""统计监控线程"""
while self.running:
time.sleep(10)
logger.info(f"统计监控 - {self._get_stats_message()}")
def stop(self):
"""停止服务器"""
self.running = False
# 关闭服务器socket
if self.server_socket:
try:
self.server_socket.close()
except:
pass
logger.info("TCP服务器已停止")
# 等待所有客户端线程结束
for thread in self.client_threads:
thread.join(timeout=2.0)
# 输出最终统计
logger.info("最终统计信息:")
logger.info(self._get_stats_message())
class TCPAdvancedServer(TCPBasicServer):
"""高级TCP服务器,支持更多特性"""
def __init__(self, host='0.0.0.0', port=8888):
super().__init__(host, port)
self.message_handlers = {
'echo': self._handle_echo,
'file': self._handle_file_transfer,
'calc': self._handle_calculation,
}
def _handle_client(self, client_socket: socket.socket, client_address: Tuple[str, int]):
"""增强的客户端处理"""
client_id = f"{client_address[0]}:{client_address[1]}"
buffer = b"" # 用于处理粘包
try:
# 发送协议头
protocol_header = struct.pack('!I', 0xA1B2C3D4) # 魔数
client_socket.send(protocol_header)
while self.running:
try:
# 接收数据
data = client_socket.recv(4096)
if not data:
break
buffer += data
# 处理完整的消息
while len(buffer) >= 8: # 至少包含头部
# 解析消息头:4字节长度 + 4字节类型
if len(buffer) >= 8:
msg_length, msg_type = struct.unpack('!II', buffer[:8])
if len(buffer) >= msg_length + 8:
# 提取完整消息
message_data = buffer[8:8+msg_length]
buffer = buffer[8+msg_length:]
# 处理消息
self._process_message(client_socket, msg_type, message_data)
else:
# 消息不完整,等待更多数据
break
except socket.timeout:
continue
except struct.error as e:
logger.error(f"协议解析错误: {e}")
break
except Exception as e:
logger.error(f"数据处理错误: {e}")
break
except Exception as e:
logger.error(f"客户端处理异常: {e}")
finally:
client_socket.close()
self.stats['connections_active'] -= 1
def _process_message(self, client_socket: socket.socket, msg_type: int, data: bytes):
"""处理不同类型的消息"""
handlers = {
1: self._handle_text_message,
2: self._handle_file_chunk,
3: self._handle_command,
}
handler = handlers.get(msg_type)
if handler:
handler(client_socket, data)
else:
logger.warning(f"未知消息类型: {msg_type}")
def _handle_text_message(self, client_socket: socket.socket, data: bytes):
"""处理文本消息"""
try:
text = data.decode('utf-8')
logger.info(f"收到文本消息: {text}")
# 回显
response = f"ECHO: {text}"
response_data = response.encode('utf-8')
response_header = struct.pack('!II', len(response_data), 1)
client_socket.send(response_header + response_data)
except UnicodeDecodeError:
logger.error("文本消息解码失败")
def _handle_file_chunk(self, client_socket: socket.socket, data: bytes):
"""处理文件分块"""
# 这里可以实现文件接收逻辑
logger.info(f"收到文件分块,大小: {len(data)} 字节")
def _handle_command(self, client_socket: socket.socket, data: bytes):
"""处理命令"""
try:
command = data.decode('utf-8').strip()
logger.info(f"收到命令: {command}")
# 执行命令并返回结果
import subprocess
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stdout + result.stderr
output_data = output.encode('utf-8')
output_header = struct.pack('!II', len(output_data), 1)
client_socket.send(output_header + output_data)
except Exception as e:
error_msg = f"命令执行错误: {e}"
error_data = error_msg.encode('utf-8')
error_header = struct.pack('!II', len(error_data), 1)
client_socket.send(error_header + error_data)
def main():
"""主函数"""
# 解析命令行参数
import argparse
parser = argparse.ArgumentParser(description='TCP服务器')
parser.add_argument('--host', default='0.0.0.0', help='监听地址')
parser.add_argument('--port', type=int, default=8888, help='监听端口')
parser.add_argument('--advanced', action='store_true', help='使用高级服务器')
args = parser.parse_args()
# 创建并启动服务器
if args.advanced:
server = TCPAdvancedServer(args.host, args.port)
else:
server = TCPBasicServer(args.host, args.port)
try:
server.start()
except KeyboardInterrupt:
logger.info("收到中断信号,正在关闭服务器...")
server.stop()
except Exception as e:
logger.error(f"服务器运行异常: {e}")
server.stop()
if __name__ == "__main__":
main()
2.4 TCP高级特性实现
python
# tcp_advanced_features.py
import socket
import threading
import queue
import time
import select
import ssl
from typing import List, Dict, Tuple
import hashlib
class TCPConnectionPool:
"""TCP连接池实现"""
def __init__(self, host: str, port: int, max_size: int = 10):
self.host = host
self.port = port
self.max_size = max_size
self.pool = queue.Queue(maxsize=max_size)
self.active_count = 0
self.lock = threading.Lock()
# 初始化连接池
self._initialize_pool()
def _initialize_pool(self):
"""初始化连接池"""
for _ in range(min(5, self.max_size)):
try:
conn = self._create_connection()
self.pool.put(conn)
self.active_count += 1
except Exception as e:
print(f"创建连接失败: {e}")
def _create_connection(self) -> socket.socket:
"""创建新连接"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5.0)
sock.connect((self.host, self.port))
return sock
def get_connection(self, timeout: float = 5.0) -> socket.socket:
"""从连接池获取连接"""
try:
# 先从队列获取
if not self.pool.empty():
return self.pool.get(timeout=timeout)
# 队列为空,检查是否可以创建新连接
with self.lock:
if self.active_count < self.max_size:
conn = self._create_connection()
self.active_count += 1
return conn
# 等待连接释放
return self.pool.get(timeout=timeout)
except queue.Empty:
raise TimeoutError("获取连接超时")
def release_connection(self, conn: socket.socket):
"""释放连接回连接池"""
if self.pool.qsize() < self.max_size:
self.pool.put(conn)
else:
conn.close()
with self.lock:
self.active_count -= 1
def close_all(self):
"""关闭所有连接"""
while not self.pool.empty():
try:
conn = self.pool.get_nowait()
conn.close()
with self.lock:
self.active_count -= 1
except queue.Empty:
break
class TCPLoadBalancer:
"""TCP负载均衡器"""
def __init__(self, backend_servers: List[Tuple[str, int]]):
"""
初始化负载均衡器
Args:
backend_servers: 后端服务器列表 [(host1, port1), (host2, port2), ...]
"""
self.backend_servers = backend_servers
self.current_index = 0
self.lock = threading.Lock()
self.server_stats = {server: {'requests': 0, 'errors': 0} for server in backend_servers}
def get_next_server(self) -> Tuple[str, int]:
"""获取下一个服务器(轮询算法)"""
with self.lock:
server = self.backend_servers[self.current_index]
self.current_index = (self.current_index + 1) % len(self.backend_servers)
self.server_stats[server]['requests'] += 1
return server
def get_least_connections_server(self, connection_counts: Dict[Tuple[str, int], int]) -> Tuple[str, int]:
"""获取连接数最少的服务器"""
if not connection_counts:
return self.get_next_server()
return min(connection_counts.items(), key=lambda x: x[1])[0]
def get_server_stats(self) -> Dict:
"""获取服务器统计信息"""
return self.server_stats.copy()
class SSLTCPServer:
"""SSL/TCP安全服务器"""
def __init__(self, host: str = '0.0.0.0', port: int = 8443):
self.host = host
self.port = port
self.context = self._create_ssl_context()
self.running = False
def _create_ssl_context(self) -> ssl.SSLContext:
"""创建SSL上下文"""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(certfile='server.crt', keyfile='server.key')
# 配置安全选项
context.options |= ssl.OP_NO_TLSv1 # 禁用TLS 1.0
context.options |= ssl.OP_NO_TLSv1_1 # 禁用TLS 1.1
context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20')
return context
def start(self):
"""启动SSL服务器"""
try:
# 创建TCP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.host, self.port))
sock.listen(5)
# 包装为SSL socket
ssl_sock = self.context.wrap_socket(sock, server_side=True)
self.running = True
print(f"SSL服务器启动在 {self.host}:{self.port}")
while self.running:
try:
client_socket, client_address = ssl_sock.accept()
# 处理SSL客户端
client_thread = threading.Thread(
target=self._handle_ssl_client,
args=(client_socket, client_address)
)
client_thread.start()
except ssl.SSLError as e:
print(f"SSL错误: {e}")
except Exception as e:
print(f"接受连接错误: {e}")
except Exception as e:
print(f"SSL服务器启动失败: {e}")
def _handle_ssl_client(self, client_socket: ssl.SSLSocket, client_address: Tuple[str, int]):
"""处理SSL客户端"""
try:
# 获取SSL证书信息
cert = client_socket.getpeercert()
if cert:
print(f"客户端证书主题: {cert.get('subject', {})}")
# SSL握手已完成,可以安全通信
welcome_msg = "安全连接已建立\r\n"
client_socket.send(welcome_msg.encode('utf-8'))
# 处理客户端消息
while True:
data = client_socket.recv(1024)
if not data:
break
# 处理加密数据
message = data.decode('utf-8', errors='ignore').strip()
print(f"安全消息: {message}")
# 回显
response = f"安全回显: {message}\r\n"
client_socket.send(response.encode('utf-8'))
except ssl.SSLError as e:
print(f"SSL通信错误: {e}")
except Exception as e:
print(f"客户端处理错误: {e}")
finally:
client_socket.close()
class TCPFileTransferServer:
"""TCP文件传输服务器"""
def __init__(self, host: str = '0.0.0.0', port: int = 9999):
self.host = host
self.port = port
self.running = False
self.file_transfers = {}
def start(self):
"""启动文件传输服务器"""
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((self.host, self.port))
server_socket.listen(5)
self.running = True
print(f"文件传输服务器启动在 {self.host}:{self.port}")
while self.running:
client_socket, client_address = server_socket.accept()
client_thread = threading.Thread(
target=self._handle_file_transfer,
args=(client_socket, client_address)
)
client_thread.start()
def _handle_file_transfer(self, client_socket: socket.socket, client_address: Tuple[str, int]):
"""处理文件传输"""
try:
# 接收文件传输请求头
header_data = client_socket.recv(256)
if not header_data:
return
# 解析请求头:文件名长度(4字节) + 文件名 + 文件大小(8字节)
filename_len = int.from_bytes(header_data[:4], 'big')
filename = header_data[4:4+filename_len].decode('utf-8')
file_size = int.from_bytes(header_data[4+filename_len:4+filename_len+8], 'big')
print(f"开始接收文件: {filename}, 大小: {file_size} 字节")
# 创建文件并接收数据
with open(f"received_{filename}", 'wb') as f:
received = 0
hash_md5 = hashlib.md5()
while received < file_size:
chunk = client_socket.recv(min(4096, file_size - received))
if not chunk:
break
f.write(chunk)
hash_md5.update(chunk)
received += len(chunk)
# 发送进度
progress = (received / file_size) * 100
if int(progress) % 10 == 0:
client_socket.send(f"进度: {progress:.1f}%\r\n".encode('utf-8'))
# 发送完成确认和MD5
md5_hash = hash_md5.hexdigest()
completion_msg = f"文件接收完成,MD5: {md5_hash}\r\n"
client_socket.send(completion_msg.encode('utf-8'))
print(f"文件接收完成: {filename}, MD5: {md5_hash}")
except Exception as e:
print(f"文件传输错误: {e}")
finally:
client_socket.close()
2.5 TCP客户端实现
python
# tcp_client.py
import socket
import threading
import time
import struct
import sys
from typing import Optional, Callable
import json
class TCPClient:
"""TCP客户端基类"""
def __init__(self, host: str = 'localhost', port: int = 8888):
self.host = host
self.port = port
self.socket: Optional[socket.socket] = None
self.connected = False
self.receive_thread: Optional[threading.Thread] = None
self.message_handlers = []
def connect(self, timeout: float = 5.0) -> bool:
"""连接到服务器"""
try:
# 创建TCP socket
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(timeout)
# 连接到服务器
self.socket.connect((self.host, self.port))
self.connected = True
print(f"已连接到服务器 {self.host}:{self.port}")
# 启动接收线程
self.receive_thread = threading.Thread(target=self._receive_messages, daemon=True)
self.receive_thread.start()
return True
except socket.timeout:
print("连接超时")
return False
except ConnectionRefusedError:
print("连接被拒绝,服务器可能未启动")
return False
except Exception as e:
print(f"连接错误: {e}")
return False
def send(self, message: str) -> bool:
"""发送消息"""
if not self.connected or not self.socket:
print("未连接到服务器")
return False
try:
# 添加换行符
if not message.endswith('\n'):
message += '\n'
self.socket.send(message.encode('utf-8'))
return True
except Exception as e:
print(f"发送消息失败: {e}")
self.connected = False
return False
def _receive_messages(self):
"""接收消息线程"""
buffer = ""
while self.connected and self.socket:
try:
data = self.socket.recv(1024)
if not data:
print("服务器断开连接")
self.connected = False
break
# 处理接收到的数据
buffer += data.decode('utf-8', errors='ignore')
# 按行分割消息
while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
line = line.strip()
if line:
self._handle_received_message(line)
except socket.timeout:
continue
except ConnectionResetError:
print("连接被重置")
self.connected = False
break
except Exception as e:
print(f"接收消息错误: {e}")
self.connected = False
break
def _handle_received_message(self, message: str):
"""处理接收到的消息"""
print(f"收到消息: {message}")
# 调用注册的消息处理器
for handler in self.message_handlers:
try:
handler(message)
except Exception as e:
print(f"消息处理器错误: {e}")
def register_message_handler(self, handler: Callable[[str], None]):
"""注册消息处理器"""
self.message_handlers.append(handler)
def disconnect(self):
"""断开连接"""
self.connected = False
if self.socket:
try:
self.socket.close()
except:
pass
print("已断开连接")
class SmartTCPClient(TCPClient):
"""智能TCP客户端,支持更多特性"""
def __init__(self, host='localhost', port=8888):
super().__init__(host, port)
self.auto_reconnect = True
self.reconnect_attempts = 0
self.max_reconnect_attempts = 5
self.reconnect_delay = 2 # 秒
def connect(self, timeout: float = 5.0) -> bool:
"""智能连接,支持自动重连"""
while self.reconnect_attempts < self.max_reconnect_attempts:
try:
if super().connect(timeout):
self.reconnect_attempts = 0
return True
self.reconnect_attempts += 1
if self.auto_reconnect and self.reconnect_attempts < self.max_reconnect_attempts:
print(f"连接失败,{self.reconnect_delay}秒后重试... "
f"({self.reconnect_attempts}/{self.max_reconnect_attempts})")
time.sleep(self.reconnect_delay)
self.reconnect_delay *= 2 # 指数退避
else:
break
except Exception as e:
print(f"连接异常: {e}")
self.reconnect_attempts += 1
return False
def send_with_ack(self, message: str, timeout: float = 5.0) -> Optional[str]:
"""发送消息并等待确认"""
ack_event = threading.Event()
ack_response = None
def ack_handler(response: str):
nonlocal ack_response
ack_response = response
ack_event.set()
# 注册临时处理器
self.register_message_handler(ack_handler)
# 发送消息
if not self.send(message):
self.message_handlers.remove(ack_handler)
return None
# 等待确认
if ack_event.wait(timeout):
self.message_handlers.remove(ack_handler)
return ack_response
else:
self.message_handlers.remove(ack_handler)
print("等待确认超时")
return None
def send_file(self, file_path: str, chunk_size: int = 4096) -> bool:
"""发送文件到服务器"""
try:
import os
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return False
file_size = os.path.getsize(file_path)
filename = os.path.basename(file_path)
# 发送文件头:文件名长度 + 文件名 + 文件大小
filename_bytes = filename.encode('utf-8')
header = (
len(filename_bytes).to_bytes(4, 'big') +
filename_bytes +
file_size.to_bytes(8, 'big')
)
self.socket.send(header)
# 发送文件内容
sent_bytes = 0
with open(file_path, 'rb') as f:
while sent_bytes < file_size:
chunk = f.read(chunk_size)
if not chunk:
break
self.socket.send(chunk)
sent_bytes += len(chunk)
# 显示进度
progress = (sent_bytes / file_size) * 100
if int(progress) % 10 == 0:
print(f"发送进度: {progress:.1f}%")
print(f"文件发送完成: {filename}")
return True
except Exception as e:
print(f"文件发送失败: {e}")
return False
class TCPBenchmarkClient:
"""TCP性能测试客户端"""
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.results = []
def run_latency_test(self, message: str = "PING", iterations: int = 100) -> dict:
"""运行延迟测试"""
latencies = []
for i in range(iterations):
try:
# 创建临时连接
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5.0)
start_time = time.perf_counter()
sock.connect((self.host, self.port))
# 发送消息
sock.send(f"{message}\n".encode('utf-8'))
# 接收响应
response = sock.recv(1024)
end_time = time.perf_counter()
latency = (end_time - start_time) * 1000 # 转换为毫秒
latencies.append(latency)
sock.close()
if i % 10 == 0:
print(f"已完成 {i+1}/{iterations} 次测试")
time.sleep(0.1) # 避免服务器过载
except Exception as e:
print(f"测试 {i+1} 失败: {e}")
continue
# 计算统计信息
if latencies:
import statistics
stats = {
'iterations': iterations,
'successful': len(latencies),
'failed': iterations - len(latencies),
'min_latency': min(latencies),
'max_latency': max(latencies),
'avg_latency': statistics.mean(latencies),
'median_latency': statistics.median(latencies),
'std_dev': statistics.stdev(latencies) if len(latencies) > 1 else 0,
}
print("\n延迟测试结果:")
for key, value in stats.items():
if 'latency' in key:
print(f"{key}: {value:.2f} ms")
else:
print(f"{key}: {value}")
return stats
else:
print("所有测试都失败了")
return {}
def run_throughput_test(self, duration: int = 10, message_size: int = 1024) -> dict:
"""运行吞吐量测试"""
try:
# 创建连接
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5.0)
sock.connect((self.host, self.port))
# 生成测试消息
test_message = 'X' * message_size + '\n'
messages_sent = 0
bytes_sent = 0
start_time = time.perf_counter()
print(f"开始吞吐量测试,持续时间: {duration} 秒")
while time.perf_counter() - start_time < duration:
try:
sent = sock.send(test_message.encode('utf-8'))
if sent > 0:
messages_sent += 1
bytes_sent += sent
# 非阻塞接收(清空缓冲区)
try:
sock.recv(4096)
except socket.timeout:
pass
except Exception as e:
print(f"发送错误: {e}")
break
end_time = time.perf_counter()
actual_duration = end_time - start_time
sock.close()
# 计算吞吐量
throughput_messages = messages_sent / actual_duration
throughput_bytes = bytes_sent / actual_duration
results = {
'duration': actual_duration,
'messages_sent': messages_sent,
'bytes_sent': bytes_sent,
'throughput_messages': throughput_messages,
'throughput_bytes': throughput_bytes,
'throughput_mbps': (throughput_bytes * 8) / (1024 * 1024),
}
print("\n吞吐量测试结果:")
print(f"持续时间: {actual_duration:.2f} 秒")
print(f"发送消息数: {messages_sent}")
print(f"发送字节数: {bytes_sent}")
print(f"消息吞吐量: {throughput_messages:.2f} 消息/秒")
print(f"字节吞吐量: {throughput_bytes:.2f} 字节/秒")
print(f"带宽: {results['throughput_mbps']:.2f} Mbps")
return results
except Exception as e:
print(f"吞吐量测试失败: {e}")
return {}
def interactive_client():
"""交互式客户端"""
import argparse
parser = argparse.ArgumentParser(description='TCP客户端')
parser.add_argument('--host', default='localhost', help='服务器地址')
parser.add_argument('--port', type=int, default=8888, help='服务器端口')
parser.add_argument('--benchmark', action='store_true', help='运行性能测试')
args = parser.parse_args()
if args.benchmark:
# 运行性能测试
benchmark = TCPBenchmarkClient(args.host, args.port)
print("=== 延迟测试 ===")
latency_results = benchmark.run_latency_test(iterations=50)
print("\n=== 吞吐量测试 ===")
throughput_results = benchmark.run_throughput_test(duration=5)
else:
# 运行交互式客户端
client = SmartTCPClient(args.host, args.port)
if not client.connect():
print("连接失败,退出")
return
# 注册消息处理器
def message_handler(msg: str):
if msg.startswith("进度:"):
print(f"\r{msg}", end='')
else:
print(f"\n{msg}")
client.register_message_handler(message_handler)
print("输入消息发送到服务器,输入 'quit' 退出,'file <路径>' 发送文件")
try:
while True:
try:
user_input = input("> ").strip()
if not user_input:
continue
if user_input.lower() == 'quit':
break
elif user_input.lower().startswith('file '):
# 发送文件
file_path = user_input[5:].strip()
if client.send_file(file_path):
print("文件发送请求已发出")
else:
print("文件发送失败")
else:
# 发送普通消息
client.send(user_input)
except KeyboardInterrupt:
print("\n中断输入")
break
except EOFError:
break
finally:
client.disconnect()
if __name__ == "__main__":
interactive_client()
3. UDP编程深度解析
3.1 UDP协议特性与应用场景
UDP(用户数据报协议)是一种无连接的传输层协议,具有以下特点:
适用场景:
- 实时音视频传输:容忍少量丢包,不能忍受延迟
- DNS查询:简单请求-响应,无需连接
- 在线游戏:需要低延迟,状态同步
- 广播/多播应用:一对多通信
- IoT设备通信:资源受限环境
3.2 UDP基础服务器实现
python
# udp_basic_server.py
import socket
import threading
import time
import struct
from typing import Tuple, Optional, Dict
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('UDPServer')
class UDPBasicServer:
"""UDP基础服务器实现"""
def __init__(self, host: str = '0.0.0.0', port: int = 9999):
"""
初始化UDP服务器
Args:
host: 监听地址
port: 监听端口
"""
self.host = host
self.port = port
self.server_socket: Optional[socket.socket] = None
self.running = False
# 客户端会话管理
self.client_sessions: Dict[Tuple[str, int], Dict] = {}
# 统计信息
self.stats = {
'packets_received': 0,
'packets_sent': 0,
'bytes_received': 0,
'bytes_sent': 0,
'errors_total': 0,
'unique_clients': 0,
}
def start(self):
"""启动UDP服务器"""
try:
# 创建UDP socket
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 设置socket选项
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 设置接收缓冲区大小(可选)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
# 绑定地址和端口
self.server_socket.bind((self.host, self.port))
# 设置超时(可选,用于非阻塞操作)
self.server_socket.settimeout(1.0)
self.running = True
logger.info(f"UDP服务器启动在 {self.host}:{self.port}")
# 启动统计监控线程
stats_thread = threading.Thread(target=self._stats_monitor, daemon=True)
stats_thread.start()
# 启动会话清理线程
cleanup_thread = threading.Thread(target=self._session_cleanup, daemon=True)
cleanup_thread.start()
# 主接收循环
self._receive_loop()
except Exception as e:
logger.error(f"服务器启动失败: {e}")
self.stop()
def _receive_loop(self):
"""接收数据包循环"""
while self.running:
try:
# 接收数据包
data, client_address = self.server_socket.recvfrom(65536)
# 更新统计
self.stats['packets_received'] += 1
self.stats['bytes_received'] += len(data)
# 更新客户端会话
if client_address not in self.client_sessions:
self.client_sessions[client_address] = {
'first_seen': time.time(),
'last_seen': time.time(),
'packets_received': 0,
'packets_sent': 0,
'bytes_received': 0,
'bytes_sent': 0,
}
self.stats['unique_clients'] += 1
session = self.client_sessions[client_address]
session['last_seen'] = time.time()
session['packets_received'] += 1
session['bytes_received'] += len(data)
# 处理数据包
self._handle_packet(data, client_address)
except socket.timeout:
continue
except BlockingIOError:
continue
except Exception as e:
logger.error(f"接收数据包错误: {e}")
self.stats['errors_total'] += 1
if not self.running:
break
def _handle_packet(self, data: bytes, client_address: Tuple[str, int]):
"""处理接收到的数据包"""
try:
# 尝试解码为文本
try:
message = data.decode('utf-8').strip()
logger.info(f"来自 {client_address} 的消息: {message}")
# 回显消息
response = f"服务器收到: {message}"
self._send_response(response.encode('utf-8'), client_address)
# 特殊命令处理
if message.lower() == 'time':
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
time_msg = f"服务器时间: {current_time}"
self._send_response(time_msg.encode('utf-8'), client_address)
elif message.lower() == 'stats':
stats_msg = self._get_stats_message()
self._send_response(stats_msg.encode('utf-8'), client_address)
elif message.lower() == 'session':
session_info = self._get_session_info(client_address)
self._send_response(session_info.encode('utf-8'), client_address)
except UnicodeDecodeError:
# 处理二进制数据
logger.info(f"来自 {client_address} 的二进制数据,长度: {len(data)}")
# 回显二进制数据
self._send_response(b"BINARY_ACK:" + data[:100], client_address)
except Exception as e:
logger.error(f"处理数据包错误: {e}")
def _send_response(self, data: bytes, client_address: Tuple[str, int]):
"""发送响应到客户端"""
try:
sent = self.server_socket.sendto(data, client_address)
# 更新统计
self.stats['packets_sent'] += 1
self.stats['bytes_sent'] += sent
# 更新会话统计
if client_address in self.client_sessions:
self.client_sessions[client_address]['packets_sent'] += 1
self.client_sessions[client_address]['bytes_sent'] += sent
logger.debug(f"发送响应到 {client_address}, 大小: {sent} 字节")
except Exception as e:
logger.error(f"发送响应错误: {e}")
self.stats['errors_total'] += 1
def _get_stats_message(self) -> str:
"""获取统计信息消息"""
active_sessions = sum(
1 for session in self.client_sessions.values()
if time.time() - session['last_seen'] < 300 # 5分钟内活跃
)
return (
f"UDP服务器统计信息:\n"
f"- 接收数据包: {self.stats['packets_received']}\n"
f"- 发送数据包: {self.stats['packets_sent']}\n"
f"- 接收字节数: {self.stats['bytes_received']}\n"
f"- 发送字节数: {self.stats['bytes_sent']}\n"
f"- 唯一客户端: {self.stats['unique_clients']}\n"
f"- 活跃会话: {active_sessions}\n"
f"- 错误总数: {self.stats['errors_total']}\n"
)
def _get_session_info(self, client_address: Tuple[str, int]) -> str:
"""获取客户端会话信息"""
if client_address not in self.client_sessions:
return "无会话信息"
session = self.client_sessions[client_address]
duration = time.time() - session['first_seen']
return (
f"客户端 {client_address} 会话信息:\n"
f"- 首次连接: {time.ctime(session['first_seen'])}\n"
f"- 最后活动: {time.ctime(session['last_seen'])}\n"
f"- 会话时长: {duration:.1f} 秒\n"
f"- 接收数据包: {session['packets_received']}\n"
f"- 发送数据包: {session['packets_sent']}\n"
f"- 接收字节数: {session['bytes_received']}\n"
f"- 发送字节数: {session['bytes_sent']}\n"
)
def _session_cleanup(self):
"""清理过期会话"""
while self.running:
time.sleep(60) # 每分钟清理一次
current_time = time.time()
expired_clients = []
for client_address, session in self.client_sessions.items():
if current_time - session['last_seen'] > 600: # 10分钟无活动
expired_clients.append(client_address)
for client_address in expired_clients:
del self.client_sessions[client_address]
logger.info(f"清理过期会话: {client_address}")
def _stats_monitor(self):
"""统计监控线程"""
while self.running:
time.sleep(30)
logger.info(f"统计监控 - {self._get_stats_message()}")
def stop(self):
"""停止服务器"""
self.running = False
if self.server_socket:
try:
self.server_socket.close()
except:
pass
logger.info("UDP服务器已停止")
# 输出最终统计
logger.info("最终统计信息:")
logger.info(self._get_stats_message())
class ReliableUDPServer(UDPBasicServer):
"""可靠UDP服务器实现(类似QUIC)"""
def __init__(self, host='0.0.0.0', port=9999):
super().__init__(host, port)
# 序列号管理
self.client_sequences: Dict[Tuple[str, int], Dict] = {}
# 数据包确认机制
self.pending_acks: Dict[Tuple[str, int], set] = {}
# 重传队列
self.retransmission_queue = []
def _handle_packet(self, data: bytes, client_address: Tuple[str, int]):
"""增强的数据包处理,支持可靠性"""
try:
# 解析数据包头部
if len(data) >= 12:
# 头部格式:4字节魔数 + 4字节序列号 + 4字节确认号
magic, seq_num, ack_num = struct.unpack('!III', data[:12])
if magic == 0x12345678: # 验证魔数
payload = data[12:]
# 处理确认
if ack_num > 0:
self._handle_acknowledgement(client_address, ack_num)
# 处理序列号
expected_seq = self._get_expected_sequence(client_address)
if seq_num == expected_seq:
# 正确的序列号,处理数据
self._process_reliable_payload(payload, client_address)
# 发送确认
self._send_acknowledgement(client_address, seq_num)
# 更新期望序列号
self._update_expected_sequence(client_address, seq_num + 1)
elif seq_num < expected_seq:
# 重复数据包,重新发送确认
self._send_acknowledgement(client_address, seq_num - 1)
else:
# 乱序数据包,暂时存储
logger.warning(f"乱序数据包: {seq_num} != {expected_seq}")
return
# 普通UDP数据包处理
super()._handle_packet(data, client_address)
except Exception as e:
logger.error(f"可靠数据包处理错误: {e}")
def _get_expected_sequence(self, client_address: Tuple[str, int]) -> int:
"""获取期望的序列号"""
if client_address not in self.client_sequences:
self.client_sequences[client_address] = {
'expected_seq': 1,
'last_acked': 0,
}
return self.client_sequences[client_address]['expected_seq']
def _update_expected_sequence(self, client_address: Tuple[str, int], seq_num: int):
"""更新期望序列号"""
if client_address in self.client_sequences:
self.client_sequences[client_address]['expected_seq'] = seq_num
def _send_acknowledgement(self, client_address: Tuple[str, int], ack_num: int):
"""发送确认"""
ack_packet = struct.pack('!III', 0x12345678, 0, ack_num)
# 添加到待确认集合
if client_address not in self.pending_acks:
self.pending_acks[client_address] = set()
self.pending_acks[client_address].add(ack_num)
# 发送确认
self._send_response(ack_packet, client_address)
def _handle_acknowledgement(self, client_address: Tuple[str, int], ack_num: int):
"""处理确认"""
if client_address in self.pending_acks and ack_num in self.pending_acks[client_address]:
self.pending_acks[client_address].remove(ack_num)
logger.debug(f"收到确认: {ack_num} 来自 {client_address}")
def _process_reliable_payload(self, payload: bytes, client_address: Tuple[str, int]):
"""处理可靠负载"""
try:
message = payload.decode('utf-8').strip()
logger.info(f"可靠消息来自 {client_address}: {message}")
# 处理消息
response = f"可靠确认: {message}"
self._send_reliable_response(response.encode('utf-8'), client_address)
except UnicodeDecodeError:
logger.info(f"可靠二进制数据来自 {client_address}, 长度: {len(payload)}")
def _send_reliable_response(self, data: bytes, client_address: Tuple[str, int]):
"""发送可靠响应"""
# 获取下一个序列号
if client_address not in self.client_sequences:
seq_num = 1
self.client_sequences[client_address] = {
'expected_seq': 1,
'last_acked': 0,
}
else:
# 这里简化处理,实际需要更复杂的序列号管理
seq_num = self.client_sequences[client_address]['last_acked'] + 1
# 构建可靠数据包
reliable_packet = struct.pack('!III', 0x12345678, seq_num, 0) + data
# 发送并添加到重传队列
self._send_response(reliable_packet, client_address)
# 添加到重传队列(简化实现)
retransmit_info = {
'client_address': client_address,
'packet': reliable_packet,
'seq_num': seq_num,
'timestamp': time.time(),
'retry_count': 0,
}
self.retransmission_queue.append(retransmit_info)
class UDPBroadcastServer:
"""UDP广播服务器"""
def __init__(self, broadcast_port: int = 9998):
self.broadcast_port = broadcast_port
self.running = False
# 服务发现信息
self.service_info = {
'service_name': 'UDP_Chat_Server',
'version': '1.0.0',
'port': 9999,
'host': self._get_local_ip(),
'timestamp': time.time(),
}
def _get_local_ip(self) -> str:
"""获取本地IP地址"""
try:
# 创建一个临时socket来获取本地IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
local_ip = s.getsockname()[0]
s.close()
return local_ip
except:
return '127.0.0.1'
def start(self):
"""启动广播服务器"""
try:
# 创建UDP广播socket
self.broadcast_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.broadcast_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.broadcast_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
self.running = True
logger.info(f"UDP广播服务器启动,端口: {self.broadcast_port}")
# 广播线程
broadcast_thread = threading.Thread(target=self._broadcast_loop, daemon=True)
broadcast_thread.start()
# 监听线程
listen_thread = threading.Thread(target=self._listen_for_discovery, daemon=True)
listen_thread.start()
except Exception as e:
logger.error(f"广播服务器启动失败: {e}")
def _broadcast_loop(self):
"""广播循环"""
while self.running:
try:
# 更新时间戳
self.service_info['timestamp'] = time.time()
# 序列化服务信息
import json
service_json = json.dumps(self.service_info).encode('utf-8')
# 广播到局域网
broadcast_address = ('255.255.255.255', self.broadcast_port)
self.broadcast_socket.sendto(service_json, broadcast_address)
logger.debug(f"广播服务信息: {self.service_info}")
time.sleep(10) # 每10秒广播一次
except Exception as e:
logger.error(f"广播错误: {e}")
time.sleep(5)
def _listen_for_discovery(self):
"""监听发现请求"""
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind(('0.0.0.0', self.broadcast_port))
while self.running:
try:
data, client_address = listen_socket.recvfrom(1024)
# 处理发现请求
request = data.decode('utf-8', errors='ignore').strip()
if request == 'DISCOVER':
logger.info(f"收到发现请求来自: {client_address}")
# 发送服务信息
import json
response = json.dumps(self.service_info).encode('utf-8')
listen_socket.sendto(response, client_address)
except socket.timeout:
continue
except Exception as e:
logger.error(f"发现监听错误: {e}")
listen_socket.close()
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='UDP服务器')
parser.add_argument('--host', default='0.0.0.0', help='监听地址')
parser.add_argument('--port', type=int, default=9999, help='监听端口')
parser.add_argument('--reliable', action='store_true', help='使用可靠UDP')
parser.add_argument('--broadcast', action='store_true', help='启用服务发现广播')
args = parser.parse_args()
# 启动广播服务器(如果启用)
broadcast_server = None
if args.broadcast:
broadcast_server = UDPBroadcastServer()
broadcast_server.start()
# 创建并启动UDP服务器
if args.reliable:
server = ReliableUDPServer(args.host, args.port)
else:
server = UDPBasicServer(args.host, args.port)
try:
server.start()
except KeyboardInterrupt:
logger.info("收到中断信号,正在关闭服务器...")
server.stop()
if broadcast_server:
broadcast_server.running = False
except Exception as e:
logger.error(f"服务器运行异常: {e}")
server.stop()
if __name__ == "__main__":
main()
3.3 UDP客户端实现
python
# udp_client.py
import socket
import threading
import time
import struct
import sys
from typing import Optional, Tuple, Callable
import select
class UDPClient:
"""UDP客户端基类"""
def __init__(self, server_host: str = 'localhost', server_port: int = 9999):
self.server_host = server_host
self.server_port = server_port
self.client_socket: Optional[socket.socket] = None
self.running = False
self.receive_thread: Optional[threading.Thread] = None
self.message_handlers = []
# 统计信息
self.stats = {
'packets_sent': 0,
'packets_received': 0,
'bytes_sent': 0,
'bytes_received': 0,
'timeout_count': 0,
}
def connect(self):
"""连接服务器(UDP无连接,这里只是创建socket)"""
try:
# 创建UDP socket
self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 设置超时
self.client_socket.settimeout(5.0)
# 设置发送缓冲区大小
self.client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
self.running = True
print(f"UDP客户端已就绪,目标服务器: {self.server_host}:{self.server_port}")
# 启动接收线程
self.receive_thread = threading.Thread(target=self._receive_loop, daemon=True)
self.receive_thread.start()
return True
except Exception as e:
print(f"客户端初始化失败: {e}")
return False
def send(self, message: str, timeout: float = 5.0) -> bool:
"""发送消息到服务器"""
if not self.client_socket:
print("客户端未初始化")
return False
try:
# 准备数据
if not isinstance(message, bytes):
data = message.encode('utf-8')
else:
data = message
# 发送数据
start_time = time.time()
sent = self.client_socket.sendto(data, (self.server_host, self.server_port))
# 更新统计
self.stats['packets_sent'] += 1
self.stats['bytes_sent'] += sent
print(f"已发送 {sent} 字节到服务器")
return True
except socket.timeout:
print("发送超时")
self.stats['timeout_count'] += 1
return False
except Exception as e:
print(f"发送失败: {e}")
return False
def send_with_response(self, message: str, timeout: float = 5.0) -> Optional[bytes]:
"""发送消息并等待响应"""
if not self.send(message):
return None
# 设置socket超时
original_timeout = self.client_socket.gettimeout()
self.client_socket.settimeout(timeout)
try:
# 等待响应
response, server_address = self.client_socket.recvfrom(65536)
# 更新统计
self.stats['packets_received'] += 1
self.stats['bytes_received'] += len(response)
print(f"收到响应,大小: {len(response)} 字节")
return response
except socket.timeout:
print("等待响应超时")
self.stats['timeout_count'] += 1
return None
finally:
self.client_socket.settimeout(original_timeout)
def _receive_loop(self):
"""接收数据包循环(用于异步接收)"""
while self.running and self.client_socket:
try:
# 非阻塞检查是否有数据
ready = select.select([self.client_socket], [], [], 0.1)
if ready[0]:
data, server_address = self.client_socket.recvfrom(65536)
# 更新统计
self.stats['packets_received'] += 1
self.stats['bytes_received'] += len(data)
# 处理接收到的数据
self._handle_received_data(data, server_address)
except socket.timeout:
continue
except Exception as e:
if self.running:
print(f"接收数据错误: {e}")
def _handle_received_data(self, data: bytes, server_address: Tuple[str, int]):
"""处理接收到的数据"""
try:
# 尝试解码为文本
try:
message = data.decode('utf-8').strip()
print(f"收到来自 {server_address} 的消息: {message}")
# 调用消息处理器
for handler in self.message_handlers:
try:
handler(message, server_address)
except Exception as e:
print(f"消息处理器错误: {e}")
except UnicodeDecodeError:
# 二进制数据
print(f"收到来自 {server_address} 的二进制数据,大小: {len(data)} 字节")
except Exception as e:
print(f"处理接收数据错误: {e}")
def register_message_handler(self, handler: Callable[[str, Tuple], None]):
"""注册消息处理器"""
self.message_handlers.append(handler)
def disconnect(self):
"""断开连接"""
self.running = False
if self.client_socket:
try:
self.client_socket.close()
except:
pass
print("UDP客户端已断开连接")
# 输出统计信息
print("\n客户端统计信息:")
for key, value in self.stats.items():
print(f"{key}: {value}")
class ReliableUDPClient(UDPClient):
"""可靠UDP客户端"""
def __init__(self, server_host='localhost', server_port=9999):
super().__init__(server_host, server_port)
# 可靠性相关
self.sequence_number = 1
self.expected_ack = 0
self.pending_packets = {} # 等待确认的数据包
self.retry_limit = 3
self.retry_delay = 1.0 # 秒
def send_reliable(self, message: str, timeout: float = 5.0) -> Optional[bytes]:
"""发送可靠消息"""
# 构建可靠数据包
reliable_data = self._create_reliable_packet(message.encode('utf-8'))
# 发送并等待确认
for attempt in range(self.retry_limit):
# 发送数据包
if not self.send(reliable_data):
return None
# 记录发送时间
send_time = time.time()
packet_id = self.sequence_number - 1
# 添加到等待确认队列
self.pending_packets[packet_id] = {
'data': reliable_data,
'send_time': send_time,
'attempt': attempt + 1,
}
# 等待确认
start_time = time.time()
while time.time() - start_time < timeout:
# 检查是否有确认
for pending_id in list(self.pending_packets.keys()):
if pending_id < self.expected_ack:
# 已确认,从队列移除
del self.pending_packets[pending_id]
if packet_id < self.expected_ack:
# 收到确认
print(f"消息确认,序列号: {packet_id}")
# 等待服务器响应
return self._wait_for_response(timeout - (time.time() - start_time))
time.sleep(0.1)
print(f"尝试 {attempt + 1} 超时,重试...")
time.sleep(self.retry_delay)
print(f"消息发送失败,达到最大重试次数: {self.retry_limit}")
return None
def _create_reliable_packet(self, data: bytes) -> bytes:
"""创建可靠数据包"""
# 头部格式:4字节魔数 + 4字节序列号 + 4字节确认号
packet_header = struct.pack('!III', 0x12345678, self.sequence_number, 0)
# 增加序列号
current_seq = self.sequence_number
self.sequence_number += 1
return packet_header + data
def _wait_for_response(self, timeout: float) -> Optional[bytes]:
"""等待服务器响应"""
if timeout <= 0:
return None
original_timeout = self.client_socket.gettimeout()
self.client_socket.settimeout(timeout)
try:
response, _ = self.client_socket.recvfrom(65536)
# 解析响应
if len(response) >= 12:
magic, seq_num, ack_num = struct.unpack('!III', response[:12])
if magic == 0x12345678:
# 更新期望的确认号
if ack_num > self.expected_ack:
self.expected_ack = ack_num
return response
except socket.timeout:
return None
finally:
self.client_socket.settimeout(original_timeout)
def _handle_received_data(self, data: bytes, server_address: Tuple[str, int]):
"""处理接收到的数据,包括确认包"""
if len(data) >= 12:
magic, seq_num, ack_num = struct.unpack('!III', data[:12])
if magic == 0x12345678:
# 处理确认
if ack_num > 0:
self._handle_acknowledgement(ack_num)
# 处理数据
if seq_num > 0:
payload = data[12:]
super()._handle_received_data(payload, server_address)
# 发送确认
self._send_acknowledgement(seq_num)
return
# 普通数据包
super()._handle_received_data(data, server_address)
def _handle_acknowledgement(self, ack_num: int):
"""处理确认"""
if ack_num in self.pending_packets:
print(f"收到确认: {ack_num}")
del self.pending_packets[ack_num]
# 更新期望确认号
if ack_num >= self.expected_ack:
self.expected_ack = ack_num + 1
def _send_acknowledgement(self, seq_num: int):
"""发送确认"""
ack_packet = struct.pack('!III', 0x12345678, 0, seq_num)
self.send(ack_packet)
class UDPDiscoveryClient:
"""UDP服务发现客户端"""
def __init__(self, broadcast_port: int = 9998, timeout: float = 10.0):
self.broadcast_port = broadcast_port
self.timeout = timeout
self.discovered_services = []
def discover_services(self) -> list:
"""发现局域网内的服务"""
print(f"开始服务发现,超时: {self.timeout} 秒")
try:
# 创建UDP socket
discover_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
discover_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
discover_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
# 绑定到所有接口
discover_socket.bind(('0.0.0.0', self.broadcast_port))
discover_socket.settimeout(1.0)
# 发送发现请求
discover_socket.sendto(b'DISCOVER', ('255.255.255.255', self.broadcast_port))
# 监听响应
self.discovered_services = []
start_time = time.time()
while time.time() - start_time < self.timeout:
try:
data, address = discover_socket.recvfrom(1024)
# 解析服务信息
import json
try:
service_info = json.loads(data.decode('utf-8'))
service_info['discovery_address'] = address
service_info['discovery_time'] = time.time()
self.discovered_services.append(service_info)
print(f"发现服务: {service_info.get('service_name', '未知')} "
f"来自 {address[0]}:{service_info.get('port', '未知')}")
except json.JSONDecodeError:
print(f"无效的服务信息来自 {address}")
except socket.timeout:
continue
discover_socket.close()
print(f"\n发现 {len(self.discovered_services)} 个服务")
return self.discovered_services
except Exception as e:
print(f"服务发现失败: {e}")
return []
def listen_for_broadcasts(self, duration: float = 30.0):
"""监听服务广播"""
print(f"开始监听服务广播,持续时间: {duration} 秒")
try:
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind(('0.0.0.0', self.broadcast_port))
listen_socket.settimeout(1.0)
self.discovered_services = []
start_time = time.time()
while time.time() - start_time < duration:
try:
data, address = listen_socket.recvfrom(1024)
# 解析服务信息
import json
try:
service_info = json.loads(data.decode('utf-8'))
# 检查是否已存在
existing = False
for svc in self.discovered_services:
if (svc.get('host') == service_info.get('host') and
svc.get('port') == service_info.get('port')):
svc['last_seen'] = time.time()
existing = True
break
if not existing:
service_info['discovery_address'] = address
service_info['discovery_time'] = time.time()
service_info['last_seen'] = time.time()
self.discovered_services.append(service_info)
print(f"发现新服务: {service_info.get('service_name', '未知')} "
f"来自 {address[0]}:{service_info.get('port', '未知')}")
except json.JSONDecodeError:
pass # 忽略非JSON数据
except socket.timeout:
continue
listen_socket.close()
print(f"\n监听结束,共发现 {len(self.discovered_services)} 个服务")
return self.discovered_services
except Exception as e:
print(f"监听广播失败: {e}")
return []
def interactive_udp_client():
"""交互式UDP客户端"""
import argparse
parser = argparse.ArgumentParser(description='UDP客户端')
parser.add_argument('--host', default='localhost', help='服务器地址')
parser.add_argument('--port', type=int, default=9999, help='服务器端口')
parser.add_argument('--reliable', action='store_true', help='使用可靠UDP')
parser.add_argument('--discover', action='store_true', help='发现服务')
parser.add_argument('--listen', action='store_true', help='监听服务广播')
args = parser.parse_args()
if args.discover or args.listen:
# 服务发现模式
discovery = UDPDiscoveryClient()
if args.listen:
services = discovery.listen_for_broadcasts(duration=30)
else:
services = discovery.discover_services()
if services:
print("\n发现的服务列表:")
for i, service in enumerate(services, 1):
print(f"{i}. {service.get('service_name')} - "
f"{service.get('host')}:{service.get('port')} "
f"(版本: {service.get('version', '未知')})")
# 选择服务连接
choice = input("\n选择要连接的服务编号 (0退出): ")
try:
choice_idx = int(choice) - 1
if 0 <= choice_idx < len(services):
args.host = services[choice_idx]['host']
args.port = services[choice_idx]['port']
else:
return
except:
return
else:
print("未发现任何服务")
return
# 创建客户端
if args.reliable:
client = ReliableUDPClient(args.host, args.port)
else:
client = UDPClient(args.host, args.port)
if not client.connect():
print("客户端连接失败")
return
# 注册消息处理器
def message_handler(msg: str, address: Tuple[str, int]):
print(f"\n[来自 {address}] {msg}")
client.register_message_handler(message_handler)
print("输入消息发送到服务器,输入 'quit' 退出")
print("特殊命令: 'time' - 获取服务器时间, 'stats' - 获取统计信息")
try:
while True:
try:
user_input = input("> ").strip()
if not user_input:
continue
if user_input.lower() == 'quit':
break
# 发送消息
if args.reliable:
response = client.send_reliable(user_input, timeout=5)
if response:
try:
print(f"服务器响应: {response.decode('utf-8')}")
except:
print(f"服务器响应 (二进制): {response[:50]}...")
else:
response = client.send_with_response(user_input, timeout=5)
if response:
try:
print(f"服务器响应: {response.decode('utf-8')}")
except:
print(f"服务器响应 (二进制): {response[:50]}...")
except KeyboardInterrupt:
print("\n中断输入")
break
except EOFError:
break
finally:
client.disconnect()
if __name__ == "__main__":
interactive_udp_client()
4. 高级网络编程技术
4.1 异步I/O与并发模型
python
# async_network.py
import asyncio
import socket
import selectors
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Tuple
import time
class AsyncTCPServer:
"""异步TCP服务器"""
def __init__(self, host='0.0.0.0', port=8888):
self.host = host
self.port = port
self.selector = selectors.DefaultSelector()
self.running = False
self.client_count = 0
async def start_async(self):
"""异步启动服务器"""
# 创建TCP socket
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.setblocking(False)
server_socket.bind((self.host, self.port))
server_socket.listen()
print(f"异步TCP服务器启动在 {self.host}:{self.port}")
# 注册服务器socket
self.selector.register(server_socket, selectors.EVENT_READ, self._accept_connection)
self.running = True
# 事件循环
while self.running:
events = self.selector.select(timeout=1)
for key, mask in events:
callback = key.data
callback(key.fileobj, mask)
def _accept_connection(self, sock: socket.socket, mask: int):
"""接受新连接"""
client_socket, client_address = sock.accept()
client_socket.setblocking(False)
self.client_count += 1
print(f"新客户端连接: {client_address} (总数: {self.client_count})")
# 注册客户端socket
self.selector.register(client_socket, selectors.EVENT_READ,
self._handle_client_data)
def _handle_client_data(self, sock: socket.socket, mask: int):
"""处理客户端数据"""
try:
data = sock.recv(1024)
if data:
# 处理数据
message = data.decode('utf-8').strip()
print(f"收到消息: {message}")
# 回显
response = f"ECHO: {message}\n"
sock.send(response.encode('utf-8'))
else:
# 客户端断开连接
print("客户端断开连接")
self.selector.unregister(sock)
sock.close()
self.client_count -= 1
except ConnectionResetError:
print("连接被重置")
self.selector.unregister(sock)
sock.close()
self.client_count -= 1
class ThreadPoolTCPServer:
"""线程池TCP服务器"""
def __init__(self, host='0.0.0.0', port=8888, max_workers=10):
self.host = host
self.port = port
self.max_workers = max_workers
self.running = False
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.client_tasks = []
def start(self):
"""启动线程池服务器"""
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((self.host, self.port))
server_socket.listen()
self.running = True
print(f"线程池TCP服务器启动在 {self.host}:{self.port}, 最大线程数: {self.max_workers}")
try:
while self.running:
client_socket, client_address = server_socket.accept()
# 提交任务到线程池
future = self.executor.submit(self._handle_client, client_socket, client_address)
self.client_tasks.append(future)
# 清理已完成的任务
self.client_tasks = [f for f in self.client_tasks if not f.done()]
except KeyboardInterrupt:
print("收到中断信号")
finally:
self.stop()
def _handle_client(self, client_socket: socket.socket, client_address: Tuple[str, int]):
"""处理客户端连接(在线程中运行)"""
try:
print(f"线程 {threading.current_thread().name} 处理客户端: {client_address}")
# 发送欢迎消息
welcome = f"欢迎,你是第 {len(self.client_tasks)} 个连接\n"
client_socket.send(welcome.encode('utf-8'))
# 处理客户端消息
while True:
data = client_socket.recv(1024)
if not data:
break
message = data.decode('utf-8').strip()
print(f"线程 {threading.current_thread().name} 收到: {message}")
# 模拟耗时操作
if message == 'slow':
time.sleep(5)
response = "慢操作完成\n"
else:
response = f"线程 {threading.current_thread().name} 回复: {message}\n"
client_socket.send(response.encode('utf-8'))
except Exception as e:
print(f"客户端处理错误: {e}")
finally:
client_socket.close()
print(f"客户端 {client_address} 连接关闭")
def stop(self):
"""停止服务器"""
self.running = False
self.executor.shutdown(wait=True)
print("服务器已停止")
class IOMultiplexingServer:
"""I/O多路复用服务器"""
def __init__(self, host='0.0.0.0', port=8888):
self.host = host
self.port = port
self.running = False
# 使用epoll(Linux)或kqueue(BSD)
if hasattr(selectors, 'EpollSelector'):
self.selector = selectors.EpollSelector()
print("使用epoll")
elif hasattr(selectors, 'KqueueSelector'):
self.selector = selectors.KqueueSelector()
print("使用kqueue")
else:
self.selector = selectors.DefaultSelector()
print("使用select")
def start(self):
"""启动I/O多路复用服务器"""
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.setblocking(False)
server_socket.bind((self.host, self.port))
server_socket.listen()
# 注册服务器socket
self.selector.register(server_socket, selectors.EVENT_READ, self._accept)
self.running = True
print(f"I/O多路复用服务器启动在 {self.host}:{self.port}")
# 事件循环
while self.running:
try:
events = self.selector.select(timeout=0.1)
for key, mask in events:
callback = key.data
callback(key.fileobj, mask)
except KeyboardInterrupt:
print("收到中断信号")
break
except Exception as e:
print(f"事件循环错误: {e}")
break
self.selector.close()
server_socket.close()
def _accept(self, sock: socket.socket, mask: int):
"""接受新连接"""
client_socket, client_address = sock.accept()
client_socket.setblocking(False)
print(f"新连接: {client_address}")
# 注册客户端socket
self.selector.register(client_socket, selectors.EVENT_READ, self._read)
def _read(self, sock: socket.socket, mask: int):
"""读取数据"""
try:
data = sock.recv(1024)
if data:
message = data.decode('utf-8').strip()
print(f"收到: {message}")
# 回显
response = f"多路复用回显: {message}\n"
sock.send(response.encode('utf-8'))
else:
# 客户端断开连接
print("客户端断开连接")
self.selector.unregister(sock)
sock.close()
except ConnectionResetError:
print("连接被重置")
self.selector.unregister(sock)
sock.close()
def benchmark_servers():
"""服务器性能基准测试"""
import multiprocessing
def test_server(server_class, server_name, num_clients=100):
"""测试服务器性能"""
print(f"\n=== 测试 {server_name} ===")
# 启动服务器进程
server_process = multiprocessing.Process(
target=lambda: server_class(port=9990).start()
)
server_process.start()
time.sleep(2) # 等待服务器启动
# 创建测试客户端
clients = []
start_time = time.time()
for i in range(num_clients):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', 9990))
sock.send(b'test\n')
sock.recv(1024)
sock.close()
clients.append(sock)
except:
pass
end_time = time.time()
duration = end_time - start_time
print(f"处理 {len(clients)}/{num_clients} 个客户端")
print(f"总时间: {duration:.2f} 秒")
print(f"平均时间: {duration/num_clients*1000:.2f} 毫秒/客户端")
# 停止服务器
server_process.terminate()
server_process.join()
return len(clients), duration
# 测试不同的服务器实现
results = []
# 测试基础服务器
from tcp_basic_server import TCPBasicServer
clients, duration = test_server(TCPBasicServer, "基础TCP服务器", 50)
results.append(("基础TCP服务器", clients, duration))
# 测试线程池服务器
clients, duration = test_server(ThreadPoolTCPServer, "线程池TCP服务器", 50)
results.append(("线程池TCP服务器", clients, duration))
# 输出比较结果
print("\n=== 性能比较 ===")
for name, clients, duration in results:
print(f"{name}: {clients} 客户端, {duration:.2f} 秒, "
f"{duration/clients*1000:.2f} 毫秒/客户端")
if __name__ == "__main__":
# 运行基准测试
benchmark_servers()
4.2 网络安全与加密
python
# network_security.py
import socket
import ssl
import hashlib
import hmac
import base64
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
from typing import Tuple, Optional
import os
class SecureTCPConnection:
"""安全TCP连接"""
def __init__(self, use_ssl: bool = True, verify_cert: bool = True):
self.use_ssl = use_ssl
self.verify_cert = verify_cert
self.ssl_context = None
if use_ssl:
self._create_ssl_context()
def _create_ssl_context(self):
"""创建SSL上下文"""
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if self.verify_cert:
# 验证服务器证书
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.check_hostname = True
# 加载系统CA证书
self.ssl_context.load_default_certs()
else:
# 不验证证书(仅用于测试)
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
def create_secure_socket(self, host: str, port: int) -> socket.socket:
"""创建安全socket"""
# 创建普通TCP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(10.0)
if self.use_ssl and self.ssl_context:
# 包装为SSL socket
ssl_sock = self.ssl_context.wrap_socket(sock, server_hostname=host)
return ssl_sock
return sock
def connect(self, host: str, port: int) -> Optional[socket.socket]:
"""安全连接到服务器"""
try:
sock = self.create_secure_socket(host, port)
sock.connect((host, port))
if self.use_ssl:
# 获取证书信息
cert = sock.getpeercert()
if cert:
print(f"服务器证书: {cert.get('subject', {})}")
return sock
except ssl.SSLError as e:
print(f"SSL连接错误: {e}")
return None
except Exception as e:
print(f"连接错误: {e}")
return None
class MessageAuthenticator:
"""消息认证"""
def __init__(self, secret_key: bytes):
self.secret_key = secret_key
def create_hmac(self, message: bytes) -> bytes:
"""创建HMAC"""
return hmac.new(self.secret_key, message, hashlib.sha256).digest()
def verify_hmac(self, message: bytes, received_hmac: bytes) -> bool:
"""验证HMAC"""
expected_hmac = self.create_hmac(message)
return hmac.compare_digest(expected_hmac, received_hmac)
def create_signed_message(self, message: bytes) -> bytes:
"""创建签名消息"""
hmac_digest = self.create_hmac(message)
return message + hmac_digest
def verify_signed_message(self, signed_message: bytes) -> Tuple[bool, Optional[bytes]]:
"""验证签名消息"""
if len(signed_message) < 32: # SHA256 HMAC是32字节
return False, None
message = signed_message[:-32]
received_hmac = signed_message[-32:]
if self.verify_hmac(message, received_hmac):
return True, message
else:
return False, None
class DataEncryptor:
"""数据加密"""
def __init__(self, password: str, salt: Optional[bytes] = None):
self.password = password.encode('utf-8')
if salt is None:
self.salt = os.urandom(16)
else:
self.salt = salt
# 派生密钥
self._derive_key()
def _derive_key(self):
"""派生加密密钥"""
kdf = PBKDF2(
algorithm=hashes.SHA256(),
length=32,
salt=self.salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(self.password))
self.cipher = Fernet(key)
def encrypt(self, data: bytes) -> bytes:
"""加密数据"""
return self.cipher.encrypt(data)
def decrypt(self, encrypted_data: bytes) -> bytes:
"""解密数据"""
return self.cipher.decrypt(encrypted_data)
def encrypt_message(self, message: str) -> Tuple[bytes, bytes]:
"""加密消息,返回加密数据和salt"""
encrypted = self.encrypt(message.encode('utf-8'))
return encrypted, self.salt
@staticmethod
def decrypt_message(encrypted_data: bytes, salt: bytes, password: str) -> Optional[str]:
"""解密消息"""
try:
encryptor = DataEncryptor(password, salt)
decrypted = encryptor.decrypt(encrypted_data)
return decrypted.decode('utf-8')
except:
return None
class SecureChatClient:
"""安全聊天客户端"""
def __init__(self, host: str, port: int, shared_secret: str):
self.host = host
self.port = port
self.shared_secret = shared_secret
# 初始化安全组件
self.secure_conn = SecureTCPConnection(use_ssl=True)
self.authenticator = MessageAuthenticator(shared_secret.encode('utf-8'))
self.encryptor = DataEncryptor(shared_secret)
def connect_and_chat(self):
"""连接并开始安全聊天"""
# 安全连接
sock = self.secure_conn.connect(self.host, self.port)
if not sock:
return
print("安全连接已建立,开始聊天...")
try:
# 握手:交换salt
sock.send(self.encryptor.salt)
server_salt = sock.recv(16)
while True:
# 发送消息
message = input("你: ").strip()
if message.lower() == 'quit':
break
# 加密和签名消息
encrypted, _ = self.encryptor.encrypt_message(message)
signed = self.authenticator.create_signed_message(encrypted)
# 发送
sock.send(len(signed).to_bytes(4, 'big'))
sock.send(signed)
# 接收响应
response_len = int.from_bytes(sock.recv(4), 'big')
signed_response = sock.recv(response_len)
# 验证和解密响应
valid, encrypted_response = self.authenticator.verify_signed_message(signed_response)
if valid:
response = DataEncryptor.decrypt_message(
encrypted_response, server_salt, self.shared_secret
)
if response:
print(f"服务器: {response}")
else:
print("解密响应失败")
else:
print("响应认证失败")
except Exception as e:
print(f"聊天错误: {e}")
finally:
sock.close()
print("连接已关闭")
def main():
"""安全客户端示例"""
import argparse
parser = argparse.ArgumentParser(description='安全聊天客户端')
parser.add_argument('--host', default='localhost', help='服务器地址')
parser.add_argument('--port', type=int, default=8888, help='服务器端口')
parser.add_argument('--secret', default='my-secret-key', help='共享密钥')
args = parser.parse_args()
client = SecureChatClient(args.host, args.port, args.secret)
client.connect_and_chat()
if __name__ == "__main__":
main()
4.3 网络诊断与调试工具
python
# network_diagnostics.py
import socket
import struct
import time
import subprocess
import platform
from typing import Dict, List, Tuple, Optional
import ipaddress
class NetworkScanner:
"""网络扫描器"""
def __init__(self):
self.open_ports = {}
def scan_ports(self, target_ip: str, start_port: int = 1,
end_port: int = 1024, timeout: float = 1.0) -> Dict[int, str]:
"""扫描端口"""
print(f"开始扫描 {target_ip} 端口 {start_port}-{end_port}")
open_ports = {}
for port in range(start_port, end_port + 1):
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((target_ip, port))
if result == 0:
# 端口开放,尝试获取服务信息
try:
sock.send(b'HEAD / HTTP/1.0\r\n\r\n')
banner = sock.recv(1024)
# 尝试解码
try:
banner_str = banner.decode('utf-8', errors='ignore')[:100]
except:
banner_str = str(banner[:100])
service_info = self._guess_service(port, banner_str)
except:
service_info = self._guess_service(port)
open_ports[port] = service_info
print(f"端口 {port} 开放: {service_info}")
sock.close()
except Exception as e:
print(f"扫描端口 {port} 时出错: {e}")
# 显示进度
if port % 100 == 0:
print(f"已扫描 {port - start_port + 1}/{end_port - start_port + 1} 个端口")
print(f"扫描完成,发现 {len(open_ports)} 个开放端口")
return open_ports
def _guess_service(self, port: int, banner: str = "") -> str:
"""猜测端口服务"""
common_ports = {
20: "FTP Data",
21: "FTP Control",
22: "SSH",
23: "Telnet",
25: "SMTP",
53: "DNS",
80: "HTTP",
110: "POP3",
143: "IMAP",
443: "HTTPS",
465: "SMTPS",
993: "IMAPS",
995: "POP3S",
3306: "MySQL",
3389: "RDP",
5432: "PostgreSQL",
6379: "Redis",
8080: "HTTP Proxy",
8443: "HTTPS Alt",
}
if port in common_ports:
service = common_ports[port]
elif "HTTP" in banner.upper():
service = "HTTP Service"
elif "SMTP" in banner.upper():
service = "SMTP Service"
else:
service = "Unknown"
if banner:
service += f" ({banner[:50]})"
return service
def ping_sweep(self, network: str, timeout: float = 1.0) -> List[str]:
"""Ping扫描网络"""
print(f"开始Ping扫描网络: {network}")
active_hosts = []
try:
net = ipaddress.ip_network(network, strict=False)
for ip in net.hosts():
ip_str = str(ip)
# 使用系统ping命令
param = '-n' if platform.system().lower() == 'windows' else '-c'
command = ['ping', param, '1', '-W', str(int(timeout * 1000)), ip_str]
try:
output = subprocess.run(command, capture_output=True, timeout=timeout + 1)
if output.returncode == 0:
active_hosts.append(ip_str)
print(f"发现活动主机: {ip_str}")
except subprocess.TimeoutExpired:
continue
except Exception as e:
print(f"Ping {ip_str} 失败: {e}")
# 限制并发
time.sleep(0.1)
except Exception as e:
print(f"网络扫描错误: {e}")
print(f"扫描完成,发现 {len(active_hosts)} 个活动主机")
return active_hosts
class NetworkMonitor:
"""网络监控器"""
def __init__(self, interface: str = ""):
self.interface = interface
self.stats = {
'start_time': time.time(),
'bytes_sent': 0,
'bytes_received': 0,
'packets_sent': 0,
'packets_received': 0,
}
def get_network_info(self) -> Dict:
"""获取网络信息"""
info = {}
try:
# 获取主机名
info['hostname'] = socket.gethostname()
# 获取IP地址
info['ip_address'] = socket.gethostbyname(info['hostname'])
# 获取网络接口信息
import netifaces
interfaces = netifaces.interfaces()
interface_info = {}
for iface in interfaces:
if self.interface and iface != self.interface:
continue
addrs = netifaces.ifaddresses(iface)
iface_info = {}
if netifaces.AF_INET in addrs:
iface_info['ipv4'] = addrs[netifaces.AF_INET]
if netifaces.AF_INET6 in addrs:
iface_info['ipv6'] = addrs[netifaces.AF_INET6]
if netifaces.AF_LINK in addrs:
iface_info['mac'] = addrs[netifaces.AF_LINK]
if iface_info:
interface_info[iface] = iface_info
info['interfaces'] = interface_info
except ImportError:
info['error'] = "需要netifaces库"
except Exception as e:
info['error'] = str(e)
return info
def monitor_traffic(self, duration: int = 60):
"""监控网络流量(简化版)"""
print(f"开始监控网络流量,持续时间: {duration} 秒")
try:
import psutil
start_counters = psutil.net_io_counters(pernic=True)
for i in range(duration):
time.sleep(1)
current_counters = psutil.net_io_counters(pernic=True)
for iface in current_counters:
if self.interface and iface != self.interface:
continue
if iface in start_counters:
bytes_sent = current_counters[iface].bytes_sent - start_counters[iface].bytes_sent
bytes_recv = current_counters[iface].bytes_recv - start_counters[iface].bytes_recv
self.stats['bytes_sent'] += bytes_sent
self.stats['bytes_received'] += bytes_recv
self.stats['packets_sent'] += current_counters[iface].packets_sent - start_counters[iface].packets_sent
self.stats['packets_received'] += current_counters[iface].packets_recv - start_counters[iface].packets_recv
if bytes_sent > 0 or bytes_recv > 0:
print(f"[{iface}] 发送: {bytes_sent/1024:.1f} KB/s, "
f"接收: {bytes_recv/1024:.1f} KB/s")
if i % 10 == 0:
self._print_summary()
print("\n监控结束")
self._print_summary()
except ImportError:
print("需要psutil库")
except Exception as e:
print(f"监控错误: {e}")
def _print_summary(self):
"""打印摘要"""
duration = time.time() - self.stats['start_time']
print(f"\n=== 流量摘要 ===")
print(f"持续时间: {duration:.1f} 秒")
print(f"总发送字节: {self.stats['bytes_sent']}")
print(f"总接收字节: {self.stats['bytes_received']}")
print(f"总发送数据包: {self.stats['packets_sent']}")
print(f"总接收数据包: {self.stats['packets_received']}")
print(f"平均发送速率: {self.stats['bytes_sent']/duration/1024:.1f} KB/s")
print(f"平均接收速率: {self.stats['bytes_received']/duration/1024:.1f} KB/s")
class PacketAnalyzer:
"""数据包分析器(简化版)"""
def __init__(self):
self.packet_count = 0
self.protocol_stats = {}
def analyze_packet(self, raw_data: bytes):
"""分析数据包"""
self.packet_count += 1
if len(raw_data) < 20:
return
# 解析IP头部
try:
# IP头部前20字节
ip_header = raw_data[:20]
# 解析IP版本和头部长度
first_byte = ip_header[0]
version = first_byte >> 4
ihl = (first_byte & 0x0F) * 4
if version == 4 and len(raw_data) >= ihl:
# 解析协议类型
protocol = ip_header[9]
protocol_names = {
1: "ICMP",
6: "TCP",
17: "UDP",
}
protocol_name = protocol_names.get(protocol, f"Unknown({protocol})")
# 更新统计
if protocol_name not in self.protocol_stats:
self.protocol_stats[protocol_name] = 0
self.protocol_stats[protocol_name] += 1
# 解析源和目的IP
src_ip = socket.inet_ntoa(ip_header[12:16])
dst_ip = socket.inet_ntoa(ip_header[16:20])
if self.packet_count % 100 == 0:
print(f"数据包 {self.packet_count}: {src_ip} -> {dst_ip} [{protocol_name}]")
except Exception as e:
print(f"解析数据包错误: {e}")
def print_stats(self):
"""打印统计信息"""
print(f"\n=== 数据包分析统计 ===")
print(f"总数据包数: {self.packet_count}")
print("协议分布:")
for protocol, count in self.protocol_stats.items():
percentage = (count / self.packet_count * 100) if self.packet_count > 0 else 0
print(f" {protocol}: {count} ({percentage:.1f}%)")
def network_diagnostics_menu():
"""网络诊断菜单"""
import argparse
parser = argparse.ArgumentParser(description='网络诊断工具')
parser.add_argument('--scan', metavar='IP', help='扫描主机的开放端口')
parser.add_argument('--ping-sweep', metavar='NETWORK', help='扫描网络中的活动主机')
parser.add_argument('--monitor', action='store_true', help='监控网络流量')
parser.add_argument('--info', action='store_true', help='显示网络信息')
args = parser.parse_args()
if args.scan:
scanner = NetworkScanner()
ports = scanner.scan_ports(args.scan, 1, 1000)
if ports:
print("\n开放端口列表:")
for port, service in ports.items():
print(f" {port}: {service}")
elif args.ping_sweep:
scanner = NetworkScanner()
hosts = scanner.ping_sweep(args.ping_sweep)
if hosts:
print("\n活动主机列表:")
for host in hosts:
print(f" {host}")
elif args.monitor:
monitor = NetworkMonitor()
monitor.monitor_traffic(duration=30)
elif args.info:
monitor = NetworkMonitor()
info = monitor.get_network_info()
print("网络信息:")
for key, value in info.items():
if key == 'interfaces':
print(f"{key}:")
for iface, iface_info in value.items():
print(f" {iface}:")
for addr_type, addresses in iface_info.items():
for addr in addresses:
print(f" {addr_type}: {addr.get('addr', 'N/A')}")
else:
print(f"{key}: {value}")
else:
parser.print_help()
if __name__ == "__main__":
network_diagnostics_menu()
5. 实战项目:多协议聊天系统
python
# multiprotocol_chat_system.py
"""
多协议聊天系统
支持TCP、UDP、WebSocket协议
"""
import socket
import threading
import json
import time
import struct
import base64
import hashlib
from enum import Enum
from typing import Dict, List, Tuple, Optional, Any
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('ChatSystem')
class Protocol(Enum):
"""协议类型"""
TCP = "TCP"
UDP = "UDP"
WEBSOCKET = "WebSocket"
class ChatMessage:
"""聊天消息类"""
def __init__(self, sender: str, content: str, timestamp: float = None):
self.sender = sender
self.content = content
self.timestamp = timestamp or time.time()
self.message_id = hashlib.md5(
f"{sender}{content}{self.timestamp}".encode()
).hexdigest()[:8]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
'sender': self.sender,
'content': self.content,
'timestamp': self.timestamp,
'message_id': self.message_id,
'time_str': time.strftime("%H:%M:%S", time.localtime(self.timestamp))
}
def to_json(self) -> str:
"""转换为JSON"""
return json.dumps(self.to_dict())
@staticmethod
def from_json(json_str: str) -> 'ChatMessage':
"""从JSON创建"""
data = json.loads(json_str)
return ChatMessage(
sender=data['sender'],
content=data['content'],
timestamp=data['timestamp']
)
class ChatRoom:
"""聊天室"""
def __init__(self, room_id: str, name: str):
self.room_id = room_id
self.name = name
self.messages: List[ChatMessage] = []
self.members: Dict[str, Any] = {} # username -> connection info
self.created_at = time.time()
self.max_messages = 1000 # 最大消息数量
def add_message(self, message: ChatMessage):
"""添加消息"""
self.messages.append(message)
# 限制消息数量
if len(self.messages) > self.max_messages:
self.messages = self.messages[-self.max_messages:]
def add_member(self, username: str, connection_info: Dict):
"""添加成员"""
self.members[username] = {
**connection_info,
'joined_at': time.time(),
'last_active': time.time()
}
def remove_member(self, username: str):
"""移除成员"""
if username in self.members:
del self.members[username]
def broadcast_message(self, message: ChatMessage, exclude_sender: str = None):
"""广播消息给所有成员"""
message_data = message.to_json()
for username, member_info in self.members.items():
if username == exclude_sender:
continue
# 这里应该通过成员连接发送消息
# 简化实现,只记录日志
logger.info(f"广播消息到 {username}: {message.content[:50]}...")
def get_info(self) -> Dict[str, Any]:
"""获取聊天室信息"""
return {
'room_id': self.room_id,
'name': self.name,
'member_count': len(self.members),
'message_count': len(self.messages),
'created_at': self.created_at,
'members': list(self.members.keys())
}
class MultiProtocolChatServer:
"""多协议聊天服务器"""
def __init__(self, host: str = '0.0.0.0', tcp_port: int = 8888,
udp_port: int = 9999, ws_port: int = 8765):
self.host = host
self.tcp_port = tcp_port
self.udp_port = udp_port
self.ws_port = ws_port
# 聊天室管理
self.chat_rooms: Dict[str, ChatRoom] = {
'general': ChatRoom('general', 'General Chat'),
'random': ChatRoom('random', 'Random Talk'),
'help': ChatRoom('help', 'Help & Support'),
}
# 用户管理
self.users: Dict[str, Dict] = {} # username -> user info
# 协议服务器
self.tcp_server: Optional[threading.Thread] = None
self.udp_server: Optional[threading.Thread] = None
self.ws_server: Optional[threading.Thread] = None
self.running = False
def start(self):
"""启动所有协议服务器"""
self.running = True
logger.info("启动多协议聊天服务器...")
logger.info(f"TCP: {self.host}:{self.tcp_port}")
logger.info(f"UDP: {self.host}:{self.udp_port}")
logger.info(f"WebSocket: {self.host}:{self.ws_port}")
# 启动TCP服务器
self.tcp_server = threading.Thread(
target=self._start_tcp_server,
daemon=True
)
self.tcp_server.start()
# 启动UDP服务器
self.udp_server = threading.Thread(
target=self._start_udp_server,
daemon=True
)
self.udp_server.start()
# 启动WebSocket服务器(简化版)
self.ws_server = threading.Thread(
target=self._start_websocket_server,
daemon=True
)
self.ws_server.start()
# 主线程等待
try:
while self.running:
time.sleep(1)
except KeyboardInterrupt:
self.stop()
def _start_tcp_server(self):
"""启动TCP服务器"""
try:
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((self.host, self.tcp_port))
server_socket.listen(5)
logger.info(f"TCP服务器已启动")
while self.running:
try:
client_socket, client_address = server_socket.accept()
# 为每个客户端创建线程
client_thread = threading.Thread(
target=self._handle_tcp_client,
args=(client_socket, client_address),
daemon=True
)
client_thread.start()
except Exception as e:
if self.running:
logger.error(f"TCP接受连接错误: {e}")
except Exception as e:
logger.error(f"TCP服务器启动失败: {e}")
def _handle_tcp_client(self, client_socket: socket.socket, client_address: Tuple[str, int]):
"""处理TCP客户端"""
client_id = f"TCP_{client_address[0]}:{client_address[1]}"
logger.info(f"TCP客户端连接: {client_id}")
try:
# 发送欢迎消息
welcome = {
'type': 'welcome',
'message': 'Welcome to Multi-Protocol Chat Server',
'protocol': 'TCP',
'timestamp': time.time(),
'rooms': [room.get_info() for room in self.chat_rooms.values()]
}
client_socket.send(json.dumps(welcome).encode('utf-8'))
# 处理客户端消息
buffer = b""
while self.running:
try:
data = client_socket.recv(4096)
if not data:
break
buffer += data
# 处理完整消息(假设消息以换行符分隔)
while b'\n' in buffer:
message_data, buffer = buffer.split(b'\n', 1)
try:
message = json.loads(message_data.decode('utf-8'))
self._process_client_message(
Protocol.TCP, client_id, client_socket, message
)
except json.JSONDecodeError:
logger.error(f"无效的JSON消息: {message_data}")
except ConnectionResetError:
break
except Exception as e:
logger.error(f"处理TCP消息错误: {e}")
break
except Exception as e:
logger.error(f"TCP客户端处理错误: {e}")
finally:
client_socket.close()
logger.info(f"TCP客户端断开: {client_id}")
def _start_udp_server(self):
"""启动UDP服务器"""
try:
udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
udp_socket.bind((self.host, self.udp_port))
logger.info(f"UDP服务器已启动")
while self.running:
try:
data, client_address = udp_socket.recvfrom(65536)
# 处理UDP消息
try:
message = json.loads(data.decode('utf-8'))
self._process_udp_message(udp_socket, client_address, message)
except json.JSONDecodeError:
logger.error(f"无效的UDP JSON消息")
# 发送错误响应
error_response = {
'type': 'error',
'message': 'Invalid JSON message'
}
udp_socket.sendto(
json.dumps(error_response).encode('utf-8'),
client_address
)
except Exception as e:
if self.running:
logger.error(f"UDP处理错误: {e}")
except Exception as e:
logger.error(f"UDP服务器启动失败: {e}")
def _process_udp_message(self, udp_socket: socket.socket,
client_address: Tuple[str, int], message: Dict):
"""处理UDP消息"""
client_id = f"UDP_{client_address[0]}:{client_address[1]}"
message_type = message.get('type', '')
if message_type == 'connect':
# UDP连接请求
response = {
'type': 'connected',
'message': 'UDP connection established',
'client_id': client_id,
'timestamp': time.time()
}
udp_socket.sendto(json.dumps(response).encode('utf-8'), client_address)
elif message_type == 'chat':
# 聊天消息
room_id = message.get('room_id', 'general')
content = message.get('content', '')
sender = message.get('sender', 'Anonymous')
if room_id in self.chat_rooms:
chat_message = ChatMessage(sender, content)
self.chat_rooms[room_id].add_message(chat_message)
# 广播消息
self.chat_rooms[room_id].broadcast_message(chat_message, exclude_sender=sender)
# 发送确认
response = {
'type': 'message_ack',
'message_id': chat_message.message_id,
'timestamp': time.time()
}
udp_socket.sendto(json.dumps(response).encode('utf-8'), client_address)
elif message_type == 'get_rooms':
# 获取聊天室列表
response = {
'type': 'room_list',
'rooms': [room.get_info() for room in self.chat_rooms.values()],
'timestamp': time.time()
}
udp_socket.sendto(json.dumps(response).encode('utf-8'), client_address)
def _start_websocket_server(self):
"""启动WebSocket服务器(简化版)"""
# 这里简化实现,实际需要完整的WebSocket协议处理
logger.info("WebSocket服务器已启动(简化版)")
# 在实际项目中,这里应该使用websockets库或其他WebSocket实现
# 为了简化,我们只记录日志
while self.running:
time.sleep(5)
logger.debug("WebSocket服务器运行中...")
def _process_client_message(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理客户端消息(通用)"""
message_type = message.get('type', '')
logger.info(f"{protocol.value} 消息 [{message_type}]: {message.get('content', '')[:50]}...")
if message_type == 'login':
self._handle_login(protocol, client_id, connection, message)
elif message_type == 'join_room':
self._handle_join_room(protocol, client_id, connection, message)
elif message_type == 'chat':
self._handle_chat_message(protocol, client_id, connection, message)
elif message_type == 'leave_room':
self._handle_leave_room(protocol, client_id, connection, message)
elif message_type == 'create_room':
self._handle_create_room(protocol, client_id, connection, message)
elif message_type == 'list_users':
self._handle_list_users(protocol, client_id, connection, message)
elif message_type == 'list_rooms':
self._handle_list_rooms(protocol, client_id, connection, message)
def _handle_login(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理登录"""
username = message.get('username', '').strip()
password = message.get('password', '') # 简化处理,实际需要加密
if not username:
response = {'type': 'error', 'message': 'Username is required'}
elif username in self.users:
response = {'type': 'error', 'message': 'Username already exists'}
else:
# 注册用户
self.users[username] = {
'protocol': protocol.value,
'client_id': client_id,
'login_time': time.time(),
'last_active': time.time(),
'current_room': 'general'
}
# 加入默认聊天室
if 'general' in self.chat_rooms:
self.chat_rooms['general'].add_member(username, {
'protocol': protocol.value,
'client_id': client_id
})
response = {
'type': 'login_success',
'username': username,
'message': 'Login successful',
'timestamp': time.time()
}
self._send_response(protocol, connection, client_id, response)
def _handle_join_room(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理加入聊天室"""
username = message.get('username', '')
room_id = message.get('room_id', 'general')
if username not in self.users:
response = {'type': 'error', 'message': 'User not logged in'}
elif room_id not in self.chat_rooms:
response = {'type': 'error', 'message': 'Room does not exist'}
else:
# 离开当前聊天室
current_room = self.users[username].get('current_room')
if current_room in self.chat_rooms:
self.chat_rooms[current_room].remove_member(username)
# 加入新聊天室
self.chat_rooms[room_id].add_member(username, {
'protocol': protocol.value,
'client_id': client_id
})
self.users[username]['current_room'] = room_id
self.users[username]['last_active'] = time.time()
# 发送房间历史消息
room_messages = [
msg.to_dict() for msg in
self.chat_rooms[room_id].messages[-20:] # 最近20条消息
]
response = {
'type': 'room_joined',
'room_id': room_id,
'room_name': self.chat_rooms[room_id].name,
'members': list(self.chat_rooms[room_id].members.keys()),
'history': room_messages,
'timestamp': time.time()
}
# 通知其他成员
join_message = ChatMessage(
'System',
f"{username} 加入了聊天室"
)
self.chat_rooms[room_id].add_message(join_message)
self.chat_rooms[room_id].broadcast_message(join_message, exclude_sender=username)
self._send_response(protocol, connection, client_id, response)
def _handle_chat_message(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理聊天消息"""
username = message.get('username', '')
room_id = message.get('room_id', 'general')
content = message.get('content', '').strip()
if not content:
return
if username not in self.users:
response = {'type': 'error', 'message': 'User not logged in'}
self._send_response(protocol, connection, client_id, response)
return
if room_id not in self.chat_rooms:
response = {'type': 'error', 'message': 'Room does not exist'}
self._send_response(protocol, connection, client_id, response)
return
# 创建聊天消息
chat_message = ChatMessage(username, content)
self.chat_rooms[room_id].add_message(chat_message)
# 更新用户活动时间
self.users[username]['last_active'] = time.time()
# 发送确认
response = {
'type': 'message_sent',
'message_id': chat_message.message_id,
'timestamp': time.time()
}
self._send_response(protocol, connection, client_id, response)
# 广播消息给房间其他成员
self.chat_rooms[room_id].broadcast_message(chat_message, exclude_sender=username)
def _handle_leave_room(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理离开聊天室"""
username = message.get('username', '')
room_id = message.get('room_id', 'general')
if username in self.users and room_id in self.chat_rooms:
self.chat_rooms[room_id].remove_member(username)
# 通知其他成员
leave_message = ChatMessage(
'System',
f"{username} 离开了聊天室"
)
self.chat_rooms[room_id].add_message(leave_message)
self.chat_rooms[room_id].broadcast_message(leave_message, exclude_sender=username)
response = {
'type': 'room_left',
'room_id': room_id,
'timestamp': time.time()
}
else:
response = {'type': 'error', 'message': 'Operation failed'}
self._send_response(protocol, connection, client_id, response)
def _handle_create_room(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理创建聊天室"""
username = message.get('username', '')
room_name = message.get('room_name', '').strip()
if not room_name:
response = {'type': 'error', 'message': 'Room name is required'}
elif username not in self.users:
response = {'type': 'error', 'message': 'User not logged in'}
else:
# 生成房间ID
room_id = hashlib.md5(f"{room_name}{time.time()}".encode()).hexdigest()[:8]
# 创建聊天室
new_room = ChatRoom(room_id, room_name)
self.chat_rooms[room_id] = new_room
# 创建者自动加入
new_room.add_member(username, {
'protocol': protocol.value,
'client_id': client_id
})
self.users[username]['current_room'] = room_id
response = {
'type': 'room_created',
'room_id': room_id,
'room_name': room_name,
'timestamp': time.time()
}
self._send_response(protocol, connection, client_id, response)
def _handle_list_users(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理列出用户"""
room_id = message.get('room_id', 'general')
if room_id in self.chat_rooms:
members = list(self.chat_rooms[room_id].members.keys())
response = {
'type': 'user_list',
'room_id': room_id,
'users': members,
'count': len(members),
'timestamp': time.time()
}
else:
response = {'type': 'error', 'message': 'Room does not exist'}
self._send_response(protocol, connection, client_id, response)
def _handle_list_rooms(self, protocol: Protocol, client_id: str,
connection: Any, message: Dict):
"""处理列出聊天室"""
rooms_info = []
for room_id, room in self.chat_rooms.items():
room_info = room.get_info()
rooms_info.append(room_info)
response = {
'type': 'room_list',
'rooms': rooms_info,
'count': len(rooms_info),
'timestamp': time.time()
}
self._send_response(protocol, connection, client_id, response)
def _send_response(self, protocol: Protocol, connection: Any,
client_id: str, response: Dict):
"""发送响应"""
response_json = json.dumps(response) + '\n'
try:
if protocol == Protocol.TCP:
# TCP连接
connection.send(response_json.encode('utf-8'))
elif protocol == Protocol.UDP:
# UDP连接,需要解析client_id获取地址
# 简化实现,这里不处理
pass
elif protocol == Protocol.WEBSOCKET:
# WebSocket连接
# 简化实现,这里不处理
pass
except Exception as e:
logger.error(f"发送响应错误: {e}")
def stop(self):
"""停止服务器"""
self.running = False
logger.info("正在停止多协议聊天服务器...")
# 等待线程结束
if self.tcp_server:
self.tcp_server.join(timeout=2)
if self.udp_server:
self.udp_server.join(timeout=2)
if self.ws_server:
self.ws_server.join(timeout=2)
logger.info("多协议聊天服务器已停止")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='多协议聊天服务器')
parser.add_argument('--host', default='0.0.0.0', help='监听地址')
parser.add_argument('--tcp-port', type=int, default=8888, help='TCP端口')
parser.add_argument('--udp-port', type=int, default=9999, help='UDP端口')
parser.add_argument('--ws-port', type=int, default=8765, help='WebSocket端口')
args = parser.parse_args()
server = MultiProtocolChatServer(
host=args.host,
tcp_port=args.tcp_port,
udp_port=args.udp_port,
ws_port=args.ws_port
)
try:
server.start()
except KeyboardInterrupt:
server.stop()
except Exception as e:
logger.error(f"服务器运行异常: {e}")
server.stop()
if __name__ == "__main__":
main()
6. 性能优化与最佳实践
6.1 网络编程性能优化
python
# performance_optimization.py
"""
网络编程性能优化技巧
"""
import socket
import threading
import time
import queue
import select
from typing import List, Dict, Tuple
import mmap
import os
class HighPerformanceTCPServer:
"""高性能TCP服务器"""
def __init__(self, host='0.0.0.0', port=8888):
self.host = host
self.port = port
self.running = False
# 使用epoll(Linux)或kqueue(BSD)
if hasattr(select, 'epoll'):
self.poller = select.epoll()
self.poll_event = select.EPOLLIN | select.EPOLLET # 边缘触发
elif hasattr(select, 'kqueue'):
self.poller = select.kqueue()
else:
self.poller = select.poll()
# 连接池
self.connection_pool = {}
# 发送队列(避免阻塞)
self.send_queues = {}
# 零拷贝缓冲区
self.buffer_pool = BufferPool(page_size=4096, pool_size=1000)
def start(self):
"""启动高性能服务器"""
# 创建监听socket
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 设置非阻塞
server_socket.setblocking(False)
# 调整TCP参数
server_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # 禁用Nagle算法
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# 调整缓冲区大小
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
server_socket.bind((self.host, self.port))
server_socket.listen(1024) # 大连接队列
# 注册到poller
if hasattr(select, 'epoll'):
self.poller.register(server_socket.fileno(), select.EPOLLIN)
else:
self.poller.register(server_socket, select.POLLIN)
self.running = True
print(f"高性能服务器启动在 {self.host}:{self.port}")
# 事件循环
while self.running:
try:
# 等待事件
if hasattr(select, 'epoll'):
events = self.poller.poll(timeout=1)
for fd, event in events:
if fd == server_socket.fileno():
self._accept_connections(server_socket)
else:
self._handle_client_event(fd, event)
else:
# 其他poll实现
pass
# 处理发送队列
self._process_send_queues()
except KeyboardInterrupt:
break
except Exception as e:
print(f"事件循环错误: {e}")
def _accept_connections(self, server_socket: socket.socket):
"""接受新连接(边缘触发,需要循环接受)"""
while True:
try:
client_socket, client_address = server_socket.accept()
# 设置非阻塞
client_socket.setblocking(False)
# 设置TCP参数
client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# 注册到poller
fd = client_socket.fileno()
if hasattr(select, 'epoll'):
self.poller.register(fd, select.EPOLLIN | select.EPOLLET)
# 保存连接
self.connection_pool[fd] = {
'socket': client_socket,
'address': client_address,
'buffer': self.buffer_pool.get_buffer(),
'last_active': time.time()
}
# 初始化发送队列
self.send_queues[fd] = queue.Queue(maxsize=100)
print(f"新连接: {client_address}")
except BlockingIOError:
# 没有更多连接
break
except Exception as e:
print(f"接受连接错误: {e}")
break
def _handle_client_event(self, fd: int, event: int):
"""处理客户端事件"""
if fd not in self.connection_pool:
return
conn_info = self.connection_pool[fd]
if event & select.EPOLLIN:
# 可读事件
self._handle_client_read(fd, conn_info)
if event & select.EPOLLOUT:
# 可写事件
self._handle_client_write(fd, conn_info)
if event & (select.EPOLLHUP | select.EPOLLERR):
# 连接关闭或错误
self._close_connection(fd, conn_info)
def _handle_client_read(self, fd: int, conn_info: Dict):
"""处理客户端读取"""
try:
client_socket = conn_info['socket']
buffer = conn_info['buffer']
# 边缘触发,需要读取所有数据
while True:
try:
data = client_socket.recv(4096)
if not data:
# 连接关闭
self._close_connection(fd, conn_info)
return
# 处理数据
self._process_client_data(fd, conn_info, data)
# 更新活动时间
conn_info['last_active'] = time.time()
except BlockingIOError:
# 没有更多数据
break
except ConnectionResetError:
self._close_connection(fd, conn_info)
return
except Exception as e:
print(f"读取客户端数据错误: {e}")
self._close_connection(fd, conn_info)
def _process_client_data(self, fd: int, conn_info: Dict, data: bytes):
"""处理客户端数据"""
# 这里可以实现业务逻辑
# 例如:解析协议、处理请求等
# 示例:简单回显
response = f"收到 {len(data)} 字节数据\n".encode('utf-8')
# 添加到发送队列
if fd in self.send_queues:
try:
self.send_queues[fd].put_nowait(response)
except queue.Full:
print(f"发送队列已满,丢弃数据: {fd}")
def _process_send_queues(self):
"""处理发送队列"""
for fd in list(self.send_queues.keys()):
if fd not in self.connection_pool:
continue
queue_obj = self.send_queues[fd]
conn_info = self.connection_pool[fd]
# 尝试发送队列中的所有数据
while not queue_obj.empty():
try:
data = queue_obj.get_nowait()
# 尝试发送
sent = conn_info['socket'].send(data)
if sent < len(data):
# 没有完全发送,放回队列
remaining = data[sent:]
queue_obj.put(remaining)
break
except BlockingIOError:
# 缓冲区满,等待下次可写事件
break
except Exception as e:
print(f"发送数据错误: {e}")
self._close_connection(fd, conn_info)
break
def _close_connection(self, fd: int, conn_info: Dict):
"""关闭连接"""
try:
# 取消注册
if hasattr(select, 'epoll'):
self.poller.unregister(fd)
# 关闭socket
conn_info['socket'].close()
# 释放缓冲区
self.buffer_pool.return_buffer(conn_info['buffer'])
# 清理数据结构
if fd in self.connection_pool:
del self.connection_pool[fd]
if fd in self.send_queues:
del self.send_queues[fd]
print(f"连接关闭: {conn_info['address']}")
except Exception as e:
print(f"关闭连接错误: {e}")
class BufferPool:
"""缓冲区池(减少内存分配)"""
def __init__(self, page_size: int = 4096, pool_size: int = 100):
self.page_size = page_size
self.pool_size = pool_size
self.free_buffers = []
self.allocated_buffers = set()
# 预分配缓冲区
for _ in range(pool_size):
self.free_buffers.append(bytearray(page_size))
def get_buffer(self) -> bytearray:
"""获取缓冲区"""
if self.free_buffers:
buffer = self.free_buffers.pop()
else:
# 池为空,分配新缓冲区
buffer = bytearray(self.page_size)
self.allocated_buffers.add(id(buffer))
return buffer
def return_buffer(self, buffer: bytearray):
"""返回缓冲区到池"""
buffer_id = id(buffer)
if buffer_id in self.allocated_buffers:
# 清空缓冲区
buffer[:] = b'\x00' * len(buffer)
# 放回空闲列表(如果未满)
if len(self.free_buffers) < self.pool_size:
self.free_buffers.append(buffer)
self.allocated_buffers.remove(buffer_id)
class ZeroCopyFileSender:
"""零拷贝文件发送器"""
def __init__(self):
pass
def send_file_zero_copy(self, socket_obj: socket.socket, file_path: str):
"""使用零拷贝发送文件(Linux sendfile系统调用)"""
try:
import os
import mmap
file_size = os.path.getsize(file_path)
with open(file_path, 'rb') as f:
# 使用mmap内存映射文件
with mmap.mmap(f.fileno(), file_size, access=mmap.ACCESS_READ) as mmapped_file:
# 发送文件大小
socket_obj.send(struct.pack('!Q', file_size))
# 分块发送
offset = 0
chunk_size = 65536 # 64KB
while offset < file_size:
# 直接发送mmap内存块
sent = socket_obj.send(
mmapped_file[offset:offset + chunk_size]
)
offset += sent
return True
except Exception as e:
print(f"零拷贝文件发送失败: {e}")
return False
class ConnectionRecycler:
"""连接回收器(减少TCP连接开销)"""
def __init__(self, keep_alive_timeout: int = 60):
self.keep_alive_timeout = keep_alive_timeout
self.connections = {} # (host, port) -> connection info
self.lock = threading.Lock()
def get_connection(self, host: str, port: int) -> Optional[socket.socket]:
"""获取连接(优先使用已存在的连接)"""
key = (host, port)
with self.lock:
if key in self.connections:
conn_info = self.connections[key]
# 检查连接是否仍然有效
if self._is_connection_alive(conn_info['socket']):
conn_info['last_used'] = time.time()
return conn_info['socket']
else:
# 连接已失效,移除
del self.connections[key]
return None
def store_connection(self, host: str, port: int, sock: socket.socket):
"""存储连接供后续使用"""
key = (host, port)
with self.lock:
self.connections[key] = {
'socket': sock,
'last_used': time.time(),
'created': time.time()
}
def cleanup_expired_connections(self):
"""清理过期连接"""
current_time = time.time()
expired_keys = []
with self.lock:
for key, conn_info in self.connections.items():
if current_time - conn_info['last_used'] > self.keep_alive_timeout:
expired_keys.append(key)
for key in expired_keys:
try:
self.connections[key]['socket'].close()
except:
pass
del self.connections[key]
def _is_connection_alive(self, sock: socket.socket) -> bool:
"""检查连接是否存活"""
try:
# 发送0字节数据检查连接状态
sock.send(b'')
return True
except:
return False
# 使用示例
def benchmark_optimizations():
"""性能优化基准测试"""
import multiprocessing
import statistics
def test_server_performance(server_class, server_name, num_connections=100):
"""测试服务器性能"""
print(f"\n=== 测试 {server_name} ===")
# 启动服务器
server_process = multiprocessing.Process(
target=lambda: server_class(port=9991).start()
)
server_process.start()
time.sleep(2)
# 创建客户端连接
latencies = []
successful = 0
for i in range(num_connections):
try:
start_time = time.perf_counter()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5.0)
sock.connect(('localhost', 9991))
# 发送测试数据
test_data = b'test' * 256 # 1KB数据
sock.send(test_data)
# 接收响应
response = sock.recv(1024)
end_time = time.perf_counter()
latency = (end_time - start_time) * 1000 # 毫秒
latencies.append(latency)
sock.close()
successful += 1
if i % 10 == 0:
print(f"已完成 {i+1}/{num_connections} 个连接")
except Exception as e:
print(f"连接 {i+1} 失败: {e}")
# 停止服务器
server_process.terminate()
server_process.join()
# 计算统计信息
if latencies:
stats = {
'successful': successful,
'failed': num_connections - successful,
'min_latency': min(latencies),
'max_latency': max(latencies),
'avg_latency': statistics.mean(latencies),
'median_latency': statistics.median(latencies),
'throughput': successful / (max(latencies) / 1000) if latencies else 0,
}
print(f"成功连接: {stats['successful']}/{num_connections}")
print(f"平均延迟: {stats['avg_latency']:.2f} ms")
print(f"中位数延迟: {stats['median_latency']:.2f} ms")
print(f"吞吐量: {stats['throughput']:.2f} 连接/秒")
return stats
else:
print("所有连接都失败了")
return {}
# 测试基础服务器
from tcp_basic_server import TCPBasicServer
basic_stats = test_server_performance(TCPBasicServer, "基础TCP服务器", 50)
# 测试高性能服务器
highperf_stats = test_server_performance(HighPerformanceTCPServer, "高性能TCP服务器", 50)
# 性能比较
if basic_stats and highperf_stats:
print("\n=== 性能提升 ===")
latency_improvement = (basic_stats['avg_latency'] - highperf_stats['avg_latency']) / basic_stats['avg_latency'] * 100
throughput_improvement = (highperf_stats['throughput'] - basic_stats['throughput']) / basic_stats['throughput'] * 100
print(f"延迟提升: {latency_improvement:.1f}%")
print(f"吞吐量提升: {throughput_improvement:.1f}%")
if __name__ == "__main__":
benchmark_optimizations()
6.2 错误处理与容错
python
# error_handling.py
"""
网络编程中的错误处理与容错机制
"""
import socket
import time
import random
from typing import Optional, Callable
import logging
from functools import wraps
class NetworkError(Exception):
"""网络错误基类"""
pass
class ConnectionError(NetworkError):
"""连接错误"""
pass
class TimeoutError(NetworkError):
"""超时错误"""
pass
class ProtocolError(NetworkError):
"""协议错误"""
pass
def retry_on_failure(max_retries: int = 3, delay: float = 1.0,
backoff_factor: float = 2.0,
exceptions: tuple = (ConnectionError, TimeoutError)):
"""重试装饰器"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt < max_retries - 1:
wait_time = delay * (backoff_factor ** attempt)
wait_time += random.uniform(0, 0.1 * wait_time) # 添加抖动
logging.warning(
f"尝试 {func.__name__} 失败 (第 {attempt + 1} 次), "
f"{wait_time:.2f} 秒后重试. 错误: {e}"
)
time.sleep(wait_time)
else:
logging.error(
f"尝试 {func.__name__} 失败 {max_retries} 次, 放弃. "
f"最后错误: {e}"
)
raise last_exception
return wrapper
return decorator
def circuit_breaker(failure_threshold: int = 5, reset_timeout: float = 60.0):
"""断路器装饰器"""
def decorator(func: Callable):
func_state = {
'failures': 0,
'last_failure': None,
'state': 'CLOSED', # CLOSED, OPEN, HALF_OPEN
}
@wraps(func)
def wrapper(*args, **kwargs):
current_time = time.time()
# 检查断路器状态
if func_state['state'] == 'OPEN':
# 检查是否应该尝试恢复
if (func_state['last_failure'] and
current_time - func_state['last_failure'] > reset_timeout):
func_state['state'] = 'HALF_OPEN'
logging.info(f"断路器进入半开状态: {func.__name__}")
else:
raise ConnectionError(f"断路器打开,拒绝请求: {func.__name__}")
try:
result = func(*args, **kwargs)
# 成功调用,重置断路器
if func_state['state'] == 'HALF_OPEN':
func_state['state'] = 'CLOSED'
func_state['failures'] = 0
logging.info(f"断路器关闭: {func.__name__}")
return result
except (ConnectionError, TimeoutError) as e:
# 记录失败
func_state['failures'] += 1
func_state['last_failure'] = current_time
if func_state['state'] == 'HALF_OPEN':
# 半开状态下失败,重新打开
func_state['state'] = 'OPEN'
logging.error(f"半开状态下失败,重新打开断路器: {func.__name__}")
elif (func_state['state'] == 'CLOSED' and
func_state['failures'] >= failure_threshold):
# 达到失败阈值,打开断路器
func_state['state'] = 'OPEN'
logging.error(f"达到失败阈值,打开断路器: {func.__name__}")
raise e
return wrapper
return decorator
class ResilientTCPClient:
"""具有容错能力的TCP客户端"""
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.socket: Optional[socket.socket] = None
# 断路器状态
self.circuit_state = 'CLOSED'
self.failure_count = 0
self.last_failure_time = None
@retry_on_failure(max_retries=3, delay=1.0, backoff_factor=2.0)
@circuit_breaker(failure_threshold=5, reset_timeout=30.0)
def connect_with_resilience(self, timeout: float = 5.0) -> bool:
"""具有容错能力的连接"""
try:
# 检查断路器
if self._check_circuit_breaker():
raise ConnectionError("断路器打开,拒绝连接")
# 创建socket
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(timeout)
# 连接
self.socket.connect((self.host, self.port))
# 连接成功,重置失败计数
self._reset_failure()
logging.info(f"成功连接到 {self.host}:{self.port}")
return True
except socket.timeout:
self._record_failure()
raise TimeoutError(f"连接超时: {self.host}:{self.port}")
except ConnectionRefusedError:
self._record_failure()
raise ConnectionError(f"连接被拒绝: {self.host}:{self.port}")
except Exception as e:
self._record_failure()
raise ConnectionError(f"连接错误: {e}")
def _check_circuit_breaker(self) -> bool:
"""检查断路器状态"""
current_time = time.time()
if self.circuit_state == 'OPEN':
# 检查是否应该进入半开状态
if (self.last_failure_time and
current_time - self.last_failure_time > 30.0):
self.circuit_state = 'HALF_OPEN'
return False
return True
return False
def _record_failure(self):
"""记录失败"""
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= 5:
self.circuit_state = 'OPEN'
logging.error(f"达到失败阈值,打开断路器: {self.host}:{self.port}")
def _reset_failure(self):
"""重置失败计数"""
self.failure_count = 0
self.circuit_state = 'CLOSED'
self.last_failure_time = None
@retry_on_failure(max_retries=2, delay=0.5)
def send_with_resilience(self, data: bytes) -> int:
"""具有容错能力的发送"""
if not self.socket:
raise ConnectionError("未连接")
try:
sent = self.socket.send(data)
return sent
except (BrokenPipeError, ConnectionResetError):
self._record_failure()
raise ConnectionError("连接已断开")
except socket.timeout:
self._record_failure()
raise TimeoutError("发送超时")
def close(self):
"""关闭连接"""
if self.socket:
try:
self.socket.close()
except:
pass
self.socket = None
class HealthChecker:
"""健康检查器"""
def __init__(self, check_interval: float = 30.0):
self.check_interval = check_interval
self.endpoints = {} # endpoint -> health status
self.running = False
def add_endpoint(self, endpoint_id: str, host: str, port: int):
"""添加监控端点"""
self.endpoints[endpoint_id] = {
'host': host,
'port': port,
'healthy': True,
'last_check': None,
'failures': 0,
'response_time': None,
}
def start_monitoring(self):
"""开始监控"""
self.running = True
import threading
monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
monitor_thread.start()
def _monitor_loop(self):
"""监控循环"""
while self.running:
for endpoint_id, endpoint_info in self.endpoints.items():
self._check_endpoint(endpoint_id, endpoint_info)
time.sleep(self.check_interval)
def _check_endpoint(self, endpoint_id: str, endpoint_info: dict):
"""检查端点"""
host = endpoint_info['host']
port = endpoint_info['port']
try:
start_time = time.perf_counter()
# 尝试连接
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5.0)
sock.connect((host, port))
# 发送健康检查请求
sock.send(b'HEALTH_CHECK\n')
# 接收响应
response = sock.recv(1024)
end_time = time.perf_counter()
response_time = (end_time - start_time) * 1000 # 毫秒
sock.close()
# 更新状态
endpoint_info['healthy'] = True
endpoint_info['failures'] = 0
endpoint_info['response_time'] = response_time
endpoint_info['last_check'] = time.time()
logging.debug(f"端点 {endpoint_id} 健康,响应时间: {response_time:.2f}ms")
except Exception as e:
# 更新状态
endpoint_info['healthy'] = False
endpoint_info['failures'] += 1
endpoint_info['last_check'] = time.time()
logging.warning(f"端点 {endpoint_id} 不健康: {e}")
def get_healthy_endpoints(self) -> list:
"""获取健康端点"""
healthy = []
for endpoint_id, endpoint_info in self.endpoints.items():
if endpoint_info['healthy']:
healthy.append({
'id': endpoint_id,
'host': endpoint_info['host'],
'port': endpoint_info['port'],
'response_time': endpoint_info['response_time']
})
# 按响应时间排序
healthy.sort(key=lambda x: x['response_time'] or float('inf'))
return healthy
def stop_monitoring(self):
"""停止监控"""
self.running = False
# 使用示例
def demonstrate_resilience():
"""演示容错机制"""
logging.basicConfig(level=logging.INFO)
# 创建健康检查器
health_checker = HealthChecker(check_interval=10.0)
health_checker.add_endpoint('server1', 'localhost', 8888)
health_checker.add_endpoint('server2', 'localhost', 8889)
health_checker.start_monitoring()
# 等待健康检查
time.sleep(15)
# 获取健康端点
healthy_servers = health_checker.get_healthy_endpoints()
if healthy_servers:
# 选择最快的服务器
best_server = healthy_servers[0]
print(f"选择服务器: {best_server['id']}, "
f"响应时间: {best_server['response_time']:.2f}ms")
# 创建具有容错能力的客户端
client = ResilientTCPClient(best_server['host'], best_server['port'])
try:
# 尝试连接(具有重试和断路器)
if client.connect_with_resilience():
print("连接成功")
# 发送数据(具有重试)
client.send_with_resilience(b"test data")
print("数据发送成功")
except Exception as e:
print(f"操作失败: {e}")
finally:
client.close()
else:
print("没有健康的服务器可用")
health_checker.stop_monitoring()
if __name__ == "__main__":
demonstrate_resilience()
7. 总结
7.1 关键知识点总结
通过本文的深度解析和实战代码,我们全面掌握了TCP/UDP网络编程的核心知识:
7.1.1 TCP编程核心
- 连接管理:三次握手、四次挥手、状态机
- 可靠性保证:序列号、确认机制、重传、流量控制、拥塞控制
- 高级特性:连接池、负载均衡、SSL/TLS加密、文件传输
- 性能优化:非阻塞I/O、线程池、连接复用、零拷贝
7.1.2 UDP编程核心
- 无连接特性:快速传输、低开销、无状态
- 可靠性实现:序列号、确认、重传(应用层实现)
- 广播/多播:一对多通信模式
- 实时应用:音视频流、游戏同步、IoT通信
7.1.3 高级网络技术
- 异步编程:select/poll/epoll、asyncio、线程池
- 网络安全:SSL/TLS、消息认证、数据加密
- 诊断工具:端口扫描、网络监控、性能分析
- 容错机制:重试策略、断路器模式、健康检查
7.2 协议选择决策树
是 否 是 否 是 否 是 否 开始网络通信设计 需求分析 需要可靠传输? 需要有序传输? UDP协议 需要流量控制? TCP协议 自定义可靠UDP 实时性要求高? 延迟敏感型UDP 带宽敏感型UDP 完成
7.3 性能对比矩阵
| 应用场景 | 推荐协议 | 延迟 | 可靠性 | 带宽效率 | 实现复杂度 |
|---|---|---|---|---|---|
| 网页浏览 | TCP | 中 | 高 | 中 | 低 |
| 文件传输 | TCP | 高 | 高 | 高 | 低 |
| 实时游戏 | UDP | 低 | 中 | 高 | 高 |
| 视频流 | UDP | 低 | 低 | 高 | 中 |
| DNS查询 | UDP | 低 | 中 | 高 | 低 |
| 物联网 | UDP/MQTT | 低 | 中 | 高 | 中 |
7.4 最佳实践清单
7.4.1 TCP编程最佳实践
- 使用连接池减少连接开销
- 实现超时和重试机制
- 处理TCP粘包/拆包问题
- 使用心跳包保持连接
- 启用TCP_NODELAY减少延迟
- 调整缓冲区大小优化性能
- 实现优雅的连接关闭
- 添加SSL/TLS安全传输
7.4.2 UDP编程最佳实践
- 实现应用层可靠性(如需要)
- 添加序列号和确认机制
- 处理数据包乱序和丢失
- 实现流量控制(如需要)
- 使用校验和验证数据完整性
- 优化数据包大小(避免分片)
- 处理网络地址转换(NAT)
- 实现服务发现机制
7.4.3 通用最佳实践
- 使用非阻塞I/O提高并发
- 实现连接监控和统计
- 添加完善的错误处理
- 实现断路器模式提高容错
- 使用配置化参数
- 添加详细的日志记录
- 实现性能监控和告警
- 编写单元测试和集成测试
7.5 未来趋势
- QUIC协议:基于UDP的可靠传输,减少连接建立延迟
- HTTP/3:基于QUIC,进一步优化Web性能
- 5G网络:低延迟、高带宽,推动实时应用发展
- 边缘计算:网络处理向边缘迁移,减少延迟
- 零信任网络:增强网络安全模型
通过本文的学习和实践,您已经掌握了TCP/UDP网络编程的核心知识和实战技能,能够设计并实现高性能、高可用的网络应用系统。这些知识是构建现代分布式系统的基础,也是向更高级网络技术(如HTTP/3、QUIC、WebSocket等)进阶的重要阶梯。