1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
3use std::time::Duration;
4
5use anyhow::{Result, anyhow};
6use parking_lot::RwLock as ParkingLotRwLock;
7use rusty_common::websocket::connector::{WebSocketSink, WebSocketStream};
8use rusty_common::websocket::{WebSocketConfig, WebSocketConnector};
9use simd_json::prelude::{ValueAsScalar, ValueObjectAccess};
10use tokio::sync::RwLock as AsyncRwLock;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ConnectionState {
15 Disconnected = 0,
17 Connecting = 1,
19 Connected = 2,
21 Reconnecting = 3,
23 Failed = 4,
25}
26
27impl From<u8> for ConnectionState {
28 fn from(value: u8) -> Self {
29 match value {
30 0 => Self::Disconnected,
31 1 => Self::Connecting,
32 2 => Self::Connected,
33 3 => Self::Reconnecting,
34 4 => Self::Failed,
35 _ => Self::Disconnected,
36 }
37 }
38}
39
40impl From<ConnectionState> for u8 {
41 fn from(state: ConnectionState) -> Self {
42 state as Self
43 }
44}
45
46pub struct BinanceWebSocketManager {
48 pub ws_sink: Arc<AsyncRwLock<Option<WebSocketSink>>>,
50
51 pub ws_stream: Arc<AsyncRwLock<Option<WebSocketStream>>>,
53
54 pub state: Arc<AtomicU8>,
56
57 pub connected: AtomicBool,
59}
60
61impl Clone for BinanceWebSocketManager {
62 fn clone(&self) -> Self {
63 Self {
64 ws_sink: self.ws_sink.clone(),
65 ws_stream: self.ws_stream.clone(),
66 state: self.state.clone(),
67 connected: AtomicBool::new(self.connected.load(Ordering::Relaxed)),
68 }
69 }
70}
71
72impl Default for BinanceWebSocketManager {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl BinanceWebSocketManager {
79 #[must_use]
81 pub fn new() -> Self {
82 Self {
83 ws_sink: Arc::new(AsyncRwLock::new(None)),
84 ws_stream: Arc::new(AsyncRwLock::new(None)),
85 state: Arc::new(AtomicU8::new(ConnectionState::Disconnected.into())),
86 connected: AtomicBool::new(false),
87 }
88 }
89
90 pub fn get_state(&self) -> ConnectionState {
92 ConnectionState::from(self.state.load(Ordering::Relaxed))
93 }
94
95 pub fn set_state(&self, state: ConnectionState) {
97 self.state.store(state.into(), Ordering::SeqCst);
98 self.connected
99 .store(state == ConnectionState::Connected, Ordering::SeqCst);
100 }
101
102 pub fn is_connected(&self) -> bool {
104 self.connected.load(Ordering::Relaxed)
105 }
106
107 pub async fn connect_with_retry(&self, ws_url: &str) -> Result<()> {
109 self.set_state(ConnectionState::Connecting);
110
111 let ws_config =
113 WebSocketConfig::new(rusty_common::types::Exchange::Binance, ws_url.to_string());
114
115 let mut connector = WebSocketConnector::new(
116 ws_config,
117 Arc::new(ParkingLotRwLock::new(
118 rusty_common::websocket::ConnectionStats::default(),
119 )),
120 Arc::new(ParkingLotRwLock::new(
121 rusty_common::websocket::ConnectionState::Disconnected,
122 )),
123 );
124
125 match connector.connect_with_retry(ws_url).await {
126 Ok((ws_sink, ws_stream)) => {
127 *self.ws_sink.write().await = Some(ws_sink);
128 *self.ws_stream.write().await = Some(ws_stream);
129 self.set_state(ConnectionState::Connected);
130 Ok(())
131 }
132 Err(e) => {
133 self.set_state(ConnectionState::Failed);
134 Err(anyhow!("Failed to connect to WebSocket: {}", e))
135 }
136 }
137 }
138
139 pub async fn disconnect(&self) -> Result<()> {
141 self.set_state(ConnectionState::Disconnected);
142
143 if let Some(mut ws_sink) = self.ws_sink.write().await.take() {
145 use futures::SinkExt;
146 if let Err(e) = ws_sink.close().await {
147 log::error!("Error closing WebSocket connection: {e}");
148 }
149 }
150 *self.ws_stream.write().await = None;
151
152 Ok(())
153 }
154
155 pub fn get_stream_handle(&self) -> Arc<AsyncRwLock<Option<WebSocketStream>>> {
157 self.ws_stream.clone()
158 }
159
160 pub fn get_sink_handle(&self) -> Arc<AsyncRwLock<Option<WebSocketSink>>> {
162 self.ws_sink.clone()
163 }
164}
165
166pub struct BinanceListenKeyManager {
168 http_client: reqwest::Client,
170
171 api_url: String,
173
174 listen_key: Arc<ParkingLotRwLock<Option<String>>>,
176}
177
178impl BinanceListenKeyManager {
179 #[must_use]
181 pub fn new(api_url: String) -> Self {
182 Self {
183 http_client: reqwest::Client::new(),
184 api_url,
185 listen_key: Arc::new(ParkingLotRwLock::new(None)),
186 }
187 }
188
189 pub async fn get_or_create_listen_key<F>(&self, generate_headers: F) -> Result<String>
191 where
192 F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap>,
193 {
194 if let Some(key) = self.listen_key.read().clone() {
196 return Ok(key);
197 }
198
199 let url = format!("{}/api/v3/userDataStream", self.api_url);
201 let headers = generate_headers("POST", "/api/v3/userDataStream", None)?;
202
203 let response = self.http_client.post(&url).headers(headers).send().await?;
204
205 if response.status().is_success() {
206 let mut response_text = response.text().await?;
207 let json = unsafe { simd_json::from_str::<simd_json::OwnedValue>(&mut response_text) }
209 .map_err(|e| anyhow!("Failed to parse listen key response: {}", e))?;
210
211 let listen_key = json
212 .get("listenKey")
213 .and_then(|v| v.as_str())
214 .ok_or_else(|| anyhow!("Invalid listen key response"))?
215 .to_string();
216
217 *self.listen_key.write() = Some(listen_key.clone());
218 Ok(listen_key)
219 } else {
220 let error_text = response.text().await?;
221 Err(anyhow!("Failed to create listen key: {}", error_text))
222 }
223 }
224
225 pub async fn refresh_listen_key<F>(&self, listen_key: &str, generate_headers: F) -> Result<()>
227 where
228 F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap>,
229 {
230 let url = format!("{}/api/v3/userDataStream", self.api_url);
231 let headers = generate_headers("PUT", "/api/v3/userDataStream", None)?;
232
233 let response = self
234 .http_client
235 .put(&url)
236 .headers(headers)
237 .query(&[("listenKey", listen_key)])
238 .send()
239 .await?;
240
241 if response.status().is_success() {
242 Ok(())
243 } else {
244 let error_text = response.text().await?;
245 Err(anyhow!("Failed to refresh listen key: {}", error_text))
246 }
247 }
248
249 pub async fn delete_listen_key<F>(&self, listen_key: &str, generate_headers: F) -> Result<()>
251 where
252 F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap>,
253 {
254 let url = format!("{}/api/v3/userDataStream", self.api_url);
255 let headers = generate_headers("DELETE", "/api/v3/userDataStream", None)?;
256
257 let response = self
258 .http_client
259 .delete(&url)
260 .headers(headers)
261 .query(&[("listenKey", listen_key)])
262 .send()
263 .await?;
264
265 if response.status().is_success() {
266 *self.listen_key.write() = None;
267 Ok(())
268 } else {
269 let error_text = response.text().await?;
270 Err(anyhow!("Failed to delete listen key: {}", error_text))
271 }
272 }
273
274 pub fn start_refresh_task<F>(&self, generate_headers: F) -> tokio::task::JoinHandle<()>
276 where
277 F: Fn(&str, &str, Option<&str>) -> Result<reqwest::header::HeaderMap> + Send + 'static,
278 {
279 let listen_key = self.listen_key.clone();
280 let http_client = self.http_client.clone();
281 let api_url = self.api_url.clone();
282
283 tokio::spawn(async move {
284 loop {
285 tokio::time::sleep(Duration::from_secs(30 * 60)).await;
287
288 let key = {
289 let guard = listen_key.read();
290 guard.clone()
291 };
292
293 if let Some(key) = key {
294 let url = format!("{api_url}/api/v3/userDataStream");
295
296 if let Ok(headers) = generate_headers("PUT", "/api/v3/userDataStream", None) {
297 if let Err(e) = http_client
298 .put(&url)
299 .headers(headers)
300 .query(&[("listenKey", &key)])
301 .send()
302 .await
303 {
304 log::error!("Failed to refresh listen key: {e}");
305 }
306 } else {
307 log::error!("Failed to generate auth headers for listen key refresh");
308 }
309 }
310 }
311 })
312 }
313}
314
315#[must_use]
317pub fn get_binance_ws_url(market_type: Option<&str>) -> &'static str {
318 match market_type {
319 Some("futures" | "usdt_futures" | "usd_futures") => "wss://fstream.binance.com/ws",
320 Some("coin_futures" | "delivery") => "wss://dstream.binance.com/ws",
321 _ => "wss://stream.binance.com:9443/ws", }
323}