Rust: 量化策略回测与简易线程池构建、子线程执行观测

在前文
Rust : 量化策略回测与多线程技术原型

中,介绍了如何利用多线程来构建多策略的回测框架。本文将探讨一下,如何自定义构建简易线程池的回测框架。

一、技术要点

除了多线程中的技术要点外(具体可参考前文),还需要通过构建一套消息传递机制,把策略任务分发给不同的子线程。

消息传递,在现有的标准库中,可以选择:

1、MPSC +Arc+Mutex的方案

本来MPSC是多发送者,单消费者模式,但是利用Arc和Mutex的互斥方案,也可以将策略任务发送到不同的子线程进行运行。

2、MPMC的方案

MPSC是多发送者,多消费者模式,这种模式可以胜任一个发送者,多个消费者的场景。

二、相关代码

下面构建MPSC +Arc +Mutex方案的线程池:

css 复制代码
// 这个线程池是一个比较粗的框架;

use std::sync::atomic::AtomicU32;
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
use std::time::Instant;
use std::collections::HashMap;
use chrono::{DateTime, Local,NaiveDateTime};

#[derive(Debug,Clone)]
enum Parameter{
    P0(()),
    P1(f32,f32),
    P2(f32,f32,f32),
    P3(f32,f32,f32,f32),
}
#[derive(Debug,Clone,PartialEq)]
enum OutPut{
    TradeFlow0(f32),
    TradeFlow1(Vec<f32>,Vec<f32>),
    TradeFlow2(Vec<f32>),
}
// 存在不均衡的情况 => steal work todo!
// 简易的threadpool,没有用condvar.
// send_num可以不要,这里只是记录一下发送的情况
pub struct ThreadPool {
    threads_num: usize,
    workers: Vec<Worker>,
    job_sender: mpsc::Sender<Message>,
    send_num: Arc<Mutex<usize>>,
    feedback_receiver: mpsc::Receiver<Message>,
    finished_num:Arc<AtomicU32>,
}

// 定义发送内容

enum Message{
    Task(Task),//发送策略任务
    Feedback(Feedback),//接收子策略任务的执行结果
    Shutdown,
}

struct Task{
    job :Job,
    task_id :usize,
}
#[derive(Debug)]
struct Feedback{
    thread_id:usize,//子线程
    task_id:usize,// 子任务
    finished_time:NaiveDateTime,//子任务完成时间
}

type Job = Box<dyn FnOnce() -> OutPut + Send +'static>;

impl ThreadPool {
    pub fn new(threads_num: usize) -> ThreadPool {
        assert!(threads_num > 0);
        let (job_sender, job_receiver) = mpsc::channel::<Message>();
        let (feedback_sender, feedback_receiver) = mpsc::channel::<Message>();
        let job_receiver = Arc::new(Mutex::new(job_receiver));
        let feedback_sender = Arc::new(Mutex::new(feedback_sender));
        let mut workers = Vec::with_capacity(threads_num);
        let send_num = Arc::new(Mutex::new(0_usize));
        let finished_num = Arc::new(AtomicU32::new(0));
        for id in 0..threads_num {
            let mut worker = Worker::new(id);
            worker.run(&job_receiver,&feedback_sender);
            workers.push(worker);
        }
        ThreadPool { threads_num,workers, job_sender,send_num,feedback_receiver,finished_num}
    }

    pub fn execute(&mut self, task_id:usize,job: Job)
    {

        let task = Task{job:job,task_id:task_id};
        let mut send_num = self.send_num.lock().unwrap();
        *send_num = *send_num +1;
        self.job_sender.send(Message::Task(task)).unwrap();

    }
    pub fn wait_feedback(&mut self){
        loop{
            let msg = self.feedback_receiver.recv().unwrap();
            match msg{
                Message::Feedback(feedback) => {
                    self.finished_num.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                    let finished_num = self.finished_num.load(std::sync::atomic::Ordering::Relaxed);
                    println!("Worker {} -> {:?}", feedback.thread_id,feedback);
                    if finished_num == *self.send_num.lock().unwrap() as u32{
                        break;
                    }
                    
                },
                _ =>{
                    break;
                },

            }
        }
    }

}

