【无标题】steal_work_thread_pool

cpp 复制代码
#pragma once

#include "concurrent/defs.h"

#include <algorithm>
#include <atomic>
#include <chrono>
#include <cerrno>
#include <cstddef>
#include <cstdio>
#include <functional>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>

#if defined(__linux__)
#include <pthread.h>
#include <semaphore.h>
#else
#include <condition_variable>
#include <mutex>
#endif

namespace topsun::concurrent {

/**
 * @brief Chase--Lev 风格的有界 work-stealing 双端队列(单 owner + 多 stealer)。
 *
 * - Owner(对应 worker 自身)仅在 **底端** push / pop(LIFO,利于局部性)。
 * - 其他线程在 **顶端** try_steal(与 owner 的最后一次 pop 存在经典竞态,由 CAS 保证正确性)。
 * @tparam T 须可默认构造、可移动(槽位会被覆盖或移出)
 * @tparam Cap 容量,须为 2 的幂且 >= 2;满员时 try_push 失败。
 */
template <typename T, std::size_t Cap>
class WorkStealingDeque {
  static_assert(Cap >= 2 && (Cap & (Cap - 1)) == 0, "Cap must be power of 2 and >= 2");

public:
  WorkStealingDeque() : buf_(Cap) {}

  WorkStealingDeque(const WorkStealingDeque &) = delete;
  WorkStealingDeque &operator=(const WorkStealingDeque &) = delete;

  [[nodiscard]] bool try_push(T &&v) noexcept {
    const std::size_t b = bottom_.load(std::memory_order_relaxed);
    const std::size_t t = top_.load(std::memory_order_acquire);
    if (b - t >= Cap) {
      return false;
    }
    buf_[b & kMask] = std::forward<T>(v);
    bottom_.store(b + 1, std::memory_order_release);
    return true;
  }

  /// 仅由「拥有该队列」的 worker 调用
  [[nodiscard]] bool try_pop_local(T &out) noexcept {
    std::size_t b = bottom_.load(std::memory_order_relaxed);
    if (b == 0) {
      return false;
    }
    b = b - 1;
    bottom_.store(b, std::memory_order_relaxed);
    std::atomic_thread_fence(std::memory_order_seq_cst);
    std::size_t t = top_.load(std::memory_order_relaxed);
    if (t > b) {
      bottom_.store(b +1, std::memory_order_relaxed);
      return false;
    }
    if (t < b) {
      out = std::move(buf_[b & kMask]);
      return true;
    }
    if (top_.compare_exchange_strong(t, t +1, std::memory_order_seq_cst,
                                     std::memory_order_relaxed)) {
      out = std::move(buf_[b & kMask]);
      bottom_.store(b +1, std::memory_order_relaxed);
      return true;
    }
    bottom_.store(b +1, std::memory_order_relaxed);
    return false;
  }

  /// 由其他 worker 调用;失败时表示空或并发争抢,可换 victim 重试
  [[nodiscard]] bool try_steal(T &out) noexcept {
    std::size_t t = top_.load(std::memory_order_acquire);
    for (;;) {
      std::size_t b = bottom_.load(std::memory_order_acquire);
      if (t >= b) {
        return false;
      }
      const std::size_t stolen = t;
      if (!top_.compare_exchange_weak(t, t +1, std::memory_order_seq_cst,
                                       std::memory_order_relaxed)) {
        continue;
      }
      out = std::move(buf_[stolen & kMask]);
      return true;
    }
  }

  /// 近似元素个数(仅用于观测 / 空闲判断,非精确快照)
  [[nodiscard]] std::size_t approx_size() const noexcept {
    const std::size_t b = bottom_.load(std::memory_order_relaxed);
    const std::size_t t = top_.load(std::memory_order_relaxed);
    return b >= t ? b - t : 0;
  }

private:
  static constexpr std::size_t kMask = Cap - 1;
  std::vector<T> buf_;
  alignas(CACHE_LINE_SIZE) std::atomic<std::size_t> bottom_{0};
  alignas(CACHE_LINE_SIZE) std::atomic<std::size_t> top_{0};
};

/**
 * @brief Work-stealing 线程池:每 worker 一条局部队列,优先本地 pop,空则轮询窃取其它队列。
 *
 * 适用:**大量短任务**、**递归/树状并行**、任务大多由池内线程继续划分时(局部队列热、全局争用低)。
 * 外部 submit 通过 **轮询挂载** 到各局部队列;若与单条全局 MPMC 相比,在「多生产者猛灌同一队列」
 * 场景下优势未必明显,但在 **局部分支产生子任务** 时通常更省争用。
 *
 * 同步:任务存储在各 WorkStealingDeque 中;全局仅用一个信号量(Linux)或 condition_variable(其它平台)
 * 表示「可能存在工作」,避免全体 worker 在用户态空转。每次成功入队后 sem_post / notify(与 ThreadPool 同类策略)。
 *
 * 与 ThreadPool 的差异:不提供 queue() 取全局无锁队列(结构不同);API 其余语义尽量对齐便于替换试验。
 */
template <std::size_t PerWorkerCapacity = 1024>
class WorkStealingThreadPool {
public:
  using Task = std::function<void()>;
  using LocalQueue = WorkStealingDeque<Task, PerWorkerCapacity>;

