rusty_strategy/
engine.rs

1use rusty_common::collections::{FxHashMap, SmallSignalVec, SmallStrategyVec};
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use flume::{Receiver, Sender};
6use parking_lot::RwLock;
7use quanta::Clock;
8use smartstring::alias::String;
9use tokio::sync::mpsc;
10use tokio::time::{self, Duration};
11
12use rusty_model::{
13    data::{bar::Bar, book_snapshot::OrderBookSnapshot, market_trade::MarketTrade},
14    instruments::InstrumentId,
15};
16
17use crate::signals::Signal;
18use crate::strategy::Strategy;
19
20/// Configuration options for the strategy engine
21#[derive(Debug, Clone)]
22pub struct StrategyEngineConfig {
23    /// Timer interval in milliseconds
24    pub timer_interval_ms: u64,
25
26    /// Channel capacity for trade messages
27    pub trade_channel_capacity: usize,
28
29    /// Channel capacity for order book depth messages
30    pub depth_channel_capacity: usize,
31
32    /// Channel capacity for bar messages
33    pub bar_channel_capacity: usize,
34
35    /// Channel capacity for signal messages
36    pub signal_channel_capacity: usize,
37}
38
39impl Default for StrategyEngineConfig {
40    fn default() -> Self {
41        Self {
42            timer_interval_ms: 100,
43            trade_channel_capacity: 10000,
44            depth_channel_capacity: 1000,
45            bar_channel_capacity: 1000,
46            signal_channel_capacity: 1000,
47        }
48    }
49}
50
51/// Strategy engine that coordinates data flow and strategy execution
52pub struct StrategyEngine {
53    /// Engine configuration
54    config: StrategyEngineConfig,
55
56    /// Registered strategies
57    strategies: Arc<RwLock<FxHashMap<String, Arc<dyn Strategy>>>>,
58
59    /// Map of instrument IDs to interested strategies (using `SmallStrategyVec` for performance)
60    instrument_map: Arc<RwLock<FxHashMap<InstrumentId, SmallStrategyVec<String>>>>,
61
62    /// Clock for precise timing
63    clock: Clock,
64
65    /// Sender for trade messages
66    trade_sender: mpsc::Sender<MarketTrade>,
67
68    /// Sender for order book depth messages
69    depth_sender: mpsc::Sender<OrderBookSnapshot>,
70
71    /// Sender for bar messages
72    bar_sender: mpsc::Sender<Bar>,
73
74    /// Sender for signal messages (output)
75    signal_sender: Sender<Signal>,
76
77    /// Receiver for signal messages (output)
78    signal_receiver: Receiver<Signal>,
79
80    /// Flag indicating if the engine is running
81    running: Arc<RwLock<bool>>,
82}
83
84impl Default for StrategyEngine {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl StrategyEngine {
91    /// Creates a new strategy engine with the default configuration
92    #[must_use]
93    pub fn new() -> Self {
94        Self::with_config(StrategyEngineConfig::default())
95    }
96
97    /// Creates a new strategy engine with a custom configuration
98    #[must_use]
99    pub fn with_config(config: StrategyEngineConfig) -> Self {
100        // Create channels for data flow
101        let (trade_sender, _trade_receiver) = mpsc::channel(config.trade_channel_capacity);
102        let (depth_sender, _depth_receiver) = mpsc::channel(config.depth_channel_capacity);
103        let (bar_sender, _bar_receiver) = mpsc::channel(config.bar_channel_capacity);
104        let (signal_sender, signal_receiver) = flume::bounded(config.signal_channel_capacity);
105
106        Self {
107            config,
108            strategies: Arc::new(RwLock::new(FxHashMap::default())),
109            instrument_map: Arc::new(RwLock::new(FxHashMap::default())),
110            clock: Clock::new(),
111            trade_sender,
112            depth_sender,
113            bar_sender,
114            signal_sender,
115            signal_receiver,
116            running: Arc::new(RwLock::new(false)),
117        }
118    }
119
120    /// Registers a strategy with the engine
121    pub fn register_strategy(&self, strategy: Arc<dyn Strategy>) -> Result<()> {
122        let strategy_id: String = String::from(strategy.id());
123
124        // Register strategy
125        self.strategies
126            .write()
127            .insert(strategy_id.clone(), strategy.clone());
128
129        // Update instrument map
130        let instruments = strategy.instruments();
131        let mut instrument_map = self.instrument_map.write();
132
133        for instrument in instruments {
134            let entry = instrument_map.entry(instrument.clone()).or_default();
135            entry.push(strategy_id.clone());
136        }
137
138        // Initialize the strategy
139        tokio::spawn(async move {
140            if let Err(e) = strategy.initialize().await {
141                log::error!("Failed to initialize strategy {strategy_id}: {e}");
142            }
143        });
144
145        Ok(())
146    }
147
148    /// Unregisters a strategy from the engine
149    pub async fn unregister_strategy(&self, strategy_id: &str) -> Result<()> {
150        // Get the strategy first to shut it down
151        let strategy = match self.strategies.read().get(strategy_id) {
152            Some(s) => s.clone(),
153            None => return Err(anyhow!("Strategy not found: {}", strategy_id)),
154        };
155
156        // Shut down the strategy
157        strategy.shutdown().await?;
158
159        // Remove from strategies map
160        self.strategies.write().remove(strategy_id);
161
162        // Update instrument map
163        let mut instrument_map = self.instrument_map.write();
164        for strategies in instrument_map.values_mut() {
165            strategies.retain(|id| id != strategy_id);
166        }
167
168        // Clean up empty entries
169        instrument_map.retain(|_, strategies| !strategies.is_empty());
170
171        Ok(())
172    }
173
174    /// Starts the strategy engine
175    pub async fn start(&mut self) -> Result<()> {
176        // Check if already running
177        if *self.running.read() {
178            return Err(anyhow!("Strategy engine is already running"));
179        }
180
181        // Set running flag
182        *self.running.write() = true;
183
184        // Create separate clones for each task to avoid ownership issues
185        let running_trade = self.running.clone();
186        let strategies_trade = self.strategies.clone();
187        let instrument_map_trade = self.instrument_map.clone();
188        let signal_sender_trade = self.signal_sender.clone();
189
190        let running_depth = self.running.clone();
191        let strategies_depth = self.strategies.clone();
192        let instrument_map_depth = self.instrument_map.clone();
193        let signal_sender_depth = self.signal_sender.clone();
194
195        let running_bar = self.running.clone();
196        let strategies_bar = self.strategies.clone();
197        let instrument_map_bar = self.instrument_map.clone();
198        let signal_sender_bar = self.signal_sender.clone();
199
200        let timer_interval = self.config.timer_interval_ms;
201
202        // Create new channel pairs since we need to move the receivers
203        let (new_trade_sender, mut trade_receiver) =
204            mpsc::channel::<MarketTrade>(self.config.trade_channel_capacity);
205        let (new_depth_sender, mut depth_receiver) =
206            mpsc::channel::<OrderBookSnapshot>(self.config.depth_channel_capacity);
207        let (new_bar_sender, mut bar_receiver) =
208            mpsc::channel::<Bar>(self.config.bar_channel_capacity);
209
210        // Spawn trade processing task
211        tokio::spawn(async move {
212            while *running_trade.read() {
213                if let Some(trade) = trade_receiver.recv().await {
214                    // Find strategies interested in this instrument
215                    let interested_strategies = {
216                        let map = instrument_map_trade.read();
217                        match map.get(&trade.instrument_id) {
218                            Some(strategies) => strategies.clone(),
219                            None => continue,
220                        }
221                    };
222
223                    // Process trade with each interested strategy
224                    for strategy_id in interested_strategies {
225                        let strategies_guard = strategies_trade.read();
226                        let strategy = match strategies_guard.get(&strategy_id) {
227                            Some(s) => s.clone(),
228                            None => continue,
229                        };
230
231                        // Clone values for async task
232                        let trade_clone = trade.clone();
233                        let signal_sender_clone = signal_sender_trade.clone();
234
235                        // Process trade asynchronously
236                        tokio::spawn(Self::process_trade_with_signals(
237                            strategy,
238                            trade_clone,
239                            signal_sender_clone,
240                            strategy_id,
241                        ));
242                    }
243                }
244            }
245        });
246
247        // Spawn depth processing task
248        tokio::spawn(async move {
249            while *running_depth.read() {
250                if let Some(depth) = depth_receiver.recv().await {
251                    // Find strategies interested in this instrument
252                    let interested_strategies = {
253                        let map = instrument_map_depth.read();
254                        match map.get(&depth.instrument_id) {
255                            Some(strategies) => strategies.clone(),
256                            None => continue,
257                        }
258                    };
259
260                    // Process depth with each interested strategy
261                    for strategy_id in interested_strategies {
262                        let strategies_guard = strategies_depth.read();
263                        let strategy = match strategies_guard.get(&strategy_id) {
264                            Some(s) => s.clone(),
265                            None => continue,
266                        };
267
268                        // Clone values for async task
269                        let depth_clone = depth.clone();
270                        let signal_sender_clone = signal_sender_depth.clone();
271
272                        // Process depth asynchronously
273                        tokio::spawn(Self::process_depth_with_signals(
274                            strategy,
275                            depth_clone,
276                            signal_sender_clone,
277                            strategy_id,
278                        ));
279                    }
280                }
281            }
282        });
283
284        // Spawn bar processing task
285        tokio::spawn(async move {
286            while *running_bar.read() {
287                if let Some(bar) = bar_receiver.recv().await {
288                    // Find strategies interested in this instrument
289                    // For now, we'll check all strategies since Bar doesn't have venue info
290                    // TODO: Consider adding venue to Bar or using a different mapping strategy
291                    let interested_strategies = {
292                        let map = instrument_map_bar.read();
293                        map.iter()
294                            .filter(|(instrument_id, _)| {
295                                instrument_id.symbol == bar.bar_type.symbol
296                            })
297                            .flat_map(|(_, strategy_list)| strategy_list.clone())
298                            .collect::<SmallStrategyVec<_>>()
299                    };
300
301                    // Process bar with each interested strategy
302                    for strategy_id in interested_strategies {
303                        let strategies_guard = strategies_bar.read();
304                        let strategy = match strategies_guard.get(&strategy_id) {
305                            Some(s) => s.clone(),
306                            None => continue,
307                        };
308
309                        // Clone values for async task
310                        let bar_clone = bar.clone();
311                        let signal_sender_clone = signal_sender_bar.clone();
312
313                        // Process bar asynchronously
314                        tokio::spawn(Self::process_bar_with_signals(
315                            strategy,
316                            bar_clone,
317                            signal_sender_clone,
318                            strategy_id,
319                        ));
320                    }
321                }
322            }
323        });
324
325        // Clone for timer task
326        let running_timer = self.running.clone();
327        let strategies_timer = self.strategies.clone();
328        let signal_sender_timer = self.signal_sender.clone();
329        let clock = self.clock.clone();
330
331        // Spawn timer task
332        tokio::spawn(async move {
333            let mut interval = time::interval(Duration::from_millis(timer_interval));
334
335            while *running_timer.read() {
336                interval.tick().await;
337
338                // Get current timestamp in nanoseconds
339                let timestamp_ns = clock.raw();
340
341                // Process timer event with each strategy
342                let strategy_ids: SmallStrategyVec<String> =
343                    strategies_timer.read().keys().cloned().collect();
344
345                for strategy_id in strategy_ids {
346                    let strategies_guard = strategies_timer.read();
347                    let strategy = match strategies_guard.get(&strategy_id) {
348                        Some(s) => s.clone(),
349                        None => continue,
350                    };
351
352                    // Clone values for async task
353                    let signal_sender_clone = signal_sender_timer.clone();
354
355                    // Process timer event asynchronously
356                    tokio::spawn(Self::process_timer_signals(
357                        strategy,
358                        timestamp_ns,
359                        signal_sender_clone,
360                        strategy_id,
361                    ));
362                }
363            }
364        });
365
366        // Save the new channel senders
367        self.trade_sender = new_trade_sender;
368        self.depth_sender = new_depth_sender;
369        self.bar_sender = new_bar_sender;
370
371        Ok(())
372    }
373
374    /// Stops the strategy engine
375    pub async fn stop(&self) -> Result<()> {
376        // Check if already stopped
377        if !*self.running.read() {
378            return Err(anyhow!("Strategy engine is not running"));
379        }
380
381        // Set running flag to false
382        *self.running.write() = false;
383
384        // Collect all strategies first before awaiting to avoid holding the lock across await points
385        let strategy_pairs: SmallStrategyVec<(String, Arc<dyn Strategy>)> = {
386            let strategies = self.strategies.read();
387            strategies
388                .iter()
389                .map(|(id, strategy)| (id.clone(), strategy.clone()))
390                .collect()
391        };
392
393        // Shut down all strategies
394        for (id, strategy) in strategy_pairs {
395            if let Err(e) = strategy.shutdown().await {
396                log::error!("Error shutting down strategy {id}: {e}");
397            }
398        }
399
400        Ok(())
401    }
402
403    /// Gets the trade sender channel
404    #[must_use]
405    pub fn trade_sender(&self) -> mpsc::Sender<MarketTrade> {
406        self.trade_sender.clone()
407    }
408
409    /// Gets the depth sender channel
410    #[must_use]
411    pub fn depth_sender(&self) -> mpsc::Sender<OrderBookSnapshot> {
412        self.depth_sender.clone()
413    }
414
415    /// Gets the bar sender channel
416    #[must_use]
417    pub fn bar_sender(&self) -> mpsc::Sender<Bar> {
418        self.bar_sender.clone()
419    }
420
421    /// Gets the signal receiver channel
422    #[must_use]
423    pub fn signal_receiver(&self) -> Receiver<Signal> {
424        self.signal_receiver.clone()
425    }
426
427    // Helper methods to reduce nesting
428
429    async fn process_trade_with_signals(
430        strategy: Arc<dyn Strategy>,
431        trade: MarketTrade,
432        signal_sender: Sender<Signal>,
433        strategy_id: String,
434    ) {
435        match strategy.process_trade(trade).await {
436            Ok(signals) => Self::send_signals(signals, signal_sender).await,
437            Err(e) => log::error!("Error processing trade in strategy {strategy_id}: {e}"),
438        }
439    }
440
441    async fn process_depth_with_signals(
442        strategy: Arc<dyn Strategy>,
443        depth: OrderBookSnapshot,
444        signal_sender: Sender<Signal>,
445        strategy_id: String,
446    ) {
447        match strategy.process_depth(depth).await {
448            Ok(signals) => Self::send_signals(signals, signal_sender).await,
449            Err(e) => log::error!("Error processing depth in strategy {strategy_id}: {e}"),
450        }
451    }
452
453    async fn process_bar_with_signals(
454        strategy: Arc<dyn Strategy>,
455        bar: Bar,
456        signal_sender: Sender<Signal>,
457        strategy_id: String,
458    ) {
459        match strategy.process_bar(bar).await {
460            Ok(signals) => Self::send_signals(signals, signal_sender).await,
461            Err(e) => log::error!("Error processing bar in strategy {strategy_id}: {e}"),
462        }
463    }
464
465    async fn process_timer_signals(
466        strategy: Arc<dyn Strategy>,
467        timestamp_ns: u64,
468        signal_sender: Sender<Signal>,
469        strategy_id: String,
470    ) {
471        match strategy.on_timer(timestamp_ns).await {
472            Ok(signals) => Self::send_signals(signals, signal_sender).await,
473            Err(e) => log::error!("Error processing timer in strategy {strategy_id}: {e}"),
474        }
475    }
476
477    async fn send_signals(signals: SmallSignalVec<Signal>, signal_sender: Sender<Signal>) {
478        for signal in signals {
479            if let Err(e) = signal_sender.send_async(signal).await {
480                log::error!("Failed to send signal: {e}");
481            }
482        }
483    }
484}