1use anyhow::{Result, bail};
12use async_trait::async_trait;
13use flume::Sender;
14use log::{debug, error, info};
15use parking_lot::RwLock;
16use rusty_common::auth::exchanges::coinbase::CoinbaseAuth;
17use rusty_common::utils::id_generation;
18use rusty_model::{
19 enums::OrderStatus, instruments::InstrumentId, trading_order::Order, venues::Venue,
20};
21use smallvec::SmallVec;
22use smartstring::alias::String as SmartString;
23use std::sync::Arc;
24
25use super::websocket_trading::CoinbaseWebsocketTrading as CoinbaseWebSocketTrader;
26use super::{
27 rest_client::{CoinbaseOrderRequest, CoinbaseRestClient},
28 };
30use crate::execution_engine::{Exchange as ExchangeTrait, ExecutionReport};
31use crate::instrument_registry::{InstrumentRegistry, OrderMetadata};
32use rust_decimal::Decimal;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum CoinbaseProtocol {
37 Rest,
39 WebSocket,
41 Fix,
43 Auto,
45}
46
47#[derive(Debug, Clone)]
49pub struct UnifiedClientConfig {
50 pub order_protocol: CoinbaseProtocol,
52 pub enable_websocket: bool,
54 pub enable_fix: bool,
56 pub sandbox: bool,
58}
59
60impl Default for UnifiedClientConfig {
61 fn default() -> Self {
62 Self {
63 order_protocol: CoinbaseProtocol::Auto,
64 enable_websocket: true,
65 enable_fix: false,
66 sandbox: false,
67 }
68 }
69}
70
71pub struct CoinbaseUnifiedClient {
76 auth: Arc<CoinbaseAuth>,
78
79 rest_client: Arc<CoinbaseRestClient>,
81
82 ws_client: Option<Arc<CoinbaseWebSocketTrader>>,
84
85 config: UnifiedClientConfig,
90
91 state: Arc<RwLock<ConnectionState>>,
93
94 instrument_registry: Arc<dyn InstrumentRegistry>,
96}
97
98#[derive(Debug, Default)]
99struct ConnectionState {
100 rest_connected: bool,
101 ws_connected: bool,
102 fix_connected: bool,
103}
104
105impl CoinbaseUnifiedClient {
106 pub fn new(
108 auth: Arc<CoinbaseAuth>,
109 config: UnifiedClientConfig,
110 instrument_registry: Arc<dyn InstrumentRegistry>,
111 ) -> Result<Self> {
112 let rest_client = Arc::new(
114 CoinbaseRestClient::new(auth.clone(), config.sandbox)
115 .map_err(|e| anyhow::anyhow!("Failed to create REST client: {}", e))?,
116 );
117
118 let ws_client = if config.enable_websocket {
120 let client = CoinbaseWebSocketTrader::new(
121 auth.clone(),
122 config.sandbox,
123 false, );
125 Some(Arc::new(client))
126 } else {
127 None
128 };
129
130 Ok(Self {
138 auth,
139 rest_client,
140 ws_client,
141 config,
143 state: Arc::new(RwLock::new(ConnectionState::default())),
144 instrument_registry,
145 })
146 }
147
148 pub async fn connect_with_report_sender(
150 &self,
151 report_tx: Sender<ExecutionReport>,
152 ) -> Result<()> {
153 info!(
154 "Connecting Coinbase unified client with protocols: REST={}, WS={}, FIX={}",
155 true, self.config.enable_websocket, self.config.enable_fix
156 );
157
158 self.state.write().rest_connected = true;
160
161 if let Some(ws_client) = &self.ws_client {
163 match ws_client.connect(report_tx.clone()).await {
164 Ok(()) => {
165 self.state.write().ws_connected = true;
166 info!("WebSocket connected successfully");
167 }
168 Err(e) => {
169 error!("WebSocket connection failed: {e}");
170 }
172 }
173 }
174
175 Ok(())
190 }
191
192 pub async fn disconnect(&self) -> Result<()> {
194 info!("Disconnecting Coinbase unified client");
195
196 if let Some(ws_client) = &self.ws_client {
198 if let Err(e) = ws_client.disconnect().await {
199 error!("WebSocket disconnect error: {e}");
200 }
201 self.state.write().ws_connected = false;
202 }
203
204 self.state.write().rest_connected = false;
213
214 Ok(())
215 }
216
217 fn select_protocol(&self, order: &Order) -> CoinbaseProtocol {
219 match self.config.order_protocol {
220 CoinbaseProtocol::Auto => {
221 let state = self.state.read();
222
223 if state.ws_connected && self.ws_client.is_some() {
228 CoinbaseProtocol::WebSocket
229 } else {
230 CoinbaseProtocol::Rest
231 }
232 }
233 protocol => protocol,
234 }
235 }
236
237 #[must_use]
239 pub fn get_connection_status(&self) -> (bool, bool, bool) {
240 let state = self.state.read();
241 (
242 state.rest_connected,
243 state.ws_connected,
244 state.fix_connected,
245 )
246 }
247
248 #[must_use]
250 pub fn is_connected(&self) -> bool {
251 let state = self.state.read();
252 state.rest_connected || state.ws_connected || state.fix_connected
253 }
254}
255
256impl CoinbaseUnifiedClient {
257 fn order_to_coinbase_request(order: &Order) -> Result<CoinbaseOrderRequest> {
259 use super::rest_client::{
260 CoinbaseLimitOrder, CoinbaseMarketOrder, CoinbaseOrderConfiguration,
261 CoinbaseOrderRequest, CoinbaseRestClient, CoinbaseStopLimitOrder,
262 };
263 use rusty_model::enums::OrderType;
264
265 let client_order_id = order.id.to_string().into();
266 let product_id = order.symbol.clone();
267 let side = CoinbaseRestClient::order_side_to_string(order.side);
268
269 let order_configuration = match order.order_type {
270 OrderType::Market => CoinbaseOrderConfiguration::MarketIoc {
271 market_market_ioc: CoinbaseMarketOrder {
272 base_size: Some(order.quantity.to_string().into()),
273 quote_size: None,
274 },
275 },
276 OrderType::Limit => {
277 let limit_price = order
278 .price
279 .ok_or_else(|| anyhow::anyhow!("Limit order requires price"))?
280 .to_string()
281 .into();
282
283 CoinbaseOrderConfiguration::LimitGtc {
284 limit_limit_gtc: CoinbaseLimitOrder {
285 base_size: order.quantity.to_string().into(),
286 limit_price,
287 post_only: None,
288 },
289 }
290 }
291 OrderType::StopLimit => {
292 let limit_price = order
293 .price
294 .ok_or_else(|| anyhow::anyhow!("Stop limit order requires price"))?
295 .to_string()
296 .into();
297 let stop_price = order
298 .stop_price
299 .ok_or_else(|| anyhow::anyhow!("Stop limit order requires stop price"))?
300 .to_string()
301 .into();
302
303 let stop_direction = if order.side == rusty_model::enums::OrderSide::Buy {
305 "STOP_DIRECTION_STOP_UP".into()
306 } else {
307 "STOP_DIRECTION_STOP_DOWN".into()
308 };
309
310 CoinbaseOrderConfiguration::StopLimitGtc {
311 stop_limit_stop_limit_gtc: CoinbaseStopLimitOrder {
312 base_size: order.quantity.to_string().into(),
313 limit_price,
314 stop_price,
315 stop_direction,
316 },
317 }
318 }
319 _ => bail!("Unsupported order type: {:?}", order.order_type),
320 };
321
322 Ok(CoinbaseOrderRequest {
323 client_order_id,
324 product_id,
325 side,
326 order_configuration,
327 })
328 }
329
330 async fn place_order_internal(
332 &self,
333 order: Order,
334 report_tx: Sender<ExecutionReport>,
335 ) -> Result<()> {
336 let coinbase_request = Self::order_to_coinbase_request(&order)?;
338
339 let order_id = order.id;
341 let instrument_id = InstrumentId::new(order.symbol.clone(), order.venue);
342 let quantity = order.quantity;
343
344 match self.rest_client.place_order(&coinbase_request).await {
346 Ok(response) => {
347 if response.success {
348 let metadata = OrderMetadata {
350 size: Some(quantity),
351 order_type: None, side: Some(coinbase_request.side.clone()),
353 };
354 self.instrument_registry.cache_order_mapping(
355 &response.order_id,
356 instrument_id.clone(),
357 Some(metadata),
358 );
359
360 let report = ExecutionReport {
362 id: id_generation::generate_ack_id(&order_id.to_string()),
363 order_id: response.order_id.clone(),
364 exchange_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
365 system_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
366 instrument_id,
367 status: OrderStatus::New,
368 filled_quantity: Decimal::ZERO,
369 remaining_quantity: quantity,
370 execution_price: None,
371 reject_reason: None,
372 exchange_execution_id: Some(response.order_id),
373 is_final: false,
374 };
375
376 if let Err(e) = report_tx.send_async(report).await {
377 error!("Failed to send acceptance report: {e}");
378 }
379 } else {
380 let report = ExecutionReport {
382 id: id_generation::generate_rejection_id(&order_id.to_string()),
383 order_id: coinbase_request.client_order_id.clone(),
384 exchange_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
385 system_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
386 instrument_id,
387 status: OrderStatus::Rejected,
388 filled_quantity: Decimal::ZERO,
389 remaining_quantity: quantity,
390 execution_price: None,
391 reject_reason: Some(response.failure_reason),
392 exchange_execution_id: None,
393 is_final: true,
394 };
395
396 if let Err(e) = report_tx.send_async(report).await {
397 error!("Failed to send rejection report: {e}");
398 }
399 }
400 Ok(())
401 }
402 Err(e) => {
403 let report = ExecutionReport {
405 id: id_generation::generate_report_id("error", &order_id.to_string()),
406 order_id: coinbase_request.client_order_id,
407 exchange_timestamp: 0,
408 system_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
409 instrument_id,
410 status: OrderStatus::Rejected,
411 filled_quantity: Decimal::ZERO,
412 remaining_quantity: quantity,
413 execution_price: None,
414 reject_reason: Some(e.to_string().into()),
415 exchange_execution_id: None,
416 is_final: true,
417 };
418
419 if let Err(send_err) = report_tx.send_async(report).await {
420 error!("Failed to send error report: {send_err}");
421 }
422
423 Err(e)
424 }
425 }
426 }
427
428 async fn cancel_order_internal(
430 &self,
431 order_id: SmartString,
432 report_tx: Sender<ExecutionReport>,
433 ) -> Result<()> {
434 match self.rest_client.cancel_order(&order_id).await {
435 Ok(response) => {
436 if response.success {
437 let instrument_id = self
439 .instrument_registry
440 .lookup_by_order_id(&order_id)
441 .unwrap_or_else(|| {
442 let normalized_symbol = self
444 .instrument_registry
445 .normalize_symbol(&order_id, Venue::Coinbase);
446 InstrumentId::new(normalized_symbol, Venue::Coinbase)
447 });
448
449 self.instrument_registry.remove_mapping(&order_id);
451
452 let report = ExecutionReport {
454 id: id_generation::generate_cancel_id(&order_id),
455 order_id: order_id.clone(),
456 exchange_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
457 system_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
458 instrument_id,
459 status: OrderStatus::Cancelled,
460 filled_quantity: Decimal::ZERO,
461 remaining_quantity: Decimal::ZERO,
462 execution_price: None,
463 reject_reason: None,
464 exchange_execution_id: None,
465 is_final: true,
466 };
467
468 if let Err(e) = report_tx.send_async(report).await {
469 error!("Failed to send cancellation report: {e}");
470 }
471 } else {
472 let instrument_id = self
474 .instrument_registry
475 .lookup_by_order_id(&order_id)
476 .unwrap_or_else(|| {
477 let normalized_symbol = self
479 .instrument_registry
480 .normalize_symbol(&order_id, Venue::Coinbase);
481 InstrumentId::new(normalized_symbol, Venue::Coinbase)
482 });
483
484 let report = ExecutionReport {
486 id: id_generation::generate_report_id("cancel_fail", order_id.as_ref()),
487 order_id: order_id.clone(),
488 exchange_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
489 system_timestamp: rusty_common::time::get_epoch_timestamp_ns(),
490 instrument_id,
491 status: OrderStatus::Rejected,
492 filled_quantity: Decimal::ZERO,
493 remaining_quantity: Decimal::ZERO,
494 execution_price: None,
495 reject_reason: Some(response.failure_reason),
496 exchange_execution_id: None,
497 is_final: false,
498 };
499
500 if let Err(e) = report_tx.send_async(report).await {
501 error!("Failed to send cancel failure report: {e}");
502 }
503 }
504 Ok(())
505 }
506 Err(e) => Err(e),
507 }
508 }
509}
510
511#[async_trait]
512impl ExchangeTrait for CoinbaseUnifiedClient {
513 async fn place_order(&self, order: Order, report_tx: Sender<ExecutionReport>) -> Result<()> {
514 let protocol = self.select_protocol(&order);
515 debug!("Placing order {} via {:?} protocol", order.id, protocol);
516
517 match protocol {
518 CoinbaseProtocol::Rest => self.place_order_internal(order, report_tx).await,
519 CoinbaseProtocol::WebSocket => {
520 if let Some(ws_client) = &self.ws_client {
522 ws_client.place_order(order, report_tx).await
523 } else {
524 self.place_order_internal(order, report_tx).await
526 }
527 }
528 CoinbaseProtocol::Fix => {
529 self.place_order_internal(order, report_tx).await
534 }
536 CoinbaseProtocol::Auto => {
537 self.place_order_internal(order, report_tx).await
539 }
540 }
541 }
542
543 async fn cancel_order(
544 &self,
545 order_id: SmartString,
546 report_tx: Sender<ExecutionReport>,
547 ) -> Result<()> {
548 let protocol = self.config.order_protocol;
550 debug!("Cancelling order {} via {:?} protocol", order_id, protocol);
551
552 match protocol {
553 CoinbaseProtocol::Rest | CoinbaseProtocol::Auto => {
554 self.cancel_order_internal(order_id, report_tx).await
555 }
556 CoinbaseProtocol::WebSocket => {
557 if let Some(ws_client) = &self.ws_client {
558 ws_client.cancel_order(order_id, report_tx).await
559 } else {
560 self.cancel_order_internal(order_id, report_tx).await
561 }
562 }
563 CoinbaseProtocol::Fix => {
564 self.cancel_order_internal(order_id, report_tx).await
568 }
570 }
571 }
572
573 async fn modify_order(
574 &self,
575 order_id: SmartString,
576 new_quantity: Option<rust_decimal::Decimal>,
577 new_price: Option<rust_decimal::Decimal>,
578 report_tx: Sender<ExecutionReport>,
579 ) -> Result<()> {
580 bail!("Order modification not supported by Coinbase. Please cancel and place a new order.")
583 }
584
585 async fn get_order_status(&self, order_id: &str) -> Result<OrderStatus> {
586 let order = self.rest_client.get_order(order_id).await?;
588 Ok(CoinbaseRestClient::string_to_order_status(&order.status))
589 }
590
591 fn venue(&self) -> Venue {
593 Venue::Coinbase
594 }
595
596 async fn cancel_all_orders(
598 &self,
599 instrument_id: Option<InstrumentId>,
600 report_tx: Sender<ExecutionReport>,
601 ) -> Result<()> {
602 bail!("Cancel all orders not yet implemented for Coinbase unified client")
605 }
606
607 async fn connect(&self, report_sender: Sender<ExecutionReport>) -> Result<()> {
609 self.connect_with_report_sender(report_sender).await
611 }
612
613 async fn disconnect(&self) -> Result<()> {
615 info!("Disconnecting Coinbase unified client");
617
618 if let Some(ws_client) = &self.ws_client {
620 if let Err(e) = ws_client.disconnect().await {
621 error!("WebSocket disconnect error: {e}");
622 }
623 self.state.write().ws_connected = false;
624 }
625
626 self.state.write().rest_connected = false;
635
636 Ok(())
637 }
638
639 async fn is_connected(&self) -> bool {
641 let state = self.state.read();
642 state.rest_connected || state.ws_connected || state.fix_connected
643 }
644
645 async fn get_instruments(&self) -> Result<SmallVec<[InstrumentId; 32]>> {
647 let products = self.rest_client.get_products().await?;
649 let instruments: SmallVec<[InstrumentId; 32]> = products
650 .into_iter()
651 .filter(|p| p.status == "online" && !p.trading_disabled)
652 .map(|p| InstrumentId::new(p.product_id, Venue::Coinbase))
653 .collect();
654 Ok(instruments)
655 }
656
657 async fn send_fix_message(&self, _message: Vec<u8>) -> Result<()> {
659 bail!("FIX message sending not supported through unified client")
660 }
661
662 async fn receive_fix_message(&self) -> Result<Vec<u8>> {
664 bail!("FIX message receiving not supported through unified client")
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use rusty_common::auth::exchanges::coinbase::CoinbaseAuth;
672 use std::str::FromStr;
673
674 #[test]
675 fn test_unified_client_creation() {
676 let auth = Arc::new(CoinbaseAuth::new_hmac(
677 "test_key".into(),
678 "test_secret".into(),
679 ));
680
681 let config = UnifiedClientConfig {
682 order_protocol: CoinbaseProtocol::Auto,
683 enable_websocket: true,
684 enable_fix: false,
685 sandbox: true,
686 };
687
688 let registry = crate::instrument_registry::create_shared_registry();
689 let client = CoinbaseUnifiedClient::new(auth, config, registry);
690 assert!(client.is_ok());
691
692 let client = client.unwrap();
693 assert!(!client.is_connected());
694 }
695
696 #[test]
697 fn test_protocol_selection() {
698 let auth = Arc::new(CoinbaseAuth::new_hmac(
699 "test_key".into(),
700 "test_secret".into(),
701 ));
702
703 let config = UnifiedClientConfig {
704 order_protocol: CoinbaseProtocol::Rest,
705 enable_websocket: true,
706 enable_fix: false,
707 sandbox: true,
708 };
709
710 let registry = crate::instrument_registry::create_shared_registry();
711 let client = CoinbaseUnifiedClient::new(auth, config, registry).unwrap();
712
713 let order = Order {
714 id: rusty_model::types::OrderId::new(),
715 client_id: rusty_model::types::ClientId::new("test"),
716 symbol: "BTC-USD".into(),
717 side: rusty_model::enums::OrderSide::Buy,
718 order_type: rusty_model::enums::OrderType::Limit,
719 quantity: rust_decimal::Decimal::from_str("0.01").unwrap(),
720 price: Some(rust_decimal::Decimal::from_str("50000").unwrap()),
721 stop_price: None,
722 exchange_order_id: None,
723 venue: rusty_model::venues::Venue::Coinbase,
724 filled_quantity: rust_decimal::Decimal::ZERO,
725 average_fill_price: None,
726 status: OrderStatus::New,
727 creation_time_ns: rusty_common::time::get_epoch_timestamp_ns(),
728 update_time_ns: rusty_common::time::get_epoch_timestamp_ns(),
729 time_in_force: rusty_model::enums::TimeInForce::GTC,
730 metadata: simd_json::json!(null),
731 };
732
733 let protocol = client.select_protocol(&order);
734 assert_eq!(protocol, CoinbaseProtocol::Rest);
735 }
736}