rusty_feeder/exchange/binance/common/
rate_limit.rs

1//! Common rate limit handling for Binance API
2
3use crate::provider::prelude::*;
4use anyhow::{Result, anyhow};
5use reqwest::header::HeaderMap;
6use rusty_common::collections::FxHashMap;
7use smartstring::alias::String as SmartString;
8use std::time::Duration;
9use thiserror::Error;
10
11/// Rate limits for Binance API
12/// These rate limits are shared between API endpoints
13/// Converted to const array for compile-time evaluation
14pub const BINANCE_RATE_LIMITS: [RateLimit; 3] = [
15    RateLimit {
16        limit_type: "REQUEST_WEIGHT",
17        interval: "MINUTE",
18        interval_num: 1,
19        limit: 1200,
20    },
21    RateLimit {
22        limit_type: "ORDERS",
23        interval: "SECOND",
24        interval_num: 10,
25        limit: 50,
26    },
27    RateLimit {
28        limit_type: "ORDERS",
29        interval: "DAY",
30        interval_num: 1,
31        limit: 160000,
32    },
33];
34
35/// Binance-specific error types for rate limiting
36#[derive(Debug, Error)]
37pub enum BinanceError {
38    /// Rate limit exceeded error with retry information
39    #[error("Rate limit exceeded. Retry after: {retry_after:?}, Used weight: {used_weight:?}")]
40    RateLimit {
41        /// Duration to wait before retrying
42        retry_after: Option<Duration>,
43        /// Map of rate limit types to their used weights
44        used_weight: FxHashMap<SmartString, u64>,
45    },
46
47    /// IP address has been banned
48    #[error("IP banned for {ban_duration:?}")]
49    IpBanned {
50        /// Duration of the IP ban
51        ban_duration: Duration,
52    },
53
54    /// Invalid header value received
55    #[error("Invalid header value: {0}")]
56    InvalidHeader(String),
57}
58
59/// Parse order count headers from Binance response
60/// Headers are in format: X-MBX-ORDER-COUNT-{interval}
61/// where interval can be 1S, 10S, 1M, 1H, 1D
62pub fn parse_order_count_headers(headers: &HeaderMap) -> FxHashMap<SmartString, u64> {
63    const PREFIX: &str = "x-mbx-order-count-";
64    let mut order_counts = FxHashMap::default();
65
66    for (name, value) in headers {
67        let name_str = name.as_str();
68        // Check if header matches the pattern (case-insensitive, allocation-free)
69        if name_str.len() >= PREFIX.len() && name_str[..PREFIX.len()].eq_ignore_ascii_case(PREFIX) {
70            // Extract the interval part (e.g., "1s", "1m", "1h", "1d")
71            let interval = &name_str[PREFIX.len()..];
72            let interval_upper = interval.to_uppercase();
73
74            // Try to parse the value as u64
75            if let Ok(value_str) = value.to_str()
76                && let Ok(count) = value_str.parse::<u64>()
77            {
78                order_counts.insert(interval_upper.into(), count);
79            }
80        }
81    }
82
83    order_counts
84}
85
86/// Extract Retry-After header value from response headers
87pub fn extract_retry_after(headers: &HeaderMap) -> Option<Duration> {
88    headers
89        .get("retry-after")
90        .and_then(|value| value.to_str().ok())
91        .and_then(|value_str| value_str.parse::<u64>().ok())
92        .map(Duration::from_secs)
93}
94
95/// Parse used weight headers from Binance response
96pub fn parse_used_weight_headers(headers: &HeaderMap) -> FxHashMap<SmartString, u64> {
97    let mut used_weight = FxHashMap::default();
98
99    // Parse the basic x-mbx-used-weight header
100    if let Some(value) = headers.get("x-mbx-used-weight")
101        && let Ok(value_str) = value.to_str()
102        && let Ok(weight) = value_str.parse::<u64>()
103    {
104        used_weight.insert("total".into(), weight);
105    }
106
107    // Parse interval-specific weight headers (e.g., x-mbx-used-weight-1m)
108    for (name, value) in headers {
109        let name_str = name.as_str();
110        if name_str.len() > 18 && name_str[..18].eq_ignore_ascii_case("x-mbx-used-weight-") {
111            // Extract interval (e.g., "1m")
112            let interval = &name_str[18..];
113            let interval_upper = interval.to_uppercase();
114
115            if let Ok(value_str) = value.to_str()
116                && let Ok(weight) = value_str.parse::<u64>()
117            {
118                used_weight.insert(interval_upper.into(), weight);
119            }
120        }
121    }
122
123    used_weight
124}
125
126/// Handle rate limit errors based on HTTP status code and headers
127pub fn handle_rate_limit_error(status: u16, headers: &HeaderMap) -> Result<()> {
128    match status {
129        429 => {
130            // Rate limit exceeded
131            let retry_after = extract_retry_after(headers);
132            let used_weight = parse_used_weight_headers(headers);
133
134            Err(anyhow!(BinanceError::RateLimit {
135                retry_after,
136                used_weight,
137            }))
138        }
139        418 => {
140            // IP banned
141            let ban_duration =
142                extract_retry_after(headers).unwrap_or_else(|| Duration::from_secs(120)); // Default 2 minutes if not specified
143
144            Err(anyhow!(BinanceError::IpBanned { ban_duration }))
145        }
146        _ => Ok(()), // Not a rate limit error
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use reqwest::header::HeaderValue;
154    use std::time::Duration;
155
156    // Mock response structure for testing
157    struct MockResponse {
158        status: u16,
159        headers: HeaderMap,
160    }
161
162    #[test]
163    fn test_parse_order_count_headers() {
164        // Given: Response headers with order count information
165        let mut headers = HeaderMap::new();
166        headers.insert("x-mbx-order-count-1s", HeaderValue::from_static("2"));
167        headers.insert("x-mbx-order-count-1m", HeaderValue::from_static("45"));
168        headers.insert("x-mbx-order-count-1h", HeaderValue::from_static("1200"));
169        headers.insert("x-mbx-order-count-1d", HeaderValue::from_static("15000"));
170
171        // When: Parsing the headers
172        let order_counts = parse_order_count_headers(&headers);
173
174        // Then: Should extract all order counts correctly
175        assert_eq!(order_counts.get("1S"), Some(&2));
176        assert_eq!(order_counts.get("1M"), Some(&45));
177        assert_eq!(order_counts.get("1H"), Some(&1200));
178        assert_eq!(order_counts.get("1D"), Some(&15000));
179    }
180
181    #[test]
182    fn test_parse_order_count_headers_mixed_case() {
183        // Given: Headers with mixed case
184        let mut headers = HeaderMap::new();
185        headers.insert("X-MBX-ORDER-COUNT-1S", HeaderValue::from_static("5"));
186        headers.insert("x-mbx-order-count-10s", HeaderValue::from_static("50"));
187
188        // When: Parsing the headers
189        let order_counts = parse_order_count_headers(&headers);
190
191        // Then: Should handle case-insensitive headers
192        assert_eq!(order_counts.get("1S"), Some(&5));
193        assert_eq!(order_counts.get("10S"), Some(&50));
194    }
195
196    #[test]
197    fn test_parse_order_count_headers_invalid_values() {
198        // Given: Headers with invalid values
199        let mut headers = HeaderMap::new();
200        headers.insert("x-mbx-order-count-1s", HeaderValue::from_static("invalid"));
201        headers.insert("x-mbx-order-count-1m", HeaderValue::from_static("45"));
202
203        // When: Parsing the headers
204        let order_counts = parse_order_count_headers(&headers);
205
206        // Then: Should skip invalid values and parse valid ones
207        assert_eq!(order_counts.get("1S"), None);
208        assert_eq!(order_counts.get("1M"), Some(&45));
209    }
210
211    #[test]
212    fn test_handle_retry_after_header() {
213        // Given: 429 response with Retry-After header
214        let mut headers = HeaderMap::new();
215        headers.insert("retry-after", HeaderValue::from_static("30"));
216
217        let response = MockResponse {
218            status: 429,
219            headers,
220        };
221
222        // When: Processing the rate limit response
223        let retry_after = extract_retry_after(&response.headers);
224
225        // Then: Should extract retry duration
226        assert_eq!(retry_after, Some(Duration::from_secs(30)));
227    }
228
229    #[test]
230    fn test_handle_retry_after_header_missing() {
231        // Given: Response without Retry-After header
232        let headers = HeaderMap::new();
233
234        // When: Processing the response
235        let retry_after = extract_retry_after(&headers);
236
237        // Then: Should return None
238        assert_eq!(retry_after, None);
239    }
240
241    #[test]
242    fn test_handle_retry_after_header_invalid() {
243        // Given: Response with invalid Retry-After header
244        let mut headers = HeaderMap::new();
245        headers.insert("retry-after", HeaderValue::from_static("not-a-number"));
246
247        // When: Processing the response
248        let retry_after = extract_retry_after(&headers);
249
250        // Then: Should return None for invalid values
251        assert_eq!(retry_after, None);
252    }
253
254    #[test]
255    fn test_ip_ban_vs_rate_limit_handling() {
256        // Test 1: Rate limit (429)
257        let rate_limit_response = create_error_response(429, Some(30));
258        let error =
259            handle_rate_limit_error(rate_limit_response.status, &rate_limit_response.headers)
260                .unwrap_err();
261
262        match error.downcast_ref::<BinanceError>() {
263            Some(BinanceError::RateLimit { retry_after, .. }) => {
264                assert_eq!(*retry_after, Some(Duration::from_secs(30)));
265            }
266            _ => panic!("Expected RateLimit error"),
267        }
268
269        // Test 2: IP ban (418)
270        let ip_ban_response = create_error_response(418, Some(120));
271        let error =
272            handle_rate_limit_error(ip_ban_response.status, &ip_ban_response.headers).unwrap_err();
273
274        match error.downcast_ref::<BinanceError>() {
275            Some(BinanceError::IpBanned { ban_duration, .. }) => {
276                assert_eq!(*ban_duration, Duration::from_secs(120));
277            }
278            _ => panic!("Expected IpBanned error"),
279        }
280    }
281
282    #[test]
283    fn test_parse_used_weight_header() {
284        // Given: Response with used weight header
285        let mut headers = HeaderMap::new();
286        headers.insert("x-mbx-used-weight", HeaderValue::from_static("75"));
287        headers.insert("x-mbx-used-weight-1m", HeaderValue::from_static("1100"));
288
289        // When: Parsing used weight
290        let used_weight = parse_used_weight_headers(&headers);
291
292        // Then: Should extract weight values
293        assert_eq!(used_weight.get("total"), Some(&75));
294        assert_eq!(used_weight.get("1M"), Some(&1100));
295    }
296
297    // Helper function to create mock error responses
298    fn create_error_response(status: u16, retry_after_secs: Option<u64>) -> MockResponse {
299        let mut headers = HeaderMap::new();
300        if let Some(secs) = retry_after_secs {
301            headers.insert(
302                "retry-after",
303                HeaderValue::from_str(&secs.to_string()).unwrap(),
304            );
305        }
306
307        MockResponse { status, headers }
308    }
309}