Rust并发编程实战技巧

Rust的并发模型是其最强大的特性之一。通过所有权系统和类型系统,Rust能够在编译时防止数据竞争,让开发者能够编写安全且高效的并发代码。

第二十二章:线程基础

22.1 创建和管理线程

rust 复制代码
use std::thread;
use std::time::Duration;

fn main() {
    // 创建新线程
    let handle = thread::spawn(|| {
        for i in 1..10 {
            println!("线程中的数字: {}", i);
            thread::sleep(Duration::from_millis(100));
        }
    });

    // 主线程继续执行
    for i in 1..5 {
        println!("主线程中的数字: {}", i);
        thread::sleep(Duration::from_millis(200));
    }

    // 等待子线程结束
    handle.join().unwrap();
    println!("所有线程执行完成!");

    // 演示线程移动
    let data = vec![1, 2, 3, 4, 5];
    
    let handle = thread::spawn(move || {
        println!("在线程中使用数据: {:?}", data);
        // data 的所有权被移动到线程中
    });
    
    handle.join().unwrap();
    // println!("{:?}", data); // 错误!data 的所有权已经移动
}

22.2 线程间通信:通道

rust 复制代码
use std::sync::mpsc; // 多生产者,单消费者
use std::thread;
use std::time::Duration;

fn main() {
    // 创建通道
    let (tx, rx) = mpsc::channel();
    
    // 克隆发送端用于多个生产者
    let tx1 = tx.clone();
    let tx2 = tx.clone();
    
    // 生产者线程 1
    thread::spawn(move || {
        let messages = vec![
            String::from("你好"),
            String::from("从"),
            String::from("线程1"),
        ];
        
        for msg in messages {
            tx1.send(msg).unwrap();
            thread::sleep(Duration::from_millis(100));
        }
    });
    
    // 生产者线程 2
    thread::spawn(move || {
        let messages = vec![
            String::from("更多"),
            String::from("消息"),
            String::from("从线程2"),
        ];
        
        for msg in messages {
            tx2.send(msg).unwrap();
            thread::sleep(Duration::from_millis(150));
        }
    });
    
    // 在主线程中接收消息
    drop(tx); // 丢弃原始的发送端
    
    for received in rx {
        println!("收到: {}", received);
    }
    
    println!("所有消息接收完成!");
}

// 更复杂的通道使用
fn channel_example() {
    let (tx, rx) = mpsc::channel();
    
    thread::spawn(move || {
        for i in 0..10 {
            let message = format!("消息-{}", i);
            tx.send(message).unwrap();
        }
    });
    
    // 使用迭代器接收
    for message in rx {
        println!("接收: {}", message);
    }
}

// 带超时的通道操作
fn channel_with_timeout() {
    use std::sync::mpsc::{Receiver, RecvTimeoutError};
    
    let (tx, rx) = mpsc::channel();
    
    thread::spawn(move || {
        thread::sleep(Duration::from_secs(2));
        tx.send("延迟的消息".to_string()).unwrap();
    });
    
    match rx.recv_timeout(Duration::from_secs(1)) {
        Ok(msg) => println!("收到: {}", msg),
        Err(RecvTimeoutError::Timeout) => println!("接收超时"),
        Err(RecvTimeoutError::Disconnected) => println!("发送端已断开"),
    }
}

第二十三章:共享状态并发

23.1 Mutex:互斥锁

rust 复制代码
use std::sync::{Arc, Mutex};
use std::thread;

fn main() {
    // 使用 Arc 实现多线程共享
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];

    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            let mut num = counter.lock().unwrap();
            *num += 1;
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终结果: {}", *counter.lock().unwrap());

    // 更复杂的 Mutex 使用
    complex_mutex_example();
}