pub struct Worker {
    id: usize,
    thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
    pub fn new(id: usize)-> Self{
        Worker {
            id:id,
            thread: None,
        }
    }
    pub fn run(&mut self, job_receiver: &Arc<Mutex<mpsc::Receiver<Message>>>,feedback_sender: &Arc<Mutex<mpsc::Sender<Message>>>){
        let id = self.id;
        let job_receiver = job_receiver.clone();
        let feedback_sender = feedback_sender.clone();
        let thread = thread::spawn(move || loop {
            let msg = job_receiver.lock().unwrap().recv().unwrap();
            match msg{
                Message::Task(task) => {
                    //println!("Worker {} -> executing task {}.", id,task.task_id);
                    (task.job)();
                    feedback_sender.lock().unwrap().send(Message::Feedback(Feedback{thread_id:id,task_id:task.task_id,finished_time:Local::now().naive_local()})).unwrap();
                },
                Message::Shutdown => {
                    //println!("Worker {} received shutdown message.", id);
                    break;//很关键
                },
                _ =>{},
            }          

        });
        self.thread = Some(thread);

    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        let start = Instant::now();
        for _ in 0..self.threads_num{
            self.job_sender.send(Message::Shutdown).unwrap();
        }
        //println!("drop worker :{:?}",self.workers.len());

        for (i,worker) in (&mut self.workers).into_iter().enumerate() {
            //println!("------Shutting down worker {}  i:{} -----------", worker.id,i);
            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
        let duration = start.elapsed();
        //println!("drop thread pool duration : {:?}",duration);

    }
}

fn strategy_follow_trend(p:Parameter) ->OutPut{
    //println!("run 【follow_trend】 strategy {:?}",p);
    thread::sleep(Duration::from_secs(1));
    OutPut::TradeFlow1(vec![1.0,2.0,3.0],vec![4.0,5.0,6.0])
    
}
fn strategy_bolling(p:Parameter) -> OutPut{
    //println!("run 【bolling】 strategy {:?}",p);
    thread::sleep(Duration::from_secs(1));
    OutPut::TradeFlow2(vec![1.0,2.0,3.0])
    
}
fn strategy_high_freq(p:Parameter)->OutPut{
    //println!("run 【high_freq】 strategy {:?}",p);
    thread::sleep(Duration::from_secs(1));
    OutPut::TradeFlow0(0.0)
    
}

fn get_strategies() -> Vec<Job>{
    let p1 = Parameter::P1(2.0,3.0);
    let s1: Job  = Box::new(move || strategy_follow_trend(p1));
    let p2 = Parameter::P2(2.0,3.0,4.0);
    let s2: Job  = Box::new(move || strategy_bolling(p2));
    let p3 = Parameter::P0(());
    let s3: Job  = Box::new(move || strategy_high_freq(p3));
    let p4 = Parameter::P0(());
    let s4: Job  = Box::new(move || strategy_high_freq(p4));
    let p5 = Parameter::P1(2.0,3.0);
    let s5: Job  = Box::new(move || strategy_follow_trend(p5));
    let p6 = Parameter::P2(2.0,3.0,4.0);
    let s6: Job  = Box::new(move || strategy_bolling(p6));
    let p7 = Parameter::P0(());
    let s7: Job  = Box::new(move || strategy_high_freq(p7));
    let p8 = Parameter::P0(());
    let s8: Job  = Box::new(move || strategy_high_freq(p8));
    let strategies:Vec<Job> = vec![s1,s2,s3,s4,s5,s6,s7,s8];
    strategies
}
pub fn main() {
    let thread_num =3;
    let mut pool = ThreadPool::new(thread_num);
    let start = Instant::now();
    let strategies = get_strategies();
    let num = strategies.len();
    for (task_id,strategy) in strategies.into_iter().enumerate(){
        pool.execute(task_id as usize,strategy);
    }
    pool.wait_feedback();
    println!("strategies num :  {:?}",num);
    println!("toal send_num  :  {:?}",*pool.send_num.lock().unwrap());
    let duration = start.elapsed();
    println!("thread pool num: {:?}, time cost sec: {:?}",thread_num,duration.as_secs_f64());
}

上面自定义线程池的特点是,可以有效观测,各个子线程的执行和分配情况。如果不需要观测子线程的情况,还可以更精简一些。

三、运行验证

上面代码输出如下:

css 复制代码
Worker 2 -> Feedback { thread_id: 2, task_id: 2, finished_time: 2025-11-12T09:12:08.364818600 }
Worker 1 -> Feedback { thread_id: 1, task_id: 1, finished_time: 2025-11-12T09:12:08.369362600 }
Worker 0 -> Feedback { thread_id: 0, task_id: 0, finished_time: 2025-11-12T09:12:08.369393 }
Worker 2 -> Feedback { thread_id: 2, task_id: 3, finished_time: 2025-11-12T09:12:09.369849100 }
Worker 0 -> Feedback { thread_id: 0, task_id: 5, finished_time: 2025-11-12T09:12:09.369859800 }
Worker 1 -> Feedback { thread_id: 1, task_id: 4, finished_time: 2025-11-12T09:12:09.369868500 }
Worker 0 -> Feedback { thread_id: 0, task_id: 7, finished_time: 2025-11-12T09:12:10.370486800 }
Worker 2 -> Feedback { thread_id: 2, task_id: 6, finished_time: 2025-11-12T09:12:10.370496400 }
strategies num :  8
toal send_num  :  8
thread pool num: 3, time cost sec: 3.0131664

可以看到,8个运行花时1秒的单策略,在这套线池程中花时约3秒左右。总体上符合预期。

其中,3个子线程构成的线程池,运行8个策略,分配如下:

css 复制代码
线程0运行【0,5,7】子策略,花时3秒;
线程1运行【1,4】子策略,花时2秒;
线程2运行【2,3,6】子策略,花时3秒;

其中,线程0和线程2各花3秒,线程1花2秒,取最大值,故线程池共花时约3秒。此外,这个池程池还具备一定的负载均衡能力。

相关推荐
资深web全栈开发2 小时前
二分搜索中 `right = mid` 而非 `right = mid + 1` 的解释
算法·rust·二分搜索
jz_ddk2 小时前
[数学基础] 瑞利分布:数学原理、物理意义及Python实验
开发语言·python·数学·概率论·信号分析
大G的笔记本2 小时前
Java JVM 篇常见面试题
java·开发语言·jvm
绝无仅有2 小时前
某东电商平台的MySQL面试知识点分析
后端·面试·架构
Pomelo_刘金2 小时前
Rust : 新版本 1.89.0
rust
ZHE|张恒2 小时前
深入理解 Java 双亲委派机制:JVM 类加载体系全解析
java·开发语言·jvm
Apifox2 小时前
如何在 Apifox 中使用「模块」合理地组织接口
前端·后端·测试
q_19132846952 小时前
基于SpringBoot+Vue2的美食菜谱美食分享平台
java·spring boot·后端·计算机·毕业设计·美食