cpp
复制代码
#pragma once
#include "concurrent/defs.h"
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cerrno>
#include <cstddef>
#include <cstdio>
#include <functional>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#if defined(__linux__)
#include <pthread.h>
#include <semaphore.h>
#else
#include <condition_variable>
#include <mutex>
#endif
namespace topsun::concurrent {
/**
* @brief Chase--Lev 风格的有界 work-stealing 双端队列(单 owner + 多 stealer)。
*
* - Owner(对应 worker 自身)仅在 **底端** push / pop(LIFO,利于局部性)。
* - 其他线程在 **顶端** try_steal(与 owner 的最后一次 pop 存在经典竞态,由 CAS 保证正确性)。
* @tparam T 须可默认构造、可移动(槽位会被覆盖或移出)
* @tparam Cap 容量,须为 2 的幂且 >= 2;满员时 try_push 失败。
*/
template <typename T, std::size_t Cap>
class WorkStealingDeque {
static_assert(Cap >= 2 && (Cap & (Cap - 1)) == 0, "Cap must be power of 2 and >= 2");
public:
WorkStealingDeque() : buf_(Cap) {}
WorkStealingDeque(const WorkStealingDeque &) = delete;
WorkStealingDeque &operator=(const WorkStealingDeque &) = delete;
[[nodiscard]] bool try_push(T &&v) noexcept {
const std::size_t b = bottom_.load(std::memory_order_relaxed);
const std::size_t t = top_.load(std::memory_order_acquire);
if (b - t >= Cap) {
return false;
}
buf_[b & kMask] = std::forward<T>(v);
bottom_.store(b + 1, std::memory_order_release);
return true;
}
/// 仅由「拥有该队列」的 worker 调用
[[nodiscard]] bool try_pop_local(T &out) noexcept {
std::size_t b = bottom_.load(std::memory_order_relaxed);
if (b == 0) {
return false;
}
b = b - 1;
bottom_.store(b, std::memory_order_relaxed);
std::atomic_thread_fence(std::memory_order_seq_cst);
std::size_t t = top_.load(std::memory_order_relaxed);
if (t > b) {
bottom_.store(b +1, std::memory_order_relaxed);
return false;
}
if (t < b) {
out = std::move(buf_[b & kMask]);
return true;
}
if (top_.compare_exchange_strong(t, t +1, std::memory_order_seq_cst,
std::memory_order_relaxed)) {
out = std::move(buf_[b & kMask]);
bottom_.store(b +1, std::memory_order_relaxed);
return true;
}
bottom_.store(b +1, std::memory_order_relaxed);
return false;
}
/// 由其他 worker 调用;失败时表示空或并发争抢,可换 victim 重试
[[nodiscard]] bool try_steal(T &out) noexcept {
std::size_t t = top_.load(std::memory_order_acquire);
for (;;) {
std::size_t b = bottom_.load(std::memory_order_acquire);
if (t >= b) {
return false;
}
const std::size_t stolen = t;
if (!top_.compare_exchange_weak(t, t +1, std::memory_order_seq_cst,
std::memory_order_relaxed)) {
continue;
}
out = std::move(buf_[stolen & kMask]);
return true;
}
}
/// 近似元素个数(仅用于观测 / 空闲判断,非精确快照)
[[nodiscard]] std::size_t approx_size() const noexcept {
const std::size_t b = bottom_.load(std::memory_order_relaxed);
const std::size_t t = top_.load(std::memory_order_relaxed);
return b >= t ? b - t : 0;
}
private:
static constexpr std::size_t kMask = Cap - 1;
std::vector<T> buf_;
alignas(CACHE_LINE_SIZE) std::atomic<std::size_t> bottom_{0};
alignas(CACHE_LINE_SIZE) std::atomic<std::size_t> top_{0};
};
/**
* @brief Work-stealing 线程池:每 worker 一条局部队列,优先本地 pop,空则轮询窃取其它队列。
*
* 适用:**大量短任务**、**递归/树状并行**、任务大多由池内线程继续划分时(局部队列热、全局争用低)。
* 外部 submit 通过 **轮询挂载** 到各局部队列;若与单条全局 MPMC 相比,在「多生产者猛灌同一队列」
* 场景下优势未必明显,但在 **局部分支产生子任务** 时通常更省争用。
*
* 同步:任务存储在各 WorkStealingDeque 中;全局仅用一个信号量(Linux)或 condition_variable(其它平台)
* 表示「可能存在工作」,避免全体 worker 在用户态空转。每次成功入队后 sem_post / notify(与 ThreadPool 同类策略)。
*
* 与 ThreadPool 的差异:不提供 queue() 取全局无锁队列(结构不同);API 其余语义尽量对齐便于替换试验。
*/
template <std::size_t PerWorkerCapacity = 1024>
class WorkStealingThreadPool {
public:
using Task = std::function<void()>;
using LocalQueue = WorkStealingDeque<Task, PerWorkerCapacity>;
static WorkStealingThreadPool &instance() {
static WorkStealingThreadPool pool(std::thread::hardware_concurrency(),
"wstp");
return pool;
}
explicit WorkStealingThreadPool(std::size_t worker_count,
std::string name_prefix = "wstp")
: name_prefix_(std::move(name_prefix)) {
if (worker_count == 0) {
worker_count = 1;
}
worker_count_ = worker_count;
queues_ = std::make_unique<LocalQueue[]>(worker_count_);
#if defined(__linux__)
if (sem_init(&work_sem_, 0, 0) != 0) {
std::abort();
}
work_sem_ready_ = true;
#endif
workers_.reserve(worker_count_);
for (std::size_t i = 0; i < worker_count_; ++i) {
workers_.emplace_back([this, i] {
worker_loop(static_cast<unsigned>(i));
});
}
}
WorkStealingThreadPool(const WorkStealingThreadPool &) = delete;
WorkStealingThreadPool &operator=(const WorkStealingThreadPool &) = delete;
WorkStealingThreadPool(WorkStealingThreadPool &&) = delete;
WorkStealingThreadPool &operator=(WorkStealingThreadPool &&) = delete;
~WorkStealingThreadPool() {
stop();
#if defined(__linux__)
if (work_sem_ready_) {
sem_destroy(&work_sem_);
work_sem_ready_ = false;
}
#endif
}
[[nodiscard]] bool accepting() const noexcept {
return accepting_.load(std::memory_order_acquire);
}
[[nodiscard]] bool stop_requested() const noexcept {
return stop_requested_.load(std::memory_order_acquire);
}
[[nodiscard]] std::size_t worker_count() const noexcept {
return worker_count_;
}
/// 各局部队列 approx_size 之和的保守近似
[[nodiscard]] std::size_t pending_approx() const noexcept {
std::size_t sum = 0;
for (std::size_t i = 0; i < worker_count_; ++i) {
sum += queues_[i].approx_size();
}
return sum;
}
/// 轮询挂载到各局部队列:若当前槽满则试下一个;全部满则 false
template <typename F>
bool try_submit(F &&f) {
if (!accepting_.load(std::memory_order_acquire)) {
return false;
}
Task task(std::forward<F>(f));
if (!task) {
return false;
}
if (!accepting_.load(std::memory_order_relaxed)) {
return false;
}
const auto n = static_cast<unsigned>(worker_count_);
const unsigned start = static_cast<unsigned>(
submit_rr_.fetch_add(1u, std::memory_order_relaxed) % n);
for (unsigned k = 0; k < n; ++k) {
const unsigned idx = (start + k) % n;
if (queues_[idx].try_push(std::move(task))) {
notify_workers();
return true;
}
}
return false;
}
/// 阻塞直到某局部队列成功接纳或池不再接受任务
template <typename F>
bool submit(F &&f) {
if (!accepting_.load(std::memory_order_acquire)) {
return false;
}
Task task(std::forward<F>(f));
if (!task) {
return false;
}
unsigned spins = 0;
for (;;) {
if (!accepting_.load(std::memory_order_acquire)) {
return false;
}
const auto n = static_cast<unsigned>(worker_count_);
const unsigned start = static_cast<unsigned>(
submit_rr_.fetch_add(1u, std::memory_order_relaxed) % n);
for (unsigned k = 0; k < n; ++k) {
const unsigned idx = (start + k) % n;
if (queues_[idx].try_push(std::move(task))) {
notify_workers();
return true;
}
}
if (++spins > 500) {
std::this_thread::sleep_for(std::chrono::microseconds(200));
} else {
spin_loop_pause();
}
}
}
/// 由 **本池 worker 线程** 调用时:任务进入当前线程的局部队列(典型 continuation / 子任务)
template <typename F>
bool try_submit_local(unsigned worker_index, F &&f) {
if (worker_index >= worker_count_) {
return false;
}
if (!accepting_.load(std::memory_order_acquire)) {
return false;
}
Task task(std::forward<F>(f));
if (!task) {
return false;
}
if (!queues_[worker_index].try_push(std::move(task))) {
return false;
}
notify_workers();
return true;
}
void stop() noexcept {
accepting_.store(false, std::memory_order_release);
stop_requested_.store(true, std::memory_order_release);
#if defined(__linux__)
for (std::size_t i = 0; i < workers_.size(); ++i) {
(void)sem_post(&work_sem_);
}
#else
idle_cv_.notify_all();
#endif
for (auto &t : workers_) {
if (t.joinable()) {
t.join();
}
}
}
private:
void notify_workers() noexcept {
#if defined(__linux__)
(void)sem_post(&work_sem_);
#else
idle_cv_.notify_one();
#endif
}
void worker_loop(unsigned index) {
#if defined(__linux__)
if (!name_prefix_.empty()) {
std::string nm = name_prefix_;
if (nm.size() > 10) {
nm.resize(10);
}
if (index < 100) {
char buf[16];
std::snprintf(buf, sizeof(buf), ".%u", index);
nm += buf;
}
nm.resize(std::min<std::size_t>(nm.size(), 15U));
pthread_setname_np(pthread_self(), nm.c_str());
}
#else
(void)index;
#endif
Task task;
unsigned stable_empty_after_stop = 0;
const auto n = static_cast<unsigned>(worker_count_);
for (;;) {
while (queues_[index].try_pop_local(task)) {
stable_empty_after_stop = 0;
if (task) {
try {
task();
} catch (...) {
}
}
}
bool stole = false;
for (unsigned r = 1; r < n; ++r) {
const unsigned victim = (index + r) % n;
if (queues_[victim].try_steal(task)) {
stole = true;
break;
}
}
if (stole) {
if (task) {
try {
task();
} catch (...) {
}
}
continue;
}
if (stop_requested_.load(std::memory_order_acquire)) {
if (pending_approx() == 0) {
if (++stable_empty_after_stop > 100) {
break;
}
std::this_thread::sleep_for(std::chrono::microseconds(200));
} else {
stable_empty_after_stop = 0;
spin_loop_pause();
}
continue;
}
stable_empty_after_stop = 0;
#if defined(__linux__)
for (;;) {
if (sem_wait(&work_sem_) == 0) {
break;
}
if (errno != EINTR) {
break;
}
}
#else
std::unique_lock<std::mutex> lk(idle_mtx_);
idle_cv_.wait(lk, [this] {
return stop_requested_.load(std::memory_order_acquire) ||
pending_approx() > 0;
});
#endif
}
}
std::string name_prefix_;
std::size_t worker_count_{0};
std::unique_ptr<LocalQueue[]> queues_;
std::vector<std::thread> workers_;
std::atomic<unsigned> submit_rr_{0};
#if defined(__linux__)
sem_t work_sem_{};
bool work_sem_ready_{false};
#else
std::mutex idle_mtx_{};
std::condition_variable idle_cv_{};
#endif
std::atomic<bool> accepting_{true};
std::atomic<bool> stop_requested_{false};
};
} // namespace topsun::concurrent