线程池:工作窃取线程池WorkingStealingPool

一、工作窃取线程池的概念

工作窃取是一种高效的任务调度策略,核心思想是:每个工作线程拥有自己的独立任务队列。当一个线程自己的队列为空(闲置)时,它会主动"窃取"其他忙碌线程队列中的任务,从而实现动态负载均衡,避免某些线程闲置、某些线程堆积的情况。

与传统线程池(如固定大小线程池使用单一共享队列)相比,工作窃取的优势在于:

  • 减少全局锁竞争(每个线程优先访问自己的队列)。
  • 天然支持不均衡负载(长短任务混合、递归任务等场景)。
  • 线程数量固定,资源可控,同时又能接近"按需执行"的效果。

工作窃取线程池的工作原理:

提交任务时,按一定策略(通常 round-robin 或随机)放入某个线程的私有队列。

每个工作线程循环:先处理自己的队列 → 如果自己的队列为空,就去其他线程的队列"偷"一批任务。

窃取成功后,窃取线程和被窃取线程继续各自工作,实现自动负载均衡。

二、为什么需要设计工作窃取线程池

传统线程池有两种极端:

  • FixedThreadPool(固定线程数 + 单一阻塞队列):线程少时容易排队,线程多时锁竞争严重。
  • CachedThreadPool(按需创建线程,无界队列):短任务场景下线程会疯狂膨胀,容易 OOM,且线程创建/销毁开销大。

工作窃取线程池正是为了解决这两者痛点而设计的:

  • 线程数量固定(避免 CachedThreadPool 的线程爆炸)。
  • 任务"缓存"在多个独立队列中(每个线程一个 bucket),而不是单一队列。
  • 通过"窃取"机制实现动态均衡,既不会让线程闲置,也不会让单个队列无限堆积。

简单说来说,它兼具 Fixed 的资源可控性 + Cached 的动态适应性,特别适合现代多核 CPU 和不均匀工作负载。

三、代码设计

同步队列

SyncQueue.hpp:

cpp 复制代码
#include<vector>
#include<list>
#include<mutex>
#include<condition_variable>
#include<iostream>

using namespace std;

template<class T>
class SyncQueue
{
private:
    //std::list<T> m_queue;
    std::vector<std::list<T>>  m_taskQueues;
    size_t m_bucketSize;   // vector size;桶的大小
    size_t m_maxSize;      // 队列的大小
    mutable std::mutex m_mutex;
    std::condition_variable m_notEmpty; //对应于消费者
    std::condition_variable m_notFull;  //对应于生产者
    size_t m_waitTime;                  //任务队列满等待时间 s
    bool m_needStop; //  true 同步队列停止工作

    bool IsFull(const int index) const
    {
        bool full = m_taskQueues[index].size() >= m_maxSize;
        if (full)
        {
            //clog << " m_queue 已经满了,需要等待..." << endl;
        }
        return full;
    }

    bool IsEmpty(const int index) const
    {
        bool empty = m_taskQueues[index].empty();
        if (empty)
        {
            // clog << "m_queue 已经空了,需要等待..." << endl;
        }
        return empty;
    }

    template<class F>
    int Add(F&& task, const int index)
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        bool waitret = m_notFull.wait_for(locker,
            std::chrono::seconds(m_waitTime),
            [this, index] { return m_needStop || !IsFull(index); });
        if (!waitret)
        {
            return 1;
        }
        if (m_needStop)
        {
            return 2;
        }
        m_taskQueues[index].push_back(std::forward<F>(task));
        m_notEmpty.notify_all();
        return 0;
    }

