高级篇:Python脚本(101-150)
1. 并发编程 (5个脚本)
1.1 多线程编程.py
python
# 101.1 多线程编程.py
import threading
import time
import random
from concurrent.futures import ThreadPoolExecutor
class BankAccount:
"""银行账户类,演示线程安全"""
def __init__(self, initial_balance=0):
self.balance = initial_balance
self.lock = threading.Lock()
def deposit(self, amount):
"""存款 - 线程安全版本"""
with self.lock:
current_balance = self.balance
time.sleep(0.001) # 模拟处理时间
self.balance = current_balance + amount
return self.balance
def withdraw(self, amount):
"""取款 - 线程安全版本"""
with self.lock:
if self.balance >= amount:
current_balance = self.balance
time.sleep(0.001) # 模拟处理时间
self.balance = current_balance - amount
return self.balance
else:
raise ValueError("余额不足")
def unsafe_transaction(account, transactions):
"""不安全的交易操作"""
for amount in transactions:
if amount > 0:
account.deposit(amount)
else:
try:
account.withdraw(-amount)
except ValueError:
pass
def multi_threading_demo():
"""多线程编程演示"""
print("=== 多线程编程 ===")
# 创建账户
account = BankAccount(1000)
# 模拟交易
transactions = []
for _ in range(100):
amount = random.choice([-50, -20, 20, 50, 100])
transactions.append(amount)
# 单线程执行
print("单线程执行:")
single_thread_account = BankAccount(1000)
start_time = time.time()
unsafe_transaction(single_thread_account, transactions)
single_thread_time = time.time() - start_time
print(f" 最终余额: {single_thread_account.balance}")
print(f" 执行时间: {single_thread_time:.4f}秒")
# 多线程执行
print("\n多线程执行:")
multi_thread_account = BankAccount(1000)
def worker(trans_chunk):
unsafe_transaction(multi_thread_account, trans_chunk)
# 分割交易到多个线程
num_threads = 4
chunk_size = len(transactions) // num_threads
threads = []
start_time = time.time()
for i in range(num_threads):
start_idx = i * chunk_size
end_idx = start_idx + chunk_size if i < num_threads - 1 else len(transactions)
chunk = transactions[start_idx:end_idx]
thread = threading.Thread(target=worker, args=(chunk,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
multi_thread_time = time.time() - start_time
print(f" 最终余额: {multi_thread_account.balance}")
print(f" 执行时间: {multi_thread_time:.4f}秒")
print(f" 加速比: {single_thread_time/multi_thread_time:.2f}x")
# 使用线程池
print("\n使用线程池:")
pool_account = BankAccount(1000)
start_time = time.time()
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
for i in range(num_threads):
start_idx = i * chunk_size
end_idx = start_idx + chunk_size if i < num_threads - 1 else len(transactions)
chunk = transactions[start_idx:end_idx]
future = executor.submit(worker, chunk)
futures.append(future)
# 等待所有任务完成
for future in futures:
future.result()
pool_time = time.time() - start_time
print(f" 最终余额: {pool_account.balance}")
print(f" 执行时间: {pool_time:.4f}秒")
if __name__ == "__main__":
multi_threading_demo()
1.2 多进程编程.py
python
# 101.2 多进程编程.py
import multiprocessing
import time
import math
import os
def cpu_intensive_task(n):
"""CPU密集型任务 - 计算素数"""
def is_prime(num):
if num < 2:
return False
for i in range(2, int(math.sqrt(num)) + 1):
if num % i == 0:
return False
return True
primes = []
for i in range(2, n + 1):
if is_prime(i):
primes.append(i)
return primes
def io_intensive_task(filename, data):
"""I/O密集型任务 - 文件操作"""
time.sleep(0.1) # 模拟I/O延迟
with open(filename, 'w') as f:
f.write(data)
time.sleep(0.1)
with open(filename, 'r') as f:
content = f.read()
os.remove(filename)
return len(content)
def multi_processing_demo():
"""多进程编程演示"""
print("=== 多进程编程 ===")
# CPU密集型任务比较
print("1. CPU密集型任务比较:")
n = 10000
# 单进程
print("单进程执行:")
start_time = time.time()
result_single = cpu_intensive_task(n)
single_time = time.time() - start_time
print(f" 找到 {len(result_single)} 个素数")
print(f" 执行时间: {single_time:.4f}秒")
# 多进程
print("\n多进程执行 (4进程):")
start_time = time.time()
with multiprocessing.Pool(processes=4) as pool:
# 分割任务
chunks = [n // 4, n // 4, n // 4, n - 3 * (n // 4)]
ranges = []
start = 2
for chunk in chunks:
end = start + chunk - 1
ranges.append((start, end))
start = end + 1
# 执行任务
results = pool.starmap(cpu_intensive_task_range, ranges)
# 合并结果
result_multi = []
for res in results:
result_multi.extend(res)
multi_time = time.time() - start_time
print(f" 找到 {len(result_multi)} 个素数")
print(f" 执行时间: {multi_time:.4f}秒")
print(f" 加速比: {single_time/multi_time:.2f}x")
# I/O密集型任务
print("\n2. I/O密集型任务比较:")
test_data = "Hello, Multiprocessing!" * 1000
num_files = 16
# 单进程
print("单进程执行:")
start_time = time.time()
single_results = []
for i in range(num_files):
filename = f"temp_single_{i}.txt"
result = io_intensive_task(filename, test_data)
single_results.append(result)
single_io_time = time.time() - start_time
print(f" 执行时间: {single_io_time:.4f}秒")
# 多进程
print("\n多进程执行:")
start_time = time.time()
with multiprocessing.Pool(processes=4) as pool:
tasks = []
for i in range(num_files):
filename = f"temp_multi_{i}.txt"
tasks.append((filename, test_data))
multi_results = pool.starmap(io_intensive_task, tasks)
multi_io_time = time.time() - start_time
print(f" 执行时间: {multi_io_time:.4f}秒")
print(f" 加速比: {single_io_time/multi_io_time:.2f}x")
# 进程间通信
print("\n3. 进程间通信:")
def producer(queue, items):
"""生产者进程"""
for item in items:
queue.put(item)
print(f"生产者放入: {item}")
time.sleep(0.1)
queue.put(None) # 结束信号
def consumer(queue, name):
"""消费者进程"""
while True:
item = queue.get()
if item is None:
queue.put(None) # 传递给其他消费者
break
print(f"消费者{name}取出: {item}")
time.sleep(0.2)
# 创建队列
queue = multiprocessing.Queue()
# 创建进程
producer_process = multiprocessing.Process(
target=producer,
args=(queue, ['A', 'B', 'C', 'D', 'E'])
)
consumer_process1 = multiprocessing.Process(
target=consumer,
args=(queue, "1")
)
consumer_process2 = multiprocessing.Process(
target=consumer,
args=(queue, "2")
)
# 启动进程
producer_process.start()
consumer_process1.start()
consumer_process2.start()
# 等待完成
producer_process.join()
consumer_process1.join()
consumer_process2.join()
def cpu_intensive_task_range(start, end):
"""辅助函数:计算指定范围内的素数"""
def is_prime(num):
if num < 2:
return False
for i in range(2, int(math.sqrt(num)) + 1):
if num % i == 0:
return False
return True
primes = []
for i in range(start, end + 1):
if is_prime(i):
primes.append(i)
return primes
if __name__ == "__main__":
multi_processing_demo()
1.3 异步编程.py
python
# 101.3 异步编程.py
import asyncio
import aiohttp
import time
import asyncpg
import json
class AsyncWebCrawler:
"""异步网页爬虫"""
def __init__(self, max_concurrent=5):
self.semaphore = asyncio.Semaphore(max_concurrent)
self.visited_urls = set()
async def fetch_url(self, session, url):
"""获取URL内容"""
async with self.semaphore:
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
if response.status == 200:
content = await response.text()
return {
'url': url,
'content': content[:200], # 只取前200字符
'status': 'success',
'length': len(content)
}
else:
return {
'url': url,
'content': '',
'status': f'error_{response.status}',
'length': 0
}
except Exception as e:
return {
'url': url,
'content': '',
'status': f'exception_{type(e).__name__}',
'length': 0
}
async def crawl_urls(self, urls):
"""爬取多个URL"""
async with aiohttp.ClientSession() as session:
tasks = [self.fetch_url(session, url) for url in urls]
results = await asyncio.gather(*tasks)
return results
class AsyncDatabase:
"""异步数据库操作"""
def __init__(self, connection_string):
self.connection_string = connection_string
self.pool = None
async def connect(self):
"""连接数据库"""
self.pool = await asyncpg.create_pool(self.connection_string)
async def create_table(self):
"""创建表"""
async with self.pool.acquire() as connection:
await connection.execute('''
CREATE TABLE IF NOT EXISTS web_data (
id SERIAL PRIMARY KEY,
url TEXT NOT NULL,
content TEXT,
status TEXT,
length INTEGER,
created_at TIMESTAMP DEFAULT NOW()
)
''')
async def insert_data(self, data_list):
"""插入数据"""
async with self.pool.acquire() as connection:
for data in data_list:
await connection.execute('''
INSERT INTO web_data (url, content, status, length)
VALUES ($1, $2, $3, $4)
''', data['url'], data['content'], data['status'], data['length'])
async def get_stats(self):
"""获取统计信息"""
async with self.pool.acquire() as connection:
stats = await connection.fetchrow('''
SELECT
COUNT(*) as total,
COUNT(CASE WHEN status = 'success' THEN 1 END) as success,
COUNT(CASE WHEN status != 'success' THEN 1 END) as failed
FROM web_data
''')
return dict(stats)
async def async_programming_demo():
"""异步编程演示"""
print("=== 异步编程 ===")
# 异步网页爬虫
print("1. 异步网页爬虫:")
crawler = AsyncWebCrawler(max_concurrent=3)
test_urls = [
'https://httpbin.org/delay/1',
'https://httpbin.org/delay/2',
'https://httpbin.org/json',
'https://httpbin.org/html',
'https://httpbin.org/xml',
'https://httpbin.org/robots.txt'
]
start_time = time.time()
results = await crawler.crawl_urls(test_urls)
crawl_time = time.time() - start_time
print(f"爬取 {len(test_urls)} 个URL耗时: {crawl_time:.2f}秒")
for result in results:
print(f" {result['url']} - {result['status']} - {result['length']}字节")
# 异步数据库操作
print("\n2. 异步数据库操作:")
try:
# 注意:需要PostgreSQL数据库支持
db = AsyncDatabase('postgresql://user:password@localhost/testdb')
await db.connect()
await db.create_table()
await db.insert_data(results)
stats = await db.get_stats()
print(f"数据库统计: 总计{stats['total']}条, 成功{stats['success']}条, 失败{stats['failed']}条")
except Exception as e:
print(f"数据库操作失败: {e}")
print("请确保PostgreSQL服务运行且连接字符串正确")
# 异步任务协调
print("\n3. 异步任务协调:")
async def long_running_task(name, duration):
"""长时间运行的任务"""
print(f"任务 {name} 开始,预计耗时 {duration}秒")
await asyncio.sleep(duration)
print(f"任务 {name} 完成")
return f"任务 {name} 结果"
async def with_timeout(task, timeout):
"""带超时的任务"""
try:
result = await asyncio.wait_for(task, timeout)
return result
except asyncio.TimeoutError:
return f"任务超时(限制{timeout}秒)"
# 同时运行多个任务
tasks = [
long_running_task("A", 3),
long_running_task("B", 1),
with_timeout(long_running_task("C", 5), 2), # 这个会超时
long_running_task("D", 2)
]
print("开始执行多个异步任务...")
start_time = time.time()
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.time() - start_time
print(f"所有任务完成,总耗时: {total_time:.2f}秒")
for i, result in enumerate(results):
print(f" 任务{i+1}: {result}")
def run_async_demo():
"""运行异步演示"""
asyncio.run(async_programming_demo())
if __name__ == "__main__":
run_async_demo()
1.4 协程和生成器协程.py
python
# 101.4 协程和生成器协程.py
import asyncio
import time
from typing import AsyncGenerator, Generator
def traditional_generator() -> Generator[int, None, None]:
"""传统生成器"""
print("传统生成器开始")
for i in range(3):
print(f"传统生成器 yield {i}")
yield i
print("传统生成器结束")
async def async_generator() -> AsyncGenerator[int, None]:
"""异步生成器"""
print("异步生成器开始")
for i in range(3):
print(f"异步生成器准备 yield {i}")
await asyncio.sleep(0.5) # 模拟异步操作
print(f"异步生成器 yield {i}")
yield i
print("异步生成器结束")
async def simple_coroutine(name: str, delay: float):
"""简单协程"""
print(f"协程 {name} 开始")
await asyncio.sleep(delay)
print(f"协程 {name} 完成")
return f"{name}的结果"
class CoroutineManager:
"""协程管理器"""
def __init__(self):
self.tasks = []
async def add_task(self, coro):
"""添加任务"""
task = asyncio.create_task(coro)
self.tasks.append(task)
return task
async def wait_all(self):
"""等待所有任务完成"""
results = await asyncio.gather(*self.tasks, return_exceptions=True)
self.tasks.clear()
return results
async def process_with_async_gen(self):
"""使用异步生成器处理数据"""
async for value in async_generator():
print(f"处理异步值: {value}")
# 模拟处理
await asyncio.sleep(0.2)
def coroutine_demo():
"""协程演示"""
print("=== 协程和生成器协程 ===")
# 传统生成器
print("1. 传统生成器:")
gen = traditional_generator()
for value in gen:
print(f"主程序收到: {value}")
# 异步生成器和协程
print("\n2. 异步生成器和协程:")
async def async_main():
# 异步生成器
print("异步生成器演示:")
async for value in async_generator():
print(f"主程序收到异步值: {value}")
# 协程管理器
print("\n协程管理器演示:")
manager = CoroutineManager()
# 添加多个任务
await manager.add_task(simple_coroutine("任务1", 1.0))
await manager.add_task(simple_coroutine("任务2", 0.5))
await manager.add_task(simple_coroutine("任务3", 1.5))
# 等待所有任务完成
results = await manager.wait_all()
print(f"所有任务结果: {results}")
# 使用异步生成器处理
print("\n使用异步生成器处理:")
await manager.process_with_async_gen()
# 运行异步主函数
asyncio.run(async_main())
def advanced_coroutine_patterns():
"""高级协程模式"""
print("\n3. 高级协程模式:")
async def producer(queue: asyncio.Queue):
"""生产者协程"""
for i in range(5):
item = f"项目_{i}"
await queue.put(item)
print(f"生产者发送: {item}")
await asyncio.sleep(0.3)
await queue.put(None) # 结束信号
async def consumer(queue: asyncio.Queue, name: str):
"""消费者协程"""
while True:
item = await queue.get()
if item is None:
# 把结束信号放回,让其他消费者也能看到
await queue.put(None)
break
print(f"消费者{name}处理: {item}")
await asyncio.sleep(0.5) # 模拟处理时间
queue.task_done()
async def advanced_main():
# 创建队列
queue = asyncio.Queue(maxsize=3)
# 创建生产者和消费者任务
producer_task = asyncio.create_task(producer(queue))
consumer_tasks = [
asyncio.create_task(consumer(queue, "A")),
asyncio.create_task(consumer(queue, "B"))
]
# 等待生产者完成
await producer_task
# 等待所有项目被处理
await queue.join()
# 取消消费者任务
for task in consumer_tasks:
task.cancel()
# 等待消费者任务完成(被取消)
await asyncio.gather(*consumer_tasks, return_exceptions=True)
print("生产消费模式完成")
asyncio.run(advanced_main())
if __name__ == "__main__":
coroutine_demo()
advanced_coroutine_patterns()
1.5 并发编程实战.py
python
# 101.5 并发编程实战.py
import asyncio
import aiohttp
import asyncpg
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from typing import List, Dict, Any
import hashlib
@dataclass
class ProcessingResult:
"""处理结果数据类"""
url: str
content_hash: str
processing_time: float
word_count: int
status: str
class ConcurrentProcessor:
"""并发处理器"""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self.results: List[ProcessingResult] = []
async def process_urls_async(self, urls: List[str]) -> List[ProcessingResult]:
"""异步处理URLs"""
print("=== 异步处理模式 ===")
start_time = time.time()
async with aiohttp.ClientSession() as session:
tasks = [self._process_single_url_async(session, url) for url in urls]
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.time() - start_time
print(f"异步处理完成: {len(urls)}个URL, 耗时: {total_time:.2f}秒")
return [r for r in results if isinstance(r, ProcessingResult)]
async def _process_single_url_async(self, session: aiohttp.ClientSession, url: str) -> ProcessingResult:
"""异步处理单个URL"""
start_time = time.time()
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as response:
if response.status == 200:
content = await response.text()
# 计算哈希
content_hash = hashlib.md5(content.encode()).hexdigest()
# 计算单词数
word_count = len(content.split())
processing_time = time.time() - start_time
return ProcessingResult(
url=url,
content_hash=content_hash,
processing_time=processing_time,
word_count=word_count,
status="success"
)
else:
return ProcessingResult(
url=url,
content_hash="",
processing_time=time.time() - start_time,
word_count=0,
status=f"http_error_{response.status}"
)
except Exception as e:
return ProcessingResult(
url=url,
content_hash="",
processing_time=time.time() - start_time,
word_count=0,
status=f"exception_{type(e).__name__}"
)
def process_urls_threaded(self, urls: List[str]) -> List[ProcessingResult]:
"""多线程处理URLs"""
print("=== 多线程处理模式 ===")
start_time = time.time()
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
results = list(executor.map(self._process_single_url_sync, urls))
total_time = time.time() - start_time
print(f"多线程处理完成: {len(urls)}个URL, 耗时: {total_time:.2f}秒")
return results
def _process_single_url_sync(self, url: str) -> ProcessingResult:
"""同步处理单个URL"""
import requests
start_time = time.time()
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
content = response.text
# 计算哈希
content_hash = hashlib.md5(content.encode()).hexdigest()
# 计算单词数
word_count = len(content.split())
processing_time = time.time() - start_time
return ProcessingResult(
url=url,
content_hash=content_hash,
processing_time=processing_time,
word_count=word_count,
status="success"
)
else:
return ProcessingResult(
url=url,
content_hash="",
processing_time=time.time() - start_time,
word_count=0,
status=f"http_error_{response.status_code}"
)
except Exception as e:
return ProcessingResult(
url=url,
content_hash="",
processing_time=time.time() - start_time,
word_count=0,
status=f"exception_{type(e).__name__}"
)
def process_data_multiprocess(self, data_list: List[str]) -> List[Dict[str, Any]]:
"""多进程数据处理"""
print("=== 多进程处理模式 ===")
start_time = time.time()
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
results = list(executor.map(self._cpu_intensive_processing, data_list))
total_time = time.time() - start_time
print(f"多进程处理完成: {len(data_list)}个数据项, 耗时: {total_time:.2f}秒")
return results
def _cpu_intensive_processing(self, data: str) -> Dict[str, Any]:
"""CPU密集型处理"""
import hashlib
import re
# 模拟复杂计算
start_time = time.time()
# 多次哈希
current_hash = data
for _ in range(10000):
current_hash = hashlib.md5(current_hash.encode()).hexdigest()
# 复杂字符串处理
words = re.findall(r'\b\w+\b', data)
unique_words = set(words)
# 排序和统计
sorted_words = sorted(unique_words)
word_stats = {
'total_words': len(words),
'unique_words': len(unique_words),
'avg_word_length': sum(len(word) for word in words) / len(words) if words else 0
}
processing_time = time.time() - start_time
return {
'original_length': len(data),
'final_hash': current_hash,
'word_stats': word_stats,
'processing_time': processing_time
}
async def concurrent_programming_practice():
"""并发编程实战"""
print("=== 并发编程实战 ===")
# 测试URLs
test_urls = [
'https://httpbin.org/json',
'https://httpbin.org/html',
'https://httpbin.org/xml',
'https://httpbin.org/robots.txt',
'https://httpbin.org/user-agent',
'https://httpbin.org/headers'
]
# 创建处理器
processor = ConcurrentProcessor(max_workers=3)
# 异步处理
async_results = await processor.process_urls_async(test_urls)
print("\n异步处理结果:")
for result in async_results:
print(f" {result.url}: {result.status}, {result.word_count}词, {result.processing_time:.2f}秒")
# 多线程处理
threaded_results = processor.process_urls_threaded(test_urls)
print("\n多线程处理结果:")
for result in threaded_results:
print(f" {result.url}: {result.status}, {result.word_count}词, {result.processing_time:.2f}秒")
# 多进程处理
sample_data = [
"Python is a programming language that lets you work quickly and integrate systems more effectively.",
"Concurrent programming in Python can be achieved through threading, multiprocessing, and asyncio.",
"Asyncio is a library to write concurrent code using the async/await syntax.",
"Threading is suitable for I/O-bound tasks, while multiprocessing is better for CPU-bound tasks."
] * 10 # 重复10次以增加数据量
multiprocess_results = processor.process_data_multiprocess(sample_data)
print("\n多进程处理结果 (前5项):")
for i, result in enumerate(multiprocess_results[:5]):
print(f" 数据{i+1}: {result['processing_time']:.2f}秒, {result['word_stats']['total_words']}词")
# 性能比较
print("\n=== 性能比较 ===")
# 模拟不同并发模式的选择建议
print("并发模式选择建议:")
print("1. I/O密集型任务 (网络请求、文件操作):")
print(" - 推荐: asyncio (最高效)")
print(" - 备选: 多线程")
print(" - 避免: 多进程 (创建成本高)")
print("\n2. CPU密集型任务 (数学计算、数据处理):")
print(" - 推荐: 多进程 (利用多核)")
print(" - 避免: 多线程 (GIL限制)")
print(" - 避免: asyncio (不提供真正的并行)")
print("\n3. 混合型任务:")
print(" - 推荐: asyncio + 多进程组合")
print(" - 使用线程池执行器运行CPU密集型任务")
def run_concurrent_practice():
"""运行并发实战"""
asyncio.run(concurrent_programming_practice())
if __name__ == "__main__":
run_concurrent_practice()
2. 网络编程 (5个脚本)
2.1 Socket编程基础.py
python
# 102.1 Socket编程基础.py
import socket
import threading
import time
import json
class TCPServer:
"""TCP服务器"""
def __init__(self, host='localhost', port=8888):
self.host = host
self.port = port
self.socket = None
self.clients = []
self.running = False
def start(self):
"""启动服务器"""
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
self.socket.bind((self.host, self.port))
self.socket.listen(5)
self.running = True
print(f"TCP服务器启动在 {self.host}:{self.port}")
# 接受连接线程
accept_thread = threading.Thread(target=self._accept_connections)
accept_thread.daemon = True
accept_thread.start()
# 控制台命令线程
command_thread = threading.Thread(target=self._command_interface)
command_thread.daemon = True
command_thread.start()
accept_thread.join()
except Exception as e:
print(f"服务器启动失败: {e}")
finally:
self.stop()
def _accept_connections(self):
"""接受客户端连接"""
while self.running:
try:
client_socket, client_address = self.socket.accept()
print(f"新的客户端连接: {client_address}")
# 为每个客户端创建处理线程
client_thread = threading.Thread(
target=self._handle_client,
args=(client_socket, client_address)
)
client_thread.daemon = True
client_thread.start()
self.clients.append({
'socket': client_socket,
'address': client_address,
'thread': client_thread
})
except socket.error:
if self.running:
print("接受连接时发生错误")
break
def _handle_client(self, client_socket, client_address):
"""处理客户端消息"""
try:
while self.running:
# 接收数据
data = client_socket.recv(1024).decode('utf-8')
if not data:
break
print(f"收到来自 {client_address} 的消息: {data}")
# 解析JSON消息
try:
message = json.loads(data)
response = self._process_message(message, client_address)
client_socket.send(json.dumps(response).encode('utf-8'))
except json.JSONDecodeError:
# 如果不是JSON,原样返回
response = {"status": "echo", "message": data}
client_socket.send(json.dumps(response).encode('utf-8'))
except socket.error as e:
print(f"客户端 {client_address} 连接错误: {e}")
finally:
client_socket.close()
self._remove_client(client_address)
print(f"客户端 {client_address} 断开连接")
def _process_message(self, message, client_address):
"""处理消息"""
msg_type = message.get('type', 'unknown')
if msg_type == 'echo':
return {"status": "success", "echo": message.get('data')}
elif msg_type == 'time':
return {"status": "success", "timestamp": time.time()}
elif msg_type == 'broadcast':
self._broadcast_message(message.get('data'), client_address)
return {"status": "success", "broadcasted": True}
else:
return {"status": "error", "message": "未知消息类型"}
def _broadcast_message(self, message, sender_address):
"""广播消息给所有客户端"""
broadcast_msg = json.dumps({
"type": "broadcast",
"from": str(sender_address),
"message": message,
"timestamp": time.time()
})
disconnected_clients = []
for client in self.clients:
try:
if client['address'] != sender_address:
client['socket'].send(broadcast_msg.encode('utf-8'))
except socket.error:
disconnected_clients.append(client['address'])
# 移除断开的客户端
for address in disconnected_clients:
self._remove_client(address)
def _remove_client(self, client_address):
"""移除客户端"""
self.clients = [c for c in self.clients if c['address'] != client_address]
print(f"当前客户端数量: {len(self.clients)}")
def _command_interface(self):
"""服务器命令接口"""
while self.running:
try:
command = input("服务器命令 (stop/status/broadcast): ").strip().lower()
if command == 'stop':
self.stop()
break
elif command == 'status':
print(f"运行状态: {self.running}")
print(f"客户端数量: {len(self.clients)}")
elif command.startswith('broadcast '):
message = command[10:]
self._broadcast_message(message, "SERVER")
print(f"广播消息: {message}")
except (EOFError, KeyboardInterrupt):
self.stop()
break
def stop(self):
"""停止服务器"""
self.running = False
# 关闭所有客户端连接
for client in self.clients:
try:
client['socket'].close()
except:
pass
if self.socket:
self.socket.close()
print("TCP服务器已停止")
class TCPClient:
"""TCP客户端"""
def __init__(self, host='localhost', port=8888):
self.host = host
self.port = port
self.socket = None
self.connected = False
def connect(self):
"""连接到服务器"""
try:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((self.host, self.port))
self.connected = True
print(f"已连接到服务器 {self.host}:{self.port}")
# 启动接收消息线程
receive_thread = threading.Thread(target=self._receive_messages)
receive_thread.daemon = True
receive_thread.start()
return True
except Exception as e:
print(f"连接失败: {e}")
return False
def send_message(self, message_type, data=None):
"""发送消息到服务器"""
if not self.connected:
print("未连接到服务器")
return
message = {"type": message_type}
if data:
message["data"] = data
try:
self.socket.send(json.dumps(message).encode('utf-8'))
except socket.error as e:
print(f"发送消息失败: {e}")
self.connected = False
def _receive_messages(self):
"""接收服务器消息"""
while self.connected:
try:
data = self.socket.recv(1024).decode('utf-8')
if not data:
break
try:
message = json.loads(data)
self._handle_server_message(message)
except json.JSONDecodeError:
print(f"收到原始消息: {data}")
except socket.error:
break
print("与服务器的连接已断开")
self.connected = False
def _handle_server_message(self, message):
"""处理服务器消息"""
msg_type = message.get('type', 'unknown')
if msg_type == 'broadcast':
print(f"\n[广播来自 {message['from']}]: {message['message']}")
else:
print(f"\n[服务器响应]: {message}")
def disconnect(self):
"""断开连接"""
self.connected = False
if self.socket:
self.socket.close()
def socket_programming_demo():
"""Socket编程演示"""
import threading
import time
print("=== Socket编程基础 ===")
# 启动服务器
server = TCPServer(port=8888)
server_thread = threading.Thread(target=server.start)
server_thread.daemon = True
server_thread.start()
# 等待服务器启动
time.sleep(1)
# 创建客户端
clients = []
for i in range(3):
client = TCPClient(port=8888)
if client.connect():
clients.append(client)
time.sleep(0.1)
if clients:
# 测试消息发送
print("\n测试消息发送:")
clients[0].send_message("echo", "Hello, Server!")
time.sleep(0.5)
clients[1].send_message("time")
time.sleep(0.5)
clients[2].send_message("broadcast", "大家好,这是广播测试!")
time.sleep(1)
# 断开客户端
for client in clients:
client.disconnect()
# 等待一段时间后停止服务器
time.sleep(2)
server.stop()
if __name__ == "__main__":
socket_programming_demo()
2.2 HTTP客户端和服务端.py
python
# 102.2 HTTP客户端和服务端.py
from http.server import HTTPServer, BaseHTTPRequestHandler
import json
import threading
import time
import urllib.parse
from http.client import HTTPConnection
import requests
class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
"""简单的HTTP请求处理器"""
def do_GET(self):
"""处理GET请求"""
parsed_path = urllib.parse.urlparse(self.path)
path = parsed_path.path
# 路由处理
if path == '/':
self._send_html_response('<h1>欢迎来到Python HTTP服务器</h1>')
elif path == '/api/time':
self._send_json_response({'timestamp': time.time()})
elif path == '/api/status':
self._send_json_response({'status': 'running', 'version': '1.0'})
elif path.startswith('/api/echo/'):
message = path.split('/api/echo/')[1]
self._send_json_response({'echo': message})
else:
self._send_error(404, "页面未找到")
def do_POST(self):
"""处理POST请求"""
content_length = int(self.headers.get('Content-Length', 0))
post_data = self.rfile.read(content_length)
try:
data = json.loads(post_data.decode('utf-8'))
response = self._process_post_data(data)
self._send_json_response(response)
except json.JSONDecodeError:
self._send_error(400, "无效的JSON数据")
def do_PUT(self):
"""处理PUT请求"""
self._send_json_response({'method': 'PUT', 'message': '资源已更新'})
def do_DELETE(self):
"""处理DELETE请求"""
self._send_json_response({'method': 'DELETE', 'message': '资源已删除'})
def _process_post_data(self, data):
"""处理POST数据"""
action = data.get('action', 'unknown')
if action == 'add':
a = data.get('a', 0)
b = data.get('b', 0)
return {'result': a + b, 'action': 'add'}
elif action == 'multiply':
a = data.get('a', 1)
b = data.get('b', 1)
return {'result': a * b, 'action': 'multiply'}
else:
return {'error': '未知操作', 'received_data': data}
def _send_html_response(self, html_content):
"""发送HTML响应"""
self.send_response(200)
self.send_header('Content-type', 'text/html; charset=utf-8')
self.end_headers()
self.wfile.write(html_content.encode('utf-8'))
def _send_json_response(self, data):
"""发送JSON响应"""
self.send_response(200)
self.send_header('Content-type', 'application/json; charset=utf-8')
self.end_headers()
self.wfile.write(json.dumps(data, ensure_ascii=False).encode('utf-8'))
def _send_error(self, code, message):
"""发送错误响应"""
self.send_response(code)
self.send_header('Content-type', 'application/json; charset=utf-8')
self.end_headers()
error_response = {'error': message, 'code': code}
self.wfile.write(json.dumps(error_response).encode('utf-8'))
def log_message(self, format, *args):
"""自定义日志格式"""
print(f"[HTTP服务器] {self.address_string()} - {format % args}")
class HTTPClientDemo:
"""HTTP客户端演示"""
@staticmethod
def test_http_client():
"""测试内置http.client"""
print("=== 使用http.client ===")
conn = HTTPConnection('localhost', 8000)
try:
# GET请求
print("1. GET请求:")
conn.request("GET", "/api/status")
response = conn.getresponse()
print(f" 状态码: {response.status}")
print(f" 响应: {response.read().decode()}")
# POST请求
print("\n2. POST请求:")
post_data = json.dumps({'action': 'add', 'a': 5, 'b': 3})
conn.request("POST", "/api/data", body=post_data, headers={'Content-Type': 'application/json'})
response = conn.getresponse()
print(f" 状态码: {response.status}")
print(f" 响应: {response.read().decode()}")
except ConnectionRefusedError:
print("连接被拒绝,请确保HTTP服务器正在运行")
finally:
conn.close()
@staticmethod
def test_requests_library():
"""测试requests库"""
print("\n=== 使用requests库 ===")
base_url = "http://localhost:8000"
try:
# GET请求
print("1. GET请求:")
response = requests.get(f"{base_url}/api/time")
print(f" 状态码: {response.status_code}")
print(f" 响应: {response.json()}")
# POST请求
print("\n2. POST请求:")
response = requests.post(f"{base_url}/api/data",
json={'action': 'multiply', 'a': 4, 'b': 6})
print(f" 状态码: {response.status_code}")
print(f" 响应: {response.json()}")
# PUT请求
print("\n3. PUT请求:")
response = requests.put(f"{base_url}/api/resource")
print(f" 状态码: {response.status_code}")
print(f" 响应: {response.json()}")
# 错误请求
print("\n4. 错误请求:")
response = requests.get(f"{base_url}/api/nonexistent")
print(f" 状态码: {response.status_code}")
print(f" 响应: {response.json()}")
except requests.ConnectionError:
print("连接错误,请确保HTTP服务器正在运行")
@staticmethod
def advanced_requests_demo():
"""高级requests功能演示"""
print("\n=== 高级requests功能 ===")
# 会话对象(保持cookies)
session = requests.Session()
# 公共API测试
try:
print("1. 公共API测试:")
response = session.get('https://httpbin.org/json')
print(f" httpbin.org响应: {response.status_code}")
print("\n2. 带参数的请求:")
response = session.get('https://httpbin.org/get', params={'key1': 'value1', 'key2': 'value2'})
data = response.json()
print(f" 参数: {data['args']}")
print("\n3. 自定义头部:")
headers = {'User-Agent': 'MyPythonClient/1.0', 'X-Custom-Header': 'test'}
response = session.get('https://httpbin.org/headers', headers=headers)
print(f" 请求头: {response.json()['headers']}")
print("\n4. 超时设置:")
try:
response = session.get('https://httpbin.org/delay/5', timeout=2)
except requests.Timeout:
print(" 请求超时(预期内)")
except Exception as e:
print(f"请求错误: {e}")
def run_http_server():
"""运行HTTP服务器"""
server = HTTPServer(('localhost', 8000), SimpleHTTPRequestHandler)
print("启动HTTP服务器在 http://localhost:8000")
print("可用端点:")
print(" GET / - HTML页面")
print(" GET /api/time - 当前时间")
print(" GET /api/status - 服务器状态")
print(" GET /api/echo/<msg>- 回声测试")
print(" POST /api/data - 数据处理")
print(" PUT /api/resource - 更新资源")
print(" DELETE /api/resource - 删除资源")
server.serve_forever()
def http_demo():
"""HTTP客户端和服务端演示"""
print("=== HTTP客户端和服务端 ===")
# 启动HTTP服务器线程
server_thread = threading.Thread(target=run_http_server)
server_thread.daemon = True
server_thread.start()
# 等待服务器启动
time.sleep(1)
# 运行客户端测试
client = HTTPClientDemo()
client.test_http_client()
client.test_requests_library()
client.advanced_requests_demo()
# 保持运行
print("\n服务器仍在运行,按Ctrl+C停止...")
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\n停止演示")
if __name__ == "__main__":
http_demo()
2.3 WebSocket实时通信.py
python
# 102.3 WebSocket实时通信.py
import asyncio
import websockets
import json
import time
import threading
from datetime import datetime
class WebSocketServer:
"""WebSocket服务器"""
def __init__(self, host='localhost', port=8765):
self.host = host
self.port = port
self.connected_clients = set()
self.chat_rooms = {}
async def handle_connection(self, websocket, path):
"""处理WebSocket连接"""
client_id = id(websocket)
self.connected_clients.add(websocket)
print(f"新的WebSocket连接: {client_id}")
try:
# 发送欢迎消息
welcome_msg = {
'type': 'system',
'message': '连接成功',
'client_id': client_id,
'timestamp': datetime.now().isoformat()
}
await websocket.send(json.dumps(welcome_msg))
# 处理消息
async for message in websocket:
await self.handle_message(websocket, client_id, message)
except websockets.exceptions.ConnectionClosed:
print(f"WebSocket连接关闭: {client_id}")
finally:
self.connected_clients.remove(websocket)
# 从所有聊天室移除
for room_name in list(self.chat_rooms.keys()):
if websocket in self.chat_rooms[room_name]:
self.chat_rooms[room_name].remove(websocket)
async def handle_message(self, websocket, client_id, message):
"""处理接收到的消息"""
try:
data = json.loads(message)
msg_type = data.get('type')
if msg_type == 'chat':
await self.handle_chat_message(websocket, client_id, data)
elif msg_type == 'join_room':
await self.handle_join_room(websocket, data.get('room'))
elif msg_type == 'leave_room':
await self.handle_leave_room(websocket, data.get('room'))
elif msg_type == 'broadcast':
await self.broadcast_message(data.get('message'), sender_id=client_id)
else:
await websocket.send(json.dumps({
'type': 'error',
'message': '未知消息类型'
}))
except json.JSONDecodeError:
await websocket.send(json.dumps({
'type': 'error',
'message': '无效的JSON格式'
}))
async def handle_chat_message(self, websocket, client_id, data):
"""处理聊天消息"""
room = data.get('room', 'general')
message = data.get('message', '')
username = data.get('username', f'用户{client_id}')
chat_msg = {
'type': 'chat',
'room': room,
'username': username,
'message': message,
'timestamp': datetime.now().isoformat()
}
# 发送到特定聊天室
await self.send_to_room(room, chat_msg)
print(f"聊天消息 [{room}]: {username}: {message}")
async def handle_join_room(self, websocket, room_name):
"""处理加入聊天室"""
if room_name not in self.chat_rooms:
self.chat_rooms[room_name] = set()
self.chat_rooms[room_name].add(websocket)
# 通知房间成员
join_msg = {
'type': 'system',
'message': f'新用户加入了房间 {room_name}',
'timestamp': datetime.now().isoformat()
}
await self.send_to_room(room_name, join_msg)
await websocket.send(json.dumps({
'type': 'system',
'message': f'已加入房间: {room_name}'
}))
async def handle_leave_room(self, websocket, room_name):
"""处理离开聊天室"""
if room_name in self.chat_rooms and websocket in self.chat_rooms[room_name]:
self.chat_rooms[room_name].remove(websocket)
await websocket.send(json.dumps({
'type': 'system',
'message': f'已离开房间: {room_name}'
}))
async def send_to_room(self, room_name, message):
"""发送消息到特定房间"""
if room_name in self.chat_rooms:
disconnected = []
for client in self.chat_rooms[room_name]:
try:
await client.send(json.dumps(message))
except websockets.exceptions.ConnectionClosed:
disconnected.append(client)
# 移除断开连接的客户端
for client in disconnected:
self.chat_rooms[room_name].remove(client)
async def broadcast_message(self, message, sender_id=None):
"""广播消息给所有客户端"""
broadcast_msg = {
'type': 'broadcast',
'message': message,
'sender_id': sender_id,
'timestamp': datetime.now().isoformat()
}
disconnected = []
for client in self.connected_clients:
try:
await client.send(json.dumps(broadcast_msg))
except websockets.exceptions.ConnectionClosed:
disconnected.append(client)
# 移除断开连接的客户端
for client in disconnected:
self.connected_clients.remove(client)
async def start(self):
"""启动WebSocket服务器"""
print(f"启动WebSocket服务器在 ws://{self.host}:{self.port}")
async with websockets.serve(self.handle_connection, self.host, self.port):
await asyncio.Future() # 永久运行
class WebSocketClient:
"""WebSocket客户端"""
def __init__(self, uri):
self.uri = uri
self.websocket = None
self.running = False
async def connect(self):
"""连接到WebSocket服务器"""
try:
self.websocket = await websockets.connect(self.uri)
self.running = True
print(f"已连接到 {self.uri}")
return True
except Exception as e:
print(f"连接失败: {e}")
return False
async def send_message(self, message_type, **kwargs):
"""发送消息"""
if not self.websocket:
print("未连接到服务器")
return
message = {'type': message_type, **kwargs}
await self.websocket.send(json.dumps(message))
async def receive_messages(self):
"""接收消息"""
try:
async for message in self.websocket:
data = json.loads(message)
self.handle_received_message(data)
except websockets.exceptions.ConnectionClosed:
print("连接已关闭")
self.running = False
def handle_received_message(self, data):
"""处理接收到的消息"""
msg_type = data.get('type')
if msg_type == 'chat':
print(f"[{data['room']}] {data['username']}: {data['message']}")
elif msg_type == 'system':
print(f"[系统] {data['message']}")
elif msg_type == 'broadcast':
print(f"[广播] {data['message']}")
else:
print(f"[未知消息] {data}")
async def chat_interface(self):
"""聊天界面"""
username = input("请输入用户名: ")
room = "general"
# 加入默认房间
await self.send_message('join_room', room=room)
print(f"\n已加入房间: {room}")
print("输入消息进行聊天,输入 'quit' 退出")
print("命令:")
print(" /join <room> - 加入房间")
print(" /leave <room> - 离开房间")
print(" /broadcast <msg> - 广播消息")
while self.running:
try:
user_input = await asyncio.get_event_loop().run_in_executor(None, input, "> ")
if user_input.lower() == 'quit':
break
elif user_input.startswith('/join '):
room = user_input[6:]
await self.send_message('join_room', room=room)
elif user_input.startswith('/leave '):
room_to_leave = user_input[7:]
await self.send_message('leave_room', room=room_to_leave)
elif user_input.startswith('/broadcast '):
message = user_input[11:]
await self.send_message('broadcast', message=message)
else:
await self.send_message('chat', room=room, username=username, message=user_input)
except (EOFError, KeyboardInterrupt):
break
async def start(self):
"""启动客户端"""
if await self.connect():
# 同时运行消息接收和用户输入
await asyncio.gather(
self.receive_messages(),
self.chat_interface()
)
await self.close()
async def close(self):
"""关闭连接"""
if self.websocket:
await self.websocket.close()
async def websocket_demo():
"""WebSocket演示"""
print("=== WebSocket实时通信 ===")
# 启动服务器
server = WebSocketServer()
server_task = asyncio.create_task(server.start())
# 等待服务器启动
await asyncio.sleep(1)
# 启动多个客户端进行测试
print("\n启动测试客户端...")
async def test_client(client_id):
client = WebSocketClient("ws://localhost:8765")
await client.connect()
# 发送测试消息
await client.send_message('join_room', room='test')
await client.send_message('chat', room='test', username=f'Client{client_id}', message=f'你好 from Client{client_id}')
await asyncio.sleep(1)
await client.send_message('broadcast', message=f'广播消息 from Client{client_id}')
await asyncio.sleep(0.5)
await client.close()
# 启动多个测试客户端
tasks = [test_client(i) for i in range(3)]
await asyncio.gather(*tasks)
# 运行真实客户端
print("\n启动交互式客户端...")
client = WebSocketClient("ws://localhost:8765")
await client.start()
# 取消服务器任务
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass
def run_websocket_demo():
"""运行WebSocket演示"""
asyncio.run(websocket_demo())
if __name__ == "__main__":
run_websocket_demo()
2.4 网络协议实现.py
python
# 102.4 网络协议实现.py
import smtplib
import poplib
import imaplib
import email
from email.mime.text import MimeText
from email.mime.multipart import MIMEMultipart
from email.header import decode_header
import ftplib
import socket
import ssl
class EmailClient:
"""邮件客户端"""
def __init__(self, smtp_server, smtp_port, pop3_server, pop3_port, imap_server, imap_port):
self.smtp_server = smtp_server
self.smtp_port = smtp_port
self.pop3_server = pop3_server
self.pop3_port = pop3_port
self.imap_server = imap_server
self.imap_port = imap_port
# 测试用的假凭证(实际使用时需要真实凭证)
self.username = "test@example.com"
self.password = "password"
def send_email_smtp(self, to_email, subject, body):
"""使用SMTP发送邮件"""
print(f"=== SMTP发送邮件 ===")
try:
# 创建邮件
msg = MIMEMultipart()
msg['From'] = self.username
msg['To'] = to_email
msg['Subject'] = subject
msg.attach(MimeText(body, 'plain', 'utf-8'))
# 连接SMTP服务器(这里使用模拟)
print(f"连接到SMTP服务器 {self.smtp_server}:{self.smtp_port}")
print(f"发件人: {self.username}")
print(f"收件人: {to_email}")
print(f"主题: {subject}")
print(f"正文: {body}")
# 实际代码(需要真实SMTP服务器):
# with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
# server.starttls()
# server.login(self.username, self.password)
# server.send_message(msg)
print("邮件发送成功(模拟)")
return True
except Exception as e:
print(f"邮件发送失败: {e}")
return False
def receive_email_pop3(self):
"""使用POP3接收邮件"""
print(f"\n=== POP3接收邮件 ===")
try:
# 连接POP3服务器(模拟)
print(f"连接到POP3服务器 {self.pop3_server}:{self.pop3_port}")
# 实际代码(需要真实POP3服务器):
# with poplib.POP3(self.pop3_server, self.pop3_port) as server:
# server.user(self.username)
# server.pass_(self.password)
#
# # 获取邮件统计
# num_messages = len(server.list()[1])
# print(f"邮箱中有 {num_messages} 封邮件")
#
# # 读取最新邮件
# if num_messages > 0:
# response, lines, octets = server.retr(num_messages)
# msg_content = b'\n'.join(lines).decode('utf-8')
# msg = email.message_from_string(msg_content)
#
# subject = decode_header(msg['Subject'])[0][0]
# if isinstance(subject, bytes):
# subject = subject.decode()
#
# print(f"最新邮件主题: {subject}")
# print(f"发件人: {msg['From']}")
print("POP3邮件接收完成(模拟)")
return True
except Exception as e:
print(f"POP3接收失败: {e}")
return False
def receive_email_imap(self):
"""使用IMAP接收邮件"""
print(f"\n=== IMAP接收邮件 ===")
try:
# 连接IMAP服务器(模拟)
print(f"连接到IMAP服务器 {self.imap_server}:{self.imap_port}")
# 实际代码(需要真实IMAP服务器):
# with imaplib.IMAP4_SSL(self.imap_server, self.imap_port) as server:
# server.login(self.username, self.password)
# server.select('INBOX')
#
# # 搜索未读邮件
# status, messages = server.search(None, 'UNSEEN')
# email_ids = messages[0].split()
#
# print(f"有 {len(email_ids)} 封未读邮件")
#
# # 读取最新未读邮件
# if email_ids:
# latest_email_id = email_ids[-1]
# status, msg_data = server.fetch(latest_email_id, '(RFC822)')
# msg = email.message_from_bytes(msg_data[0][1])
#
# subject = decode_header(msg['Subject'])[0][0]
# if isinstance(subject, bytes):
# subject = subject.decode()
#
# print(f"最新未读邮件主题: {subject}")
# print(f"发件人: {msg['From']}")
print("IMAP邮件接收完成(模拟)")
return True
except Exception as e:
print(f"IMAP接收失败: {e}")
return False
class FTPClient:
"""FTP客户端"""
def __init__(self, server, port=21):
self.server = server
self.port = port
self.ftp = None
def connect(self, username='anonymous', password=''):
"""连接到FTP服务器"""
try:
self.ftp = ftplib.FTP()
self.ftp.connect(self.server, self.port)
self.ftp.login(username, password)
print(f"已连接到FTP服务器 {self.server}:{self.port}")
return True
except Exception as e:
print(f"FTP连接失败: {e}")
return False
def list_files(self, directory=''):
"""列出目录文件"""
if not self.ftp:
print("未连接到FTP服务器")
return
try:
print(f"目录 {directory} 中的文件:")
files = []
self.ftp.dir(directory, files.append)
for file in files:
print(f" {file}")
except Exception as e:
print(f"列出文件失败: {e}")
def upload_file(self, local_file, remote_file):
"""上传文件"""
if not self.ftp:
print("未连接到FTP服务器")
return
try:
with open(local_file, 'rb') as f:
self.ftp.storbinary(f'STOR {remote_file}', f)
print(f"文件上传成功: {local_file} -> {remote_file}")
except Exception as e:
print(f"文件上传失败: {e}")
def download_file(self, remote_file, local_file):
"""下载文件"""
if not self.ftp:
print("未连接到FTP服务器")
return
try:
with open(local_file, 'wb') as f:
self.ftp.retrbinary(f'RETR {remote_file}', f.write)
print(f"文件下载成功: {remote_file} -> {local_file}")
except Exception as e:
print(f"文件下载失败: {e}")
def disconnect(self):
"""断开连接"""
if self.ftp:
self.ftp.quit()
print("FTP连接已关闭")
class DNSLookup:
"""DNS查询工具"""
@staticmethod
def lookup_hostname(hostname):
"""查询主机名的IP地址"""
try:
ip_address = socket.gethostbyname(hostname)
print(f"DNS查询: {hostname} -> {ip_address}")
return ip_address
except socket.gaierror as e:
print(f"DNS查询失败: {e}")
return None
@staticmethod
def reverse_lookup(ip_address):
"""反向DNS查询"""
try:
hostname = socket.gethostbyaddr(ip_address)
print(f"反向DNS查询: {ip_address} -> {hostname[0]}")
return hostname[0]
except socket.herror as e:
print(f"反向DNS查询失败: {e}")
return None
@staticmethod
def get_host_info(hostname):
"""获取主机信息"""
try:
info = socket.gethostbyname_ex(hostname)
print(f"主机信息 {hostname}:")
print(f" 主机名: {info[0]}")
print(f" 别名: {info[1]}")
print(f" IP地址: {info[2]}")
return info
except socket.gaierror as e:
print(f"获取主机信息失败: {e}")
return None
def network_protocols_demo():
"""网络协议实现演示"""
print("=== 网络协议实现 ===")
# 邮件协议演示
print("\n1. 邮件协议演示:")
email_client = EmailClient(
smtp_server="smtp.example.com",
smtp_port=587,
pop3_server="pop.example.com",
pop3_port=995,
imap_server="imap.example.com",
imap_port=993
)
email_client.send_email_smtp("recipient@example.com", "测试邮件", "这是一封测试邮件")
email_client.receive_email_pop3()
email_client.receive_email_imap()
# FTP协议演示
print("\n2. FTP协议演示:")
# 使用公共测试FTP服务器
ftp_client = FTPClient("ftp.dlptest.com")
if ftp_client.connect("dlpuser", "rNrKYTX9g7z3RgJRmxWuGHbeu"):
ftp_client.list_files()
# 注意:公共FTP服务器可能不允许上传下载
ftp_client.disconnect()
# DNS查询演示
print("\n3. DNS查询演示:")
dns = DNSLookup()
dns.lookup_hostname("www.google.com")
dns.lookup_hostname("www.github.com")
# 获取本地主机信息
print("\n4. 本地网络信息:")
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
print(f"本地主机名: {hostname}")
print(f"本地IP地址: {local_ip}")
# 端口扫描演示(简单的)
print("\n5. 简单端口扫描演示:")
target_host = "localhost"
common_ports = [21, 22, 23, 25, 53, 80, 110, 143, 443, 993, 995]
print(f"扫描 {target_host} 的常用端口:")
for port in common_ports:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
result = sock.connect_ex((target_host, port))
if result == 0:
print(f" 端口 {port}: 开放")
sock.close()
if __name__ == "__main__":
network_protocols_demo()
2.5 网络安全编程.py
python
# 102.5 网络安全编程.py
import ssl
import hashlib
import hmac
import secrets
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
import base64
import socket
class SecureSocketServer:
"""安全Socket服务器(使用SSL/TLS)"""
def __init__(self, host='localhost', port=8443):
self.host = host
self.port = port
self.context = None
self.setup_ssl_context()
def setup_ssl_context(self):
"""设置SSL上下文"""
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.context.check_hostname = False
# 在实际应用中,这里需要真实的证书和私钥文件
# self.context.load_cert_chain('server.crt', 'server.key')
# 为了演示,我们创建自签名证书(在实际生产环境中不要这样做)
print("警告: 使用自签名证书仅用于演示")
def start(self):
"""启动安全服务器"""
try:
# 创建socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((self.host, self.port))
sock.listen(5)
print(f"安全服务器启动在 {self.host}:{self.port}")
while True:
try:
client_socket, client_address = sock.accept()
print(f"新的客户端连接: {client_address}")
# 包装为SSL socket
# 注意:由于没有真实证书,这里注释掉了实际的SSL包装
# ssl_socket = self.context.wrap_socket(client_socket, server_side=True)
# 模拟安全通信
self.handle_secure_client(client_socket, client_address)
except Exception as e:
print(f"处理客户端时发生错误: {e}")
except KeyboardInterrupt:
print("\n服务器停止")
finally:
sock.close()
def handle_secure_client(self, client_socket, client_address):
"""处理安全客户端连接"""
try:
# 模拟安全握手
client_socket.send(b"220 Secure Service Ready\n")
# 接收客户端数据
data = client_socket.recv(1024).decode('utf-8')
print(f"收到客户端数据: {data.strip()}")
# 发送响应(模拟加密通信)
response = "250 Secure communication established\n"
client_socket.send(response.encode('utf-8'))
# 模拟安全数据传输
secure_message = "这是通过安全连接传输的数据\n"
client_socket.send(secure_message.encode('utf-8'))
except Exception as e:
print(f"安全通信错误: {e}")
finally:
client_socket.close()
class EncryptionDemo:
"""加密演示"""
@staticmethod
def hash_demo():
"""哈希函数演示"""
print("=== 哈希函数 ===")
data = "Hello, Secure World!"
# MD5 (不推荐用于安全用途)
md5_hash = hashlib.md5(data.encode()).hexdigest()
print(f"MD5: {md5_hash}")
# SHA-256
sha256_hash = hashlib.sha256(data.encode()).hexdigest()
print(f"SHA-256: {sha256_hash}")
# SHA-512
sha512_hash = hashlib.sha512(data.encode()).hexdigest()
print(f"SHA-512: {sha512_hash}")
# 加盐哈希
salt = secrets.token_bytes(16)
salted_data = salt + data.encode()
salted_hash = hashlib.sha256(salted_data).hexdigest()
print(f"加盐SHA-256: {salted_hash}")
@staticmethod
def hmac_demo():
"""HMAC演示"""
print("\n=== HMAC (哈希消息认证码) ===")
message = "Important message"
key = secrets.token_bytes(32) # 256位密钥
# 创建HMAC
hmac_obj = hmac.new(key, message.encode(), hashlib.sha256)
hmac_digest = hmac_obj.hexdigest()
print(f"消息: {message}")
print(f"HMAC: {hmac_digest}")
# 验证HMAC
hmac_verify = hmac.new(key, message.encode(), hashlib.sha256)
try:
hmac_verify.hexverify(hmac_digest)
print("HMAC验证: 成功")
except Exception:
print("HMAC验证: 失败")
@staticmethod
def symmetric_encryption_demo():
"""对称加密演示"""
print("\n=== 对称加密 (AES) ===")
# 生成密钥
password = b"my_secret_password"
salt = secrets.token_bytes(16)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password))
# 创建Fernet加密器
fernet = Fernet(key)
# 加密数据
original_data = "这是需要加密的敏感数据"
encrypted_data = fernet.encrypt(original_data.encode())
print(f"原始数据: {original_data}")
print(f"加密数据: {encrypted_data.decode()}")
# 解密数据
decrypted_data = fernet.decrypt(encrypted_data).decode()
print(f"解密数据: {decrypted_data}")
# 验证完整性
if original_data == decrypted_data:
print("加解密验证: 成功")
else:
print("加解密验证: 失败")
@staticmethod
def asymmetric_encryption_demo():
"""非对称加密演示"""
print("\n=== 非对称加密 (RSA) ===")
# 生成RSA密钥对
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
public_key = private_key.public_key()
# 要加密的消息
message = b"这是使用RSA加密的消息"
# 使用公钥加密
ciphertext = public_key.encrypt(
message,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
print(f"原始消息: {message.decode()}")
print(f"加密消息: {base64.b64encode(ciphertext).decode()}")
# 使用私钥解密
decrypted_message = private_key.decrypt(
ciphertext,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
print(f"解密消息: {decrypted_message.decode()}")
# 数字签名
print("\n=== 数字签名 ===")
# 创建签名
signature = private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
print(f"签名: {base64.b64encode(signature).decode()}")
# 验证签名
try:
public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
print("签名验证: 成功")
except Exception:
print("签名验证: 失败")
class SecurityUtilities:
"""安全工具"""
@staticmethod
def generate_secure_token(length=32):
"""生成安全随机令牌"""
token = secrets.token_urlsafe(length)
print(f"安全令牌: {token}")
return token
@staticmethod
def password_strength_check(password):
"""检查密码强度"""
score = 0
feedback = []
if len(password) >= 8:
score += 1
else:
feedback.append("密码长度至少8位")
if any(c.islower() for c in password):
score += 1
else:
feedback.append("需要小写字母")
if any(c.isupper() for c in password):
score += 1
else:
feedback.append("需要大写字母")
if any(c.isdigit() for c in password):
score += 1
else:
feedback.append("需要数字")
if any(not c.isalnum() for c in password):
score += 1
else:
feedback.append("需要特殊字符")
strength_levels = ["非常弱", "弱", "一般", "强", "非常强"]
strength = strength_levels[score - 1] if score > 0 else "无效"
print(f"密码强度: {strength} ({score}/5)")
if feedback:
print("改进建议:", ", ".join(feedback))
return score
def network_security_demo():
"""网络安全编程演示"""
print("=== 网络安全编程 ===")
# 启动安全服务器(在后台线程)
import threading
server = SecureSocketServer()
server_thread = threading.Thread(target=server.start)
server_thread.daemon = True
server_thread.start()
# 等待服务器启动
import time
time.sleep(1)
# 加密演示
encryption = EncryptionDemo()
encryption.hash_demo()
encryption.hmac_demo()
encryption.symmetric_encryption_demo()
encryption.asymmetric_encryption_demo()
# 安全工具演示
print("\n=== 安全工具 ===")
security_utils = SecurityUtilities()
# 生成安全令牌
security_utils.generate_secure_token()
# 密码强度检查
print("\n密码强度检查:")
test_passwords = ["123", "password", "Password1", "StrongPass123!", "V3ry$tr0ngP@ssw0rd"]
for pwd in test_passwords:
print(f"\n测试密码: {pwd}")
security_utils.password_strength_check(pwd)
# 安全连接测试
print("\n=== 安全连接测试 ===")
try:
# 创建SSL客户端连接
context = ssl.create_default_context()
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE # 仅用于测试
# 注意:由于是自签名证书,实际连接会失败
# with socket.create_connection(('localhost', 8443)) as sock:
# with context.wrap_socket(sock, server_hostname='localhost') as ssock:
# print(f"SSL连接信息: {ssock.version()}")
# ssock.send(b"HELO secure client\n")
# response = ssock.recv(1024)
# print(f"服务器响应: {response.decode()}")
print("安全连接测试完成(模拟)")
except Exception as e:
print(f"安全连接测试失败: {e}")
if __name__ == "__main__":
network_security_demo()
5. 数据分析 (5个脚本)
5.1 数据分析基础.py
python
# 105.1 数据分析基础.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')
class DataAnalyzer:
"""数据分析器基础类"""
def __init__(self):
plt.style.use('seaborn-v0_8')
self.df = None
def generate_sample_data(self, n_samples=1000):
"""生成示例销售数据"""
np.random.seed(42)
dates = pd.date_range('2023-01-01', periods=n_samples, freq='D')
products = ['Product_A', 'Product_B', 'Product_C', 'Product_D', 'Product_E']
regions = ['North', 'South', 'East', 'West']
categories = ['Electronics', 'Clothing', 'Food', 'Books', 'Home']
data = {
'date': np.random.choice(dates, n_samples),
'product': np.random.choice(products, n_samples),
'region': np.random.choice(regions, n_samples),
'category': np.random.choice(categories, n_samples),
'sales_amount': np.random.normal(1000, 300, n_samples).round(2),
'quantity': np.random.randint(1, 100, n_samples),
'customer_rating': np.random.uniform(1, 5, n_samples).round(1)
}
# 添加一些缺失值
for col in ['sales_amount', 'customer_rating']:
mask = np.random.random(n_samples) < 0.05
data[col] = np.where(mask, np.nan, data[col])
self.df = pd.DataFrame(data)
return self.df
def basic_analysis(self):
"""基础数据分析"""
print("=== 基础数据分析 ===")
# 数据基本信息
print("1. 数据基本信息:")
print(f"数据集形状: {self.df.shape}")
print(f"列名: {list(self.df.columns)}")
print("\n数据类型:")
print(self.df.dtypes)
# 数据预览
print("\n2. 数据预览:")
print("前5行数据:")
print(self.df.head())
print("\n后5行数据:")
print(self.df.tail())
# 统计描述
print("\n3. 数值列统计描述:")
print(self.df.describe())
# 缺失值分析
print("\n4. 缺失值分析:")
missing_data = self.df.isnull().sum()
missing_percent = (missing_data / len(self.df)) * 100
missing_df = pd.DataFrame({
'缺失数量': missing_data,
'缺失比例%': missing_percent.round(2)
})
print(missing_df[missing_df['缺失数量'] > 0])
# 唯一值分析
print("\n5. 分类变量唯一值:")
categorical_cols = self.df.select_dtypes(include=['object']).columns
for col in categorical_cols:
unique_count = self.df[col].nunique()
print(f"{col}: {unique_count} 个唯一值")
if unique_count <= 10: # 只显示少量唯一值的具体内容
print(f" 值: {self.df[col].unique()}")
def data_cleaning(self):
"""数据清洗"""
print("\n=== 数据清洗 ===")
# 处理缺失值
print("1. 处理缺失值:")
original_shape = self.df.shape
# 数值列用中位数填充
numeric_cols = self.df.select_dtypes(include=[np.number]).columns
for col in numeric_cols:
if self.df[col].isnull().sum() > 0:
median_val = self.df[col].median()
self.df[col].fillna(median_val, inplace=True)
print(f" {col}: 用中位数 {median_val:.2f} 填充")
# 分类列用众数填充
categorical_cols = self.df.select_dtypes(include=['object']).columns
for col in categorical_cols:
if self.df[col].isnull().sum() > 0:
mode_val = self.df[col].mode()[0]
self.df[col].fillna(mode_val, inplace=True)
print(f" {col}: 用众数 '{mode_val}' 填充")
print(f"清洗后数据形状: {self.df.shape} (无变化)")
# 数据类型转换
print("\n2. 数据类型优化:")
print("转换前内存使用:")
print(self.df.info(memory_usage='deep'))
# 优化数值类型
for col in self.df.select_dtypes(include=[np.number]).columns:
col_min = self.df[col].min()
col_max = self.df[col].max()
# 整数优化
if self.df[col].dtype == float and self.df[col].round().equals(self.df[col]):
if col_min >= 0:
if col_max < 255:
self.df[col] = self.df[col].astype(np.uint8)
elif col_max < 65535:
self.df[col] = self.df[col].astype(np.uint16)
else:
if col_max < 127 and col_min > -128:
self.df[col] = self.df[col].astype(np.int8)
elif col_max < 32767 and col_min > -32768:
self.df[col] = self.df[col].astype(np.int16)
print("\n转换后内存使用:")
print(self.df.info(memory_usage='deep'))
def exploratory_analysis(self):
"""探索性数据分析"""
print("\n=== 探索性数据分析 ===")
# 单变量分析
print("1. 单变量分析:")
# 数值变量分布
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
numeric_cols = ['sales_amount', 'quantity', 'customer_rating']
for i, col in enumerate(numeric_cols):
ax = axes[i//2, i%2]
self.df[col].hist(bins=30, ax=ax, alpha=0.7, color='skyblue')
ax.set_title(f'{col} 分布')
ax.set_xlabel(col)
ax.set_ylabel('频数')
# 添加箱线图
ax = axes[1, 1]
self.df[numeric_cols].boxplot(ax=ax)
ax.set_title('数值变量箱线图')
ax.tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()
# 分类变量分析
print("\n2. 分类变量分析:")
categorical_cols = ['product', 'region', 'category']
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
for i, col in enumerate(categorical_cols):
value_counts = self.df[col].value_counts()
axes[i].pie(value_counts.values, labels=value_counts.index, autopct='%1.1f%%')
axes[i].set_title(f'{col} 分布')
plt.tight_layout()
plt.show()
# 相关性分析
print("\n3. 相关性分析:")
correlation_matrix = self.df[numeric_cols].corr()
plt.figure(figsize=(8, 6))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
square=True, fmt='.2f')
plt.title('数值变量相关性热力图')
plt.show()
def data_analysis_basics():
"""数据分析基础演示"""
analyzer = DataAnalyzer()
# 生成数据
print("生成示例数据...")
df = analyzer.generate_sample_data(1000)
# 基础分析
analyzer.basic_analysis()
# 数据清洗
analyzer.data_cleaning()
# 探索性分析
analyzer.exploratory_analysis()
return analyzer
if __name__ == "__main__":
analyzer = data_analysis_basics()
5.2 数据清洗和预处理.py
python
# 105.2 数据清洗和预处理.py
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
from sklearn.impute import SimpleImputer, KNNImputer
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
class DataPreprocessor:
"""数据预处理器"""
def __init__(self):
self.scalers = {}
self.encoders = {}
self.imputers = {}
def load_complex_data(self):
"""加载包含各种问题的复杂数据集"""
np.random.seed(42)
n_samples = 500
# 创建包含各种数据问题的数据集
data = {
'customer_id': range(1, n_samples + 1),
'age': np.random.normal(35, 10, n_samples).round(),
'income': np.random.lognormal(10, 1, n_samples).round(2),
'spending_score': np.random.randint(1, 100, n_samples),
'city': np.random.choice(['北京', '上海', '广州', '深圳', '杭州', '成都', '武汉', '西安'], n_samples),
'membership_level': np.random.choice(['青铜', '白银', '黄金', '铂金', '钻石'], n_samples),
'last_purchase_days': np.random.exponential(30, n_samples).round(),
'total_orders': np.random.poisson(15, n_samples)
}
df = pd.DataFrame(data)
# 故意添加数据问题
# 1. 异常值
df.loc[0, 'age'] = 150 # 不可能的年龄
df.loc[1, 'income'] = 1000000 # 异常高收入
df.loc[2, 'spending_score'] = -10 # 负分
# 2. 缺失值
missing_indices = np.random.choice(n_samples, 50, replace=False)
df.loc[missing_indices[:25], 'income'] = np.nan
df.loc[missing_indices[25:40], 'age'] = np.nan
df.loc[missing_indices[40:], 'city'] = None
# 3. 不一致的数据
df.loc[10, 'city'] = 'Beijing' # 英文而不是中文
df.loc[11, 'membership_level'] = 'Gold' # 英文而不是中文
# 4. 重复数据
duplicate_row = df.iloc[20].copy()
df = pd.concat([df, duplicate_row.to_frame().T], ignore_index=True)
return df
def detect_outliers(self, df, column):
"""检测异常值"""
Q1 = df[column].quantile(0.25)
Q3 = df[column].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outliers = df[(df[column] < lower_bound) | (df[column] > upper_bound)]
print(f"{column} 异常值检测:")
print(f" 正常范围: [{lower_bound:.2f}, {upper_bound:.2f}]")
print(f" 发现 {len(outliers)} 个异常值")
return outliers
def handle_outliers(self, df, column, method='clip'):
"""处理异常值"""
original_stats = df[column].describe()
if method == 'clip':
# 缩尾处理
lower = df[column].quantile(0.01)
upper = df[column].quantile(0.99)
df[column] = df[column].clip(lower, upper)
print(f" 缩尾处理: [{lower:.2f}, {upper:.2f}]")
elif method == 'remove':
# 移除异常值
Q1 = df[column].quantile(0.25)
Q3 = df[column].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
mask = (df[column] >= lower_bound) & (df[column] <= upper_bound)
df = df[mask]
print(f" 移除异常值,剩余 {len(df)} 行")
elif method == 'transform':
# 对数变换
if df[column].min() > 0:
df[column] = np.log1p(df[column])
print(" 应用对数变换")
new_stats = df[column].describe()
print(f" 处理前: 均值={original_stats['mean']:.2f}, 标准差={original_stats['std']:.2f}")
print(f" 处理后: 均值={new_stats['mean']:.2f}, 标准差={new_stats['std']:.2f}")
return df
def advanced_imputation(self, df):
"""高级缺失值填充"""
print("\n=== 高级缺失值填充 ===")
# 数值列缺失值处理
numeric_cols = df.select_dtypes(include=[np.number]).columns
numeric_cols_with_na = numeric_cols[df[numeric_cols].isnull().any()]
for col in numeric_cols_with_na:
print(f"\n处理 {col} 的缺失值:")
# 方法1: KNN填充
knn_imputer = KNNImputer(n_neighbors=5)
df_knn = df.copy()
df_knn[col] = knn_imputer.fit_transform(df[[col]])
# 方法2: 多重填充(简化版)
df_multiple = df.copy()
for fill_method in ['mean', 'median', 'most_frequent']:
imputer = SimpleImputer(strategy=fill_method)
df_temp = df.copy()
df_temp[col] = imputer.fit_transform(df[[col]])
# 在实际应用中,这里会创建多个填充数据集并合并
# 使用中位数填充(实际选择)
imputer = SimpleImputer(strategy='median')
df[col] = imputer.fit_transform(df[[col]])
print(f" 使用中位数填充: {imputer.statistics_[0]:.2f}")
# 添加缺失值标志
df[f'{col}_was_missing'] = df[col].isnull()
def feature_engineering(self, df):
"""特征工程"""
print("\n=== 特征工程 ===")
# 1. 创建新特征
print("1. 创建新特征:")
# 基于业务逻辑创建特征
df['age_group'] = pd.cut(df['age'],
bins=[0, 25, 35, 45, 55, 100],
labels=['青年', '中青年', '中年', '中老年', '老年'])
df['income_level'] = pd.qcut(df['income'], q=4,
labels=['低收入', '中低收入', '中高收入', '高收入'])
# 交互特征
df['spending_per_order'] = df['spending_score'] / df['total_orders']
df['spending_per_order'] = df['spending_per_order'].replace([np.inf, -np.inf], np.nan)
df['spending_per_order'].fillna(0, inplace=True)
print(" 创建的特征: age_group, income_level, spending_per_order")
# 2. 编码分类变量
print("\n2. 编码分类变量:")
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
for col in categorical_cols:
if df[col].nunique() <= 10: # 低基数使用one-hot
dummies = pd.get_dummies(df[col], prefix=col)
df = pd.concat([df, dummies], axis=1)
print(f" {col}: One-Hot编码 ({df[col].nunique()}个类别)")
else: # 高基数使用标签编码
le = LabelEncoder()
df[f'{col}_encoded'] = le.fit_transform(df[col].fillna('Unknown'))
self.encoders[col] = le
print(f" {col}: 标签编码")
# 3. 数值特征缩放
print("\n3. 数值特征缩放:")
numeric_cols = df.select_dtypes(include=[np.number]).columns
# 排除ID列和标志列
numeric_cols = [col for col in numeric_cols if not col.endswith('_was_missing')
and col != 'customer_id']
# 标准化
scaler = StandardScaler()
df_scaled = scaler.fit_transform(df[numeric_cols])
df_scaled = pd.DataFrame(df_scaled, columns=[f'{col}_standardized' for col in numeric_cols])
# 归一化
minmax_scaler = MinMaxScaler()
df_normalized = minmax_scaler.fit_transform(df[numeric_cols])
df_normalized = pd.DataFrame(df_normalized, columns=[f'{col}_normalized' for col in numeric_cols])
df = pd.concat([df, df_scaled, df_normalized], axis=1)
self.scalers['standard'] = scaler
self.scalers['minmax'] = minmax_scaler
print(f" 标准化特征: {len(numeric_cols)}个")
print(f" 归一化特征: {len(numeric_cols)}个")
return df
def data_quality_report(self, df):
"""数据质量报告"""
print("\n=== 数据质量报告 ===")
report_data = []
for col in df.columns:
col_data = {
'字段名': col,
'数据类型': df[col].dtype,
'总数': len(df),
'非空数': df[col].count(),
'缺失数': df[col].isnull().sum(),
'缺失比例%': (df[col].isnull().sum() / len(df) * 100).round(2),
'唯一值数': df[col].nunique()
}
if df[col].dtype in [np.number]:
col_data.update({
'均值': df[col].mean(),
'标准差': df[col].std(),
'最小值': df[col].min(),
'最大值': df[col].max()
})
else:
col_data.update({
'最常见值': df[col].mode()[0] if not df[col].empty else None,
'最常见值频数': df[col].value_counts().iloc[0] if not df[col].empty else 0
})
report_data.append(col_data)
report_df = pd.DataFrame(report_data)
print(report_df.to_string(index=False))
return report_df
def data_cleaning_demo():
"""数据清洗和预处理演示"""
preprocessor = DataPreprocessor()
# 加载数据
print("加载复杂数据集...")
df = preprocessor.load_complex_data()
print("原始数据形状:", df.shape)
print("\n原始数据预览:")
print(df.head())
# 数据质量报告
preprocessor.data_quality_report(df)
# 处理异常值
print("\n=== 异常值处理 ===")
numeric_cols = ['age', 'income', 'spending_score']
for col in numeric_cols:
outliers = preprocessor.detect_outliers(df, col)
if len(outliers) > 0:
df = preprocessor.handle_outliers(df, col, method='clip')
# 处理缺失值
preprocessor.advanced_imputation(df)
# 数据一致性处理
print("\n=== 数据一致性处理 ===")
df['city'] = df['city'].replace({'Beijing': '北京'})
df['membership_level'] = df['membership_level'].replace({'Gold': '黄金'})
# 去除重复数据
print("去除重复数据:")
before_dedup = len(df)
df = df.drop_duplicates()
after_dedup = len(df)
print(f" 去重前: {before_dedup} 行")
print(f" 去重后: {after_dedup} 行")
print(f" 移除: {before_dedup - after_dedup} 行重复数据")
# 特征工程
df = preprocessor.feature_engineering(df)
print(f"\n最终数据形状: {df.shape}")
print("预处理完成!")
return df, preprocessor
if __name__ == "__main__":
df, preprocessor = data_cleaning_demo()
5.3 数据可视化分析.py
python
# 105.3 数据可视化分析.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')
class DataVisualizer:
"""数据可视化分析器"""
def __init__(self):
plt.style.use('seaborn-v0_8')
self.df = None
def create_sample_data(self):
"""创建示例销售数据"""
np.random.seed(42)
n_samples = 1000
dates = pd.date_range('2023-01-01', periods=365, freq='D')
products = ['笔记本电脑', '智能手机', '平板电脑', '智能手表', '耳机']
regions = ['华东', '华南', '华北', '西南', '西北', '东北']
categories = ['电子产品', '配件', '服务']
data = {
'date': np.random.choice(dates, n_samples),
'product': np.random.choice(products, n_samples, p=[0.3, 0.25, 0.2, 0.15, 0.1]),
'region': np.random.choice(regions, n_samples),
'category': np.random.choice(categories, n_samples, p=[0.6, 0.3, 0.1]),
'sales': np.random.lognormal(8, 1, n_samples).round(2),
'quantity': np.random.poisson(3, n_samples) + 1,
'profit': np.random.normal(200, 50, n_samples).round(2),
'customer_rating': np.random.uniform(3, 5, n_samples).round(1)
}
self.df = pd.DataFrame(data)
# 添加月份和季度
self.df['month'] = self.df['date'].dt.month
self.df['quarter'] = self.df['date'].dt.quarter
self.df['weekday'] = self.df['date'].dt.day_name()
return self.df
def basic_plots(self):
"""基础图表"""
print("=== 基础数据可视化 ===")
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 1. 直方图 - 销售额分布
axes[0, 0].hist(self.df['sales'], bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[0, 0].set_title('销售额分布')
axes[0, 0].set_xlabel('销售额')
axes[0, 0].set_ylabel('频数')
# 2. 箱线图 - 各地区利润
self.df.boxplot(column='profit', by='region', ax=axes[0, 1])
axes[0, 1].set_title('各地区利润分布')
axes[0, 1].set_xlabel('地区')
axes[0, 1].set_ylabel('利润')
# 3. 饼图 - 产品销量占比
product_sales = self.df['product'].value_counts()
axes[1, 0].pie(product_sales.values, labels=product_sales.index, autopct='%1.1f%%')
axes[1, 0].set_title('产品销量占比')
# 4. 散点图 - 销售额 vs 利润
axes[1, 1].scatter(self.df['sales'], self.df['profit'], alpha=0.6, color='green')
axes[1, 1].set_title('销售额 vs 利润')
axes[1, 1].set_xlabel('销售额')
axes[1, 1].set_ylabel('利润')
plt.tight_layout()
plt.show()
def advanced_plots(self):
"""高级图表"""
print("\n=== 高级数据可视化 ===")
# 1. 热力图 - 相关性矩阵
plt.figure(figsize=(10, 8))
numeric_cols = ['sales', 'quantity', 'profit', 'customer_rating']
correlation_matrix = self.df[numeric_cols].corr()
sns.heatmap(correlation_matrix, annot=True, cmap='RdYlBu', center=0,
square=True, fmt='.2f', linewidths=0.5)
plt.title('数值变量相关性热力图')
plt.tight_layout()
plt.show()
# 2. 时间序列图
plt.figure(figsize=(12, 6))
# 按月聚合销售数据
monthly_sales = self.df.groupby(self.df['date'].dt.to_period('M'))['sales'].sum()
monthly_sales.index = monthly_sales.index.to_timestamp()
plt.plot(monthly_sales.index, monthly_sales.values, marker='o', linewidth=2)
plt.title('月销售额趋势')
plt.xlabel('月份')
plt.ylabel('销售额')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# 3. 多变量分析 - 小提琴图
plt.figure(figsize=(12, 6))
sns.violinplot(data=self.df, x='region', y='sales', hue='category')
plt.title('各地区-类别销售额分布')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
def seaborn_advanced_plots(self):
"""Seaborn高级图表"""
print("\n=== Seaborn高级可视化 ===")
# 1. 配对图
print("生成配对图...")
numeric_cols = ['sales', 'quantity', 'profit', 'customer_rating']
sns.pairplot(self.df[numeric_cols], diag_kind='kde', plot_kws={'alpha': 0.6})
plt.suptitle('数值变量配对图', y=1.02)
plt.show()
# 2. 分面网格
print("生成分面网格...")
g = sns.FacetGrid(self.df, col='region', col_wrap=3, height=4)
g.map(sns.histplot, 'sales', kde=True)
g.set_titles('{col_name}地区')
g.fig.suptitle('各地区销售额分布', y=1.02)
plt.show()
# 3. 聚类热力图
print("生成聚类热力图...")
# 创建透视表
pivot_table = self.df.pivot_table(values='sales', index='product', columns='region', aggfunc='mean')
plt.figure(figsize=(10, 8))
sns.clustermap(pivot_table.fillna(0), cmap='viridis', standard_scale=1)
plt.title('产品-地区销售额聚类热力图')
plt.show()
def interactive_plots(self):
"""交互式图表"""
print("\n=== 交互式可视化 ===")
# 1. 交互式散点图
fig = px.scatter(self.df, x='sales', y='profit', color='region',
size='quantity', hover_data=['product', 'customer_rating'],
title='销售额-利润关系图(按地区)')
fig.show()
# 2. 交互式时间序列
daily_sales = self.df.groupby('date')['sales'].sum().reset_index()
fig = px.line(daily_sales, x='date', y='sales',
title='日销售额趋势',
labels={'sales': '销售额', 'date': '日期'})
fig.update_xaxes(rangeslider_visible=True)
fig.show()
# 3. 交互式旭日图
fig = px.sunburst(self.df, path=['category', 'product', 'region'],
values='sales', title='销售数据旭日图')
fig.show()
# 4. 交互式平行坐标图
fig = px.parallel_categories(self.df, dimensions=['region', 'category', 'product'],
color='sales', title='多维度平行坐标图')
fig.show()
def dashboard_creation(self):
"""创建仪表板"""
print("\n=== 数据仪表板 ===")
# 使用plotly创建仪表板
fig = make_subplots(
rows=2, cols=2,
subplot_titles=('月销售额趋势', '产品销量分布', '地区利润对比', '客户评分分布'),
specs=[[{"type": "scatter"}, {"type": "bar"}],
[{"type": "box"}, {"type": "histogram"}]]
)
# 1. 月销售额趋势
monthly_sales = self.df.groupby(self.df['date'].dt.to_period('M'))['sales'].sum()
monthly_sales.index = monthly_sales.index.to_timestamp()
fig.add_trace(
go.Scatter(x=monthly_sales.index, y=monthly_sales.values, mode='lines+markers'),
row=1, col=1
)
# 2. 产品销量分布
product_quantity = self.df.groupby('product')['quantity'].sum().sort_values(ascending=False)
fig.add_trace(
go.Bar(x=product_quantity.index, y=product_quantity.values),
row=1, col=2
)
# 3. 地区利润对比
for region in self.df['region'].unique():
region_data = self.df[self.df['region'] == region]['profit']
fig.add_trace(
go.Box(y=region_data, name=region),
row=2, col=1
)
# 4. 客户评分分布
fig.add_trace(
go.Histogram(x=self.df['customer_rating'], nbinsx=20),
row=2, col=2
)
fig.update_layout(height=800, showlegend=False,
title_text="销售数据分析仪表板")
fig.show()
def statistical_plots(self):
"""统计图表"""
print("\n=== 统计可视化 ===")
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 1. Q-Q图检验正态性
stats.probplot(self.df['sales'], dist="norm", plot=axes[0, 0])
axes[0, 0].set_title('销售额Q-Q图(正态性检验)')
# 2. 累积分布函数
sorted_sales = np.sort(self.df['sales'])
cdf = np.arange(1, len(sorted_sales) + 1) / len(sorted_sales)
axes[0, 1].plot(sorted_sales, cdf, linewidth=2)
axes[0, 1].set_title('销售额累积分布函数')
axes[0, 1].set_xlabel('销售额')
axes[0, 1].set_ylabel('CDF')
axes[0, 1].grid(True, alpha=0.3)
# 3. 核密度估计
for region in self.df['region'].unique()[:3]: # 只显示前3个地区
region_data = self.df[self.df['region'] == region]['sales']
sns.kdeplot(region_data, label=region, ax=axes[1, 0])
axes[1, 0].set_title('各地区销售额核密度估计')
axes[1, 0].legend()
# 4. 2D密度图
sns.kdeplot(data=self.df, x='sales', y='profit', fill=True, ax=axes[1, 1])
axes[1, 1].set_title('销售额-利润2D密度图')
plt.tight_layout()
plt.show()
def data_visualization_demo():
"""数据可视化演示"""
visualizer = DataVisualizer()
# 创建数据
print("创建示例数据...")
df = visualizer.create_sample_data()
print(f"数据形状: {df.shape}")
# 基础图表
visualizer.basic_plots()
# 高级图表
visualizer.advanced_plots()
# Seaborn图表
visualizer.seaborn_advanced_plots()
# 交互式图表
visualizer.interactive_plots()
# 统计图表
visualizer.statistical_plots()
# 仪表板
visualizer.dashboard_creation()
print("数据可视化演示完成!")
if __name__ == "__main__":
data_visualization_demo()
5.4 时间序列分析.py
python
# 105.4 时间序列分析.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.stattools import adfuller, acf, pacf
from statsmodels.tsa.arima.model import ARIMA
from sklearn.metrics import mean_squared_error, mean_absolute_error
import warnings
warnings.filterwarnings('ignore')
class TimeSeriesAnalyzer:
"""时间序列分析器"""
def __init__(self):
self.df = None
self.train = None
self.test = None
def create_time_series_data(self):
"""创建时间序列数据"""
np.random.seed(42)
# 创建日期范围
dates = pd.date_range('2018-01-01', '2023-12-31', freq='D')
# 基础趋势
trend = np.linspace(100, 500, len(dates))
# 季节性成分(年度)
seasonal = 50 * np.sin(2 * np.pi * np.arange(len(dates)) / 365.25)
# 周期性成分(月度)
cycle = 30 * np.sin(2 * np.pi * np.arange(len(dates)) / 30.5)
# 随机噪声
noise = np.random.normal(0, 20, len(dates))
# 合成时间序列
sales = trend + seasonal + cycle + noise
# 创建DataFrame
self.df = pd.DataFrame({
'date': dates,
'sales': sales,
'temperature': np.random.normal(20, 10, len(dates)), # 外部变量
'promotion': np.random.choice([0, 1], len(dates), p=[0.8, 0.2]) # 促销活动
})
# 设置日期为索引
self.df.set_index('date', inplace=True)
# 添加一些缺失值(真实数据中常见)
missing_indices = np.random.choice(len(self.df), 50, replace=False)
self.df.loc[self.df.index[missing_indices], 'sales'] = np.nan
return self.df
def exploratory_time_series_analysis(self):
"""探索性时间序列分析"""
print("=== 探索性时间序列分析 ===")
# 填充缺失值
self.df['sales'].fillna(method='ffill', inplace=True)
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# 1. 原始时间序列
axes[0, 0].plot(self.df.index, self.df['sales'], linewidth=1)
axes[0, 0].set_title('原始销售数据')
axes[0, 0].set_xlabel('日期')
axes[0, 0].set_ylabel('销售额')
axes[0, 0].grid(True, alpha=0.3)
# 2. 滚动统计
rolling_mean = self.df['sales'].rolling(window=30).mean()
rolling_std = self.df['sales'].rolling(window=30).std()
axes[0, 1].plot(self.df.index, self.df['sales'], label='原始数据', alpha=0.5)
axes[0, 1].plot(self.df.index, rolling_mean, label='30天滚动均值', color='red')
axes[0, 1].plot(self.df.index, rolling_std, label='30天滚动标准差', color='green')
axes[0, 1].set_title('滚动统计量')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# 3. 季节性分析 - 按月份
monthly_sales = self.df['sales'].resample('M').mean()
axes[1, 0].bar(monthly_sales.index, monthly_sales.values, alpha=0.7)
axes[1, 0].set_title('月平均销售额')
axes[1, 0].set_xlabel('月份')
axes[1, 0].set_ylabel('平均销售额')
axes[1, 0].tick_params(axis='x', rotation=45)
# 4. 季节性分析 - 按星期
self.df['weekday'] = self.df.index.day_name()
weekday_sales = self.df.groupby('weekday')['sales'].mean()
weekday_order = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
weekday_sales = weekday_sales.reindex(weekday_order)
axes[1, 1].bar(weekday_sales.index, weekday_sales.values, alpha=0.7, color='orange')
axes[1, 1].set_title('星期平均销售额')
axes[1, 1].set_xlabel('星期')
axes[1, 1].set_ylabel('平均销售额')
axes[1, 1].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()
return self.df
def time_series_decomposition(self):
"""时间序列分解"""
print("\n=== 时间序列分解 ===")
# 使用statsmodels进行季节性分解
# 注意:需要足够的数据点,这里使用月度数据
monthly_data = self.df['sales'].resample('M').mean()
# 季节性分解
decomposition = seasonal_decompose(monthly_data, model='additive', period=12)
fig, axes = plt.subplots(4, 1, figsize=(12, 10))
decomposition.observed.plot(ax=axes[0], title='原始序列')
decomposition.trend.plot(ax=axes[1], title='趋势成分')
decomposition.seasonal.plot(ax=axes[2], title='季节性成分')
decomposition.resid.plot(ax=axes[3], title='残差成分')
for ax in axes:
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
return decomposition
def stationarity_test(self):
"""平稳性检验"""
print("\n=== 平稳性检验 ===")
# Augmented Dickey-Fuller检验
result = adfuller(self.df['sales'].dropna())
print(f'ADF统计量: {result[0]:.6f}')
print(f'p值: {result[1]:.6f}')
print(f'使用的滞后数: {result[2]}')
print(f'观测数: {result[3]}')
print('临界值:')
for key, value in result[4].items():
print(f' {key}: {value:.6f}')
if result[1] <= 0.05:
print("结论: 序列是平稳的 (拒绝原假设)")
else:
print("结论: 序列是非平稳的 (不能拒绝原假设)")
return result
def autocorrelation_analysis(self):
"""自相关分析"""
print("\n=== 自相关分析 ===")
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
# 自相关函数 (ACF)
acf_values = acf(self.df['sales'], nlags=40)
axes[0].stem(acf_values)
axes[0].axhline(y=0, linestyle='--', color='gray')
axes[0].axhline(y=-1.96/np.sqrt(len(self.df)), linestyle='--', color='red')
axes[0].axhline(y=1.96/np.sqrt(len(self.df)), linestyle='--', color='red')
axes[0].set_title('自相关函数 (ACF)')
axes[0].set_xlabel('滞后')
axes[0].set_ylabel('ACF')
# 偏自相关函数 (PACF)
pacf_values = pacf(self.df['sales'], nlags=40)
axes[1].stem(pacf_values)
axes[1].axhline(y=0, linestyle='--', color='gray')
axes[1].axhline(y=-1.96/np.sqrt(len(self.df)), linestyle='--', color='red')
axes[1].axhline(y=1.96/np.sqrt(len(self.df)), linestyle='--', color='red')
axes[1].set_title('偏自相关函数 (PACF)')
axes[1].set_xlabel('滞后')
axes[1].set_ylabel('PACF')
plt.tight_layout()
plt.show()
return acf_values, pacf_values
def prepare_train_test(self, test_size=0.2):
"""准备训练集和测试集"""
# 按时间顺序分割
split_index = int(len(self.df) * (1 - test_size))
self.train = self.df['sales'][:split_index]
self.test = self.df['sales'][split_index:]
print(f"训练集大小: {len(self.train)}")
print(f"测试集大小: {len(self.test)}")
print(f"测试集比例: {test_size:.1%}")
return self.train, self.test
def arima_forecasting(self, order=(2,1,2)):
"""ARIMA模型预测"""
print("\n=== ARIMA模型预测 ===")
# 准备数据
self.prepare_train_test()
try:
# 训练ARIMA模型
model = ARIMA(self.train, order=order)
model_fit = model.fit()
print("ARIMA模型摘要:")
print(model_fit.summary())
# 预测
forecast = model_fit.forecast(steps=len(self.test))
forecast_index = self.test.index
# 计算评估指标
mse = mean_squared_error(self.test, forecast)
mae = mean_absolute_error(self.test, forecast)
rmse = np.sqrt(mse)
print(f"\n预测性能指标:")
print(f"均方误差 (MSE): {mse:.2f}")
print(f"均方根误差 (RMSE): {rmse:.2f}")
print(f"平均绝对误差 (MAE): {mae:.2f}")
# 绘制预测结果
plt.figure(figsize=(12, 6))
plt.plot(self.train.index, self.train, label='训练数据')
plt.plot(self.test.index, self.test, label='真实值', color='green')
plt.plot(forecast_index, forecast, label='预测值', color='red', linestyle='--')
plt.title('ARIMA模型预测')
plt.xlabel('日期')
plt.ylabel('销售额')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
return model_fit, forecast
except Exception as e:
print(f"ARIMA模型训练失败: {e}")
return None, None
def advanced_forecasting_methods(self):
"""高级预测方法"""
print("\n=== 高级预测方法 ===")
# 简单方法对比
methods = {
'朴素法': self.train.iloc[-1], # 最后一个观测值
'简单平均': self.train.mean(),
'移动平均': self.train.rolling(30).mean().iloc[-1]
}
print("简单预测方法对比:")
results = []
for method_name, prediction in methods.items():
# 对所有测试点使用相同的预测值
forecast = np.full(len(self.test), prediction)
mse = mean_squared_error(self.test, forecast)
mae = mean_absolute_error(self.test, forecast)
results.append({
'方法': method_name,
'MSE': mse,
'MAE': mae,
'预测值': prediction
})
print(f" {method_name}: MSE={mse:.2f}, MAE={mae:.2f}")
results_df = pd.DataFrame(results)
# 季节性朴素法
seasonal_naive = self.train.iloc[-365:] # 使用去年同期的数据
if len(seasonal_naive) >= len(self.test):
seasonal_forecast = seasonal_naive.values[:len(self.test)]
else:
# 如果不够长,重复使用
repeats = len(self.test) // len(seasonal_naive) + 1
seasonal_forecast = np.tile(seasonal_naive.values, repeats)[:len(self.test)]
seasonal_mse = mean_squared_error(self.test, seasonal_forecast)
seasonal_mae = mean_absolute_error(self.test, seasonal_forecast)
print(f" 季节性朴素法: MSE={seasonal_mse:.2f}, MAE={seasonal_mae:.2f}")
return results_df
def time_series_features(self):
"""时间序列特征工程"""
print("\n=== 时间序列特征工程 ===")
# 创建时间序列特征
self.df['year'] = self.df.index.year
self.df['month'] = self.df.index.month
self.df['quarter'] = self.df.index.quarter
self.df['dayofweek'] = self.df.index.dayofweek
self.df['weekend'] = self.df['dayofweek'].isin([5, 6]).astype(int)
# 滞后特征
for lag in [1, 7, 30]:
self.df[f'sales_lag_{lag}'] = self.df['sales'].shift(lag)
# 滚动统计特征
self.df['sales_rolling_mean_7'] = self.df['sales'].rolling(7).mean()
self.df['sales_rolling_std_7'] = self.df['sales'].rolling(7).std()
self.df['sales_rolling_min_7'] = self.df['sales'].rolling(7).min()
self.df['sales_rolling_max_7'] = self.df['sales'].rolling(7).max()
# 扩展窗口统计
self.df['sales_expanding_mean'] = self.df['sales'].expanding().mean()
# 季节性特征
self.df['dayofyear'] = self.df.index.dayofyear
self.df['weekofyear'] = self.df.index.isocalendar().week
print("创建的时间序列特征:")
new_features = [col for col in self.df.columns if col not in ['sales', 'temperature', 'promotion', 'weekday']]
for feature in new_features:
print(f" {feature}")
return self.df
def time_series_analysis_demo():
"""时间序列分析演示"""
analyzer = TimeSeriesAnalyzer()
# 创建时间序列数据
print("创建时间序列数据...")
df = analyzer.create_time_series_data()
print(f"时间序列数据形状: {df.shape}")
print(f"时间范围: {df.index.min()} 到 {df.index.max()}")
# 探索性分析
analyzer.exploratory_time_series_analysis()
# 时间序列分解
decomposition = analyzer.time_series_decomposition()
# 平稳性检验
stationarity_result = analyzer.stationarity_test()
# 自相关分析
acf_values, pacf_values = analyzer.autocorrelation_analysis()
# 特征工程
df_with_features = analyzer.time_series_features()
# ARIMA预测
arima_model, forecast = analyzer.arima_forecasting(order=(1,1,1))
# 高级预测方法
simple_methods = analyzer.advanced_forecasting_methods()
print("\n时间序列分析完成!")
return analyzer
if __name__ == "__main__":
analyzer = time_series_analysis_demo()
5.5 高级数据分析技术.py
python
# 105.5 高级数据分析技术.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import KMeans, DBSCAN
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from scipy import stats
from scipy.optimize import curve_fit
import warnings
warnings.filterwarnings('ignore')
class AdvancedDataAnalyzer:
"""高级数据分析技术"""
def __init__(self):
self.df = None
self.scaler = StandardScaler()
def create_advanced_dataset(self):
"""创建高级分析数据集"""
np.random.seed(42)
n_samples = 1000
# 创建多维度数据集
data = {
# 客户基本信息
'age': np.random.normal(35, 10, n_samples).round(),
'income': np.random.lognormal(10, 0.8, n_samples).round(2),
'education_years': np.random.randint(8, 20, n_samples),
# 消费行为
'monthly_spending': np.random.gamma(2, 500, n_samples).round(2),
'online_shopping_freq': np.random.poisson(8, n_samples),
'store_visits': np.random.poisson(4, n_samples),
# 产品偏好
'tech_interest': np.random.beta(2, 5, n_samples).round(3),
'fashion_interest': np.random.beta(3, 3, n_samples).round(3),
'food_interest': np.random.beta(5, 2, n_samples).round(3),
# 客户价值
'customer_lifetime_value': np.random.lognormal(9, 1, n_samples).round(2),
'retention_probability': np.random.beta(8, 2, n_samples).round(3)
}
self.df = pd.DataFrame(data)
# 添加一些异常值
self.df.loc[0, 'income'] = 500000 # 异常高收入
self.df.loc[1, 'monthly_spending'] = 10000 # 异常高消费
self.df.loc[2, 'age'] = 150 # 不可能年龄
return self.df
def clustering_analysis(self):
"""聚类分析"""
print("=== 聚类分析 ===")
# 选择用于聚类的特征
features = ['age', 'income', 'monthly_spending', 'tech_interest', 'fashion_interest', 'food_interest']
X = self.df[features].copy()
# 处理异常值
for col in features:
Q1 = X[col].quantile(0.25)
Q3 = X[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
X[col] = X[col].clip(lower_bound, upper_bound)
# 标准化
X_scaled = self.scaler.fit_transform(X)
# 1. K-means聚类
print("1. K-means聚类:")
# 使用肘部法则确定最佳K值
wcss = []
k_range = range(1, 11)
for k in k_range:
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
kmeans.fit(X_scaled)
wcss.append(kmeans.inertia_)
# 绘制肘部图
plt.figure(figsize=(10, 6))
plt.plot(k_range, wcss, 'bo-')
plt.xlabel('聚类数量 (K)')
plt.ylabel('WCSS (Within-Cluster Sum of Squares)')
plt.title('K-means肘部法则')
plt.grid(True, alpha=0.3)
plt.show()
# 选择K=4进行聚类
optimal_k = 4
kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
clusters = kmeans.fit_predict(X_scaled)
self.df['kmeans_cluster'] = clusters
# 聚类分析
cluster_summary = self.df.groupby('kmeans_cluster')[features].mean()
print("\n各聚类特征均值:")
print(cluster_summary)
# 2. DBSCAN聚类
print("\n2. DBSCAN聚类:")
dbscan = DBSCAN(eps=0.5, min_samples=5)
dbscan_clusters = dbscan.fit_predict(X_scaled)
self.df['dbscan_cluster'] = dbscan_clusters
n_dbscan_clusters = len(set(dbscan_clusters)) - (1 if -1 in dbscan_clusters else 0)
n_noise = list(dbscan_clusters).count(-1)
print(f"DBSCAN发现 {n_dbscan_clusters} 个聚类")
print(f"噪声点数量: {n_noise}")
# 可视化聚类结果
self.visualize_clusters(X_scaled, clusters, dbscan_clusters)
return clusters, dbscan_clusters
def visualize_clusters(self, X_scaled, kmeans_clusters, dbscan_clusters):
"""可视化聚类结果"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# 使用PCA降维可视化
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
# K-means聚类可视化
scatter1 = axes[0, 0].scatter(X_pca[:, 0], X_pca[:, 1], c=kmeans_clusters, cmap='viridis')
axes[0, 0].set_title('K-means聚类 (PCA降维)')
axes[0, 0].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%}方差)')
axes[0, 0].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%}方差)')
plt.colorbar(scatter1, ax=axes[0, 0])
# DBSCAN聚类可视化
scatter2 = axes[0, 1].scatter(X_pca[:, 0], X_pca[:, 1], c=dbscan_clusters, cmap='Set1')
axes[0, 1].set_title('DBSCAN聚类 (PCA降维)')
axes[0, 1].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%}方差)')
axes[0, 1].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%}方差)')
plt.colorbar(scatter2, ax=axes[0, 1])
# 使用t-SNE降维
tsne = TSNE(n_components=2, random_state=42)
X_tsne = tsne.fit_transform(X_scaled)
scatter3 = axes[1, 0].scatter(X_tsne[:, 0], X_tsne[:, 1], c=kmeans_clusters, cmap='viridis')
axes[1, 0].set_title('K-means聚类 (t-SNE降维)')
axes[1, 0].set_xlabel('t-SNE 1')
axes[1, 0].set_ylabel('t-SNE 2')
plt.colorbar(scatter3, ax=axes[1, 0])
scatter4 = axes[1, 1].scatter(X_tsne[:, 0], X_tsne[:, 1], c=dbscan_clusters, cmap='Set1')
axes[1, 1].set_title('DBSCAN聚类 (t-SNE降维)')
axes[1, 1].set_xlabel('t-SNE 1')
axes[1, 1].set_ylabel('t-SNE 2')
plt.colorbar(scatter4, ax=axes[1, 1])
plt.tight_layout()
plt.show()
def dimensionality_reduction(self):
"""降维分析"""
print("\n=== 降维分析 ===")
features = ['age', 'income', 'monthly_spending', 'tech_interest',
'fashion_interest', 'food_interest', 'online_shopping_freq', 'store_visits']
X = self.df[features].copy()
# 处理异常值
for col in features:
Q1 = X[col].quantile(0.25)
Q3 = X[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
X[col] = X[col].clip(lower_bound, upper_bound)
X_scaled = self.scaler.fit_transform(X)
# 1. 主成分分析 (PCA)
print("1. 主成分分析 (PCA):")
pca = PCA()
X_pca = pca.fit_transform(X_scaled)
# 方差解释率
explained_variance = pca.explained_variance_ratio_
cumulative_variance = explained_variance.cumsum()
print("各主成分方差解释率:")
for i, (var, cum_var) in enumerate(zip(explained_variance, cumulative_variance)):
print(f" 主成分 {i+1}: {var:.3f} ({cum_var:.3f} 累积)")
# 绘制碎石图
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(explained_variance) + 1), explained_variance, 'bo-', label='单个方差')
plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, 'ro-', label='累积方差')
plt.axhline(y=0.95, color='g', linestyle='--', label='95%方差线')
plt.xlabel('主成分数量')
plt.ylabel('解释方差比例')
plt.title('PCA碎石图')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 2. t-SNE可视化
print("\n2. t-SNE降维可视化:")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_tsne = tsne.fit_transform(X_scaled)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1],
c=self.df['customer_lifetime_value'],
cmap='viridis', alpha=0.7)
plt.colorbar(scatter, label='客户终身价值')
plt.title('t-SNE降维可视化 (按客户终身价值着色)')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.show()
return pca, tsne
def statistical_analysis(self):
"""统计分析"""
print("\n=== 统计分析 ===")
# 1. 相关分析
print("1. 相关分析:")
numeric_cols = self.df.select_dtypes(include=[np.number]).columns
correlation_matrix = self.df[numeric_cols].corr()
plt.figure(figsize=(12, 10))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
square=True, fmt='.2f', linewidths=0.5)
plt.title('变量相关性热力图')
plt.tight_layout()
plt.show()
# 2. 假设检验
print("\n2. 假设检验:")
# 检验收入是否服从正态分布
income_data = self.df['income'].dropna()
stat, p_value = stats.normaltest(income_data)
print(f"收入正态性检验:")
print(f" 统计量: {stat:.4f}, p值: {p_value:.4f}")
if p_value > 0.05:
print(" 结论: 收入服从正态分布 (不能拒绝原假设)")
else:
print(" 结论: 收入不服从正态分布 (拒绝原假设)")
# 比较高低消费群体的收入差异
spending_median = self.df['monthly_spending'].median()
high_spenders = self.df[self.df['monthly_spending'] > spending_median]['income']
low_spenders = self.df[self.df['monthly_spending'] <= spending_median]['income']
t_stat, t_p_value = stats.ttest_ind(high_spenders, low_spenders, equal_var=False)
print(f"\n高低消费群体收入差异检验:")
print(f" t统计量: {t_stat:.4f}, p值: {t_p_value:.4f}")
if t_p_value < 0.05:
print(" 结论: 高低消费群体的收入有显著差异")
else:
print(" 结论: 高低消费群体的收入无显著差异")
# 3. 回归分析
print("\n3. 回归分析:")
# 简单的线性回归
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
X_reg = self.df[['income', 'age', 'education_years']]
y_reg = self.df['customer_lifetime_value']
# 处理异常值
for col in X_reg.columns:
Q1 = X_reg[col].quantile(0.25)
Q3 = X_reg[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
X_reg[col] = X_reg[col].clip(lower_bound, upper_bound)
# 标准化
X_reg_scaled = self.scaler.fit_transform(X_reg)
model = LinearRegression()
model.fit(X_reg_scaled, y_reg)
y_pred = model.predict(X_reg_scaled)
r2 = r2_score(y_reg, y_pred)
print(f"线性回归 R²分数: {r2:.4f}")
print("回归系数:")
for feature, coef in zip(X_reg.columns, model.coef_):
print(f" {feature}: {coef:.4f}")
return correlation_matrix, (t_stat, t_p_value), model
def advanced_analytics(self):
"""高级分析技术"""
print("\n=== 高级分析技术 ===")
# 1. 客户分群分析
print("1. 客户分群分析:")
if 'kmeans_cluster' in self.df.columns:
cluster_profiles = self.df.groupby('kmeans_cluster').agg({
'age': 'mean',
'income': 'mean',
'monthly_spending': 'mean',
'customer_lifetime_value': 'mean',
'retention_probability': 'mean'
}).round(2)
print("各客户群特征:")
print(cluster_profiles)
# 为每个群命名
cluster_names = {
0: '年轻价值型',
1: '高收入高消费型',
2: '普通消费型',
3: '潜在价值型'
}
self.df['cluster_name'] = self.df['kmeans_cluster'].map(cluster_names)
# 2. 趋势分析
print("\n2. 趋势分析:")
# 创建模拟时间序列
dates = pd.date_range('2020-01-01', periods=len(self.df), freq='D')
self.df['date'] = dates
self.df.set_index('date', inplace=True)
# 模拟销售额趋势
trend = np.linspace(1000, 5000, len(self.df))
seasonal = 500 * np.sin(2 * np.pi * np.arange(len(self.df)) / 365)
noise = np.random.normal(0, 200, len(self.df))
self.df['daily_sales'] = trend + seasonal + noise
# 移动平均分析
self.df['sales_ma_7'] = self.df['daily_sales'].rolling(7).mean()
self.df['sales_ma_30'] = self.df['daily_sales'].rolling(30).mean()
plt.figure(figsize=(12, 6))
plt.plot(self.df.index, self.df['daily_sales'], label='日销售额', alpha=0.3)
plt.plot(self.df.index, self.df['sales_ma_7'], label='7日移动平均', linewidth=2)
plt.plot(self.df.index, self.df['sales_ma_30'], label='30日移动平均', linewidth=2)
plt.title('销售额趋势分析')
plt.xlabel('日期')
plt.ylabel('销售额')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 3. 异常检测
print("\n3. 异常检测:")
# 使用Z-score检测异常
from scipy import stats
z_scores = stats.zscore(self.df[['income', 'monthly_spending', 'customer_lifetime_value']])
abs_z_scores = np.abs(z_scores)
# 标记异常值 (Z-score > 3)
outliers = (abs_z_scores > 3).any(axis=1)
n_outliers = outliers.sum()
print(f"检测到 {n_outliers} 个异常值")
print("异常值统计:")
outlier_stats = self.df[outliers].describe()
print(outlier_stats)
return cluster_profiles, outliers
def generate_insights_report(self):
"""生成分析洞察报告"""
print("\n=== 数据分析洞察报告 ===")
insights = []
# 基本统计洞察
avg_income = self.df['income'].mean()
avg_spending = self.df['monthly_spending'].mean()
avg_clv = self.df['customer_lifetime_value'].mean()
insights.append(f"• 客户平均收入: ¥{avg_income:,.2f}")
insights.append(f"• 客户月均消费: ¥{avg_spending:,.2f}")
insights.append(f"• 客户平均终身价值: ¥{avg_clv:,.2f}")
# 相关性洞察
income_spending_corr = self.df['income'].corr(self.df['monthly_spending'])
insights.append(f"• 收入与消费相关性: {income_spending_corr:.3f}")
if income_spending_corr > 0.5:
insights.append(" → 收入与消费呈强正相关,高收入客户消费能力更强")
elif income_spending_corr > 0.2:
insights.append(" → 收入与消费呈中等正相关")
else:
insights.append(" → 收入与消费相关性较弱")
# 聚类洞察
if 'kmeans_cluster' in self.df.columns:
cluster_sizes = self.df['kmeans_cluster'].value_counts()
insights.append(f"• 客户分群结果: {len(cluster_sizes)}个主要群体")
for cluster_id, size in cluster_sizes.items():
cluster_avg_income = self.df[self.df['kmeans_cluster'] == cluster_id]['income'].mean()
cluster_avg_clv = self.df[self.df['kmeans_cluster'] == cluster_id]['customer_lifetime_value'].mean()
insights.append(f" 群组{cluster_id}: {size}人, 平均收入¥{cluster_avg_income:,.0f}, 平均价值¥{cluster_avg_clv:,.0f}")
# 业务建议
insights.append("\n业务建议:")
insights.append("• 针对高价值客户群体,提供个性化服务和产品推荐")
insights.append("• 加强中低收入客户的忠诚度计划,提升其终身价值")
insights.append("• 利用聚类分析结果优化营销策略和资源配置")
# 打印洞察报告
for insight in insights:
print(insight)
return insights
def advanced_data_analysis_demo():
"""高级数据分析演示"""
analyzer = AdvancedDataAnalyzer()
# 创建数据集
print("创建高级分析数据集...")
df = analyzer.create_advanced_dataset()
print(f"数据集形状: {df.shape}")
print("\n数据概览:")
print(df.describe())
# 聚类分析
kmeans_clusters, dbscan_clusters = analyzer.clustering_analysis()
# 降维分析
pca, tsne = analyzer.dimensionality_reduction()
# 统计分析
correlation_matrix, t_test_results, regression_model = analyzer.statistical_analysis()
# 高级分析
cluster_profiles, outliers = analyzer.advanced_analytics()
# 生成洞察报告
insights = analyzer.generate_insights_report()
print("\n高级数据分析完成!")
return analyzer
if __name__ == "__main__":
analyzer = advanced_data_analysis_demo()
9. 设计模式 (5个脚本)
9.1 创建型设计模式.py
python
# 109.1 创建型设计模式.py
from abc import ABC, abstractmethod
from typing import Any, Dict
import copy
class SingletonMeta(type):
"""单例模式的元类实现"""
_instances: Dict[type, object] = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
class DatabaseConnection(metaclass=SingletonMeta):
"""数据库连接单例"""
def __init__(self):
self.connection_string = "localhost:5432/mydb"
self.is_connected = False
print("创建数据库连接实例")
def connect(self):
if not self.is_connected:
self.is_connected = True
print("连接到数据库")
def disconnect(self):
if self.is_connected:
self.is_connected = False
print("断开数据库连接")
def execute_query(self, query):
if self.is_connected:
print(f"执行查询: {query}")
return f"结果: {query}"
else:
raise Exception("未连接到数据库")
class Product(ABC):
"""产品抽象类"""
@abstractmethod
def operation(self) -> str:
pass
class ConcreteProductA(Product):
"""具体产品A"""
def operation(self) -> str:
return "具体产品A的操作"
class ConcreteProductB(Product):
"""具体产品B"""
def operation(self) -> str:
return "具体产品B的操作"
class Creator(ABC):
"""创建者抽象类"""
@abstractmethod
def factory_method(self) -> Product:
pass
def some_operation(self) -> str:
product = self.factory_method()
return f"创建者: {product.operation()}"
class ConcreteCreatorA(Creator):
"""具体创建者A"""
def factory_method(self) -> Product:
return ConcreteProductA()
class ConcreteCreatorB(Creator):
"""具体创建者B"""
def factory_method(self) -> Product:
return ConcreteProductB()
class Prototype:
"""原型模式"""
def __init__(self):
self._components = []
self._special_value = None
def add_component(self, component):
self._components.append(component)
def set_special_value(self, value):
self._special_value = value
def clone(self):
"""深度拷贝克隆方法"""
return copy.deepcopy(self)
def __str__(self):
return f"Prototype(components={self._components}, special_value={self._special_value})"
class Builder:
"""建造者模式"""
def __init__(self):
self.reset()
def reset(self):
self._product = ProductComplex()
@property
def product(self):
product = self._product
self.reset()
return product
def produce_part_a(self):
self._product.add("部件A")
def produce_part_b(self):
self._product.add("部件B")
def produce_part_c(self):
self._product.add("部件C")
class ProductComplex:
"""复杂产品"""
def __init__(self):
self.parts = []
def add(self, part):
self.parts.append(part)
def list_parts(self):
return f"产品部件: {', '.join(self.parts)}"
class Director:
"""指挥者"""
def __init__(self):
self._builder = None
@property
def builder(self):
return self._builder
@builder.setter
def builder(self, builder):
self._builder = builder
def build_minimal_viable_product(self):
self.builder.produce_part_a()
def build_full_featured_product(self):
self.builder.produce_part_a()
self.builder.produce_part_b()
self.builder.produce_part_c()
def creational_patterns_demo():
"""创建型设计模式演示"""
print("=== 创建型设计模式 ===")
# 1. 单例模式
print("\n1. 单例模式演示:")
db1 = DatabaseConnection()
db2 = DatabaseConnection()
print(f"db1 is db2: {db1 is db2}")
print(f"db1 id: {id(db1)}")
print(f"db2 id: {id(db2)}")
db1.connect()
db2.execute_query("SELECT * FROM users")
# 2. 工厂方法模式
print("\n2. 工厂方法模式演示:")
creators = [ConcreteCreatorA(), ConcreteCreatorB()]
for creator in creators:
print(creator.some_operation())
# 3. 原型模式
print("\n3. 原型模式演示:")
original = Prototype()
original.add_component("基础组件")
original.set_special_value("特殊值")
clone1 = original.clone()
clone1.add_component("克隆添加的组件")
clone2 = original.clone()
clone2.set_special_value("修改的特殊值")
print(f"原始对象: {original}")
print(f"克隆1: {clone1}")
print(f"克隆2: {clone2}")
# 4. 建造者模式
print("\n4. 建造者模式演示:")
director = Director()
builder = Builder()
director.builder = builder
print("构建最小可行产品:")
director.build_minimal_viable_product()
print(builder.product.list_parts())
print("\n构建完整产品:")
director.build_full_featured_product()
print(builder.product.list_parts())
print("\n自定义构建:")
builder.produce_part_a()
builder.produce_part_c()
print(builder.product.list_parts())
if __name__ == "__main__":
creational_patterns_demo()
9.2 结构型设计模式.py
python
# 109.2 结构型设计模式.py
from abc import ABC, abstractmethod
from typing import List
class Target:
"""目标接口"""
def request(self) -> str:
return "目标: 默认行为"
class Adaptee:
"""需要适配的类"""
def specific_request(self) -> str:
return ".eetpadA eht fo roivaheb laicepS"
class Adapter(Target):
"""适配器"""
def __init__(self, adaptee: Adaptee):
self.adaptee = adaptee
def request(self) -> str:
return f"适配器: (转换后) {self.adaptee.specific_request()[::-1]}"
class Component(ABC):
"""组件接口"""
@abstractmethod
def operation(self) -> str:
pass
class ConcreteComponent(Component):
"""具体组件"""
def operation(self) -> str:
return "具体组件"
class Decorator(Component):
"""装饰器基类"""
def __init__(self, component: Component):
self._component = component
@abstractmethod
def operation(self) -> str:
pass
class ConcreteDecoratorA(Decorator):
"""具体装饰器A"""
def operation(self) -> str:
return f"具体装饰器A({self._component.operation()})"
class ConcreteDecoratorB(Decorator):
"""具体装饰器B"""
def operation(self) -> str:
return f"具体装饰器B({self._component.operation()})"
class Subject(ABC):
"""主题接口"""
@abstractmethod
def request(self) -> None:
pass
class RealSubject(Subject):
"""真实主题"""
def request(self) -> None:
print("真实主题: 处理请求")
class Proxy(Subject):
"""代理"""
def __init__(self, real_subject: RealSubject):
self._real_subject = real_subject
def request(self) -> None:
if self.check_access():
self._real_subject.request()
self.log_access()
def check_access(self) -> bool:
print("代理: 检查访问权限")
return True
def log_access(self) -> None:
print("代理: 记录请求时间")
class Flyweight:
"""享元"""
def __init__(self, shared_state: str):
self._shared_state = shared_state
def operation(self, unique_state: str) -> None:
print(f"享元: 共享({self._shared_state}) 和 唯一({unique_state})")
class FlyweightFactory:
"""享元工厂"""
_flyweights: Dict[str, Flyweight] = {}
def __init__(self, initial_flyweights: List[List[str]]):
for state in initial_flyweights:
self._flyweights[self.get_key(state)] = Flyweight(state)
def get_key(self, state: List) -> str:
return "_".join(sorted(state))
def get_flyweight(self, shared_state: List) -> Flyweight:
key = self.get_key(shared_state)
if not self._flyweights.get(key):
print("享元工厂: 创建新享元")
self._flyweights[key] = Flyweight(shared_state)
else:
print("享元工厂: 重用现有享元")
return self._flyweights[key]
def list_flyweights(self) -> None:
count = len(self._flyweights)
print(f"享元工厂: 我有 {count} 个享元:")
print("\n".join(self._flyweights.keys()))
class Composite(Component):
"""组合"""
def __init__(self) -> None:
self._children: List[Component] = []
def add(self, component: Component) -> None:
self._children.append(component)
def remove(self, component: Component) -> None:
self._children.remove(component)
def operation(self) -> str:
results = []
for child in self._children:
results.append(child.operation())
return f"分支({'+'.join(results)})"
class Bridge:
"""桥接模式实现"""
class Implementation(ABC):
@abstractmethod
def operation_implementation(self) -> str:
pass
class ConcreteImplementationA(Implementation):
def operation_implementation(self) -> str:
return "具体实现A: 结果"
class ConcreteImplementationB(Implementation):
def operation_implementation(self) -> str:
return "具体实现B: 结果"
class Abstraction:
def __init__(self, implementation: Implementation):
self._implementation = implementation
def operation(self) -> str:
return f"抽象: 基础操作与:\n{self._implementation.operation_implementation()}"
class ExtendedAbstraction(Abstraction):
def operation(self) -> str:
return f"扩展抽象: 扩展操作与:\n{self._implementation.operation_implementation()}"
def structural_patterns_demo():
"""结构型设计模式演示"""
print("=== 结构型设计模式 ===")
# 1. 适配器模式
print("\n1. 适配器模式演示:")
adaptee = Adaptee()
adapter = Adapter(adaptee)
print("客户端: 我可以正常使用目标接口:")
target = Target()
print(target.request())
print("\n客户端: 但我也可以使用适配器:")
print(adapter.request())
# 2. 装饰器模式
print("\n2. 装饰器模式演示:")
simple = ConcreteComponent()
print(f"结果: {simple.operation()}")
decorator1 = ConcreteDecoratorA(simple)
decorator2 = ConcreteDecoratorB(decorator1)
print(f"结果: {decorator2.operation()}")
# 3. 代理模式
print("\n3. 代理模式演示:")
real_subject = RealSubject()
proxy = Proxy(real_subject)
proxy.request()
# 4. 享元模式
print("\n4. 享元模式演示:")
factory = FlyweightFactory([
["Chevrolet", "Camaro2018", "pink"],
["Mercedes Benz", "C300", "black"],
["Mercedes Benz", "C500", "red"],
["BMW", "M5", "red"],
["BMW", "X6", "white"],
])
factory.list_flyweights()
def add_car_to_police_database(factory, plates, owner, brand, model, color):
print("\n客户端: 添加汽车到数据库")
flyweight = factory.get_flyweight([brand, model, color])
flyweight.operation([plates, owner])
add_car_to_police_database(factory, "CL234IR", "James Doe", "BMW", "M5", "red")
add_car_to_police_database(factory, "CL234IR", "James Doe", "BMW", "X1", "red")
print("\n")
factory.list_flyweights()
# 5. 组合模式
print("\n5. 组合模式演示:")
tree = Composite()
branch1 = Composite()
branch1.add(ConcreteComponent())
branch1.add(ConcreteComponent())
branch2 = Composite()
branch2.add(ConcreteComponent())
tree.add(branch1)
tree.add(branch2)
print(f"结果: {tree.operation()}")
# 6. 桥接模式
print("\n6. 桥接模式演示:")
implementation = Bridge.ConcreteImplementationA()
abstraction = Bridge.Abstraction(implementation)
print(abstraction.operation())
implementation = Bridge.ConcreteImplementationB()
abstraction = Bridge.ExtendedAbstraction(implementation)
print(abstraction.operation())
if __name__ == "__main__":
structural_patterns_demo()
9.3 行为型设计模式.py
python
# 109.3 行为型设计模式.py
from abc import ABC, abstractmethod
from typing import List, Dict
from enum import Enum
class Observer(ABC):
"""观察者接口"""
@abstractmethod
def update(self, subject) -> None:
pass
class Subject:
"""主题"""
def __init__(self):
self._observers: List[Observer] = []
self._state = None
def attach(self, observer: Observer) -> None:
print("主题: 附加了一个观察者")
self._observers.append(observer)
def detach(self, observer: Observer) -> None:
self._observers.remove(observer)
def notify(self) -> None:
print("主题: 通知观察者...")
for observer in self._observers:
observer.update(self)
@property
def state(self) -> int:
return self._state
@state.setter
def state(self, state: int) -> None:
self._state = state
print(f"主题: 状态改变为: {state}")
self.notify()
class ConcreteObserverA(Observer):
"""具体观察者A"""
def update(self, subject: Subject) -> None:
if subject.state < 3:
print("具体观察者A: 对状态变化做出反应")
class ConcreteObserverB(Observer):
"""具体观察者B"""
def update(self, subject: Subject) -> None:
if subject.state == 0 or subject.state >= 2:
print("具体观察者B: 对状态变化做出反应")
class Strategy(ABC):
"""策略接口"""
@abstractmethod
def do_algorithm(self, data: List):
pass
class ConcreteStrategyA(Strategy):
"""具体策略A"""
def do_algorithm(self, data: List):
return sorted(data)
class ConcreteStrategyB(Strategy):
"""具体策略B"""
def do_algorithm(self, data: List):
return reversed(sorted(data))
class Context:
"""上下文"""
def __init__(self, strategy: Strategy):
self._strategy = strategy
@property
def strategy(self) -> Strategy:
return self._strategy
@strategy.setter
def strategy(self, strategy: Strategy):
self._strategy = strategy
def do_some_business_logic(self):
print("上下文: 使用策略处理数据")
result = self._strategy.do_algorithm(["a", "b", "c", "d", "e"])
print(",".join(result))
class Handler(ABC):
"""处理器接口"""
@abstractmethod
def set_next(self, handler):
pass
@abstractmethod
def handle(self, request):
pass
class AbstractHandler(Handler):
"""抽象处理器"""
_next_handler = None
def set_next(self, handler):
self._next_handler = handler
return handler
def handle(self, request):
if self._next_handler:
return self._next_handler.handle(request)
return None
class MonkeyHandler(AbstractHandler):
"""猴子处理器"""
def handle(self, request):
if request == "Banana":
return f"猴子: 我会吃 {request}"
else:
return super().handle(request)
class SquirrelHandler(AbstractHandler):
"""松鼠处理器"""
def handle(self, request):
if request == "Nut":
return f"松鼠: 我会吃 {request}"
else:
return super().handle(request)
class DogHandler(AbstractHandler):
"""狗处理器"""
def handle(self, request):
if request == "MeatBall":
return f"狗: 我会吃 {request}"
else:
return super().handle(request)
class Command(ABC):
"""命令接口"""
@abstractmethod
def execute(self):
pass
class SimpleCommand(Command):
"""简单命令"""
def __init__(self, payload):
self._payload = payload
def execute(self):
print(f"简单命令: 打印 ({self._payload})")
class ComplexCommand(Command):
"""复杂命令"""
def __init__(self, receiver, a, b):
self._receiver = receiver
self._a = a
self._b = b
def execute(self):
print("复杂命令: 复杂操作应由接收者对象完成")
self._receiver.do_something(self._a)
self._receiver.do_something_else(self._b)
class Receiver:
"""接收者"""
def do_something(self, a):
print(f"接收者: 正在处理 {a}")
def do_something_else(self, b):
print(f"接收者: 也在处理 {b}")
class Invoker:
"""调用者"""
_on_start = None
_on_finish = None
def set_on_start(self, command):
self._on_start = command
def set_on_finish(self, command):
self._on_finish = command
def do_something_important(self):
print("调用者: 有人要我做事吗?")
if isinstance(self._on_start, Command):
self._on_start.execute()
print("调用者: ...做一些重要的事情...")
if isinstance(self._on_finish, Command):
self._on_finish.execute()
class State(ABC):
"""状态接口"""
@abstractmethod
def handle(self, context):
pass
class ConcreteStateA(State):
"""具体状态A"""
def handle(self, context):
print("具体状态A: 处理请求")
context.state = ConcreteStateB()
class ConcreteStateB(State):
"""具体状态B"""
def handle(self, context):
print("具体状态B: 处理请求")
context.state = ConcreteStateA()
class ContextState:
"""状态上下文"""
def __init__(self, state):
self._state = state
@property
def state(self):
return self._state
@state.setter
def state(self, state):
print(f"上下文: 状态改变为 {type(state).__name__}")
self._state = state
def request(self):
self._state.handle(self)
class Iterator:
"""迭代器接口"""
@abstractmethod
def next(self):
pass
@abstractmethod
def has_next(self) -> bool:
pass
class ConcreteIterator(Iterator):
"""具体迭代器"""
def __init__(self, collection):
self._collection = collection
self._position = 0
def next(self):
if self.has_next():
item = self._collection[self._position]
self._position += 1
return item
else:
raise StopIteration()
def has_next(self) -> bool:
return self._position < len(self._collection)
class Collection:
"""集合"""
def __init__(self):
self._items = []
def add_item(self, item):
self._items.append(item)
def create_iterator(self):
return ConcreteIterator(self._items)
def behavioral_patterns_demo():
"""行为型设计模式演示"""
print("=== 行为型设计模式 ===")
# 1. 观察者模式
print("\n1. 观察者模式演示:")
subject = Subject()
observer_a = ConcreteObserverA()
subject.attach(observer_a)
observer_b = ConcreteObserverB()
subject.attach(observer_b)
subject.state = 1
subject.state = 2
subject.state = 3
# 2. 策略模式
print("\n2. 策略模式演示:")
context = Context(ConcreteStrategyA())
print("客户端: 策略设置为正常排序")
context.do_some_business_logic()
print("客户端: 策略设置为反向排序")
context.strategy = ConcreteStrategyB()
context.do_some_business_logic()
# 3. 责任链模式
print("\n3. 责任链模式演示:")
monkey = MonkeyHandler()
squirrel = SquirrelHandler()
dog = DogHandler()
monkey.set_next(squirrel).set_next(dog)
foods = ["Nut", "Banana", "Cup of coffee"]
for food in foods:
print(f"\n客户端: 谁想要 {food}?")
result = monkey.handle(food)
if result:
print(f" {result}")
else:
print(f" {food} 没有被吃掉")
# 4. 命令模式
print("\n4. 命令模式演示:")
invoker = Invoker()
invoker.set_on_start(SimpleCommand("Say Hi!"))
receiver = Receiver()
invoker.set_on_finish(ComplexCommand(receiver, "发送邮件", "保存报告"))
invoker.do_something_important()
# 5. 状态模式
print("\n5. 状态模式演示:")
context_state = ContextState(ConcreteStateA())
context_state.request()
context_state.request()
context_state.request()
context_state.request()
# 6. 迭代器模式
print("\n6. 迭代器模式演示:")
collection = Collection()
collection.add_item("项目1")
collection.add_item("项目2")
collection.add_item("项目3")
iterator = collection.create_iterator()
while iterator.has_next():
print(iterator.next())
if __name__ == "__main__":
behavioral_patterns_demo()
9.4 Pythonic设计模式.py
python
# 109.4 Pythonic设计模式.py
from functools import wraps
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Any, List
from enum import Enum
import time
def singleton(cls):
"""使用装饰器实现单例模式"""
instances = {}
@wraps(cls)
def wrapper(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return wrapper
@singleton
class Configuration:
"""配置管理器(单例)"""
def __init__(self):
self.settings = {
'debug': True,
'database_url': 'sqlite:///app.db',
'max_connections': 100
}
def get(self, key):
return self.settings.get(key)
def set(self, key, value):
self.settings[key] = value
class LogLevel(Enum):
"""日志级别枚举"""
DEBUG = 1
INFO = 2
WARNING = 3
ERROR = 4
class Logger:
"""日志器(使用Pythonic方式实现观察者模式)"""
def __init__(self):
self._handlers = []
self._min_level = LogLevel.DEBUG
def add_handler(self, handler: Callable):
"""添加处理器"""
self._handlers.append(handler)
def set_level(self, level: LogLevel):
"""设置日志级别"""
self._min_level = level
def log(self, level: LogLevel, message: str):
"""记录日志"""
if level.value >= self._min_level.value:
for handler in self._handlers:
handler(level, message)
def debug(self, message):
self.log(LogLevel.DEBUG, message)
def info(self, message):
self.log(LogLevel.INFO, message)
def warning(self, message):
self.log(LogLevel.WARNING, message)
def error(self, message):
self.log(LogLevel.ERROR, message)
def console_handler(level: LogLevel, message: str):
"""控制台处理器"""
print(f"[{level.name}] {time.strftime('%Y-%m-%d %H:%M:%S')} - {message}")
def file_handler(level: LogLevel, message: str):
"""文件处理器(模拟)"""
# 在实际应用中会写入文件
pass
class StrategyRegistry:
"""策略注册表(使用Python字典实现策略模式)"""
def __init__(self):
self._strategies = {}
def register(self, name: str):
"""注册策略的装饰器"""
def decorator(func):
self._strategies[name] = func
return func
return decorator
def get_strategy(self, name: str):
"""获取策略"""
return self._strategies.get(name)
def execute(self, name: str, *args, **kwargs):
"""执行策略"""
strategy = self.get_strategy(name)
if strategy:
return strategy(*args, **kwargs)
else:
raise ValueError(f"未知策略: {name}")
# 创建策略注册表实例
registry = StrategyRegistry()
@registry.register('add')
def add_strategy(a, b):
return a + b
@registry.register('multiply')
def multiply_strategy(a, b):
return a * b
@registry.register('concat')
def concat_strategy(a, b):
return str(a) + str(b)
class Pipeline:
"""处理管道(Pythonic责任链模式)"""
def __init__(self):
self._filters = []
def add_filter(self, filter_func: Callable):
"""添加过滤器"""
self._filters.append(filter_func)
return self # 支持链式调用
def process(self, data: Any) -> Any:
"""处理数据"""
for filter_func in self._filters:
data = filter_func(data)
return data
def validate_data(data):
"""数据验证过滤器"""
if not isinstance(data, (int, float)):
raise ValueError("数据必须是数字")
return data
def transform_data(data):
"""数据转换过滤器"""
return data * 2
def format_data(data):
"""数据格式化过滤器"""
return f"结果: {data}"
@dataclass
class Product:
"""产品数据类(替代传统的建造者模式)"""
name: str
price: float
category: str = "general"
description: str = ""
@classmethod
def create_expensive_product(cls, name, description=""):
"""类方法作为工厂方法"""
return cls(name=name, price=999.99, category="luxury", description=description)
@classmethod
def create_cheap_product(cls, name):
"""另一个工厂方法"""
return cls(name=name, price=9.99, category="budget")
@contextmanager
def database_transaction():
"""上下文管理器实现(替代传统的模板方法模式)"""
print("开始数据库事务")
try:
yield "数据库连接"
print("提交事务")
except Exception as e:
print(f"回滚事务: {e}")
raise
finally:
print("关闭数据库连接")
class cached_property:
"""缓存属性装饰器(享元模式的变体)"""
def __init__(self, func):
self.func = func
self.attrname = None
self.__doc__ = func.__doc__
def __set_name__(self, owner, name):
self.attrname = name
def __get__(self, instance, owner=None):
if instance is None:
return self
if self.attrname is None:
raise TypeError("无法使用未绑定的 cached_property")
cache = instance.__dict__
val = cache.get(self.attrname, self)
if val is self:
val = self.func(instance)
cache[self.attrname] = val
return val
class ExpensiveComputation:
"""需要进行昂贵计算的类"""
def __init__(self, data):
self.data = data
@cached_property
def computed_value(self):
print("执行昂贵计算...")
time.sleep(1) # 模拟耗时计算
return sum(len(str(x)) for x in self.data) * 3.14
def pythonic_design_patterns_demo():
"""Pythonic设计模式演示"""
print("=== Pythonic设计模式 ===")
# 1. 装饰器单例
print("\n1. 装饰器单例模式:")
config1 = Configuration()
config2 = Configuration()
print(f"config1 is config2: {config1 is config2}")
print(f"数据库URL: {config1.get('database_url')}")
# 2. Pythonic观察者模式
print("\n2. Pythonic观察者模式:")
logger = Logger()
logger.add_handler(console_handler)
logger.add_handler(file_handler)
logger.set_level(LogLevel.INFO)
logger.debug("这条调试信息不会显示")
logger.info("这是一条信息")
logger.warning("这是一条警告")
logger.error("这是一条错误")
# 3. 字典策略模式
print("\n3. 字典策略模式:")
print(f"加法策略: {registry.execute('add', 5, 3)}")
print(f"乘法策略: {registry.execute('multiply', 5, 3)}")
print(f"连接策略: {registry.execute('concat', 5, 3)}")
# 4. 管道责任链模式
print("\n4. 管道责任链模式:")
pipeline = Pipeline()
pipeline.add_filter(validate_data).add_filter(transform_data).add_filter(format_data)
result = pipeline.process(42)
print(result)
# 5. 数据类工厂模式
print("\n5. 数据类工厂模式:")
luxury_product = Product.create_expensive_product("钻石戒指", "闪耀的钻石")
budget_product = Product.create_cheap_product("塑料戒指")
print(f"奢侈品: {luxury_product}")
print(f"廉价品: {budget_product}")
# 6. 上下文管理器模板方法模式
print("\n6. 上下文管理器模板方法模式:")
try:
with database_transaction() as connection:
print(f"使用 {connection} 执行操作")
# 模拟正常操作
print("执行数据库操作...")
except Exception as e:
print(f"捕获异常: {e}")
# 7. 缓存属性享元模式
print("\n7. 缓存属性享元模式:")
computation = ExpensiveComputation([1, 2, 3, 4, 5])
print("第一次访问计算属性:")
start_time = time.time()
result1 = computation.computed_value
elapsed1 = time.time() - start_time
print(f"结果: {result1}, 耗时: {elapsed1:.2f}秒")
print("第二次访问计算属性:")
start_time = time.time()
result2 = computation.computed_value
elapsed2 = time.time() - start_time
print(f"结果: {result2}, 耗时: {elapsed2:.2f}秒")
print(f"结果相同: {result1 == result2}")
print(f"性能提升: {elapsed1/elapsed2 if elapsed2 > 0 else '无限':.0f}x")
if __name__ == "__main__":
pythonic_design_patterns_demo()
9.5 设计模式实战应用.py
python
# 109.5 设计模式实战应用.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from dataclasses import dataclass
from enum import Enum
import json
import time
class PaymentMethod(Enum):
"""支付方式枚举"""
CREDIT_CARD = "credit_card"
PAYPAL = "paypal"
WECHAT = "wechat"
ALIPAY = "alipay"
class PaymentStrategy(ABC):
"""支付策略接口"""
@abstractmethod
def process_payment(self, amount: float) -> bool:
pass
@abstractmethod
def get_name(self) -> str:
pass
class CreditCardPayment(PaymentStrategy):
"""信用卡支付策略"""
def process_payment(self, amount: float) -> bool:
print(f"处理信用卡支付: ${amount:.2f}")
# 模拟支付处理
time.sleep(0.5)
return True
def get_name(self) -> str:
return "信用卡"
class PayPalPayment(PaymentStrategy):
"""PayPal支付策略"""
def process_payment(self, amount: float) -> bool:
print(f"处理PayPal支付: ${amount:.2f}")
# 模拟支付处理
time.sleep(0.3)
return True
def get_name(self) -> str:
return "PayPal"
class WeChatPayment(PaymentStrategy):
"""微信支付策略"""
def process_payment(self, amount: float) -> bool:
print(f"处理微信支付: ${amount:.2f}")
# 模拟支付处理
time.sleep(0.2)
return True
def get_name(self) -> str:
return "微信支付"
class PaymentProcessor:
"""支付处理器(策略模式 + 工厂模式)"""
_strategies = {
PaymentMethod.CREDIT_CARD: CreditCardPayment(),
PaymentMethod.PAYPAL: PayPalPayment(),
PaymentMethod.WECHAT: WeChatPayment(),
}
def __init__(self):
self._strategy = None
def set_payment_method(self, method: PaymentMethod):
self._strategy = self._strategies.get(method)
if not self._strategy:
raise ValueError(f"不支持的支付方式: {method}")
def process(self, amount: float) -> bool:
if not self._strategy:
raise ValueError("未设置支付方式")
print(f"使用 {self._strategy.get_name()} 进行支付...")
return self._strategy.process_payment(amount)
@dataclass
class Order:
"""订单数据类"""
order_id: str
items: List[str]
total_amount: float
status: str = "pending"
def to_dict(self) -> Dict[str, Any]:
return {
'order_id': self.order_id,
'items': self.items,
'total_amount': self.total_amount,
'status': self.status
}
class OrderObserver(ABC):
"""订单观察者接口"""
@abstractmethod
def update(self, order: Order):
pass
class EmailNotification(OrderObserver):
"""邮件通知观察者"""
def update(self, order: Order):
print(f"发送邮件通知: 订单 {order.order_id} 状态变为 {order.status}")
class SMSNotification(OrderObserver):
"""短信通知观察者"""
def update(self, order: Order):
print(f"发送短信通知: 订单 {order.order_id} 状态变为 {order.status}")
class InventoryUpdate(OrderObserver):
"""库存更新观察者"""
def update(self, order: Order):
print(f"更新库存: 处理订单 {order.order_id} 的商品")
class OrderManager:
"""订单管理器(观察者模式 + 单例模式)"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._observers = []
cls._instance._orders = {}
return cls._instance
def attach(self, observer: OrderObserver):
self._observers.append(observer)
def detach(self, observer: OrderObserver):
self._observers.remove(observer)
def notify(self, order: Order):
for observer in self._observers:
observer.update(order)
def create_order(self, items: List[str], total_amount: float) -> Order:
order_id = f"ORD{int(time.time())}"
order = Order(order_id=order_id, items=items, total_amount=total_amount)
self._orders[order_id] = order
self.notify(order)
return order
def update_order_status(self, order_id: str, status: str):
if order_id in self._orders:
order = self._orders[order_id]
order.status = status
self.notify(order)
def get_order(self, order_id: str) -> Order:
return self._orders.get(order_id)
class DiscountStrategy(ABC):
"""折扣策略接口"""
@abstractmethod
def calculate_discount(self, amount: float) -> float:
pass
class NoDiscount(DiscountStrategy):
"""无折扣策略"""
def calculate_discount(self, amount: float) -> float:
return 0.0
class PercentageDiscount(DiscountStrategy):
"""百分比折扣策略"""
def __init__(self, percentage: float):
self.percentage = percentage
def calculate_discount(self, amount: float) -> float:
return amount * self.percentage / 100
class FixedAmountDiscount(DiscountStrategy):
"""固定金额折扣策略"""
def __init__(self, amount: float):
self.amount = amount
def calculate_discount(self, amount: float) -> float:
return min(self.amount, amount)
class DiscountCalculator:
"""折扣计算器(策略模式)"""
def __init__(self):
self._strategy = NoDiscount()
def set_strategy(self, strategy: DiscountStrategy):
self._strategy = strategy
def calculate(self, amount: float) -> float:
return self._strategy.calculate_discount(amount)
class LoggerDecorator:
"""日志装饰器(装饰器模式)"""
def __init__(self, processor):
self._processor = processor
def process(self, amount: float) -> bool:
print(f"开始支付处理: ${amount:.2f}")
start_time = time.time()
result = self._processor.process(amount)
end_time = time.time()
duration = end_time - start_time
status = "成功" if result else "失败"
print(f"支付处理{status}, 耗时: {duration:.2f}秒")
return result
class OrderBuilder:
"""订单建造者(建造者模式)"""
def __init__(self):
self.reset()
def reset(self):
self._order_id = None
self._items = []
self._total_amount = 0.0
self._customer_name = None
self._shipping_address = None
def set_order_id(self, order_id: str):
self._order_id = order_id
return self
def add_item(self, item: str, price: float):
self._items.append(item)
self._total_amount += price
return self
def set_customer(self, name: str):
self._customer_name = name
return self
def set_shipping_address(self, address: str):
self._shipping_address = address
return self
def build(self) -> Dict[str, Any]:
if not self._order_id:
self._order_id = f"ORD{int(time.time())}"
order_data = {
'order_id': self._order_id,
'items': self._items,
'total_amount': round(self._total_amount, 2),
'customer_name': self._customer_name,
'shipping_address': self._shipping_address,
'status': 'created',
'created_at': time.strftime('%Y-%m-%d %H:%M:%S')
}
self.reset()
return order_data
class ECommerceFacade:
"""电商系统外观(外观模式)"""
def __init__(self):
self.order_manager = OrderManager()
self.payment_processor = PaymentProcessor()
self.discount_calculator = DiscountCalculator()
# 设置默认观察者
self.order_manager.attach(EmailNotification())
self.order_manager.attach(SMSNotification())
self.order_manager.attach(InventoryUpdate())
def create_order(self, items: List[Dict], customer_name: str,
shipping_address: str, payment_method: PaymentMethod) -> Dict[str, Any]:
"""创建订单"""
print("\n=== 创建新订单 ===")
# 使用建造者构建订单
builder = OrderBuilder()
builder.set_customer(customer_name)
builder.set_shipping_address(shipping_address)
for item in items:
builder.add_item(item['name'], item['price'])
order_data = builder.build()
# 计算折扣
self.discount_calculator.set_strategy(PercentageDiscount(10)) # 10%折扣
discount = self.discount_calculator.calculate(order_data['total_amount'])
final_amount = order_data['total_amount'] - discount
print(f"订单金额: ${order_data['total_amount']:.2f}")
print(f"折扣: ${discount:.2f}")
print(f"最终金额: ${final_amount:.2f}")
# 创建订单记录
order = self.order_manager.create_order(
[item['name'] for item in items],
final_amount
)
# 处理支付
self.payment_processor.set_payment_method(payment_method)
# 使用装饰器添加日志
logged_processor = LoggerDecorator(self.payment_processor)
payment_success = logged_processor.process(final_amount)
if payment_success:
self.order_manager.update_order_status(order.order_id, "paid")
order_data['status'] = 'paid'
order_data['final_amount'] = final_amount
print("订单创建成功!")
else:
self.order_manager.update_order_status(order.order_id, "payment_failed")
order_data['status'] = 'payment_failed'
print("订单创建失败: 支付处理失败")
return order_data
def get_order_status(self, order_id: str) -> str:
"""获取订单状态"""
order = self.order_manager.get_order(order_id)
if order:
return order.status
return "not_found"
def design_patterns_practice():
"""设计模式实战应用演示"""
print("=== 设计模式实战应用 ===")
# 创建电商系统外观
ecommerce = ECommerceFacade()
# 模拟订单数据
order_items = [
{'name': '笔记本电脑', 'price': 999.99},
{'name': '无线鼠标', 'price': 29.99},
{'name': '电脑包', 'price': 49.99}
]
customer_name = "张三"
shipping_address = "北京市朝阳区某某街道123号"
# 创建订单并使用信用卡支付
print("\n1. 创建订单(信用卡支付):")
order1 = ecommerce.create_order(
items=order_items,
customer_name=customer_name,
shipping_address=shipping_address,
payment_method=PaymentMethod.CREDIT_CARD
)
print(f"\n订单详情: {json.dumps(order1, indent=2, ensure_ascii=False)}")
# 创建另一个订单使用微信支付
print("\n2. 创建订单(微信支付):")
order2 = ecommerce.create_order(
items=[{'name': '手机', 'price': 599.99}],
customer_name="李四",
shipping_address="上海市浦东新区某某路456号",
payment_method=PaymentMethod.WECHAT
)
# 查询订单状态
print("\n3. 查询订单状态:")
status1 = ecommerce.get_order_status(order1['order_id'])
status2 = ecommerce.get_order_status(order2['order_id'])
print(f"订单 {order1['order_id']} 状态: {status1}")
print(f"订单 {order2['order_id']} 状态: {status2}")
# 折扣策略演示
print("\n4. 折扣策略演示:")
calculator = DiscountCalculator()
test_amount = 100.0
# 无折扣
calculator.set_strategy(NoDiscount())
print(f"无折扣: ${calculator.calculate(test_amount):.2f}")
# 10%折扣
calculator.set_strategy(PercentageDiscount(10))
print(f"10%折扣: ${calculator.calculate(test_amount):.2f}")
# 固定金额折扣
calculator.set_strategy(FixedAmountDiscount(20))
print(f"$20固定折扣: ${calculator.calculate(test_amount):.2f}")
# 复杂折扣组合(在实际系统中可能使用组合模式)
print("\n5. 复杂业务场景:")
# 模拟批量订单处理
print("批量处理多个订单...")
orders_data = [
{
'items': [{'name': '商品A', 'price': 50}],
'customer': '王五',
'address': '广州',
'payment_method': PaymentMethod.PAYPAL
},
{
'items': [{'name': '商品B', 'price': 150}, {'name': '商品C', 'price': 75}],
'customer': '赵六',
'address': '深圳',
'payment_method': PaymentMethod.ALIPAY
}
]
for i, order_data in enumerate(orders_data, 1):
print(f"\n处理批量订单 #{i}:")
try:
order = ecommerce.create_order(
items=order_data['items'],
customer_name=order_data['customer'],
shipping_address=order_data['address'],
payment_method=order_data['payment_method']
)
print(f"批量订单 #{i} 处理完成")
except Exception as e:
print(f"批量订单 #{i} 处理失败: {e}")
if __name__ == "__main__":
design_patterns_practice()