并发编程基础

Rust 并发编程基础

概述

Rust 的所有权系统使得并发编程更加安全。编译器可以在编译时防止数据竞争,这是 Rust 的一大优势。

线程基础

Rust 使用 1:1 线程模型,每个语言线程对应一个操作系统线程。

简单示例

rust 复制代码
use std::thread;
use std::time::Duration;

fn basic_threading() {
    let handle = thread::spawn(|| {
        for i in 1..10 {
            println!("子线程: {}", i);
            thread::sleep(Duration::from_millis(1));
        }
    });
    
    for i in 1..5 {
        println!("主线程: {}", i);
        thread::sleep(Duration::from_millis(1));
    }
    
    handle.join().unwrap();
}

复杂案例:实现一个并发任务调度器

rust 复制代码
use std::sync::{Arc, Mutex, Condvar};
use std::sync::mpsc::{self, Sender, Receiver};
use std::thread;
use std::time::Duration;
use std::collections::VecDeque;

// 任务类型
type Task = Box<dyn FnOnce() + Send + 'static>;

// 工作线程状态
#[derive(Debug, Clone, Copy, PartialEq)]
enum WorkerState {
    Idle,
    Busy,
    Stopped,
}

// 工作线程
struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    fn new(
        id: usize,
        receiver: Arc<Mutex<Receiver<Task>>>,
        state: Arc<Mutex<WorkerState>>,
        condvar: Arc<Condvar>,
    ) -> Self {
        let thread = thread::spawn(move || {
            loop {
                let task = {
                    let receiver = receiver.lock().unwrap();
                    receiver.recv()
                };
                
                match task {
                    Ok(task) => {
                        {
                            let mut s = state.lock().unwrap();
                            *s = WorkerState::Busy;
                        }
                        
                        println!("工作线程 {} 正在执行任务", id);
                        task();
                        
                        {
                            let mut s = state.lock().unwrap();
                            *s = WorkerState::Idle;
                        }
                        condvar.notify_all();
                    }
                    Err(_) => {
                        println!("工作线程 {} 停止", id);
                        let mut s = state.lock().unwrap();
                        *s = WorkerState::Stopped;
                        break;
                    }
                }
            }
        });
        
        Worker {
            id,
            thread: Some(thread),
        }
    }
}

// 线程池
struct ThreadPool {
    workers: Vec<Worker>,
    sender: Sender<Task>,
    worker_states: Vec<Arc<Mutex<WorkerState>>>,
    condvar: Arc<Condvar>,
}

