一、工作窃取线程池的概念
工作窃取是一种高效的任务调度策略,核心思想是:每个工作线程拥有自己的独立任务队列。当一个线程自己的队列为空(闲置)时,它会主动"窃取"其他忙碌线程队列中的任务,从而实现动态负载均衡,避免某些线程闲置、某些线程堆积的情况。
与传统线程池(如固定大小线程池使用单一共享队列)相比,工作窃取的优势在于:
- 减少全局锁竞争(每个线程优先访问自己的队列)。
- 天然支持不均衡负载(长短任务混合、递归任务等场景)。
- 线程数量固定,资源可控,同时又能接近"按需执行"的效果。
工作窃取线程池的工作原理:
提交任务时,按一定策略(通常 round-robin 或随机)放入某个线程的私有队列。
每个工作线程循环:先处理自己的队列 → 如果自己的队列为空,就去其他线程的队列"偷"一批任务。
窃取成功后,窃取线程和被窃取线程继续各自工作,实现自动负载均衡。
二、为什么需要设计工作窃取线程池
传统线程池有两种极端:
- FixedThreadPool(固定线程数 + 单一阻塞队列):线程少时容易排队,线程多时锁竞争严重。
- CachedThreadPool(按需创建线程,无界队列):短任务场景下线程会疯狂膨胀,容易 OOM,且线程创建/销毁开销大。
工作窃取线程池正是为了解决这两者痛点而设计的:
- 线程数量固定(避免 CachedThreadPool 的线程爆炸)。
- 任务"缓存"在多个独立队列中(每个线程一个 bucket),而不是单一队列。
- 通过"窃取"机制实现动态均衡,既不会让线程闲置,也不会让单个队列无限堆积。
简单说来说,它兼具 Fixed 的资源可控性 + Cached 的动态适应性,特别适合现代多核 CPU 和不均匀工作负载。
三、代码设计
同步队列
SyncQueue.hpp:
cpp
#include<vector>
#include<list>
#include<mutex>
#include<condition_variable>
#include<iostream>
using namespace std;
template<class T>
class SyncQueue
{
private:
//std::list<T> m_queue;
std::vector<std::list<T>> m_taskQueues;
size_t m_bucketSize; // vector size;桶的大小
size_t m_maxSize; // 队列的大小
mutable std::mutex m_mutex;
std::condition_variable m_notEmpty; //对应于消费者
std::condition_variable m_notFull; //对应于生产者
size_t m_waitTime; //任务队列满等待时间 s
bool m_needStop; // true 同步队列停止工作
bool IsFull(const int index) const
{
bool full = m_taskQueues[index].size() >= m_maxSize;
if (full)
{
//clog << " m_queue 已经满了,需要等待..." << endl;
}
return full;
}
bool IsEmpty(const int index) const
{
bool empty = m_taskQueues[index].empty();
if (empty)
{
// clog << "m_queue 已经空了,需要等待..." << endl;
}
return empty;
}
template<class F>
int Add(F&& task, const int index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notFull.wait_for(locker,
std::chrono::seconds(m_waitTime),
[this, index] { return m_needStop || !IsFull(index); });
if (!waitret)
{
return 1;
}
if (m_needStop)
{
return 2;
}
m_taskQueues[index].push_back(std::forward<F>(task));
m_notEmpty.notify_all();
return 0;
}
public:
SyncQueue(int bucketsize, int maxsize = 200, size_t timeout = 1)
:m_bucketSize(bucketsize),
m_maxSize(maxsize),
m_needStop(false),
m_waitTime(timeout)
{
m_taskQueues.resize(m_bucketSize);
}
~SyncQueue()
{}
int Put(const T& task, const int index) // 0 ..m_bucketSize-1
{
return Add(task, index);
}
int Put(T&& task, const int index)
{
return Add(std::forward<T>(task), index);
}
int Take(std::list<T>& list, const int index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notEmpty.wait_for(locker,
std::chrono::seconds(m_waitTime),
[this, index] { return m_needStop || !IsEmpty(index); });
if (!waitret)
{
return 1;
}
if (m_needStop)
{
return 2;
}
list = std::move(m_taskQueues[index]);
m_notFull.notify_all();
return 0;
}
int Take(T& task, const int index)
{
std::unique_lock<std::mutex> locker(m_mutex);
bool waitret = m_notEmpty.wait_for(locker,
std::chrono::seconds(m_waitTime),
[this, index] { return m_needStop || !IsEmpty(index); });
if (!waitret)
{
return 1;
}
if (m_needStop)
{
return 2;
}
task = m_taskQueues[index].front();
m_taskQueues[index].pop_front();
m_notFull.notify_all();
return 0;
}
void Stop()
{
std::unique_lock<std::mutex> locker(m_mutex);
for (int i = 0; i < m_bucketSize; ++i)
{
while (!m_needStop && !IsEmpty(i))
{
m_notFull.wait(locker);
}
}
m_needStop = true;
m_notEmpty.notify_all();
m_notFull.notify_all();
}
bool Empty() const
{
std::unique_lock<std::mutex> locker(m_mutex);
size_t sum = 0;
for (auto& xlist : m_taskQueues)
{
sum += xlist.size();
}
return sum == 0;
}
bool Full() const
{
std::unique_lock<std::mutex> locker(m_mutex);
size_t sum = 0;
for (auto& xlist : m_taskQueues)
{
sum += xlist.size();
}
return sum >= m_maxSize;
}
size_t size() const
{
std::unique_lock<std::mutex> locker(m_mutex);
size_t sum = 0;
for (auto& xlist : m_taskQueues)
{
sum += xlist.size();
}
return sum;
}
};
工作窃取线程池的SyncQueue ,它不再是单个 list,而是 vector<list> m_taskQueues,桶数量 = 线程数(m_bucketSize)。
设计关键点:
- 每个桶独立容量:IsFull(index) 只检查 m_taskQueues[index].size() >= m_maxSize,即每个线程的私有队列有独立上限。
- 同一个 mutex + 两把 condition_variable:m_notEmpty(消费者用)、m_notFull(生产者用)。
- Take 支持两种模式:
Take(T&):取单个任务(未在池中使用)。
Take(list&):一次性把整个桶的任务 move 走(批量处理,提升效率)。 - Put 时指定 index,任务被精准投递到某个桶。
为什么这么设计?
- 实现"每个线程拥有自己的任务队列",为工作窃取提供基础。
- 单一 mutex 简化实现(生产环境可改成 per-bucket 细粒度锁或 lock-free deque)。
- Stop() 中对所有桶逐个等待清空,保证优雅关闭。
工作窃取线程池
WorkingStealingPool.hpp:
cpp
#include"SyncQueue.hpp"
#include<functional>
#include<future>
#include<memory>
#include<vector>
using namespace std;
class WorkStealingPool
{
public:
using Task = std::function<void(void)>;
private:
size_t m_numThreads; //
SyncQueue<Task> m_queue; // std::vector<std::list<T>> m_taskQueues;
std::vector<std::shared_ptr<std::thread>> m_threadgroup;
std::atomic_bool m_running; // false; // true;
std::once_flag m_flag;
void Start(int numthreads)
{
m_running = true;
for (int i = 0; i < numthreads; ++i)
{
m_threadgroup.push_back(std::make_shared<std::thread>(std::thread(&WorkStealingPool::RunInThread, this, i)));
}
}
void RunInThread(const int index) // 0 // 1
{
while (m_running)
{
std::list<Task> tasklist;
if (m_queue.Take(tasklist, index) == 0)
{
for (auto& task : tasklist)
{
if (!m_running)
return;
task();
}
}
else
{
int i = threadIndex();
if (i != index && m_queue.Take(tasklist, i) == 0)
{
clog << "偷取任务成功..." << endl;
for (auto& task : tasklist)
{
if (!m_running)
return;
task();
}
}
}
}
}
void StopThreadGroup()
{
m_queue.Stop();
m_running = false;
for (auto& tha : m_threadgroup)
{
if (tha && tha->joinable())
{
tha->join();
}
}
m_threadgroup.clear();
}
int threadIndex()
{
static int num = 0;
return ++num % m_numThreads; // 8 // 0~7
}
public:
WorkStealingPool(const int qusize = 100, const int numthreads = 8)
:m_numThreads(numthreads),
m_queue(m_numThreads, qusize),
m_running(false)
{
std::call_once(m_flag, &WorkStealingPool::Start, this, numthreads);
}
~WorkStealingPool()
{
Stop();
}
void Stop()
{
std::call_once(m_flag, [this]() { StopThreadGroup(); });
}
template<class Func, class... Args>
auto submit(Func&& func, Args&& ... args)
{
using RetType = decltype(func(args...));
auto task = std::make_shared<std::packaged_task<RetType(void)>>(
std::bind(std::forward<Func>(func), std::forward<Args>(args)...)
);
std::future<RetType> result = task->get_future();
if (m_queue.Put([task]() { (*task)(); }, threadIndex()) != 0)
{
(*task)();
}
return result;
}
void Execute(Task&& task)
{
if (m_queue.Put(std::forward<Task>(task), threadIndex()) != 0)
{
cout << "task queue is full, add task fail" << endl;
task();
}
}
void Execute(const Task& task)
{
if (m_queue.Put(task, threadIndex()) != 0)
{
cout << "task queue is full, add task fail" << endl;
task();
}
}
};
WorkingStealingPool.hpp中RunInThread()是窃取逻辑的核心:
cpp
while (m_running) {
std::list<Task> tasklist;
if (m_queue.Take(tasklist, index) == 0) { // 优先处理自己的桶
// 执行所有任务
} else {
int i = threadIndex(); // 选择受害者
if (i != index && m_queue.Take(tasklist, i) == 0) {
clog << "偷取任务成功...";
// 执行偷来的任务
}
}
}
为什么这样设计?
- 优先 own queue:本地性最好,减少锁竞争和缓存 miss。
- 窃取时使用 threadIndex():全局 round-robin 选择目标(简单有效,避免每次 random)。
- Take 整个 list:一次性偷一批任务,减少窃取频率和锁争用。
- submit / Execute:
用 threadIndex() 实现 round-robin 投递。
队列满时直接在提交线程执行(back-pressure 机制,避免任务丢失)。
submit 使用 packaged_task + future,完美支持返回值和异常传递。 - threadIndex() 用静态变量实现全局轮询,保证任务均匀分布,也被窃取逻辑复用。
测试代码
test.cpp:
cpp
#include "WorkingStealingPool.hpp"
#include <chrono>
#include <random>
#include <string>
#include <iomanip>
using namespace std;
// ==================== 测试用的任务函数 ====================
int add(int a, int b, int sleep_ms = 0)
{
if (sleep_ms > 0) {
this_thread::sleep_for(chrono::milliseconds(sleep_ms));
}
return a + b;
}
void print_task(int id)
{
cout << "Task " << id << " executed by thread " << this_thread::get_id() << endl;
}
void throwing_task(int id)
{
cout << "Task " << id << " will throw exception" << endl;
throw runtime_error("Test exception from task " + to_string(id));
}
// ==================== 测试用例函数 ====================
// 测试1: 基本功能 + 返回值正确性
void Test_BasicFunctionality(WorkStealingPool& pool)
{
cout << "\n=== 测试1: 基本功能 ===\n";
auto fut1 = pool.submit(add, 10, 20, 0);
auto fut2 = pool.submit(add, 5, 15, 10); // 带一点延时
cout << "10 + 20 = " << fut1.get() << endl;
cout << "5 + 15 = " << fut2.get() << endl;
}
// 测试2: 大量任务提交(验证并发与完成度)
void Test_ManyTasks(WorkStealingPool& pool, int task_count = 1000)
{
cout << "\n=== 测试2: 大量任务 (" << task_count << " 个) ===\n";
vector<future<int>> results;
results.reserve(task_count);
for (int i = 0; i < task_count; ++i) {
results.push_back(pool.submit(add, i, 100, rand() % 5)); // 随机小延时
}
int success = 0;
for (auto& f : results) {
try {
f.get();
++success;
}
catch (...) {
// 忽略(本测试不抛异常)
}
}
cout << "完成 " << success << " / " << task_count << " 个任务\n";
}
// 测试3: 观察工作窃取机制(通过日志)
void Test_WorkStealing(WorkStealingPool& pool, int task_count = 500)
{
cout << "\n=== 测试3: 工作窃取机制(请观察日志中的 '偷取任务成功') ===\n";
vector<future<void>> results;
results.reserve(task_count);
// 先提交一些长任务,让某些队列堆积
for (int i = 0; i < task_count; ++i) {
results.push_back(pool.submit([i]() {
this_thread::sleep_for(chrono::milliseconds(10 + (i % 30)));
// clog << "任务 " << i << " 执行完成\n";
}));
}
for (auto& f : results) f.get();
cout << "测试3 完成(请查看控制台是否有 '偷取任务成功...' 日志)\n";
}
// 测试4: 异常处理
void Test_ExceptionHandling(WorkStealingPool& pool)
{
cout << "\n=== 测试4: 异常处理 ===\n";
auto fut = pool.submit(throwing_task, 999);
try {
fut.get();
cout << "未捕获到异常!(错误)\n";
}
catch (const exception& e) {
cout << "成功捕获异常: " << e.what() << endl;
}
catch (...) {
cout << "捕获到未知异常\n";
}
}
// 测试5: 队列满回退(直接在提交线程执行)
void Test_QueueFull(WorkStealingPool& pool)
{
cout << "\n=== 测试5: 队列满回退 ===\n";
// 故意用小队列容量测试(构造时可调整 qusize)
for (int i = 0; i < 100; ++i) {
pool.submit(print_task, i); // 如果队列满,会在当前线程直接执行
}
}
// 测试6: 不均匀负载(长短任务混合,更容易触发窃取)
void Test_UnbalancedLoad(WorkStealingPool& pool)
{
cout << "\n=== 测试6: 不均匀负载(验证窃取效果) ===\n";
vector<future<void>> futures;
// 提交一些非常短的任务 + 少数长任务
for (int i = 0; i < 800; ++i) {
if (i % 50 == 0) {
// 长任务
futures.push_back(pool.submit([i]() {
this_thread::sleep_for(chrono::milliseconds(80));
cout << "长任务 " << i << " 完成\n";
}));
}
else {
// 短任务
futures.push_back(pool.submit([i]() {
this_thread::sleep_for(chrono::milliseconds(1));
}));
}
}
for (auto& f : futures) f.get();
cout << "不均匀负载测试完成\n";
}
// ==================== main 测试入口 ====================
int main()
{
srand(static_cast<unsigned>(time(nullptr)));
cout << "=== WorkStealingPool 综合测试开始 ===\n\n";
// 创建线程池:8个工作线程,队列容量适中
WorkStealingPool pool(300, 8); // qusize=300, numthreads=8
try {
Test_BasicFunctionality(pool);
Test_ManyTasks(pool, 800);
Test_WorkStealing(pool, 600);
Test_ExceptionHandling(pool);
Test_QueueFull(pool);
Test_UnbalancedLoad(pool);
}
catch (const exception& e) {
cout << "测试过程中发生未捕获异常: " << e.what() << endl;
}
catch (...) {
cout << "测试过程中发生未知异常\n";
}
cout << "\n=== 所有测试执行完毕,正在停止线程池 ===\n";
pool.Stop(); // 显式停止,确保 worker 线程退出
cout << "=== 测试结束 ===\n";
return 0;
}
6 个测试用例设计覆盖了功能、性能、机制、边界、异常等维度:
Test_BasicFunctionality:验证基本提交、future 返回值正确性、最小可用性。
Test_ManyTasks:压力测试(800 个任务),检查高并发下是否全部完成、无遗漏。
Test_WorkStealing:专门观察"偷取任务成功..."日志,验证窃取机制是否真的被触发(需配合长任务制造不均衡)。
Test_ExceptionHandling:验证异常是否能通过 future 正确传播到提交者(packaged_task 的核心价值)。
Test_QueueFull:测试 back-pressure 机制------队列满时任务是否能在提交线程直接执行(防止系统崩溃)。
Test_UnbalancedLoad:最能体现工作窃取价值的场景(长短任务混合 800 个),短任务线程会快速窃取长任务线程的剩余工作,验证负载均衡效果。