1use std::fmt::Debug;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use anyhow::{Result, anyhow};
11use flume;
12use log::{error, info, warn};
13use parking_lot::RwLock;
14use quanta::Clock;
15use rusty_common::SmallVec;
16use serde::de::DeserializeOwned;
17use smartstring::alias::String;
18use tokio::sync::watch;
19use tokio::task::JoinHandle;
20
21use crate::provider::prelude::*;
22use rusty_common::websocket::connector::WebSocketConnector;
23use rusty_common::websocket::heartbeat::{HeartbeatMonitor, HeartbeatStatus};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SessionState {
28 Uninitialized,
30 Initializing,
32 Active,
34 Unhealthy,
36 Inactive,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum SessionType {
43 Primary,
45 Backup,
47}
48
49#[derive(Debug, Clone)]
51pub struct SessionConfig {
52 pub session_type: SessionType,
54 pub url: String,
56 pub subscription_message: String,
58 pub connection_config: ConnectionConfig,
60}
61
62#[derive(Debug, Clone)]
64pub struct SessionStats {
65 pub connection_stats: ConnectionStats,
67 pub state: SessionState,
69 pub last_state_change: Instant,
71 pub activations: u32,
73 pub deactivations: u32,
75 pub total_active_time: Duration,
77 pub last_activation: Option<Instant>,
79}
80
81impl Default for SessionStats {
82 fn default() -> Self {
83 Self {
84 connection_stats: ConnectionStats::default(),
85 state: SessionState::Uninitialized,
86 last_state_change: Instant::now(),
87 activations: 0,
88 deactivations: 0,
89 total_active_time: Duration::from_secs(0),
90 last_activation: None,
91 }
92 }
93}
94
95#[derive(Debug)]
103pub struct SessionManager<T: Clone + serde::de::DeserializeOwned, const N: usize = 8> {
104 configs: SmallVec<[SessionConfig; N]>,
106 active_session: Arc<RwLock<usize>>,
108 session_states: SmallVec<[Arc<RwLock<SessionState>>; N]>,
110 session_stats: SmallVec<[Arc<RwLock<SessionStats>>; N]>,
112 connectors: SmallVec<[WebSocketConnector; N]>,
114 heartbeat_monitors: SmallVec<[Option<HeartbeatMonitor>; N]>,
116 websocket_handles: SmallVec<[Option<JoinHandle<()>>; N]>,
118 message_channels: SmallVec<[Option<flume::Sender<T>>; N]>,
120 output_sender: flume::Sender<T>,
122 output_receiver: flume::Receiver<T>,
124 stop_sender: watch::Sender<bool>,
126 stop_receiver: watch::Receiver<bool>,
128 clock: Clock,
130 last_health_check: Arc<RwLock<Instant>>,
132 health_check_interval_milliseconds: u64,
134 health_check_handle: Option<JoinHandle<()>>,
136 fast_switching_enabled: bool,
138 running: Arc<RwLock<bool>>,
140}
141
142impl<T: Clone + DeserializeOwned + Send + Sync + 'static, const N: usize> SessionManager<T, N> {
143 #[must_use]
145 pub fn new(
146 configs: SmallVec<[SessionConfig; N]>,
147 health_check_interval_milliseconds: u64,
148 channel_buffer_size: usize,
149 fast_switching_enabled: bool,
150 ) -> Self {
151 let (output_sender, output_receiver) = flume::bounded(channel_buffer_size);
152 let (stop_sender, stop_receiver) = watch::channel(false);
153 let clock = Clock::new();
154 let now = Instant::now();
155
156 let mut session_states = SmallVec::with_capacity(configs.len());
157 let mut session_stats = SmallVec::with_capacity(configs.len());
158 let mut connectors = SmallVec::with_capacity(configs.len());
159 let mut heartbeat_monitors = SmallVec::with_capacity(configs.len());
160 let mut websocket_handles = SmallVec::with_capacity(configs.len());
161 let mut message_channels = SmallVec::with_capacity(configs.len());
162
163 for config in &configs {
164 let connection_status = Arc::new(RwLock::new(ConnectionState::Disconnected));
165 let stats = Arc::new(RwLock::new(ConnectionStats::default()));
166 let session_state = Arc::new(RwLock::new(SessionState::Uninitialized));
167 let session_stat = Arc::new(RwLock::new(SessionStats::default()));
168
169 let ws_config = rusty_common::websocket::WebSocketConfig::new(
171 rusty_common::types::Exchange::Binance, config
173 .connection_config
174 .websocket_config
175 .base_url
176 .to_string(),
177 );
178 let connector = WebSocketConnector::new(
179 ws_config,
180 Arc::new(RwLock::new(
181 rusty_common::websocket::ConnectionStats::default(),
182 )),
183 Arc::new(RwLock::new(
184 rusty_common::websocket::ConnectionState::Disconnected,
185 )),
186 );
187
188 let heartbeat_monitor = if config
189 .connection_config
190 .websocket_config
191 .heartbeat_interval_milliseconds
192 > 0
193 {
194 Some(HeartbeatMonitor::new(
195 config
196 .connection_config
197 .websocket_config
198 .heartbeat_interval_milliseconds,
199 config
200 .connection_config
201 .websocket_config
202 .heartbeat_timeout_milliseconds,
203 config
204 .connection_config
205 .websocket_config
206 .max_missed_heartbeats,
207 Arc::new(RwLock::new(
208 rusty_common::websocket::ConnectionState::Disconnected,
209 )),
210 Arc::new(RwLock::new(
211 rusty_common::websocket::ConnectionStats::default(),
212 )),
213 ))
214 } else {
215 None
216 };
217
218 session_states.push(session_state);
219 session_stats.push(session_stat);
220 connectors.push(connector);
221 heartbeat_monitors.push(heartbeat_monitor);
222 websocket_handles.push(None);
223 message_channels.push(None);
224 }
225
226 Self {
227 configs,
228 active_session: Arc::new(RwLock::new(0)), session_states,
230 session_stats,
231 connectors,
232 heartbeat_monitors,
233 websocket_handles,
234 message_channels,
235 output_sender,
236 output_receiver,
237 stop_sender,
238 stop_receiver,
239 clock,
240 last_health_check: Arc::new(RwLock::new(now)),
241 health_check_interval_milliseconds,
242 health_check_handle: None,
243 fast_switching_enabled,
244 running: Arc::new(RwLock::new(false)),
245 }
246 }
247
248 pub async fn start(&mut self) -> Result<flume::Receiver<T>> {
250 if *self.running.read() {
251 return Err(anyhow!("Session manager is already running"));
252 }
253
254 if self.configs.is_empty() {
255 return Err(anyhow!("No session configurations provided"));
256 }
257
258 *self.running.write() = true;
260
261 for i in 0..self.configs.len() {
263 self.start_session(i).await?;
264 }
265
266 self.start_health_check_task();
268
269 Ok(self.output_receiver.clone())
271 }
272
273 pub async fn stop(&mut self) -> Result<()> {
275 if !*self.running.read() {
276 return Ok(());
277 }
278
279 *self.running.write() = false;
281
282 let _ = self.stop_sender.send(true);
284
285 if let Some(handle) = self.health_check_handle.take() {
287 handle.abort();
288 }
289
290 for i in 0..self.configs.len() {
292 self.stop_session(i).await?;
293 }
294
295 Ok(())
296 }
297
298 async fn start_session(&mut self, session_index: usize) -> Result<()> {
300 if session_index >= self.configs.len() {
301 return Err(anyhow!("Invalid session index"));
302 }
303
304 *self.session_states[session_index].write() = SessionState::Initializing;
306 self.session_stats[session_index].write().state = SessionState::Initializing;
307 self.session_stats[session_index].write().last_state_change = Instant::now();
308
309 let (tx, _rx) = flume::bounded(1024);
311 self.message_channels[session_index] = Some(tx.clone());
312
313 let config = &self.configs[session_index];
315 let url = config.url.clone();
316 let subscription_message = config.subscription_message.clone();
317
318 let handle = self
320 .create_websocket_connection(session_index, url, subscription_message, tx)
321 .await?;
322
323 self.websocket_handles[session_index] = Some(handle);
325
326 if session_index == 0 {
328 *self.active_session.write() = 0;
329 *self.session_states[session_index].write() = SessionState::Active;
330 self.session_stats[session_index].write().state = SessionState::Active;
331 self.session_stats[session_index].write().last_state_change = Instant::now();
332 self.session_stats[session_index].write().activations += 1;
333 self.session_stats[session_index].write().last_activation = Some(Instant::now());
334 } else {
335 *self.session_states[session_index].write() = SessionState::Inactive;
336 self.session_stats[session_index].write().state = SessionState::Inactive;
337 self.session_stats[session_index].write().last_state_change = Instant::now();
338 }
339
340 Ok(())
341 }
342
343 async fn stop_session(&mut self, session_index: usize) -> Result<()> {
345 if session_index >= self.configs.len() {
346 return Err(anyhow!("Invalid session index"));
347 }
348
349 if let Some(handle) = self.websocket_handles[session_index].take() {
351 handle.abort();
352 }
353
354 *self.session_states[session_index].write() = SessionState::Inactive;
356 self.session_stats[session_index].write().state = SessionState::Inactive;
357 self.session_stats[session_index].write().last_state_change = Instant::now();
358
359 if *self.active_session.read() == session_index {
361 self.session_stats[session_index].write().deactivations += 1;
362 if let Some(last_activation) = self.session_stats[session_index].write().last_activation
363 {
364 let active_duration = Instant::now().duration_since(last_activation);
365 self.session_stats[session_index].write().total_active_time += active_duration;
366 }
367 }
368
369 Ok(())
370 }
371
372 async fn create_websocket_connection(
374 &self,
375 session_index: usize,
376 _url: String,
377 _subscription_message: String,
378 _tx: flume::Sender<T>,
379 ) -> Result<JoinHandle<()>> {
380 let session_state = self.session_states[session_index].clone();
382 let session_stats = self.session_stats[session_index].clone();
383 let stop_receiver = self.stop_receiver.clone();
384 let running = self.running.clone();
385
386 let handle = tokio::spawn(async move {
387 loop {
388 if *stop_receiver.borrow() || !*running.read() {
390 break;
391 }
392
393 *session_state.write() = SessionState::Inactive;
395 session_stats.write().state = SessionState::Inactive;
396 session_stats.write().last_state_change = Instant::now();
397
398 tokio::time::sleep(Duration::from_secs(5)).await;
400
401 error!("WebSocket connection not implemented - migration to yawc in progress");
402 }
403 });
404
405 Ok(handle)
406 }
407
408 fn start_health_check_task(&mut self) {
410 let session_states = self.session_states.clone();
412 let session_stats = self.session_stats.clone();
413 let heartbeat_monitors = self.heartbeat_monitors.clone();
414 let active_session = self.active_session.clone();
415 let last_health_check = self.last_health_check.clone();
416 let health_check_interval_milliseconds = self.health_check_interval_milliseconds;
417 let stop_receiver = self.stop_receiver.clone();
418 let running = self.running.clone();
419 let fast_switching_enabled = self.fast_switching_enabled;
420
421 self.health_check_handle = Some(tokio::spawn(async move {
423 let mut interval =
424 tokio::time::interval(Duration::from_millis(health_check_interval_milliseconds));
425
426 while *running.read() {
427 interval.tick().await;
429
430 if *stop_receiver.borrow() || !*running.read() {
432 break;
433 }
434
435 *last_health_check.write() = Instant::now();
437
438 for i in 0..session_states.len() {
440 if let Some(ref hb) = heartbeat_monitors[i] {
441 let status = hb.check_heartbeat();
442 match status {
443 HeartbeatStatus::Healthy => {
444 if *active_session.read() == i {
446 *session_states[i].write() = SessionState::Active;
447 session_stats[i].write().state = SessionState::Active;
448 }
449 }
450 HeartbeatStatus::Unhealthy => {
451 if *session_states[i].read() == SessionState::Active {
453 *session_states[i].write() = SessionState::Unhealthy;
454 session_stats[i].write().state = SessionState::Unhealthy;
455 session_stats[i].write().last_state_change = Instant::now();
456 warn!("Session {i} is unhealthy");
457 }
458
459 if *active_session.read() == i && fast_switching_enabled {
462 for j in 0..session_states.len() {
464 if j != i
465 && *session_states[j].read() == SessionState::Inactive
466 {
467 if let Some(ref backup_hb) = heartbeat_monitors[j]
469 && backup_hb.check_heartbeat()
470 == HeartbeatStatus::Healthy
471 {
472 info!(
474 "Switching from unhealthy session {i} to healthy backup session {j}"
475 );
476 *active_session.write() = j;
477 *session_states[j].write() = SessionState::Active;
478 session_stats[j].write().state =
479 SessionState::Active;
480 session_stats[j].write().last_state_change =
481 Instant::now();
482 session_stats[j].write().activations += 1;
483 session_stats[j].write().last_activation =
484 Some(Instant::now());
485
486 session_stats[i].write().deactivations += 1;
488 if let Some(last_activation) =
489 session_stats[i].write().last_activation
490 {
491 let active_duration = Instant::now()
492 .duration_since(last_activation);
493 session_stats[i].write().total_active_time +=
494 active_duration;
495 }
496 break;
497 }
498 }
499 }
500 }
501 }
502 HeartbeatStatus::Dead => {
503 *session_states[i].write() = SessionState::Inactive;
505 session_stats[i].write().state = SessionState::Inactive;
506 session_stats[i].write().last_state_change = Instant::now();
507 error!("Session {i} is dead");
508
509 if *active_session.read() == i {
511 let mut found_backup = false;
513 for j in 0..session_states.len() {
514 if j != i
515 && (*session_states[j].read() == SessionState::Inactive
516 || *session_states[j].read()
517 == SessionState::Active)
518 {
519 if let Some(ref backup_hb) = heartbeat_monitors[j] {
521 let backup_status = backup_hb.check_heartbeat();
522 if backup_status == HeartbeatStatus::Healthy {
523 info!(
525 "Switching from dead session {i} to healthy backup session {j}"
526 );
527 *active_session.write() = j;
528 *session_states[j].write() =
529 SessionState::Active;
530 session_stats[j].write().state =
531 SessionState::Active;
532 session_stats[j].write().last_state_change =
533 Instant::now();
534 session_stats[j].write().activations += 1;
535 session_stats[j].write().last_activation =
536 Some(Instant::now());
537
538 session_stats[i].write().deactivations += 1;
540 if let Some(last_activation) =
541 session_stats[i].write().last_activation
542 {
543 let active_duration = Instant::now()
544 .duration_since(last_activation);
545 session_stats[i]
546 .write()
547 .total_active_time += active_duration;
548 }
549
550 found_backup = true;
551 break;
552 }
553 }
554 }
555 }
556
557 if !found_backup {
559 warn!(
560 "No healthy backup session found for dead session {i}, attempting to reconnect"
561 );
562 }
563 }
564 }
565 }
566 }
567 }
568 }
569 }));
570 }
571
572 pub fn active_session_index(&self) -> usize {
574 *self.active_session.read()
575 }
576
577 pub fn session_stats(&self, session_index: usize) -> Result<SessionStats> {
579 if session_index >= self.configs.len() {
580 return Err(anyhow!("Invalid session index"));
581 }
582
583 Ok(self.session_stats[session_index].read().clone())
584 }
585
586 pub fn all_session_stats(&self) -> Vec<SessionStats> {
588 self.session_stats
589 .iter()
590 .map(|stats| stats.read().clone())
591 .collect()
592 }
593
594 pub fn session_state(&self, session_index: usize) -> Result<SessionState> {
596 if session_index >= self.configs.len() {
597 return Err(anyhow!("Invalid session index"));
598 }
599
600 Ok(*self.session_states[session_index].read())
601 }
602
603 pub fn all_session_states(&self) -> Vec<SessionState> {
605 self.session_states
606 .iter()
607 .map(|state| *state.read())
608 .collect()
609 }
610
611 pub async fn switch_to_session(&mut self, session_index: usize) -> Result<()> {
613 if session_index >= self.configs.len() {
614 return Err(anyhow!("Invalid session index"));
615 }
616
617 let current_active = *self.active_session.read();
618
619 if current_active == session_index {
621 return Ok(());
622 }
623
624 *self.active_session.write() = session_index;
626
627 *self.session_states[current_active].write() = SessionState::Inactive;
629 self.session_stats[current_active].write().state = SessionState::Inactive;
630 self.session_stats[current_active].write().last_state_change = Instant::now();
631 self.session_stats[current_active].write().deactivations += 1;
632
633 if let Some(last_activation) = self.session_stats[current_active].write().last_activation {
634 let active_duration = Instant::now().duration_since(last_activation);
635 self.session_stats[current_active].write().total_active_time += active_duration;
636 }
637
638 *self.session_states[session_index].write() = SessionState::Active;
639 self.session_stats[session_index].write().state = SessionState::Active;
640 self.session_stats[session_index].write().last_state_change = Instant::now();
641 self.session_stats[session_index].write().activations += 1;
642 self.session_stats[session_index].write().last_activation = Some(Instant::now());
643
644 info!("Manually switched from session {current_active} to session {session_index}");
645
646 Ok(())
647 }
648
649 pub fn is_running(&self) -> bool {
651 *self.running.read()
652 }
653
654 pub fn time_since_last_health_check(&self) -> Duration {
656 Instant::now().duration_since(*self.last_health_check.read())
657 }
658
659 pub fn session_count(&self) -> usize {
661 self.configs.len()
662 }
663
664 pub fn session_config(&self, session_index: usize) -> Result<&SessionConfig> {
666 if session_index >= self.configs.len() {
667 return Err(anyhow!("Invalid session index"));
668 }
669
670 Ok(&self.configs[session_index])
671 }
672
673 pub fn all_session_configs(&self) -> &[SessionConfig] {
675 &self.configs
676 }
677}
678
679pub type DefaultSessionManager<T> = SessionManager<T, 8>;
682pub type SessionManager4<T> = SessionManager<T, 4>;
684pub type SessionManager16<T> = SessionManager<T, 16>;
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690
691 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
693 struct MockMessage {
694 pub content: String,
695 }
696
697 #[tokio::test]
698 async fn test_session_manager_creation() {
699 let configs: SmallVec<[SessionConfig; 8]> = smallvec::smallvec![
701 SessionConfig {
702 session_type: SessionType::Primary,
703 url: "wss://example.com/ws/primary".into(),
704 subscription_message:
705 r#"{"type":"subscribe","channels":["ticker"],"product_ids":["BTC-USD"]}"#.into(),
706 connection_config: ConnectionConfig::default(),
707 },
708 SessionConfig {
709 session_type: SessionType::Backup,
710 url: "wss://example.com/ws/backup".into(),
711 subscription_message:
712 r#"{"type":"subscribe","channels":["ticker"],"product_ids":["BTC-USD"]}"#.into(),
713 connection_config: ConnectionConfig::default(),
714 },
715 ];
716
717 let session_manager = SessionManager::<MockMessage>::new(
719 configs, 1000, 100, true, );
723
724 assert_eq!(session_manager.session_count(), 2);
726 assert_eq!(session_manager.active_session_index(), 0);
727 assert!(!session_manager.is_running());
728
729 let states = session_manager.all_session_states();
731 assert_eq!(states.len(), 2);
732 assert_eq!(states[0], SessionState::Uninitialized);
733 assert_eq!(states[1], SessionState::Uninitialized);
734 }
735}