线程池学习(六)实现工作窃取线程池(WorkStealingThreadPool)

WorkStealingPool核心思想

一句话:每个线程有自己的任务队列,空闲时偷别人的任务

复制代码
线程1队列:[任务1.1, 任务1.2, 任务1.3]
线程2队列:[任务2.1, 任务2.2]
线程3队列:[任务3.1]

线程3先做完 → 偷线程1的任务1.3 → 继续工作

为什么需要WorkStealingPool?

传统线程池的问题:

复制代码
一个全局队列 + 多个线程
↓
所有线程竞争同一个锁
↓
锁竞争成为性能瓶颈

WorkStealingPool的解决方案:

复制代码
每个线程有自己的队列
↓
大部分时间操作自己的队列(无竞争)
↓
只有空闲时才偷别人队列(偶尔竞争)

需求:

工作窃取算法:WorkStealingPool采用了工作窃取算法,具体来说就是当某个线程执行完自己队列中的任务后,会从其他线程的队列中"偷取"任务来执行。这种算法可以提高线程利用率,减少线程之间的竞争,以及减少线程的等待时间。

WorkStealingPool可以设定多个工作线程,每个工作线程都有一个自己的任务队列,每个线程在执行任务时会首先从自己的队列中获取任务,如果自己队列为空,则从其他线程的队列中获取任务。这种设计可以充分发挥多核处理器的并行能力,提高整体的任务处理效率。

syncqueue.hpp

cpp 复制代码
#ifndef WORKSTEALSYNCQUEUE.HPP
#define WORKSTEALSYNCQUEUE.HPP
#include<iostream>
using namespace std;
#include<mutex>
#include<queue>
#include<vector>
#include<condition_variable>
#include<atomic>
#include<thread>
#include<functional>
#include<list>
template<typename T>
class syncqueue{    
private:
    deque<T>tasks;
    mutable mutex mtx;
    condition_variable notempty;
  
public:
    syncqueue()=default;
    syncqueue(const syncqueue&)=delete;
    syncqueue& operator=(const syncqueue&)=delete;
    //本地线程用
    void pushback(T task){//从尾部放入任务
        lock_guard<mutex>lock(mtx);
        tasks.push_back(move(task));
        notempty.notify_one();
    }
    bool popback(T&task){//从尾部取出任务
        lock_guard<mutex>lock(mtx);
        if(tasks.empty()){
            return false;
        }
        task=move(tasks.back());
        tasks.pop_back();
        return true;
    }
    //窃取操作(其他线程用)
    bool stealfront(T&task){
        lock_guard<mutex>lock(mtx);
        if(tasks.empty()){
            return false;
        }
        task=move(tasks.front());
        tasks.pop_front();
        return true;
    }
    
    bool trystealfront(T&task){//尝试从头部窃取 非阻塞
        lock_guard<mutex>lock(mtx);
        if(tasks.empty()){
            return false;
        }
        task=move(tasks.front());
        tasks.pop_front();
        return true;
    }

    void empty()const{
        lock_guard<mutex>lock(mtx);
        return tasks.empty();
    }
    size_t getsize()const{
        lock_guard<mutex>lock(mtx);
        return tasks.size();
    }
    void waitfortask(){
        unique_lock<mutex>lock(mtx);
        notempty.wait(lock,[this](){return !tasks.empty();});
    }
};
#endif

WorkStealingQueue的独特设计

1. 为什么成员变量这么少?

复制代码
// WorkStealingQueue只有:
deque<T> tasks;          // 双端队列
mutex mtx;              // 一个互斥锁
condition_variable notempty; // 一个条件变量

// 相比FixedThreadPool的SyncQueue缺少:
// - condition_variable notfull (不需要队列满控制)
// - int m_maxSize (不需要队列大小限制)
// - bool m_needStop (停止标志可以合并)

原因:工作窃取队列设计哲学不同:

  • 每个线程有自己的队列 → 不会出现全局队列那种积压

  • 本地操作多,窃取操作少 → 竞争少,可以简化

  • 任务窃取是"尽力而为" → 偷不到就等,不需要复杂状态

2. 为什么用deque而不用queuelist

复制代码
deque<T> tasks;  // ✅ 双端队列,两端都能快速操作

// 本地线程操作尾部(LIFO,缓存友好):
push_back() / pop_back()

// 其他线程窃取头部(FIFO,公平):
pop_front()

3. 为什么禁用拷贝构造函数?

复制代码
syncqueue(const syncqueue&)=delete;  // ❌ 禁止拷贝
syncqueue& operator=(const syncqueue&)=delete;

原因:

  1. 队列包含互斥锁:锁不能拷贝

  2. 队列包含条件变量:条件变量不能拷贝

  3. 语义上不应拷贝:队列是资源,应该移动而不是拷贝

4. 可以移动吗?应该允许!

复制代码
// 应该添加移动构造和移动赋值
syncqueue(syncqueue&&) = default;           // ✅ 允许移动构造
syncqueue& operator=(syncqueue&&) = default; // ✅ 允许移动赋值

