中,介绍了如何利用多线程来构建多策略的回测框架。本文将探讨一下,如何自定义构建简易线程池的回测框架。
一、技术要点
除了多线程中的技术要点外(具体可参考前文),还需要通过构建一套消息传递机制,把策略任务分发给不同的子线程。
消息传递,在现有的标准库中,可以选择:
1、MPSC +Arc+Mutex的方案
本来MPSC是多发送者,单消费者模式,但是利用Arc和Mutex的互斥方案,也可以将策略任务发送到不同的子线程进行运行。
2、MPMC的方案
MPSC是多发送者,多消费者模式,这种模式可以胜任一个发送者,多个消费者的场景。
二、相关代码
下面构建MPSC +Arc +Mutex方案的线程池:
css
// 这个线程池是一个比较粗的框架;
use std::sync::atomic::AtomicU32;
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
use std::time::Instant;
use std::collections::HashMap;
use chrono::{DateTime, Local,NaiveDateTime};
#[derive(Debug,Clone)]
enum Parameter{
P0(()),
P1(f32,f32),
P2(f32,f32,f32),
P3(f32,f32,f32,f32),
}
#[derive(Debug,Clone,PartialEq)]
enum OutPut{
TradeFlow0(f32),
TradeFlow1(Vec<f32>,Vec<f32>),
TradeFlow2(Vec<f32>),
}
// 存在不均衡的情况 => steal work todo!
// 简易的threadpool,没有用condvar.
// send_num可以不要,这里只是记录一下发送的情况
pub struct ThreadPool {
threads_num: usize,
workers: Vec<Worker>,
job_sender: mpsc::Sender<Message>,
send_num: Arc<Mutex<usize>>,
feedback_receiver: mpsc::Receiver<Message>,
finished_num:Arc<AtomicU32>,
}
// 定义发送内容
enum Message{
Task(Task),//发送策略任务
Feedback(Feedback),//接收子策略任务的执行结果
Shutdown,
}
struct Task{
job :Job,
task_id :usize,
}
#[derive(Debug)]
struct Feedback{
thread_id:usize,//子线程
task_id:usize,// 子任务
finished_time:NaiveDateTime,//子任务完成时间
}
type Job = Box<dyn FnOnce() -> OutPut + Send +'static>;
impl ThreadPool {
pub fn new(threads_num: usize) -> ThreadPool {
assert!(threads_num > 0);
let (job_sender, job_receiver) = mpsc::channel::<Message>();
let (feedback_sender, feedback_receiver) = mpsc::channel::<Message>();
let job_receiver = Arc::new(Mutex::new(job_receiver));
let feedback_sender = Arc::new(Mutex::new(feedback_sender));
let mut workers = Vec::with_capacity(threads_num);
let send_num = Arc::new(Mutex::new(0_usize));
let finished_num = Arc::new(AtomicU32::new(0));
for id in 0..threads_num {
let mut worker = Worker::new(id);
worker.run(&job_receiver,&feedback_sender);
workers.push(worker);
}
ThreadPool { threads_num,workers, job_sender,send_num,feedback_receiver,finished_num}
}
pub fn execute(&mut self, task_id:usize,job: Job)
{
let task = Task{job:job,task_id:task_id};
let mut send_num = self.send_num.lock().unwrap();
*send_num = *send_num +1;
self.job_sender.send(Message::Task(task)).unwrap();
}
pub fn wait_feedback(&mut self){
loop{
let msg = self.feedback_receiver.recv().unwrap();
match msg{
Message::Feedback(feedback) => {
self.finished_num.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let finished_num = self.finished_num.load(std::sync::atomic::Ordering::Relaxed);
println!("Worker {} -> {:?}", feedback.thread_id,feedback);
if finished_num == *self.send_num.lock().unwrap() as u32{
break;
}
},
_ =>{
break;
},
}
}
}
}
pub struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
pub fn new(id: usize)-> Self{
Worker {
id:id,
thread: None,
}
}
pub fn run(&mut self, job_receiver: &Arc<Mutex<mpsc::Receiver<Message>>>,feedback_sender: &Arc<Mutex<mpsc::Sender<Message>>>){
let id = self.id;
let job_receiver = job_receiver.clone();
let feedback_sender = feedback_sender.clone();
let thread = thread::spawn(move || loop {
let msg = job_receiver.lock().unwrap().recv().unwrap();
match msg{
Message::Task(task) => {
//println!("Worker {} -> executing task {}.", id,task.task_id);
(task.job)();
feedback_sender.lock().unwrap().send(Message::Feedback(Feedback{thread_id:id,task_id:task.task_id,finished_time:Local::now().naive_local()})).unwrap();
},
Message::Shutdown => {
//println!("Worker {} received shutdown message.", id);
break;//很关键
},
_ =>{},
}
});
self.thread = Some(thread);
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
let start = Instant::now();
for _ in 0..self.threads_num{
self.job_sender.send(Message::Shutdown).unwrap();
}
//println!("drop worker :{:?}",self.workers.len());
for (i,worker) in (&mut self.workers).into_iter().enumerate() {
//println!("------Shutting down worker {} i:{} -----------", worker.id,i);
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
let duration = start.elapsed();
//println!("drop thread pool duration : {:?}",duration);
}
}
fn strategy_follow_trend(p:Parameter) ->OutPut{
//println!("run 【follow_trend】 strategy {:?}",p);
thread::sleep(Duration::from_secs(1));
OutPut::TradeFlow1(vec![1.0,2.0,3.0],vec![4.0,5.0,6.0])
}
fn strategy_bolling(p:Parameter) -> OutPut{
//println!("run 【bolling】 strategy {:?}",p);
thread::sleep(Duration::from_secs(1));
OutPut::TradeFlow2(vec![1.0,2.0,3.0])
}
fn strategy_high_freq(p:Parameter)->OutPut{
//println!("run 【high_freq】 strategy {:?}",p);
thread::sleep(Duration::from_secs(1));
OutPut::TradeFlow0(0.0)
}
fn get_strategies() -> Vec<Job>{
let p1 = Parameter::P1(2.0,3.0);
let s1: Job = Box::new(move || strategy_follow_trend(p1));
let p2 = Parameter::P2(2.0,3.0,4.0);
let s2: Job = Box::new(move || strategy_bolling(p2));
let p3 = Parameter::P0(());
let s3: Job = Box::new(move || strategy_high_freq(p3));
let p4 = Parameter::P0(());
let s4: Job = Box::new(move || strategy_high_freq(p4));
let p5 = Parameter::P1(2.0,3.0);
let s5: Job = Box::new(move || strategy_follow_trend(p5));
let p6 = Parameter::P2(2.0,3.0,4.0);
let s6: Job = Box::new(move || strategy_bolling(p6));
let p7 = Parameter::P0(());
let s7: Job = Box::new(move || strategy_high_freq(p7));
let p8 = Parameter::P0(());
let s8: Job = Box::new(move || strategy_high_freq(p8));
let strategies:Vec<Job> = vec![s1,s2,s3,s4,s5,s6,s7,s8];
strategies
}
pub fn main() {
let thread_num =3;
let mut pool = ThreadPool::new(thread_num);
let start = Instant::now();
let strategies = get_strategies();
let num = strategies.len();
for (task_id,strategy) in strategies.into_iter().enumerate(){
pool.execute(task_id as usize,strategy);
}
pool.wait_feedback();
println!("strategies num : {:?}",num);
println!("toal send_num : {:?}",*pool.send_num.lock().unwrap());
let duration = start.elapsed();
println!("thread pool num: {:?}, time cost sec: {:?}",thread_num,duration.as_secs_f64());
}
上面自定义线程池的特点是,可以有效观测,各个子线程的执行和分配情况。如果不需要观测子线程的情况,还可以更精简一些。
三、运行验证
上面代码输出如下:
css
Worker 2 -> Feedback { thread_id: 2, task_id: 2, finished_time: 2025-11-12T09:12:08.364818600 }
Worker 1 -> Feedback { thread_id: 1, task_id: 1, finished_time: 2025-11-12T09:12:08.369362600 }
Worker 0 -> Feedback { thread_id: 0, task_id: 0, finished_time: 2025-11-12T09:12:08.369393 }
Worker 2 -> Feedback { thread_id: 2, task_id: 3, finished_time: 2025-11-12T09:12:09.369849100 }
Worker 0 -> Feedback { thread_id: 0, task_id: 5, finished_time: 2025-11-12T09:12:09.369859800 }
Worker 1 -> Feedback { thread_id: 1, task_id: 4, finished_time: 2025-11-12T09:12:09.369868500 }
Worker 0 -> Feedback { thread_id: 0, task_id: 7, finished_time: 2025-11-12T09:12:10.370486800 }
Worker 2 -> Feedback { thread_id: 2, task_id: 6, finished_time: 2025-11-12T09:12:10.370496400 }
strategies num : 8
toal send_num : 8
thread pool num: 3, time cost sec: 3.0131664
可以看到,8个运行花时1秒的单策略,在这套线池程中花时约3秒左右。总体上符合预期。
其中,3个子线程构成的线程池,运行8个策略,分配如下:
css
线程0运行【0,5,7】子策略,花时3秒;
线程1运行【1,4】子策略,花时2秒;
线程2运行【2,3,6】子策略,花时3秒;
其中,线程0和线程2各花3秒,线程1花2秒,取最大值,故线程池共花时约3秒。此外,这个池程池还具备一定的负载均衡能力。