rusty_common/websocket/
connector.rs

1//! WebSocket connector with advanced features
2//!
3//! Provides connection management with batch processing, heartbeat monitoring,
4//! failover support, and high-level task creation utilities.
5
6use futures_util::{SinkExt, StreamExt};
7use parking_lot::RwLock;
8use quanta::{Clock, Instant};
9use smartstring::alias::String;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{mpsc, watch};
13use tokio::task::JoinHandle;
14use tokio::time::{sleep, timeout};
15use yawc::{Options, WebSocket, frame::FrameView};
16
17use super::batch::{BatchProcessingMetrics, BatchProcessor};
18use super::exchanges::ExchangeConfig;
19use super::heartbeat::HeartbeatMonitor;
20use super::stats::SharedStats;
21use super::{ConnectionState, Message, WebSocketConfig, WebSocketError, WebSocketResult};
22
23/// Type alias for WebSocket split sink.
24pub type WebSocketSink = futures_util::stream::SplitSink<WebSocket, FrameView>;
25/// Type alias for WebSocket split stream.
26pub type WebSocketStream = futures_util::stream::SplitStream<WebSocket>;
27
28/// Reconnection state tracking
29#[derive(Debug)]
30struct ReconnectState {
31    /// Number of reconnection attempts made
32    attempts: u32,
33    /// Current backoff delay in milliseconds
34    backoff_delay_milliseconds: u64,
35    /// Last connection attempt timestamp
36    last_attempt: Instant,
37    /// Last successful connection
38    last_success: Option<Instant>,
39    /// Consecutive failures count
40    consecutive_failures: u32,
41    /// Jitter factor (0.0-1.0)
42    jitter_factor: f64,
43}
44
45impl Default for ReconnectState {
46    fn default() -> Self {
47        Self {
48            attempts: 0,
49            backoff_delay_milliseconds: 100,
50            last_attempt: Instant::now(),
51            last_success: None,
52            consecutive_failures: 0,
53            jitter_factor: 0.1,
54        }
55    }
56}
57
58/// WebSocket connector with advanced features
59#[derive(Debug)]
60#[repr(align(64))] // Cache-line alignment
61pub struct WebSocketConnector {
62    /// Configuration
63    config: WebSocketConfig,
64    /// Reconnection state
65    reconnect_state: ReconnectState,
66    /// Shared clock
67    clock: Clock,
68    /// Connection statistics
69    stats: SharedStats,
70    /// Connection status
71    connection_status: Arc<RwLock<ConnectionState>>,
72    /// Batch processor
73    batch_processor: BatchProcessor,
74    /// Heartbeat monitor
75    heartbeat_monitor: Option<HeartbeatMonitor>,
76    /// Current URL index for failover
77    current_url_index: Arc<RwLock<usize>>,
78}
79
80impl WebSocketConnector {
81    /// Create a new WebSocket connector
82    #[must_use]
83    pub fn new(
84        config: WebSocketConfig,
85        stats: SharedStats,
86        connection_status: Arc<RwLock<ConnectionState>>,
87    ) -> Self {
88        // Create heartbeat monitor if enabled
89        let heartbeat_monitor = if config.heartbeat_interval_milliseconds > 0 {
90            Some(HeartbeatMonitor::new(
91                config.heartbeat_interval_milliseconds,
92                config.heartbeat_timeout_milliseconds,
93                config.max_missed_heartbeats,
94                connection_status.clone(),
95                stats.clone(),
96            ))
97        } else {
98            None
99        };
100
101        Self {
102            batch_processor: BatchProcessor::new(config.batch_size),
103            config,
104            reconnect_state: ReconnectState::default(),
105            clock: Clock::new(),
106            stats,
107            connection_status,
108            heartbeat_monitor,
109            current_url_index: Arc::new(RwLock::new(0)),
110        }
111    }
112
113    /// Get the configuration
114    pub const fn config(&self) -> &WebSocketConfig {
115        &self.config
116    }
117
118    /// Get batch processing metrics
119    pub fn get_batch_metrics(&self) -> BatchProcessingMetrics {
120        self.batch_processor.get_metrics()
121    }
122
123    /// Create WebSocket options with exchange-specific settings
124    pub fn create_websocket_options(exchange_config: Option<&ExchangeConfig>) -> Options {
125        let mut options = Options::default();
126
127        if let Some(config) = exchange_config
128            && config.compression.enabled
129        {
130            // TODO: Configure compression when yawc Options API is available
131            // Currently yawc Options doesn't expose compression configuration
132            options.compression = Some(Default::default());
133        }
134
135        options
136    }
137
138    /// Process a batch of messages
139    pub async fn process_message_batch<T, P, R>(
140        &self,
141        messages: Vec<Message>,
142        process_message: P,
143        sender: &mpsc::Sender<T>,
144    ) -> WebSocketResult<usize>
145    where
146        T: Send + 'static,
147        P: Fn(&str, &Clock) -> Option<R> + Send + Sync,
148        R: Into<T> + Send,
149    {
150        if messages.is_empty() {
151            return Ok(0);
152        }
153
154        let start_time = self.clock.now();
155        let message_count = messages.len();
156        let mut processed_count = 0;
157        let mut results = Vec::with_capacity(message_count);
158
159        // Process all messages
160        for message in messages {
161            match message {
162                Message::Text(text) => {
163                    // Update stats
164                    {
165                        let mut s = self.stats.write();
166                        s.messages_received += 1;
167                        s.bytes_received += text.len() as u64;
168                        s.last_message_time = self.clock.raw();
169                    }
170
171                    // Process the message
172                    if let Some(result) = process_message(&text, &self.clock) {
173                        results.push(result.into());
174                        processed_count += 1;
175                    }
176                }
177                Message::Binary(bin) => {
178                    let text: String = std::string::String::from_utf8_lossy(&bin)
179                        .into_owned()
180                        .into();
181
182                    // Update stats
183                    {
184                        let mut s = self.stats.write();
185                        s.messages_received += 1;
186                        s.bytes_received += bin.len() as u64;
187                        s.last_message_time = self.clock.raw();
188                    }
189
190                    // Process the message
191                    if let Some(result) = process_message(&text, &self.clock) {
192                        results.push(result.into());
193                        processed_count += 1;
194                    }
195                }
196                Message::Pong(_) => {
197                    self.stats.write().last_pong_time = self.clock.raw();
198                }
199                Message::Close(close_frame) => {
200                    log::info!("WebSocket connection closed: {close_frame:?}");
201                    *self.connection_status.write() = ConnectionState::Disconnected;
202                }
203                Message::Ping(data) => {
204                    log::trace!("Received ping message with {} bytes", data.len());
205                }
206                Message::Frame(_) => {
207                    log::trace!("Received frame message");
208                }
209            }
210        }
211
212        // Send all results
213        for result in results {
214            if let Err(e) = sender.send(result).await {
215                log::error!("Failed to send processed message: {e}");
216                return Err(WebSocketError::MessageProcessingError(e.to_string()));
217            }
218        }
219
220        // Update batch metrics
221        let elapsed_ns = (self.clock.now() - start_time).as_nanos() as u64;
222        self.batch_processor
223            .update_metrics(message_count, elapsed_ns);
224
225        Ok(processed_count)
226    }
227
228    /// Connect with retry and failover support
229    pub async fn connect_with_retry(
230        &mut self,
231        url: &str,
232    ) -> WebSocketResult<(WebSocketSink, WebSocketStream)> {
233        // Check backoff timing
234        let now = Instant::now();
235        if now.duration_since(self.reconnect_state.last_attempt)
236            < Duration::from_millis(self.reconnect_state.backoff_delay_milliseconds)
237        {
238            sleep(Duration::from_millis(
239                self.reconnect_state.backoff_delay_milliseconds,
240            ))
241            .await;
242        }
243
244        // Update connection state
245        *self.connection_status.write() = ConnectionState::Connecting;
246        self.reconnect_state.last_attempt = Instant::now();
247
248        // Get the current URL (primary or failover)
249        let current_url =
250            if self.config.enable_session_failover && !self.config.failover_urls.is_empty() {
251                let current_index = *self.current_url_index.read();
252                if current_index == 0 {
253                    url.to_string()
254                } else {
255                    let failover_index = (current_index - 1) % self.config.failover_urls.len();
256                    self.config.failover_urls[failover_index].clone()
257                }
258            } else {
259                url.to_string()
260            };
261
262        // Parse URL
263        let parsed_url = current_url
264            .parse()
265            .map_err(|e| WebSocketError::ConnectionError(format!("Invalid URL: {e}")))?;
266
267        // Create options
268        let options = Self::create_websocket_options(None); // TODO: Pass exchange config
269
270        // Add jitter to backoff
271        let jitter_factor = self.reconnect_state.jitter_factor;
272        let jitter_range = self.reconnect_state.backoff_delay_milliseconds as f64 * jitter_factor;
273        let now_ns = self.clock.raw();
274        let jitter = ((now_ns % 1000) as f64 / 1000.0 * jitter_range) as u64;
275
276        // Connect with timeout
277        match timeout(
278            self.config.connect_timeout,
279            WebSocket::connect(parsed_url).with_options(options),
280        )
281        .await
282        {
283            Ok(Ok(ws_stream)) => {
284                // Reset reconnect state on success
285                self.reconnect_state.attempts = 0;
286                self.reconnect_state.backoff_delay_milliseconds =
287                    self.config.reconnect.initial_delay.as_millis() as u64;
288                self.reconnect_state.consecutive_failures = 0;
289                self.reconnect_state.last_success = Some(Instant::now());
290
291                // Update connection status
292                *self.connection_status.write() = ConnectionState::Connected;
293                self.stats.write().connected_time = self.clock.raw();
294
295                // Reset heartbeat monitor
296                if let Some(ref monitor) = self.heartbeat_monitor {
297                    monitor.reset();
298                }
299
300                // Split the stream
301                Ok(ws_stream.split())
302            }
303            Ok(Err(e)) => {
304                // Connection failed
305                *self.connection_status.write() = ConnectionState::Error;
306                self.reconnect_state.consecutive_failures += 1;
307
308                // Try failover if enabled
309                if self.config.enable_session_failover && !self.config.failover_urls.is_empty() {
310                    {
311                        let mut current_index = self.current_url_index.write();
312                        *current_index =
313                            (*current_index + 1) % (self.config.failover_urls.len() + 1);
314                    }
315
316                    let current_index = *self.current_url_index.read();
317                    if current_index == 0 {
318                        log::warn!("All failover URLs failed, resetting to primary");
319                    } else {
320                        log::info!("Trying failover URL index: {current_index}");
321                        return Box::pin(self.connect_with_retry(url)).await;
322                    }
323                }
324
325                // Update backoff
326                if self.config.reconnect.enabled {
327                    self.reconnect_state.attempts += 1;
328
329                    if self.config.reconnect.max_attempts > 0
330                        && self.reconnect_state.attempts >= self.config.reconnect.max_attempts
331                    {
332                        return Err(WebSocketError::ConnectionError(format!(
333                            "Maximum reconnection attempts exceeded: {e}"
334                        )));
335                    }
336
337                    let consecutive_factor =
338                        std::cmp::min(self.reconnect_state.consecutive_failures, 5) as f64 * 0.5
339                            + 1.0;
340
341                    self.reconnect_state.backoff_delay_milliseconds = std::cmp::min(
342                        ((self.reconnect_state.backoff_delay_milliseconds as f64
343                            * consecutive_factor
344                            * self.config.reconnect.multiplier) as u64)
345                            .saturating_add(jitter),
346                        self.config.reconnect.max_delay.as_millis() as u64,
347                    );
348
349                    log::warn!(
350                        "WebSocket connection failed, attempt {}/{}. Next retry in {}ms. Error: {}",
351                        self.reconnect_state.attempts,
352                        if self.config.reconnect.max_attempts > 0 {
353                            self.config.reconnect.max_attempts.to_string()
354                        } else {
355                            "∞".to_string()
356                        },
357                        self.reconnect_state.backoff_delay_milliseconds,
358                        e
359                    );
360                }
361
362                Err(WebSocketError::ConnectionError(format!(
363                    "Failed to connect: {e}"
364                )))
365            }
366            Err(_) => {
367                // Timeout
368                *self.connection_status.write() = ConnectionState::Error;
369                self.reconnect_state.consecutive_failures += 1;
370                self.reconnect_state.attempts += 1;
371
372                // Try failover if enabled
373                if self.config.enable_session_failover && !self.config.failover_urls.is_empty() {
374                    {
375                        let mut current_index = self.current_url_index.write();
376                        *current_index =
377                            (*current_index + 1) % (self.config.failover_urls.len() + 1);
378                    }
379
380                    let current_index = *self.current_url_index.read();
381                    if current_index == 0 {
382                        log::warn!("All failover URLs failed, resetting to primary");
383                    } else {
384                        log::info!("Trying failover URL index: {current_index}");
385                        return Box::pin(self.connect_with_retry(url)).await;
386                    }
387                }
388
389                // Double the timeout for next attempt
390                self.reconnect_state.backoff_delay_milliseconds = std::cmp::min(
391                    self.reconnect_state.backoff_delay_milliseconds * 2,
392                    self.config.reconnect.max_delay.as_millis() as u64,
393                );
394
395                Err(WebSocketError::Timeout(
396                    self.config.connect_timeout.as_millis() as u64,
397                ))
398            }
399        }
400    }
401
402    /// Create a WebSocket handler task
403    pub fn create_websocket_task<T, P, R>(
404        &self,
405        url: String,
406        subscription_message: String,
407        process_message: P,
408        sender: mpsc::Sender<T>,
409        stop_receiver: &mut watch::Receiver<bool>,
410    ) -> JoinHandle<()>
411    where
412        T: Send + 'static,
413        P: Fn(&str, &Clock) -> Option<R> + Send + Sync + 'static,
414        R: Into<T> + Send + 'static,
415    {
416        // Clone needed values
417        let clock = self.clock.clone();
418        let timeout_ms = self.config.timeout.as_millis() as u64;
419        let stats = self.stats.clone();
420        let connection_status = self.connection_status.clone();
421        let reconnect_config = self.config.reconnect.clone();
422        let ping_interval = self.config.ping_interval;
423        let custom_ping = self.config.custom_ping_message.clone();
424        let custom_pong = self.config.custom_pong_response.clone();
425        let _use_compression = self.config.compression.enabled;
426        let mut stop_receiver = stop_receiver.clone();
427
428        tokio::spawn(async move {
429            let mut reconnect_attempts = 0;
430            let mut backoff_delay = reconnect_config.initial_delay.as_millis() as u64;
431
432            'connection: loop {
433                // Check for stop signal
434                if *stop_receiver.borrow() {
435                    break;
436                }
437
438                // Parse URL
439                let parsed_url = match url.parse() {
440                    Ok(u) => u,
441                    Err(e) => {
442                        log::error!("Failed to parse URL: {e}");
443                        tokio::time::sleep(Duration::from_millis(backoff_delay)).await;
444                        continue;
445                    }
446                };
447
448                // Create options
449                let options = WebSocketConnector::create_websocket_options(None);
450
451                // Connect
452                match WebSocket::connect(parsed_url).with_options(options).await {
453                    Ok(ws_stream) => {
454                        // Reset reconnect state
455                        reconnect_attempts = 0;
456                        backoff_delay = reconnect_config.initial_delay.as_millis() as u64;
457
458                        // Update status
459                        *connection_status.write() = ConnectionState::Connected;
460                        stats.write().connected_time = clock.raw();
461
462                        // Split stream
463                        let (mut websocket_sender, mut websocket_receiver) = ws_stream.split();
464
465                        // Send subscription
466                        if let Err(e) = websocket_sender
467                            .send(FrameView::text(subscription_message.to_string()))
468                            .await
469                        {
470                            log::error!("Failed to send subscription: {e}");
471                            tokio::time::sleep(Duration::from_millis(backoff_delay)).await;
472                            continue 'connection;
473                        }
474
475                        // Setup ping interval
476                        let mut ping_interval = tokio::time::interval(ping_interval);
477
478                        // Handle messages
479                        loop {
480                            tokio::select! {
481                                // Check for stop signal
482                                _ = stop_receiver.changed() => {
483                                    if *stop_receiver.borrow() {
484                                        break 'connection;
485                                    }
486                                }
487
488                                // Send ping
489                                _ = ping_interval.tick() => {
490                                    let ping_frame = if let Some(ref custom) = custom_ping {
491                                        FrameView::text(custom.clone())
492                                    } else {
493                                        FrameView::ping(vec![])
494                                    };
495
496                                    if let Err(e) = websocket_sender.send(ping_frame).await {
497                                        log::error!("Failed to send ping: {e}");
498                                        break;
499                                    }
500                                    stats.write().last_ping_time = clock.raw();
501                                }
502
503                                // Receive message
504                                message_result = timeout(
505                                    Duration::from_millis(timeout_ms),
506                                    websocket_receiver.next()
507                                ) => {
508                                    match message_result {
509                                        Ok(Some(frame)) => {
510                                            let message = Message::from_frame_view(frame);
511
512                                            match &message {
513                                                Message::Text(text) => {
514                                                    // Update stats
515                                                    {
516                                                        let mut s = stats.write();
517                                                        s.messages_received += 1;
518                                                        s.bytes_received += text.len() as u64;
519                                                        s.last_message_time = clock.raw();
520                                                    }
521
522                                                    // Process message
523                                                    if let Some(result) = process_message(text, &clock)
524                                                        && let Err(e) = sender.send(result.into()).await {
525                                                            log::error!("Failed to send message: {e}");
526                                                        }
527                                                }
528                                                Message::Binary(bin) => {
529                                                    let text = std::string::String::from_utf8_lossy(bin);
530
531                                                    // Update stats
532                                                    {
533                                                        let mut s = stats.write();
534                                                        s.messages_received += 1;
535                                                        s.bytes_received += bin.len() as u64;
536                                                        s.last_message_time = clock.raw();
537                                                    }
538
539                                                    // Process message
540                                                    if let Some(result) = process_message(&text, &clock)
541                                                        && let Err(e) = sender.send(result.into()).await {
542                                                            log::error!("Failed to send message: {e}");
543                                                        }
544                                                }
545                                                Message::Pong(_) => {
546                                                    stats.write().last_pong_time = clock.raw();
547                                                }
548                                                Message::Close(close_frame) => {
549                                                    log::info!("WebSocket closed: {close_frame:?}");
550                                                    *connection_status.write() = ConnectionState::Disconnected;
551                                                    break;
552                                                }
553                                                Message::Ping(_) => {
554                                                    let pong_frame = if let Some(ref custom) = custom_pong {
555                                                        FrameView::text(custom.clone())
556                                                    } else {
557                                                        FrameView::pong(vec![])
558                                                    };
559
560                                                    if let Err(e) = websocket_sender.send(pong_frame).await {
561                                                        log::error!("Failed to send pong: {e}");
562                                                        break;
563                                                    }
564                                                }
565                                                &Message::Frame(_) => {
566                                                    log::trace!("Received frame message");
567                                                }
568                                            }
569                                        }
570                                        Ok(None) => {
571                                            log::info!("WebSocket connection closed");
572                                            break;
573                                        }
574                                        Err(_) => {
575                                            log::error!("WebSocket message timeout");
576                                            break;
577                                        }
578                                    }
579                                }
580                            }
581                        }
582                    }
583                    Err(e) => {
584                        log::error!("Failed to connect: {e}");
585                        *connection_status.write() = ConnectionState::Error;
586                    }
587                }
588
589                // Reconnect if enabled
590                if reconnect_config.enabled
591                    && (reconnect_config.max_attempts == 0
592                        || reconnect_attempts < reconnect_config.max_attempts)
593                {
594                    reconnect_attempts += 1;
595                    *connection_status.write() = ConnectionState::Reconnecting;
596                    stats.write().reconnections += 1;
597
598                    // Calculate backoff
599                    backoff_delay = std::cmp::min(
600                        (backoff_delay as f64 * reconnect_config.multiplier) as u64,
601                        reconnect_config.max_delay.as_millis() as u64,
602                    );
603
604                    tokio::time::sleep(Duration::from_millis(backoff_delay)).await;
605                } else {
606                    *connection_status.write() = ConnectionState::Disconnected;
607                    break 'connection;
608                }
609            }
610        })
611    }
612}