  static WorkStealingThreadPool &instance() {
    static WorkStealingThreadPool pool(std::thread::hardware_concurrency(),
                                       "wstp");
    return pool;
  }

  explicit WorkStealingThreadPool(std::size_t worker_count,
                                  std::string name_prefix = "wstp")
      : name_prefix_(std::move(name_prefix)) {
    if (worker_count == 0) {
      worker_count = 1;
    }
    worker_count_ = worker_count;
    queues_ = std::make_unique<LocalQueue[]>(worker_count_);
#if defined(__linux__)
    if (sem_init(&work_sem_, 0, 0) != 0) {
      std::abort();
    }
    work_sem_ready_ = true;
#endif
    workers_.reserve(worker_count_);
    for (std::size_t i = 0; i < worker_count_; ++i) {
      workers_.emplace_back([this, i] {
        worker_loop(static_cast<unsigned>(i));
      });
    }
  }

  WorkStealingThreadPool(const WorkStealingThreadPool &) = delete;
  WorkStealingThreadPool &operator=(const WorkStealingThreadPool &) = delete;
  WorkStealingThreadPool(WorkStealingThreadPool &&) = delete;
  WorkStealingThreadPool &operator=(WorkStealingThreadPool &&) = delete;

  ~WorkStealingThreadPool() {
    stop();
#if defined(__linux__)
    if (work_sem_ready_) {
      sem_destroy(&work_sem_);
      work_sem_ready_ = false;
    }
#endif
  }

  [[nodiscard]] bool accepting() const noexcept {
    return accepting_.load(std::memory_order_acquire);
  }

  [[nodiscard]] bool stop_requested() const noexcept {
    return stop_requested_.load(std::memory_order_acquire);
  }

  [[nodiscard]] std::size_t worker_count() const noexcept {
    return worker_count_;
  }

  /// 各局部队列 approx_size 之和的保守近似
  [[nodiscard]] std::size_t pending_approx() const noexcept {
    std::size_t sum = 0;
    for (std::size_t i = 0; i < worker_count_; ++i) {
      sum += queues_[i].approx_size();
    }
    return sum;
  }

  /// 轮询挂载到各局部队列:若当前槽满则试下一个;全部满则 false
  template <typename F>
  bool try_submit(F &&f) {
    if (!accepting_.load(std::memory_order_acquire)) {
      return false;
    }
    Task task(std::forward<F>(f));
    if (!task) {
      return false;
    }
    if (!accepting_.load(std::memory_order_relaxed)) {
      return false;
    }
    const auto n = static_cast<unsigned>(worker_count_);
    const unsigned start = static_cast<unsigned>(
        submit_rr_.fetch_add(1u, std::memory_order_relaxed) % n);
    for (unsigned k = 0; k < n; ++k) {
      const unsigned idx = (start + k) % n;
      if (queues_[idx].try_push(std::move(task))) {
        notify_workers();
        return true;
      }
    }
    return false;
  }

  /// 阻塞直到某局部队列成功接纳或池不再接受任务
  template <typename F>
  bool submit(F &&f) {
    if (!accepting_.load(std::memory_order_acquire)) {
      return false;
    }
    Task task(std::forward<F>(f));
    if (!task) {
      return false;
    }
    unsigned spins = 0;
    for (;;) {
      if (!accepting_.load(std::memory_order_acquire)) {
        return false;
      }
      const auto n = static_cast<unsigned>(worker_count_);
      const unsigned start = static_cast<unsigned>(
          submit_rr_.fetch_add(1u, std::memory_order_relaxed) % n);
      for (unsigned k = 0; k < n; ++k) {
        const unsigned idx = (start + k) % n;
        if (queues_[idx].try_push(std::move(task))) {
          notify_workers();
          return true;
        }
      }
      if (++spins > 500) {
        std::this_thread::sleep_for(std::chrono::microseconds(200));
      } else {
        spin_loop_pause();
      }
    }
  }