fn complex_mutex_example() {
    #[derive(Debug)]
    struct BankAccount {
        balance: Mutex<f64>,
        name: String,
    }

    impl BankAccount {
        fn new(name: &str, initial_balance: f64) -> Self {
            Self {
                balance: Mutex::new(initial_balance),
                name: name.to_string(),
            }
        }

        fn deposit(&self, amount: f64) -> Result<(), String> {
            let mut balance = self.balance.lock().unwrap();
            if amount < 0.0 {
                return Err("存款金额不能为负".to_string());
            }
            *balance += amount;
            Ok(())
        }

        fn withdraw(&self, amount: f64) -> Result<f64, String> {
            let mut balance = self.balance.lock().unwrap();
            if amount < 0.0 {
                return Err("取款金额不能为负".to_string());
            }
            if *balance < amount {
                return Err("余额不足".to_string());
            }
            *balance -= amount;
            Ok(amount)
        }

        fn get_balance(&self) -> f64 {
            *self.balance.lock().unwrap()
        }

        fn transfer(&self, to: &BankAccount, amount: f64) -> Result<(), String> {
            // 注意:这可能导致死锁!实际应用中应该使用更安全的方法
            let _from_balance = self.balance.lock().unwrap();
            let _to_balance = to.balance.lock().unwrap();
            
            self.withdraw(amount)?;
            to.deposit(amount)?;
            
            Ok(())
        }
    }

    let account = Arc::new(BankAccount::new("主账户", 1000.0));
    let mut handles = vec![];

    // 多个存款线程
    for i in 0..5 {
        let account = Arc::clone(&account);
        let handle = thread::spawn(move || {
            match account.deposit((i + 1) as f64 * 100.0) {
                Ok(()) => println!("线程 {} 存款成功", i),
                Err(e) => println!("线程 {} 存款失败: {}", i, e),
            }
        });
        handles.push(handle);
    }

    // 多个取款线程
    for i in 0..3 {
        let account = Arc::clone(&account);
        let handle = thread::spawn(move || {
            match account.withdraw((i + 1) as f64 * 50.0) {
                Ok(amount) => println!("线程 {} 取款 {} 成功", i, amount),
                Err(e) => println!("线程 {} 取款失败: {}", i, e),
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终余额: {:.2}", account.get_balance());
}

23.2 RwLock:读写锁

rust 复制代码
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Duration;

fn main() {
    let data = Arc::new(RwLock::new(vec![1, 2, 3, 4, 5]));
    let mut handles = vec![];

    // 多个读取线程
    for i in 0..3 {
        let data = Arc::clone(&data);
        let handle = thread::spawn(move || {
            for _ in 0..3 {
                let reader = data.read().unwrap();
                println!("读取线程 {}: {:?}", i, *reader);
                thread::sleep(Duration::from_millis(100));
                // 读锁在这里自动释放
            }
        });
        handles.push(handle);
    }

    // 写入线程
    let data_write = Arc::clone(&data);
    let write_handle = thread::spawn(move || {
        for i in 0..2 {
            {
                let mut writer = data_write.write().unwrap();
                println!("写入线程: 修改数据");
                writer.push(writer.len() as i32 + 1);
            } // 写锁在这里释放,允许读取线程继续
            thread::sleep(Duration::from_millis(300));
        }
    });
    handles.push(write_handle);

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终数据: {:?}", data.read().unwrap());
}

// 缓存系统示例
struct Cache<K, V> 
where 
    K: Eq + std::hash::Hash + Clone,
    V: Clone,
{
    data: Arc<RwLock<std::collections::HashMap<K, V>>>,
}

impl<K, V> Cache<K, V>
where 
    K: Eq + std::hash::Hash + Clone,
    V: Clone,
{
    fn new() -> Self {
        Self {
            data: Arc::new(RwLock::new(std::collections::HashMap::new())),
        }
    }

    fn get(&self, key: &K) -> Option<V> {
        let reader = self.data.read().unwrap();
        reader.get(key).cloned()
    }

    fn set(&self, key: K, value: V) {
        let mut writer = self.data.write().unwrap();
        writer.insert(key, value);
    }

    fn remove(&self, key: &K) -> Option<V> {
        let mut writer = self.data.write().unwrap();
        writer.remove(key)
    }

    fn clear(&self) {
        let mut writer = self.data.write().unwrap();
        writer.clear();
    }

    fn len(&self) -> usize {
        let reader = self.data.read().unwrap();
        reader.len()
    }
}

fn cache_example() {
    let cache = Arc::new(Cache::new());
    let mut handles = vec![];

    // 多个读取线程
    for i in 0..5 {
        let cache = Arc::clone(&cache);
        let handle = thread::spawn(move || {
            for j in 0..10 {
                let key = format!("key_{}_{}", i, j);
                if let Some(value) = cache.get(&key) {
                    println!("线程 {} 找到键 {}: {}", i, key, value);
                } else {
                    // 模拟缓存未命中,然后设置值
                    let value = format!("value_{}_{}", i, j);
                    cache.set(key.clone(), value.clone());
                    println!("线程 {} 设置键 {}: {}", i, key, value);
                }
                thread::sleep(Duration::from_millis(50));
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("缓存最终大小: {}", cache.len());
}

23.3 Atomic 类型

rust 复制代码
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;

fn main() {
    // Atomic 类型示例
    let atomic_counter = Arc::new(AtomicUsize::new(0));
    let stop_flag = Arc::new(AtomicBool::new(false));
    
    let mut handles = vec![];

    // 启动多个增加计数的线程
    for i in 0..5 {
        let counter = Arc::clone(&atomic_counter);
        let stop = Arc::clone(&stop_flag);
        
        let handle = thread::spawn(move || {
            while !stop.load(Ordering::Relaxed) {
                let current = counter.fetch_add(1, Ordering::SeqCst);
                println!("线程 {} 增加计数到: {}", i, current + 1);
                thread::sleep(std::time::Duration::from_millis(100));
            }
            println!("线程 {} 停止", i);
        });
        handles.push(handle);
    }

    // 让线程运行一段时间
    thread::sleep(std::time::Duration::from_secs(2));
    
    // 设置停止标志
    stop_flag.store(true, Ordering::SeqCst);

    for handle in handles {
        handle.join().unwrap();
    }

    println!("最终计数: {}", atomic_counter.load(Ordering::SeqCst));

    // 更复杂的原子操作示例
    atomic_operations_example();
}

fn atomic_operations_example() {
    let shared_value = Arc::new(AtomicI32::new(0));
    let mut handles = vec![];

    // 使用比较并交换 (CAS) 操作
    for i in 0..10 {
        let shared_value = Arc::clone(&shared_value);
        let handle = thread::spawn(move || {
            loop {
                let current = shared_value.load(Ordering::Acquire);
                let new = current + 1;
                
                // 尝试原子性地更新值
                match shared_value.compare_exchange(
                    current, 
                    new, 
                    Ordering::SeqCst, 
                    Ordering::Relaxed
                ) {
                    Ok(_) => {
                        println!("线程 {} 成功更新: {} -> {}", i, current, new);
                        break;
                    }
                    Err(_) => {
                        println!("线程 {} 更新冲突,重试", i);
                        thread::sleep(std::time::Duration::from_millis(10));
                    }
                }
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    println!("CAS 操作后的最终值: {}", shared_value.load(Ordering::SeqCst));
}

// 无锁栈实现
use std::ptr;

struct Node<T> {
    value: T,
    next: *mut Node<T>,
}

pub struct LockFreeStack<T> {
    head: AtomicPtr<Node<T>>,
}

impl<T> LockFreeStack<T> {
    pub fn new() -> Self {
        Self {
            head: AtomicPtr::new(ptr::null_mut()),
        }
    }

    pub fn push(&self, value: T) {
        let new_node = Box::into_raw(Box::new(Node {
            value,
            next: ptr::null_mut(),
        }));

        loop {
            let current_head = self.head.load(Ordering::Acquire);
            unsafe {
                (*new_node).next = current_head;
            }
            
            if self.head.compare_exchange(
                current_head, 
                new_node, 
                Ordering::Release, 
                Ordering::Relaxed
            ).is_ok() {
                break;
            }
        }
    }

    pub fn pop(&self) -> Option<T> {
        loop {
            let current_head = self.head.load(Ordering::Acquire);
            if current_head.is_null() {
                return None;
            }

            let next = unsafe { (*current_head).next };
            
            if self.head.compare_exchange(
                current_head, 
                next, 
                Ordering::Release, 
                Ordering::Relaxed
            ).is_ok() {
                let node = unsafe { Box::from_raw(current_head) };
                return Some(node.value);
            }
        }
    }
}

impl<T> Drop for LockFreeStack<T> {
    fn drop(&mut self) {
        while self.pop().is_some() {}
    }
}

fn lock_free_stack_example() {
    let stack = Arc::new(LockFreeStack::new());
    let mut handles = vec![];

    // 生产者线程
    for i in 0..3 {
        let stack = Arc::clone(&stack);
        let handle = thread::spawn(move || {
            for j in 0..5 {
                let value = format!("值_{}_{}", i, j);
                stack.push(value);
                println!("生产者 {} 推入值", i);
                thread::sleep(std::time::Duration::from_millis(50));
            }
        });
        handles.push(handle);
    }

    // 消费者线程
    for i in 0..2 {
        let stack = Arc::clone(&stack);
        let handle = thread::spawn(move || {
            for _ in 0..7 {
                if let Some(value) = stack.pop() {
                    println!("消费者 {} 弹出: {}", i, value);
                } else {
                    println!("消费者 {} 发现栈为空", i);
                }
                thread::sleep(std::time::Duration::from_millis(80));
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }

    // 清空栈
    while let Some(value) = stack.pop() {
        println!("剩余值: {}", value);
    }
}

第二十四章:高级并发模式

24.1 工作窃取线程池

rust 复制代码
use std::sync::{Arc, Mutex, Condvar};
use std::collections::VecDeque;
use std::thread;

type Job = Box<dyn FnOnce() + Send + 'static>;

struct ThreadPool {
    workers: Vec<Worker>,
    sender: Option<std::sync::mpsc::Sender<Job>>,
}

impl ThreadPool {
    fn new(size: usize) -> ThreadPool {
        assert!(size > 0);

        let (sender, receiver) = std::sync::mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);

        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }

        ThreadPool {
            workers,
            sender: Some(sender),
        }
    }

    fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.as_ref().unwrap().send(job).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        drop(self.sender.take());

        for worker in &mut self.workers {
            println!("关闭工作线程 {}", worker.id);

            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<std::sync::mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let job = receiver.lock().unwrap().recv();

            match job {
                Ok(job) => {
                    println!("工作线程 {} 执行任务", id);
                    job();
                }
                Err(_) => {
                    println!("工作线程 {} 断开连接,关闭", id);
                    break;
                }
            }
        });

        Worker {
            id,
            thread: Some(thread),
        }
    }
}

fn thread_pool_example() {
    let pool = ThreadPool::new(4);

    for i in 0..8 {
        pool.execute(move || {
            println!("执行任务 {}", i);
            thread::sleep(std::time::Duration::from_secs(1));
            println!("完成任务 {}", i);
        });
    }

    thread::sleep(std::time::Duration::from_secs(5));
}

24.2 异步屏障和条件变量

rust 复制代码
use std::sync::{Arc, Barrier, Condvar, Mutex};
use std::thread;
use std::time::{Duration, Instant};

fn barrier_example() {
    let num_threads = 5;
    let barrier = Arc::new(Barrier::new(num_threads));
    let mut handles = vec![];

    for i in 0..num_threads {
        let barrier = Arc::clone(&barrier);
        let handle = thread::spawn(move || {
            println!("线程 {} 开始阶段1", i);
            thread::sleep(Duration::from_millis(i * 100));
            println!("线程 {} 完成阶段1,等待其他线程", i);
            
            // 等待所有线程到达屏障
            barrier.wait();
            
            println!("线程 {} 开始阶段2", i);
            thread::sleep(Duration::from_millis((num_threads - i) * 100));
            println!("线程 {} 完成阶段2", i);
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }
}

fn condition_variable_example() {
    #[derive(Debug)]
    struct SharedData {
        data: Vec<i32>,
        ready: bool,
    }

    let shared = Arc::new((Mutex::new(SharedData { data: vec![], ready: false }), Condvar::new()));
    let mut handles = vec![];

    // 生产者线程
    let shared_producer = Arc::clone(&shared);
    let producer_handle = thread::spawn(move || {
        let (lock, cvar) = &*shared_producer;
        
        println!("生产者: 准备数据...");
        thread::sleep(Duration::from_secs(2));
        
        {
            let mut data = lock.lock().unwrap();
            data.data = vec![1, 2, 3, 4, 5];
            data.ready = true;
            println!("生产者: 数据准备完成,通知消费者");
            cvar.notify_all();
        }
    });

    // 消费者线程
    for i in 0..3 {
        let shared_consumer = Arc::clone(&shared);
        let handle = thread::spawn(move || {
            let (lock, cvar) = &*shared_consumer;
            
            println!("消费者 {}: 等待数据...", i);
            let mut data = lock.lock().unwrap();
            
            while !data.ready {
                data = cvar.wait(data).unwrap();
            }
            
            println!("消费者 {}: 收到数据 {:?}", i, data.data);
        });
        handles.push(handle);
    }

    handles.push(producer_handle);

    for handle in handles {
        handle.join().unwrap();
    }
}

// 速率限制器
struct RateLimiter {
    last_check: Mutex<Instant>,
    interval: Duration,
    condvar: Condvar,
}

impl RateLimiter {
    fn new(interval: Duration) -> Self {
        Self {
            last_check: Mutex::new(Instant::now()),
            interval,
            condvar: Condvar::new(),
        }
    }

    fn wait(&self) {
        let mut last_check = self.last_check.lock().unwrap();
        
        loop {
            let now = Instant::now();
            let elapsed = now.duration_since(*last_check);
            
            if elapsed >= self.interval {
                *last_check = now;
                break;
            }
            
            let remaining = self.interval - elapsed;
            let (new_guard, _) = self.condvar.wait_timeout(last_check, remaining).unwrap();
            last_check = new_guard;
        }
    }
}

fn rate_limiter_example() {
    let limiter = Arc::new(RateLimiter::new(Duration::from_millis(500)));
    let mut handles = vec![];

    for i in 0..5 {
        let limiter = Arc::clone(&limiter);
        let handle = thread::spawn(move || {
            for j in 0..3 {
                limiter.wait();
                println!("线程 {} 执行操作 {}", i, j);
            }
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.join().unwrap();
    }
}

第二十五章:实际应用案例

25.1 并发Web服务器

rust 复制代码
use std::net::{TcpListener, TcpStream};
use std::io::prelude::*;
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use std::thread;
use std::time::Duration;

type Cache = Arc<Mutex<HashMap<String, String>>>;

struct HttpServer {
    cache: Cache,
    thread_pool: ThreadPool,
}

impl HttpServer {
    fn new() -> Self {
        Self {
            cache: Arc::new(Mutex::new(HashMap::new())),
            thread_pool: ThreadPool::new(10),
        }
    }

    fn handle_connection(&self, mut stream: TcpStream) {
        let cache = Arc::clone(&self.cache);
        
        self.thread_pool.execute(move || {
            let mut buffer = [0; 1024];
            stream.read(&mut buffer).unwrap();

            let request = String::from_utf8_lossy(&buffer[..]);
            let request_line = request.lines().next().unwrap_or("");

            println!("收到请求: {}", request_line);

            let response = match request_line {
                "GET / HTTP/1.1" => Self::handle_root(),
                s if s.starts_with("GET /cache/") => Self::handle_cache(s, &cache),
                s if s.starts_with("POST /cache/") => Self::handle_cache_post(s, &cache, &request),
                "GET /slow HTTP/1.1" => Self::handle_slow_request(),
                _ => Self::handle_not_found(),
            };

            stream.write(response.as_bytes()).unwrap();
            stream.flush().unwrap();
        });
    }

    fn handle_root() -> String {
        format!(
            "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
            <h1>欢迎来到并发服务器</h1>\
            <p><a href='/slow'>慢请求</a></p>\
            <p><a href='/cache/test'>获取缓存</a></p>\
            <form method='POST' action='/cache/test'>\
            <input type='text' name='value' value='缓存值'>\
            <input type='submit' value='设置缓存'>\
            </form>"
        )
    }

    fn handle_cache(request_line: &str, cache: &Cache) -> String {
        let key = request_line.split_whitespace().nth(1)
            .unwrap_or("")
            .trim_start_matches("/cache/");
        
        let cache_lock = cache.lock().unwrap();
        if let Some(value) = cache_lock.get(key) {
            format!(
                "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n\
                键 '{}' 的值是: {}",
                key, value
            )
        } else {
            format!(
                "HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\n\r\n\
                键 '{}' 未找到",
                key
            )
        }
    }

    fn handle_cache_post(request_line: &str, cache: &Cache, request: &str) -> String {
        let key = request_line.split_whitespace().nth(1)
            .unwrap_or("")
            .trim_start_matches("/cache/");
        
        // 简单的表单数据解析
        let value = if let Some(body_start) = request.find("\r\n\r\n") {
            let body = &request[body_start + 4..];
            body.split('=')
                .nth(1)
                .unwrap_or("默认值")
                .split('&')
                .next()
                .unwrap_or("默认值")
                .to_string()
        } else {
            "默认值".to_string()
        };

        {
            let mut cache_lock = cache.lock().unwrap();
            cache_lock.insert(key.to_string(), value.clone());
        }

        format!(
            "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n\
            已设置键 '{}' 的值为: {}",
            key, value
        )
    }

    fn handle_slow_request() -> String {
        // 模拟慢请求
        thread::sleep(Duration::from_secs(3));
        "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n慢请求完成!".to_string()
    }

    fn handle_not_found() -> String {
        "HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\n\r\n页面未找到".to_string()
    }

    fn start(&self, address: &str) {
        let listener = TcpListener::bind(address).unwrap();
        println!("服务器运行在 {}", address);

        for stream in listener.incoming() {
            match stream {
                Ok(stream) => {
                    self.handle_connection(stream);
                }
                Err(e) => {
                    eprintln!("连接错误: {}", e);
                }
            }
        }
    }
}

fn main() {
    let server = HttpServer::new();
    server.start("127.0.0.1:8080");
}

25.2 并发数据处理管道

rust 复制代码
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

#[derive(Debug, Clone)]
struct DataItem {
    id: u64,
    value: f64,
    timestamp: u64,
}

struct DataProcessor {
    input_rx: mpsc::Receiver<DataItem>,
    output_tx: mpsc::Sender<ProcessedData>,
}

#[derive(Debug)]
struct ProcessedData {
    original_id: u64,
    normalized_value: f64,
    category: String,
    processed_at: u64,
}

impl DataProcessor {
    fn new(input_rx: mpsc::Receiver<DataItem>, output_tx: mpsc::Sender<ProcessedData>) -> Self {
        Self { input_rx, output_tx }
    }

    fn start(self, num_workers: usize) {
        let mut handles = vec![];

        for worker_id in 0..num_workers {
            let input_rx = self.input_rx.clone();
            let output_tx = self.output_tx.clone();
            
            let handle = thread::spawn(move || {
                while let Ok(item) = input_rx.recv() {
                    let processed = Self::process_item(item, worker_id);
                    if output_tx.send(processed).is_err() {
                        break;
                    }
                }
                println!("工作线程 {} 结束", worker_id);
            });
            handles.push(handle);
        }

        // 不再需要原始的接收端
        drop(self.input_rx);

        for handle in handles {
            handle.join().unwrap();
        }
    }

    fn process_item(item: DataItem, worker_id: usize) -> ProcessedData {
        // 模拟处理时间
        thread::sleep(Duration::from_millis(10));

        let normalized_value = item.value / 100.0;
        let category = if normalized_value < 0.3 {
            "低".to_string()
        } else if normalized_value < 0.7 {
            "中".to_string()
        } else {
            "高".to_string()
        };

        ProcessedData {
            original_id: item.id,
            normalized_value,
            category,
            processed_at: item.timestamp + 1,
        }
    }
}

struct DataAggregator {
    input_rx: mpsc::Receiver<ProcessedData>,
}

impl DataAggregator {
    fn new(input_rx: mpsc::Receiver<ProcessedData>) -> Self {
        Self { input_rx }
    }

    fn start(self) {
        let mut stats = std::collections::HashMap::new();
        let mut total_count = 0;

        while let Ok(processed) = self.input_rx.recv() {
            total_count += 1;
            *stats.entry(processed.category.clone()).or_insert(0) += 1;

            if total_count % 100 == 0 {
                println!("已处理 {} 个数据项", total_count);
                println!("分类统计: {:?}", stats);
            }
        }

        println!("聚合器结束,总共处理 {} 个数据项", total_count);
        println!("最终统计: {:?}", stats);
    }
}

fn data_processing_pipeline() {
    let (data_tx, data_rx) = mpsc::channel();
    let (processed_tx, processed_rx) = mpsc::channel();

    // 启动数据生成器
    let data_tx_clone = data_tx.clone();
    let generator_handle = thread::spawn(move || {
        for i in 0..1000 {
            let item = DataItem {
                id: i,
                value: (i as f64 * 0.1).sin().abs() * 100.0,
                timestamp: i,
            };
            
            if data_tx_clone.send(item).is_err() {
                break;
            }
            
            if i % 100 == 0 {
                println!("已生成 {} 个数据项", i);
            }
            
            thread::sleep(Duration::from_millis(1));
        }
        println!("数据生成完成");
    });

    // 启动处理器
    let processor = DataProcessor::new(data_rx, processed_tx);
    let processor_handle = thread::spawn(move || {
        processor.start(4); // 使用4个工作线程
    });

    // 启动聚合器
    let aggregator = DataAggregator::new(processed_rx);
    let aggregator_handle = thread::spawn(move || {
        aggregator.start();
    });

    generator_handle.join().unwrap();
    drop(data_tx); // 关闭数据通道
    
    processor_handle.join().unwrap();
    aggregator_handle.join().unwrap();

    println!("数据处理管道运行完成");
}

fn main() {
    println!("启动并发数据处理管道...");
    data_processing_pipeline();
}

并发编程最佳实践

  1. 选择合适的并发原语

    • 消息传递:使用通道进行线程间通信
    • 共享状态:使用 Arc<Mutex<T>>Arc<RwLock<T>>
    • 无锁编程:使用原子类型
  2. 避免死锁

    • 按固定顺序获取锁
    • 使用超时机制
    • 避免在持有锁时调用未知代码
  3. 性能优化

    • 减少锁的持有时间
    • 使用读写锁替代互斥锁
    • 考虑无锁数据结构
  4. 错误处理

    • 正确处理 Mutex 中毒
    • 使用 Result 类型进行错误传播
    • 实现优雅的关闭机制

Rust的并发模型提供了强大的安全保障,让开发者能够编写高效且安全的并发代码。通过合理使用这些工具和模式,你可以构建能够充分利用多核处理器优势的应用程序。

继续构建高效的并发应用!⚡


版权声明:本教程仅供学习使用,转载请注明出处。

相关推荐
Lisonseekpan2 小时前
Linux 常用命令详解与使用规则
linux·服务器·后端
Yurko132 小时前
【C语言】选择结构和循环结构的进阶
c语言·开发语言·学习
小白学大数据3 小时前
构建1688店铺商品数据集:Python爬虫数据采集与格式化实践
开发语言·爬虫·python
林太白3 小时前
rust15-菜单模块
后端·rust
大邳草民3 小时前
深入理解 Python 的“左闭右开”设计哲学
开发语言·笔记·python
调试人生的显微镜3 小时前
iOS 上架费用全解析 开发者账号、App 审核、工具使用与开心上架(Appuploader)免 Mac 成本优化指南
后端
实心儿儿3 小时前
C++ —— list
开发语言·c++
LSTM973 小时前
使用 Spire.XLS for Python 将 Excel 转换为 PDF
后端
ACGkaka_3 小时前
SpringBoot 实战(四十)集成 Statemachine
java·spring boot·后端