workstealingpool

cpp 复制代码
// WorkStealingThreadPool.h
#ifndef WORK_STEALING_THREAD_POOL_H
#define WORK_STEALING_THREAD_POOL_H

#include "WorkStealingQueue.h"
#include <vector>
#include <thread>
#include <atomic>
#include <random>
#include <functional>
#include <future>
#include <iostream>

class WorkStealingThreadPool {
public:
    using Task = std::function<void()>;
    
private:
    // 每个线程一个队列
    std::vector<std::unique_ptr<WorkStealingQueue<Task>>> threadQueues_;
    
    // 工作线程
    std::vector<std::thread> workers_;
    
    // 控制标志
    std::atomic<bool> running_;
    std::once_flag stop_flag_;
    
    // 随机数生成器(用于随机选择偷取目标)
    static thread_local std::mt19937 random_generator_;
    
    // 获取当前线程的队列索引
    static thread_local size_t thread_index_;
    
    // 工作线程主函数
    void workerMain(size_t index) {
        thread_index_ = index;
        
        while (running_) {
            Task task;
            
            // 第一步:从自己的队列取任务
            if (threadQueues_[index]->popBack(task)) {
                task();
                continue;
            }
            
            // 第二步:自己的队列空,尝试窃取
            bool stolen = false;
            size_t queue_count = threadQueues_.size();
            
            // 随机顺序尝试窃取
            std::vector<size_t> steal_order(queue_count);
            for (size_t i = 0; i < queue_count; i++) {
                steal_order[i] = i;
            }
            std::shuffle(steal_order.begin(), steal_order.end(), random_generator_);
            
            for (size_t i : steal_order) {
                if (i == index) continue;  // 不偷自己
                
                if (threadQueues_[i]->stealFront(task)) {
                    stolen = true;
                    break;
                }
            }
            
            // 第三步:如果窃取成功,执行任务
            if (stolen) {
                task();
                continue;
            }
            
            // 第四步:所有队列都空,等待一会儿
            std::this_thread::sleep_for(std::chrono::milliseconds(10));
        }
    }
    
    // 停止所有线程
    void stopAll() {
        running_ = false;
        
        for (auto& thread : workers_) {
            if (thread.joinable()) {
                thread.join();
            }
        }
        
        workers_.clear();
        threadQueues_.clear();
    }
    
public:
    // 构造函数
    WorkStealingThreadPool(size_t thread_count = std::thread::hardware_concurrency()) 
        : running_(true) {
        
        if (thread_count == 0) {
            thread_count = std::thread::hardware_concurrency();
        }
        
        // 为每个线程创建队列
        threadQueues_.reserve(thread_count);
        for (size_t i = 0; i < thread_count; i++) {
            threadQueues_.emplace_back(std::make_unique<WorkStealingQueue<Task>>());
        }
        
        // 创建工作线程
        workers_.reserve(thread_count);
        for (size_t i = 0; i < thread_count; i++) {
            workers_.emplace_back(&WorkStealingThreadPool::workerMain, this, i);
        }
        
        std::cout << "[WorkStealingPool] 启动 " << thread_count 
                  << " 个工作线程" << std::endl;
    }
    
    // 析构函数
    ~WorkStealingThreadPool() {
        stop();
    }
    
    // 禁止拷贝
    WorkStealingThreadPool(const WorkStealingThreadPool&) = delete;
    WorkStealingThreadPool& operator=(const WorkStealingThreadPool&) = delete;
    
    // 停止线程池
    void stop() {
        std::call_once(stop_flag_, [this]() {
            stopAll();
            std::cout << "[WorkStealingPool] 已停止" << std::endl;
        });
    }
    
    // ========== 任务提交接口 ==========
    
    // 提交任务到随机队列
    template<typename Func, typename... Args>
    void execute(Func&& func, Args&&... args) {
        if (!running_) {
            throw std::runtime_error("线程池已停止");
        }
        
        // 包装任务
        auto task = std::bind(std::forward<Func>(func), 
                              std::forward<Args>(args)...);
        
        // 随机选择一个队列放入
        static std::random_device rd;
        static std::mt19937 gen(rd());
        std::uniform_int_distribution<size_t> dist(0, threadQueues_.size() - 1);
        
        size_t index = dist(gen);
        threadQueues_[index]->pushBack(task);
    }
    
    // 提交任务到指定线程的队列
    template<typename Func, typename... Args>
    void executeToThread(size_t thread_index, Func&& func, Args&&... args) {
        if (!running_) {
            throw std::runtime_error("线程池已停止");
        }
        
        if (thread_index >= threadQueues_.size()) {
            throw std::out_of_range("线程索引超出范围");
        }
        
        auto task = std::bind(std::forward<Func>(func), 
                              std::forward<Args>(args)...);
        
        threadQueues_[thread_index]->pushBack(task);
    }
    