  /// 由 **本池 worker 线程** 调用时:任务进入当前线程的局部队列(典型 continuation / 子任务)
  template <typename F>
  bool try_submit_local(unsigned worker_index, F &&f) {
    if (worker_index >= worker_count_) {
      return false;
    }
    if (!accepting_.load(std::memory_order_acquire)) {
      return false;
    }
    Task task(std::forward<F>(f));
    if (!task) {
      return false;
    }
    if (!queues_[worker_index].try_push(std::move(task))) {
      return false;
    }
    notify_workers();
    return true;
  }

  void stop() noexcept {
    accepting_.store(false, std::memory_order_release);
    stop_requested_.store(true, std::memory_order_release);
#if defined(__linux__)
    for (std::size_t i = 0; i < workers_.size(); ++i) {
      (void)sem_post(&work_sem_);
    }
#else
    idle_cv_.notify_all();
#endif
    for (auto &t : workers_) {
      if (t.joinable()) {
        t.join();
      }
    }
  }

private:
  void notify_workers() noexcept {
#if defined(__linux__)
    (void)sem_post(&work_sem_);
#else
    idle_cv_.notify_one();
#endif
  }

  void worker_loop(unsigned index) {
#if defined(__linux__)
    if (!name_prefix_.empty()) {
      std::string nm = name_prefix_;
      if (nm.size() > 10) {
        nm.resize(10);
      }
      if (index < 100) {
        char buf[16];
        std::snprintf(buf, sizeof(buf), ".%u", index);
        nm += buf;
      }
      nm.resize(std::min<std::size_t>(nm.size(), 15U));
      pthread_setname_np(pthread_self(), nm.c_str());
    }
#else
    (void)index;
#endif

    Task task;
    unsigned stable_empty_after_stop = 0;
    const auto n = static_cast<unsigned>(worker_count_);

    for (;;) {
      while (queues_[index].try_pop_local(task)) {
        stable_empty_after_stop = 0;
        if (task) {
          try {
            task();
          } catch (...) {
          }
        }
      }

      bool stole = false;
      for (unsigned r = 1; r < n; ++r) {
        const unsigned victim = (index + r) % n;
        if (queues_[victim].try_steal(task)) {
          stole = true;
          break;
        }
      }
      if (stole) {
        if (task) {
          try {
            task();
          } catch (...) {
          }
        }
        continue;
      }

      if (stop_requested_.load(std::memory_order_acquire)) {
        if (pending_approx() == 0) {
          if (++stable_empty_after_stop > 100) {
            break;
          }
          std::this_thread::sleep_for(std::chrono::microseconds(200));
        } else {
          stable_empty_after_stop = 0;
          spin_loop_pause();
        }
        continue;
      }

      stable_empty_after_stop = 0;
#if defined(__linux__)
      for (;;) {
        if (sem_wait(&work_sem_) == 0) {
          break;
        }
        if (errno != EINTR) {
          break;
        }
      }
#else
      std::unique_lock<std::mutex> lk(idle_mtx_);
      idle_cv_.wait(lk, [this] {
        return stop_requested_.load(std::memory_order_acquire) ||
               pending_approx() > 0;
      });
#endif
    }
  }

  std::string name_prefix_;
  std::size_t worker_count_{0};
  std::unique_ptr<LocalQueue[]> queues_;
  std::vector<std::thread> workers_;
  std::atomic<unsigned> submit_rr_{0};

#if defined(__linux__)
  sem_t work_sem_{};
  bool work_sem_ready_{false};
#else
  std::mutex idle_mtx_{};
  std::condition_variable idle_cv_{};
#endif
  std::atomic<bool> accepting_{true};
  std::atomic<bool> stop_requested_{false};
};

} // namespace topsun::concurrent
相关推荐
zzzsde1 小时前
【Linux】线程概念与控制(3):线程ID&&C++封装线程
linux·运维·服务器·开发语言·算法
w_t_y_y1 小时前
VUE组件配置项(零)概述
前端·javascript·vue.js
水云桐程序员1 小时前
Web应用的分类
前端·javascript·vue.js·react.js·webkit
Jack N1 小时前
2026 Web 网站性能优化指南
前端·性能优化
UXbot1 小时前
支持移动端原型绘制的 AI 工具核心功能对比(2026):5 款主流平台能力横向评测
前端·低代码·ui·交互·原型模式·web app
handler011 小时前
滑动窗口(同向双指针)算法:模板与例题解析
c语言·c++·笔记·算法·蓝桥杯·双指针·滑动窗口
不做无法实现的梦~1 小时前
Linux 新手到日常运维操作指南
linux·运维·服务器
Brilliantwxx1 小时前
【算法题】基础计算器的不同实现方式
c++·算法
Sunsets_Red1 小时前
P12375 「LAOI-12」MST? 题解
c++·算法·洛谷·信息学·oier·洛谷题解