rusty_common/websocket/
client.rs1use futures_util::{SinkExt, StreamExt};
6use parking_lot::RwLock;
7use std::sync::Arc;
8use tokio::sync::{mpsc, watch};
9use tokio::time::timeout;
10use url::Url;
11use yawc::{DeflateOptions, Options, WebSocket};
12
13use super::{
14 Message, MessageHandler, ReconnectStrategy, WebSocketConfig, WebSocketError, WebSocketResult,
15};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ConnectionState {
20 Disconnected,
22
23 Connecting,
25
26 Connected,
28
29 Reconnecting,
31
32 Error,
34
35 Closed,
37}
38
39#[derive(Debug, Default, Clone)]
41pub struct ClientStats {
42 pub messages_sent: u64,
44
45 pub messages_received: u64,
47
48 pub bytes_sent: u64,
50
51 pub bytes_received: u64,
53
54 pub connection_attempts: u32,
56
57 pub successful_connections: u32,
59
60 pub reconnections: u32,
62
63 pub last_ping_time: u64,
65
66 pub last_pong_time: u64,
68}
69
70pub type WebSocketStream = WebSocket;
72
73pub struct WebSocketClient {
75 config: WebSocketConfig,
77
78 state: Arc<RwLock<ConnectionState>>,
80
81 stats: Arc<RwLock<ClientStats>>,
83
84 reconnect: ReconnectStrategy,
86
87 shutdown_tx: watch::Sender<bool>,
89 shutdown_rx: watch::Receiver<bool>,
90
91 send_tx: Option<mpsc::UnboundedSender<Message>>,
93}
94
95impl WebSocketClient {
96 fn create_yawc_options(&self) -> Options {
98 let mut options = Options::default();
99
100 if self.config.compression.enabled {
102 options.compression = Some(DeflateOptions::default());
105 }
106
107 options.max_payload_read = Some(self.config.max_message_size);
109 options.max_read_buffer = Some(self.config.max_message_size * 2); options
117 }
118
119 #[must_use]
121 pub fn new(config: WebSocketConfig) -> Self {
122 let (shutdown_tx, shutdown_rx) = watch::channel(false);
123
124 Self {
125 reconnect: ReconnectStrategy::new(config.reconnect.clone()),
126 config,
127 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
128 stats: Arc::new(RwLock::new(ClientStats::default())),
129 shutdown_tx,
130 shutdown_rx,
131 send_tx: None,
132 }
133 }
134
135 pub async fn run<H: MessageHandler + 'static>(
137 &mut self,
138 mut handler: H,
139 ) -> WebSocketResult<()> {
140 let mut shutdown_rx = self.shutdown_rx.clone();
141
142 loop {
143 if *shutdown_rx.borrow() {
145 break;
146 }
147
148 *self.state.write() = ConnectionState::Connecting;
150 self.stats.write().connection_attempts += 1;
151
152 match self.connect().await {
154 Ok(ws) => {
155 *self.state.write() = ConnectionState::Connected;
157 self.stats.write().successful_connections += 1;
158 self.reconnect.reset();
159
160 let (conn_send_tx, conn_send_rx) = mpsc::unbounded_channel();
162
163 self.send_tx = Some(conn_send_tx.clone());
165
166 handler.set_sender(conn_send_tx);
170
171 if let Err(e) = handler.on_connected().await {
173 log::error!("Handler error on connected: {e}");
174 }
175
176 if let Err(e) = self.run_connection(ws, &mut handler, conn_send_rx).await {
178 log::error!("Connection error: {e}");
179
180 if let Err(e) = handler.on_error(e).await {
182 log::error!("Handler error on error: {e}");
183 }
184 }
185
186 self.send_tx = None;
188
189 if let Err(e) = handler.on_disconnected().await {
191 log::error!("Handler error on disconnected: {e}");
192 }
193 }
194 Err(e) => {
195 log::error!("Failed to connect: {e}");
196 *self.state.write() = ConnectionState::Error;
197 self.send_tx = None;
198 }
199 }
200
201 if !self.reconnect.should_reconnect() {
203 break;
204 }
205
206 if let Some(delay) = self.reconnect.next_delay() {
208 *self.state.write() = ConnectionState::Reconnecting;
209 self.stats.write().reconnections += 1;
210
211 log::info!(
212 "Reconnecting in {:?} (attempt {})",
213 delay,
214 self.reconnect.attempts()
215 );
216
217 tokio::select! {
219 _ = tokio::time::sleep(delay) => {}
220 _ = shutdown_rx.changed() => {
221 if *shutdown_rx.borrow() {
222 break;
223 }
224 }
225 }
226 } else {
227 break;
228 }
229 }
230
231 *self.state.write() = ConnectionState::Closed;
232 Ok(())
233 }
234
235 async fn connect(&self) -> WebSocketResult<WebSocketStream> {
237 let url = Url::parse(&self.config.url)?;
239
240 let options = self.create_yawc_options();
242
243 let ws = timeout(
245 self.config.connect_timeout,
246 WebSocket::connect(url).with_options(options),
247 )
248 .await
249 .map_err(|_| WebSocketError::Timeout(self.config.connect_timeout.as_millis() as u64))?
250 .map_err(WebSocketError::from)?;
251
252 Ok(ws)
253 }
254
255 async fn run_connection<H: MessageHandler>(
257 &self,
258 mut ws: WebSocketStream,
259 handler: &mut H,
260 mut send_rx: mpsc::UnboundedReceiver<Message>,
261 ) -> WebSocketResult<()> {
262 let mut shutdown_rx = self.shutdown_rx.clone();
263
264 loop {
265 tokio::select! {
266 _ = shutdown_rx.changed() => {
268 if *shutdown_rx.borrow() {
269 ws.send(yawc::frame::FrameView::close(
271 yawc::close::CloseCode::Iana(1000),
272 "Client shutdown"
273 )).await?;
274 break;
275 }
276 }
277
278 Some(message) = send_rx.recv() => {
280 let bytes_sent = match &message {
282 Message::Text(text) => text.len() as u64,
283 Message::Binary(data) => data.len() as u64,
284 _ => 0,
285 };
286
287 match message {
289 Message::Text(text) => {
290 let data = text.as_bytes().to_vec();
292 ws.send(yawc::frame::FrameView::text(data)).await?;
293 }
294 Message::Binary(data) => {
295 ws.send(yawc::frame::FrameView::binary(data)).await?;
297 }
298 Message::Close(close_frame) => {
299 match close_frame {
300 Some((code, reason)) => {
301 ws.send(yawc::frame::FrameView::close(
302 yawc::close::CloseCode::Iana(code),
303 &reason
304 )).await?;
305 }
306 None => {
307 ws.send(yawc::frame::FrameView::close(
308 yawc::close::CloseCode::Iana(1000),
309 "Normal closure"
310 )).await?;
311 }
312 }
313 }
314 Message::Ping(data) => {
315 ws.send(yawc::frame::FrameView::ping(data)).await?;
316 }
317 Message::Pong(data) => {
318 ws.send(yawc::frame::FrameView::pong(data)).await?;
319 }
320 Message::Frame(frame) => {
321 ws.send(frame).await?;
322 }
323 }
324
325 let mut stats = self.stats.write();
327 stats.messages_sent += 1;
328 stats.bytes_sent += bytes_sent;
329 }
330
331 frame = timeout(self.config.timeout, ws.next()) => {
335 match frame {
336 Ok(Some(frame)) => {
337 let message = Message::from_frame_view(frame);
338
339 {
341 let mut stats = self.stats.write();
342 stats.messages_received += 1;
343 match &message {
344 Message::Text(text) => stats.bytes_received += text.len() as u64,
345 Message::Binary(data) => stats.bytes_received += data.len() as u64,
346 _ => {}
347 }
348 } handler.on_message(message).await?;
352 }
353 Ok(None) => {
354 break;
356 }
357 Err(_) => {
358 return Err(WebSocketError::Timeout(self.config.timeout.as_millis() as u64));
360 }
361 }
362 }
363 }
364 }
365
366 Ok(())
367 }
368
369 pub async fn send(&self, message: Message) -> WebSocketResult<()> {
371 match &self.send_tx {
372 Some(tx) => {
373 tx.send(message).map_err(|_| WebSocketError::NotConnected)?;
375 Ok(())
376 }
377 None => Err(WebSocketError::NotConnected),
378 }
379 }
380
381 pub fn try_send(&self, message: Message) -> bool {
383 match &self.send_tx {
384 Some(tx) => {
385 tx.send(message).is_ok()
388 }
389 None => false,
390 }
391 }
392
393 pub fn state(&self) -> ConnectionState {
395 *self.state.read()
396 }
397
398 pub fn stats(&self) -> ClientStats {
400 (*self.stats.read()).clone()
401 }
402
403 pub fn shutdown(&self) {
405 let _ = self.shutdown_tx.send(true);
406 }
407}
408
409pub struct WebSocketClientBuilder {
411 config: WebSocketConfig,
412}
413
414impl WebSocketClientBuilder {
415 #[must_use]
417 pub const fn new(config: WebSocketConfig) -> Self {
418 Self { config }
419 }
420
421 pub fn build(self) -> WebSocketClient {
423 WebSocketClient::new(self.config)
424 }
425}