rusty_common/websocket/bridge/
channel_bridge.rs1use super::super::{ExchangeHandler, Message, MessageHandler, WebSocketError, WebSocketResult};
6use super::message_router::{MessageRouter, SubscriptionType};
7use crate::collections::FxHashMap;
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use std::any::Any;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14pub struct SubscriptionChannel {
16 pub sender: mpsc::UnboundedSender<Box<dyn Any + Send>>,
18 pub subscription_type: SubscriptionType,
20}
21
22pub struct ChannelBridgeHandler<H: ExchangeHandler, R: MessageRouter> {
24 inner: H,
26 router: R,
28 subscriptions: Arc<RwLock<FxHashMap<String, SubscriptionChannel>>>,
30}
31
32impl<H: ExchangeHandler, R: MessageRouter> ChannelBridgeHandler<H, R> {
33 pub fn new(inner: H, router: R) -> Self {
35 Self {
36 inner,
37 router,
38 subscriptions: Arc::new(RwLock::new(FxHashMap::default())),
39 }
40 }
41
42 pub fn register_subscription(
44 &self,
45 key: String,
46 sender: mpsc::UnboundedSender<Box<dyn Any + Send>>,
47 subscription_type: SubscriptionType,
48 ) {
49 let channel = SubscriptionChannel {
50 sender,
51 subscription_type,
52 };
53 self.subscriptions.write().insert(key, channel);
54 }
55
56 pub fn unregister_subscription(&self, key: &str) -> Option<SubscriptionChannel> {
58 self.subscriptions.write().remove(key)
59 }
60
61 pub fn subscription_count(&self) -> usize {
63 self.subscriptions.read().len()
64 }
65
66 pub fn clear_subscriptions(&self) {
68 self.subscriptions.write().clear();
69 }
70
71 pub const fn inner(&self) -> &H {
73 &self.inner
74 }
75
76 pub const fn inner_mut(&mut self) -> &mut H {
78 &mut self.inner
79 }
80}
81
82#[async_trait]
83impl<H: ExchangeHandler, R: MessageRouter> MessageHandler for ChannelBridgeHandler<H, R> {
84 async fn on_message(&mut self, message: Message) -> WebSocketResult<()> {
85 self.inner.on_message(message.clone()).await?;
87
88 if let Some(routed) = self.router.route_message(&message)? {
90 let subs = self.subscriptions.read();
91 if let Some(channel) = subs.get(&routed.subscription_key) {
92 let _ = channel.sender.send(routed.payload);
94 }
95 }
96
97 Ok(())
98 }
99
100 async fn on_connected(&mut self) -> WebSocketResult<()> {
101 self.inner.on_connected().await
102 }
103
104 async fn on_disconnected(&mut self) -> WebSocketResult<()> {
105 self.inner.on_disconnected().await
106 }
107
108 async fn on_error(&mut self, error: WebSocketError) -> WebSocketResult<()> {
109 self.inner.on_error(error).await
110 }
111}
112
113#[async_trait]
114impl<H: ExchangeHandler, R: MessageRouter> ExchangeHandler for ChannelBridgeHandler<H, R> {
115 fn exchange(&self) -> crate::types::Exchange {
116 self.inner.exchange()
117 }
118
119 async fn get_subscriptions(&self) -> WebSocketResult<Vec<Message>> {
120 self.inner.get_subscriptions().await
121 }
122
123 async fn authenticate(&mut self) -> WebSocketResult<Option<Message>> {
124 self.inner.authenticate().await
125 }
126
127 fn requires_auth(&self) -> bool {
128 self.inner.requires_auth()
129 }
130
131 fn get_heartbeat(&self) -> Option<Message> {
132 self.inner.get_heartbeat()
133 }
134
135 fn heartbeat_interval(&self) -> Option<std::time::Duration> {
136 self.inner.heartbeat_interval()
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::types::Exchange;
144
145 struct TestRouter;
146
147 impl MessageRouter for TestRouter {
148 fn route_message(
149 &self,
150 message: &Message,
151 ) -> WebSocketResult<Option<super::super::RoutedMessage>> {
152 match message {
153 Message::Text(text) if text.contains("trade") => {
154 Ok(Some(super::super::RoutedMessage {
155 subscription_key: "trades:TEST".to_string(),
156 subscription_type: SubscriptionType::Trades,
157 payload: Box::new(text.clone()),
158 }))
159 }
160 _ => Ok(None),
161 }
162 }
163 }
164
165 struct TestHandler;
166
167 #[async_trait]
168 impl MessageHandler for TestHandler {
169 async fn on_message(&mut self, _message: Message) -> WebSocketResult<()> {
170 Ok(())
171 }
172 }
173
174 #[async_trait]
175 impl ExchangeHandler for TestHandler {
176 fn exchange(&self) -> Exchange {
177 Exchange::Binance
178 }
179
180 async fn get_subscriptions(&self) -> WebSocketResult<Vec<Message>> {
181 Ok(vec![])
182 }
183 }
184
185 #[tokio::test]
186 async fn test_channel_bridge() {
187 let handler = TestHandler;
188 let router = TestRouter;
189 let mut bridge = ChannelBridgeHandler::new(handler, router);
190
191 let (tx, mut rx) = mpsc::unbounded_channel();
193 bridge.register_subscription("trades:TEST".to_string(), tx, SubscriptionType::Trades);
194
195 let trade_msg = Message::Text("trade data".into());
197 bridge.on_message(trade_msg).await.unwrap();
198
199 let received = rx.try_recv().unwrap();
201 let text = received
202 .downcast_ref::<smartstring::alias::String>()
203 .unwrap();
204 assert_eq!(text.as_str(), "trade data");
205
206 let other_msg = Message::Text("other data".into());
208 bridge.on_message(other_msg).await.unwrap();
209 assert!(rx.try_recv().is_err());
210 }
211}