rusty_feeder/exchange/coinbase/
websocket_handler.rs

1//! Coinbase WebSocket handler implementation
2//!
3//! Provides Coinbase-specific message handling and routing.
4
5use super::data::orderbook::Level2Change;
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use rusty_common::collections::FxHashMap;
9use rusty_common::types::Exchange;
10use rusty_common::websocket::{
11    ExchangeHandler, Message, MessageHandler, MessageRouter, RoutedMessage, SubscriptionType,
12    WebSocketResult,
13};
14use simd_json::prelude::*;
15use smartstring::alias::String;
16use std::sync::Arc;
17
18use super::data::{orderbook::Level2Update, trade::TradeMessage};
19
20// Type aliases for compatibility
21type CoinbaseTradeMessage = TradeMessage;
22type CoinbaseDepthMessage = Level2Update;
23
24/// Subscription information
25#[derive(Debug, Clone)]
26pub struct SubscriptionInfo {
27    product_ids: Vec<String>,
28    channels: Vec<String>,
29}
30
31/// Coinbase WebSocket handler
32pub struct CoinbaseWebSocketHandler {
33    /// Active subscriptions
34    subscriptions: Arc<RwLock<FxHashMap<String, SubscriptionInfo>>>,
35    /// Whether authenticated
36    authenticated: bool,
37}
38
39impl CoinbaseWebSocketHandler {
40    /// Create a new handler
41    pub fn new() -> Self {
42        Self {
43            subscriptions: Arc::new(RwLock::new(FxHashMap::default())),
44            authenticated: false,
45        }
46    }
47
48    /// Register a subscription
49    pub fn register_subscription(&self, key: String, info: SubscriptionInfo) {
50        self.subscriptions.write().insert(key, info);
51    }
52
53    /// Parse message type from JSON
54    fn get_message_type<'a>(json: &'a simd_json::BorrowedValue<'a>) -> Option<&'a str> {
55        json.get("type")?.as_str()
56    }
57
58    /// Parse product ID from JSON
59    fn get_product_id<'a>(json: &'a simd_json::BorrowedValue<'a>) -> Option<&'a str> {
60        json.get("product_id")?.as_str()
61    }
62}
63
64impl Default for CoinbaseWebSocketHandler {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70#[async_trait]
71impl MessageHandler for CoinbaseWebSocketHandler {
72    async fn on_message(&mut self, message: Message) -> WebSocketResult<()> {
73        // Coinbase-specific message handling (auth responses, etc.)
74        if let Message::Text(text) = &message {
75            // Parse to check for system messages
76            let mut text_bytes = text.as_bytes().to_vec();
77            match simd_json::to_borrowed_value(&mut text_bytes) {
78                Ok(json) => {
79                    if let Some(msg_type) = Self::get_message_type(&json) {
80                        match msg_type {
81                            "subscriptions" => {
82                                log::debug!("Coinbase subscription confirmed");
83                            }
84                            "error" => {
85                                let error_msg = json
86                                    .get("message")
87                                    .and_then(|m| m.as_str())
88                                    .unwrap_or("Unknown error");
89                                log::error!("Coinbase error: {error_msg}");
90                            }
91                            _ => {
92                                // Market data messages are handled by router
93                            }
94                        }
95                    }
96                }
97                Err(e) => {
98                    log::warn!("Failed to parse Coinbase message: {e}");
99                }
100            }
101        }
102        Ok(())
103    }
104
105    async fn on_connected(&mut self) -> WebSocketResult<()> {
106        log::info!("Connected to Coinbase WebSocket");
107        Ok(())
108    }
109
110    async fn on_disconnected(&mut self) -> WebSocketResult<()> {
111        log::info!("Disconnected from Coinbase WebSocket");
112        self.authenticated = false;
113        Ok(())
114    }
115}
116
117#[async_trait]
118impl ExchangeHandler for CoinbaseWebSocketHandler {
119    fn exchange(&self) -> Exchange {
120        Exchange::Coinbase
121    }
122
123    async fn get_subscriptions(&self) -> WebSocketResult<Vec<Message>> {
124        // Return current subscriptions as messages
125        let subs = self.subscriptions.read();
126        let mut messages = Vec::new();
127
128        for (_, info) in subs.iter() {
129            let product_ids: Vec<std::string::String> =
130                info.product_ids.iter().map(|s| s.to_string()).collect();
131            let channels: Vec<std::string::String> =
132                info.channels.iter().map(|s| s.to_string()).collect();
133            let sub_msg = rusty_common::websocket::exchanges::coinbase::create_subscription(
134                product_ids,
135                channels,
136            );
137            messages.push(Message::Text(sub_msg.to_string().into()));
138        }
139
140        Ok(messages)
141    }
142
143    async fn authenticate(&mut self) -> WebSocketResult<Option<Message>> {
144        // TODO: Implement authentication when needed
145        Ok(None)
146    }
147
148    fn requires_auth(&self) -> bool {
149        false // Public feeds don't require auth
150    }
151}
152
153/// Coinbase message router
154pub struct CoinbaseMessageRouter;
155
156impl CoinbaseMessageRouter {
157    /// Create a new router
158    pub const fn new() -> Self {
159        Self
160    }
161
162    /// Parse trade message
163    fn parse_trade(
164        json: &simd_json::BorrowedValue,
165        product_id: &str,
166    ) -> WebSocketResult<CoinbaseTradeMessage> {
167        use rust_decimal::Decimal;
168
169        let trade_id = json.get("trade_id").and_then(|v| v.as_u64()).unwrap_or(0);
170
171        let size = json
172            .get("last_size")
173            .and_then(|v| v.as_str())
174            .and_then(|s| s.parse::<Decimal>().ok())
175            .unwrap_or_default();
176
177        let price = json
178            .get("price")
179            .and_then(|v| v.as_str())
180            .and_then(|s| s.parse::<Decimal>().ok())
181            .unwrap_or_default();
182
183        let sequence = json.get("sequence").and_then(|v| v.as_u64()).unwrap_or(0);
184
185        let msg = CoinbaseTradeMessage {
186            message_type: json
187                .get("type")
188                .and_then(|v| v.as_str())
189                .unwrap_or("")
190                .into(),
191            trade_id,
192            maker_order_id: json
193                .get("maker_order_id")
194                .and_then(|v| v.as_str())
195                .unwrap_or("")
196                .into(),
197            taker_order_id: json
198                .get("taker_order_id")
199                .and_then(|v| v.as_str())
200                .unwrap_or("")
201                .into(),
202            side: json
203                .get("side")
204                .and_then(|v| v.as_str())
205                .unwrap_or("")
206                .into(),
207            size,
208            price,
209            product_id: product_id.into(),
210            sequence,
211            time: json
212                .get("time")
213                .and_then(|v| v.as_str())
214                .unwrap_or("")
215                .into(),
216        };
217        Ok(msg)
218    }
219
220    /// Parse depth update
221    fn parse_depth(
222        json: &simd_json::BorrowedValue,
223        product_id: &str,
224    ) -> WebSocketResult<CoinbaseDepthMessage> {
225        use smallvec::SmallVec;
226
227        let changes = json
228            .get("changes")
229            .and_then(|v| v.as_array())
230            .map(|arr| {
231                let mut result = SmallVec::<[Level2Change; 32]>::new();
232                for change in arr {
233                    if let Some(change_array) = change.as_array()
234                        && change_array.len() >= 3
235                    {
236                        let side = change_array[0].as_str().unwrap_or("").into();
237                        let price = change_array[1].as_str().unwrap_or("0").into();
238                        let size = change_array[2].as_str().unwrap_or("0").into();
239                        result.push([side, price, size]);
240                    }
241                }
242                result
243            })
244            .unwrap_or_default();
245
246        let msg = CoinbaseDepthMessage {
247            message_type: json
248                .get("type")
249                .and_then(|v| v.as_str())
250                .unwrap_or("")
251                .into(),
252            product_id: product_id.into(),
253            time: json
254                .get("time")
255                .and_then(|v| v.as_str())
256                .unwrap_or("")
257                .into(),
258            changes,
259        };
260        Ok(msg)
261    }
262}
263
264impl Default for CoinbaseMessageRouter {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270impl MessageRouter for CoinbaseMessageRouter {
271    fn route_message(&self, message: &Message) -> WebSocketResult<Option<RoutedMessage>> {
272        match message {
273            Message::Text(text) => {
274                // Parse JSON
275                let mut text_bytes = text.as_bytes().to_vec();
276                match simd_json::to_borrowed_value(&mut text_bytes) {
277                    Ok(json) => {
278                        // Get message type and product ID
279                        if let (Some(msg_type), Some(product_id)) = (
280                            CoinbaseWebSocketHandler::get_message_type(&json),
281                            CoinbaseWebSocketHandler::get_product_id(&json),
282                        ) {
283                            match msg_type {
284                                "ticker" => {
285                                    // Parse as trade message
286                                    match Self::parse_trade(&json, product_id) {
287                                        Ok(trade) => {
288                                            let key = self.get_subscription_key(
289                                                product_id,
290                                                SubscriptionType::Trades,
291                                            );
292                                            Ok(Some(RoutedMessage {
293                                                subscription_key: key,
294                                                subscription_type: SubscriptionType::Trades,
295                                                payload: Box::new(trade),
296                                            }))
297                                        }
298                                        Err(e) => {
299                                            log::warn!("Failed to parse trade: {e}");
300                                            Ok(None)
301                                        }
302                                    }
303                                }
304                                "l2update" => {
305                                    // Parse as depth message
306                                    match Self::parse_depth(&json, product_id) {
307                                        Ok(depth) => {
308                                            let key = self.get_subscription_key(
309                                                product_id,
310                                                SubscriptionType::OrderBook,
311                                            );
312                                            Ok(Some(RoutedMessage {
313                                                subscription_key: key,
314                                                subscription_type: SubscriptionType::OrderBook,
315                                                payload: Box::new(depth),
316                                            }))
317                                        }
318                                        Err(e) => {
319                                            log::warn!("Failed to parse depth: {e}");
320                                            Ok(None)
321                                        }
322                                    }
323                                }
324                                _ => Ok(None), // Other message types not routed
325                            }
326                        } else {
327                            Ok(None) // System messages
328                        }
329                    }
330                    Err(e) => {
331                        log::warn!("Failed to parse JSON: {e}");
332                        Ok(None)
333                    }
334                }
335            }
336            _ => Ok(None), // Binary, ping/pong not routed
337        }
338    }
339}
340
341#[cfg(test)]
342#[path = "websocket_handler_test.rs"]
343mod websocket_handler_test;