    // 提交任务(有返回值)
    template<typename Func, typename... Args>
    auto submit(Func&& func, Args&&... args) 
        -> std::future<decltype(func(args...))> {
        
        using ReturnType = decltype(func(args...));
        
        // 创建packaged_task
        auto task = std::make_shared<std::packaged_task<ReturnType()>>(
            std::bind(std::forward<Func>(func), 
                      std::forward<Args>(args)...)
        );
        
        // 获取future
        std::future<ReturnType> result = task->get_future();
        
        // 随机选择队列
        static std::random_device rd;
        static std::mt19937 gen(rd());
        std::uniform_int_distribution<size_t> dist(0, threadQueues_.size() - 1);
        
        size_t index = dist(gen);
        
        // 包装成void()函数
        threadQueues_[index]->pushBack([task]() {
            (*task)();
        });
        
        return result;
    }
    
    // ========== 状态查询 ==========
    
    // 获取线程数量
    size_t threadCount() const {
        return workers_.size();
    }
    
    // 获取各队列大小
    std::vector<size_t> queueSizes() const {
        std::vector<size_t> sizes;
        sizes.reserve(threadQueues_.size());
        
        for (const auto& queue : threadQueues_) {
            sizes.push_back(queue->size());
        }
        
        return sizes;
    }
    
    // 总任务数
    size_t totalTaskCount() const {
        size_t total = 0;
        for (const auto& queue : threadQueues_) {
            total += queue->size();
        }
        return total;
    }
};

// 初始化线程局部变量
thread_local std::mt19937 WorkStealingThreadPool::random_generator_(std::random_device{}());
thread_local size_t WorkStealingThreadPool::thread_index_ = 0;

#endif

WorkStealingPool核心要点

1. 双端队列策略

复制代码
// 本地线程操作尾部(LIFO)
queue_.push_back(task);  // 放任务
queue_.pop_back();       // 取任务(最近放的,缓存热)

// 其他线程窃取头部(FIFO)  
queue_.pop_front();      // 偷任务(最老的,公平)

2. 随机窃取算法

复制代码
// 随机选择偷窃目标,避免所有线程偷同一个队列
std::shuffle(steal_order.begin(), steal_order.end(), random_generator_);

3. 线程局部存储

复制代码
// 每个线程记住自己的队列索引
static thread_local size_t thread_index_;

适用场景

WorkStealingPool适用于以下场景:

  1. 任务分解型应用:当一个任务需要被分解成多个子任务进行并行处理时,WorkStealingPool可以自动管理任务的分配和调度,充分利用多核处理器的并行能力,提高任务处理效率。例如,图像处理、数据处理、并行排序等。

  2. 递归型任务:对于递归型的任务,WorkStealingPool能够适应任务的动态变化,根据需要创建和调度子任务,以实现更高效的递归执行。例如,斐波那契数列计算、归并排序等。

  3. 高吞吐量任务:WorkStealingPool的工作窃取算法可以减少线程之间的竞争,并且能够在任务队列为空时从其他线程窃取任务,从而减少线程的等待时间,提高整体的任务处理吞吐量。适用于需要高吞吐量的任务场景。

  4. CPU密集型任务:对于需要大量的CPU计算而没有I/O阻塞的任务,使用WorkStealingPool可以更好地充分利用CPU核心,并且可以根据需要增加或减少线程数量,以适应任务的计算量。

需要注意的是,WorkStealingPool在任务数较少或任务之间存在I/O等阻塞时可能不如其他类型的线程池效果好,因为工作窃取算法适用于CPU密集型任务。在实际应用中,根据具体情况选择合适的线程池类型和参数才能达到最佳的性能和效果。

相关推荐
一条咸鱼_SaltyFish21 小时前
[Day10] contract-management初期开发避坑指南:合同模块 DDD 架构规划的教训与调整
开发语言·经验分享·微服务·架构·bug·开源软件·ai编程
额呃呃21 小时前
STL内存分配器
开发语言·c++
七点半77021 小时前
c++基本内容
开发语言·c++·算法
嵌入式进阶行者21 小时前
【算法】基于滑动窗口的区间问题求解算法与实例:华为OD机考双机位A卷 - 最长的顺子
开发语言·c++·算法
No0d1es21 小时前
2025年12月 GESP CCF编程能力等级认证Python三级真题
开发语言·php
lalala_lulu1 天前
什么是事务,事务有什么特性?
java·开发语言·数据库
CCPC不拿奖不改名1 天前
python基础:python语言中的函数与模块+面试习题
开发语言·python·面试·职场和发展·蓝桥杯
毕设源码-朱学姐1 天前
【开题答辩全过程】以 基于Python语言的疫情数据可视化系统为例,包含答辩的问题和答案
开发语言·python·信息可视化
哥只是传说中的小白1 天前
Nano Banana Pro高并发接入Grsai Api实战!0.09/张无限批量生成(附接入实战+开源工具)
开发语言·数据库·ai作画·开源·aigc·php·api