rusty_common/websocket/
client.rs

1//! WebSocket client implementation
2//!
3//! Provides a high-performance WebSocket client using yawc.
4
5use futures_util::{SinkExt, StreamExt};
6use parking_lot::RwLock;
7use std::sync::Arc;
8use tokio::sync::{mpsc, watch};
9use tokio::time::timeout;
10use url::Url;
11use yawc::{DeflateOptions, Options, WebSocket};
12
13use super::{
14    Message, MessageHandler, ReconnectStrategy, WebSocketConfig, WebSocketError, WebSocketResult,
15};
16
17/// WebSocket connection state
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ConnectionState {
20    /// Not connected
21    Disconnected,
22
23    /// Connecting
24    Connecting,
25
26    /// Connected
27    Connected,
28
29    /// Reconnecting
30    Reconnecting,
31
32    /// Error state
33    Error,
34
35    /// Closed
36    Closed,
37}
38
39/// WebSocket client statistics
40#[derive(Debug, Default, Clone)]
41pub struct ClientStats {
42    /// Messages sent
43    pub messages_sent: u64,
44
45    /// Messages received
46    pub messages_received: u64,
47
48    /// Bytes sent
49    pub bytes_sent: u64,
50
51    /// Bytes received
52    pub bytes_received: u64,
53
54    /// Connection attempts
55    pub connection_attempts: u32,
56
57    /// Successful connections
58    pub successful_connections: u32,
59
60    /// Reconnection count
61    pub reconnections: u32,
62
63    /// Last ping timestamp (nanoseconds)
64    pub last_ping_time: u64,
65
66    /// Last pong timestamp (nanoseconds)
67    pub last_pong_time: u64,
68}
69
70/// WebSocket stream type
71pub type WebSocketStream = WebSocket;
72
73/// WebSocket client
74pub struct WebSocketClient {
75    /// Configuration
76    config: WebSocketConfig,
77
78    /// Current connection state
79    state: Arc<RwLock<ConnectionState>>,
80
81    /// Statistics
82    stats: Arc<RwLock<ClientStats>>,
83
84    /// Reconnection strategy
85    reconnect: ReconnectStrategy,
86
87    /// Shutdown signal
88    shutdown_tx: watch::Sender<bool>,
89    shutdown_rx: watch::Receiver<bool>,
90
91    /// Send channel
92    send_tx: Option<mpsc::UnboundedSender<Message>>,
93}
94
95impl WebSocketClient {
96    /// Create yawc Options from our WebSocketConfig
97    fn create_yawc_options(&self) -> Options {
98        let mut options = Options::default();
99
100        // Configure compression
101        if self.config.compression.enabled {
102            // Set compression option (yawc currently only supports enabling/disabling)
103            // Window bits configuration is not exposed in the current yawc API
104            options.compression = Some(DeflateOptions::default());
105        }
106
107        // Configure size limits
108        options.max_payload_read = Some(self.config.max_message_size);
109        options.max_read_buffer = Some(self.config.max_message_size * 2); // Allow buffer for fragmented messages
110
111        // Note: yawc 0.2 doesn't expose ping_interval and pong_timeout in Options.
112        // Ping/pong is handled internally by yawc at the protocol level.
113        // For exchanges that require custom application-level ping messages,
114        // use the custom_ping_message field in WebSocketConfig.
115
116        options
117    }
118
119    /// Create a new WebSocket client
120    #[must_use]
121    pub fn new(config: WebSocketConfig) -> Self {
122        let (shutdown_tx, shutdown_rx) = watch::channel(false);
123
124        Self {
125            reconnect: ReconnectStrategy::new(config.reconnect.clone()),
126            config,
127            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
128            stats: Arc::new(RwLock::new(ClientStats::default())),
129            shutdown_tx,
130            shutdown_rx,
131            send_tx: None,
132        }
133    }
134
135    /// Connect and run the client with a message handler
136    pub async fn run<H: MessageHandler + 'static>(
137        &mut self,
138        mut handler: H,
139    ) -> WebSocketResult<()> {
140        let mut shutdown_rx = self.shutdown_rx.clone();
141
142        loop {
143            // Check for shutdown
144            if *shutdown_rx.borrow() {
145                break;
146            }
147
148            // Update state
149            *self.state.write() = ConnectionState::Connecting;
150            self.stats.write().connection_attempts += 1;
151
152            // Try to connect
153            match self.connect().await {
154                Ok(ws) => {
155                    // Update state
156                    *self.state.write() = ConnectionState::Connected;
157                    self.stats.write().successful_connections += 1;
158                    self.reconnect.reset();
159
160                    // Create a new channel for this connection (high-performance approach)
161                    let (conn_send_tx, conn_send_rx) = mpsc::unbounded_channel();
162
163                    // Store the sender in the client
164                    self.send_tx = Some(conn_send_tx.clone());
165
166                    // Set sender for handler so it can send messages
167                    // Note: This clone is necessary as both the client and handler may need
168                    // to send messages independently
169                    handler.set_sender(conn_send_tx);
170
171                    // Notify handler (after setting sender so it can send subscription messages)
172                    if let Err(e) = handler.on_connected().await {
173                        log::error!("Handler error on connected: {e}");
174                    }
175
176                    // Run the connection
177                    if let Err(e) = self.run_connection(ws, &mut handler, conn_send_rx).await {
178                        log::error!("Connection error: {e}");
179
180                        // Notify handler
181                        if let Err(e) = handler.on_error(e).await {
182                            log::error!("Handler error on error: {e}");
183                        }
184                    }
185
186                    // Clear send channel
187                    self.send_tx = None;
188
189                    // Notify handler
190                    if let Err(e) = handler.on_disconnected().await {
191                        log::error!("Handler error on disconnected: {e}");
192                    }
193                }
194                Err(e) => {
195                    log::error!("Failed to connect: {e}");
196                    *self.state.write() = ConnectionState::Error;
197                    self.send_tx = None;
198                }
199            }
200
201            // Check if we should reconnect
202            if !self.reconnect.should_reconnect() {
203                break;
204            }
205
206            // Get reconnection delay
207            if let Some(delay) = self.reconnect.next_delay() {
208                *self.state.write() = ConnectionState::Reconnecting;
209                self.stats.write().reconnections += 1;
210
211                log::info!(
212                    "Reconnecting in {:?} (attempt {})",
213                    delay,
214                    self.reconnect.attempts()
215                );
216
217                // Wait with cancellation
218                tokio::select! {
219                    _ = tokio::time::sleep(delay) => {}
220                    _ = shutdown_rx.changed() => {
221                        if *shutdown_rx.borrow() {
222                            break;
223                        }
224                    }
225                }
226            } else {
227                break;
228            }
229        }
230
231        *self.state.write() = ConnectionState::Closed;
232        Ok(())
233    }
234
235    /// Connect to the WebSocket server
236    async fn connect(&self) -> WebSocketResult<WebSocketStream> {
237        // Parse URL
238        let url = Url::parse(&self.config.url)?;
239
240        // Create yawc Options from our config
241        let options = self.create_yawc_options();
242
243        // Connect with timeout and options
244        let ws = timeout(
245            self.config.connect_timeout,
246            WebSocket::connect(url).with_options(options),
247        )
248        .await
249        .map_err(|_| WebSocketError::Timeout(self.config.connect_timeout.as_millis() as u64))?
250        .map_err(WebSocketError::from)?;
251
252        Ok(ws)
253    }
254
255    /// Run the WebSocket connection
256    async fn run_connection<H: MessageHandler>(
257        &self,
258        mut ws: WebSocketStream,
259        handler: &mut H,
260        mut send_rx: mpsc::UnboundedReceiver<Message>,
261    ) -> WebSocketResult<()> {
262        let mut shutdown_rx = self.shutdown_rx.clone();
263
264        loop {
265            tokio::select! {
266                // Shutdown signal
267                _ = shutdown_rx.changed() => {
268                    if *shutdown_rx.borrow() {
269                        // Send close frame
270                        ws.send(yawc::frame::FrameView::close(
271                            yawc::close::CloseCode::Iana(1000),
272                            "Client shutdown"
273                        )).await?;
274                        break;
275                    }
276                }
277
278                // Send message (high-performance unbounded channel)
279                Some(message) = send_rx.recv() => {
280                    // Track bytes for stats
281                    let bytes_sent = match &message {
282                        Message::Text(text) => text.len() as u64,
283                        Message::Binary(data) => data.len() as u64,
284                        _ => 0,
285                    };
286
287                    // Convert to owned data for yawc's Bytes requirement
288                    match message {
289                        Message::Text(text) => {
290                            // Convert SmartString to Vec<u8> then to Bytes
291                            let data = text.as_bytes().to_vec();
292                            ws.send(yawc::frame::FrameView::text(data)).await?;
293                        }
294                        Message::Binary(data) => {
295                            // Pass owned Vec<u8> directly, it will convert to Bytes
296                            ws.send(yawc::frame::FrameView::binary(data)).await?;
297                        }
298                        Message::Close(close_frame) => {
299                            match close_frame {
300                                Some((code, reason)) => {
301                                    ws.send(yawc::frame::FrameView::close(
302                                        yawc::close::CloseCode::Iana(code),
303                                        &reason
304                                    )).await?;
305                                }
306                                None => {
307                                    ws.send(yawc::frame::FrameView::close(
308                                        yawc::close::CloseCode::Iana(1000),
309                                        "Normal closure"
310                                    )).await?;
311                                }
312                            }
313                        }
314                        Message::Ping(data) => {
315                            ws.send(yawc::frame::FrameView::ping(data)).await?;
316                        }
317                        Message::Pong(data) => {
318                            ws.send(yawc::frame::FrameView::pong(data)).await?;
319                        }
320                        Message::Frame(frame) => {
321                            ws.send(frame).await?;
322                        }
323                    }
324
325                    // Update stats
326                    let mut stats = self.stats.write();
327                    stats.messages_sent += 1;
328                    stats.bytes_sent += bytes_sent;
329                }
330
331                // Note: Ping/pong is handled automatically by yawc when Options are configured
332
333                // Receive message
334                frame = timeout(self.config.timeout, ws.next()) => {
335                    match frame {
336                        Ok(Some(frame)) => {
337                            let message = Message::from_frame_view(frame);
338
339                            // Update stats
340                            {
341                                let mut stats = self.stats.write();
342                                stats.messages_received += 1;
343                                match &message {
344                                    Message::Text(text) => stats.bytes_received += text.len() as u64,
345                                    Message::Binary(data) => stats.bytes_received += data.len() as u64,
346                                    _ => {}
347                                }
348                            } // stats guard is dropped here
349
350                            // Handle message
351                            handler.on_message(message).await?;
352                        }
353                        Ok(None) => {
354                            // Connection closed
355                            break;
356                        }
357                        Err(_) => {
358                            // Timeout
359                            return Err(WebSocketError::Timeout(self.config.timeout.as_millis() as u64));
360                        }
361                    }
362                }
363            }
364        }
365
366        Ok(())
367    }
368
369    /// Send a message (high-performance, non-blocking)
370    pub async fn send(&self, message: Message) -> WebSocketResult<()> {
371        match &self.send_tx {
372            Some(tx) => {
373                // Use try_send for non-blocking operation
374                tx.send(message).map_err(|_| WebSocketError::NotConnected)?;
375                Ok(())
376            }
377            None => Err(WebSocketError::NotConnected),
378        }
379    }
380
381    /// Try to send a message (non-blocking, returns false if channel is full)
382    pub fn try_send(&self, message: Message) -> bool {
383        match &self.send_tx {
384            Some(tx) => {
385                // Since we use unbounded channel, this should always succeed
386                // unless the receiver is dropped (disconnected)
387                tx.send(message).is_ok()
388            }
389            None => false,
390        }
391    }
392
393    /// Get the current connection state
394    pub fn state(&self) -> ConnectionState {
395        *self.state.read()
396    }
397
398    /// Get statistics
399    pub fn stats(&self) -> ClientStats {
400        (*self.stats.read()).clone()
401    }
402
403    /// Shutdown the client
404    pub fn shutdown(&self) {
405        let _ = self.shutdown_tx.send(true);
406    }
407}
408
409/// WebSocket client builder
410pub struct WebSocketClientBuilder {
411    config: WebSocketConfig,
412}
413
414impl WebSocketClientBuilder {
415    /// Create a new builder
416    #[must_use]
417    pub const fn new(config: WebSocketConfig) -> Self {
418        Self { config }
419    }
420
421    /// Build the client
422    pub fn build(self) -> WebSocketClient {
423        WebSocketClient::new(self.config)
424    }
425}