rusty_common/websocket/bridge/
channel_bridge.rs

1//! Channel bridge handler implementation
2//!
3//! Bridges between WebSocketClient's callback pattern and channel-based consumers.
4
5use super::super::{ExchangeHandler, Message, MessageHandler, WebSocketError, WebSocketResult};
6use super::message_router::{MessageRouter, SubscriptionType};
7use crate::collections::FxHashMap;
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use std::any::Any;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14/// Channel information for a subscription
15pub struct SubscriptionChannel {
16    /// Channel sender
17    pub sender: mpsc::UnboundedSender<Box<dyn Any + Send>>,
18    /// Type of subscription
19    pub subscription_type: SubscriptionType,
20}
21
22/// Channel bridge handler
23pub struct ChannelBridgeHandler<H: ExchangeHandler, R: MessageRouter> {
24    /// Inner exchange handler
25    inner: H,
26    /// Message router
27    router: R,
28    /// Active subscriptions
29    subscriptions: Arc<RwLock<FxHashMap<String, SubscriptionChannel>>>,
30}
31
32impl<H: ExchangeHandler, R: MessageRouter> ChannelBridgeHandler<H, R> {
33    /// Create a new channel bridge handler
34    pub fn new(inner: H, router: R) -> Self {
35        Self {
36            inner,
37            router,
38            subscriptions: Arc::new(RwLock::new(FxHashMap::default())),
39        }
40    }
41
42    /// Register a subscription channel
43    pub fn register_subscription(
44        &self,
45        key: String,
46        sender: mpsc::UnboundedSender<Box<dyn Any + Send>>,
47        subscription_type: SubscriptionType,
48    ) {
49        let channel = SubscriptionChannel {
50            sender,
51            subscription_type,
52        };
53        self.subscriptions.write().insert(key, channel);
54    }
55
56    /// Unregister a subscription
57    pub fn unregister_subscription(&self, key: &str) -> Option<SubscriptionChannel> {
58        self.subscriptions.write().remove(key)
59    }
60
61    /// Get the number of active subscriptions
62    pub fn subscription_count(&self) -> usize {
63        self.subscriptions.read().len()
64    }
65
66    /// Clear all subscriptions
67    pub fn clear_subscriptions(&self) {
68        self.subscriptions.write().clear();
69    }
70
71    /// Get inner handler reference
72    pub const fn inner(&self) -> &H {
73        &self.inner
74    }
75
76    /// Get inner handler mutable reference
77    pub const fn inner_mut(&mut self) -> &mut H {
78        &mut self.inner
79    }
80}
81
82#[async_trait]
83impl<H: ExchangeHandler, R: MessageRouter> MessageHandler for ChannelBridgeHandler<H, R> {
84    async fn on_message(&mut self, message: Message) -> WebSocketResult<()> {
85        // Let inner handler process first (for auth, heartbeat, etc.)
86        self.inner.on_message(message.clone()).await?;
87
88        // Route to appropriate channel
89        if let Some(routed) = self.router.route_message(&message)? {
90            let subs = self.subscriptions.read();
91            if let Some(channel) = subs.get(&routed.subscription_key) {
92                // Send to channel, ignore if receiver dropped
93                let _ = channel.sender.send(routed.payload);
94            }
95        }
96
97        Ok(())
98    }
99
100    async fn on_connected(&mut self) -> WebSocketResult<()> {
101        self.inner.on_connected().await
102    }
103
104    async fn on_disconnected(&mut self) -> WebSocketResult<()> {
105        self.inner.on_disconnected().await
106    }
107
108    async fn on_error(&mut self, error: WebSocketError) -> WebSocketResult<()> {
109        self.inner.on_error(error).await
110    }
111}
112
113#[async_trait]
114impl<H: ExchangeHandler, R: MessageRouter> ExchangeHandler for ChannelBridgeHandler<H, R> {
115    fn exchange(&self) -> crate::types::Exchange {
116        self.inner.exchange()
117    }
118
119    async fn get_subscriptions(&self) -> WebSocketResult<Vec<Message>> {
120        self.inner.get_subscriptions().await
121    }
122
123    async fn authenticate(&mut self) -> WebSocketResult<Option<Message>> {
124        self.inner.authenticate().await
125    }
126
127    fn requires_auth(&self) -> bool {
128        self.inner.requires_auth()
129    }
130
131    fn get_heartbeat(&self) -> Option<Message> {
132        self.inner.get_heartbeat()
133    }
134
135    fn heartbeat_interval(&self) -> Option<std::time::Duration> {
136        self.inner.heartbeat_interval()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::types::Exchange;
144
145    struct TestRouter;
146
147    impl MessageRouter for TestRouter {
148        fn route_message(
149            &self,
150            message: &Message,
151        ) -> WebSocketResult<Option<super::super::RoutedMessage>> {
152            match message {
153                Message::Text(text) if text.contains("trade") => {
154                    Ok(Some(super::super::RoutedMessage {
155                        subscription_key: "trades:TEST".to_string(),
156                        subscription_type: SubscriptionType::Trades,
157                        payload: Box::new(text.clone()),
158                    }))
159                }
160                _ => Ok(None),
161            }
162        }
163    }
164
165    struct TestHandler;
166
167    #[async_trait]
168    impl MessageHandler for TestHandler {
169        async fn on_message(&mut self, _message: Message) -> WebSocketResult<()> {
170            Ok(())
171        }
172    }
173
174    #[async_trait]
175    impl ExchangeHandler for TestHandler {
176        fn exchange(&self) -> Exchange {
177            Exchange::Binance
178        }
179
180        async fn get_subscriptions(&self) -> WebSocketResult<Vec<Message>> {
181            Ok(vec![])
182        }
183    }
184
185    #[tokio::test]
186    async fn test_channel_bridge() {
187        let handler = TestHandler;
188        let router = TestRouter;
189        let mut bridge = ChannelBridgeHandler::new(handler, router);
190
191        // Create channel
192        let (tx, mut rx) = mpsc::unbounded_channel();
193        bridge.register_subscription("trades:TEST".to_string(), tx, SubscriptionType::Trades);
194
195        // Send trade message
196        let trade_msg = Message::Text("trade data".into());
197        bridge.on_message(trade_msg).await.unwrap();
198
199        // Check received
200        let received = rx.try_recv().unwrap();
201        let text = received
202            .downcast_ref::<smartstring::alias::String>()
203            .unwrap();
204        assert_eq!(text.as_str(), "trade data");
205
206        // Non-trade message should not be routed
207        let other_msg = Message::Text("other data".into());
208        bridge.on_message(other_msg).await.unwrap();
209        assert!(rx.try_recv().is_err());
210    }
211}