1. GIL的困扰与多进程的优势
Python的全局解释器锁(GIL)是CPython解释器中的一个机制,它确保任何时候都只有一个线程在执行Python字节码。这个设计主要简化了CPython的内存管理实现,但也带来了性能限制:
- CPU密集型任务:由于GIL的存在,即使有多核CPU,多线程程序也无法真正并行执行计算任务。例如计算斐波那契数列、矩阵运算等纯计算任务。
markdown
# CPU密集型任务示例
def calculate_fibonacci(n):
if n <= 1:
return n
return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)
- I/O密集型任务:当线程进行I/O操作(如文件读写、网络请求)时,会释放GIL,因此多线程仍然可以提高这类任务的执行效率。例如爬虫程序、文件批量处理等。
解决方案矩阵:
| 任务类型 | 推荐方案 | 适用场景示例 |
|---|---|---|
| I/O密集型 | 多线程 | Web爬虫、文件处理、数据库查询 |
| CPU密集型 | 多进程 | 数值计算、图像处理、机器学习训练 |
2. threading模块基础
2.1 基本线程创建
python
import threading
import time
def worker(name):
print(f"线程{name}启动于{time.ctime()}")
time.sleep(1) # 模拟I/O操作
print(f"线程{name}结束于{time.ctime()}")
# 创建并启动线程
threads = []
for i in ["A", "B", "C"]:
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# 等待所有线程完成
for t in threads:
t.join()
print(f"所有线程完成于{time.ctime()}")
2.2 继承Thread类
python
class DownloadThread(threading.Thread):
def __init__(self, url, save_path):
super().__init__()
self.url = url
self.save_path = save_path
def run(self):
print(f"开始下载 {self.url}")
# 模拟下载过程
time.sleep(random.uniform(0.5, 2.0))
print(f"完成下载 {self.url} 保存到 {self.save_path}")
# 使用示例
downloads = [
("http://example.com/file1", "data/file1.txt"),
("http://example.com/file2", "data/file2.txt")
]
for url, path in downloads:
t = DownloadThread(url, path)
t.start()
3. 线程同步机制
3.1 互斥锁(Lock)
python
class BankAccount:
def __init__(self):
self.balance = 1000
self.lock = threading.Lock()
def transfer(self, amount):
with self.lock:
new_balance = self.balance + amount
time.sleep(0.01) # 模拟处理延迟
self.balance = new_balance
account = BankAccount()
def client():
for _ in range(100):
account.transfer(10)
threads = [threading.Thread(target=client) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"最终余额: {account.balance}") # 应该为2000
3.2 可重入锁(RLock)
python
class RecursiveCounter:
def __init__(self):
self.value = 0
self.lock = threading.RLock()
def increment(self, n):
with self.lock:
if n <= 0:
return
self.value += 1
self.increment(n-1) # 递归调用
counter = RecursiveCounter()
counter.increment(5)
print(f"计数器值: {counter.value}") # 输出5
3.3 信号量(Semaphore)
python
class DatabaseConnectionPool:
def __init__(self, max_connections=5):
self.semaphore = threading.Semaphore(max_connections)
self.connections = []
def get_connection(self):
with self.semaphore:
conn = self._create_connection()
self.connections.append(conn)
return conn
def _create_connection(self):
time.sleep(0.5) # 模拟创建连接耗时
return f"Connection-{len(self.connections)+1}"
pool = DatabaseConnectionPool(3)
def query_data():
conn = pool.get_connection()
print(f"使用 {conn} 查询数据")
time.sleep(1) # 模拟查询耗时
print(f"释放 {conn}")
for i in range(10):
threading.Thread(target=query_data).start()
3.4 事件(Event)
python
class DownloadManager:
def __init__(self):
self.complete_event = threading.Event()
self.error_event = threading.Event()
def download_file(self, url):
try:
print(f"开始下载 {url}")
time.sleep(random.uniform(1, 3)) # 模拟下载时间
if random.random() < 0.2: # 20%概率模拟失败
raise Exception("网络错误")
print(f"下载完成 {url}")
self.complete_event.set()
except Exception as e:
print(f"下载失败 {url}: {e}")
self.error_event.set()
manager = DownloadManager()
def wait_for_download():
print("等待下载完成...")
manager.complete_event.wait()
print("收到下载完成通知!")
def handle_error():
print("等待可能发生的错误...")
manager.error_event.wait()
print("检测到下载错误!")
threading.Thread(target=manager.download_file, args=("http://example.com/file",)).start()
threading.Thread(target=wait_for_download).start()
threading.Thread(target=handle_error).start()
4. 线程间通信:队列(Queue)
python
import queue
import threading
import random
class TaskDispatcher:
def __init__(self, num_workers=3):
self.task_queue = queue.Queue()
self.result_queue = queue.Queue()
self.workers = []
for i in range(num_workers):
t = threading.Thread(target=self.worker_loop, args=(f"Worker-{i+1}",))
t.start()
self.workers.append(t)
def worker_loop(self, name):
while True:
task = self.task_queue.get()
if task is None: # 终止信号
break
print(f"{name} 处理任务: {task}")
time.sleep(random.uniform(0.1, 0.5)) # 模拟处理时间
self.result_queue.put((task, task * 2)) # 返回(输入, 结果)
self.task_queue.task_done()
def add_task(self, task):
self.task_queue.put(task)
def get_results(self):
results = []
while not self.result_queue.empty():
results.append(self.result_queue.get())
return results
def shutdown(self):
for _ in self.workers:
self.task_queue.put(None) # 发送终止信号
for t in self.workers:
t.join()
# 使用示例
dispatcher = TaskDispatcher()
# 添加任务
for i in range(10):
dispatcher.add_task(i)
# 等待所有任务完成
dispatcher.task_queue.join()
# 获取结果
print("处理结果:", dispatcher.get_results())
# 关闭调度器
dispatcher.shutdown()
5. 线程池:ThreadPoolExecutor
python
from concurrent.futures import ThreadPoolExecutor, as_completed
import urllib.request
def download_url(url):
try:
with urllib.request.urlopen(url, timeout=5) as response:
return url, response.status, len(response.read())
except Exception as e:
return url, None, str(e)
urls = [
"http://www.python.org",
"http://www.google.com",
"http://www.example.com",
"http://www.invalid-url-123456.com"
]
with ThreadPoolExecutor(max_workers=3) as executor:
# 使用submit提交任务
future_to_url = {executor.submit(download_url, url): url for url in urls}
for future in as_completed(future_to_url):
url = future_to_url[future]
try:
data = future.result()
print(f"URL: {data[0]}, 状态码: {data[1]}, 大小: {data[2]}")
except Exception as e:
print(f"{url} 生成异常: {e}")
# 使用map的简化版本
with ThreadPoolExecutor(max_workers=3) as executor:
results = executor.map(download_url, urls)
for result in results:
print(f"结果: {result}")
6. 多进程编程基础
6.1 基本进程创建
python
import multiprocessing
import os
def worker_process(name):
print(f"子进程 {name} PID: {os.getpid()}, 父PID: {os.getppid()}")
time.sleep(1)
return name.upper()
if __name__ == "__main__":
print(f"主进程 PID: {os.getpid()}")
processes = []
for i in range(3):
p = multiprocessing.Process(target=worker_process, args=(f"Process-{i}",))
processes.append(p)
p.start()
for p in processes:
p.join()
print(f"进程 {p.name} 退出码: {p.exitcode}")
6.2 进程间共享数据
python
from multiprocessing import Process, Value, Array, Manager
def modify_shared_data(n, arr, shared_dict):
# 修改共享数值
n.value += 1
# 修改共享数组
for i in range(len(arr)):
arr[i] *= 2
# 修改共享字典
shared_dict[os.getpid()] = time.time()
if __name__ == "__main__":
# 共享数值(类型'i'表示有符号整数)
shared_num = Value('i', 0)
# 共享数组(类型'd'表示双精度浮点数)
shared_arr = Array('d', [1.0, 2.0, 3.0])
# 通过Manager共享复杂数据结构
manager = Manager()
shared_dict = manager.dict()
processes = []
for _ in range(3):
p = Process(target=modify_shared_data, args=(shared_num, shared_arr, shared_dict))
processes.append(p)
p.start()
for p in processes:
p.join()
print(f"共享数值: {shared_num.value}") # 应为3
print(f"共享数组: {list(shared_arr)}") # 应为[8.0, 16.0, 24.0]
print(f"共享字典: {dict(shared_dict)}") # 包含3个进程的PID和时间戳
7. 进程间通信
7.1 管道(Pipe)
python
from multiprocessing import Pipe, Process
def sender(conn):
for i in range(5):
conn.send(f"消息 {i}")
time.sleep(0.5)
conn.send(None) # 发送结束信号
conn.close()
def receiver(conn):
while True:
msg = conn.recv()
if msg is None:
break
print(f"接收到: {msg}")
conn.close()
if __name__ == "__main__":
parent_conn, child_conn = Pipe()
p1 = Process(target=sender, args=(child_conn,))
p2 = Process(target=receiver, args=(parent_conn,))
p1.start()
p2.start()
p1.join()
p2.join()
7.2 进程安全队列(Queue)
python
from multiprocessing import Process, Queue
def producer(q, items):
for item in items:
print(f"生产: {item}")
q.put(item)
time.sleep(random.uniform(0.1, 0.3))
q.put(None) # 结束信号
def consumer(q, name):
while True:
item = q.get()
if item is None:
q.put(None) # 将结束信号传递给下一个消费者
break
print(f"{name} 消费: {item}")
time.sleep(random.uniform(0.2, 0.4))
if __name__ == "__main__":
q = Queue()
items = [f"Item-{i}" for i in range(10)]
producers = [
Process(target=producer, args=(q, items[:5])),
Process(target=producer, args=(q, items[5:]))
]
consumers = [
Process(target=consumer, args=(q, "Consumer-1")),
Process(target=consumer, args=(q, "Consumer-2"))
]
for p in producers:
p.start()
for c in consumers:
c.start()
for p in producers:
p.join()
for c in consumers:
c.join()
8. 进程池
8.1 multiprocessing.Pool
python
import multiprocessing
import math
def factorize(n):
"""分解质因数"""
factors = []
while n % 2 == 0:
factors.append(2)
n = n // 2
i = 3
max_factor = math.sqrt(n)
while i <= max_factor:
while n % i == 0:
factors.append(i)
n = n // i
max_factor = math.sqrt(n)
i += 2
if n > 1:
factors.append(n)
return factors
if __name__ == "__main__":
numbers = [999999999999, 888888888888, 777777777777, 666666666666]
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
results = pool.map(factorize, numbers)
for num, factors in zip(numbers, results):
print(f"{num} = {' × '.join(map(str, factors))}")
8.2 ProcessPoolExecutor
python
from concurrent.futures import ProcessPoolExecutor
import hashlib
def hash_file(file_path):
"""计算文件哈希值"""
hash_obj = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(8192):
hash_obj.update(chunk)
return file_path, hash_obj.hexdigest()
if __name__ == "__main__":
import glob
# 获取当前目录下所有.txt文件
files = glob.glob("*.txt")
with ProcessPoolExecutor(max_workers=4) as executor:
# 使用submit提交任务
futures = {executor.submit(hash_file, file): file for file in files}
for future in futures:
file = futures[future]
try:
file_path, file_hash = future.result()
print(f"{file_path}: {file_hash}")
except Exception as e:
print(f"处理 {file} 时出错: {e}")
9. 实战场景
9.1 I/O密集型任务:多线程爬虫
python
import threading
import requests
from queue import Queue
from urllib.parse import urljoin
class WebCrawler:
def __init__(self, base_url, max_threads=5):
self.base_url = base_url
self.visited = set()
self.lock = threading.Lock()
self.queue = Queue()
self.queue.put(base_url)
self.max_threads = max_threads
def crawl(self):
threads = []
for i in range(self.max_threads):
t = threading.Thread(target=self.worker, daemon=True)
t.start()
threads.append(t)
self.queue.join()
# 发送停止信号
for _ in range(self.max_threads):
self.queue.put(None)
for t in threads:
t.join()
def worker(self):
while True:
url = self.queue.get()
if url is None:
break
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
print(f"爬取成功: {url}")
# 获取页面所有链接
soup = BeautifulSoup(response.text, 'html.parser')
links = [urljoin(url, a['href'])
for a in soup.find_all('a', href=True)
if urljoin(url, a['href']).startswith(self.base_url)]
with self.lock:
for link in links:
if link not in self.visited:
self.visited.add(link)
self.queue.put(link)
except Exception as e:
print(f"爬取失败 {url}: {e}")
self.queue.task_done()
if __name__ == "__main__":
crawler = WebCrawler("http://example.com")
crawler.crawl()
print(f"总共爬取了 {len(crawler.visited)} 个页面")
9.2 CPU密集型任务:多进程计算质数
python
import math
import multiprocessing
import time
def is_prime(n):
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
max_divisor = math.isqrt(n) + 1
for i in range(3, max_divisor, 2):
if n % i == 0:
return False
return True
def count_primes(start, end):
count = 0
for n in range(start, end):
if is_prime(n):
count += 1
return count
def parallel_prime_count(n, num_processes=None):
if num_processes is None:
num_processes = multiprocessing.cpu_count()
chunk_size = n // num_processes
ranges = [(i*chunk_size, (i+1)*chunk_size) for i in range(num_processes)]
ranges[-1] = (ranges[-1][0], n) # 确保覆盖所有数字
with multiprocessing.Pool(processes=num_processes) as pool:
results = pool.starmap(count_primes, ranges)
return sum(results)
if __name__ == "__main__":
n = 10_000_000
print(f"计算1到{n}之间的质数数量...")
start_time = time.time()
prime_count = parallel_prime_count(n)
duration = time.time() - start_time
print(f"质数总数: {prime_count}")
print(f"耗时: {duration:.2f}秒")
9.3 生产者-消费者模型
python
from multiprocessing import Process, Queue
import time
import random
def producer(queue, product_id):
for i in range(5):
item = f"产品-{product_id}-{i}"
print(f"生产者 {product_id} 生产: {item}")
queue.put(item)
time.sleep(random.uniform(0.1, 0.5))
queue.put(None) # 发送结束信号
def consumer(queue, consumer_id):
while True:
item = queue.get()
if item is None:
queue.put(None) # 将结束信号传递给下一个消费者
break
print(f"消费者 {consumer_id} 消费: {item}")
time.sleep(random.uniform(0.2, 0.7))
if __name__ == "__main__":
queue = Queue(maxsize=10) # 限制队列大小
# 创建2个生产者
producers = [
Process(target=producer, args=(queue, i))
for i in range(2)
]
# 创建3个消费者
consumers = [
Process(target=consumer, args=(queue, i))
for i in range(3)
]
# 启动所有进程
for p in producers:
p.start()
for c in consumers:
c.start()
# 等待生产者完成
for p in producers:
p.join()
# 等待消费者完成
for c in consumers:
c.join()
print("所有生产消费任务完成")
markdown
# Python多线程与多进程编程详解
## 1. GIL的困扰与多进程的优势
Python的全局解释器锁(GIL)是CPython解释器中的一个机制,它确保任何时候都只有一个线程在执行Python字节码。这个设计主要简化了CPython的内存管理实现,但也带来了性能限制:
- **CPU密集型任务**:由于GIL的存在,即使有多核CPU,多线程程序也无法真正并行执行计算任务。例如计算斐波那契数列、矩阵运算等纯计算任务。
```python
# CPU密集型任务示例
def calculate_fibonacci(n):
if n <= 1:
return n
return calculate_fibonacci(n-1) + calculate_fibonacci(n-2)
- I/O密集型任务:当线程进行I/O操作(如文件读写、网络请求)时,会释放GIL,因此多线程仍然可以提高这类任务的执行效率。例如爬虫程序、文件批量处理等。
解决方案矩阵:
| 任务类型 | 推荐方案 | 适用场景示例 |
|---|---|---|
| I/O密集型 | 多线程 | Web爬虫、文件处理、数据库查询 |
| CPU密集型 | 多进程 | 数值计算、图像处理、机器学习训练 |
2. threading模块基础
2.1 基本线程创建
python
import threading
import time
def worker(name):
print(f"线程{name}启动于{time.ctime()}")
time.sleep(1) # 模拟I/O操作
print(f"线程{name}结束于{time.ctime()}")
# 创建并启动线程
threads = []
for i in ["A", "B", "C"]:
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# 等待所有线程完成
for t in threads:
t.join()
print(f"所有线程完成于{time.ctime()}")
2.2 继承Thread类
python
class DownloadThread(threading.Thread):
def __init__(self, url, save_path):
super().__init__()
self.url = url
self.save_path = save_path
def run(self):
print(f"开始下载 {self.url}")
# 模拟下载过程
time.sleep(random.uniform(0.5, 2.0))
print(f"完成下载 {self.url} 保存到 {self.save_path}")
# 使用示例
downloads = [
("http://example.com/file1", "data/file1.txt"),
("http://example.com/file2", "data/file2.txt")
]
for url, path in downloads:
t = DownloadThread(url, path)
t.start()
3. 线程同步机制
3.1 互斥锁(Lock)
python
class BankAccount:
def __init__(self):
self.balance = 1000
self.lock = threading.Lock()
def transfer(self, amount):
with self.lock:
new_balance = self.balance + amount
time.sleep(0.01) # 模拟处理延迟
self.balance = new_balance
account = BankAccount()
def client():
for _ in range(100):
account.transfer(10)
threads = [threading.Thread(target=client) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
print(f"最终余额: {account.balance}") # 应该为2000
3.2 可重入锁(RLock)
python
class RecursiveCounter:
def __init__(self):
self.value = 0
self.lock = threading.RLock()
def increment(self, n):
with self.lock:
if n <= 0:
return
self.value += 1
self.increment(n-1) # 递归调用
counter = RecursiveCounter()
counter.increment(5)
print(f"计数器值: {counter.value}") # 输出5
3.3 信号量(Semaphore)
python
class DatabaseConnectionPool:
def __init__(self, max_connections=5):
self.semaphore = threading.Semaphore(max_connections)
self.connections = []
def get_connection(self):
with self.semaphore:
conn = self._create_connection()
self.connections.append(conn)
return conn
def _create_connection(self):
time.sleep(0.5) # 模拟创建连接耗时
return f"Connection-{len(self.connections)+1}"
pool = DatabaseConnectionPool(3)
def query_data():
conn = pool.get_connection()
print(f"使用 {conn} 查询数据")
time.sleep(1) # 模拟查询耗时
print(f"释放 {conn}")
for i in range(10):
threading.Thread(target=query_data).start()
3.4 事件(Event)
python
class DownloadManager:
def __init__(self):
self.complete_event = threading.Event()
self.error_event = threading.Event()
def download_file(self, url):
try:
print(f"开始下载 {url}")
time.sleep(random.uniform(1, 3)) # 模拟下载时间
if random.random() < 0.2: # 20%概率模拟失败
raise Exception("网络错误")
print(f"下载完成 {url}")
self.complete_event.set()
except Exception as e:
print(f"下载失败 {url}: {e}")
self.error_event.set()
manager = DownloadManager()
def wait_for_download():
print("等待下载完成...")
manager.complete_event.wait()
print("收到下载完成通知!")
def handle_error():
print("等待可能发生的错误...")
manager.error_event.wait()
print("检测到下载错误!")
threading.Thread(target=manager.download_file, args=("http://example.com/file",)).start()
threading.Thread(target=wait_for_download).start()
threading.Thread(target=handle_error).start()
4. 线程间通信:队列(Queue)
python
import queue
import threading
import random
class TaskDispatcher:
def __init__(self, num_workers=3):
self.task_queue = queue.Queue()
self.result_queue = queue.Queue()
self.workers = []
for i in range(num_workers):
t = threading.Thread(target=self.worker_loop, args=(f"Worker-{i+1}",))
t.start()
self.workers.append(t)
def worker_loop(self, name):
while True:
task = self.task_queue.get()
if task is None: # 终止信号
break
print(f"{name} 处理任务: {task}")
time.sleep(random.uniform(0.1, 0.5)) # 模拟处理时间
self.result_queue.put((task, task * 2)) # 返回(输入, 结果)
self.task_queue.task_done()
def add_task(self, task):
self.task_queue.put(task)
def get_results(self):
results = []
while not self.result_queue.empty():
results.append(self.result_queue.get())
return results
def shutdown(self):
for _ in self.workers:
self.task_queue.put(None) # 发送终止信号
for t in self.workers:
t.join()
# 使用示例
dispatcher = TaskDispatcher()
# 添加任务
for i in range(10):
dispatcher.add_task(i)
# 等待所有任务完成
dispatcher.task_queue.join()
# 获取结果
print("处理结果:", dispatcher.get_results())
# 关闭调度器
dispatcher.shutdown()
5. 线程池:ThreadPoolExecutor
python
from concurrent.futures import ThreadPoolExecutor, as_completed
import urllib.request
def download_url(url):
try:
with urllib.request.urlopen(url, timeout=5) as response:
return url, response.status, len(response.read())
except Exception as e:
return url, None, str(e)
urls = [
"http://www.python.org",
"http://www.google.com",
"http://www.example.com",
"http://www.invalid-url-123456.com"
]
with ThreadPoolExecutor(max_workers=3) as executor:
# 使用submit提交任务
future_to_url = {executor.submit(download_url, url): url for url in urls}
for future in as_completed(future_to_url):
url = future_to_url[future]
try:
data = future.result()
print(f"URL: {data[0]}, 状态码: {data[1]}, 大小: {data[2]}")
except Exception as e:
print(f"{url} 生成异常: {e}")
# 使用map的简化版本
with ThreadPoolExecutor(max_workers=3) as executor:
results = executor.map(download_url, urls)
for result in results:
print(f"结果: {result}")
6. 多进程编程基础
6.1 基本进程创建
python
import multiprocessing
import os
def worker_process(name):
print(f"子进程 {name} PID: {os.getpid()}, 父PID: {os.getppid()}")
time.sleep(1)
return name.upper()
if __name__ == "__main__":
print(f"主进程 PID: {os.getpid()}")
processes = []
for i in range(3):
p = multiprocessing.Process(target=worker_process, args=(f"Process-{i}",))
processes.append(p)
p.start()
for p in processes:
p.join()
print(f"进程 {p.name} 退出码: {p.exitcode}")
6.2 进程间共享数据
python
from multiprocessing import Process, Value, Array, Manager
def modify_shared_data(n, arr, shared_dict):
# 修改共享数值
n.value += 1
# 修改共享数组
for i in range(len(arr)):
arr[i] *= 2
# 修改共享字典
shared_dict[os.getpid()] = time.time()
if __name__ == "__main__":
# 共享数值(类型'i'表示有符号整数)
shared_num = Value('i', 0)
# 共享数组(类型'd'表示双精度浮点数)
shared_arr = Array('d', [1.0, 2.0, 3.0])
# 通过Manager共享复杂数据结构
manager = Manager()
shared_dict = manager.dict()
processes = []
for _ in range(3):
p = Process(target=modify_shared_data, args=(shared_num, shared_arr, shared_dict))
processes.append(p)
p.start()
for p in processes:
p.join()
print(f"共享数值: {shared_num.value}") # 应为3
print(f"共享数组: {list(shared_arr)}") # 应为[8.0, 16.0, 24.0]
print(f"共享字典: {dict(shared_dict)}") # 包含3个进程的PID和时间戳
7. 进程间通信
7.1 管道(Pipe)
python
from multiprocessing import Pipe, Process
def sender(conn):
for i in range(5):
conn.send(f"消息 {i}")
time.sleep(0.5)
conn.send(None) # 发送结束信号
conn.close()
def receiver(conn):
while True:
msg = conn.recv()
if msg is None:
break
print(f"接收到: {msg}")
conn.close()
if __name__ == "__main__":
parent_conn, child_conn = Pipe()
p1 = Process(target=sender, args=(child_conn,))
p2 = Process(target=receiver, args=(parent_conn,))
p1.start()
p2.start()
p1.join()
p2.join()
7.2 进程安全队列(Queue)
python
from multiprocessing import Process, Queue
def producer(q, items):
for item in items:
print(f"生产: {item}")
q.put(item)
time.sleep(random.uniform(0.1, 0.3))
q.put(None) # 结束信号
def consumer(q, name):
while True:
item = q.get()
if item is None:
q.put(None) # 将结束信号传递给下一个消费者
break
print(f"{name} 消费: {item}")
time.sleep(random.uniform(0.2, 0.4))
if __name__ == "__main__":
q = Queue()
items = [f"Item-{i}" for i in range(10)]
producers = [
Process(target=producer, args=(q, items[:5])),
Process(target=producer, args=(q, items[5:]))
]
consumers = [
Process(target=consumer, args=(q, "Consumer-1")),
Process(target=consumer, args=(q, "Consumer-2"))
]
for p in producers:
p.start()
for c in consumers:
c.start()
for p in producers:
p.join()
for c in consumers:
c.join()
8. 进程池
8.1 multiprocessing.Pool
python
import multiprocessing
import math
def factorize(n):
"""分解质因数"""
factors = []
while n % 2 == 0:
factors.append(2)
n = n // 2
i = 3
max_factor = math.sqrt(n)
while i <= max_factor:
while n % i == 0:
factors.append(i)
n = n // i
max_factor = math.sqrt(n)
i += 2
if n > 1:
factors.append(n)
return factors
if __name__ == "__main__":
numbers = [999999999999, 888888888888, 777777777777, 666666666666]
with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
results = pool.map(factorize, numbers)
for num, factors in zip(numbers, results):
print(f"{num} = {' × '.join(map(str, factors))}")
8.2 ProcessPoolExecutor
python
from concurrent.futures import ProcessPoolExecutor
import hashlib
def hash_file(file_path):
"""计算文件哈希值"""
hash_obj = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(8192):
hash_obj.update(chunk)
return file_path, hash_obj.hexdigest()
if __name__ == "__main__":
import glob
# 获取当前目录下所有.txt文件
files = glob.glob("*.txt")
with ProcessPoolExecutor(max_workers=4) as executor:
# 使用submit提交任务
futures = {executor.submit(hash_file, file): file for file in files}
for future in futures:
file = futures[future]
try:
file_path, file_hash = future.result()
print(f"{file_path}: {file_hash}")
except Exception as e:
print(f"处理 {file} 时出错: {e}")
9. 实战场景
9.1 I/O密集型任务:多线程爬虫
python
import threading
import requests
from queue import Queue
from urllib.parse import urljoin
class WebCrawler:
def __init__(self, base_url, max_threads=5):
self.base_url = base_url
self.visited = set()
self.lock = threading.Lock()
self.queue = Queue()
self.queue.put(base_url)
self.max_threads = max_threads
def crawl(self):
threads = []
for i in range(self.max_threads):
t = threading.Thread(target=self.worker, daemon=True)
t.start()
threads.append(t)
self.queue.join()
# 发送停止信号
for _ in range(self.max_threads):
self.queue.put(None)
for t in threads:
t.join()
def worker(self):
while True:
url = self.queue.get()
if url is None:
break
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
print(f"爬取成功: {url}")
# 获取页面所有链接
soup = BeautifulSoup(response.text, 'html.parser')
links = [urljoin(url, a['href'])
for a in soup.find_all('a', href=True)
if urljoin(url, a['href']).startswith(self.base_url)]
with self.lock:
for link in links:
if link not in self.visited:
self.visited.add(link)
self.queue.put(link)
except Exception as e:
print(f"爬取失败 {url}: {e}")
self.queue.task_done()
if __name__ == "__main__":
crawler = WebCrawler("http://example.com")
crawler.crawl()
print(f"总共爬取了 {len(crawler.visited)} 个页面")
9.2 CPU密集型任务:多进程计算质数
python
import math
import multiprocessing
import time
def is_prime(n):
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
max_divisor = math.isqrt(n) + 1
for i in range(3, max_divisor, 2):
if n % i == 0:
return False
return True
def count_primes(start, end):
count = 0
for n in range(start, end):
if is_prime(n):
count += 1
return count
def parallel_prime_count(n, num_processes=None):
if num_processes is None:
num_processes = multiprocessing.cpu_count()
chunk_size = n // num_processes
ranges = [(i*chunk_size, (i+1)*chunk_size) for i in range(num_processes)]
ranges[-1] = (ranges[-1][0], n) # 确保覆盖所有数字
with multiprocessing.Pool(processes=num_processes) as pool:
results = pool.starmap(count_primes, ranges)
return sum(results)
if __name__ == "__main__":
n = 10_000_000
print(f"计算1到{n}之间的质数数量...")
start_time = time.time()
prime_count = parallel_prime_count(n)
duration = time.time() - start_time
print(f"质数总数: {prime_count}")
print(f"耗时: {duration:.2f}秒")
9.3 生产者-消费者模型
python
from multiprocessing import Process, Queue
import time
import random
def producer(queue, product_id):
for i in range(5):
item = f"产品-{product_id}-{i}"
print(f"生产者 {product_id} 生产: {item}")
queue.put(item)
time.sleep(random.uniform(0.1, 0.5))
queue.put(None) # 发送结束信号
def consumer(queue, consumer_id):
while True:
item = queue.get()
if item is None:
queue.put(None) # 将结束信号传递给下一个消费者
break
print(f"消费者 {consumer_id} 消费: {item}")
time.sleep(random.uniform(0.2, 0.7))
if __name__ == "__main__":
queue = Queue(maxsize=10) # 限制队列大小
# 创建2个生产者
producers = [
Process(target=producer, args=(queue, i))
for i in range(2)
]
# 创建3个消费者
consumers = [
Process(target=consumer, args=(queue, i))
for i in range(3)
]
# 启动所有进程
for p in producers:
p.start()
for c in consumers:
c.start()
# 等待生产者完成
for p in producers:
p.join()
# 等待消费者完成
for c in consumers:
c.join()
print("所有生产消费任务完成")
10. 总结与性能对比
| 场景 | 推荐方案 | 原因 |
|---|---|---|
| I/O密集型(网络、文件) | 多线程/异步IO | 线程切换开销小,GIL不影响 |
| CPU密集型(计算) | 多进程 | 绕过GIL,充分利用多核 |
| 大量微小任务 | 线程池/进程池 | 减少创建销毁开销 |
| 共享大量数据 | 多进程+共享内存 | 进程独立内存空间,需要IPC |
- I/O密集型任务:多线程或异步IO能有效减少等待时间,GIL对这类任务影响较小。
- CPU密集型任务:多进程可绕过GIL限制,充分利用多核CPU性能。
- 任务调度优化:线程池或进程池避免频繁创建销毁线程/进程,提升效率。
- 数据共享需求:多进程需通过共享内存或IPC机制实现数据交互,但需注意同步问题。
记住:`threading`适合I/O,`multiprocessing`适合CPU。不要混合使用过度。