Rust: 量化策略回测与简易线程池构建(MPMC)

在上文
Rust: 量化策略回测与简易线程池构建、子线程执行观测

中介绍了自定义MPSC+Arc+Mutex线程池框架方案。

下面来介绍一下MPMC的线程池方案。

感谢下文:
Rust简单实现一个线程池

提供的MPMC框架,具体相关文章可以详细阅读。

和上文中相比,本文中的MPMC方案技术相对难度较大,要点较多。

一、主要技术

1、策略封装

和之前

css 复制代码
type Job = Box<dyn FnOnce() -> OutPut + Send +'static >;

策略类型不同,本次采用了下面的封装:

css 复制代码
type Job = Box<dyn FnOnce() -> OutPut + Send +'static +UnwindSafe>;

具体的UnwindSafe类型,你不需要过多考虑,这个是框架需要。

2、消息传递机制

在这里采用了MPMC和OneShot的消息传递机制。

MPSC和MPSC不同,这是一个多生产者和多消费者方案。

css 复制代码
let (task_sender, task_receiver) = std::sync::mpmc::channel::<Box<dyn FnOnce() + Send>>();

而OneShot,是单生产者和单消费者的方案。

css 复制代码
pub struct JoinHandle<R> {
    receiver: oneshot::Receiver<R>,
}

3、scope和spawn的分发方式

下面介绍。

4、PhantomData

css 复制代码
 pub struct Scope<'scope, 'env: 'scope> {
    task_sender: &'scope std::sync::mpmc::Sender<Box<dyn FnOnce() + Send>>,
    scope: std::marker::PhantomData<&'scope mut &'scope ()>,
    env: std::marker::PhantomData<&'env mut &'env ()>,
}

5、transmute:强制类型传换

css 复制代码
let boxed_task: Box<dyn FnOnce() + Send + 'scope> = Box::new(task);
        let boxed_task: Box<dyn FnOnce() + Send + 'static> =
            unsafe { std::mem::transmute(boxed_task) };

6、线程阻塞和唤醒

线程阻塞:

css 复制代码
pub fn scope<'env, F, T>(&self, f: F) -> T
where
    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T
{
    let scope = Scope {
        task_sender: &self.task_sender,
        scope: PhantomData,
        env: PhantomData,
    };
    let ret = f(&scope);
    while scope.running_tasks.load(std::sync::atomic::Ordering::Acquire) > 0 {
        std::thread::park();
    }
    return ret;
} 

线程唤醒:

css 复制代码
impl<'scope, 'env: 'scope> Scope<'scope, 'env> {
    pub fn spawn<F, R>(&'scope self, task: F) -> ScopeJoinHandle<'scope, R>
    where
        F: FnOnce() -> R + Send + 'scope,
        R: Send + 'scope,
    {
        /* ... */
        self.running_tasks.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
        let main_thread = std::thread::current();
        let task =  move || {
            let ret = task();
            oneshot_sender.send(ret).unwrap();
            if self.running_tasks.fetch_sub(1, std::sync::atomic::Ordering::Release) == 1 {
                main_thread.unpark();
            }
        };
        /* ... */
    }
}   

7、原子操作

css 复制代码
use std::sync::atomic::{AtomicUsize, Ordering};

8、捕获panic

css 复制代码
let task =  move || {
    let ret = std::panic::catch_unwind(task);
    oneshot_sender.send(ret).unwrap();
    if self.running_tasks.fetch_sub(1, std::sync::atomic::Ordering::Release) == 1 {
        main_thread.unpark();
    }
};

9、切换到nightly版本

目前,MPMC在标准库中还没有stable。因此需要:

css 复制代码
#![feature(mpmc_channel)]

在具体操作上,如果你当前是stable版本,运行下面的代码时需要切换到nightly版本,可以

css 复制代码
rustup default nightly
rustc --version  

完成上面的切换和确认后,你就可以运行下面的代码了。

二、具体代码

1、toml

css 复制代码
[dependencies]
oneshot = "0.1.1"
chrono = "0.4.24"
rayon = "1.7.0"

2、代码