public:
    SyncQueue(int bucketsize, int maxsize = 200, size_t timeout = 1)
        :m_bucketSize(bucketsize),
        m_maxSize(maxsize),
        m_needStop(false),
        m_waitTime(timeout)
    {
        m_taskQueues.resize(m_bucketSize);
    }

    ~SyncQueue()
    {}

    int Put(const T& task, const int index)  // 0 ..m_bucketSize-1
    {
        return Add(task, index);
    }

    int Put(T&& task, const int index)
    {
        return Add(std::forward<T>(task), index);
    }

    int Take(std::list<T>& list, const int index)
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        bool waitret = m_notEmpty.wait_for(locker,
            std::chrono::seconds(m_waitTime),
            [this, index] { return m_needStop || !IsEmpty(index); });
        if (!waitret)
        {
            return 1;
        }
        if (m_needStop)
        {
            return 2;
        }
        list = std::move(m_taskQueues[index]);
        m_notFull.notify_all();
        return 0;
    }

    int Take(T& task, const int index)
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        bool waitret = m_notEmpty.wait_for(locker,
            std::chrono::seconds(m_waitTime),
            [this, index] { return m_needStop || !IsEmpty(index); });
        if (!waitret)
        {
            return 1;
        }
        if (m_needStop)
        {
            return 2;
        }
        task = m_taskQueues[index].front();
        m_taskQueues[index].pop_front();
        m_notFull.notify_all();
        return 0;
    }

    void Stop()
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        for (int i = 0; i < m_bucketSize; ++i)
        {
            while (!m_needStop && !IsEmpty(i))
            {
                m_notFull.wait(locker);
            }
        }
        m_needStop = true;
        m_notEmpty.notify_all();
        m_notFull.notify_all();
    }

    bool Empty() const
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        size_t sum = 0;
        for (auto& xlist : m_taskQueues)
        {
            sum += xlist.size();
        }
        return sum == 0;
    }

    bool Full() const
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        size_t sum = 0;
        for (auto& xlist : m_taskQueues)
        {
            sum += xlist.size();
        }
        return sum >= m_maxSize;
    }

    size_t size() const
    {
        std::unique_lock<std::mutex> locker(m_mutex);
        size_t sum = 0;
        for (auto& xlist : m_taskQueues)
        {
            sum += xlist.size();
        }
        return sum;
    }
};

工作窃取线程池的SyncQueue ,它不再是单个 list,而是 vector<list> m_taskQueues,桶数量 = 线程数(m_bucketSize)。

设计关键点:

  • 每个桶独立容量:IsFull(index) 只检查 m_taskQueues[index].size() >= m_maxSize,即每个线程的私有队列有独立上限。
  • 同一个 mutex + 两把 condition_variable:m_notEmpty(消费者用)、m_notFull(生产者用)。
  • Take 支持两种模式:
    Take(T&):取单个任务(未在池中使用)。
    Take(list&):一次性把整个桶的任务 move 走(批量处理,提升效率)。
  • Put 时指定 index,任务被精准投递到某个桶。

为什么这么设计?

  • 实现"每个线程拥有自己的任务队列",为工作窃取提供基础。
  • 单一 mutex 简化实现(生产环境可改成 per-bucket 细粒度锁或 lock-free deque)。
  • Stop() 中对所有桶逐个等待清空,保证优雅关闭。

工作窃取线程池

WorkingStealingPool.hpp:

cpp 复制代码
#include"SyncQueue.hpp"
#include<functional>
#include<future>
#include<memory>
#include<vector>
using namespace std;

class WorkStealingPool
{
public:
    using Task = std::function<void(void)>;

private:
    size_t m_numThreads;  // 
    SyncQueue<Task> m_queue;  // std::vector<std::list<T>>  m_taskQueues;
    std::vector<std::shared_ptr<std::thread>> m_threadgroup;
    std::atomic_bool m_running; // false;  // true;
    std::once_flag m_flag;

    void Start(int numthreads)
    {
        m_running = true;
        for (int i = 0; i < numthreads; ++i)
        {
            m_threadgroup.push_back(std::make_shared<std::thread>(std::thread(&WorkStealingPool::RunInThread, this, i)));
        }
    }

    void RunInThread(const int index) // 0 // 1
    {
        while (m_running)
        {
            std::list<Task> tasklist;
            if (m_queue.Take(tasklist, index) == 0)
            {
                for (auto& task : tasklist)
                {
                    if (!m_running)
                        return;
                    task();
                }
            }
            else
            {
                int i = threadIndex();
                if (i != index && m_queue.Take(tasklist, i) == 0)
                {
                    clog << "偷取任务成功..." << endl;
                    for (auto& task : tasklist)
                    {
                        if (!m_running)
                            return;
                        task();
                    }
                }
            }
        }
    }

    void StopThreadGroup()
    {
        m_queue.Stop();
        m_running = false;
        for (auto& tha : m_threadgroup)
        {
            if (tha && tha->joinable())
            {
                tha->join();
            }
        }
        m_threadgroup.clear();
    }

    int threadIndex()
    {
        static int num = 0;
        return ++num % m_numThreads; // 8 // 0~7
    }

public:
    WorkStealingPool(const int qusize = 100, const int numthreads = 8)
        :m_numThreads(numthreads),
        m_queue(m_numThreads, qusize),
        m_running(false)
    {
        std::call_once(m_flag, &WorkStealingPool::Start, this, numthreads);
    }

    ~WorkStealingPool()
    {
        Stop();
    }

    void Stop()
    {
        std::call_once(m_flag, [this]() { StopThreadGroup(); });
    }

