rusty_ems/exchanges/
binance_websocket.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use flume::Sender;
5use futures::StreamExt;
6use log::error;
7use quanta::Clock;
8use rust_decimal::Decimal;
9use rusty_common::utils::id_generation;
10use rusty_common::websocket::Message;
11use rusty_model::{enums::OrderStatus, instruments::InstrumentId, venues::Venue};
12use simd_json::prelude::*;
13use simd_json::value::owned::Value as JsonValue;
14use smartstring::alias::String as SmartString;
15
16use crate::exchanges::websocket_common::{
17    BinanceListenKeyManager, BinanceWebSocketManager, get_binance_ws_url,
18};
19use crate::execution_engine::ExecutionReport;
20use rusty_common::auth::exchanges::binance::BinanceAuth;
21
22/// Base URL for Binance Spot WebSocket combined streams
23const BINANCE_WS_COMBINED_URL: &str = "wss://stream.binance.com:9443/stream";
24
25/// Standalone function to map Binance order status to internal `OrderStatus`
26fn map_order_status(status: &str) -> OrderStatus {
27    match status {
28        "NEW" => OrderStatus::New,
29        "PARTIALLY_FILLED" => OrderStatus::PartiallyFilled,
30        "FILLED" => OrderStatus::Filled,
31        "CANCELED" | "CANCELLED" => OrderStatus::Cancelled,
32        "REJECTED" => OrderStatus::Rejected,
33        "EXPIRED" => OrderStatus::Expired,
34        "PENDING_CANCEL" => OrderStatus::Pending,
35        "PENDING_NEW" => OrderStatus::Pending,
36        "REPLACED" => OrderStatus::Unknown, // Binance-specific status for replaced orders
37        _ => {
38            error!("Unknown Binance order status: {status}");
39            OrderStatus::Unknown
40        }
41    }
42}
43
44/// WebSocket client for Binance user data stream
45pub struct BinanceWebSocketClient {
46    /// Authentication handler - direct reference for maximum performance
47    auth: Arc<BinanceAuth>,
48
49    /// WebSocket connection manager
50    connection_manager: BinanceWebSocketManager,
51
52    /// Listen key manager
53    listen_key_manager: BinanceListenKeyManager,
54
55    /// High-precision clock
56    clock: Clock,
57}
58
59impl Clone for BinanceWebSocketClient {
60    fn clone(&self) -> Self {
61        Self {
62            auth: self.auth.clone(),
63            connection_manager: self.connection_manager.clone(),
64            listen_key_manager: BinanceListenKeyManager::new("https://api.binance.com".to_string()),
65            clock: self.clock.clone(),
66        }
67    }
68}
69
70impl BinanceWebSocketClient {
71    /// Create a new Binance WebSocket client with direct auth for maximum performance
72    #[must_use]
73    pub fn new(auth: Arc<BinanceAuth>) -> Self {
74        Self {
75            auth,
76            connection_manager: BinanceWebSocketManager::new(),
77            listen_key_manager: BinanceListenKeyManager::new("https://api.binance.com".to_string()),
78            clock: Clock::new(),
79        }
80    }
81
82    /// Create a new Binance WebSocket client with custom API URL
83    #[must_use]
84    pub fn with_api_url(self, api_url: &str) -> Self {
85        Self {
86            auth: self.auth,
87            connection_manager: self.connection_manager,
88            listen_key_manager: BinanceListenKeyManager::new(api_url.to_string()),
89            clock: self.clock,
90        }
91    }
92
93    /// Generate headers for authenticated requests - optimized for performance
94    fn generate_headers(
95        &self,
96        method: &str,
97        path: &str,
98        body: Option<&str>,
99    ) -> Result<reqwest::header::HeaderMap> {
100        let auth_headers = self
101            .auth
102            .generate_headers(method, path, None, body)
103            .map_err(|e| anyhow!("Auth header generation failed: {}", e))?;
104
105        let mut headers = reqwest::header::HeaderMap::new();
106
107        for (key, value) in auth_headers {
108            let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
109                .map_err(|e| anyhow!("Invalid header name: {}", e))?;
110            let header_value = reqwest::header::HeaderValue::from_str(&value)
111                .map_err(|e| anyhow!("Invalid header value: {}", e))?;
112            headers.insert(header_name, header_value);
113        }
114
115        Ok(headers)
116    }
117
118    /// Start WebSocket connection for user data stream
119    pub async fn start_user_data_stream(
120        &self,
121        report_tx: Sender<ExecutionReport>,
122        market_type: Option<&str>,
123    ) -> Result<()> {
124        // Get or create listen key using the manager
125        let listen_key = self
126            .listen_key_manager
127            .get_or_create_listen_key(|method, path, body| {
128                self.generate_headers(method, path, body)
129            })
130            .await?;
131
132        // Determine the appropriate WebSocket URL based on the market type
133        let base_ws_url = get_binance_ws_url(market_type);
134
135        // Connect to WebSocket
136        let ws_url = format!("{base_ws_url}/{listen_key}");
137
138        // Connect using the connection manager
139        self.connection_manager.connect_with_retry(&ws_url).await?;
140
141        // Create standalone WebSocket handler to avoid capturing self
142        let ws_stream = self.connection_manager.get_stream_handle();
143        let report_tx = report_tx.clone();
144        let clock = self.clock.clone();
145
146        // Pass the mapping function as a static function reference
147        tokio::spawn(async move {
148            if let Some(mut ws_stream) = ws_stream.write().await.take() {
149                while let Some(frame) = ws_stream.next().await {
150                    let msg = Message::from_frame_view(frame);
151
152                    match msg {
153                        Message::Text(text) => {
154                            let mut text_copy = text.to_string();
155                            match unsafe { simd_json::from_str::<JsonValue>(&mut text_copy) } {
156                                Ok(json) => {
157                                    if let Some(event_type) = json.get("e").and_then(|v| v.as_str())
158                                        && event_type == "executionReport"
159                                    {
160                                        // Process execution report
161                                        let symbol: SmartString = json
162                                            .get("s")
163                                            .and_then(|v| v.as_str())
164                                            .unwrap_or("")
165                                            .into();
166                                        let order_id: SmartString = json
167                                            .get("i")
168                                            .and_then(simd_json::prelude::ValueAsScalar::as_u64)
169                                            .unwrap_or(0)
170                                            .to_string()
171                                            .into();
172                                        let client_order_id: SmartString = json
173                                            .get("c")
174                                            .and_then(|v| v.as_str())
175                                            .unwrap_or("")
176                                            .into();
177                                        let price: SmartString = json
178                                            .get("p")
179                                            .and_then(|v| v.as_str())
180                                            .unwrap_or("0")
181                                            .into();
182                                        let original_qty: SmartString = json
183                                            .get("q")
184                                            .and_then(|v| v.as_str())
185                                            .unwrap_or("0")
186                                            .into();
187                                        let executed_qty: SmartString = json
188                                            .get("z")
189                                            .and_then(|v| v.as_str())
190                                            .unwrap_or("0")
191                                            .into();
192                                        let status =
193                                            json.get("X").and_then(|v| v.as_str()).unwrap_or("");
194                                        let transaction_time = json
195                                            .get("T")
196                                            .and_then(simd_json::prelude::ValueAsScalar::as_u64)
197                                            .unwrap_or(0);
198
199                                        // Create instrument ID
200                                        let instrument = InstrumentId {
201                                            symbol,
202                                            venue: Venue::Binance,
203                                        };
204
205                                        // Create execution report
206                                        let report = ExecutionReport {
207                                            id: id_generation::generate_exchange_report_id(
208                                                "binance",
209                                                &client_order_id,
210                                            ),
211                                            order_id: client_order_id,
212                                            exchange_timestamp: transaction_time * 1_000_000, // Convert ms to ns
213                                            system_timestamp:
214                                                rusty_common::time::get_epoch_timestamp_ns(),
215                                            instrument_id: instrument,
216                                            status: map_order_status(status),
217                                            filled_quantity: Decimal::from_str_exact(&executed_qty)
218                                                .unwrap_or(Decimal::ZERO),
219                                            remaining_quantity: {
220                                                let original =
221                                                    Decimal::from_str_exact(&original_qty)
222                                                        .unwrap_or(Decimal::ZERO);
223                                                let executed =
224                                                    Decimal::from_str_exact(&executed_qty)
225                                                        .unwrap_or(Decimal::ZERO);
226                                                original - executed
227                                            },
228                                            execution_price: Decimal::from_str_exact(&price).ok(),
229                                            reject_reason: None,
230                                            exchange_execution_id: Some(order_id),
231                                            is_final: matches!(
232                                                map_order_status(status),
233                                                OrderStatus::Filled
234                                                    | OrderStatus::Cancelled
235                                                    | OrderStatus::Rejected
236                                                    | OrderStatus::Expired
237                                            ),
238                                        };
239
240                                        if let Err(e) = report_tx.send_async(report).await {
241                                            error!("Failed to send execution report: {e}");
242                                        }
243                                    }
244                                }
245                                Err(e) => {
246                                    error!("Failed to parse WebSocket message: {e}");
247                                }
248                            }
249                        }
250                        Message::Close(_) => {
251                            break;
252                        }
253                        _ => {}
254                    }
255                }
256            }
257        });
258
259        // Start listen key refresh task using the manager
260        let auth = self.auth.clone();
261        self.listen_key_manager
262            .start_refresh_task(move |method, path, body| {
263                let auth_headers = auth.generate_headers(method, path, None, body)?;
264                let mut headers = reqwest::header::HeaderMap::new();
265
266                for (key, value) in auth_headers {
267                    let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
268                        .map_err(|e| anyhow!("Invalid header name: {}", e))?;
269                    let header_value = reqwest::header::HeaderValue::from_str(&value)
270                        .map_err(|e| anyhow!("Invalid header value: {}", e))?;
271                    headers.insert(header_name, header_value);
272                }
273
274                Ok(headers)
275            });
276
277        Ok(())
278    }
279
280    /// Check if WebSocket is connected
281    pub fn is_connected(&self) -> bool {
282        self.connection_manager.is_connected()
283    }
284
285    /// Disconnect WebSocket
286    pub async fn disconnect(&self) -> Result<()> {
287        // Disconnect using the connection manager
288        self.connection_manager.disconnect().await?;
289
290        // Delete listen key if one exists using the manager
291        // Note: The listen key manager handles key cleanup internally
292
293        Ok(())
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use dotenv::dotenv;
301
302    use std::env;
303
304    #[tokio::test]
305    async fn test_websocket_client_creation() {
306        dotenv().ok();
307
308        let api_key = env::var("BINANCE_API_KEY").unwrap_or_default();
309        let secret_key = env::var("BINANCE_SECRET_KEY").unwrap_or_default();
310
311        if api_key.is_empty() || secret_key.is_empty() {
312            println!("Skipping test: BINANCE_API_KEY and BINANCE_SECRET_KEY must be set");
313            return;
314        }
315
316        let auth = Arc::new(BinanceAuth::new_hmac(api_key.into(), secret_key.into()));
317        let ws_client = BinanceWebSocketClient::new(auth);
318
319        assert!(!ws_client.is_connected());
320        println!("WebSocket client created successfully");
321    }
322
323    #[tokio::test]
324    async fn test_listen_key_creation() {
325        dotenv().ok();
326
327        let api_key = env::var("BINANCE_API_KEY").unwrap_or_default();
328        let secret_key = env::var("BINANCE_SECRET_KEY").unwrap_or_default();
329
330        if api_key.is_empty() || secret_key.is_empty() {
331            println!("Skipping test: BINANCE_API_KEY and BINANCE_SECRET_KEY must be set");
332            return;
333        }
334
335        let auth = Arc::new(BinanceAuth::new_hmac(api_key.into(), secret_key.into()));
336        let ws_client = BinanceWebSocketClient::new(auth);
337
338        let result = ws_client
339            .listen_key_manager
340            .get_or_create_listen_key(|method, path, body| {
341                ws_client.generate_headers(method, path, body)
342            })
343            .await;
344        assert!(result.is_ok(), "Failed to create listen key: {result:?}");
345
346        let listen_key = result.unwrap();
347        assert!(!listen_key.is_empty(), "Listen key should not be empty");
348        println!("Listen key created successfully: {listen_key}");
349
350        // Test refreshing the listen key
351        let refresh_result = ws_client
352            .listen_key_manager
353            .refresh_listen_key(&listen_key, |method, path, body| {
354                ws_client.generate_headers(method, path, body)
355            })
356            .await;
357        assert!(
358            refresh_result.is_ok(),
359            "Failed to refresh listen key: {refresh_result:?}"
360        );
361        println!("Listen key refreshed successfully");
362    }
363}