在上文
Rust: 量化策略回测与简易线程池构建、子线程执行观测
中介绍了自定义MPSC+Arc+Mutex线程池框架方案。
下面来介绍一下MPMC的线程池方案。
感谢下文:
Rust简单实现一个线程池
提供的MPMC框架,具体相关文章可以详细阅读。
和上文中相比,本文中的MPMC方案技术相对难度较大,要点较多。
一、主要技术
1、策略封装
和之前
css
type Job = Box<dyn FnOnce() -> OutPut + Send +'static >;
策略类型不同,本次采用了下面的封装:
css
type Job = Box<dyn FnOnce() -> OutPut + Send +'static +UnwindSafe>;
具体的UnwindSafe类型,你不需要过多考虑,这个是框架需要。
2、消息传递机制
在这里采用了MPMC和OneShot的消息传递机制。
MPSC和MPSC不同,这是一个多生产者和多消费者方案。
css
let (task_sender, task_receiver) = std::sync::mpmc::channel::<Box<dyn FnOnce() + Send>>();
而OneShot,是单生产者和单消费者的方案。
css
pub struct JoinHandle<R> {
receiver: oneshot::Receiver<R>,
}
3、scope和spawn的分发方式
下面介绍。
4、PhantomData
css
pub struct Scope<'scope, 'env: 'scope> {
task_sender: &'scope std::sync::mpmc::Sender<Box<dyn FnOnce() + Send>>,
scope: std::marker::PhantomData<&'scope mut &'scope ()>,
env: std::marker::PhantomData<&'env mut &'env ()>,
}
5、transmute:强制类型传换
css
let boxed_task: Box<dyn FnOnce() + Send + 'scope> = Box::new(task);
let boxed_task: Box<dyn FnOnce() + Send + 'static> =
unsafe { std::mem::transmute(boxed_task) };
6、线程阻塞和唤醒
线程阻塞:
css
pub fn scope<'env, F, T>(&self, f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T
{
let scope = Scope {
task_sender: &self.task_sender,
scope: PhantomData,
env: PhantomData,
};
let ret = f(&scope);
while scope.running_tasks.load(std::sync::atomic::Ordering::Acquire) > 0 {
std::thread::park();
}
return ret;
}
线程唤醒:
css
impl<'scope, 'env: 'scope> Scope<'scope, 'env> {
pub fn spawn<F, R>(&'scope self, task: F) -> ScopeJoinHandle<'scope, R>
where
F: FnOnce() -> R + Send + 'scope,
R: Send + 'scope,
{
/* ... */
self.running_tasks.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let main_thread = std::thread::current();
let task = move || {
let ret = task();
oneshot_sender.send(ret).unwrap();
if self.running_tasks.fetch_sub(1, std::sync::atomic::Ordering::Release) == 1 {
main_thread.unpark();
}
};
/* ... */
}
}
7、原子操作
css
use std::sync::atomic::{AtomicUsize, Ordering};
8、捕获panic
css
let task = move || {
let ret = std::panic::catch_unwind(task);
oneshot_sender.send(ret).unwrap();
if self.running_tasks.fetch_sub(1, std::sync::atomic::Ordering::Release) == 1 {
main_thread.unpark();
}
};
9、切换到nightly版本
目前,MPMC在标准库中还没有stable。因此需要:
css
#![feature(mpmc_channel)]
在具体操作上,如果你当前是stable版本,运行下面的代码时需要切换到nightly版本,可以
css
rustup default nightly
rustc --version
完成上面的切换和确认后,你就可以运行下面的代码了。
二、具体代码
1、toml
css
[dependencies]
oneshot = "0.1.1"
chrono = "0.4.24"
rayon = "1.7.0"
2、代码
css
#![feature(mpmc_channel)]
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::panic::UnwindSafe;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpmc;
use std::thread;
use std::time::{Duration,Instant};
use oneshot;// 三方库
// use rayon::ThreadPoolBuilder;
// use rayon::iter::IntoParallelRefMutIterator;
// MPMC线程池方案
#[derive(Debug)]
pub struct ThreadPool {
task_sender: mpmc::Sender<Box<dyn FnOnce() + Send>>,
}
impl ThreadPool {
pub fn new(count: NonZeroUsize) -> Self {
let (task_sender, task_receiver) = mpmc::channel::<Box<dyn FnOnce() + Send>>();
for _ in 0..count.get() {
let task_receiver = task_receiver.clone();
std::thread::spawn(move || {
loop {
let Ok(task) = task_receiver.recv() else {
return;
};
task();
}
});
}
Self { task_sender }
}
pub fn spawn<F, R>(&self, task: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + UnwindSafe + 'static,
R: Send + 'static,
{
let (ret_sender, ret_receiver) = oneshot::channel::<std::thread::Result<R>>();
let task = move || {
let ret = std::panic::catch_unwind(task);
// If the receiver was dropped that means the caller was not interested in the result
_ = ret_sender.send(ret);
};
self.task_sender.send(Box::new(task)).expect("Unexpected error while sending tasks. This should never happen unless all threads are dropped");
JoinHandle {
receiver: ret_receiver,
}
}
pub fn scope<'env, F, T>(&'env self, f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
let scope = Scope {
running_tasks: AtomicUsize::new(0),
task_sender: &self.task_sender,
scope: PhantomData,
env: PhantomData,
};
let ret = f(&scope);
while scope.running_tasks.load(Ordering::Acquire) > 0 {
std::thread::park();
}
return ret;
}
}
#[derive(Debug)]
pub struct JoinHandle<R> {
receiver: oneshot::Receiver<std::thread::Result<R>>,
}
impl<R> JoinHandle<R> {
pub fn join(self) -> Result<std::thread::Result<R>, oneshot::RecvError> {
self.receiver.recv()
}
}
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope> {
running_tasks: AtomicUsize,
task_sender: &'scope mpmc::Sender<Box<dyn FnOnce() + Send>>,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}
impl<'scope, 'env: 'scope> Scope<'scope, 'env> {
pub fn spawn<F, R>(&'scope self, task: F) -> ScopeJoinHandle<'scope, R>
where
F: FnOnce() -> R + Send + UnwindSafe + 'scope,
R: Send + 'scope,
{
let (ret_sender, ret_receiver) = oneshot::channel::<std::thread::Result<R>>();
let main_thread = std::thread::current();
self.running_tasks.fetch_add(1, Ordering::Relaxed);
let task = move || {
let ret = std::panic::catch_unwind(task);
// If the receiver was dropped that means the caller was not interested in the result
_ = ret_sender.send(ret);
// Make sure the previous operation is completed before decrementing the counter
if self.running_tasks.fetch_sub(1, Ordering::Release) == 1 {
main_thread.unpark();
}
};
let boxed_task: Box<dyn FnOnce() + Send + 'scope> = Box::new(task);
let boxed_task: Box<dyn FnOnce() + Send + 'static> =
unsafe { std::mem::transmute(boxed_task) };
self.task_sender.send(boxed_task).expect("Unexpected error while sending tasks. This should never happen unless all threads are dropped");
ScopeJoinHandle {
receiver: ret_receiver,
scope: PhantomData,
}
}
}
#[derive(Debug)]
pub struct ScopeJoinHandle<'scope, R> {
receiver: oneshot::Receiver<std::thread::Result<R>>,
scope: PhantomData<&'scope mut &'scope ()>,
}
impl<'scope, R> ScopeJoinHandle<'scope, R> {
pub fn join(self) -> Result<std::thread::Result<R>, oneshot::RecvError> {
self.receiver.recv()
}
}
#[derive(Debug,Clone)]
enum Parameter{
P0(()),
P1(f32,f32),
P2(f32,f32,f32),
P3(f32,f32,f32,f32),
}
#[derive(Debug,Clone,PartialEq)]
enum OutPut{
TradeFlow0(f32),
TradeFlow1(Vec<f32>,Vec<f32>),
TradeFlow2(Vec<f32>),
}
type Job = Box<dyn FnOnce() -> OutPut + Send +'static +UnwindSafe>;
fn strategy_follow_trend(p:Parameter) ->OutPut{
println!("thread:{:?} run 【follow_trend】 strategy {:?}",thread::current().id(),p);
thread::sleep(Duration::from_secs(1));
OutPut::TradeFlow1(vec![1.0,2.0,3.0],vec![4.0,5.0,6.0])
}
fn strategy_bolling(p:Parameter) -> OutPut{
println!("thread:{:?} run 【follow_trend】 strategy {:?}",thread::current().id(),p);
thread::sleep(Duration::from_secs(1));
OutPut::TradeFlow2(vec![1.0,2.0,3.0])
}
fn strategy_high_freq(p:Parameter)->OutPut{
println!("thread:{:?} run 【follow_trend】 strategy {:?}",thread::current().id(),p);
thread::sleep(Duration::from_secs(1));
OutPut::TradeFlow0(0.0)
}
fn get_strategies() -> Vec<Job>{
let p1 = Parameter::P1(2.0,3.0);
let s1: Job = Box::new(move || strategy_follow_trend(p1));
let p2 = Parameter::P2(2.0,3.0,4.0);
let s2: Job = Box::new(move || strategy_bolling(p2));
let p3 = Parameter::P0(());
let s3: Job = Box::new(move || strategy_high_freq(p3));
let p4 = Parameter::P0(());
let s4: Job = Box::new(move || strategy_high_freq(p4));
let p5 = Parameter::P1(2.0,3.0);
let s5: Job = Box::new(move || strategy_follow_trend(p5));
let p6 = Parameter::P2(2.0,3.0,4.0);
let s6: Job = Box::new(move || strategy_bolling(p6));
let p7 = Parameter::P0(());
let s7: Job = Box::new(move || strategy_high_freq(p7));
let p8 = Parameter::P0(());
let s8: Job = Box::new(move || strategy_high_freq(p8));
let strategies:Vec<Job> = vec![s1,s2,s3,s4,s5,s6,s7,s8];
strategies
}
// 定义闭包中没有参数输入的函数类型,做为发送对象
// 需要加Send约束的原因是便于在线程中发送,支持Spawn函数
fn main(){
thread_pool_run();
thread_pool_run_single();
thread_pool_run2();
}
fn thread_pool_run(){
println!("----------------thread_pool_run-----------------");
let pool = ThreadPool::new(NonZeroUsize::new(3).unwrap());
let start = Instant::now();
//type Job = Box<dyn FnOnce() -> OutPut + Send +'static +UnwindSafe >;
pool.scope(|scope|{
let strategies = get_strategies();//vec<Job>
for strategy in strategies{
let job = strategy;
scope.spawn(||{
job();
});
}
});
let duration = start.elapsed();
println!("run duration : {:?}",duration);
}
fn thread_pool_run_single(){
println!("----------------thread_pool_run_single-----------------");
let pool = ThreadPool::new(NonZeroUsize::new(3).unwrap());
let start = Instant::now();
pool.scope(|scope|{
let strategies = get_strategies();
for strategy in strategies{
let job = strategy();
scope.spawn(move||{
job;
});
}
});
let duration = start.elapsed();
println!("run duration : {:?}",duration);
}
fn thread_pool_run2(){
println!("----------------thread_pool_run2-----------------");
let pool = ThreadPool::new(NonZeroUsize::new(3).unwrap());
let start = Instant::now();
pool.scope(|scope|{
scope.spawn(||{
let p0 = Parameter::P0(());
strategy_high_freq(p0);
});
scope.spawn(||{
let p1 = Parameter::P1(2.0,3.0);
strategy_follow_trend(p1);
});
scope.spawn(||{
let p2 = Parameter::P2(2.0,3.0,4.0);
strategy_bolling(p2);
});
scope.spawn(||{
let p3 = Parameter::P0(());
strategy_high_freq(p3);
});
scope.spawn(||{
let p4 = Parameter::P1(2.0,3.0);
strategy_follow_trend(p4);
});
scope.spawn(||{
let p5 = Parameter::P2(2.0,3.0,4.0);
strategy_bolling(p5);
});
scope.spawn(||{
let p6 = Parameter::P0(());
strategy_high_freq(p6);
});
scope.spawn(||{
let p7 = Parameter::P0(());
strategy_high_freq(p7);
});
});
let duration = start.elapsed();
println!("run duration : {:?}",duration);
}
需要说明的是,thread_pool_run_single()版本是单线程版本。
三、输出
css
----------------thread_pool_run-----------------
thread:ThreadId(2) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(3) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(4) run 【follow_trend】 strategy P0(())
thread:ThreadId(4) run 【follow_trend】 strategy P0(())
thread:ThreadId(3) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(2) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(4) run 【follow_trend】 strategy P0(())
thread:ThreadId(3) run 【follow_trend】 strategy P0(())
run duration : 3.0030121s
----------------thread_pool_run_single-----------------
thread:ThreadId(1) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(1) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
thread:ThreadId(1) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(1) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
run duration : 8.0092952s
----------------thread_pool_run2-----------------
thread:ThreadId(8) run 【follow_trend】 strategy P0(())
thread:ThreadId(10) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(9) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(10) run 【follow_trend】 strategy P0(())
thread:ThreadId(8) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(9) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(10) run 【follow_trend】 strategy P0(())
thread:ThreadId(8) run 【follow_trend】 strategy P0(())
run duration : 3.003468s
上面输出,可以清楚看到各线程的运行情况。
css
thread_pool_run();// 正常
thread_pool_run_single(); // 不正常,另掉坑
thread_pool_run2(); // 正常
和自定义版本、tokio版本的表征基本一致。