css 复制代码
#![feature(mpmc_channel)]
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::panic::UnwindSafe;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpmc;
use std::thread;
use std::time::{Duration,Instant};
use oneshot;// 三方库
// use rayon::ThreadPoolBuilder;
// use rayon::iter::IntoParallelRefMutIterator;

// MPMC线程池方案
#[derive(Debug)]
pub struct ThreadPool {
    task_sender: mpmc::Sender<Box<dyn FnOnce() + Send>>,
}

impl ThreadPool {
    pub fn new(count: NonZeroUsize) -> Self {
        let (task_sender, task_receiver) = mpmc::channel::<Box<dyn FnOnce() + Send>>();
        for _ in 0..count.get() {
            let task_receiver = task_receiver.clone();
            std::thread::spawn(move || {
                loop {
                    let Ok(task) = task_receiver.recv() else {
                        return;
                    };
                    task();
                }
            });
        }
        Self { task_sender }
    }
    pub fn spawn<F, R>(&self, task: F) -> JoinHandle<R>
    where
        F: FnOnce() -> R + Send + UnwindSafe + 'static,
        R: Send + 'static,
    {
        let (ret_sender, ret_receiver) = oneshot::channel::<std::thread::Result<R>>();
        let task = move || {
            let ret = std::panic::catch_unwind(task);
            // If the receiver was dropped that means the caller was not interested in the result
            _ = ret_sender.send(ret);
        };
        self.task_sender.send(Box::new(task)).expect("Unexpected error while sending tasks. This should never happen unless all threads are dropped");
        JoinHandle {
            receiver: ret_receiver,
        }
    }
    pub fn scope<'env, F, T>(&'env self, f: F) -> T
    where
        F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
    {
        let scope = Scope {
            running_tasks: AtomicUsize::new(0),
            task_sender: &self.task_sender,
            scope: PhantomData,
            env: PhantomData,
        };
        let ret = f(&scope);
        while scope.running_tasks.load(Ordering::Acquire) > 0 {
            std::thread::park();
        }
        return ret;
    }
}

#[derive(Debug)]
pub struct JoinHandle<R> {
    receiver: oneshot::Receiver<std::thread::Result<R>>,
}

impl<R> JoinHandle<R> {
    pub fn join(self) -> Result<std::thread::Result<R>, oneshot::RecvError> {
        self.receiver.recv()
    }
}

#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope> {
    running_tasks: AtomicUsize,
    task_sender: &'scope mpmc::Sender<Box<dyn FnOnce() + Send>>,
    scope: PhantomData<&'scope mut &'scope ()>,
    env: PhantomData<&'env mut &'env ()>,
}

impl<'scope, 'env: 'scope> Scope<'scope, 'env> {
    pub fn spawn<F, R>(&'scope self, task: F) -> ScopeJoinHandle<'scope, R>
    where
        F: FnOnce() -> R + Send + UnwindSafe + 'scope,
        R: Send + 'scope,
    {
        let (ret_sender, ret_receiver) = oneshot::channel::<std::thread::Result<R>>();
        let main_thread = std::thread::current();
        self.running_tasks.fetch_add(1, Ordering::Relaxed);
        let task = move || {
            let ret = std::panic::catch_unwind(task);
            // If the receiver was dropped that means the caller was not interested in the result
            _ = ret_sender.send(ret);
            // Make sure the previous operation is completed before decrementing the counter
            if self.running_tasks.fetch_sub(1, Ordering::Release) == 1 {
                main_thread.unpark();
            }
        };
        let boxed_task: Box<dyn FnOnce() + Send + 'scope> = Box::new(task);
        let boxed_task: Box<dyn FnOnce() + Send + 'static> =
            unsafe { std::mem::transmute(boxed_task) };
        self.task_sender.send(boxed_task).expect("Unexpected error while sending tasks. This should never happen unless all threads are dropped");
        ScopeJoinHandle {
            receiver: ret_receiver,
            scope: PhantomData,
        }
    }
}

#[derive(Debug)]
pub struct ScopeJoinHandle<'scope, R> {
    receiver: oneshot::Receiver<std::thread::Result<R>>,
    scope: PhantomData<&'scope mut &'scope ()>,
}

