1use rusty_common::collections::FxHashMap;
7use smartstring::alias::String;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use super::data::{
12 orderbook::Level2Update, subscription::SubscriptionMessage, trade::TradeMessage,
13};
14use super::types::{COINBASE_API_URL, COINBASE_RATE_LIMITS, COINBASE_WS_ADVANCED_URL};
15use anyhow::{Context, Result, anyhow};
16use async_trait::async_trait;
17use parking_lot::RwLock;
18use quanta::Clock;
19use reqwest::header::{HeaderMap, HeaderValue};
20use rusty_common::json::Value;
21use rusty_common::websocket::{Message, WebSocketClient, WebSocketConfig, WebSocketError};
22use rusty_model::{
23 data::{
24 book_snapshot::OrderBookSnapshot, market_trade::MarketTrade,
25 simd_orderbook::SharedSimdOrderBook,
26 },
27 instruments::{Instrument, InstrumentId},
28 venues::Venue,
29};
30use simd_json::prelude::{ValueAsArray, ValueAsScalar};
31use smallvec::SmallVec;
32use tokio::sync::{mpsc, watch};
33use tokio::task::JoinHandle;
34
35use crate::provider::prelude::*;
36
37#[derive(Debug, Clone)]
39pub struct CoinbaseInstrument {
40 pub id: InstrumentId,
42}
43
44impl Instrument for CoinbaseInstrument {
45 fn id(&self) -> InstrumentId {
46 self.id.clone()
47 }
48
49 fn symbol(&self) -> String {
50 self.id.symbol.clone()
51 }
52
53 fn venue(&self) -> Venue {
54 self.id.venue
55 }
56
57 fn as_any(&self) -> &dyn std::any::Any {
58 self
59 }
60
61 fn clone_box(&self) -> Box<dyn Instrument> {
62 Box::new(self.clone())
63 }
64}
65
66#[derive(Debug)]
71pub struct CoinbaseProvider {
72 config: ConnectionConfig,
74
75 connection_status: Arc<RwLock<ConnectionState>>,
77
78 stats: Arc<RwLock<ConnectionStats>>,
80
81 clock: Clock,
83
84 subscriptions: Arc<RwLock<FxHashMap<String, watch::Sender<bool>>>>,
86
87 ws_handles: Arc<RwLock<FxHashMap<String, JoinHandle<()>>>>,
89
90 last_connection_attempt: Arc<RwLock<Instant>>,
92
93 http_client: reqwest::Client,
95}
96
97impl Default for CoinbaseProvider {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl CoinbaseProvider {
104 #[inline]
106 #[must_use]
107 pub fn new() -> Self {
108 Self::with_config(None)
109 }
110
111 #[inline]
113 #[must_use]
114 pub fn with_config(config: Option<ConnectionConfig>) -> Self {
115 let mut default_config = ConnectionConfig::default();
116
117 default_config.websocket_config.base_url = COINBASE_WS_ADVANCED_URL.into();
119 default_config.websocket_config.ping_interval_milliseconds = 30000; default_config.websocket_config.use_compression = false;
121
122 default_config.rest_config.base_url = COINBASE_API_URL.into();
124 default_config.rest_config.timeout_milliseconds = 5000; let config = config.unwrap_or(default_config);
128 let clock = config.clock.clone();
129
130 let mut headers = HeaderMap::new();
132 headers.insert(
133 "User-Agent",
134 HeaderValue::from_str(&config.rest_config.user_agent)
135 .unwrap_or_else(|_| HeaderValue::from_static("RustyHFT/1.0")),
136 );
137
138 let http_client = reqwest::Client::builder()
139 .timeout(Duration::from_millis(
140 config.rest_config.timeout_milliseconds,
141 ))
142 .connect_timeout(Duration::from_millis(
143 config.rest_config.timeout_milliseconds / 2,
144 ))
145 .pool_max_idle_per_host(config.rest_config.connection_pool_size)
146 .pool_idle_timeout(Duration::from_millis(
147 config.rest_config.keep_alive_milliseconds,
148 ))
149 .default_headers(headers)
150 .build()
151 .unwrap_or_default();
152
153 Self {
154 config,
155 connection_status: Arc::new(RwLock::new(ConnectionState::Disconnected)),
156 stats: Arc::new(RwLock::new(ConnectionStats::default())),
157 clock,
158 subscriptions: Arc::new(RwLock::new(FxHashMap::default())),
159 ws_handles: Arc::new(RwLock::new(FxHashMap::default())),
160 last_connection_attempt: Arc::new(RwLock::new(Instant::now())),
161 http_client,
162 }
163 }
164
165 #[inline]
167 fn update_receive_stats(
168 stats: Arc<RwLock<ConnectionStats>>,
169 message_size: usize,
170 local_time: u64,
171 ) {
172 let mut s = stats.write();
173 s.messages_received += 1;
174 s.bytes_received += message_size as u64;
175 s.last_message_time = local_time;
176 }
177
178 fn create_websocket_config(
180 url: String,
181 use_compression: bool,
182 ping_interval_milliseconds: u64,
183 timeout_milliseconds: u64,
184 ) -> WebSocketConfig {
185 WebSocketConfig::builder(rusty_common::types::Exchange::Coinbase, url.to_string())
186 .ping_interval(Duration::from_millis(ping_interval_milliseconds))
187 .timeout(Duration::from_millis(timeout_milliseconds))
188 .compression(if use_compression {
189 rusty_common::websocket::CompressionConfig::default()
190 } else {
191 rusty_common::websocket::CompressionConfig::disabled()
192 })
193 .build()
194 }
195
196 #[inline]
198 async fn connect(&self) -> Result<()> {
199 if *self.connection_status.read() == ConnectionState::Connected {
200 return Ok(());
201 }
202
203 {
205 let now = Instant::now();
206 let last_attempt = *self.last_connection_attempt.read();
207 let backoff_ms = self
208 .config
209 .websocket_config
210 .reconnect
211 .backoff_initial_milliseconds;
212
213 if now.duration_since(last_attempt) < Duration::from_millis(backoff_ms) {
214 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
216 }
217
218 *self.last_connection_attempt.write() = Instant::now();
220 }
221
222 *self.connection_status.write() = ConnectionState::Connecting;
223
224 *self.connection_status.write() = ConnectionState::Connected;
227 self.stats.write().connected_time = self.clock.raw();
228 Ok(())
229 }
230}
231
232#[async_trait]
233impl Provider for CoinbaseProvider {
234 type TradeMessage = TradeMessage;
235 type DepthMessage = Level2Update;
236 type InstrumentMessage = Value;
237
238 fn name(&self) -> &'static str {
239 "Coinbase"
240 }
241
242 fn venue(&self) -> Venue {
243 Venue::Coinbase
244 }
245
246 fn config(&self) -> &ConnectionConfig {
247 &self.config
248 }
249
250 async fn init(&mut self) -> Result<()> {
251 self.connect().await
252 }
253
254 async fn shutdown(&mut self) -> Result<()> {
255 let mut keys_to_remove = Vec::new();
257
258 for (key, _) in self.subscriptions.read().iter() {
259 keys_to_remove.push(key.clone());
260 }
261
262 for key in keys_to_remove {
264 if let Some(tx) = self.subscriptions.write().remove(&key) {
265 let _ = tx.send(true);
266 }
267
268 if let Some(handle) = self.ws_handles.write().remove(&key) {
270 handle.abort();
271 }
272 }
273
274 *self.connection_status.write() = ConnectionState::Disconnected;
276
277 Ok(())
278 }
279
280 #[inline]
281 async fn subscribe_trades(
282 &self,
283 symbols: SmallVec<[String; 8]>,
284 ) -> Result<mpsc::Receiver<Self::TradeMessage>> {
285 self.connect().await?;
287
288 let (tx, rx) = mpsc::channel(1024);
290
291 let subscription = SubscriptionMessage::trades(symbols.to_vec());
293 let connection_id = format!("trade-{}", symbols.join(","));
294
295 let handle = self
297 .create_trade_connection(symbols, tx, subscription)
298 .await?;
299
300 self.ws_handles.write().insert(connection_id.into(), handle);
302
303 Ok(rx)
304 }
305
306 #[inline]
307 async fn unsubscribe_trades(&self) -> Result<()> {
308 let mut keys_to_remove = Vec::new();
310
311 for (key, _) in self.subscriptions.read().iter() {
312 if key.starts_with("trade-") {
313 keys_to_remove.push(key.clone());
314 }
315 }
316
317 for key in keys_to_remove {
319 if let Some(tx) = self.subscriptions.write().remove(&key) {
320 let _ = tx.send(true);
321 }
322
323 if let Some(handle) = self.ws_handles.write().remove(&key) {
325 handle.abort();
326 }
327 }
328
329 Ok(())
330 }
331
332 #[inline]
333 async fn subscribe_orderbook(
334 &self,
335 symbols: SmallVec<[String; 8]>,
336 ) -> Result<mpsc::Receiver<Self::DepthMessage>> {
337 self.connect().await?;
339
340 let (tx, rx) = mpsc::channel(1024);
342
343 let subscription = SubscriptionMessage::orderbook(symbols.to_vec());
345 let connection_id = format!("orderbook-{}", symbols.join(","));
346
347 let handle = self
349 .create_orderbook_connection(symbols, tx, subscription)
350 .await?;
351
352 self.ws_handles.write().insert(connection_id.into(), handle);
354
355 Ok(rx)
356 }
357
358 #[inline]
359 async fn unsubscribe_orderbook(&self) -> Result<()> {
360 let mut keys_to_remove = Vec::new();
362
363 for (key, _) in self.subscriptions.read().iter() {
364 if key.starts_with("orderbook-") {
365 keys_to_remove.push(key.clone());
366 }
367 }
368
369 for key in keys_to_remove {
371 if let Some(tx) = self.subscriptions.write().remove(&key) {
372 let _ = tx.send(true);
373 }
374
375 if let Some(handle) = self.ws_handles.write().remove(&key) {
377 handle.abort();
378 }
379 }
380
381 Ok(())
382 }
383
384 #[inline]
385 async fn get_realtime_orderbook(&self, _symbol: &str) -> Result<SharedSimdOrderBook> {
386 Err(anyhow!("Real-time orderbook not implemented yet"))
388 }
389
390 #[inline]
391 async fn get_instruments(&self) -> Result<Vec<Box<dyn Instrument>>> {
392 let url = format!("{}/products", self.config.rest_config.base_url);
394
395 let response = self
396 .http_client
397 .get(&url)
398 .send()
399 .await
400 .context("Failed to fetch Coinbase products info")?;
401
402 if !response.status().is_success() {
404 return Err(anyhow!(
405 "Failed to fetch Coinbase products info: HTTP {}",
406 response.status()
407 ));
408 }
409
410 let bytes = response
412 .bytes()
413 .await
414 .context("Failed to get response bytes")?;
415 let mut bytes_vec = bytes.to_vec();
416 let products_info: simd_json::OwnedValue = simd_json::from_slice(&mut bytes_vec)
417 .context("Failed to parse Coinbase products info response")?;
418
419 let products = products_info
421 .as_array()
422 .ok_or_else(|| anyhow!("Invalid response format: expected array of products"))?;
423
424 let mut instruments = Vec::with_capacity(products.len());
426
427 for product_data in products {
428 let status = product_data["status"].as_str().unwrap_or("");
430 if status != "online" {
431 continue;
432 }
433
434 let product_id = product_data["product_id"].as_str().ok_or_else(|| {
435 anyhow!("Invalid product data: 'product_id' field not found or not a String")
436 })?;
437
438 let instrument_id = InstrumentId::new(product_id, Venue::Coinbase);
440 let instrument =
441 Box::new(CoinbaseInstrument { id: instrument_id }) as Box<dyn Instrument>;
442 instruments.push(instrument);
443 }
444
445 Ok(instruments)
446 }
447
448 #[inline]
449 async fn get_historical_trades(
450 &self,
451 _symbol: &str,
452 _limit: Option<u32>,
453 ) -> Result<Vec<MarketTrade>> {
454 Err(anyhow!("Historical trades not implemented yet"))
457 }
458
459 #[inline]
460 async fn get_orderbook_snapshot(
461 &self,
462 symbol: &str,
463 depth: Option<u32>,
464 ) -> Result<OrderBookSnapshot> {
465 let _limit = depth.unwrap_or(50);
466 let url = format!(
467 "{}/products/{}/book?level=2",
468 self.config.rest_config.base_url, symbol
469 );
470
471 let response = self.http_client.get(&url).send().await?;
472
473 if !response.status().is_success() {
474 let error_text = response.text().await?;
475 return Err(anyhow!(
476 "Failed to fetch orderbook snapshot: {}",
477 error_text
478 ));
479 }
480
481 let bytes = response
483 .bytes()
484 .await
485 .map_err(|e| anyhow!("Failed to get response bytes: {}", e))?;
486 let mut bytes_vec = bytes.to_vec();
487 let snapshot_data: simd_json::OwnedValue = simd_json::from_slice(&mut bytes_vec)
488 .map_err(|e| anyhow!("Failed to parse JSON: {}", e))?;
489
490 let timestamp = self.clock.raw();
492 let instrument_id = InstrumentId::new(symbol, Venue::Coinbase);
493
494 let sequence = snapshot_data["sequence"].as_u64().unwrap_or(0);
495 let mut order_book_snapshot =
496 OrderBookSnapshot::new_empty(instrument_id, timestamp, sequence);
497
498 if let Some(bids) = snapshot_data["bids"].as_array() {
500 for bid in bids {
501 if let Some(bid_array) = bid.as_array()
502 && bid_array.len() >= 2
503 && let (Some(price_str), Some(quantity_str)) =
504 (bid_array[0].as_str(), bid_array[1].as_str())
505 && let (Ok(price), Ok(quantity)) = (
506 price_str.parse::<rust_decimal::Decimal>(),
507 quantity_str.parse::<rust_decimal::Decimal>(),
508 )
509 {
510 order_book_snapshot.add_bid(price, quantity);
511 }
512 }
513 }
514
515 if let Some(asks) = snapshot_data["asks"].as_array() {
517 for ask in asks {
518 if let Some(ask_array) = ask.as_array()
519 && ask_array.len() >= 2
520 && let (Some(price_str), Some(quantity_str)) =
521 (ask_array[0].as_str(), ask_array[1].as_str())
522 && let (Ok(price), Ok(quantity)) = (
523 price_str.parse::<rust_decimal::Decimal>(),
524 quantity_str.parse::<rust_decimal::Decimal>(),
525 )
526 {
527 order_book_snapshot.add_ask(price, quantity);
528 }
529 }
530 }
531
532 Ok(order_book_snapshot)
533 }
534
535 #[inline]
536 async fn is_connected(&self) -> bool {
537 *self.connection_status.read() == ConnectionState::Connected
538 }
539
540 #[inline]
541 async fn connection_status(&self) -> ConnectionState {
542 *self.connection_status.read()
543 }
544
545 #[inline]
546 async fn get_stats(&self) -> ConnectionStats {
547 self.stats.read().clone()
548 }
549
550 #[inline]
551 async fn ping(&self) -> Result<u64> {
552 let start = self.clock.raw();
555
556 let response = self
557 .http_client
558 .get(format!("{}/time", self.config.rest_config.base_url))
559 .send()
560 .await?;
561
562 if response.status().is_success() {
563 Ok(self.clock.raw().saturating_sub(start))
564 } else {
565 Err(anyhow!("Ping failed with status: {}", response.status()))
566 }
567 }
568
569 #[inline]
570 async fn reset_connection(&self) -> Result<()> {
571 let mut keys_to_remove = Vec::new();
573
574 for (key, _) in self.subscriptions.read().iter() {
575 keys_to_remove.push(key.clone());
576 }
577
578 for key in &keys_to_remove {
580 if let Some(tx) = self.subscriptions.write().remove(key) {
581 let _ = tx.send(true);
582 }
583
584 if let Some(handle) = self.ws_handles.write().remove(key) {
586 handle.abort();
587 }
588 }
589
590 *self.connection_status.write() = ConnectionState::Disconnected;
592
593 *self.stats.write() = ConnectionStats::default();
595
596 self.connect().await
598 }
599
600 #[inline]
601 fn get_rate_limits(&self) -> Vec<RateLimit> {
602 COINBASE_RATE_LIMITS.to_vec()
603 }
604}
605
606impl CoinbaseProvider {
607 #[inline]
609 async fn create_trade_connection(
610 &self,
611 symbols: SmallVec<[String; 8]>,
612 tx: mpsc::Sender<TradeMessage>,
613 subscription: SubscriptionMessage,
614 ) -> Result<JoinHandle<()>> {
615 let connection_id = format!("trade-{}", symbols.join(","));
616
617 let (stop_tx, stop_rx) = watch::channel(false);
619 self.subscriptions
620 .write()
621 .insert(connection_id.clone().into(), stop_tx);
622
623 let url = self.config.websocket_config.base_url.clone();
625 let clock = self.clock.clone();
626 let timeout_milliseconds = self.config.websocket_config.timeout_milliseconds;
627 let ping_interval_milliseconds = self.config.websocket_config.ping_interval_milliseconds;
628 let stats = self.stats.clone();
629 let connection_status = self.connection_status.clone();
630 let use_compression = self.config.websocket_config.use_compression;
631
632 let handle = tokio::spawn(async move {
634 loop {
635 if *stop_rx.borrow() {
637 break;
638 }
639
640 let ws_config = Self::create_websocket_config(
642 url.clone(),
643 use_compression,
644 ping_interval_milliseconds,
645 timeout_milliseconds,
646 );
647
648 let mut client = WebSocketClient::new(ws_config);
650
651 *connection_status.write() = ConnectionState::Connecting;
653
654 let handler = CoinbaseMessageHandler::new_trade_handler(
656 clock.clone(),
657 stats.clone(),
658 tx.clone(),
659 subscription.clone(),
660 );
661
662 if let Err(e) = client.run(handler).await {
664 log::error!("WebSocket client error: {e}");
665 *connection_status.write() = ConnectionState::Error;
666 }
667
668 if *stop_rx.borrow() {
670 break;
671 }
672
673 tokio::time::sleep(Duration::from_millis(1000)).await;
675 }
676 });
677
678 Ok(handle)
679 }
680
681 #[inline]
683 async fn create_orderbook_connection(
684 &self,
685 symbols: SmallVec<[String; 8]>,
686 tx: mpsc::Sender<Level2Update>,
687 subscription: SubscriptionMessage,
688 ) -> Result<JoinHandle<()>> {
689 let connection_id = format!("orderbook-{}", symbols.join(","));
690
691 let (stop_tx, stop_rx) = watch::channel(false);
693 self.subscriptions
694 .write()
695 .insert(connection_id.clone().into(), stop_tx);
696
697 let url = self.config.websocket_config.base_url.clone();
699 let clock = self.clock.clone();
700 let timeout_milliseconds = self.config.websocket_config.timeout_milliseconds;
701 let ping_interval_milliseconds = self.config.websocket_config.ping_interval_milliseconds;
702 let stats = self.stats.clone();
703 let connection_status = self.connection_status.clone();
704 let use_compression = self.config.websocket_config.use_compression;
705
706 let handle = tokio::spawn(async move {
708 loop {
709 if *stop_rx.borrow() {
711 break;
712 }
713
714 let ws_config = Self::create_websocket_config(
716 url.clone(),
717 use_compression,
718 ping_interval_milliseconds,
719 timeout_milliseconds,
720 );
721
722 let mut client = WebSocketClient::new(ws_config);
724
725 *connection_status.write() = ConnectionState::Connecting;
727
728 let handler = CoinbaseMessageHandler::new_orderbook_handler(
730 clock.clone(),
731 stats.clone(),
732 tx.clone(),
733 subscription.clone(),
734 );
735
736 if let Err(e) = client.run(handler).await {
738 log::error!("WebSocket client error: {e}");
739 *connection_status.write() = ConnectionState::Error;
740 }
741
742 if *stop_rx.borrow() {
744 break;
745 }
746
747 tokio::time::sleep(Duration::from_millis(1000)).await;
749 }
750 });
751
752 Ok(handle)
753 }
754}
755
756#[derive(Debug)]
758struct CoinbaseMessageHandler {
759 handler_type: HandlerType,
760 clock: Clock,
761 stats: Arc<RwLock<ConnectionStats>>,
762 trade_tx: Option<mpsc::Sender<TradeMessage>>,
763 orderbook_tx: Option<mpsc::Sender<Level2Update>>,
764 subscription: SubscriptionMessage,
765}
766
767#[derive(Debug)]
768enum HandlerType {
769 Trade,
770 Orderbook,
771}
772
773impl CoinbaseMessageHandler {
774 const fn new_trade_handler(
775 clock: Clock,
776 stats: Arc<RwLock<ConnectionStats>>,
777 tx: mpsc::Sender<TradeMessage>,
778 subscription: SubscriptionMessage,
779 ) -> Self {
780 Self {
781 handler_type: HandlerType::Trade,
782 clock,
783 stats,
784 trade_tx: Some(tx),
785 orderbook_tx: None,
786 subscription,
787 }
788 }
789
790 const fn new_orderbook_handler(
791 clock: Clock,
792 stats: Arc<RwLock<ConnectionStats>>,
793 tx: mpsc::Sender<Level2Update>,
794 subscription: SubscriptionMessage,
795 ) -> Self {
796 Self {
797 handler_type: HandlerType::Orderbook,
798 clock,
799 stats,
800 trade_tx: None,
801 orderbook_tx: Some(tx),
802 subscription,
803 }
804 }
805}
806
807#[async_trait]
808impl rusty_common::websocket::MessageHandler for CoinbaseMessageHandler {
809 async fn on_message(&mut self, message: Message) -> std::result::Result<(), WebSocketError> {
810 let local_time = self.clock.raw();
811
812 if let Some(text) = message.as_text() {
814 CoinbaseProvider::update_receive_stats(self.stats.clone(), text.len(), local_time);
815
816 let mut message_bytes = text.as_bytes().to_vec();
818 let json_value: simd_json::OwnedValue = simd_json::from_slice(&mut message_bytes)
819 .map_err(|e| {
820 WebSocketError::MessageProcessingError(format!("Failed to parse JSON: {e}"))
821 })?;
822
823 if let Some(message_type) = json_value["type"].as_str() {
825 match self.handler_type {
826 HandlerType::Trade
827 if message_type == "match" || message_type == "last_match" =>
828 {
829 if let Ok(trade_response) =
830 simd_json::from_slice::<TradeMessage>(&mut message_bytes)
831 && let Some(ref tx) = self.trade_tx
832 && let Err(e) = tx.send(trade_response).await
833 {
834 log::warn!("Failed to send trade message: {e}");
835 }
836 }
837 HandlerType::Orderbook if message_type == "l2update" => {
838 if let Ok(orderbook_response) =
839 simd_json::from_slice::<Level2Update>(&mut message_bytes)
840 && let Some(ref tx) = self.orderbook_tx
841 && let Err(e) = tx.send(orderbook_response).await
842 {
843 log::warn!("Failed to send orderbook message: {e}");
844 }
845 }
846 _ => {
847 match message_type {
849 "subscriptions" => {
850 log::info!("Coinbase subscription confirmed: {text}");
851 }
852 "heartbeat" => {
853 log::trace!("Coinbase heartbeat received");
854 }
855 "error" => {
856 log::error!("Coinbase WebSocket error: {text}");
857 }
858 _ => {
859 log::debug!("Unhandled Coinbase message type: {message_type}");
860 }
861 }
862 }
863 }
864 }
865 }
866
867 Ok(())
868 }
869
870 async fn on_connected(&mut self) -> std::result::Result<(), WebSocketError> {
871 let subscription_json = simd_json::to_string(&self.subscription).map_err(|e| {
873 WebSocketError::MessageProcessingError(format!("Failed to serialize subscription: {e}"))
874 })?;
875
876 log::info!("Sending Coinbase subscription: {subscription_json}");
877
878 Ok(())
881 }
882
883 async fn on_disconnected(&mut self) -> std::result::Result<(), WebSocketError> {
884 log::info!("Disconnected from Coinbase WebSocket");
885 Ok(())
886 }
887
888 async fn on_error(&mut self, error: WebSocketError) -> std::result::Result<(), WebSocketError> {
889 log::error!("Coinbase WebSocket error: {error}");
890 Ok(())
891 }
892}