rusty_common/websocket/
handler.rs

1//! WebSocket message handlers
2//!
3//! Provides traits for handling WebSocket messages in an exchange-agnostic way.
4
5use super::{Message, WebSocketError, WebSocketResult};
6use crate::types::Exchange;
7use async_trait::async_trait;
8use tokio::sync::mpsc;
9
10/// Message handler trait
11#[async_trait]
12pub trait MessageHandler: Send + Sync {
13    /// Handle a received message
14    async fn on_message(&mut self, message: Message) -> WebSocketResult<()>;
15
16    /// Handle connection established
17    async fn on_connected(&mut self) -> WebSocketResult<()> {
18        Ok(())
19    }
20
21    /// Handle connection closed
22    async fn on_disconnected(&mut self) -> WebSocketResult<()> {
23        Ok(())
24    }
25
26    /// Handle errors
27    async fn on_error(&mut self, error: WebSocketError) -> WebSocketResult<()> {
28        log::error!("WebSocket error: {}", error);
29        Ok(())
30    }
31
32    /// Set the sender for outgoing messages (optional, default implementation does nothing)
33    fn set_sender(&mut self, _sender: mpsc::UnboundedSender<Message>) {
34        // Default implementation does nothing
35    }
36
37    /// Send a message through the WebSocket (optional, default implementation returns error)
38    fn send_message(&self, _message: Message) -> WebSocketResult<()> {
39        Err(WebSocketError::NotConnected)
40    }
41}
42
43/// Exchange-specific handler trait
44#[async_trait]
45pub trait ExchangeHandler: MessageHandler {
46    /// Get the exchange type
47    fn exchange(&self) -> Exchange;
48
49    /// Get subscription messages
50    async fn get_subscriptions(&self) -> WebSocketResult<Vec<Message>>;
51
52    /// Handle authentication if required
53    async fn authenticate(&mut self) -> WebSocketResult<Option<Message>> {
54        Ok(None)
55    }
56
57    /// Check if the connection needs authentication
58    fn requires_auth(&self) -> bool {
59        false
60    }
61
62    /// Get heartbeat message if required
63    fn get_heartbeat(&self) -> Option<Message> {
64        None
65    }
66
67    /// Get heartbeat interval
68    fn heartbeat_interval(&self) -> Option<std::time::Duration> {
69        None
70    }
71}
72
73/// Simple message handler that collects messages
74pub struct CollectingHandler {
75    /// A vector of received messages.
76    pub messages: Vec<Message>,
77}
78
79impl CollectingHandler {
80    /// Create a new collecting handler
81    #[must_use]
82    pub const fn new() -> Self {
83        Self {
84            messages: Vec::new(),
85        }
86    }
87}
88
89impl Default for CollectingHandler {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95#[async_trait]
96impl MessageHandler for CollectingHandler {
97    async fn on_message(&mut self, message: Message) -> WebSocketResult<()> {
98        self.messages.push(message);
99        Ok(())
100    }
101}
102
103/// Logging message handler
104pub struct LoggingHandler {
105    prefix: String,
106}
107
108impl LoggingHandler {
109    /// Create a new logging handler
110    #[must_use]
111    pub fn new(prefix: &str) -> Self {
112        Self {
113            prefix: prefix.to_string(),
114        }
115    }
116}
117
118#[async_trait]
119impl MessageHandler for LoggingHandler {
120    async fn on_message(&mut self, message: Message) -> WebSocketResult<()> {
121        match &message {
122            Message::Text(text) => log::info!("{}: Text: {}", self.prefix, text),
123            Message::Binary(data) => log::info!("{}: Binary: {} bytes", self.prefix, data.len()),
124            Message::Ping(data) => log::debug!("{}: Ping: {} bytes", self.prefix, data.len()),
125            Message::Pong(data) => log::debug!("{}: Pong: {} bytes", self.prefix, data.len()),
126            Message::Close(frame) => {
127                if let Some((code, reason)) = frame {
128                    log::info!("{}: Close: {} - {}", self.prefix, code, reason);
129                } else {
130                    log::info!("{}: Close", self.prefix);
131                }
132            }
133            Message::Frame(_) => log::trace!("{}: Frame", self.prefix),
134        }
135        Ok(())
136    }
137
138    async fn on_connected(&mut self) -> WebSocketResult<()> {
139        log::info!("{}: Connected", self.prefix);
140        Ok(())
141    }
142
143    async fn on_disconnected(&mut self) -> WebSocketResult<()> {
144        log::info!("{}: Disconnected", self.prefix);
145        Ok(())
146    }
147}