rusty_feeder/
session_manager.rs

1//! Session management for WebSocket connections
2//!
3//! This module provides a session manager for maintaining multiple WebSocket connections
4//! and performing fast switching between them when a connection is detected as dead.
5
6use std::fmt::Debug;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use anyhow::{Result, anyhow};
11use flume;
12use log::{error, info, warn};
13use parking_lot::RwLock;
14use quanta::Clock;
15use rusty_common::SmallVec;
16use serde::de::DeserializeOwned;
17use smartstring::alias::String;
18use tokio::sync::watch;
19use tokio::task::JoinHandle;
20
21use crate::provider::prelude::*;
22use rusty_common::websocket::connector::WebSocketConnector;
23use rusty_common::websocket::heartbeat::{HeartbeatMonitor, HeartbeatStatus};
24
25/// Session state
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum SessionState {
28    /// Session is not initialized
29    Uninitialized,
30    /// Session is initializing
31    Initializing,
32    /// Session is active and healthy
33    Active,
34    /// Session is active but unhealthy (missed some heartbeats)
35    Unhealthy,
36    /// Session is inactive (disconnected or failed)
37    Inactive,
38}
39
40/// Session type
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum SessionType {
43    /// Primary session
44    Primary,
45    /// Backup session
46    Backup,
47}
48
49/// Session configuration
50#[derive(Debug, Clone)]
51pub struct SessionConfig {
52    /// Session type
53    pub session_type: SessionType,
54    /// WebSocket URL
55    pub url: String,
56    /// Subscription message
57    pub subscription_message: String,
58    /// Connection configuration
59    pub connection_config: ConnectionConfig,
60}
61
62/// Session statistics
63#[derive(Debug, Clone)]
64pub struct SessionStats {
65    /// Connection statistics
66    pub connection_stats: ConnectionStats,
67    /// Session state
68    pub state: SessionState,
69    /// Last state change time
70    pub last_state_change: Instant,
71    /// Number of times this session has been activated
72    pub activations: u32,
73    /// Number of times this session has been deactivated
74    pub deactivations: u32,
75    /// Total time this session has been active
76    pub total_active_time: Duration,
77    /// Last activation time
78    pub last_activation: Option<Instant>,
79}
80
81impl Default for SessionStats {
82    fn default() -> Self {
83        Self {
84            connection_stats: ConnectionStats::default(),
85            state: SessionState::Uninitialized,
86            last_state_change: Instant::now(),
87            activations: 0,
88            deactivations: 0,
89            total_active_time: Duration::from_secs(0),
90            last_activation: None,
91        }
92    }
93}
94
95/// Session manager for handling multiple concurrent connections to exchanges
96/// with failover, load balancing, and session persistence.
97///
98/// # Type Parameters
99/// - `T`: Message type that the session manager handles
100/// - `N`: Maximum number of concurrent sessions (default: 8)
101///   Optimized for typical HFT scenarios with multiple exchange connections
102#[derive(Debug)]
103pub struct SessionManager<T: Clone + serde::de::DeserializeOwned, const N: usize = 8> {
104    /// Session configurations
105    configs: SmallVec<[SessionConfig; N]>,
106    /// Active session index
107    active_session: Arc<RwLock<usize>>,
108    /// Session states
109    session_states: SmallVec<[Arc<RwLock<SessionState>>; N]>,
110    /// Session statistics
111    session_stats: SmallVec<[Arc<RwLock<SessionStats>>; N]>,
112    /// WebSocket connectors
113    connectors: SmallVec<[WebSocketConnector; N]>,
114    /// Heartbeat monitors
115    heartbeat_monitors: SmallVec<[Option<HeartbeatMonitor>; N]>,
116    /// WebSocket task handles
117    websocket_handles: SmallVec<[Option<JoinHandle<()>>; N]>,
118    /// Message channels
119    message_channels: SmallVec<[Option<flume::Sender<T>>; N]>,
120    /// Output channel
121    output_sender: flume::Sender<T>,
122    /// Output receiver
123    output_receiver: flume::Receiver<T>,
124    /// Stop signal
125    stop_sender: watch::Sender<bool>,
126    /// Stop receiver
127    stop_receiver: watch::Receiver<bool>,
128    /// Shared clock
129    clock: Clock,
130    /// Last health check time
131    last_health_check: Arc<RwLock<Instant>>,
132    /// Health check interval in milliseconds
133    health_check_interval_milliseconds: u64,
134    /// Health check task handle
135    health_check_handle: Option<JoinHandle<()>>,
136    /// Whether fast switching is enabled
137    fast_switching_enabled: bool,
138    /// Whether the session manager is running
139    running: Arc<RwLock<bool>>,
140}
141
142impl<T: Clone + DeserializeOwned + Send + Sync + 'static, const N: usize> SessionManager<T, N> {
143    /// Create a new session manager
144    #[must_use]
145    pub fn new(
146        configs: SmallVec<[SessionConfig; N]>,
147        health_check_interval_milliseconds: u64,
148        channel_buffer_size: usize,
149        fast_switching_enabled: bool,
150    ) -> Self {
151        let (output_sender, output_receiver) = flume::bounded(channel_buffer_size);
152        let (stop_sender, stop_receiver) = watch::channel(false);
153        let clock = Clock::new();
154        let now = Instant::now();
155
156        let mut session_states = SmallVec::with_capacity(configs.len());
157        let mut session_stats = SmallVec::with_capacity(configs.len());
158        let mut connectors = SmallVec::with_capacity(configs.len());
159        let mut heartbeat_monitors = SmallVec::with_capacity(configs.len());
160        let mut websocket_handles = SmallVec::with_capacity(configs.len());
161        let mut message_channels = SmallVec::with_capacity(configs.len());
162
163        for config in &configs {
164            let connection_status = Arc::new(RwLock::new(ConnectionState::Disconnected));
165            let stats = Arc::new(RwLock::new(ConnectionStats::default()));
166            let session_state = Arc::new(RwLock::new(SessionState::Uninitialized));
167            let session_stat = Arc::new(RwLock::new(SessionStats::default()));
168
169            // Create WebSocket configuration from connection config
170            let ws_config = rusty_common::websocket::WebSocketConfig::new(
171                rusty_common::types::Exchange::Binance, // Default exchange, can be overridden
172                config
173                    .connection_config
174                    .websocket_config
175                    .base_url
176                    .to_string(),
177            );
178            let connector = WebSocketConnector::new(
179                ws_config,
180                Arc::new(RwLock::new(
181                    rusty_common::websocket::ConnectionStats::default(),
182                )),
183                Arc::new(RwLock::new(
184                    rusty_common::websocket::ConnectionState::Disconnected,
185                )),
186            );
187
188            let heartbeat_monitor = if config
189                .connection_config
190                .websocket_config
191                .heartbeat_interval_milliseconds
192                > 0
193            {
194                Some(HeartbeatMonitor::new(
195                    config
196                        .connection_config
197                        .websocket_config
198                        .heartbeat_interval_milliseconds,
199                    config
200                        .connection_config
201                        .websocket_config
202                        .heartbeat_timeout_milliseconds,
203                    config
204                        .connection_config
205                        .websocket_config
206                        .max_missed_heartbeats,
207                    Arc::new(RwLock::new(
208                        rusty_common::websocket::ConnectionState::Disconnected,
209                    )),
210                    Arc::new(RwLock::new(
211                        rusty_common::websocket::ConnectionStats::default(),
212                    )),
213                ))
214            } else {
215                None
216            };
217
218            session_states.push(session_state);
219            session_stats.push(session_stat);
220            connectors.push(connector);
221            heartbeat_monitors.push(heartbeat_monitor);
222            websocket_handles.push(None);
223            message_channels.push(None);
224        }
225
226        Self {
227            configs,
228            active_session: Arc::new(RwLock::new(0)), // Default to first session (primary)
229            session_states,
230            session_stats,
231            connectors,
232            heartbeat_monitors,
233            websocket_handles,
234            message_channels,
235            output_sender,
236            output_receiver,
237            stop_sender,
238            stop_receiver,
239            clock,
240            last_health_check: Arc::new(RwLock::new(now)),
241            health_check_interval_milliseconds,
242            health_check_handle: None,
243            fast_switching_enabled,
244            running: Arc::new(RwLock::new(false)),
245        }
246    }
247
248    /// Start the session manager
249    pub async fn start(&mut self) -> Result<flume::Receiver<T>> {
250        if *self.running.read() {
251            return Err(anyhow!("Session manager is already running"));
252        }
253
254        if self.configs.is_empty() {
255            return Err(anyhow!("No session configurations provided"));
256        }
257
258        // Set running flag
259        *self.running.write() = true;
260
261        // Start all sessions
262        for i in 0..self.configs.len() {
263            self.start_session(i).await?;
264        }
265
266        // Start health check task
267        self.start_health_check_task();
268
269        // Return the output receiver
270        Ok(self.output_receiver.clone())
271    }
272
273    /// Stop the session manager
274    pub async fn stop(&mut self) -> Result<()> {
275        if !*self.running.read() {
276            return Ok(());
277        }
278
279        // Set running flag to false
280        *self.running.write() = false;
281
282        // Send stop signal
283        let _ = self.stop_sender.send(true);
284
285        // Stop health check task
286        if let Some(handle) = self.health_check_handle.take() {
287            handle.abort();
288        }
289
290        // Stop all sessions
291        for i in 0..self.configs.len() {
292            self.stop_session(i).await?;
293        }
294
295        Ok(())
296    }
297
298    /// Start a session
299    async fn start_session(&mut self, session_index: usize) -> Result<()> {
300        if session_index >= self.configs.len() {
301            return Err(anyhow!("Invalid session index"));
302        }
303
304        // Update session state
305        *self.session_states[session_index].write() = SessionState::Initializing;
306        self.session_stats[session_index].write().state = SessionState::Initializing;
307        self.session_stats[session_index].write().last_state_change = Instant::now();
308
309        // Create message channel
310        let (tx, _rx) = flume::bounded(1024);
311        self.message_channels[session_index] = Some(tx.clone());
312
313        // Get configuration
314        let config = &self.configs[session_index];
315        let url = config.url.clone();
316        let subscription_message = config.subscription_message.clone();
317
318        // Create WebSocket connection
319        let handle = self
320            .create_websocket_connection(session_index, url, subscription_message, tx)
321            .await?;
322
323        // Save task handle
324        self.websocket_handles[session_index] = Some(handle);
325
326        // If this is the primary session, set it as active
327        if session_index == 0 {
328            *self.active_session.write() = 0;
329            *self.session_states[session_index].write() = SessionState::Active;
330            self.session_stats[session_index].write().state = SessionState::Active;
331            self.session_stats[session_index].write().last_state_change = Instant::now();
332            self.session_stats[session_index].write().activations += 1;
333            self.session_stats[session_index].write().last_activation = Some(Instant::now());
334        } else {
335            *self.session_states[session_index].write() = SessionState::Inactive;
336            self.session_stats[session_index].write().state = SessionState::Inactive;
337            self.session_stats[session_index].write().last_state_change = Instant::now();
338        }
339
340        Ok(())
341    }
342
343    /// Stop a session
344    async fn stop_session(&mut self, session_index: usize) -> Result<()> {
345        if session_index >= self.configs.len() {
346            return Err(anyhow!("Invalid session index"));
347        }
348
349        // Abort task if running
350        if let Some(handle) = self.websocket_handles[session_index].take() {
351            handle.abort();
352        }
353
354        // Update session state
355        *self.session_states[session_index].write() = SessionState::Inactive;
356        self.session_stats[session_index].write().state = SessionState::Inactive;
357        self.session_stats[session_index].write().last_state_change = Instant::now();
358
359        // If this was the active session, update stats
360        if *self.active_session.read() == session_index {
361            self.session_stats[session_index].write().deactivations += 1;
362            if let Some(last_activation) = self.session_stats[session_index].write().last_activation
363            {
364                let active_duration = Instant::now().duration_since(last_activation);
365                self.session_stats[session_index].write().total_active_time += active_duration;
366            }
367        }
368
369        Ok(())
370    }
371
372    /// Create a WebSocket connection (simplified implementation)
373    async fn create_websocket_connection(
374        &self,
375        session_index: usize,
376        _url: String,
377        _subscription_message: String,
378        _tx: flume::Sender<T>,
379    ) -> Result<JoinHandle<()>> {
380        // Simplified implementation until WebSocket migration is complete
381        let session_state = self.session_states[session_index].clone();
382        let session_stats = self.session_stats[session_index].clone();
383        let stop_receiver = self.stop_receiver.clone();
384        let running = self.running.clone();
385
386        let handle = tokio::spawn(async move {
387            loop {
388                // Check for stop signal
389                if *stop_receiver.borrow() || !*running.read() {
390                    break;
391                }
392
393                // Mark session as inactive until WebSocket migration is complete
394                *session_state.write() = SessionState::Inactive;
395                session_stats.write().state = SessionState::Inactive;
396                session_stats.write().last_state_change = Instant::now();
397
398                // Wait before retrying
399                tokio::time::sleep(Duration::from_secs(5)).await;
400
401                error!("WebSocket connection not implemented - migration to yawc in progress");
402            }
403        });
404
405        Ok(handle)
406    }
407
408    /// Start the health check task
409    fn start_health_check_task(&mut self) {
410        // Clone needed values for the task
411        let session_states = self.session_states.clone();
412        let session_stats = self.session_stats.clone();
413        let heartbeat_monitors = self.heartbeat_monitors.clone();
414        let active_session = self.active_session.clone();
415        let last_health_check = self.last_health_check.clone();
416        let health_check_interval_milliseconds = self.health_check_interval_milliseconds;
417        let stop_receiver = self.stop_receiver.clone();
418        let running = self.running.clone();
419        let fast_switching_enabled = self.fast_switching_enabled;
420
421        // Spawn the task
422        self.health_check_handle = Some(tokio::spawn(async move {
423            let mut interval =
424                tokio::time::interval(Duration::from_millis(health_check_interval_milliseconds));
425
426            while *running.read() {
427                // Wait for next interval
428                interval.tick().await;
429
430                // Check for stop signal
431                if *stop_receiver.borrow() || !*running.read() {
432                    break;
433                }
434
435                // Update last health check time
436                *last_health_check.write() = Instant::now();
437
438                // Check heartbeat status for all sessions
439                for i in 0..session_states.len() {
440                    if let Some(ref hb) = heartbeat_monitors[i] {
441                        let status = hb.check_heartbeat();
442                        match status {
443                            HeartbeatStatus::Healthy => {
444                                // If this is the active session, ensure it's marked as active
445                                if *active_session.read() == i {
446                                    *session_states[i].write() = SessionState::Active;
447                                    session_stats[i].write().state = SessionState::Active;
448                                }
449                            }
450                            HeartbeatStatus::Unhealthy => {
451                                // Mark session as unhealthy
452                                if *session_states[i].read() == SessionState::Active {
453                                    *session_states[i].write() = SessionState::Unhealthy;
454                                    session_stats[i].write().state = SessionState::Unhealthy;
455                                    session_stats[i].write().last_state_change = Instant::now();
456                                    warn!("Session {i} is unhealthy");
457                                }
458
459                                // If this is the active session and fast switching is enabled,
460                                // try to switch to a healthy backup session
461                                if *active_session.read() == i && fast_switching_enabled {
462                                    // Find a healthy backup session
463                                    for j in 0..session_states.len() {
464                                        if j != i
465                                            && *session_states[j].read() == SessionState::Inactive
466                                        {
467                                            // Check if this backup session is healthy
468                                            if let Some(ref backup_hb) = heartbeat_monitors[j]
469                                                && backup_hb.check_heartbeat()
470                                                    == HeartbeatStatus::Healthy
471                                            {
472                                                // Switch to this backup session
473                                                info!(
474                                                    "Switching from unhealthy session {i} to healthy backup session {j}"
475                                                );
476                                                *active_session.write() = j;
477                                                *session_states[j].write() = SessionState::Active;
478                                                session_stats[j].write().state =
479                                                    SessionState::Active;
480                                                session_stats[j].write().last_state_change =
481                                                    Instant::now();
482                                                session_stats[j].write().activations += 1;
483                                                session_stats[j].write().last_activation =
484                                                    Some(Instant::now());
485
486                                                // Update stats for the previous active session
487                                                session_stats[i].write().deactivations += 1;
488                                                if let Some(last_activation) =
489                                                    session_stats[i].write().last_activation
490                                                {
491                                                    let active_duration = Instant::now()
492                                                        .duration_since(last_activation);
493                                                    session_stats[i].write().total_active_time +=
494                                                        active_duration;
495                                                }
496                                                break;
497                                            }
498                                        }
499                                    }
500                                }
501                            }
502                            HeartbeatStatus::Dead => {
503                                // Mark session as inactive
504                                *session_states[i].write() = SessionState::Inactive;
505                                session_stats[i].write().state = SessionState::Inactive;
506                                session_stats[i].write().last_state_change = Instant::now();
507                                error!("Session {i} is dead");
508
509                                // If this is the active session, switch to a backup session
510                                if *active_session.read() == i {
511                                    // Find a healthy backup session
512                                    let mut found_backup = false;
513                                    for j in 0..session_states.len() {
514                                        if j != i
515                                            && (*session_states[j].read() == SessionState::Inactive
516                                                || *session_states[j].read()
517                                                    == SessionState::Active)
518                                        {
519                                            // Check if this backup session is healthy
520                                            if let Some(ref backup_hb) = heartbeat_monitors[j] {
521                                                let backup_status = backup_hb.check_heartbeat();
522                                                if backup_status == HeartbeatStatus::Healthy {
523                                                    // Switch to this backup session
524                                                    info!(
525                                                        "Switching from dead session {i} to healthy backup session {j}"
526                                                    );
527                                                    *active_session.write() = j;
528                                                    *session_states[j].write() =
529                                                        SessionState::Active;
530                                                    session_stats[j].write().state =
531                                                        SessionState::Active;
532                                                    session_stats[j].write().last_state_change =
533                                                        Instant::now();
534                                                    session_stats[j].write().activations += 1;
535                                                    session_stats[j].write().last_activation =
536                                                        Some(Instant::now());
537
538                                                    // Update stats for the previous active session
539                                                    session_stats[i].write().deactivations += 1;
540                                                    if let Some(last_activation) =
541                                                        session_stats[i].write().last_activation
542                                                    {
543                                                        let active_duration = Instant::now()
544                                                            .duration_since(last_activation);
545                                                        session_stats[i]
546                                                            .write()
547                                                            .total_active_time += active_duration;
548                                                    }
549
550                                                    found_backup = true;
551                                                    break;
552                                                }
553                                            }
554                                        }
555                                    }
556
557                                    // If no healthy backup session was found, try to reconnect this session
558                                    if !found_backup {
559                                        warn!(
560                                            "No healthy backup session found for dead session {i}, attempting to reconnect"
561                                        );
562                                    }
563                                }
564                            }
565                        }
566                    }
567                }
568            }
569        }));
570    }
571
572    /// Get the active session index
573    pub fn active_session_index(&self) -> usize {
574        *self.active_session.read()
575    }
576
577    /// Get session statistics
578    pub fn session_stats(&self, session_index: usize) -> Result<SessionStats> {
579        if session_index >= self.configs.len() {
580            return Err(anyhow!("Invalid session index"));
581        }
582
583        Ok(self.session_stats[session_index].read().clone())
584    }
585
586    /// Get all session statistics
587    pub fn all_session_stats(&self) -> Vec<SessionStats> {
588        self.session_stats
589            .iter()
590            .map(|stats| stats.read().clone())
591            .collect()
592    }
593
594    /// Get session state
595    pub fn session_state(&self, session_index: usize) -> Result<SessionState> {
596        if session_index >= self.configs.len() {
597            return Err(anyhow!("Invalid session index"));
598        }
599
600        Ok(*self.session_states[session_index].read())
601    }
602
603    /// Get all session states
604    pub fn all_session_states(&self) -> Vec<SessionState> {
605        self.session_states
606            .iter()
607            .map(|state| *state.read())
608            .collect()
609    }
610
611    /// Manually switch to a different session
612    pub async fn switch_to_session(&mut self, session_index: usize) -> Result<()> {
613        if session_index >= self.configs.len() {
614            return Err(anyhow!("Invalid session index"));
615        }
616
617        let current_active = *self.active_session.read();
618
619        // If already active, do nothing
620        if current_active == session_index {
621            return Ok(());
622        }
623
624        // Update active session
625        *self.active_session.write() = session_index;
626
627        // Update session states
628        *self.session_states[current_active].write() = SessionState::Inactive;
629        self.session_stats[current_active].write().state = SessionState::Inactive;
630        self.session_stats[current_active].write().last_state_change = Instant::now();
631        self.session_stats[current_active].write().deactivations += 1;
632
633        if let Some(last_activation) = self.session_stats[current_active].write().last_activation {
634            let active_duration = Instant::now().duration_since(last_activation);
635            self.session_stats[current_active].write().total_active_time += active_duration;
636        }
637
638        *self.session_states[session_index].write() = SessionState::Active;
639        self.session_stats[session_index].write().state = SessionState::Active;
640        self.session_stats[session_index].write().last_state_change = Instant::now();
641        self.session_stats[session_index].write().activations += 1;
642        self.session_stats[session_index].write().last_activation = Some(Instant::now());
643
644        info!("Manually switched from session {current_active} to session {session_index}");
645
646        Ok(())
647    }
648
649    /// Check if the session manager is running
650    pub fn is_running(&self) -> bool {
651        *self.running.read()
652    }
653
654    /// Get the time since the last health check
655    pub fn time_since_last_health_check(&self) -> Duration {
656        Instant::now().duration_since(*self.last_health_check.read())
657    }
658
659    /// Get the number of sessions
660    pub fn session_count(&self) -> usize {
661        self.configs.len()
662    }
663
664    /// Get a reference to the session configuration
665    pub fn session_config(&self, session_index: usize) -> Result<&SessionConfig> {
666        if session_index >= self.configs.len() {
667            return Err(anyhow!("Invalid session index"));
668        }
669
670        Ok(&self.configs[session_index])
671    }
672
673    /// Get all session configurations
674    pub fn all_session_configs(&self) -> &[SessionConfig] {
675        &self.configs
676    }
677}
678
679// Type aliases for backward compatibility
680/// Default session manager with 8 concurrent sessions
681pub type DefaultSessionManager<T> = SessionManager<T, 8>;
682/// Session manager with 4 concurrent sessions
683pub type SessionManager4<T> = SessionManager<T, 4>;
684/// Session manager with 16 concurrent sessions
685pub type SessionManager16<T> = SessionManager<T, 16>;
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690
691    // Mock message type for testing
692    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
693    struct MockMessage {
694        pub content: String,
695    }
696
697    #[tokio::test]
698    async fn test_session_manager_creation() {
699        // Create session configurations
700        let configs: SmallVec<[SessionConfig; 8]> = smallvec::smallvec![
701            SessionConfig {
702                session_type: SessionType::Primary,
703                url: "wss://example.com/ws/primary".into(),
704                subscription_message:
705                    r#"{"type":"subscribe","channels":["ticker"],"product_ids":["BTC-USD"]}"#.into(),
706                connection_config: ConnectionConfig::default(),
707            },
708            SessionConfig {
709                session_type: SessionType::Backup,
710                url: "wss://example.com/ws/backup".into(),
711                subscription_message:
712                    r#"{"type":"subscribe","channels":["ticker"],"product_ids":["BTC-USD"]}"#.into(),
713                connection_config: ConnectionConfig::default(),
714            },
715        ];
716
717        // Create session manager
718        let session_manager = SessionManager::<MockMessage>::new(
719            configs, 1000, // 1 second health check interval
720            100,  // 100 message buffer size
721            true, // Enable fast switching
722        );
723
724        // Check initial state
725        assert_eq!(session_manager.session_count(), 2);
726        assert_eq!(session_manager.active_session_index(), 0);
727        assert!(!session_manager.is_running());
728
729        // Check session states
730        let states = session_manager.all_session_states();
731        assert_eq!(states.len(), 2);
732        assert_eq!(states[0], SessionState::Uninitialized);
733        assert_eq!(states[1], SessionState::Uninitialized);
734    }
735}