rusty_ems/exchanges/
websocket_unified.rs

1//! Unified WebSocket Trading Components
2//!
3//! This module provides common abstractions, traits, and utilities for WebSocket
4//! trading implementations across all supported exchanges. It eliminates code
5//! duplication and ensures consistent behavior patterns.
6//!
7//! # Features
8//!
9//! - **Unified Connection State Management**: Standardized state machine for all exchanges
10//! - **Common Authentication Interface**: Pluggable authentication mechanisms
11//! - **Standardized Health Monitoring**: Consistent ping/pong and health tracking
12//! - **Shared Reconnection Logic**: Exponential backoff with configurable parameters
13//! - **Unified Rate Limiting**: Configurable rate limiting for all exchanges
14//! - **Common Request Management**: Standardized request ID generation and tracking
15
16use std::collections::VecDeque;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
19use std::time::Duration;
20
21use anyhow::{Result, anyhow};
22use async_trait::async_trait;
23use flume::Sender;
24use quanta::Clock;
25use rand::{Rng, rng};
26use rusty_common::SmartString;
27use rusty_common::collections::FxHashMap;
28use rusty_common::websocket::connector::WebSocketSink;
29use simd_json::prelude::{ValueAsScalar, ValueObjectAccess};
30use simd_json::value::owned::Value as JsonValue;
31use tokio::task::JoinHandle;
32use uuid::Uuid;
33
34use crate::execution_engine::ExecutionReport;
35
36/// Unified connection state for all WebSocket trading implementations
37///
38/// This enum provides a standard state machine that can be used across
39/// all exchange implementations, ensuring consistent behavior.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41#[repr(u8)]
42pub enum WebSocketConnectionState {
43    /// WebSocket connection is not established
44    Disconnected = 0,
45    /// WebSocket connection is being established
46    Connecting = 1,
47    /// WebSocket connection is established but not yet authenticated
48    Connected = 2,
49    /// WebSocket connection is performing authentication
50    Authenticating = 3,
51    /// WebSocket connection is authenticated and ready for trading
52    Authenticated = 4,
53    /// WebSocket connection is gracefully disconnecting
54    Disconnecting = 5,
55}
56
57impl From<u8> for WebSocketConnectionState {
58    fn from(value: u8) -> Self {
59        match value {
60            0 => Self::Disconnected,
61            1 => Self::Connecting,
62            2 => Self::Connected,
63            3 => Self::Authenticating,
64            4 => Self::Authenticated,
65            5 => Self::Disconnecting,
66            _ => Self::Disconnected,
67        }
68    }
69}
70
71impl From<WebSocketConnectionState> for u8 {
72    fn from(state: WebSocketConnectionState) -> Self {
73        state as Self
74    }
75}
76
77/// Standard health information for WebSocket connections
78///
79/// Provides consistent health monitoring metrics across all exchanges
80#[derive(Debug, Clone)]
81pub struct ConnectionHealth {
82    /// Current connection state
83    pub state: WebSocketConnectionState,
84    /// Whether the WebSocket is connected
85    pub is_connected: bool,
86    /// Whether authentication is completed
87    pub is_authenticated: bool,
88    /// Time since last ping was sent
89    pub time_since_last_ping: Option<Duration>,
90    /// Time since last pong was received
91    pub time_since_last_pong: Option<Duration>,
92    /// Time since authentication completed
93    pub time_since_auth: Option<Duration>,
94    /// Whether the connection is considered healthy
95    pub is_healthy: bool,
96    /// Current reconnection attempt count
97    pub reconnection_attempts: u32,
98    /// Time until next reconnection attempt
99    pub next_reconnection_in: Option<Duration>,
100}
101
102/// Connection state manager for thread-safe state management
103///
104/// Provides atomic operations for connection state changes with proper
105/// ordering guarantees for concurrent access.
106#[derive(Debug, Clone)]
107pub struct ConnectionStateManager {
108    state: Arc<AtomicU8>,
109}
110
111impl ConnectionStateManager {
112    /// Create a new connection state manager
113    #[must_use]
114    pub fn new() -> Self {
115        Self {
116            state: Arc::new(AtomicU8::new(WebSocketConnectionState::Disconnected as u8)),
117        }
118    }
119
120    /// Get the current connection state
121    #[must_use]
122    pub fn get_state(&self) -> WebSocketConnectionState {
123        WebSocketConnectionState::from(self.state.load(Ordering::Acquire))
124    }
125
126    /// Set the connection state
127    pub fn set_state(&self, new_state: WebSocketConnectionState) {
128        self.state.store(new_state as u8, Ordering::Release);
129    }
130
131    /// Check if currently connected
132    #[must_use]
133    pub fn is_connected(&self) -> bool {
134        matches!(
135            self.get_state(),
136            WebSocketConnectionState::Connected
137                | WebSocketConnectionState::Authenticating
138                | WebSocketConnectionState::Authenticated
139        )
140    }
141
142    /// Check if authenticated
143    #[must_use]
144    pub fn is_authenticated(&self) -> bool {
145        self.get_state() == WebSocketConnectionState::Authenticated
146    }
147
148    /// Attempt to transition state (returns success)
149    #[must_use]
150    pub fn try_transition(
151        &self,
152        from: WebSocketConnectionState,
153        to: WebSocketConnectionState,
154    ) -> bool {
155        self.state
156            .compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
157            .is_ok()
158    }
159}
160
161impl Default for ConnectionStateManager {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167/// Configuration for WebSocket health monitoring
168#[derive(Debug, Clone)]
169pub struct HealthConfig {
170    /// Interval between ping messages
171    pub ping_interval: Duration,
172    /// Timeout for pong responses
173    pub pong_timeout: Duration,
174    /// Whether to use WebSocket-level pings
175    pub use_websocket_ping: bool,
176    /// Whether to use application-level pings
177    pub use_app_level_ping: bool,
178}
179
180impl Default for HealthConfig {
181    fn default() -> Self {
182        Self {
183            ping_interval: Duration::from_secs(30),
184            pong_timeout: Duration::from_secs(10),
185            use_websocket_ping: true,
186            use_app_level_ping: false,
187        }
188    }
189}
190
191/// Configuration for reconnection behavior
192#[derive(Debug, Clone)]
193pub struct ReconnectionConfig {
194    /// Initial backoff delay
195    pub initial_backoff: Duration,
196    /// Maximum backoff delay
197    pub max_backoff: Duration,
198    /// Maximum number of reconnection attempts
199    pub max_attempts: u32,
200    /// Backoff multiplier for exponential backoff
201    pub backoff_multiplier: f64,
202    /// Whether to add jitter to backoff delays
203    pub use_jitter: bool,
204}
205
206impl Default for ReconnectionConfig {
207    fn default() -> Self {
208        Self {
209            initial_backoff: Duration::from_secs(1),
210            max_backoff: Duration::from_secs(60),
211            max_attempts: 10,
212            backoff_multiplier: 2.0,
213            use_jitter: true,
214        }
215    }
216}
217
218/// Configuration for rate limiting
219#[derive(Debug, Clone)]
220pub struct RateLimitConfig {
221    /// Maximum requests per time window
222    pub max_requests: u32,
223    /// Time window for rate limiting
224    pub window_duration: Duration,
225    /// Maximum batch size for bulk operations
226    pub max_batch_size: usize,
227}
228
229impl Default for RateLimitConfig {
230    fn default() -> Self {
231        Self {
232            max_requests: 300,
233            window_duration: Duration::from_secs(10),
234            max_batch_size: 50,
235        }
236    }
237}
238
239/// Authentication mechanism trait for different exchange authentication methods
240#[async_trait]
241pub trait AuthenticationMechanism: Send + Sync {
242    /// Perform authentication with the exchange
243    async fn authenticate(&self, ws_sink: &mut WebSocketSink) -> Result<()>;
244
245    /// Check if authentication needs to be refreshed
246    fn needs_refresh(&self) -> bool;
247
248    /// Refresh authentication credentials
249    async fn refresh(&self) -> Result<()>;
250
251    /// Get authentication timeout duration
252    fn auth_timeout(&self) -> Duration {
253        Duration::from_secs(10)
254    }
255}
256
257/// Unified rate limiter for WebSocket requests
258///
259/// Provides configurable rate limiting with sliding window approach
260/// to prevent API violations across all exchanges.
261#[derive(Debug)]
262pub struct UnifiedRateLimiter {
263    request_times: VecDeque<u64>,
264    config: RateLimitConfig,
265    clock: Clock,
266}
267
268impl UnifiedRateLimiter {
269    /// Create a new rate limiter with the given configuration
270    #[must_use]
271    pub const fn new(config: RateLimitConfig, clock: Clock) -> Self {
272        Self {
273            request_times: VecDeque::new(),
274            config,
275            clock,
276        }
277    }
278
279    /// Remove expired timestamps from the sliding window
280    fn cleanup_expired(&mut self) {
281        let now = self.clock.raw() / 1_000_000; // Convert to milliseconds
282        let window_start = now.saturating_sub(self.config.window_duration.as_millis() as u64);
283
284        while let Some(&front) = self.request_times.front() {
285            if front < window_start {
286                self.request_times.pop_front();
287            } else {
288                break;
289            }
290        }
291    }
292
293    /// Check if the given number of requests can be made
294    pub fn can_make_requests(&mut self, count: usize) -> bool {
295        self.cleanup_expired();
296        self.request_times.len() + count <= self.config.max_requests as usize
297    }
298
299    /// Record the given number of requests
300    pub fn record_requests(&mut self, count: usize) {
301        let now = self.clock.raw() / 1_000_000; // Convert to milliseconds
302        for _ in 0..count {
303            self.request_times.push_back(now);
304        }
305    }
306
307    /// Get current usage statistics
308    pub fn get_usage(&mut self) -> (usize, usize) {
309        self.cleanup_expired();
310        (self.request_times.len(), self.config.max_requests as usize)
311    }
312
313    /// Wait until requests can be made (for async rate limiting)
314    pub async fn acquire_permits(&mut self, count: usize) -> Result<()> {
315        while !self.can_make_requests(count) {
316            tokio::time::sleep(Duration::from_millis(100)).await;
317        }
318        self.record_requests(count);
319        Ok(())
320    }
321}
322
323/// Request ID management for tracking WebSocket requests
324#[derive(Debug, Clone)]
325pub enum RequestId {
326    /// Sequential numeric ID
327    Sequential(u64),
328    /// UUID-based ID
329    Uuid(SmartString),
330}
331
332impl RequestId {
333    /// Generate a new sequential request ID
334    pub fn new_sequential(counter: &AtomicU64) -> Self {
335        Self::Sequential(counter.fetch_add(1, Ordering::Relaxed))
336    }
337
338    /// Generate a new UUID-based request ID
339    #[must_use]
340    pub fn new_uuid() -> Self {
341        Self::Uuid(Uuid::new_v4().to_string().into())
342    }
343
344    /// Get the request ID as a string
345    #[must_use]
346    pub fn as_string(&self) -> SmartString {
347        match self {
348            Self::Sequential(id) => id.to_string().into(),
349            Self::Uuid(id) => id.clone(),
350        }
351    }
352
353    /// Get the request ID as a JSON value
354    #[must_use]
355    pub fn as_json_value(&self) -> JsonValue {
356        match self {
357            Self::Sequential(id) => JsonValue::from(*id as i64),
358            Self::Uuid(id) => JsonValue::from(id.as_str()),
359        }
360    }
361}
362
363/// Request management for tracking pending WebSocket requests
364pub type PendingRequestsMap = FxHashMap<SmartString, tokio::sync::oneshot::Sender<JsonValue>>;
365
366/// Standard task handles for WebSocket connections
367#[derive(Debug)]
368pub struct WebSocketTaskHandles {
369    /// Task for handling WebSocket responses
370    pub response_handler: Option<JoinHandle<()>>,
371    /// Task for sending ping messages
372    pub ping_handler: Option<JoinHandle<()>>,
373    /// Task for monitoring reconnections
374    pub reconnection_monitor: Option<JoinHandle<()>>,
375    /// Task for cleanup operations
376    pub cleanup_task: Option<JoinHandle<()>>,
377}
378
379impl WebSocketTaskHandles {
380    /// Create new empty task handles
381    #[must_use]
382    pub const fn new() -> Self {
383        Self {
384            response_handler: None,
385            ping_handler: None,
386            reconnection_monitor: None,
387            cleanup_task: None,
388        }
389    }
390
391    /// Abort all running tasks
392    pub async fn abort_all(&mut self) {
393        if let Some(handle) = self.response_handler.take() {
394            handle.abort();
395        }
396        if let Some(handle) = self.ping_handler.take() {
397            handle.abort();
398        }
399        if let Some(handle) = self.reconnection_monitor.take() {
400            handle.abort();
401        }
402        if let Some(handle) = self.cleanup_task.take() {
403            handle.abort();
404        }
405    }
406
407    /// Check if any tasks are running
408    #[must_use]
409    pub const fn has_running_tasks(&self) -> bool {
410        self.response_handler.is_some()
411            || self.ping_handler.is_some()
412            || self.reconnection_monitor.is_some()
413            || self.cleanup_task.is_some()
414    }
415}
416
417impl Default for WebSocketTaskHandles {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423/// Reconnection manager with exponential backoff
424#[derive(Debug)]
425pub struct ReconnectionManager {
426    config: ReconnectionConfig,
427    current_attempt: AtomicU32,
428    backoff_ms: Arc<AtomicU64>,
429    should_stop: Arc<AtomicBool>,
430}
431
432impl ReconnectionManager {
433    /// Create a new reconnection manager
434    #[must_use]
435    pub fn new(config: ReconnectionConfig) -> Self {
436        let initial_backoff = config.initial_backoff.as_millis() as u64;
437        Self {
438            config,
439            current_attempt: AtomicU32::new(0),
440            backoff_ms: Arc::new(AtomicU64::new(initial_backoff)),
441            should_stop: Arc::new(AtomicBool::new(false)),
442        }
443    }
444
445    /// Reset the reconnection state
446    pub fn reset(&self) {
447        self.current_attempt.store(0, Ordering::Release);
448        self.backoff_ms.store(
449            self.config.initial_backoff.as_millis() as u64,
450            Ordering::Release,
451        );
452        self.should_stop.store(false, Ordering::Release);
453    }
454
455    /// Check if reconnection should continue
456    pub fn should_reconnect(&self) -> bool {
457        !self.should_stop.load(Ordering::Acquire)
458            && self.current_attempt.load(Ordering::Acquire) < self.config.max_attempts
459    }
460
461    /// Get the current backoff delay
462    pub fn get_backoff_delay(&self) -> Duration {
463        let current_backoff = self.backoff_ms.load(Ordering::Acquire);
464        let delay = if self.config.use_jitter {
465            // Add ±25% jitter
466            let jitter_range = current_backoff / 4;
467            let jitter = rng().random_range(0..jitter_range * 2);
468            current_backoff
469                .saturating_sub(jitter_range)
470                .saturating_add(jitter)
471        } else {
472            current_backoff
473        };
474        Duration::from_millis(delay)
475    }
476
477    /// Record a reconnection attempt and update backoff
478    pub fn record_attempt(&self) {
479        let attempt = self.current_attempt.fetch_add(1, Ordering::AcqRel);
480        if attempt < self.config.max_attempts {
481            let current_backoff = self.backoff_ms.load(Ordering::Acquire);
482            let new_backoff = ((current_backoff as f64) * self.config.backoff_multiplier) as u64;
483            let capped_backoff = new_backoff.min(self.config.max_backoff.as_millis() as u64);
484            self.backoff_ms.store(capped_backoff, Ordering::Release);
485        }
486    }
487
488    /// Stop reconnection attempts
489    pub fn stop(&self) {
490        self.should_stop.store(true, Ordering::Release);
491    }
492
493    /// Get current attempt count
494    pub fn get_attempt_count(&self) -> u32 {
495        self.current_attempt.load(Ordering::Acquire)
496    }
497}
498
499/// Common WebSocket trading interface
500///
501/// This trait defines the standard interface that all exchange WebSocket
502/// trading implementations should follow for consistent behavior.
503#[async_trait]
504pub trait WebSocketTrader: Send + Sync {
505    /// Exchange-specific configuration type
506    type Config: Send + Sync;
507
508    /// Connect to the exchange WebSocket API
509    async fn connect(&self, report_tx: Sender<ExecutionReport>) -> Result<()>;
510
511    /// Disconnect from the exchange
512    async fn disconnect(&self) -> Result<()>;
513
514    /// Get current connection health information
515    fn get_connection_health(&self) -> ConnectionHealth;
516
517    /// Check if connected and ready for trading
518    fn is_authenticated(&self) -> bool;
519
520    /// Send a ping message to keep the connection alive
521    async fn send_ping(&self) -> Result<()>;
522
523    /// Get current rate limit usage
524    fn get_rate_limit_usage(&self) -> (usize, usize);
525}
526
527/// Helper function to create exponential backoff with jitter
528#[must_use]
529pub fn calculate_backoff_with_jitter(
530    base_delay: Duration,
531    attempt: u32,
532    max_delay: Duration,
533    multiplier: f64,
534) -> Duration {
535    let delay_ms = (base_delay.as_millis() as f64 * multiplier.powi(attempt as i32)) as u64;
536    let capped_delay = delay_ms.min(max_delay.as_millis() as u64);
537
538    // Add ±25% jitter
539    let jitter_range = capped_delay / 4;
540    let jitter = rng().random_range(0..jitter_range * 2);
541    let final_delay = capped_delay
542        .saturating_sub(jitter_range)
543        .saturating_add(jitter);
544
545    Duration::from_millis(final_delay)
546}
547
548/// Utility function to validate WebSocket response structure
549pub fn validate_websocket_response(response: &JsonValue) -> Result<()> {
550    // Basic validation that applies to most exchange responses
551    if response.get("error").is_some() {
552        let error_msg = response
553            .get("error")
554            .and_then(|e| e.get("message"))
555            .and_then(|m| m.as_str())
556            .unwrap_or("Unknown WebSocket error");
557        return Err(anyhow!("WebSocket error: {}", error_msg));
558    }
559    Ok(())
560}