1use std::collections::VecDeque;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
19use std::time::Duration;
20
21use anyhow::{Result, anyhow};
22use async_trait::async_trait;
23use flume::Sender;
24use quanta::Clock;
25use rand::{Rng, rng};
26use rusty_common::SmartString;
27use rusty_common::collections::FxHashMap;
28use rusty_common::websocket::connector::WebSocketSink;
29use simd_json::prelude::{ValueAsScalar, ValueObjectAccess};
30use simd_json::value::owned::Value as JsonValue;
31use tokio::task::JoinHandle;
32use uuid::Uuid;
33
34use crate::execution_engine::ExecutionReport;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41#[repr(u8)]
42pub enum WebSocketConnectionState {
43 Disconnected = 0,
45 Connecting = 1,
47 Connected = 2,
49 Authenticating = 3,
51 Authenticated = 4,
53 Disconnecting = 5,
55}
56
57impl From<u8> for WebSocketConnectionState {
58 fn from(value: u8) -> Self {
59 match value {
60 0 => Self::Disconnected,
61 1 => Self::Connecting,
62 2 => Self::Connected,
63 3 => Self::Authenticating,
64 4 => Self::Authenticated,
65 5 => Self::Disconnecting,
66 _ => Self::Disconnected,
67 }
68 }
69}
70
71impl From<WebSocketConnectionState> for u8 {
72 fn from(state: WebSocketConnectionState) -> Self {
73 state as Self
74 }
75}
76
77#[derive(Debug, Clone)]
81pub struct ConnectionHealth {
82 pub state: WebSocketConnectionState,
84 pub is_connected: bool,
86 pub is_authenticated: bool,
88 pub time_since_last_ping: Option<Duration>,
90 pub time_since_last_pong: Option<Duration>,
92 pub time_since_auth: Option<Duration>,
94 pub is_healthy: bool,
96 pub reconnection_attempts: u32,
98 pub next_reconnection_in: Option<Duration>,
100}
101
102#[derive(Debug, Clone)]
107pub struct ConnectionStateManager {
108 state: Arc<AtomicU8>,
109}
110
111impl ConnectionStateManager {
112 #[must_use]
114 pub fn new() -> Self {
115 Self {
116 state: Arc::new(AtomicU8::new(WebSocketConnectionState::Disconnected as u8)),
117 }
118 }
119
120 #[must_use]
122 pub fn get_state(&self) -> WebSocketConnectionState {
123 WebSocketConnectionState::from(self.state.load(Ordering::Acquire))
124 }
125
126 pub fn set_state(&self, new_state: WebSocketConnectionState) {
128 self.state.store(new_state as u8, Ordering::Release);
129 }
130
131 #[must_use]
133 pub fn is_connected(&self) -> bool {
134 matches!(
135 self.get_state(),
136 WebSocketConnectionState::Connected
137 | WebSocketConnectionState::Authenticating
138 | WebSocketConnectionState::Authenticated
139 )
140 }
141
142 #[must_use]
144 pub fn is_authenticated(&self) -> bool {
145 self.get_state() == WebSocketConnectionState::Authenticated
146 }
147
148 #[must_use]
150 pub fn try_transition(
151 &self,
152 from: WebSocketConnectionState,
153 to: WebSocketConnectionState,
154 ) -> bool {
155 self.state
156 .compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
157 .is_ok()
158 }
159}
160
161impl Default for ConnectionStateManager {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct HealthConfig {
170 pub ping_interval: Duration,
172 pub pong_timeout: Duration,
174 pub use_websocket_ping: bool,
176 pub use_app_level_ping: bool,
178}
179
180impl Default for HealthConfig {
181 fn default() -> Self {
182 Self {
183 ping_interval: Duration::from_secs(30),
184 pong_timeout: Duration::from_secs(10),
185 use_websocket_ping: true,
186 use_app_level_ping: false,
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct ReconnectionConfig {
194 pub initial_backoff: Duration,
196 pub max_backoff: Duration,
198 pub max_attempts: u32,
200 pub backoff_multiplier: f64,
202 pub use_jitter: bool,
204}
205
206impl Default for ReconnectionConfig {
207 fn default() -> Self {
208 Self {
209 initial_backoff: Duration::from_secs(1),
210 max_backoff: Duration::from_secs(60),
211 max_attempts: 10,
212 backoff_multiplier: 2.0,
213 use_jitter: true,
214 }
215 }
216}
217
218#[derive(Debug, Clone)]
220pub struct RateLimitConfig {
221 pub max_requests: u32,
223 pub window_duration: Duration,
225 pub max_batch_size: usize,
227}
228
229impl Default for RateLimitConfig {
230 fn default() -> Self {
231 Self {
232 max_requests: 300,
233 window_duration: Duration::from_secs(10),
234 max_batch_size: 50,
235 }
236 }
237}
238
239#[async_trait]
241pub trait AuthenticationMechanism: Send + Sync {
242 async fn authenticate(&self, ws_sink: &mut WebSocketSink) -> Result<()>;
244
245 fn needs_refresh(&self) -> bool;
247
248 async fn refresh(&self) -> Result<()>;
250
251 fn auth_timeout(&self) -> Duration {
253 Duration::from_secs(10)
254 }
255}
256
257#[derive(Debug)]
262pub struct UnifiedRateLimiter {
263 request_times: VecDeque<u64>,
264 config: RateLimitConfig,
265 clock: Clock,
266}
267
268impl UnifiedRateLimiter {
269 #[must_use]
271 pub const fn new(config: RateLimitConfig, clock: Clock) -> Self {
272 Self {
273 request_times: VecDeque::new(),
274 config,
275 clock,
276 }
277 }
278
279 fn cleanup_expired(&mut self) {
281 let now = self.clock.raw() / 1_000_000; let window_start = now.saturating_sub(self.config.window_duration.as_millis() as u64);
283
284 while let Some(&front) = self.request_times.front() {
285 if front < window_start {
286 self.request_times.pop_front();
287 } else {
288 break;
289 }
290 }
291 }
292
293 pub fn can_make_requests(&mut self, count: usize) -> bool {
295 self.cleanup_expired();
296 self.request_times.len() + count <= self.config.max_requests as usize
297 }
298
299 pub fn record_requests(&mut self, count: usize) {
301 let now = self.clock.raw() / 1_000_000; for _ in 0..count {
303 self.request_times.push_back(now);
304 }
305 }
306
307 pub fn get_usage(&mut self) -> (usize, usize) {
309 self.cleanup_expired();
310 (self.request_times.len(), self.config.max_requests as usize)
311 }
312
313 pub async fn acquire_permits(&mut self, count: usize) -> Result<()> {
315 while !self.can_make_requests(count) {
316 tokio::time::sleep(Duration::from_millis(100)).await;
317 }
318 self.record_requests(count);
319 Ok(())
320 }
321}
322
323#[derive(Debug, Clone)]
325pub enum RequestId {
326 Sequential(u64),
328 Uuid(SmartString),
330}
331
332impl RequestId {
333 pub fn new_sequential(counter: &AtomicU64) -> Self {
335 Self::Sequential(counter.fetch_add(1, Ordering::Relaxed))
336 }
337
338 #[must_use]
340 pub fn new_uuid() -> Self {
341 Self::Uuid(Uuid::new_v4().to_string().into())
342 }
343
344 #[must_use]
346 pub fn as_string(&self) -> SmartString {
347 match self {
348 Self::Sequential(id) => id.to_string().into(),
349 Self::Uuid(id) => id.clone(),
350 }
351 }
352
353 #[must_use]
355 pub fn as_json_value(&self) -> JsonValue {
356 match self {
357 Self::Sequential(id) => JsonValue::from(*id as i64),
358 Self::Uuid(id) => JsonValue::from(id.as_str()),
359 }
360 }
361}
362
363pub type PendingRequestsMap = FxHashMap<SmartString, tokio::sync::oneshot::Sender<JsonValue>>;
365
366#[derive(Debug)]
368pub struct WebSocketTaskHandles {
369 pub response_handler: Option<JoinHandle<()>>,
371 pub ping_handler: Option<JoinHandle<()>>,
373 pub reconnection_monitor: Option<JoinHandle<()>>,
375 pub cleanup_task: Option<JoinHandle<()>>,
377}
378
379impl WebSocketTaskHandles {
380 #[must_use]
382 pub const fn new() -> Self {
383 Self {
384 response_handler: None,
385 ping_handler: None,
386 reconnection_monitor: None,
387 cleanup_task: None,
388 }
389 }
390
391 pub async fn abort_all(&mut self) {
393 if let Some(handle) = self.response_handler.take() {
394 handle.abort();
395 }
396 if let Some(handle) = self.ping_handler.take() {
397 handle.abort();
398 }
399 if let Some(handle) = self.reconnection_monitor.take() {
400 handle.abort();
401 }
402 if let Some(handle) = self.cleanup_task.take() {
403 handle.abort();
404 }
405 }
406
407 #[must_use]
409 pub const fn has_running_tasks(&self) -> bool {
410 self.response_handler.is_some()
411 || self.ping_handler.is_some()
412 || self.reconnection_monitor.is_some()
413 || self.cleanup_task.is_some()
414 }
415}
416
417impl Default for WebSocketTaskHandles {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[derive(Debug)]
425pub struct ReconnectionManager {
426 config: ReconnectionConfig,
427 current_attempt: AtomicU32,
428 backoff_ms: Arc<AtomicU64>,
429 should_stop: Arc<AtomicBool>,
430}
431
432impl ReconnectionManager {
433 #[must_use]
435 pub fn new(config: ReconnectionConfig) -> Self {
436 let initial_backoff = config.initial_backoff.as_millis() as u64;
437 Self {
438 config,
439 current_attempt: AtomicU32::new(0),
440 backoff_ms: Arc::new(AtomicU64::new(initial_backoff)),
441 should_stop: Arc::new(AtomicBool::new(false)),
442 }
443 }
444
445 pub fn reset(&self) {
447 self.current_attempt.store(0, Ordering::Release);
448 self.backoff_ms.store(
449 self.config.initial_backoff.as_millis() as u64,
450 Ordering::Release,
451 );
452 self.should_stop.store(false, Ordering::Release);
453 }
454
455 pub fn should_reconnect(&self) -> bool {
457 !self.should_stop.load(Ordering::Acquire)
458 && self.current_attempt.load(Ordering::Acquire) < self.config.max_attempts
459 }
460
461 pub fn get_backoff_delay(&self) -> Duration {
463 let current_backoff = self.backoff_ms.load(Ordering::Acquire);
464 let delay = if self.config.use_jitter {
465 let jitter_range = current_backoff / 4;
467 let jitter = rng().random_range(0..jitter_range * 2);
468 current_backoff
469 .saturating_sub(jitter_range)
470 .saturating_add(jitter)
471 } else {
472 current_backoff
473 };
474 Duration::from_millis(delay)
475 }
476
477 pub fn record_attempt(&self) {
479 let attempt = self.current_attempt.fetch_add(1, Ordering::AcqRel);
480 if attempt < self.config.max_attempts {
481 let current_backoff = self.backoff_ms.load(Ordering::Acquire);
482 let new_backoff = ((current_backoff as f64) * self.config.backoff_multiplier) as u64;
483 let capped_backoff = new_backoff.min(self.config.max_backoff.as_millis() as u64);
484 self.backoff_ms.store(capped_backoff, Ordering::Release);
485 }
486 }
487
488 pub fn stop(&self) {
490 self.should_stop.store(true, Ordering::Release);
491 }
492
493 pub fn get_attempt_count(&self) -> u32 {
495 self.current_attempt.load(Ordering::Acquire)
496 }
497}
498
499#[async_trait]
504pub trait WebSocketTrader: Send + Sync {
505 type Config: Send + Sync;
507
508 async fn connect(&self, report_tx: Sender<ExecutionReport>) -> Result<()>;
510
511 async fn disconnect(&self) -> Result<()>;
513
514 fn get_connection_health(&self) -> ConnectionHealth;
516
517 fn is_authenticated(&self) -> bool;
519
520 async fn send_ping(&self) -> Result<()>;
522
523 fn get_rate_limit_usage(&self) -> (usize, usize);
525}
526
527#[must_use]
529pub fn calculate_backoff_with_jitter(
530 base_delay: Duration,
531 attempt: u32,
532 max_delay: Duration,
533 multiplier: f64,
534) -> Duration {
535 let delay_ms = (base_delay.as_millis() as f64 * multiplier.powi(attempt as i32)) as u64;
536 let capped_delay = delay_ms.min(max_delay.as_millis() as u64);
537
538 let jitter_range = capped_delay / 4;
540 let jitter = rng().random_range(0..jitter_range * 2);
541 let final_delay = capped_delay
542 .saturating_sub(jitter_range)
543 .saturating_add(jitter);
544
545 Duration::from_millis(final_delay)
546}
547
548pub fn validate_websocket_response(response: &JsonValue) -> Result<()> {
550 if response.get("error").is_some() {
552 let error_msg = response
553 .get("error")
554 .and_then(|e| e.get("message"))
555 .and_then(|m| m.as_str())
556 .unwrap_or("Unknown WebSocket error");
557 return Err(anyhow!("WebSocket error: {}", error_msg));
558 }
559 Ok(())
560}