    template<class Func, class... Args>
    auto submit(Func&& func, Args&& ... args)
    {
        using RetType = decltype(func(args...));
        auto task = std::make_shared<std::packaged_task<RetType(void)>>(
            std::bind(std::forward<Func>(func), std::forward<Args>(args)...)
        );
        std::future<RetType> result = task->get_future();
        if (m_queue.Put([task]() { (*task)(); }, threadIndex()) != 0)
        {
            (*task)();
        }
        return result;
    }

    void Execute(Task&& task)
    {
        if (m_queue.Put(std::forward<Task>(task), threadIndex()) != 0)
        {
            cout << "task queue is full, add task fail" << endl;
            task();
        }
    }

    void Execute(const Task& task)
    {
        if (m_queue.Put(task, threadIndex()) != 0)
        {
            cout << "task queue is full, add task fail" << endl;
            task();
        }
    }
};

WorkingStealingPool.hpp中RunInThread()是窃取逻辑的核心:

cpp 复制代码
while (m_running) {
    std::list<Task> tasklist;
    if (m_queue.Take(tasklist, index) == 0) {  // 优先处理自己的桶
        // 执行所有任务
    } else {
        int i = threadIndex();                 // 选择受害者
        if (i != index && m_queue.Take(tasklist, i) == 0) {
            clog << "偷取任务成功...";
            // 执行偷来的任务
        }
    }
}

为什么这样设计?

  • 优先 own queue:本地性最好,减少锁竞争和缓存 miss。
  • 窃取时使用 threadIndex():全局 round-robin 选择目标(简单有效,避免每次 random)。
  • Take 整个 list:一次性偷一批任务,减少窃取频率和锁争用。
  • submit / Execute:
    用 threadIndex() 实现 round-robin 投递。
    队列满时直接在提交线程执行(back-pressure 机制,避免任务丢失)。
    submit 使用 packaged_task + future,完美支持返回值和异常传递。
  • threadIndex() 用静态变量实现全局轮询,保证任务均匀分布,也被窃取逻辑复用。

测试代码

test.cpp:

cpp 复制代码
#include "WorkingStealingPool.hpp" 
#include <chrono>
#include <random>
#include <string>
#include <iomanip>

using namespace std;

// ==================== 测试用的任务函数 ====================

int add(int a, int b, int sleep_ms = 0)
{
    if (sleep_ms > 0) {
        this_thread::sleep_for(chrono::milliseconds(sleep_ms));
    }
    return a + b;
}

void print_task(int id)
{
    cout << "Task " << id << " executed by thread " << this_thread::get_id() << endl;
}

void throwing_task(int id)
{
    cout << "Task " << id << " will throw exception" << endl;
    throw runtime_error("Test exception from task " + to_string(id));
}

// ==================== 测试用例函数 ====================

// 测试1: 基本功能 + 返回值正确性
void Test_BasicFunctionality(WorkStealingPool& pool)
{
    cout << "\n=== 测试1: 基本功能 ===\n";
    auto fut1 = pool.submit(add, 10, 20, 0);
    auto fut2 = pool.submit(add, 5, 15, 10);   // 带一点延时

    cout << "10 + 20 = " << fut1.get() << endl;
    cout << "5 + 15 = " << fut2.get() << endl;
}

// 测试2: 大量任务提交(验证并发与完成度)
void Test_ManyTasks(WorkStealingPool& pool, int task_count = 1000)
{
    cout << "\n=== 测试2: 大量任务 (" << task_count << " 个) ===\n";
    vector<future<int>> results;
    results.reserve(task_count);

    for (int i = 0; i < task_count; ++i) {
        results.push_back(pool.submit(add, i, 100, rand() % 5));  // 随机小延时
    }

    int success = 0;
    for (auto& f : results) {
        try {
            f.get();
            ++success;
        }
        catch (...) {
            // 忽略(本测试不抛异常)
        }
    }
    cout << "完成 " << success << " / " << task_count << " 个任务\n";
}

// 测试3: 观察工作窃取机制(通过日志)
void Test_WorkStealing(WorkStealingPool& pool, int task_count = 500)
{
    cout << "\n=== 测试3: 工作窃取机制(请观察日志中的 '偷取任务成功') ===\n";
    vector<future<void>> results;
    results.reserve(task_count);

    // 先提交一些长任务,让某些队列堆积
    for (int i = 0; i < task_count; ++i) {
        results.push_back(pool.submit([i]() {
            this_thread::sleep_for(chrono::milliseconds(10 + (i % 30)));
            // clog << "任务 " << i << " 执行完成\n";
            }));
    }

    for (auto& f : results) f.get();
    cout << "测试3 完成(请查看控制台是否有 '偷取任务成功...' 日志)\n";
}