impl<'scope, R> ScopeJoinHandle<'scope, R> {
    pub fn join(self) -> Result<std::thread::Result<R>, oneshot::RecvError> {
        self.receiver.recv()
    }
}

#[derive(Debug,Clone)]
enum Parameter{
    P0(()),
    P1(f32,f32),
    P2(f32,f32,f32),
    P3(f32,f32,f32,f32),
}
#[derive(Debug,Clone,PartialEq)]
enum OutPut{
    TradeFlow0(f32),
    TradeFlow1(Vec<f32>,Vec<f32>),
    TradeFlow2(Vec<f32>),
}
type Job = Box<dyn FnOnce() -> OutPut + Send +'static +UnwindSafe>;
fn strategy_follow_trend(p:Parameter) ->OutPut{
    println!("thread:{:?} run 【follow_trend】 strategy {:?}",thread::current().id(),p);
    thread::sleep(Duration::from_secs(1));
    OutPut::TradeFlow1(vec![1.0,2.0,3.0],vec![4.0,5.0,6.0])
    
}
fn strategy_bolling(p:Parameter) -> OutPut{
    println!("thread:{:?} run 【follow_trend】 strategy {:?}",thread::current().id(),p);
    thread::sleep(Duration::from_secs(1));
    OutPut::TradeFlow2(vec![1.0,2.0,3.0])
    
}
fn strategy_high_freq(p:Parameter)->OutPut{
    println!("thread:{:?} run 【follow_trend】 strategy {:?}",thread::current().id(),p);
    thread::sleep(Duration::from_secs(1));
    OutPut::TradeFlow0(0.0)
    
}

fn get_strategies() -> Vec<Job>{
    let p1 = Parameter::P1(2.0,3.0);
    let s1: Job  = Box::new(move || strategy_follow_trend(p1));
    let p2 = Parameter::P2(2.0,3.0,4.0);
    let s2: Job  = Box::new(move || strategy_bolling(p2));
    let p3 = Parameter::P0(());
    let s3: Job  = Box::new(move || strategy_high_freq(p3));
    let p4 = Parameter::P0(());
    let s4: Job  = Box::new(move || strategy_high_freq(p4));
    let p5 = Parameter::P1(2.0,3.0);
    let s5: Job  = Box::new(move || strategy_follow_trend(p5));
    let p6 = Parameter::P2(2.0,3.0,4.0);
    let s6: Job  = Box::new(move || strategy_bolling(p6));
    let p7 = Parameter::P0(());
    let s7: Job  = Box::new(move || strategy_high_freq(p7));
    let p8 = Parameter::P0(());
    let s8: Job  = Box::new(move || strategy_high_freq(p8));
    let strategies:Vec<Job> = vec![s1,s2,s3,s4,s5,s6,s7,s8];
    strategies
}
// 定义闭包中没有参数输入的函数类型,做为发送对象
// 需要加Send约束的原因是便于在线程中发送,支持Spawn函数

fn main(){
    thread_pool_run();
    thread_pool_run_single();
    thread_pool_run2();
}
fn thread_pool_run(){
    println!("----------------thread_pool_run-----------------");
    let pool = ThreadPool::new(NonZeroUsize::new(3).unwrap());
    let start = Instant::now();
    //type Job = Box<dyn FnOnce() -> OutPut + Send +'static +UnwindSafe >;
    pool.scope(|scope|{
        let strategies = get_strategies();//vec<Job>
        for strategy in strategies{
            let job = strategy;
            scope.spawn(||{
                job();
            });
        } 
        
    });
        
    let duration = start.elapsed();
    println!("run duration : {:?}",duration);
}


