rusty_ems/exchanges/
websocket_common.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
3use std::time::Duration;
4
5use anyhow::{Result, anyhow};
6use parking_lot::RwLock as ParkingLotRwLock;
7use rusty_common::websocket::connector::{WebSocketSink, WebSocketStream};
8use rusty_common::websocket::{WebSocketConfig, WebSocketConnector};
9use simd_json::prelude::{ValueAsScalar, ValueObjectAccess};
10use tokio::sync::RwLock as AsyncRwLock;
11
12/// WebSocket connection state for Binance connections
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ConnectionState {
15    /// WebSocket connection is not established
16    Disconnected = 0,
17    /// WebSocket connection is being established
18    Connecting = 1,
19    /// WebSocket connection is established and operational
20    Connected = 2,
21    /// WebSocket connection is attempting to reconnect after disconnection
22    Reconnecting = 3,
23    /// WebSocket connection has failed and cannot be established
24    Failed = 4,
25}
26
27impl From<u8> for ConnectionState {
28    fn from(value: u8) -> Self {
29        match value {
30            0 => Self::Disconnected,
31            1 => Self::Connecting,
32            2 => Self::Connected,
33            3 => Self::Reconnecting,
34            4 => Self::Failed,
35            _ => Self::Disconnected,
36        }
37    }
38}
39
40impl From<ConnectionState> for u8 {
41    fn from(state: ConnectionState) -> Self {
42        state as Self
43    }
44}
45
46/// Shared WebSocket connection manager for Binance connections
47pub struct BinanceWebSocketManager {
48    /// WebSocket sink for sending messages
49    pub ws_sink: Arc<AsyncRwLock<Option<WebSocketSink>>>,
50
51    /// WebSocket stream for receiving messages
52    pub ws_stream: Arc<AsyncRwLock<Option<WebSocketStream>>>,
53
54    /// Connection state (atomic for lock-free access)
55    pub state: Arc<AtomicU8>,
56
57    /// Simple connected flag for basic checks
58    pub connected: AtomicBool,
59}
60
61impl Clone for BinanceWebSocketManager {
62    fn clone(&self) -> Self {
63        Self {
64            ws_sink: self.ws_sink.clone(),
65            ws_stream: self.ws_stream.clone(),
66            state: self.state.clone(),
67            connected: AtomicBool::new(self.connected.load(Ordering::Relaxed)),
68        }
69    }
70}
71
72impl Default for BinanceWebSocketManager {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl BinanceWebSocketManager {
79    /// Create a new WebSocket connection manager
80    #[must_use]
81    pub fn new() -> Self {
82        Self {
83            ws_sink: Arc::new(AsyncRwLock::new(None)),
84            ws_stream: Arc::new(AsyncRwLock::new(None)),
85            state: Arc::new(AtomicU8::new(ConnectionState::Disconnected.into())),
86            connected: AtomicBool::new(false),
87        }
88    }
89
90    /// Get current connection state
91    pub fn get_state(&self) -> ConnectionState {
92        ConnectionState::from(self.state.load(Ordering::Relaxed))
93    }
94
95    /// Set connection state
96    pub fn set_state(&self, state: ConnectionState) {
97        self.state.store(state.into(), Ordering::SeqCst);
98        self.connected
99            .store(state == ConnectionState::Connected, Ordering::SeqCst);
100    }
101
102    /// Check if connected
103    pub fn is_connected(&self) -> bool {
104        self.connected.load(Ordering::Relaxed)
105    }
106
107    /// Connect to WebSocket with automatic retry
108    pub async fn connect_with_retry(&self, ws_url: &str) -> Result<()> {
109        self.set_state(ConnectionState::Connecting);
110
111        // Create WebSocket configuration
112        let ws_config =
113            WebSocketConfig::new(rusty_common::types::Exchange::Binance, ws_url.to_string());
114
115        let mut connector = WebSocketConnector::new(
116            ws_config,
117            Arc::new(ParkingLotRwLock::new(
118                rusty_common::websocket::ConnectionStats::default(),
119            )),
120            Arc::new(ParkingLotRwLock::new(
121                rusty_common::websocket::ConnectionState::Disconnected,
122            )),
123        );
124
125        match connector.connect_with_retry(ws_url).await {
126            Ok((ws_sink, ws_stream)) => {
127                *self.ws_sink.write().await = Some(ws_sink);
128                *self.ws_stream.write().await = Some(ws_stream);
129                self.set_state(ConnectionState::Connected);
130                Ok(())
131            }
132            Err(e) => {
133                self.set_state(ConnectionState::Failed);
134                Err(anyhow!("Failed to connect to WebSocket: {}", e))
135            }
136        }
137    }
138
139    /// Disconnect WebSocket connection
140    pub async fn disconnect(&self) -> Result<()> {
141        self.set_state(ConnectionState::Disconnected);
142
143        // Close WebSocket connection if open
144        if let Some(mut ws_sink) = self.ws_sink.write().await.take() {
145            use futures::SinkExt;
146            if let Err(e) = ws_sink.close().await {
147                log::error!("Error closing WebSocket connection: {e}");
148            }
149        }
150        *self.ws_stream.write().await = None;
151
152        Ok(())
153    }
154
155    /// Get a cloned stream handle for message processing
156    pub fn get_stream_handle(&self) -> Arc<AsyncRwLock<Option<WebSocketStream>>> {
157        self.ws_stream.clone()
158    }
159
160    /// Get a cloned sink handle for sending messages
161    pub fn get_sink_handle(&self) -> Arc<AsyncRwLock<Option<WebSocketSink>>> {
162        self.ws_sink.clone()
163    }
164}
165
166/// Common Binance listen key management
167pub struct BinanceListenKeyManager {
168    /// HTTP client for listen key management
169    http_client: reqwest::Client,
170
171    /// Base API URL
172    api_url: String,
173
174    /// Current listen key
175    listen_key: Arc<ParkingLotRwLock<Option<String>>>,
176}
177
178impl BinanceListenKeyManager {
179    /// Create a new listen key manager
180    #[must_use]
181    pub fn new(api_url: String) -> Self {
182        Self {
183            http_client: reqwest::Client::new(),
184            api_url,
185            listen_key: Arc::new(ParkingLotRwLock::new(None)),
186        }
187    }
188
189    /// Get or create listen key
190    pub async fn get_or_create_listen_key<F>(&self, generate_headers: F) -> Result<String>
191    where
192        F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap>,
193    {
194        // Check if we have a cached listen key
195        if let Some(key) = self.listen_key.read().clone() {
196            return Ok(key);
197        }
198
199        // Create new listen key
200        let url = format!("{}/api/v3/userDataStream", self.api_url);
201        let headers = generate_headers("POST", "/api/v3/userDataStream", None)?;
202
203        let response = self.http_client.post(&url).headers(headers).send().await?;
204
205        if response.status().is_success() {
206            let mut response_text = response.text().await?;
207            // SAFETY: response_text comes from response.text().await which is UTF-8 validated
208            let json = unsafe { simd_json::from_str::<simd_json::OwnedValue>(&mut response_text) }
209                .map_err(|e| anyhow!("Failed to parse listen key response: {}", e))?;
210
211            let listen_key = json
212                .get("listenKey")
213                .and_then(|v| v.as_str())
214                .ok_or_else(|| anyhow!("Invalid listen key response"))?
215                .to_string();
216
217            *self.listen_key.write() = Some(listen_key.clone());
218            Ok(listen_key)
219        } else {
220            let error_text = response.text().await?;
221            Err(anyhow!("Failed to create listen key: {}", error_text))
222        }
223    }
224
225    /// Refresh listen key
226    pub async fn refresh_listen_key<F>(&self, listen_key: &str, generate_headers: F) -> Result<()>
227    where
228        F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap>,
229    {
230        let url = format!("{}/api/v3/userDataStream", self.api_url);
231        let headers = generate_headers("PUT", "/api/v3/userDataStream", None)?;
232
233        let response = self
234            .http_client
235            .put(&url)
236            .headers(headers)
237            .query(&[("listenKey", listen_key)])
238            .send()
239            .await?;
240
241        if response.status().is_success() {
242            Ok(())
243        } else {
244            let error_text = response.text().await?;
245            Err(anyhow!("Failed to refresh listen key: {}", error_text))
246        }
247    }
248
249    /// Delete listen key
250    pub async fn delete_listen_key<F>(&self, listen_key: &str, generate_headers: F) -> Result<()>
251    where
252        F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap>,
253    {
254        let url = format!("{}/api/v3/userDataStream", self.api_url);
255        let headers = generate_headers("DELETE", "/api/v3/userDataStream", None)?;
256
257        let response = self
258            .http_client
259            .delete(&url)
260            .headers(headers)
261            .query(&[("listenKey", listen_key)])
262            .send()
263            .await?;
264
265        if response.status().is_success() {
266            *self.listen_key.write() = None;
267            Ok(())
268        } else {
269            let error_text = response.text().await?;
270            Err(anyhow!("Failed to delete listen key: {}", error_text))
271        }
272    }
273
274    /// Start automatic listen key refresh task
275    pub fn start_refresh_task<F>(&self, generate_headers: F) -> tokio::task::JoinHandle<()>
276    where
277        F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap> + Send + 'static,
278    {
279        let listen_key = self.listen_key.clone();
280        let http_client = self.http_client.clone();
281        let api_url = self.api_url.clone();
282
283        tokio::spawn(async move {
284            loop {
285                // Refresh listen key every 30 minutes (Binance requires refresh every 60 minutes)
286                tokio::time::sleep(Duration::from_secs(30 * 60)).await;
287
288                let key = {
289                    let guard = listen_key.read();
290                    guard.clone()
291                };
292
293                if let Some(key) = key {
294                    let url = format!("{api_url}/api/v3/userDataStream");
295
296                    if let Ok(headers) = generate_headers("PUT", "/api/v3/userDataStream", None) {
297                        if let Err(e) = http_client
298                            .put(&url)
299                            .headers(headers)
300                            .query(&[("listenKey", &key)])
301                            .send()
302                            .await
303                        {
304                            log::error!("Failed to refresh listen key: {e}");
305                        }
306                    } else {
307                        log::error!("Failed to generate auth headers for listen key refresh");
308                    }
309                }
310            }
311        })
312    }
313}
314
315/// Determine appropriate WebSocket URL based on market type
316#[must_use]
317pub fn get_binance_ws_url(market_type: Option<&str>) -> &'static str {
318    match market_type {
319        Some("futures" | "usdt_futures" | "usd_futures") => "wss://fstream.binance.com/ws",
320        Some("coin_futures" | "delivery") => "wss://dstream.binance.com/ws",
321        _ => "wss://stream.binance.com:9443/ws", // Default to spot
322    }
323}