1use futures_util::{SinkExt, StreamExt};
7use parking_lot::RwLock;
8use quanta::{Clock, Instant};
9use smartstring::alias::String;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{mpsc, watch};
13use tokio::task::JoinHandle;
14use tokio::time::{sleep, timeout};
15use yawc::{Options, WebSocket, frame::FrameView};
16
17use super::batch::{BatchProcessingMetrics, BatchProcessor};
18use super::exchanges::ExchangeConfig;
19use super::heartbeat::HeartbeatMonitor;
20use super::stats::SharedStats;
21use super::{ConnectionState, Message, WebSocketConfig, WebSocketError, WebSocketResult};
22
23pub type WebSocketSink = futures_util::stream::SplitSink<WebSocket, FrameView>;
25pub type WebSocketStream = futures_util::stream::SplitStream<WebSocket>;
27
28#[derive(Debug)]
30struct ReconnectState {
31 attempts: u32,
33 backoff_delay_milliseconds: u64,
35 last_attempt: Instant,
37 last_success: Option<Instant>,
39 consecutive_failures: u32,
41 jitter_factor: f64,
43}
44
45impl Default for ReconnectState {
46 fn default() -> Self {
47 Self {
48 attempts: 0,
49 backoff_delay_milliseconds: 100,
50 last_attempt: Instant::now(),
51 last_success: None,
52 consecutive_failures: 0,
53 jitter_factor: 0.1,
54 }
55 }
56}
57
58#[derive(Debug)]
60#[repr(align(64))] pub struct WebSocketConnector {
62 config: WebSocketConfig,
64 reconnect_state: ReconnectState,
66 clock: Clock,
68 stats: SharedStats,
70 connection_status: Arc<RwLock<ConnectionState>>,
72 batch_processor: BatchProcessor,
74 heartbeat_monitor: Option<HeartbeatMonitor>,
76 current_url_index: Arc<RwLock<usize>>,
78}
79
80impl WebSocketConnector {
81 #[must_use]
83 pub fn new(
84 config: WebSocketConfig,
85 stats: SharedStats,
86 connection_status: Arc<RwLock<ConnectionState>>,
87 ) -> Self {
88 let heartbeat_monitor = if config.heartbeat_interval_milliseconds > 0 {
90 Some(HeartbeatMonitor::new(
91 config.heartbeat_interval_milliseconds,
92 config.heartbeat_timeout_milliseconds,
93 config.max_missed_heartbeats,
94 connection_status.clone(),
95 stats.clone(),
96 ))
97 } else {
98 None
99 };
100
101 Self {
102 batch_processor: BatchProcessor::new(config.batch_size),
103 config,
104 reconnect_state: ReconnectState::default(),
105 clock: Clock::new(),
106 stats,
107 connection_status,
108 heartbeat_monitor,
109 current_url_index: Arc::new(RwLock::new(0)),
110 }
111 }
112
113 pub const fn config(&self) -> &WebSocketConfig {
115 &self.config
116 }
117
118 pub fn get_batch_metrics(&self) -> BatchProcessingMetrics {
120 self.batch_processor.get_metrics()
121 }
122
123 pub fn create_websocket_options(exchange_config: Option<&ExchangeConfig>) -> Options {
125 let mut options = Options::default();
126
127 if let Some(config) = exchange_config
128 && config.compression.enabled
129 {
130 options.compression = Some(Default::default());
133 }
134
135 options
136 }
137
138 pub async fn process_message_batch<T, P, R>(
140 &self,
141 messages: Vec<Message>,
142 process_message: P,
143 sender: &mpsc::Sender<T>,
144 ) -> WebSocketResult<usize>
145 where
146 T: Send + 'static,
147 P: Fn(&str, &Clock) -> Option<R> + Send + Sync,
148 R: Into<T> + Send,
149 {
150 if messages.is_empty() {
151 return Ok(0);
152 }
153
154 let start_time = self.clock.now();
155 let message_count = messages.len();
156 let mut processed_count = 0;
157 let mut results = Vec::with_capacity(message_count);
158
159 for message in messages {
161 match message {
162 Message::Text(text) => {
163 {
165 let mut s = self.stats.write();
166 s.messages_received += 1;
167 s.bytes_received += text.len() as u64;
168 s.last_message_time = self.clock.raw();
169 }
170
171 if let Some(result) = process_message(&text, &self.clock) {
173 results.push(result.into());
174 processed_count += 1;
175 }
176 }
177 Message::Binary(bin) => {
178 let text: String = std::string::String::from_utf8_lossy(&bin)
179 .into_owned()
180 .into();
181
182 {
184 let mut s = self.stats.write();
185 s.messages_received += 1;
186 s.bytes_received += bin.len() as u64;
187 s.last_message_time = self.clock.raw();
188 }
189
190 if let Some(result) = process_message(&text, &self.clock) {
192 results.push(result.into());
193 processed_count += 1;
194 }
195 }
196 Message::Pong(_) => {
197 self.stats.write().last_pong_time = self.clock.raw();
198 }
199 Message::Close(close_frame) => {
200 log::info!("WebSocket connection closed: {close_frame:?}");
201 *self.connection_status.write() = ConnectionState::Disconnected;
202 }
203 Message::Ping(data) => {
204 log::trace!("Received ping message with {} bytes", data.len());
205 }
206 Message::Frame(_) => {
207 log::trace!("Received frame message");
208 }
209 }
210 }
211
212 for result in results {
214 if let Err(e) = sender.send(result).await {
215 log::error!("Failed to send processed message: {e}");
216 return Err(WebSocketError::MessageProcessingError(e.to_string()));
217 }
218 }
219
220 let elapsed_ns = (self.clock.now() - start_time).as_nanos() as u64;
222 self.batch_processor
223 .update_metrics(message_count, elapsed_ns);
224
225 Ok(processed_count)
226 }
227
228 pub async fn connect_with_retry(
230 &mut self,
231 url: &str,
232 ) -> WebSocketResult<(WebSocketSink, WebSocketStream)> {
233 let now = Instant::now();
235 if now.duration_since(self.reconnect_state.last_attempt)
236 < Duration::from_millis(self.reconnect_state.backoff_delay_milliseconds)
237 {
238 sleep(Duration::from_millis(
239 self.reconnect_state.backoff_delay_milliseconds,
240 ))
241 .await;
242 }
243
244 *self.connection_status.write() = ConnectionState::Connecting;
246 self.reconnect_state.last_attempt = Instant::now();
247
248 let current_url =
250 if self.config.enable_session_failover && !self.config.failover_urls.is_empty() {
251 let current_index = *self.current_url_index.read();
252 if current_index == 0 {
253 url.to_string()
254 } else {
255 let failover_index = (current_index - 1) % self.config.failover_urls.len();
256 self.config.failover_urls[failover_index].clone()
257 }
258 } else {
259 url.to_string()
260 };
261
262 let parsed_url = current_url
264 .parse()
265 .map_err(|e| WebSocketError::ConnectionError(format!("Invalid URL: {e}")))?;
266
267 let options = Self::create_websocket_options(None); let jitter_factor = self.reconnect_state.jitter_factor;
272 let jitter_range = self.reconnect_state.backoff_delay_milliseconds as f64 * jitter_factor;
273 let now_ns = self.clock.raw();
274 let jitter = ((now_ns % 1000) as f64 / 1000.0 * jitter_range) as u64;
275
276 match timeout(
278 self.config.connect_timeout,
279 WebSocket::connect(parsed_url).with_options(options),
280 )
281 .await
282 {
283 Ok(Ok(ws_stream)) => {
284 self.reconnect_state.attempts = 0;
286 self.reconnect_state.backoff_delay_milliseconds =
287 self.config.reconnect.initial_delay.as_millis() as u64;
288 self.reconnect_state.consecutive_failures = 0;
289 self.reconnect_state.last_success = Some(Instant::now());
290
291 *self.connection_status.write() = ConnectionState::Connected;
293 self.stats.write().connected_time = self.clock.raw();
294
295 if let Some(ref monitor) = self.heartbeat_monitor {
297 monitor.reset();
298 }
299
300 Ok(ws_stream.split())
302 }
303 Ok(Err(e)) => {
304 *self.connection_status.write() = ConnectionState::Error;
306 self.reconnect_state.consecutive_failures += 1;
307
308 if self.config.enable_session_failover && !self.config.failover_urls.is_empty() {
310 {
311 let mut current_index = self.current_url_index.write();
312 *current_index =
313 (*current_index + 1) % (self.config.failover_urls.len() + 1);
314 }
315
316 let current_index = *self.current_url_index.read();
317 if current_index == 0 {
318 log::warn!("All failover URLs failed, resetting to primary");
319 } else {
320 log::info!("Trying failover URL index: {current_index}");
321 return Box::pin(self.connect_with_retry(url)).await;
322 }
323 }
324
325 if self.config.reconnect.enabled {
327 self.reconnect_state.attempts += 1;
328
329 if self.config.reconnect.max_attempts > 0
330 && self.reconnect_state.attempts >= self.config.reconnect.max_attempts
331 {
332 return Err(WebSocketError::ConnectionError(format!(
333 "Maximum reconnection attempts exceeded: {e}"
334 )));
335 }
336
337 let consecutive_factor =
338 std::cmp::min(self.reconnect_state.consecutive_failures, 5) as f64 * 0.5
339 + 1.0;
340
341 self.reconnect_state.backoff_delay_milliseconds = std::cmp::min(
342 ((self.reconnect_state.backoff_delay_milliseconds as f64
343 * consecutive_factor
344 * self.config.reconnect.multiplier) as u64)
345 .saturating_add(jitter),
346 self.config.reconnect.max_delay.as_millis() as u64,
347 );
348
349 log::warn!(
350 "WebSocket connection failed, attempt {}/{}. Next retry in {}ms. Error: {}",
351 self.reconnect_state.attempts,
352 if self.config.reconnect.max_attempts > 0 {
353 self.config.reconnect.max_attempts.to_string()
354 } else {
355 "∞".to_string()
356 },
357 self.reconnect_state.backoff_delay_milliseconds,
358 e
359 );
360 }
361
362 Err(WebSocketError::ConnectionError(format!(
363 "Failed to connect: {e}"
364 )))
365 }
366 Err(_) => {
367 *self.connection_status.write() = ConnectionState::Error;
369 self.reconnect_state.consecutive_failures += 1;
370 self.reconnect_state.attempts += 1;
371
372 if self.config.enable_session_failover && !self.config.failover_urls.is_empty() {
374 {
375 let mut current_index = self.current_url_index.write();
376 *current_index =
377 (*current_index + 1) % (self.config.failover_urls.len() + 1);
378 }
379
380 let current_index = *self.current_url_index.read();
381 if current_index == 0 {
382 log::warn!("All failover URLs failed, resetting to primary");
383 } else {
384 log::info!("Trying failover URL index: {current_index}");
385 return Box::pin(self.connect_with_retry(url)).await;
386 }
387 }
388
389 self.reconnect_state.backoff_delay_milliseconds = std::cmp::min(
391 self.reconnect_state.backoff_delay_milliseconds * 2,
392 self.config.reconnect.max_delay.as_millis() as u64,
393 );
394
395 Err(WebSocketError::Timeout(
396 self.config.connect_timeout.as_millis() as u64,
397 ))
398 }
399 }
400 }
401
402 pub fn create_websocket_task<T, P, R>(
404 &self,
405 url: String,
406 subscription_message: String,
407 process_message: P,
408 sender: mpsc::Sender<T>,
409 stop_receiver: &mut watch::Receiver<bool>,
410 ) -> JoinHandle<()>
411 where
412 T: Send + 'static,
413 P: Fn(&str, &Clock) -> Option<R> + Send + Sync + 'static,
414 R: Into<T> + Send + 'static,
415 {
416 let clock = self.clock.clone();
418 let timeout_ms = self.config.timeout.as_millis() as u64;
419 let stats = self.stats.clone();
420 let connection_status = self.connection_status.clone();
421 let reconnect_config = self.config.reconnect.clone();
422 let ping_interval = self.config.ping_interval;
423 let custom_ping = self.config.custom_ping_message.clone();
424 let custom_pong = self.config.custom_pong_response.clone();
425 let _use_compression = self.config.compression.enabled;
426 let mut stop_receiver = stop_receiver.clone();
427
428 tokio::spawn(async move {
429 let mut reconnect_attempts = 0;
430 let mut backoff_delay = reconnect_config.initial_delay.as_millis() as u64;
431
432 'connection: loop {
433 if *stop_receiver.borrow() {
435 break;
436 }
437
438 let parsed_url = match url.parse() {
440 Ok(u) => u,
441 Err(e) => {
442 log::error!("Failed to parse URL: {e}");
443 tokio::time::sleep(Duration::from_millis(backoff_delay)).await;
444 continue;
445 }
446 };
447
448 let options = WebSocketConnector::create_websocket_options(None);
450
451 match WebSocket::connect(parsed_url).with_options(options).await {
453 Ok(ws_stream) => {
454 reconnect_attempts = 0;
456 backoff_delay = reconnect_config.initial_delay.as_millis() as u64;
457
458 *connection_status.write() = ConnectionState::Connected;
460 stats.write().connected_time = clock.raw();
461
462 let (mut websocket_sender, mut websocket_receiver) = ws_stream.split();
464
465 if let Err(e) = websocket_sender
467 .send(FrameView::text(subscription_message.to_string()))
468 .await
469 {
470 log::error!("Failed to send subscription: {e}");
471 tokio::time::sleep(Duration::from_millis(backoff_delay)).await;
472 continue 'connection;
473 }
474
475 let mut ping_interval = tokio::time::interval(ping_interval);
477
478 loop {
480 tokio::select! {
481 _ = stop_receiver.changed() => {
483 if *stop_receiver.borrow() {
484 break 'connection;
485 }
486 }
487
488 _ = ping_interval.tick() => {
490 let ping_frame = if let Some(ref custom) = custom_ping {
491 FrameView::text(custom.clone())
492 } else {
493 FrameView::ping(vec![])
494 };
495
496 if let Err(e) = websocket_sender.send(ping_frame).await {
497 log::error!("Failed to send ping: {e}");
498 break;
499 }
500 stats.write().last_ping_time = clock.raw();
501 }
502
503 message_result = timeout(
505 Duration::from_millis(timeout_ms),
506 websocket_receiver.next()
507 ) => {
508 match message_result {
509 Ok(Some(frame)) => {
510 let message = Message::from_frame_view(frame);
511
512 match &message {
513 Message::Text(text) => {
514 {
516 let mut s = stats.write();
517 s.messages_received += 1;
518 s.bytes_received += text.len() as u64;
519 s.last_message_time = clock.raw();
520 }
521
522 if let Some(result) = process_message(text, &clock)
524 && let Err(e) = sender.send(result.into()).await {
525 log::error!("Failed to send message: {e}");
526 }
527 }
528 Message::Binary(bin) => {
529 let text = std::string::String::from_utf8_lossy(bin);
530
531 {
533 let mut s = stats.write();
534 s.messages_received += 1;
535 s.bytes_received += bin.len() as u64;
536 s.last_message_time = clock.raw();
537 }
538
539 if let Some(result) = process_message(&text, &clock)
541 && let Err(e) = sender.send(result.into()).await {
542 log::error!("Failed to send message: {e}");
543 }
544 }
545 Message::Pong(_) => {
546 stats.write().last_pong_time = clock.raw();
547 }
548 Message::Close(close_frame) => {
549 log::info!("WebSocket closed: {close_frame:?}");
550 *connection_status.write() = ConnectionState::Disconnected;
551 break;
552 }
553 Message::Ping(_) => {
554 let pong_frame = if let Some(ref custom) = custom_pong {
555 FrameView::text(custom.clone())
556 } else {
557 FrameView::pong(vec![])
558 };
559
560 if let Err(e) = websocket_sender.send(pong_frame).await {
561 log::error!("Failed to send pong: {e}");
562 break;
563 }
564 }
565 &Message::Frame(_) => {
566 log::trace!("Received frame message");
567 }
568 }
569 }
570 Ok(None) => {
571 log::info!("WebSocket connection closed");
572 break;
573 }
574 Err(_) => {
575 log::error!("WebSocket message timeout");
576 break;
577 }
578 }
579 }
580 }
581 }
582 }
583 Err(e) => {
584 log::error!("Failed to connect: {e}");
585 *connection_status.write() = ConnectionState::Error;
586 }
587 }
588
589 if reconnect_config.enabled
591 && (reconnect_config.max_attempts == 0
592 || reconnect_attempts < reconnect_config.max_attempts)
593 {
594 reconnect_attempts += 1;
595 *connection_status.write() = ConnectionState::Reconnecting;
596 stats.write().reconnections += 1;
597
598 backoff_delay = std::cmp::min(
600 (backoff_delay as f64 * reconnect_config.multiplier) as u64,
601 reconnect_config.max_delay.as_millis() as u64,
602 );
603
604 tokio::time::sleep(Duration::from_millis(backoff_delay)).await;
605 } else {
606 *connection_status.write() = ConnectionState::Disconnected;
607 break 'connection;
608 }
609 }
610 })
611 }
612}