1use rusty_common::collections::{FxHashMap, SmallSignalVec, SmallStrategyVec};
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use flume::{Receiver, Sender};
6use parking_lot::RwLock;
7use quanta::Clock;
8use smartstring::alias::String;
9use tokio::sync::mpsc;
10use tokio::time::{self, Duration};
11
12use rusty_model::{
13 data::{bar::Bar, book_snapshot::OrderBookSnapshot, market_trade::MarketTrade},
14 instruments::InstrumentId,
15};
16
17use crate::signals::Signal;
18use crate::strategy::Strategy;
19
20#[derive(Debug, Clone)]
22pub struct StrategyEngineConfig {
23 pub timer_interval_ms: u64,
25
26 pub trade_channel_capacity: usize,
28
29 pub depth_channel_capacity: usize,
31
32 pub bar_channel_capacity: usize,
34
35 pub signal_channel_capacity: usize,
37}
38
39impl Default for StrategyEngineConfig {
40 fn default() -> Self {
41 Self {
42 timer_interval_ms: 100,
43 trade_channel_capacity: 10000,
44 depth_channel_capacity: 1000,
45 bar_channel_capacity: 1000,
46 signal_channel_capacity: 1000,
47 }
48 }
49}
50
51pub struct StrategyEngine {
53 config: StrategyEngineConfig,
55
56 strategies: Arc<RwLock<FxHashMap<String, Arc<dyn Strategy>>>>,
58
59 instrument_map: Arc<RwLock<FxHashMap<InstrumentId, SmallStrategyVec<String>>>>,
61
62 clock: Clock,
64
65 trade_sender: mpsc::Sender<MarketTrade>,
67
68 depth_sender: mpsc::Sender<OrderBookSnapshot>,
70
71 bar_sender: mpsc::Sender<Bar>,
73
74 signal_sender: Sender<Signal>,
76
77 signal_receiver: Receiver<Signal>,
79
80 running: Arc<RwLock<bool>>,
82}
83
84impl Default for StrategyEngine {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl StrategyEngine {
91 #[must_use]
93 pub fn new() -> Self {
94 Self::with_config(StrategyEngineConfig::default())
95 }
96
97 #[must_use]
99 pub fn with_config(config: StrategyEngineConfig) -> Self {
100 let (trade_sender, _trade_receiver) = mpsc::channel(config.trade_channel_capacity);
102 let (depth_sender, _depth_receiver) = mpsc::channel(config.depth_channel_capacity);
103 let (bar_sender, _bar_receiver) = mpsc::channel(config.bar_channel_capacity);
104 let (signal_sender, signal_receiver) = flume::bounded(config.signal_channel_capacity);
105
106 Self {
107 config,
108 strategies: Arc::new(RwLock::new(FxHashMap::default())),
109 instrument_map: Arc::new(RwLock::new(FxHashMap::default())),
110 clock: Clock::new(),
111 trade_sender,
112 depth_sender,
113 bar_sender,
114 signal_sender,
115 signal_receiver,
116 running: Arc::new(RwLock::new(false)),
117 }
118 }
119
120 pub fn register_strategy(&self, strategy: Arc<dyn Strategy>) -> Result<()> {
122 let strategy_id: String = String::from(strategy.id());
123
124 self.strategies
126 .write()
127 .insert(strategy_id.clone(), strategy.clone());
128
129 let instruments = strategy.instruments();
131 let mut instrument_map = self.instrument_map.write();
132
133 for instrument in instruments {
134 let entry = instrument_map.entry(instrument.clone()).or_default();
135 entry.push(strategy_id.clone());
136 }
137
138 tokio::spawn(async move {
140 if let Err(e) = strategy.initialize().await {
141 log::error!("Failed to initialize strategy {strategy_id}: {e}");
142 }
143 });
144
145 Ok(())
146 }
147
148 pub async fn unregister_strategy(&self, strategy_id: &str) -> Result<()> {
150 let strategy = match self.strategies.read().get(strategy_id) {
152 Some(s) => s.clone(),
153 None => return Err(anyhow!("Strategy not found: {}", strategy_id)),
154 };
155
156 strategy.shutdown().await?;
158
159 self.strategies.write().remove(strategy_id);
161
162 let mut instrument_map = self.instrument_map.write();
164 for strategies in instrument_map.values_mut() {
165 strategies.retain(|id| id != strategy_id);
166 }
167
168 instrument_map.retain(|_, strategies| !strategies.is_empty());
170
171 Ok(())
172 }
173
174 pub async fn start(&mut self) -> Result<()> {
176 if *self.running.read() {
178 return Err(anyhow!("Strategy engine is already running"));
179 }
180
181 *self.running.write() = true;
183
184 let running_trade = self.running.clone();
186 let strategies_trade = self.strategies.clone();
187 let instrument_map_trade = self.instrument_map.clone();
188 let signal_sender_trade = self.signal_sender.clone();
189
190 let running_depth = self.running.clone();
191 let strategies_depth = self.strategies.clone();
192 let instrument_map_depth = self.instrument_map.clone();
193 let signal_sender_depth = self.signal_sender.clone();
194
195 let running_bar = self.running.clone();
196 let strategies_bar = self.strategies.clone();
197 let instrument_map_bar = self.instrument_map.clone();
198 let signal_sender_bar = self.signal_sender.clone();
199
200 let timer_interval = self.config.timer_interval_ms;
201
202 let (new_trade_sender, mut trade_receiver) =
204 mpsc::channel::<MarketTrade>(self.config.trade_channel_capacity);
205 let (new_depth_sender, mut depth_receiver) =
206 mpsc::channel::<OrderBookSnapshot>(self.config.depth_channel_capacity);
207 let (new_bar_sender, mut bar_receiver) =
208 mpsc::channel::<Bar>(self.config.bar_channel_capacity);
209
210 tokio::spawn(async move {
212 while *running_trade.read() {
213 if let Some(trade) = trade_receiver.recv().await {
214 let interested_strategies = {
216 let map = instrument_map_trade.read();
217 match map.get(&trade.instrument_id) {
218 Some(strategies) => strategies.clone(),
219 None => continue,
220 }
221 };
222
223 for strategy_id in interested_strategies {
225 let strategies_guard = strategies_trade.read();
226 let strategy = match strategies_guard.get(&strategy_id) {
227 Some(s) => s.clone(),
228 None => continue,
229 };
230
231 let trade_clone = trade.clone();
233 let signal_sender_clone = signal_sender_trade.clone();
234
235 tokio::spawn(Self::process_trade_with_signals(
237 strategy,
238 trade_clone,
239 signal_sender_clone,
240 strategy_id,
241 ));
242 }
243 }
244 }
245 });
246
247 tokio::spawn(async move {
249 while *running_depth.read() {
250 if let Some(depth) = depth_receiver.recv().await {
251 let interested_strategies = {
253 let map = instrument_map_depth.read();
254 match map.get(&depth.instrument_id) {
255 Some(strategies) => strategies.clone(),
256 None => continue,
257 }
258 };
259
260 for strategy_id in interested_strategies {
262 let strategies_guard = strategies_depth.read();
263 let strategy = match strategies_guard.get(&strategy_id) {
264 Some(s) => s.clone(),
265 None => continue,
266 };
267
268 let depth_clone = depth.clone();
270 let signal_sender_clone = signal_sender_depth.clone();
271
272 tokio::spawn(Self::process_depth_with_signals(
274 strategy,
275 depth_clone,
276 signal_sender_clone,
277 strategy_id,
278 ));
279 }
280 }
281 }
282 });
283
284 tokio::spawn(async move {
286 while *running_bar.read() {
287 if let Some(bar) = bar_receiver.recv().await {
288 let interested_strategies = {
292 let map = instrument_map_bar.read();
293 map.iter()
294 .filter(|(instrument_id, _)| {
295 instrument_id.symbol == bar.bar_type.symbol
296 })
297 .flat_map(|(_, strategy_list)| strategy_list.clone())
298 .collect::<SmallStrategyVec<_>>()
299 };
300
301 for strategy_id in interested_strategies {
303 let strategies_guard = strategies_bar.read();
304 let strategy = match strategies_guard.get(&strategy_id) {
305 Some(s) => s.clone(),
306 None => continue,
307 };
308
309 let bar_clone = bar.clone();
311 let signal_sender_clone = signal_sender_bar.clone();
312
313 tokio::spawn(Self::process_bar_with_signals(
315 strategy,
316 bar_clone,
317 signal_sender_clone,
318 strategy_id,
319 ));
320 }
321 }
322 }
323 });
324
325 let running_timer = self.running.clone();
327 let strategies_timer = self.strategies.clone();
328 let signal_sender_timer = self.signal_sender.clone();
329 let clock = self.clock.clone();
330
331 tokio::spawn(async move {
333 let mut interval = time::interval(Duration::from_millis(timer_interval));
334
335 while *running_timer.read() {
336 interval.tick().await;
337
338 let timestamp_ns = clock.raw();
340
341 let strategy_ids: SmallStrategyVec<String> =
343 strategies_timer.read().keys().cloned().collect();
344
345 for strategy_id in strategy_ids {
346 let strategies_guard = strategies_timer.read();
347 let strategy = match strategies_guard.get(&strategy_id) {
348 Some(s) => s.clone(),
349 None => continue,
350 };
351
352 let signal_sender_clone = signal_sender_timer.clone();
354
355 tokio::spawn(Self::process_timer_signals(
357 strategy,
358 timestamp_ns,
359 signal_sender_clone,
360 strategy_id,
361 ));
362 }
363 }
364 });
365
366 self.trade_sender = new_trade_sender;
368 self.depth_sender = new_depth_sender;
369 self.bar_sender = new_bar_sender;
370
371 Ok(())
372 }
373
374 pub async fn stop(&self) -> Result<()> {
376 if !*self.running.read() {
378 return Err(anyhow!("Strategy engine is not running"));
379 }
380
381 *self.running.write() = false;
383
384 let strategy_pairs: SmallStrategyVec<(String, Arc<dyn Strategy>)> = {
386 let strategies = self.strategies.read();
387 strategies
388 .iter()
389 .map(|(id, strategy)| (id.clone(), strategy.clone()))
390 .collect()
391 };
392
393 for (id, strategy) in strategy_pairs {
395 if let Err(e) = strategy.shutdown().await {
396 log::error!("Error shutting down strategy {id}: {e}");
397 }
398 }
399
400 Ok(())
401 }
402
403 #[must_use]
405 pub fn trade_sender(&self) -> mpsc::Sender<MarketTrade> {
406 self.trade_sender.clone()
407 }
408
409 #[must_use]
411 pub fn depth_sender(&self) -> mpsc::Sender<OrderBookSnapshot> {
412 self.depth_sender.clone()
413 }
414
415 #[must_use]
417 pub fn bar_sender(&self) -> mpsc::Sender<Bar> {
418 self.bar_sender.clone()
419 }
420
421 #[must_use]
423 pub fn signal_receiver(&self) -> Receiver<Signal> {
424 self.signal_receiver.clone()
425 }
426
427 async fn process_trade_with_signals(
430 strategy: Arc<dyn Strategy>,
431 trade: MarketTrade,
432 signal_sender: Sender<Signal>,
433 strategy_id: String,
434 ) {
435 match strategy.process_trade(trade).await {
436 Ok(signals) => Self::send_signals(signals, signal_sender).await,
437 Err(e) => log::error!("Error processing trade in strategy {strategy_id}: {e}"),
438 }
439 }
440
441 async fn process_depth_with_signals(
442 strategy: Arc<dyn Strategy>,
443 depth: OrderBookSnapshot,
444 signal_sender: Sender<Signal>,
445 strategy_id: String,
446 ) {
447 match strategy.process_depth(depth).await {
448 Ok(signals) => Self::send_signals(signals, signal_sender).await,
449 Err(e) => log::error!("Error processing depth in strategy {strategy_id}: {e}"),
450 }
451 }
452
453 async fn process_bar_with_signals(
454 strategy: Arc<dyn Strategy>,
455 bar: Bar,
456 signal_sender: Sender<Signal>,
457 strategy_id: String,
458 ) {
459 match strategy.process_bar(bar).await {
460 Ok(signals) => Self::send_signals(signals, signal_sender).await,
461 Err(e) => log::error!("Error processing bar in strategy {strategy_id}: {e}"),
462 }
463 }
464
465 async fn process_timer_signals(
466 strategy: Arc<dyn Strategy>,
467 timestamp_ns: u64,
468 signal_sender: Sender<Signal>,
469 strategy_id: String,
470 ) {
471 match strategy.on_timer(timestamp_ns).await {
472 Ok(signals) => Self::send_signals(signals, signal_sender).await,
473 Err(e) => log::error!("Error processing timer in strategy {strategy_id}: {e}"),
474 }
475 }
476
477 async fn send_signals(signals: SmallSignalVec<Signal>, signal_sender: Sender<Signal>) {
478 for signal in signals {
479 if let Err(e) = signal_sender.send_async(signal).await {
480 log::error!("Failed to send signal: {e}");
481 }
482 }
483 }
484}