rusty_ems/auth/
manager.rs1use super::traits::{AuthenticationContext, AuthenticationManager, AuthenticationResult};
4use crate::error::Result;
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use quanta::Clock;
8use reqwest::header::HeaderMap;
9use rusty_common::{SmartString, collections::FxHashMap};
10use sha2::{Digest, Sha256};
11use std::sync::Arc;
12use std::time::Duration;
13
14pub struct UnifiedAuthManager {
16 exchange_name: SmartString,
18 adapter: Arc<dyn AuthenticationManager>,
20 last_auth_time: Arc<RwLock<Option<u64>>>,
22 clock: Clock,
24 cached_rest_auth: Arc<RwLock<FxHashMap<SmartString, (AuthenticationResult, u64)>>>,
26 cached_ws_auth: Arc<RwLock<Option<(AuthenticationResult, u64)>>>,
28 cache_validity_ns: u64,
30}
31
32impl UnifiedAuthManager {
33 pub fn new(
35 exchange_name: impl Into<SmartString>,
36 adapter: Arc<dyn AuthenticationManager>,
37 ) -> Self {
38 Self {
39 exchange_name: exchange_name.into(),
40 adapter,
41 last_auth_time: Arc::new(RwLock::new(None)),
42 clock: Clock::new(),
43 cached_rest_auth: Arc::new(RwLock::new(FxHashMap::default())),
44 cached_ws_auth: Arc::new(RwLock::new(None)),
45 cache_validity_ns: 60_000_000_000, }
47 }
48
49 #[must_use]
51 pub const fn with_cache_validity(mut self, duration: Duration) -> Self {
52 self.cache_validity_ns = duration.as_nanos() as u64;
53 self
54 }
55
56 fn should_cache_request(&self, context: &AuthenticationContext) -> bool {
58 match context.method.as_str() {
61 "GET" | "HEAD" | "OPTIONS" => {
62 context.timestamp.is_none()
64 }
65 _ => false, }
67 }
68
69 pub fn clear_cache(&self) {
71 self.cached_rest_auth.write().clear();
72 *self.cached_ws_auth.write() = None;
73 *self.last_auth_time.write() = None;
74 }
75
76 fn is_cache_valid(&self, timestamp: u64) -> bool {
78 let now = self.clock.raw();
79 now.saturating_sub(timestamp) < self.cache_validity_ns
80 }
81
82 fn generate_rest_cache_key(&self, context: &AuthenticationContext) -> SmartString {
84 let mut key = SmartString::new();
85 key.push_str(&context.method);
86 key.push(':');
87 key.push_str(&context.path);
88
89 if let Some(ref params) = context.query_params
91 && !params.is_empty()
92 {
93 key.push('?');
94 let mut sorted_params = params.clone();
95 sorted_params.sort_by(|a, b| a.0.cmp(&b.0));
96 for (i, (k, v)) in sorted_params.iter().enumerate() {
97 if i > 0 {
98 key.push('&');
99 }
100 key.push_str(k);
101 key.push('=');
102 key.push_str(v);
103 }
104 }
105
106 if let Some(ref body) = context.body {
108 key.push('#');
109 key.push_str(&format!("{:016x}", self.hash_body(body)));
110 }
111
112 key
113 }
114
115 fn hash_body(&self, body: &str) -> u64 {
118 let mut hasher = Sha256::new();
119 hasher.update(body.as_bytes());
120 let hash_result = hasher.finalize();
121
122 let mut bytes = [0u8; 8];
124 bytes.copy_from_slice(&hash_result[..8]);
125 u64::from_be_bytes(bytes)
126 }
127
128 fn get_cached_rest_auth(
130 &self,
131 context: &AuthenticationContext,
132 ) -> Option<AuthenticationResult> {
133 let cache_key = self.generate_rest_cache_key(context);
134 let cache = self.cached_rest_auth.read();
135 if let Some((result, timestamp)) = cache.get(&cache_key)
136 && self.is_cache_valid(*timestamp)
137 {
138 return Some(result.clone());
139 }
140 None
141 }
142
143 fn set_cached_rest_auth(&self, context: &AuthenticationContext, result: AuthenticationResult) {
145 let cache_key = self.generate_rest_cache_key(context);
146 let timestamp = self.clock.raw();
147 self.cached_rest_auth
148 .write()
149 .insert(cache_key, (result, timestamp));
150 }
151
152 fn get_cached_ws_auth(&self) -> Option<AuthenticationResult> {
154 let cache = self.cached_ws_auth.read();
155 if let Some((result, timestamp)) = cache.as_ref()
156 && self.is_cache_valid(*timestamp)
157 {
158 return Some(result.clone());
159 }
160 None
161 }
162
163 fn set_cached_ws_auth(&self, result: AuthenticationResult) {
165 let timestamp = self.clock.raw();
166 *self.cached_ws_auth.write() = Some((result, timestamp));
167 }
168
169 #[must_use]
171 pub fn create_rest_context(
172 &self,
173 method: &str,
174 path: &str,
175 body: Option<&str>,
176 ) -> AuthenticationContext {
177 let mut context = AuthenticationContext::new(method, path);
178 if let Some(body) = body {
179 context = context.with_body(body);
180 }
181 context
182 }
183
184 pub async fn authenticate_simple_rest(
186 &self,
187 method: &str,
188 path: &str,
189 body: Option<&str>,
190 ) -> Result<HeaderMap> {
191 let context = self.create_rest_context(method, path, body);
192 let result = self.authenticate_rest_request(&context).await?;
193 Ok(result.headers)
194 }
195
196 #[must_use]
198 pub fn get_auth_stats(&self) -> AuthenticationStats {
199 let last_auth = *self.last_auth_time.read();
200 let now = self.clock.raw();
201
202 let time_since_last_auth =
203 last_auth.map(|time| Duration::from_nanos(now.saturating_sub(time)));
204
205 let rest_cache_valid = {
206 let cache = self.cached_rest_auth.read();
207 cache
208 .values()
209 .any(|(_, timestamp)| self.is_cache_valid(*timestamp))
210 };
211
212 let ws_cache_valid = self
213 .cached_ws_auth
214 .read()
215 .as_ref()
216 .is_some_and(|(_, timestamp)| self.is_cache_valid(*timestamp));
217
218 AuthenticationStats {
219 exchange_name: self.exchange_name.clone(),
220 last_authentication: time_since_last_auth,
221 rest_cache_valid,
222 websocket_cache_valid: ws_cache_valid,
223 cache_validity_duration: Duration::from_nanos(self.cache_validity_ns),
224 supports_websocket_trading: self.adapter.supports_websocket_trading(),
225 requires_refresh: self.adapter.requires_refresh(),
226 }
227 }
228}
229
230#[async_trait]
231impl AuthenticationManager for UnifiedAuthManager {
232 async fn authenticate_rest_request(
233 &self,
234 context: &AuthenticationContext,
235 ) -> Result<AuthenticationResult> {
236 let use_cache = self.should_cache_request(context);
238
239 if use_cache && let Some(cached) = self.get_cached_rest_auth(context) {
240 return Ok(cached);
241 }
242
243 let result = self.adapter.authenticate_rest_request(context).await?;
245
246 *self.last_auth_time.write() = Some(self.clock.raw());
248
249 if use_cache {
251 self.set_cached_rest_auth(context, result.clone());
252 }
253
254 Ok(result)
255 }
256
257 async fn authenticate_websocket(&self) -> Result<AuthenticationResult> {
258 if let Some(cached) = self.get_cached_ws_auth() {
260 return Ok(cached);
261 }
262
263 let result = self.adapter.authenticate_websocket().await?;
265
266 *self.last_auth_time.write() = Some(self.clock.raw());
268
269 self.set_cached_ws_auth(result.clone());
271
272 Ok(result)
273 }
274
275 async fn authenticate_websocket_trading(
276 &self,
277 timestamp: Option<u64>,
278 ) -> Result<AuthenticationResult> {
279 let result = self
281 .adapter
282 .authenticate_websocket_trading(timestamp)
283 .await?;
284
285 *self.last_auth_time.write() = Some(self.clock.raw());
287
288 Ok(result)
289 }
290
291 fn is_authentication_valid(&self) -> bool {
292 self.adapter.is_authentication_valid()
293 }
294
295 fn time_until_expiration(&self) -> Option<Duration> {
296 self.adapter.time_until_expiration()
297 }
298
299 async fn refresh_authentication(&self) -> Result<()> {
300 self.clear_cache();
302
303 self.adapter.refresh_authentication().await?;
305
306 *self.last_auth_time.write() = Some(self.clock.raw());
308
309 Ok(())
310 }
311
312 fn exchange_name(&self) -> &str {
313 &self.exchange_name
314 }
315
316 fn api_key(&self) -> &str {
317 self.adapter.api_key()
318 }
319
320 fn supports_websocket_trading(&self) -> bool {
321 self.adapter.supports_websocket_trading()
322 }
323
324 fn requires_refresh(&self) -> bool {
325 self.adapter.requires_refresh()
326 }
327
328 fn refresh_interval(&self) -> Option<Duration> {
329 self.adapter.refresh_interval()
330 }
331}
332
333#[derive(Debug, Clone)]
335pub struct AuthenticationStats {
336 pub exchange_name: SmartString,
338 pub last_authentication: Option<Duration>,
340 pub rest_cache_valid: bool,
342 pub websocket_cache_valid: bool,
344 pub cache_validity_duration: Duration,
346 pub supports_websocket_trading: bool,
348 pub requires_refresh: bool,
350}
351
352impl AuthenticationStats {
353 #[must_use]
355 pub fn is_healthy(&self) -> bool {
356 if let Some(last_auth) = self.last_authentication {
359 last_auth < Duration::from_secs(300)
361 } else {
362 false
364 }
365 }
366
367 #[must_use]
369 pub fn health_status(&self) -> SmartString {
370 if self.is_healthy() {
371 "Healthy".into()
372 } else if self.last_authentication.is_none() {
373 "Not authenticated".into()
374 } else {
375 "Stale authentication".into()
376 }
377 }
378}