rusty_common/auth/exchanges/
upbit.rs

1//! Upbit authentication implementation
2
3use crate::Result;
4use crate::SmartString;
5use crate::auth::hmac::build_sorted_query_smartstring;
6use crate::auth::jwt::{JwtClaims, generate_jwt_hs256};
7use crate::auth::signature::sha512_hash;
8use crate::auth::{AuthConfig, ExchangeAuth};
9use crate::collections::FxHashMap;
10use uuid::Uuid;
11
12/// Upbit authentication config
13#[derive(Debug, Clone)]
14pub struct UpbitAuthConfig {
15    /// The access key for the Upbit API.
16    pub access_key: SmartString,
17    /// The secret key for the Upbit API.
18    pub secret_key: SmartString,
19}
20
21impl UpbitAuthConfig {
22    /// Create a new authentication configuration
23    #[must_use]
24    pub const fn new(access_key: SmartString, secret_key: SmartString) -> Self {
25        Self {
26            access_key,
27            secret_key,
28        }
29    }
30}
31
32/// Upbit authentication handler
33#[derive(Debug, Clone)]
34pub struct UpbitAuth {
35    config: AuthConfig,
36}
37
38impl UpbitAuth {
39    /// Create new Upbit auth instance
40    #[must_use]
41    pub fn new(config: UpbitAuthConfig) -> Self {
42        Self {
43            config: AuthConfig::new(config.access_key, config.secret_key),
44        }
45    }
46
47    /// Generate JWT token for REST API GET/DELETE requests
48    pub fn generate_rest_jwt_get(
49        &self,
50        query_params: Option<&[(&str, &str)]>,
51    ) -> Result<SmartString> {
52        // Create JWT claims
53        let mut claims = JwtClaims::new(self.config.api_key.clone());
54        claims.nonce = Uuid::new_v4().to_string().into();
55
56        // Add query hash for GET/DELETE with params
57        if let Some(params) = query_params
58            && !params.is_empty()
59        {
60            let query_smartstring = build_sorted_query_smartstring(params)?;
61            let query_hash = sha512_hash(&query_smartstring);
62            claims = claims.with_query_hash(query_hash);
63        }
64
65        // Generate JWT token
66        generate_jwt_hs256(&claims, &self.config.secret_key)
67    }
68
69    /// Generate JWT token for REST API POST/PUT requests
70    pub fn generate_rest_jwt_post(&self, body: &str) -> Result<SmartString> {
71        // Create JWT claims
72        let mut claims = JwtClaims::new(self.config.api_key.clone());
73        claims.nonce = Uuid::new_v4().to_string().into();
74
75        // Add body hash
76        let body_hash = sha512_hash(body);
77        claims = claims.with_body_hash(body_hash);
78
79        // Generate JWT token
80        generate_jwt_hs256(&claims, &self.config.secret_key)
81    }
82}
83
84impl ExchangeAuth for UpbitAuth {
85    fn generate_authentication_headers(
86        &self,
87        method: &str,
88        _path: &str,
89        params: Option<(&str, &str)>,
90    ) -> Result<FxHashMap<SmartString, SmartString>> {
91        let mut headers = FxHashMap::default();
92
93        // Create JWT claims
94        let mut claims = JwtClaims::new(self.config.api_key.clone());
95        claims.nonce = Uuid::new_v4().to_string().into();
96
97        // Add query hash for GET/DELETE with params
98        if (method == "GET" || method == "DELETE")
99            && let Some(params) = params
100        {
101            let query_smartstring = build_sorted_query_smartstring(&[params])?;
102            let query_hash = sha512_hash(&query_smartstring);
103            claims = claims.with_query_hash(query_hash);
104        }
105
106        // Add body hash for POST/PUT
107        // Note: In actual implementation, body would be passed as parameter
108
109        // Generate JWT token
110        let token = generate_jwt_hs256(&claims, &self.config.secret_key)?;
111
112        headers.insert("Authorization".into(), format!("Bearer {token}").into());
113        headers.insert("Accept".into(), "application/json".into());
114
115        Ok(headers)
116    }
117
118    fn generate_websocket_authentication(&self) -> Result<SmartString> {
119        // Generate JWT for WebSocket authentication
120        let mut claims = JwtClaims::new(self.config.api_key.clone());
121        claims.nonce = Uuid::new_v4().to_string().into(); // Required for Upbit WebSocket
122        generate_jwt_hs256(&claims, &self.config.secret_key)
123    }
124
125    fn api_key(&self) -> &str {
126        &self.config.api_key
127    }
128}
129
130/// test code for upbit authentification
131/// ```sh
132/// cargo test upbit
133/// ```
134#[cfg(test)]
135mod test {
136    use super::*;
137    use crate::auth::ExchangeAuth;
138    use crate::collections::FxHashSet;
139
140    fn create_test_auth() -> UpbitAuth {
141        let config = UpbitAuthConfig::new("test_access_key".into(), "test_secret_key".into());
142        UpbitAuth::new(config)
143    }
144
145    #[test]
146    fn test_auth_normal() {
147        let upbit_auth_config =
148            UpbitAuthConfig::new("test_public_key".into(), "test_private_key".into());
149        let upbit_auth = UpbitAuth::new(upbit_auth_config);
150
151        let rest_jwt = upbit_auth.generate_rest_jwt_get(Some(&[("key", "value")]));
152        let post_jwt = upbit_auth.generate_rest_jwt_post("body");
153        let exchange_header =
154            upbit_auth.generate_authentication_headers("POST", "path", Some(("key", "value")));
155        let websocket_auth = upbit_auth.generate_websocket_authentication();
156        assert!(rest_jwt.is_ok());
157        assert!(post_jwt.is_ok());
158        assert!(exchange_header.is_ok());
159        assert!(websocket_auth.is_ok());
160    }
161
162    #[test]
163    fn test_auth_error() {
164        let upbit_auth_config =
165            UpbitAuthConfig::new("test_public_key".into(), "test_private_key".into());
166        let upbit_auth = UpbitAuth::new(upbit_auth_config);
167
168        let rest_jwt = upbit_auth.generate_rest_jwt_get(None);
169        assert!(rest_jwt.is_ok());
170    }
171
172    // === Extended Comprehensive Tests ===
173    // The following tests provide comprehensive coverage for edge cases,
174    // error conditions, and performance scenarios
175
176    #[test]
177    fn test_upbit_auth_config_creation() {
178        let config = UpbitAuthConfig::new("access_key".into(), "secret_key".into());
179        assert_eq!(config.access_key, "access_key");
180        assert_eq!(config.secret_key, "secret_key");
181    }
182
183    #[test]
184    fn test_upbit_auth_config_clone() {
185        let config = UpbitAuthConfig::new("access_key".into(), "secret_key".into());
186        let config_clone = config.clone();
187        assert_eq!(config.access_key, config_clone.access_key);
188        assert_eq!(config.secret_key, config_clone.secret_key);
189    }
190
191    #[test]
192    fn test_upbit_auth_creation() {
193        let auth = create_test_auth();
194        assert_eq!(auth.api_key(), "test_access_key");
195    }
196
197    #[test]
198    fn test_generate_rest_jwt_get_empty_params() {
199        let auth = create_test_auth();
200        let jwt = auth.generate_rest_jwt_get(Some(&[])).unwrap();
201
202        // JWT should have 3 parts separated by dots
203        let parts: Vec<&str> = jwt.split('.').collect();
204        assert_eq!(parts.len(), 3);
205        assert!(!jwt.is_empty());
206    }
207
208    #[test]
209    fn test_generate_rest_jwt_get_with_params() {
210        let auth = create_test_auth();
211        let params = &[("market", "KRW-BTC"), ("count", "10")];
212        let jwt = auth.generate_rest_jwt_get(Some(params)).unwrap();
213
214        // JWT should have 3 parts separated by dots
215        let parts: Vec<&str> = jwt.split('.').collect();
216        assert_eq!(parts.len(), 3);
217        assert!(!jwt.is_empty());
218    }
219
220    #[test]
221    fn test_generate_rest_jwt_get_none_params() {
222        let auth = create_test_auth();
223        let jwt = auth.generate_rest_jwt_get(None).unwrap();
224
225        // JWT should have 3 parts separated by dots
226        let parts: Vec<&str> = jwt.split('.').collect();
227        assert_eq!(parts.len(), 3);
228        assert!(!jwt.is_empty());
229    }
230
231    #[test]
232    fn test_generate_rest_jwt_get_different_params_different_tokens() {
233        let auth = create_test_auth();
234        let params1 = &[("market", "KRW-BTC")];
235        let params2 = &[("market", "KRW-ETH")];
236
237        let jwt1 = auth.generate_rest_jwt_get(Some(params1)).unwrap();
238        let jwt2 = auth.generate_rest_jwt_get(Some(params2)).unwrap();
239
240        // Different parameters should produce different tokens
241        assert_ne!(jwt1, jwt2);
242    }
243
244    #[test]
245    fn test_generate_rest_jwt_post_empty_body() {
246        let auth = create_test_auth();
247        let jwt = auth.generate_rest_jwt_post("").unwrap();
248
249        // JWT should have 3 parts separated by dots
250        let parts: Vec<&str> = jwt.split('.').collect();
251        assert_eq!(parts.len(), 3);
252        assert!(!jwt.is_empty());
253    }
254
255    #[test]
256    fn test_generate_rest_jwt_post_with_body() {
257        let auth = create_test_auth();
258        let body = r#"{"market":"KRW-BTC","side":"bid","volume":"0.01","price":"1000000","ord_type":"limit"}"#;
259        let jwt = auth.generate_rest_jwt_post(body).unwrap();
260
261        // JWT should have 3 parts separated by dots
262        let parts: Vec<&str> = jwt.split('.').collect();
263        assert_eq!(parts.len(), 3);
264        assert!(!jwt.is_empty());
265    }
266
267    #[test]
268    fn test_generate_rest_jwt_post_different_bodies_different_tokens() {
269        let auth = create_test_auth();
270        let body1 = r#"{"market":"KRW-BTC","side":"bid"}"#;
271        let body2 = r#"{"market":"KRW-ETH","side":"ask"}"#;
272
273        let jwt1 = auth.generate_rest_jwt_post(body1).unwrap();
274        let jwt2 = auth.generate_rest_jwt_post(body2).unwrap();
275
276        // Different bodies should produce different tokens
277        assert_ne!(jwt1, jwt2);
278    }
279
280    #[test]
281    fn test_generate_rest_jwt_post_json_body() {
282        let auth = create_test_auth();
283        let json_body = r#"{"market":"KRW-BTC","side":"bid","volume":"0.01","price":"1000000","ord_type":"limit"}"#;
284        let jwt = auth.generate_rest_jwt_post(json_body).unwrap();
285
286        // Should successfully generate JWT for valid JSON body
287        assert!(!jwt.is_empty());
288        let parts: Vec<&str> = jwt.split('.').collect();
289        assert_eq!(parts.len(), 3);
290    }
291
292    #[test]
293    fn test_generate_websocket_authentication() {
294        let auth = create_test_auth();
295        let ws_auth = auth.generate_websocket_authentication().unwrap();
296
297        // WebSocket auth should be a valid JWT
298        let parts: Vec<&str> = ws_auth.split('.').collect();
299        assert_eq!(parts.len(), 3);
300        assert!(!ws_auth.is_empty());
301    }
302
303    #[test]
304    fn test_generate_websocket_authentication_multiple_calls() {
305        let auth = create_test_auth();
306        let ws_auth1 = auth.generate_websocket_authentication().unwrap();
307        let ws_auth2 = auth.generate_websocket_authentication().unwrap();
308
309        // Multiple calls should generate different tokens (due to UUID nonce)
310        assert_ne!(ws_auth1, ws_auth2);
311    }
312
313    #[test]
314    fn test_generate_authentication_headers_get() {
315        let auth = create_test_auth();
316        let headers = auth
317            .generate_authentication_headers("GET", "/v1/accounts", None)
318            .unwrap();
319
320        assert!(headers.contains_key(&SmartString::from("Authorization")));
321        assert!(headers.contains_key(&SmartString::from("Accept")));
322
323        let auth_header = headers.get(&SmartString::from("Authorization")).unwrap();
324        assert!(auth_header.starts_with("Bearer "));
325
326        let accept_header = headers.get(&SmartString::from("Accept")).unwrap();
327        assert_eq!(accept_header, &SmartString::from("application/json"));
328    }
329
330    #[test]
331    fn test_generate_authentication_headers_get_with_params() {
332        let auth = create_test_auth();
333        let headers = auth
334            .generate_authentication_headers(
335                "GET",
336                "/v1/candles/minutes/1",
337                Some(("market", "KRW-BTC")),
338            )
339            .unwrap();
340
341        assert!(headers.contains_key(&SmartString::from("Authorization")));
342        assert!(headers.contains_key(&SmartString::from("Accept")));
343
344        let auth_header = headers.get(&SmartString::from("Authorization")).unwrap();
345        assert!(auth_header.starts_with("Bearer "));
346        assert!(auth_header.len() > 20); // Should be substantial JWT token
347    }
348
349    #[test]
350    fn test_generate_authentication_headers_delete() {
351        let auth = create_test_auth();
352        let headers = auth
353            .generate_authentication_headers("DELETE", "/v1/order", Some(("uuid", "test-uuid")))
354            .unwrap();
355
356        assert!(headers.contains_key(&SmartString::from("Authorization")));
357        assert!(headers.contains_key(&SmartString::from("Accept")));
358
359        let auth_header = headers.get(&SmartString::from("Authorization")).unwrap();
360        assert!(auth_header.starts_with("Bearer "));
361    }
362
363    #[test]
364    fn test_generate_authentication_headers_post() {
365        let auth = create_test_auth();
366        let headers = auth
367            .generate_authentication_headers("POST", "/v1/orders", None)
368            .unwrap();
369
370        assert!(headers.contains_key(&SmartString::from("Authorization")));
371        assert!(headers.contains_key(&SmartString::from("Accept")));
372
373        let auth_header = headers.get(&SmartString::from("Authorization")).unwrap();
374        assert!(auth_header.starts_with("Bearer "));
375    }
376
377    #[test]
378    fn test_generate_authentication_headers_put() {
379        let auth = create_test_auth();
380        let headers = auth
381            .generate_authentication_headers("PUT", "/v1/order", None)
382            .unwrap();
383
384        assert!(headers.contains_key(&SmartString::from("Authorization")));
385        assert!(headers.contains_key(&SmartString::from("Accept")));
386
387        let auth_header = headers.get(&SmartString::from("Authorization")).unwrap();
388        assert!(auth_header.starts_with("Bearer "));
389    }
390
391    #[test]
392    fn test_generate_authentication_headers_different_methods() {
393        let auth = create_test_auth();
394        let get_headers = auth
395            .generate_authentication_headers("GET", "/v1/accounts", None)
396            .unwrap();
397        let post_headers = auth
398            .generate_authentication_headers("POST", "/v1/orders", None)
399            .unwrap();
400
401        // Should have same structure but different tokens
402        assert_eq!(get_headers.len(), post_headers.len());
403        assert!(get_headers.contains_key(&SmartString::from("Authorization")));
404        assert!(post_headers.contains_key(&SmartString::from("Authorization")));
405
406        let get_auth = get_headers
407            .get(&SmartString::from("Authorization"))
408            .unwrap();
409        let post_auth = post_headers
410            .get(&SmartString::from("Authorization"))
411            .unwrap();
412
413        // Both should be Bearer tokens but with different values
414        assert!(get_auth.starts_with("Bearer "));
415        assert!(post_auth.starts_with("Bearer "));
416        assert_ne!(get_auth, post_auth);
417    }
418
419    #[test]
420    fn test_api_key_getter() {
421        let auth = create_test_auth();
422        assert_eq!(auth.api_key(), "test_access_key");
423    }
424
425    #[test]
426    fn test_jwt_token_uniqueness() {
427        let auth = create_test_auth();
428        let mut tokens = FxHashSet::default();
429
430        // Generate multiple JWT tokens and ensure they're all unique
431        for _ in 0..100 {
432            let jwt = auth.generate_rest_jwt_get(None).unwrap();
433            assert!(tokens.insert(jwt)); // Should be unique
434        }
435    }
436
437    #[test]
438    fn test_jwt_token_format() {
439        let auth = create_test_auth();
440        let jwt = auth.generate_rest_jwt_get(None).unwrap();
441
442        // JWT should have proper format: header.payload.signature
443        let parts: Vec<&str> = jwt.split('.').collect();
444        assert_eq!(parts.len(), 3);
445
446        // Each part should be base64-like
447        for part in parts {
448            assert!(!part.is_empty());
449            assert!(part.chars().all(|c| c.is_alphanumeric()
450                || c == '-'
451                || c == '_'
452                || c == '='
453                || c == '+'
454                || c == '/'));
455        }
456    }
457
458    #[test]
459    fn test_query_params_encoding() {
460        let auth = create_test_auth();
461
462        // Test with special characters that need encoding
463        let params = &[("market", "KRW-BTC"), ("side", "bid"), ("volume", "0.01")];
464        let jwt1 = auth.generate_rest_jwt_get(Some(params)).unwrap();
465        let jwt2 = auth.generate_rest_jwt_get(Some(params)).unwrap();
466
467        // Same parameters should produce different tokens (due to UUID nonce)
468        assert_ne!(jwt1, jwt2);
469
470        // But both should be valid JWTs
471        assert_eq!(jwt1.split('.').count(), 3);
472        assert_eq!(jwt2.split('.').count(), 3);
473    }
474
475    #[test]
476    fn test_korean_market_symbols() {
477        let auth = create_test_auth();
478
479        // Test with Korean market symbols
480        let params = &[
481            ("market", "KRW-BTC"),
482            ("count", "200"),
483            ("to", "2023-01-01T00:00:00Z"),
484        ];
485        let jwt = auth.generate_rest_jwt_get(Some(params)).unwrap();
486
487        // Should generate valid JWT for Korean market parameters
488        assert!(!jwt.is_empty());
489        let parts: Vec<&str> = jwt.split('.').collect();
490        assert_eq!(parts.len(), 3);
491    }
492
493    #[test]
494    fn test_multiple_parameter_combinations() {
495        let auth = create_test_auth();
496
497        let test_cases = vec![
498            vec![("market", "KRW-BTC")],
499            vec![("market", "KRW-ETH"), ("count", "100")],
500            vec![
501                ("market", "KRW-ADA"),
502                ("count", "50"),
503                ("to", "2023-01-01T00:00:00Z"),
504            ],
505            vec![("uuids", "uuid1"), ("uuids", "uuid2")], // Multiple values for same key
506        ];
507
508        for params in test_cases {
509            let jwt = auth.generate_rest_jwt_get(Some(&params)).unwrap();
510            assert!(!jwt.is_empty());
511            let parts: Vec<&str> = jwt.split('.').collect();
512            assert_eq!(parts.len(), 3);
513        }
514    }
515
516    #[test]
517    fn test_large_parameter_values() {
518        let auth = create_test_auth();
519
520        // Test with large parameter values
521        let large_value = "a".repeat(1000);
522        let params = &[("market", "KRW-BTC"), ("identifier", &large_value)];
523        let jwt = auth.generate_rest_jwt_get(Some(params)).unwrap();
524
525        // Should handle large parameter values
526        assert!(!jwt.is_empty());
527        let parts: Vec<&str> = jwt.split('.').collect();
528        assert_eq!(parts.len(), 3);
529    }
530
531    #[test]
532    fn test_large_request_body() {
533        let auth = create_test_auth();
534
535        // Test with large request body
536        let large_body = format!(
537            "{{\"market\":\"KRW-BTC\",\"side\":\"bid\",\"volume\":\"0.01\",\"price\":\"1000000\",\"ord_type\":\"limit\",\"notes\":\"{}\"}}",
538            "x".repeat(1000)
539        );
540        let jwt = auth.generate_rest_jwt_post(&large_body).unwrap();
541
542        // Should handle large request bodies
543        assert!(!jwt.is_empty());
544        let parts: Vec<&str> = jwt.split('.').collect();
545        assert_eq!(parts.len(), 3);
546    }
547
548    #[test]
549    fn test_special_characters_in_parameters() {
550        let auth = create_test_auth();
551
552        // Test with special characters in parameters
553        let params = &[
554            ("market", "KRW-BTC"),
555            ("identifier", "[email protected]"),
556            ("note", "Hello World!"),
557        ];
558        let jwt = auth.generate_rest_jwt_get(Some(params)).unwrap();
559
560        // Should handle special characters correctly
561        assert!(!jwt.is_empty());
562        let parts: Vec<&str> = jwt.split('.').collect();
563        assert_eq!(parts.len(), 3);
564    }
565
566    #[test]
567    fn test_empty_parameter_values() {
568        let auth = create_test_auth();
569
570        // Test with empty parameter values
571        let params = &[("market", "KRW-BTC"), ("note", "")];
572        let jwt = auth.generate_rest_jwt_get(Some(params)).unwrap();
573
574        // Should handle empty parameter values
575        assert!(!jwt.is_empty());
576        let parts: Vec<&str> = jwt.split('.').collect();
577        assert_eq!(parts.len(), 3);
578    }
579
580    #[test]
581    fn test_debug_implementation() {
582        let auth = create_test_auth();
583        let debug_output = format!("{auth:?}");
584
585        // Should have debug implementation
586        assert!(debug_output.contains("UpbitAuth"));
587    }
588
589    #[test]
590    fn test_clone_implementation() {
591        let auth = create_test_auth();
592        let auth_clone = auth.clone();
593
594        // Clone should work and produce equivalent auth
595        assert_eq!(auth.api_key(), auth_clone.api_key());
596
597        // Should be able to generate tokens with both
598        let jwt1 = auth.generate_rest_jwt_get(None).unwrap();
599        let jwt2 = auth_clone.generate_rest_jwt_get(None).unwrap();
600
601        // Both should be valid JWTs (but different due to UUID nonce)
602        assert_ne!(jwt1, jwt2);
603        assert_eq!(jwt1.split('.').count(), 3);
604        assert_eq!(jwt2.split('.').count(), 3);
605    }
606
607    #[test]
608    fn test_exchange_auth_trait_implementation() {
609        let auth = create_test_auth();
610
611        // Test ExchangeAuth trait methods
612        assert_eq!(auth.api_key(), "test_access_key");
613
614        let headers = auth
615            .generate_authentication_headers("GET", "/v1/accounts", None)
616            .unwrap();
617        assert!(headers.contains_key(&SmartString::from("Authorization")));
618
619        let ws_auth = auth.generate_websocket_authentication().unwrap();
620        assert!(!ws_auth.is_empty());
621    }
622
623    #[test]
624    fn test_concurrent_token_generation() {
625        use std::sync::Arc;
626        use std::thread;
627
628        let auth = Arc::new(create_test_auth());
629        let mut handles = vec![];
630
631        // Generate tokens concurrently
632        for _ in 0..10 {
633            let auth_clone = Arc::clone(&auth);
634            let handle = thread::spawn(move || auth_clone.generate_rest_jwt_get(None).unwrap());
635            handles.push(handle);
636        }
637
638        // Collect all tokens
639        let mut tokens = FxHashSet::default();
640        for handle in handles {
641            let jwt = handle.join().unwrap();
642            assert!(!jwt.is_empty());
643            assert_eq!(jwt.split('.').count(), 3);
644            tokens.insert(jwt);
645        }
646
647        // All tokens should be unique
648        assert_eq!(tokens.len(), 10);
649    }
650
651    #[test]
652    fn test_consistency_across_calls() {
653        let auth = create_test_auth();
654
655        // Test that the same auth instance produces valid tokens consistently
656        for _ in 0..50 {
657            let jwt = auth.generate_rest_jwt_get(None).unwrap();
658            assert!(!jwt.is_empty());
659            assert_eq!(jwt.split('.').count(), 3);
660
661            let ws_auth = auth.generate_websocket_authentication().unwrap();
662            assert!(!ws_auth.is_empty());
663            assert_eq!(ws_auth.split('.').count(), 3);
664
665            let headers = auth
666                .generate_authentication_headers("GET", "/v1/accounts", None)
667                .unwrap();
668            assert_eq!(headers.len(), 2);
669        }
670    }
671}