1use crate::engine::{
7 BacktestConfig, BacktestStats, DefaultBacktestEngine, MarketDataEvent, Strategy,
8};
9use crate::latency::LatencyModel;
10use crate::matching::QueueModel;
11use parking_lot::{Mutex, RwLock};
12use rayon::prelude::*;
13use rust_decimal::Decimal;
14use rusty_common::SmartString;
15use rusty_common::collections::{FxHashMap, SmallStrategyVec, SmallSymbolVec};
16use smallvec::SmallVec;
17use std::sync::Arc;
19
20pub trait ParallelStrategy: Send + Sync {
22 fn clone_for_symbol(&self, symbol: &str) -> Box<dyn Strategy>;
24
25 fn on_completion(&mut self, results: &FxHashMap<SmartString, BacktestStats>);
27}
28
29pub struct ParallelConfig {
31 pub start_time_ns: u64,
33 pub end_time_ns: u64,
35 pub queue_model: QueueModel,
37 pub allow_partial_fills: bool,
39 pub order_latency: Box<dyn LatencyModel>,
41 pub market_data_latency: Box<dyn LatencyModel>,
43}
44
45#[derive(Clone)]
47pub struct SymbolConfig {
48 pub symbol: SmartString,
50 pub tick_size: Decimal,
52 pub lot_size: Decimal,
54}
55
56pub trait MarketDataProvider: Send + Sync {
58 fn get_events(
60 &self,
61 symbol: &str,
62 start_ns: u64,
63 end_ns: u64,
64 ) -> SmallVec<[(u64, MarketDataEvent); 256]>;
65}
66
67pub struct ParallelEngine {
69 config: ParallelConfig,
70 symbols: SmallSymbolVec<SymbolConfig>,
71 strategies: Arc<RwLock<SmallStrategyVec<Box<dyn ParallelStrategy>>>>,
72 data_provider: Arc<dyn MarketDataProvider>,
73}
74
75impl ParallelEngine {
76 #[must_use]
78 pub fn new(
79 config: ParallelConfig,
80 symbols: SmallSymbolVec<SymbolConfig>,
81 data_provider: Arc<dyn MarketDataProvider>,
82 ) -> Self {
83 Self {
84 config,
85 symbols,
86 strategies: Arc::new(RwLock::new(SmallStrategyVec::new())),
87 data_provider,
88 }
89 }
90
91 pub fn add_strategy(&self, strategy: Box<dyn ParallelStrategy>) {
93 self.strategies.write().push(strategy);
94 }
95
96 #[must_use]
98 pub fn run(&self) -> FxHashMap<SmartString, BacktestStats> {
99 let results: Vec<(SmartString, BacktestStats)> = self
101 .symbols
102 .par_iter()
103 .map(|symbol_config| self.run_single_symbol(symbol_config))
104 .collect();
105
106 let results_map: FxHashMap<SmartString, BacktestStats> = results.into_iter().collect();
108
109 let mut strategies = self.strategies.write();
111 for strategy in strategies.iter_mut() {
112 strategy.on_completion(&results_map);
113 }
114
115 results_map
116 }
117
118 fn run_single_symbol(&self, symbol_config: &SymbolConfig) -> (SmartString, BacktestStats) {
120 let mut config = BacktestConfig {
122 start_time_ns: self.config.start_time_ns,
123 end_time_ns: self.config.end_time_ns,
124 #[allow(clippy::disallowed_names)]
125 symbols: smallvec::smallvec![symbol_config.symbol.clone()],
126 tick_sizes: FxHashMap::default(),
127 lot_sizes: FxHashMap::default(),
128 queue_model: self.config.queue_model,
129 allow_partial_fills: self.config.allow_partial_fills,
130 order_latency: self.config.order_latency.clone_box(),
131 market_data_latency: self.config.market_data_latency.clone_box(),
132 conservative_mode: false,
133 market_impact: None,
134 conservative_params: None,
135 };
136
137 config
138 .tick_sizes
139 .insert(symbol_config.symbol.clone(), symbol_config.tick_size);
140 config
141 .lot_sizes
142 .insert(symbol_config.symbol.clone(), symbol_config.lot_size);
143
144 let engine = DefaultBacktestEngine::new(config);
146
147 let strategies = self.strategies.read();
149 for strategy in strategies.iter() {
150 let symbol_strategy = strategy.clone_for_symbol(&symbol_config.symbol);
151 engine.add_strategy(symbol_strategy);
152 }
153 drop(strategies);
154
155 let events = self.data_provider.get_events(
157 &symbol_config.symbol,
158 self.config.start_time_ns,
159 self.config.end_time_ns,
160 );
161
162 for (timestamp_ns, event) in events {
164 engine.add_market_data(timestamp_ns, symbol_config.symbol.clone(), event);
165 }
166
167 let stats = engine.run();
169
170 (symbol_config.symbol.clone(), stats)
171 }
172
173 #[must_use]
175 pub fn run_with_progress<F>(&self, progress_fn: F) -> FxHashMap<SmartString, BacktestStats>
176 where
177 F: Fn(&str, f64) + Sync + Send,
178 {
179 let total_symbols = self.symbols.len();
180 let completed = Arc::new(Mutex::new(0usize));
181
182 let results: Vec<(SmartString, BacktestStats)> = self
183 .symbols
184 .par_iter()
185 .map(|symbol_config| {
186 let result = self.run_single_symbol(symbol_config);
187
188 let mut count = completed.lock();
190 *count += 1;
191 let progress = (*count as f64) / (total_symbols as f64);
192 progress_fn(&symbol_config.symbol, progress);
193
194 result
195 })
196 .collect();
197
198 let results_map: FxHashMap<SmartString, BacktestStats> = results.into_iter().collect();
200
201 let mut strategies = self.strategies.write();
203 for strategy in strategies.iter_mut() {
204 strategy.on_completion(&results_map);
205 }
206
207 results_map
208 }
209}
210
211pub struct MultiSymbolBacktest {
213 engines: FxHashMap<SmartString, Arc<DefaultBacktestEngine>>,
215 shared_state: Arc<RwLock<SharedState>>,
217}
218
219#[derive(Default)]
221pub struct SharedState {
222 pub positions: FxHashMap<SmartString, f64>,
224 pub total_pnl: f64,
226 pub total_volume: f64,
228}
229
230impl Default for MultiSymbolBacktest {
231 fn default() -> Self {
232 Self::new()
233 }
234}
235
236impl MultiSymbolBacktest {
237 #[must_use]
239 pub fn new() -> Self {
240 Self {
241 engines: FxHashMap::default(),
242 shared_state: Arc::new(RwLock::new(SharedState::default())),
243 }
244 }
245
246 pub fn add_symbol(&mut self, symbol: SmartString, engine: Arc<DefaultBacktestEngine>) {
248 self.engines.insert(symbol, engine);
249 }
250
251 pub fn shared_state(&self) -> Arc<RwLock<SharedState>> {
253 self.shared_state.clone()
254 }
255
256 #[must_use]
258 pub fn run_parallel(&self) -> FxHashMap<SmartString, BacktestStats> {
259 let results: FxHashMap<SmartString, BacktestStats> = self
260 .engines
261 .par_iter()
262 .map(|(symbol, engine)| {
263 let stats = engine.run();
264 (symbol.clone(), stats)
265 })
266 .collect();
267
268 results.into_iter().collect()
269 }
270}
271
272pub struct SimpleMarketDataProvider {
274 data: Arc<RwLock<MarketDataStore>>,
275}
276
277type MarketDataStore = FxHashMap<SmartString, SmallVec<[(u64, MarketDataEvent); 256]>>;
278
279impl Default for SimpleMarketDataProvider {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285impl SimpleMarketDataProvider {
286 #[must_use]
288 pub fn new() -> Self {
289 Self {
290 data: Arc::new(RwLock::new(FxHashMap::default())),
291 }
292 }
293
294 pub fn add_data(&self, symbol: SmartString, events: SmallVec<[(u64, MarketDataEvent); 256]>) {
296 self.data.write().insert(symbol, events);
297 }
298}
299
300impl MarketDataProvider for SimpleMarketDataProvider {
301 fn get_events(
302 &self,
303 symbol: &str,
304 start_ns: u64,
305 end_ns: u64,
306 ) -> SmallVec<[(u64, MarketDataEvent); 256]> {
307 self.data
308 .read()
309 .get(symbol)
310 .map(|events| {
311 events
312 .iter()
313 .filter(|(ts, _)| *ts >= start_ns && *ts <= end_ns)
314 .cloned()
315 .collect()
316 })
317 .unwrap_or_default()
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::engine::Strategy;
325 use crate::latency::FixedLatency;
326 use crate::matching::{Order, OrderSide, OrderType, QueuePosition, TimeInForce};
327 use crate::orderbook::OrderBook;
328 use rust_decimal_macros::dec;
329 use rusty_common::SmallOrderVec;
330
331 struct TestParallelStrategy {}
332
333 impl ParallelStrategy for TestParallelStrategy {
334 fn clone_for_symbol(&self, symbol: &str) -> Box<dyn Strategy> {
335 Box::new(TestSymbolStrategy {
336 symbol: SmartString::from(symbol),
337 orders_sent: false,
338 })
339 }
340
341 fn on_completion(&mut self, results: &FxHashMap<SmartString, BacktestStats>) {
342 println!("Backtest complete: {results:?}");
343 }
344 }
345
346 struct TestSymbolStrategy {
347 symbol: SmartString,
348 orders_sent: bool,
349 }
350
351 impl Strategy for TestSymbolStrategy {
352 fn on_market_data(&mut self, _symbol: &str, _event: &MarketDataEvent, _book: &OrderBook) {
353 }
355
356 fn on_order_response(&mut self, _response: &crate::engine::OrderResponse) {
357 }
359
360 fn on_timer(&mut self, _timer_id: u64) {
361 }
363
364 fn get_orders(&mut self) -> SmallOrderVec<Order> {
365 if !self.orders_sent && self.symbol == "BTC-USD" {
366 self.orders_sent = true;
367 smallvec::smallvec![Order {
368 id: 1,
369 symbol: self.symbol.clone(),
370 side: OrderSide::Buy,
371 order_type: OrderType::Limit,
372 price: dec!(50000),
373 quantity: Decimal::ONE,
374 remaining_quantity: Decimal::ONE,
375 time_in_force: TimeInForce::GTC,
376 timestamp_ns: 1_000_000,
377 queue_position: QueuePosition::RiskAverse,
378 }]
379 } else {
380 SmallOrderVec::new()
381 }
382 }
383
384 fn get_cancels(&mut self) -> SmallOrderVec<u64> {
385 SmallOrderVec::new()
386 }
387 }
388
389 #[test]
390 fn test_parallel_engine() {
391 let config = ParallelConfig {
392 start_time_ns: 0,
393 end_time_ns: 10_000_000_000,
394 queue_model: QueueModel::FIFO,
395 allow_partial_fills: true,
396 order_latency: Box::new(FixedLatency::new(100_000)),
397 market_data_latency: Box::new(FixedLatency::new(50_000)),
398 };
399
400 let symbols = smallvec::smallvec![
401 SymbolConfig {
402 symbol: "BTC-USD".into(),
403 tick_size: dec!(0.01),
404 lot_size: Decimal::new(1, 3), },
406 SymbolConfig {
407 symbol: "ETH-USD".into(),
408 tick_size: dec!(0.01),
409 lot_size: Decimal::new(1, 3), },
411 ];
412
413 let provider = Arc::new(SimpleMarketDataProvider::new());
415
416 provider.add_data(
418 "BTC-USD".into(),
419 smallvec::smallvec![
420 (
421 1_000_000,
422 MarketDataEvent::DepthUpdate {
423 side: OrderSide::Sell,
424 price: dec!(50001),
425 quantity: Decimal::TWO,
426 order_count: 1,
427 },
428 ),
429 (
430 2_000_000,
431 MarketDataEvent::Trade {
432 side: OrderSide::Sell,
433 price: dec!(50000),
434 quantity: Decimal::new(5, 1), },
436 ),
437 ],
438 );
439
440 provider.add_data(
442 "ETH-USD".into(),
443 smallvec::smallvec![(
444 1_000_000,
445 MarketDataEvent::DepthUpdate {
446 side: OrderSide::Buy,
447 price: dec!(3000),
448 quantity: Decimal::from(10),
449 order_count: 2,
450 },
451 )],
452 );
453
454 let engine = ParallelEngine::new(config, symbols, provider);
455 engine.add_strategy(Box::new(TestParallelStrategy {}));
456
457 let results = engine.run_with_progress(|symbol, progress| {
459 println!("{}: {:.1}%", symbol, progress * 100.0);
460 });
461
462 assert_eq!(results.len(), 2);
463 assert!(results.contains_key("BTC-USD"));
464 assert!(results.contains_key("ETH-USD"));
465
466 let btc_stats = &results["BTC-USD"];
468 assert_eq!(btc_stats.orders_submitted, 1);
469 assert_eq!(btc_stats.orders_filled, 1);
470 }
471}