Rust 并发编程基础
概述
Rust 的所有权系统使得并发编程更加安全。编译器可以在编译时防止数据竞争,这是 Rust 的一大优势。
线程基础
Rust 使用 1:1 线程模型,每个语言线程对应一个操作系统线程。
简单示例
rust
use std::thread;
use std::time::Duration;
fn basic_threading() {
let handle = thread::spawn(|| {
for i in 1..10 {
println!("子线程: {}", i);
thread::sleep(Duration::from_millis(1));
}
});
for i in 1..5 {
println!("主线程: {}", i);
thread::sleep(Duration::from_millis(1));
}
handle.join().unwrap();
}
复杂案例:实现一个并发任务调度器
rust
use std::sync::{Arc, Mutex, Condvar};
use std::sync::mpsc::{self, Sender, Receiver};
use std::thread;
use std::time::Duration;
use std::collections::VecDeque;
// 任务类型
type Task = Box<dyn FnOnce() + Send + 'static>;
// 工作线程状态
#[derive(Debug, Clone, Copy, PartialEq)]
enum WorkerState {
Idle,
Busy,
Stopped,
}
// 工作线程
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(
id: usize,
receiver: Arc<Mutex<Receiver<Task>>>,
state: Arc<Mutex<WorkerState>>,
condvar: Arc<Condvar>,
) -> Self {
let thread = thread::spawn(move || {
loop {
let task = {
let receiver = receiver.lock().unwrap();
receiver.recv()
};
match task {
Ok(task) => {
{
let mut s = state.lock().unwrap();
*s = WorkerState::Busy;
}
println!("工作线程 {} 正在执行任务", id);
task();
{
let mut s = state.lock().unwrap();
*s = WorkerState::Idle;
}
condvar.notify_all();
}
Err(_) => {
println!("工作线程 {} 停止", id);
let mut s = state.lock().unwrap();
*s = WorkerState::Stopped;
break;
}
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
// 线程池
struct ThreadPool {
workers: Vec<Worker>,
sender: Sender<Task>,
worker_states: Vec<Arc<Mutex<WorkerState>>>,
condvar: Arc<Condvar>,
}
impl ThreadPool {
fn new(size: usize) -> Self {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let condvar = Arc::new(Condvar::new());
let mut workers = Vec::with_capacity(size);
let mut worker_states = Vec::with_capacity(size);
for id in 0..size {
let state = Arc::new(Mutex::new(WorkerState::Idle));
worker_states.push(state.clone());
workers.push(Worker::new(
id,
Arc::clone(&receiver),
state,
Arc::clone(&condvar),
));
}
ThreadPool {
workers,
sender,
worker_states,
condvar,
}
}
fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let task = Box::new(f);
self.sender.send(task).unwrap();
}
fn active_count(&self) -> usize {
self.worker_states
.iter()
.filter(|state| {
*state.lock().unwrap() == WorkerState::Busy
})
.count()
}
fn wait_completion(&self) {
loop {
let all_idle = self.worker_states
.iter()
.all(|state| {
let s = state.lock().unwrap();
*s == WorkerState::Idle || *s == WorkerState::Stopped
});
if all_idle {
break;
}
let state = self.worker_states[0].lock().unwrap();
let _guard = self.condvar.wait(state).unwrap();
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
drop(self.sender.clone());
for worker in &mut self.workers {
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}
// 演示线程池使用
fn demonstrate_thread_pool() {
let pool = ThreadPool::new(4);
println!("提交 10 个任务到线程池");
for i in 0..10 {
pool.execute(move || {
println!("任务 {} 开始执行", i);
thread::sleep(Duration::from_millis(500));
println!("任务 {} 完成", i);
});
}
println!("等待所有任务完成...");
pool.wait_completion();
println!("所有任务已完成");
}
// 使用 Arc 和 Mutex 实现共享状态
struct SharedCounter {
count: Arc<Mutex<i32>>,
}
impl SharedCounter {
fn new() -> Self {
SharedCounter {
count: Arc::new(Mutex::new(0)),
}
}
fn increment(&self) {
let mut count = self.count.lock().unwrap();
*count += 1;
}
fn get(&self) -> i32 {
*self.count.lock().unwrap()
}
fn clone_counter(&self) -> Self {
SharedCounter {
count: Arc::clone(&self.count),
}
}
}
fn demonstrate_shared_state() {
let counter = SharedCounter::new();
let mut handles = vec![];
for _ in 0..10 {
let counter_clone = counter.clone_counter();
let handle = thread::spawn(move || {
for _ in 0..100 {
counter_clone.increment();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("最终计数: {}", counter.get());
}
// 消息传递并发
fn demonstrate_message_passing() {
let (tx, rx) = mpsc::channel();
// 创建多个发送者
for i in 0..5 {
let tx_clone = tx.clone();
thread::spawn(move || {
for j in 0..10 {
tx_clone.send(format!("线程 {} 发送消息 {}", i, j)).unwrap();
thread::sleep(Duration::from_millis(100));
}
});
}
drop(tx); // 关闭原始发送者
// 接收消息
let mut count = 0;
for received in rx {
println!("收到: {}", received);
count += 1;
}
println!("总共收到 {} 条消息", count);
}
// 生产者-消费者模式
struct ProducerConsumer {
queue: Arc<Mutex<VecDeque<i32>>>,
condvar: Arc<Condvar>,
max_size: usize,
}
impl ProducerConsumer {
fn new(max_size: usize) -> Self {
ProducerConsumer {
queue: Arc::new(Mutex::new(VecDeque::new())),
condvar: Arc::new(Condvar::new()),
max_size,
}
}
fn produce(&self, item: i32) {
let mut queue = self.queue.lock().unwrap();
while queue.len() >= self.max_size {
queue = self.condvar.wait(queue).unwrap();
}
queue.push_back(item);
println!("生产: {}, 队列大小: {}", item, queue.len());
self.condvar.notify_all();
}
fn consume(&self) -> Option<i32> {
let mut queue = self.queue.lock().unwrap();
while queue.is_empty() {
queue = self.condvar.wait(queue).unwrap();
}
let item = queue.pop_front();
if let Some(i) = item {
println!("消费: {}, 队列大小: {}", i, queue.len());
}
self.condvar.notify_all();
item
}
fn clone_pc(&self) -> Self {
ProducerConsumer {
queue: Arc::clone(&self.queue),
condvar: Arc::clone(&self.condvar),
max_size: self.max_size,
}
}
}
fn demonstrate_producer_consumer() {
let pc = ProducerConsumer::new(5);
// 生产者线程
let pc_producer = pc.clone_pc();
let producer = thread::spawn(move || {
for i in 0..20 {
pc_producer.produce(i);
thread::sleep(Duration::from_millis(50));
}
});
// 消费者线程
let mut consumers = vec![];
for _ in 0..3 {
let pc_consumer = pc.clone_pc();
let consumer = thread::spawn(move || {
for _ in 0..7 {
pc_consumer.consume();
thread::sleep(Duration::from_millis(150));
}
});
consumers.push(consumer);
}
producer.join().unwrap();
for consumer in consumers {
consumer.join().unwrap();
}
}
// 并行计算示例:并行求和
fn parallel_sum(data: Vec<i32>, num_threads: usize) -> i32 {
let chunk_size = (data.len() + num_threads - 1) / num_threads;
let data = Arc::new(data);
let mut handles = vec![];
for i in 0..num_threads {
let data_clone = Arc::clone(&data);
let handle = thread::spawn(move || {
let start = i * chunk_size;
let end = ((i + 1) * chunk_size).min(data_clone.len());
if start >= data_clone.len() {
return 0;
}
data_clone[start..end].iter().sum::<i32>()
});
handles.push(handle);
}
handles.into_iter()
.map(|h| h.join().unwrap())
.sum()
}
fn demonstrate_parallel_sum() {
let data: Vec<i32> = (1..=1000).collect();
let sum = parallel_sum(data, 4);
println!("并行求和结果: {}", sum);
}
fn main() {
println!("=== 基础线程 ===");
basic_threading();
println!("\n=== 线程池 ===");
demonstrate_thread_pool();
println!("\n=== 共享状态 ===");
demonstrate_shared_state();
println!("\n=== 消息传递 ===");
demonstrate_message_passing();
println!("\n=== 生产者消费者 ===");
demonstrate_producer_consumer();
println!("\n=== 并行求和 ===");
demonstrate_parallel_sum();
}
并发安全的数据结构
rust
use std::sync::RwLock;
struct ConcurrentMap<K, V> {
data: Arc<RwLock<std::collections::HashMap<K, V>>>,
}
impl<K: Eq + std::hash::Hash, V> ConcurrentMap<K, V> {
fn new() -> Self {
ConcurrentMap {
data: Arc::new(RwLock::new(std::collections::HashMap::new())),
}
}
fn insert(&self, key: K, value: V) {
let mut map = self.data.write().unwrap();
map.insert(key, value);
}
fn get(&self, key: &K) -> Option<V>
where
V: Clone
{
let map = self.data.read().unwrap();
map.get(key).cloned()
}
}
总结
Rust 的并发模型通过所有权和类型系统保证了线程安全。使用 Arc、Mutex、通道等工具,可以安全高效地编写并发程序。