// 测试4: 异常处理
void Test_ExceptionHandling(WorkStealingPool& pool)
{
    cout << "\n=== 测试4: 异常处理 ===\n";
    auto fut = pool.submit(throwing_task, 999);

    try {
        fut.get();
        cout << "未捕获到异常!(错误)\n";
    }
    catch (const exception& e) {
        cout << "成功捕获异常: " << e.what() << endl;
    }
    catch (...) {
        cout << "捕获到未知异常\n";
    }
}

// 测试5: 队列满回退(直接在提交线程执行)
void Test_QueueFull(WorkStealingPool& pool)
{
    cout << "\n=== 测试5: 队列满回退 ===\n";
    // 故意用小队列容量测试(构造时可调整 qusize)
    for (int i = 0; i < 100; ++i) {
        pool.submit(print_task, i);   // 如果队列满,会在当前线程直接执行
    }
}

// 测试6: 不均匀负载(长短任务混合,更容易触发窃取)
void Test_UnbalancedLoad(WorkStealingPool& pool)
{
    cout << "\n=== 测试6: 不均匀负载(验证窃取效果) ===\n";
    vector<future<void>> futures;

    // 提交一些非常短的任务 + 少数长任务
    for (int i = 0; i < 800; ++i) {
        if (i % 50 == 0) {
            // 长任务
            futures.push_back(pool.submit([i]() {
                this_thread::sleep_for(chrono::milliseconds(80));
                cout << "长任务 " << i << " 完成\n";
                }));
        }
        else {
            // 短任务
            futures.push_back(pool.submit([i]() {
                this_thread::sleep_for(chrono::milliseconds(1));
                }));
        }
    }

    for (auto& f : futures) f.get();
    cout << "不均匀负载测试完成\n";
}

// ==================== main 测试入口 ====================

int main()
{
    srand(static_cast<unsigned>(time(nullptr)));

    cout << "=== WorkStealingPool 综合测试开始 ===\n\n";

    // 创建线程池:8个工作线程,队列容量适中
    WorkStealingPool pool(300, 8);   // qusize=300, numthreads=8

    try {
        Test_BasicFunctionality(pool);
        Test_ManyTasks(pool, 800);
        Test_WorkStealing(pool, 600);
        Test_ExceptionHandling(pool);
        Test_QueueFull(pool);
        Test_UnbalancedLoad(pool);

    }
    catch (const exception& e) {
        cout << "测试过程中发生未捕获异常: " << e.what() << endl;
    }
    catch (...) {
        cout << "测试过程中发生未知异常\n";
    }

    cout << "\n=== 所有测试执行完毕,正在停止线程池 ===\n";
    pool.Stop();        // 显式停止,确保 worker 线程退出

    cout << "=== 测试结束 ===\n";
    return 0;
}

6 个测试用例设计覆盖了功能、性能、机制、边界、异常等维度:

Test_BasicFunctionality:验证基本提交、future 返回值正确性、最小可用性。

Test_ManyTasks:压力测试(800 个任务),检查高并发下是否全部完成、无遗漏。

Test_WorkStealing:专门观察"偷取任务成功..."日志,验证窃取机制是否真的被触发(需配合长任务制造不均衡)。

Test_ExceptionHandling:验证异常是否能通过 future 正确传播到提交者(packaged_task 的核心价值)。

Test_QueueFull:测试 back-pressure 机制------队列满时任务是否能在提交线程直接执行(防止系统崩溃)。

Test_UnbalancedLoad:最能体现工作窃取价值的场景(长短任务混合 800 个),短任务线程会快速窃取长任务线程的剩余工作,验证负载均衡效果。

相关推荐
Z1Jxxx2 小时前
C++ P1150 Peter 的烟
数据结构·c++·算法
CheerWWW2 小时前
C++学习笔记——函数指针、Lambda表达式、谨慎使用using namespace std、命名空间
c++·笔记·学习
夜猫子ing2 小时前
如何编写一个CMakelists文件
开发语言·c++
踮起脚看烟花2 小时前
chapter10_泛型算法
c++·算法
山栀shanzhi2 小时前
C++四大常见排序对比
c++·算法·排序算法
云栖梦泽2 小时前
Linux内核与驱动:8.ioctl驱动基础
linux·c++
云栖梦泽2 小时前
Linux内核与驱动:7.从应用层 lseek() 到驱动层 .llseek,Linux 字符设备偏移控制详解
linux·c++
steins_甲乙2 小时前
从0做一个小型内存泄露检测器(2): elf文件的动态链接
c++
charlie1145141912 小时前
通用GUI编程技术——图形渲染实战(二十八)——图像格式与编解码:PNG/JPEG全掌握
开发语言·c++·windows·学习·图形渲染·win32