1use rusty_common::collections::FxHashMap;
7use smartstring::alias::String;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use super::auth::{UpbitAuth, UpbitAuthConfig};
12use super::data::{orderbook::OrderbookMessage, prelude::*, trade::TradeMessage};
13use super::types::{UPBIT_API_URL, UPBIT_RATE_LIMITS, UPBIT_WS_URL};
14use anyhow::{Result, anyhow};
15use async_trait::async_trait;
16use parking_lot::RwLock;
17use quanta::Clock;
18use reqwest::header::{HeaderMap, HeaderValue};
19use rust_decimal::prelude::FromPrimitive;
20use rusty_common::json::Value;
21use rusty_common::websocket::{Message, WebSocketClient, WebSocketConfig, WebSocketError};
22use rusty_model::{
23 data::{
24 book_snapshot::OrderBookSnapshot,
25 market_trade::MarketTrade,
26 orderbook::OrderBook,
27 simd_orderbook::{SharedSimdOrderBook, SimdPriceLevels},
28 },
29 enums::OrderSide,
30 instruments::Instrument,
31};
32use simd_json::prelude::{ValueAsArray, ValueAsScalar, ValueObjectAccess};
33use smallvec::SmallVec;
34use std::str::FromStr;
35use tokio::sync::{mpsc, watch};
36use tokio::task::JoinHandle;
37
38use crate::provider::prelude::*;
39
40#[derive(Debug)]
45pub struct UpbitProvider {
46 config: ConnectionConfig,
48
49 connection_status: Arc<RwLock<ConnectionState>>,
51
52 stats: Arc<RwLock<ConnectionStats>>,
54
55 clock: Clock,
57
58 subscriptions: Arc<RwLock<FxHashMap<String, watch::Sender<bool>>>>,
60
61 instruments: Arc<RwLock<FxHashMap<String, Box<dyn Instrument>>>>,
63
64 ws_handles: Arc<RwLock<FxHashMap<String, JoinHandle<()>>>>,
66
67 last_connection_attempt: Arc<RwLock<Instant>>,
69
70 http_client: reqwest::Client,
72
73 auth: Option<UpbitAuth>,
75}
76
77impl Default for UpbitProvider {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83impl UpbitProvider {
84 #[inline]
86 #[must_use]
87 pub fn new() -> Self {
88 Self::with_config(None)
89 }
90
91 #[inline]
93 #[must_use]
94 pub fn with_config(config: Option<ConnectionConfig>) -> Self {
95 let mut default_config = ConnectionConfig::default();
96
97 default_config.websocket_config.base_url = UPBIT_WS_URL.into();
99 default_config.websocket_config.ping_interval_milliseconds = 30000; default_config.websocket_config.use_compression = true;
101
102 default_config.rest_config.base_url = UPBIT_API_URL.into();
104 default_config.rest_config.timeout_milliseconds = 5000; let config = config.unwrap_or(default_config);
108 let clock = config.clock.clone();
109
110 let mut headers = HeaderMap::new();
112 headers.insert(
113 "User-Agent",
114 HeaderValue::from_str(&config.rest_config.user_agent)
115 .unwrap_or_else(|_| HeaderValue::from_static("RustyHFT/1.0")),
116 );
117
118 let http_client = reqwest::Client::builder()
119 .timeout(Duration::from_millis(
120 config.rest_config.timeout_milliseconds,
121 ))
122 .connect_timeout(Duration::from_millis(
123 config.rest_config.timeout_milliseconds / 2,
124 ))
125 .pool_max_idle_per_host(config.rest_config.connection_pool_size)
126 .pool_idle_timeout(Duration::from_millis(
127 config.rest_config.keep_alive_milliseconds,
128 ))
129 .default_headers(headers)
130 .build()
131 .unwrap_or_default();
132
133 let auth = if let (Some(api_key), Some(secret_key)) =
135 (&config.auth_config.api_key, &config.auth_config.secret_key)
136 {
137 if !api_key.is_empty() && !secret_key.is_empty() {
138 Some(UpbitAuth::new(UpbitAuthConfig::new(
139 api_key.clone(),
140 secret_key.clone(),
141 )))
142 } else {
143 None
144 }
145 } else {
146 None
147 };
148
149 Self {
150 config,
151 connection_status: Arc::new(RwLock::new(ConnectionState::Disconnected)),
152 stats: Arc::new(RwLock::new(ConnectionStats::default())),
153 clock,
154 subscriptions: Arc::new(RwLock::new(FxHashMap::default())),
155 instruments: Arc::new(RwLock::new(FxHashMap::default())),
156 ws_handles: Arc::new(RwLock::new(FxHashMap::default())),
157 last_connection_attempt: Arc::new(RwLock::new(Instant::now())),
158 http_client,
159 auth,
160 }
161 }
162
163 #[inline]
165 fn update_receive_stats(
166 stats: Arc<RwLock<ConnectionStats>>,
167 message_size: usize,
168 local_time: u64,
169 ) {
170 let mut s = stats.write();
171 s.messages_received += 1;
172 s.bytes_received += message_size as u64;
173 s.last_message_time = local_time;
174 }
175
176 fn create_websocket_config(
178 url: String,
179 use_compression: bool,
180 ping_interval_milliseconds: u64,
181 timeout_milliseconds: u64,
182 ) -> WebSocketConfig {
183 WebSocketConfig::builder(rusty_common::types::Exchange::Upbit, url.to_string())
184 .ping_interval(Duration::from_millis(ping_interval_milliseconds))
185 .timeout(Duration::from_millis(timeout_milliseconds))
186 .compression(if use_compression {
187 rusty_common::websocket::CompressionConfig::default()
188 } else {
189 rusty_common::websocket::CompressionConfig::disabled()
190 })
191 .build()
192 }
193
194 #[inline]
196 async fn connect(&self) -> Result<()> {
197 if *self.connection_status.read() == ConnectionState::Connected {
198 return Ok(());
199 }
200
201 {
203 let now = Instant::now();
204 let last_attempt = *self.last_connection_attempt.read();
205 let backoff_ms = self
206 .config
207 .websocket_config
208 .reconnect
209 .backoff_initial_milliseconds;
210
211 if now.duration_since(last_attempt) < Duration::from_millis(backoff_ms) {
212 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
214 }
215
216 *self.last_connection_attempt.write() = Instant::now();
218 }
219
220 *self.connection_status.write() = ConnectionState::Connecting;
221
222 *self.connection_status.write() = ConnectionState::Connected;
225 self.stats.write().connected_time = self.clock.raw();
226 Ok(())
227 }
228
229 #[inline]
231 async fn create_trade_connection(
232 &self,
233 symbols: SmallVec<[String; 8]>,
234 tx: mpsc::Sender<TradeMessage>,
235 subscription: Vec<simd_json::OwnedValue>,
236 ) -> Result<JoinHandle<()>> {
237 let connection_id = format!("trade-{}", symbols.join(","));
238
239 let (stop_tx, stop_rx) = watch::channel(false);
241 self.subscriptions
242 .write()
243 .insert(connection_id.clone().into(), stop_tx);
244
245 let url = self.config.websocket_config.base_url.clone();
247 let clock = self.clock.clone();
248 let timeout_milliseconds = self.config.websocket_config.timeout_milliseconds;
249 let ping_interval_milliseconds = self.config.websocket_config.ping_interval_milliseconds;
250 let stats = self.stats.clone();
251 let connection_status = self.connection_status.clone();
252 let use_compression = self.config.websocket_config.use_compression;
253
254 let handle = tokio::spawn(async move {
256 loop {
257 if *stop_rx.borrow() {
259 break;
260 }
261
262 let ws_config = Self::create_websocket_config(
264 url.clone(),
265 use_compression,
266 ping_interval_milliseconds,
267 timeout_milliseconds,
268 );
269
270 let mut client = WebSocketClient::new(ws_config);
272
273 *connection_status.write() = ConnectionState::Connecting;
275
276 let handler = UpbitMessageHandler::new_trade_handler(
278 clock.clone(),
279 stats.clone(),
280 tx.clone(),
281 subscription.clone(),
282 );
283
284 if let Err(e) = client.run(handler).await {
286 log::error!("WebSocket client error: {e}");
287 *connection_status.write() = ConnectionState::Error;
288 }
289
290 if *stop_rx.borrow() {
292 break;
293 }
294
295 tokio::time::sleep(Duration::from_millis(1000)).await;
297 }
298 });
299
300 Ok(handle)
301 }
302
303 #[inline]
305 async fn create_orderbook_connection(
306 &self,
307 symbols: SmallVec<[String; 8]>,
308 tx: mpsc::Sender<OrderbookMessage>,
309 subscription: Vec<simd_json::OwnedValue>,
310 ) -> Result<JoinHandle<()>> {
311 let connection_id = format!("orderbook-{}", symbols.join(","));
312
313 let (stop_tx, stop_rx) = watch::channel(false);
315 self.subscriptions
316 .write()
317 .insert(connection_id.clone().into(), stop_tx);
318
319 let url = self.config.websocket_config.base_url.clone();
321 let clock = self.clock.clone();
322 let timeout_milliseconds = self.config.websocket_config.timeout_milliseconds;
323 let ping_interval_milliseconds = self.config.websocket_config.ping_interval_milliseconds;
324 let stats = self.stats.clone();
325 let connection_status = self.connection_status.clone();
326 let use_compression = self.config.websocket_config.use_compression;
327
328 let handle = tokio::spawn(async move {
330 loop {
331 if *stop_rx.borrow() {
333 break;
334 }
335
336 let ws_config = Self::create_websocket_config(
338 url.clone(),
339 use_compression,
340 ping_interval_milliseconds,
341 timeout_milliseconds,
342 );
343
344 let mut client = WebSocketClient::new(ws_config);
346
347 *connection_status.write() = ConnectionState::Connecting;
349
350 let handler = UpbitMessageHandler::new_orderbook_handler(
352 clock.clone(),
353 stats.clone(),
354 tx.clone(),
355 subscription.clone(),
356 );
357
358 if let Err(e) = client.run(handler).await {
360 log::error!("WebSocket client error: {e}");
361 *connection_status.write() = ConnectionState::Error;
362 }
363
364 if *stop_rx.borrow() {
366 break;
367 }
368
369 tokio::time::sleep(Duration::from_millis(1000)).await;
371 }
372 });
373
374 Ok(handle)
375 }
376}
377
378impl HttpClientProvider for UpbitProvider {
379 fn http_client(&self) -> &reqwest::Client {
380 &self.http_client
381 }
382}
383
384#[async_trait]
385impl Provider for UpbitProvider {
386 type TradeMessage = TradeMessage;
387 type DepthMessage = OrderbookMessage;
388 type InstrumentMessage = Value;
389
390 fn name(&self) -> &'static str {
391 "Upbit"
392 }
393
394 fn venue(&self) -> rusty_model::venues::Venue {
395 rusty_model::venues::Venue::Upbit
396 }
397
398 fn config(&self) -> &ConnectionConfig {
399 &self.config
400 }
401
402 async fn init(&mut self) -> Result<()> {
403 match self.get_instruments().await {
405 Ok(_) => Ok(()),
406 Err(e) => Err(anyhow!("Failed to initialize Upbit provider: {}", e)),
407 }
408 }
409
410 #[inline]
411 async fn subscribe_trades(
412 &self,
413 symbols: SmallVec<[String; 8]>,
414 ) -> Result<mpsc::Receiver<Self::TradeMessage>> {
415 self.connect().await?;
417
418 let (tx, rx) = mpsc::channel(1024);
420
421 let connection_id = format!("trade-{}", symbols.join(","));
423 let subscription =
424 create_trade_subscription(&connection_id, symbols.clone(), None, None, false);
425
426 let handle = self
428 .create_trade_connection(symbols, tx, subscription)
429 .await?;
430
431 self.ws_handles.write().insert(connection_id.into(), handle);
433
434 Ok(rx)
435 }
436
437 #[inline]
438 async fn unsubscribe_trades(&self) -> Result<()> {
439 let mut keys_to_remove = Vec::new();
441
442 for (key, _) in self.subscriptions.read().iter() {
443 if key.starts_with("trade-") {
444 keys_to_remove.push(key.clone());
445 }
446 }
447
448 for key in keys_to_remove {
450 if let Some(tx) = self.subscriptions.write().remove(&key) {
451 let _ = tx.send(true);
452 }
453
454 if let Some(handle) = self.ws_handles.write().remove(&key) {
456 handle.abort();
457 }
458 }
459
460 Ok(())
461 }
462
463 #[inline]
464 async fn subscribe_orderbook(
465 &self,
466 symbols: SmallVec<[String; 8]>,
467 ) -> Result<mpsc::Receiver<Self::DepthMessage>> {
468 self.connect().await?;
470
471 let (tx, rx) = mpsc::channel(1024);
473
474 let connection_id = format!("orderbook-{}", symbols.join(","));
476 let subscription =
477 create_orderbook_subscription(&connection_id, symbols.clone(), None, None, None, false);
478
479 let handle = self
481 .create_orderbook_connection(symbols, tx, subscription)
482 .await?;
483
484 self.ws_handles.write().insert(connection_id.into(), handle);
486
487 Ok(rx)
488 }
489
490 #[inline]
491 async fn unsubscribe_orderbook(&self) -> Result<()> {
492 let mut keys_to_remove = Vec::new();
494
495 for (key, _) in self.subscriptions.read().iter() {
496 if key.starts_with("orderbook-") {
497 keys_to_remove.push(key.clone());
498 }
499 }
500
501 for key in keys_to_remove {
503 if let Some(tx) = self.subscriptions.write().remove(&key) {
504 let _ = tx.send(true);
505 }
506
507 if let Some(handle) = self.ws_handles.write().remove(&key) {
509 handle.abort();
510 }
511 }
512
513 Ok(())
514 }
515
516 #[inline]
517 async fn get_instruments(&self) -> Result<Vec<Box<dyn Instrument>>> {
518 if !self.instruments.read().is_empty() {
520 let mut instruments = Vec::new();
522 for instrument in self.instruments.read().values() {
523 instruments.push(instrument.clone_box());
524 }
525 return Ok(instruments);
526 }
527
528 let url = format!(
530 "{}/market/all?isDetails=true",
531 self.config.rest_config.base_url
532 );
533
534 let response = self.http_client.get(&url).send().await?;
535
536 if !response.status().is_success() {
537 let error_text = response.text().await?;
538 return Err(anyhow!("Failed to fetch instruments: {}", error_text));
539 }
540
541 let bytes = response
543 .bytes()
544 .await
545 .map_err(|e| anyhow!("Failed to get response bytes: {}", e))?;
546 let mut bytes_vec = bytes.to_vec();
547 let markets: Vec<simd_json::OwnedValue> = simd_json::from_slice(&mut bytes_vec)
548 .map_err(|e| anyhow!("Failed to parse JSON: {}", e))?;
549
550 let mut instruments = Vec::with_capacity(markets.len());
552 let mut instruments_map = self.instruments.write();
553
554 for market in markets {
555 let market_id = market["market"]
557 .as_str()
558 .ok_or_else(|| anyhow!("Missing market field"))?;
559 let base_currency = market["base_asset"].as_str().unwrap_or_else(|| {
560 market["baseCurrency"]
561 .as_str()
562 .unwrap_or_else(|| market_id.split('-').nth(1).unwrap_or(""))
563 });
564 let quote_currency = market["quote_asset"].as_str().unwrap_or_else(|| {
565 market["quoteCurrency"]
566 .as_str()
567 .unwrap_or_else(|| market_id.split('-').nth(0).unwrap_or(""))
568 });
569
570 let instrument = Box::new(rusty_model::instruments::SpotInstrument::new(
572 market_id,
573 base_currency,
574 quote_currency,
575 rusty_model::venues::Venue::Upbit,
576 ));
577
578 instruments.push(instrument.clone_box());
580 instruments_map.insert(market_id.into(), instrument);
581 }
582
583 Ok(instruments)
584 }
585
586 #[inline]
587 async fn get_historical_trades(
588 &self,
589 symbol: &str,
590 limit: Option<u32>,
591 ) -> Result<Vec<MarketTrade>> {
592 let limit = limit.unwrap_or(100).min(500); let url = format!(
595 "{}/trades/ticks?market={}&count={}",
596 self.config.rest_config.base_url, symbol, limit
597 );
598
599 let response = self.http_client.get(&url).send().await?;
601
602 if !response.status().is_success() {
603 let error_text = response.text().await?;
604 return Err(anyhow!("Failed to fetch historical trades: {}", error_text));
605 }
606
607 let bytes = response
609 .bytes()
610 .await
611 .map_err(|e| anyhow!("Failed to get response bytes: {}", e))?;
612 let mut bytes_vec = bytes.to_vec();
613 let trades_data: Vec<simd_json::OwnedValue> = simd_json::from_slice(&mut bytes_vec)
614 .map_err(|e| anyhow!("Failed to parse JSON: {}", e))?;
615
616 let mut trades = Vec::with_capacity(trades_data.len());
618 let instrument_id =
619 rusty_model::instruments::InstrumentId::new(symbol, rusty_model::venues::Venue::Upbit);
620
621 for trade_data in trades_data {
622 let timestamp_ms = match trade_data["timestamp"].as_u64() {
624 Some(ts) => ts,
625 None => {
626 if let Some(trade_time_str) = trade_data["trade_time_utc"].as_str() {
628 if let Some(ts) = timestamp::iso8601_to_nanos(trade_time_str) {
630 ts / 1_000_000 } else {
632 self.clock.raw() / 1_000_000 }
635 } else {
636 self.clock.raw() / 1_000_000 }
639 }
640 };
641
642 let exchange_time = self.convert_exchange_timestamp(timestamp_ms);
644
645 let price_str_value: std::string::String;
647 let price_str = match trade_data["trade_price"].as_str() {
648 Some(s) => s,
649 None => {
650 price_str_value = trade_data["trade_price"]
651 .as_f64()
652 .map(|p| p.to_string())
653 .unwrap_or_else(|| "0.0".to_string());
654 &price_str_value
655 }
656 };
657
658 let quantity_str_value: std::string::String;
659 let quantity_str = match trade_data["trade_volume"].as_str() {
660 Some(s) => s,
661 None => {
662 quantity_str_value = trade_data["trade_volume"]
663 .as_f64()
664 .map(|q| q.to_string())
665 .unwrap_or_else(|| "0.0".to_string());
666 &quantity_str_value
667 }
668 };
669
670 let price =
671 rust_decimal::Decimal::from_str(price_str).unwrap_or(rust_decimal::Decimal::ZERO);
672 let quantity = rust_decimal::Decimal::from_str(quantity_str)
673 .unwrap_or(rust_decimal::Decimal::ZERO);
674
675 let direction = if trade_data["ask_bid"].as_str().unwrap_or("") == "ASK" {
677 OrderSide::Sell
678 } else {
679 OrderSide::Buy
680 };
681
682 let trade = MarketTrade {
684 timestamp: self.clock.now(),
685 exchange_time_ns: exchange_time,
686 price,
687 quantity,
688 direction,
689 instrument_id: instrument_id.clone(),
690 };
691
692 trades.push(trade);
693 }
694
695 Ok(trades)
696 }
697
698 #[inline]
699 async fn get_orderbook_snapshot(
700 &self,
701 symbol: &str,
702 depth: Option<u32>,
703 ) -> Result<OrderBookSnapshot> {
704 let url = format!(
706 "{}/orderbook?markets={}",
707 self.config.rest_config.base_url, symbol
708 );
709
710 let response = self.http_client.get(&url).send().await?;
712
713 if !response.status().is_success() {
714 let error_text = response.text().await?;
715 return Err(anyhow!(
716 "Failed to fetch orderbook snapshot: {}",
717 error_text
718 ));
719 }
720
721 let bytes = response
723 .bytes()
724 .await
725 .map_err(|e| anyhow!("Failed to get response bytes: {}", e))?;
726 let mut bytes_vec = bytes.to_vec();
727 let orderbooks: Vec<simd_json::OwnedValue> = simd_json::from_slice(&mut bytes_vec)
728 .map_err(|e| anyhow!("Failed to parse JSON: {}", e))?;
729
730 if orderbooks.is_empty() {
732 return Err(anyhow!("Empty orderbook response"));
733 }
734
735 let orderbook = &orderbooks[0];
737
738 let timestamp_ms = orderbook["timestamp"]
740 .as_u64()
741 .unwrap_or_else(|| self.clock.raw() / 1_000_000);
742
743 let exchange_time = self.convert_exchange_timestamp(timestamp_ms);
745
746 let instrument_id =
748 rusty_model::instruments::InstrumentId::new(symbol, rusty_model::venues::Venue::Upbit);
749
750 let mut order_book_depth =
752 OrderBookSnapshot::new_empty(instrument_id, self.clock.raw(), exchange_time);
753
754 let orderbook_units = orderbook["orderbook_units"]
756 .as_array()
757 .ok_or_else(|| anyhow!("Missing orderbook_units"))?;
758
759 let limit = depth.unwrap_or(orderbook_units.len() as u32) as usize;
761 let orderbook_units = &orderbook_units[0..std::cmp::min(limit, orderbook_units.len())];
762
763 for unit in orderbook_units {
765 if let (Some(bid_price), Some(bid_size)) =
767 (unit["bid_price"].as_f64(), unit["bid_size"].as_f64())
768 {
769 let bid_price_dec = rust_decimal::Decimal::from_f64(bid_price)
770 .unwrap_or(rust_decimal::Decimal::ZERO);
771 let bid_size_dec = rust_decimal::Decimal::from_f64(bid_size)
772 .unwrap_or(rust_decimal::Decimal::ZERO);
773
774 order_book_depth.add_bid(bid_price_dec, bid_size_dec);
775 }
776
777 if let (Some(ask_price), Some(ask_size)) =
779 (unit["ask_price"].as_f64(), unit["ask_size"].as_f64())
780 {
781 let ask_price_dec = rust_decimal::Decimal::from_f64(ask_price)
782 .unwrap_or(rust_decimal::Decimal::ZERO);
783 let ask_size_dec = rust_decimal::Decimal::from_f64(ask_size)
784 .unwrap_or(rust_decimal::Decimal::ZERO);
785
786 order_book_depth.add_ask(ask_price_dec, ask_size_dec);
787 }
788 }
789
790 Ok(order_book_depth)
791 }
792
793 #[inline]
794 async fn get_realtime_orderbook(&self, symbol: &str) -> Result<SharedSimdOrderBook> {
795 let snapshot = self.get_orderbook_snapshot(symbol, None).await?;
797
798 let bids: SmallVec<[rusty_model::data::orderbook::PriceLevel; 64]> = snapshot
801 .bids
802 .iter()
803 .map(|level| rusty_model::data::orderbook::PriceLevel::new(level.price, level.quantity))
804 .collect();
805
806 let asks: SmallVec<[rusty_model::data::orderbook::PriceLevel; 64]> = snapshot
807 .asks
808 .iter()
809 .map(|level| rusty_model::data::orderbook::PriceLevel::new(level.price, level.quantity))
810 .collect();
811
812 let orderbook = OrderBook::new(
813 snapshot.instrument_id.symbol.clone(),
814 snapshot.timestamp_event,
815 snapshot.timestamp_init,
816 bids,
817 asks,
818 );
819 let shared_orderbook = SharedSimdOrderBook::from_orderbook(&orderbook);
820
821 let (_, _depth_rx) = mpsc::channel::<OrderbookMessage>(1024);
823
824 let mut symbols = SmallVec::<[String; 8]>::new();
826 symbols.push(symbol.into());
827 let mut msg_rx = self.subscribe_orderbook(symbols).await?;
828
829 let shared_orderbook_clone = shared_orderbook.clone();
831 let clock = self.clock.clone();
832
833 tokio::spawn(async move {
834 while let Some(msg) = msg_rx.recv().await {
835 {
838 let data = &msg; let instrument_id = rusty_model::instruments::InstrumentId::new(
841 data.code.clone(),
842 rusty_model::venues::Venue::Upbit,
843 );
844
845 let timestamp_ms = if data.timestamp == 0 {
847 clock.raw() / 1_000_000 } else {
849 data.timestamp
850 };
851
852 let exchange_time = timestamp::exchange_time_to_nanos(
854 timestamp_ms,
855 TimestampFormat::Milliseconds,
856 );
857
858 shared_orderbook_clone.write(|ob| {
862 ob.exchange_timestamp_ns = exchange_time;
864 ob.system_timestamp_ns = clock.raw();
865
866 ob.bids.clear();
868 ob.asks.clear();
869
870 let mut bids =
873 SmallVec::<[rusty_model::data::orderbook::PriceLevel; 64]>::new();
874 let mut asks =
875 SmallVec::<[rusty_model::data::orderbook::PriceLevel; 64]>::new();
876
877 for unit in &data.orderbook_units {
878 if unit.bid_size > rust_decimal::Decimal::ZERO {
880 bids.push(rusty_model::data::orderbook::PriceLevel::new(
881 unit.bid_price,
882 unit.bid_size,
883 ));
884 }
885
886 if unit.ask_size > rust_decimal::Decimal::ZERO {
888 asks.push(rusty_model::data::orderbook::PriceLevel::new(
889 unit.ask_price,
890 unit.ask_size,
891 ));
892 }
893 }
894
895 bids.sort_by(|a, b| b.price.cmp(&a.price));
897
898 asks.sort_by(|a, b| a.price.cmp(&b.price));
900
901 ob.bids = SimdPriceLevels::from_smallvec(&bids);
903 ob.asks = SimdPriceLevels::from_smallvec(&asks);
904 });
905 }
906 }
907 });
908
909 Ok(shared_orderbook)
910 }
911
912 #[inline]
913 async fn is_connected(&self) -> bool {
914 *self.connection_status.read() == ConnectionState::Connected
915 }
916
917 #[inline]
918 async fn connection_status(&self) -> ConnectionState {
919 *self.connection_status.read()
920 }
921
922 #[inline]
923 async fn get_stats(&self) -> ConnectionStats {
924 self.stats.read().clone()
925 }
926
927 #[inline]
928 async fn ping(&self) -> Result<u64> {
929 let start = self.clock.raw();
932
933 let response = self
934 .http_client
935 .get(format!(
936 "{}/v1/market/all",
937 self.config.rest_config.base_url
938 ))
939 .send()
940 .await?;
941
942 if response.status().is_success() {
943 Ok(self.clock.raw().saturating_sub(start))
944 } else {
945 Err(anyhow!("Ping failed with status: {}", response.status()))
946 }
947 }
948
949 #[inline]
950 async fn reset_connection(&self) -> Result<()> {
951 let mut keys_to_remove = Vec::new();
953
954 for (key, _) in self.subscriptions.read().iter() {
955 keys_to_remove.push(key.clone());
956 }
957
958 for key in &keys_to_remove {
960 if let Some(tx) = self.subscriptions.write().remove(key) {
961 let _ = tx.send(true);
962 }
963
964 if let Some(handle) = self.ws_handles.write().remove(key) {
966 handle.abort();
967 }
968 }
969
970 *self.connection_status.write() = ConnectionState::Disconnected;
972
973 *self.stats.write() = ConnectionStats::default();
975
976 self.connect().await
978 }
979
980 #[inline]
981 fn get_rate_limits(&self) -> Vec<RateLimit> {
982 UPBIT_RATE_LIMITS.to_vec()
983 }
984
985 async fn shutdown(&mut self) -> Result<()> {
986 let mut keys_to_remove = Vec::new();
988
989 for (key, _) in self.subscriptions.read().iter() {
990 keys_to_remove.push(key.clone());
991 }
992
993 for key in keys_to_remove {
995 if let Some(tx) = self.subscriptions.write().remove(&key) {
996 let _ = tx.send(true);
997 }
998
999 if let Some(handle) = self.ws_handles.write().remove(&key) {
1001 handle.abort();
1002 }
1003 }
1004
1005 *self.connection_status.write() = ConnectionState::Disconnected;
1007
1008 Ok(())
1009 }
1010}
1011
1012#[derive(Debug)]
1014struct UpbitMessageHandler {
1015 handler_type: HandlerType,
1016 clock: Clock,
1017 stats: Arc<RwLock<ConnectionStats>>,
1018 trade_tx: Option<mpsc::Sender<TradeMessage>>,
1019 orderbook_tx: Option<mpsc::Sender<OrderbookMessage>>,
1020 subscription: Vec<simd_json::OwnedValue>,
1021}
1022
1023#[derive(Debug)]
1024enum HandlerType {
1025 Trade,
1026 Orderbook,
1027}
1028
1029impl UpbitMessageHandler {
1030 const fn new_trade_handler(
1031 clock: Clock,
1032 stats: Arc<RwLock<ConnectionStats>>,
1033 tx: mpsc::Sender<TradeMessage>,
1034 subscription: Vec<simd_json::OwnedValue>,
1035 ) -> Self {
1036 Self {
1037 handler_type: HandlerType::Trade,
1038 clock,
1039 stats,
1040 trade_tx: Some(tx),
1041 orderbook_tx: None,
1042 subscription,
1043 }
1044 }
1045
1046 const fn new_orderbook_handler(
1047 clock: Clock,
1048 stats: Arc<RwLock<ConnectionStats>>,
1049 tx: mpsc::Sender<OrderbookMessage>,
1050 subscription: Vec<simd_json::OwnedValue>,
1051 ) -> Self {
1052 Self {
1053 handler_type: HandlerType::Orderbook,
1054 clock,
1055 stats,
1056 trade_tx: None,
1057 orderbook_tx: Some(tx),
1058 subscription,
1059 }
1060 }
1061
1062 async fn parse_and_send_message<T>(
1075 message_bytes: &mut [u8],
1076 sender: Option<&mpsc::Sender<T>>,
1077 message_context: &str,
1078 ) where
1079 T: serde::de::DeserializeOwned + Send,
1080 {
1081 if let Ok(parsed_message) = simd_json::from_slice::<T>(message_bytes)
1082 && let Some(tx) = sender
1083 && let Err(e) = tx.send(parsed_message).await
1084 {
1085 log::warn!("Failed to send {message_context} message: {e}");
1086 }
1087 }
1088}
1089
1090#[async_trait]
1091impl rusty_common::websocket::MessageHandler for UpbitMessageHandler {
1092 async fn on_message(&mut self, message: Message) -> std::result::Result<(), WebSocketError> {
1093 let local_time = self.clock.raw();
1094
1095 if let Some(text) = message.as_text() {
1097 UpbitProvider::update_receive_stats(self.stats.clone(), text.len(), local_time);
1098
1099 let mut message_bytes = text.as_bytes().to_vec();
1101 let json_value: simd_json::OwnedValue = simd_json::from_slice(&mut message_bytes)
1102 .map_err(|e| {
1103 WebSocketError::MessageProcessingError(format!("Failed to parse JSON: {e}"))
1104 })?;
1105
1106 if let Some(msg_type) = json_value.get("type").and_then(|v| v.as_str()) {
1108 match self.handler_type {
1109 HandlerType::Trade if msg_type == "trade" => {
1110 Self::parse_and_send_message::<TradeMessage>(
1111 &mut message_bytes,
1112 self.trade_tx.as_ref(),
1113 "trade",
1114 )
1115 .await;
1116 }
1117 HandlerType::Orderbook if msg_type == "orderbook" => {
1118 Self::parse_and_send_message::<OrderbookMessage>(
1119 &mut message_bytes,
1120 self.orderbook_tx.as_ref(),
1121 "orderbook",
1122 )
1123 .await;
1124 }
1125 _ => {
1126 log::debug!("Unhandled Upbit message type: {msg_type}");
1128 }
1129 }
1130 }
1131 }
1132
1133 Ok(())
1134 }
1135
1136 async fn on_connected(&mut self) -> std::result::Result<(), WebSocketError> {
1137 for sub_msg in &self.subscription {
1139 let subscription_json = simd_json::to_string(sub_msg).map_err(|e| {
1140 WebSocketError::MessageProcessingError(format!(
1141 "Failed to serialize subscription: {e}"
1142 ))
1143 })?;
1144
1145 log::info!("Sending Upbit subscription: {subscription_json}");
1146 }
1147
1148 Ok(())
1151 }
1152
1153 async fn on_disconnected(&mut self) -> std::result::Result<(), WebSocketError> {
1154 log::info!("Disconnected from Upbit WebSocket");
1155 Ok(())
1156 }
1157
1158 async fn on_error(&mut self, error: WebSocketError) -> std::result::Result<(), WebSocketError> {
1159 log::error!("Upbit WebSocket error: {error}");
1160 Ok(())
1161 }
1162}