impl ThreadPool {
    fn new(size: usize) -> Self {
        assert!(size > 0);
        
        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));
        let condvar = Arc::new(Condvar::new());
        
        let mut workers = Vec::with_capacity(size);
        let mut worker_states = Vec::with_capacity(size);
        
        for id in 0..size {
            let state = Arc::new(Mutex::new(WorkerState::Idle));
            worker_states.push(state.clone());
            
            workers.push(Worker::new(
                id,
                Arc::clone(&receiver),
                state,
                Arc::clone(&condvar),
            ));
        }
        
        ThreadPool {
            workers,
            sender,
            worker_states,
            condvar,
        }
    }
    
    fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let task = Box::new(f);
        self.sender.send(task).unwrap();
    }
    
    fn active_count(&self) -> usize {
        self.worker_states
            .iter()
            .filter(|state| {
                *state.lock().unwrap() == WorkerState::Busy
            })
            .count()
    }
    
    fn wait_completion(&self) {
        loop {
            let all_idle = self.worker_states
                .iter()
                .all(|state| {
                    let s = state.lock().unwrap();
                    *s == WorkerState::Idle || *s == WorkerState::Stopped
                });
            
            if all_idle {
                break;
            }
            
            let state = self.worker_states[0].lock().unwrap();
            let _guard = self.condvar.wait(state).unwrap();
        }
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.clone());
        
        for worker in &mut self.workers {
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

// 演示线程池使用
fn demonstrate_thread_pool() {
    let pool = ThreadPool::new(4);
    
    println!("提交 10 个任务到线程池");
    
    for i in 0..10 {
        pool.execute(move || {
            println!("任务 {} 开始执行", i);
            thread::sleep(Duration::from_millis(500));
            println!("任务 {} 完成", i);
        });
    }
    
    println!("等待所有任务完成...");
    pool.wait_completion();
    println!("所有任务已完成");
}

// 使用 Arc 和 Mutex 实现共享状态
struct SharedCounter {
    count: Arc<Mutex<i32>>,
}

impl SharedCounter {
    fn new() -> Self {
        SharedCounter {
            count: Arc::new(Mutex::new(0)),
        }
    }
    
    fn increment(&self) {
        let mut count = self.count.lock().unwrap();
        *count += 1;
    }
    
    fn get(&self) -> i32 {
        *self.count.lock().unwrap()
    }
    
    fn clone_counter(&self) -> Self {
        SharedCounter {
            count: Arc::clone(&self.count),
        }
    }
}

fn demonstrate_shared_state() {
    let counter = SharedCounter::new();
    let mut handles = vec![];
    
    for _ in 0..10 {
        let counter_clone = counter.clone_counter();
        let handle = thread::spawn(move || {
            for _ in 0..100 {
                counter_clone.increment();
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("最终计数: {}", counter.get());
}

// 消息传递并发
fn demonstrate_message_passing() {
    let (tx, rx) = mpsc::channel();
    
    // 创建多个发送者
    for i in 0..5 {
        let tx_clone = tx.clone();
        thread::spawn(move || {
            for j in 0..10 {
                tx_clone.send(format!("线程 {} 发送消息 {}", i, j)).unwrap();
                thread::sleep(Duration::from_millis(100));
            }
        });
    }
    
    drop(tx); // 关闭原始发送者
    
    // 接收消息
    let mut count = 0;
    for received in rx {
        println!("收到: {}", received);
        count += 1;
    }
    
    println!("总共收到 {} 条消息", count);
}

// 生产者-消费者模式
struct ProducerConsumer {
    queue: Arc<Mutex<VecDeque<i32>>>,
    condvar: Arc<Condvar>,
    max_size: usize,
}

impl ProducerConsumer {
    fn new(max_size: usize) -> Self {
        ProducerConsumer {
            queue: Arc::new(Mutex::new(VecDeque::new())),
            condvar: Arc::new(Condvar::new()),
            max_size,
        }
    }
    
    fn produce(&self, item: i32) {
        let mut queue = self.queue.lock().unwrap();
        
        while queue.len() >= self.max_size {
            queue = self.condvar.wait(queue).unwrap();
        }
        
        queue.push_back(item);
        println!("生产: {}, 队列大小: {}", item, queue.len());
        self.condvar.notify_all();
    }
    
    fn consume(&self) -> Option<i32> {
        let mut queue = self.queue.lock().unwrap();
        
        while queue.is_empty() {
            queue = self.condvar.wait(queue).unwrap();
        }
        
        let item = queue.pop_front();
        if let Some(i) = item {
            println!("消费: {}, 队列大小: {}", i, queue.len());
        }
        self.condvar.notify_all();
        item
    }
    
    fn clone_pc(&self) -> Self {
        ProducerConsumer {
            queue: Arc::clone(&self.queue),
            condvar: Arc::clone(&self.condvar),
            max_size: self.max_size,
        }
    }
}

fn demonstrate_producer_consumer() {
    let pc = ProducerConsumer::new(5);
    
    // 生产者线程
    let pc_producer = pc.clone_pc();
    let producer = thread::spawn(move || {
        for i in 0..20 {
            pc_producer.produce(i);
            thread::sleep(Duration::from_millis(50));
        }
    });
    
    // 消费者线程
    let mut consumers = vec![];
    for _ in 0..3 {
        let pc_consumer = pc.clone_pc();
        let consumer = thread::spawn(move || {
            for _ in 0..7 {
                pc_consumer.consume();
                thread::sleep(Duration::from_millis(150));
            }
        });
        consumers.push(consumer);
    }
    
    producer.join().unwrap();
    for consumer in consumers {
        consumer.join().unwrap();
    }
}

// 并行计算示例:并行求和
fn parallel_sum(data: Vec<i32>, num_threads: usize) -> i32 {
    let chunk_size = (data.len() + num_threads - 1) / num_threads;
    let data = Arc::new(data);
    let mut handles = vec![];
    
    for i in 0..num_threads {
        let data_clone = Arc::clone(&data);
        let handle = thread::spawn(move || {
            let start = i * chunk_size;
            let end = ((i + 1) * chunk_size).min(data_clone.len());
            
            if start >= data_clone.len() {
                return 0;
            }
            
            data_clone[start..end].iter().sum::<i32>()
        });
        handles.push(handle);
    }
    
    handles.into_iter()
        .map(|h| h.join().unwrap())
        .sum()
}

fn demonstrate_parallel_sum() {
    let data: Vec<i32> = (1..=1000).collect();
    let sum = parallel_sum(data, 4);
    println!("并行求和结果: {}", sum);
}

fn main() {
    println!("=== 基础线程 ===");
    basic_threading();
    
    println!("\n=== 线程池 ===");
    demonstrate_thread_pool();
    
    println!("\n=== 共享状态 ===");
    demonstrate_shared_state();
    
    println!("\n=== 消息传递 ===");
    demonstrate_message_passing();
    
    println!("\n=== 生产者消费者 ===");
    demonstrate_producer_consumer();
    
    println!("\n=== 并行求和 ===");
    demonstrate_parallel_sum();
}

并发安全的数据结构

rust 复制代码
use std::sync::RwLock;

struct ConcurrentMap<K, V> {
    data: Arc<RwLock<std::collections::HashMap<K, V>>>,
}

impl<K: Eq + std::hash::Hash, V> ConcurrentMap<K, V> {
    fn new() -> Self {
        ConcurrentMap {
            data: Arc::new(RwLock::new(std::collections::HashMap::new())),
        }
    }
    
    fn insert(&self, key: K, value: V) {
        let mut map = self.data.write().unwrap();
        map.insert(key, value);
    }
    
    fn get(&self, key: &K) -> Option<V> 
    where 
        V: Clone 
    {
        let map = self.data.read().unwrap();
        map.get(key).cloned()
    }
}

总结

Rust 的并发模型通过所有权和类型系统保证了线程安全。使用 Arc、Mutex、通道等工具,可以安全高效地编写并发程序。

相关推荐
碧海银沙音频科技研究院5 小时前
ES7243E ADC模拟音频转i2S到 BES I2S1 Master输出播放到SPK精准分析
人工智能·算法·音视频
百度智能云5 小时前
MySQL最怕的IN大列表,被百度智能云GaiaDB治好了!查询速度提升60倍!
算法
信奥卷王5 小时前
[GESP202506 五级] 奖品兑换
数据结构·算法
奶茶树5 小时前
【数据结构】二叉搜索树
数据结构·算法
晨曦(zxr_0102)6 小时前
CSP-X 2024 复赛编程题全解(B4104+B4105+B4106+B4107)
数据结构·c++·算法
ai安歌6 小时前
【Rust编程:从新手到大师】 Rust 控制流深度详解
开发语言·算法·rust
Shinom1ya_6 小时前
算法 day 36
算法
·白小白6 小时前
力扣(LeetCode) ——15.三数之和(C++)
c++·算法·leetcode
海琴烟Sunshine6 小时前
leetcode 268. 丢失的数字 python
python·算法·leetcode