fn thread_pool_run_single(){
    println!("----------------thread_pool_run_single-----------------");
    let pool = ThreadPool::new(NonZeroUsize::new(3).unwrap());
    let start = Instant::now();
    
    pool.scope(|scope|{
        let strategies = get_strategies();
        for strategy in strategies{
            let job = strategy();
            scope.spawn(move||{
                job;
            });
        } 
        
    });
        
    let duration = start.elapsed();
    println!("run duration : {:?}",duration);
}
fn thread_pool_run2(){
    println!("----------------thread_pool_run2-----------------");
    let pool = ThreadPool::new(NonZeroUsize::new(3).unwrap());
    let start = Instant::now();
    pool.scope(|scope|{
        scope.spawn(||{
            let p0 = Parameter::P0(());
            strategy_high_freq(p0);
        });
        scope.spawn(||{
            let p1 = Parameter::P1(2.0,3.0);
            strategy_follow_trend(p1);
        });
        scope.spawn(||{
            let p2 = Parameter::P2(2.0,3.0,4.0);
            strategy_bolling(p2);
        });
        scope.spawn(||{
            let p3 = Parameter::P0(());
            strategy_high_freq(p3);
        });
        scope.spawn(||{
            let p4 = Parameter::P1(2.0,3.0);
            strategy_follow_trend(p4);
        });
        scope.spawn(||{
            let p5 = Parameter::P2(2.0,3.0,4.0);
            strategy_bolling(p5);
        });
        scope.spawn(||{
            let p6 = Parameter::P0(());
            strategy_high_freq(p6);
        });
        scope.spawn(||{
            let p7 = Parameter::P0(());
            strategy_high_freq(p7);
        });
    });
        
    let duration = start.elapsed();
    println!("run duration : {:?}",duration);
}

需要说明的是,thread_pool_run_single()版本是单线程版本。

三、输出

css 复制代码
----------------thread_pool_run-----------------
thread:ThreadId(2) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(3) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(4) run 【follow_trend】 strategy P0(())
thread:ThreadId(4) run 【follow_trend】 strategy P0(())
thread:ThreadId(3) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(2) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(4) run 【follow_trend】 strategy P0(())
thread:ThreadId(3) run 【follow_trend】 strategy P0(())
run duration : 3.0030121s
----------------thread_pool_run_single-----------------
thread:ThreadId(1) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(1) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
thread:ThreadId(1) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(1) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
thread:ThreadId(1) run 【follow_trend】 strategy P0(())
run duration : 8.0092952s
----------------thread_pool_run2-----------------
thread:ThreadId(8) run 【follow_trend】 strategy P0(())
thread:ThreadId(10) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(9) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(10) run 【follow_trend】 strategy P0(())
thread:ThreadId(8) run 【follow_trend】 strategy P2(2.0, 3.0, 4.0)
thread:ThreadId(9) run 【follow_trend】 strategy P1(2.0, 3.0)
thread:ThreadId(10) run 【follow_trend】 strategy P0(())
thread:ThreadId(8) run 【follow_trend】 strategy P0(())
run duration : 3.003468s

上面输出,可以清楚看到各线程的运行情况。

css 复制代码
    thread_pool_run();// 正常
    thread_pool_run_single(); // 不正常,另掉坑
    thread_pool_run2(); // 正常

和自定义版本、tokio版本的表征基本一致。

相关推荐
摇滚侠1 小时前
Vue 项目实战《尚医通》,完成确定挂号业务,笔记46
java·开发语言·javascript·vue.js·笔记
绝无仅有1 小时前
面试日志elk之ES数据查询与数据同步
后端·面试·架构
码农BookSea1 小时前
用好PowerMock,轻松搞定那些让你头疼的单元测试
后端·单元测试
十五年专注C++开发2 小时前
libdatrie: 一个高效的 基于双数组字典树(Double-Array Trie)的C语言函数库
c语言·开发语言·trie
绝无仅有2 小时前
大场面试之最终一致性与分布式锁
后端·面试·架构
程序猿_极客2 小时前
【2025最新】 Java入门到实战:包装类、字符串转换、equals/toString + 可变字符串,一篇搞定开发高频场景(含案例解析)
java·开发语言·java进阶·面试核心·java快速入门
U***e632 小时前
Python测试
开发语言·python
晨晖23 小时前
springboot的Thymeleaf语法
java·spring boot·后端
yi碗汤园3 小时前
Visual Studio常用的快捷键
开发语言·ide·c#·编辑器·visual studio