1use std::env;
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use crate::execution_engine::Exchange;
10use async_trait::async_trait;
11use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
12use hmac::{Hmac, Mac};
13use log::error;
14use quanta::Clock;
15use reqwest::Client;
16use rusty_model::{
17 enums::{OrderSide, OrderType},
18 trading_order::Order,
19};
20use serde::{Deserialize, Serialize};
21use sha2::{Digest, Sha256, Sha512};
22use simd_json::json;
23type HmacSha256 = Hmac<Sha256>;
26
27const API_URL: &str = "https://api.upbit.com";
29
30pub struct UpbitExchange {
32 client: UpbitClient,
34}
35
36impl UpbitExchange {
37 #[must_use]
39 pub fn new() -> Self {
40 Self {
41 client: UpbitClient::new(),
42 }
43 }
44
45 #[must_use]
47 pub fn with_credentials(
48 access_key: std::string::String,
49 secret_key: std::string::String,
50 ) -> Self {
51 Self {
52 client: UpbitClient::with_credentials(access_key, secret_key),
53 }
54 }
55}
56
57impl Default for UpbitExchange {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[async_trait]
64impl Exchange for UpbitExchange {
65 async fn send_order(&self, order: Order) -> crate::Result<()> {
66 self.client
67 .place_order(order)
68 .await
69 .map_err(|e| crate::OmsError::Exchange(e.to_string().into()))
70 }
71
72 async fn cancel_order(&self, order_id: std::string::String) -> crate::Result<()> {
73 self.client
74 .cancel_order(order_id)
75 .await
76 .map_err(|e| crate::OmsError::Exchange(e.to_string().into()))
77 }
78
79 async fn get_order_status(
80 &self,
81 order_id: std::string::String,
82 ) -> crate::Result<std::string::String> {
83 return self
84 .client
85 .get_order_status(order_id)
86 .await
87 .map_err(|e| crate::OmsError::Exchange(e.to_string().into()));
88 }
89}
90
91#[derive(Debug, Serialize, Deserialize)]
93pub struct UpbitCredentials {
94 access_key: std::string::String,
96 secret_key: std::string::String,
98}
99
100impl UpbitCredentials {
101 #[must_use]
103 pub fn new() -> Self {
104 Self {
105 access_key: env::var("UPBIT_ACCESS_KEY").expect("UPBIT_ACCESS_KEY not found in .env"),
106 secret_key: env::var("UPBIT_SECRET_KEY").expect("UPBIT_SECRET_KEY not found in .env"),
107 }
108 }
109
110 pub const fn with_credentials(
112 access_key: std::string::String,
113 secret_key: std::string::String,
114 ) -> Self {
115 Self {
116 access_key,
117 secret_key,
118 }
119 }
120}
121
122#[derive(Debug, Serialize, Deserialize)]
124struct UpbitOrderResponse {
125 uuid: std::string::String,
127 side: std::string::String,
129 ord_type: std::string::String,
131 price: Option<std::string::String>,
133 state: std::string::String,
135 market: std::string::String,
137 created_at: std::string::String,
139 volume: std::string::String,
141 remaining_volume: std::string::String,
143 reserved_fee: std::string::String,
145 remaining_fee: std::string::String,
147 paid_fee: std::string::String,
149 locked: std::string::String,
151 executed_volume: std::string::String,
153 trades_count: u64,
155}
156
157#[derive(Debug, Serialize, Deserialize)]
159struct UpbitErrorResponse {
160 error: UpbitError,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
166struct UpbitError {
167 name: std::string::String,
169 message: std::string::String,
171}
172
173pub struct UpbitClient {
175 client: Client,
177 credentials: UpbitCredentials,
179 clock: Clock,
181}
182
183impl UpbitClient {
184 #[must_use]
186 pub fn new() -> Self {
187 Self {
188 client: Client::new(),
189 credentials: UpbitCredentials::new(),
190 clock: Clock::new(),
191 }
192 }
193
194 pub fn with_credentials(
196 access_key: std::string::String,
197 secret_key: std::string::String,
198 ) -> Self {
199 Self {
200 client: Client::new(),
201 credentials: UpbitCredentials::with_credentials(access_key, secret_key),
202 clock: Clock::new(),
203 }
204 }
205
206 async fn generate_jwt(&self, query_params: Option<&str>) -> std::string::String {
208 let nonce = format!(
211 "{}",
212 SystemTime::now()
213 .duration_since(UNIX_EPOCH)
214 .expect("Time went backwards")
215 .as_millis() );
217
218 let mut payload = json!({
220 "access_key": self.credentials.access_key,
221 "nonce": nonce,
222 });
223
224 if let Some(params) = query_params {
226 let mut hasher = Sha512::new();
227 hasher.update(params.as_bytes());
228 let query_hash = hex::encode(hasher.finalize());
229
230 payload["query_hash"] = json!(query_hash);
231 payload["query_hash_alg"] = json!("SHA512");
232 }
233
234 let payload_str = simd_json::to_string(&payload).expect("Failed to serialize payload");
236
237 let header = json!({
239 "alg": "HS256",
240 "typ": "JWT"
241 });
242 let header_str = simd_json::to_string(&header).expect("Failed to serialize header");
243
244 let header_b64 = BASE64.encode(header_str.as_bytes());
246 let payload_b64 = BASE64.encode(payload_str.as_bytes());
247
248 let signature_input = format!("{header_b64}.{payload_b64}");
250 let mut mac = HmacSha256::new_from_slice(self.credentials.secret_key.as_bytes())
251 .expect("HMAC can take key of any size");
252 mac.update(signature_input.as_bytes());
253 let signature = BASE64.encode(mac.finalize().into_bytes());
254
255 format!("{header_b64}.{payload_b64}.{signature}")
257 }
258
259 const fn map_order_type(order_type: OrderType) -> &'static str {
261 match order_type {
262 OrderType::Market => "price", OrderType::Limit => "limit",
264 _ => "limit", }
266 }
267
268 const fn map_order_side(side: OrderSide) -> &'static str {
270 match side {
271 OrderSide::Buy => "bid",
272 OrderSide::Sell => "ask",
273 }
274 }
275
276 pub async fn place_order(
277 &self,
278 order: Order,
279 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
280 let market = order.symbol.clone();
282
283 let side = Self::map_order_side(order.side);
285 let ord_type = Self::map_order_type(order.order_type);
286
287 let mut body = json!({
289 "market": market,
290 "side": side,
291 "ord_type": ord_type,
292 });
293
294 if ord_type == "limit" {
296 body["price"] = json!(
297 order
298 .price
299 .map_or_else(|| "0".to_string(), |p| p.to_string())
300 );
301 body["volume"] = json!(order.quantity.to_string());
302 } else if ord_type == "price" {
303 body["price"] = json!(
305 order
306 .price
307 .map_or_else(|| "0".to_string(), |p| p.to_string())
308 );
309 } else if ord_type == "market" {
310 body["volume"] = json!(order.quantity.to_string());
312 }
313
314 let token = self.generate_jwt(None).await;
316
317 let url = format!("{API_URL}/v1/orders");
319 let response = self
320 .client
321 .post(&url)
322 .header("Authorization", format!("Bearer {token}"))
323 .header("Content-Type", "application/json")
324 .body(simd_json::to_string(&body)?)
325 .send()
326 .await?;
327
328 if !response.status().is_success() {
329 let bytes = response.bytes().await?;
330 let mut bytes_vec = bytes.to_vec();
331 let error_response: UpbitErrorResponse = simd_json::from_slice(&mut bytes_vec)?;
332 let error_message = format!(
333 "{}: {}",
334 error_response.error.name, error_response.error.message
335 );
336 error!("Upbit order placement failed: {error_message}");
337 return Err(Box::new(std::io::Error::other(error_message)));
341 }
342
343 Ok(())
344 }
345
346 pub async fn cancel_order(
347 &self,
348 order_id: std::string::String,
349 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
350 let query_params = format!("uuid={order_id}");
352
353 let token = self.generate_jwt(Some(&query_params)).await;
355
356 let url = format!("{API_URL}/v1/order?{query_params}");
358 let response = self
359 .client
360 .delete(&url)
361 .header("Authorization", format!("Bearer {token}"))
362 .send()
363 .await?;
364
365 if !response.status().is_success() {
366 let bytes = response.bytes().await?;
367 let mut bytes_vec = bytes.to_vec();
368 let error_response: UpbitErrorResponse = simd_json::from_slice(&mut bytes_vec)?;
369 let error_message = format!(
370 "{}: {}",
371 error_response.error.name, error_response.error.message
372 );
373 error!("Upbit order cancellation failed: {error_message}");
374 return Err(Box::new(std::io::Error::other(error_message)));
378 }
379
380 Ok(())
381 }
382
383 pub async fn get_order_status(
384 &self,
385 order_id: std::string::String,
386 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
387 let query_params = format!("uuid={order_id}");
389
390 let token = self.generate_jwt(Some(&query_params)).await;
392
393 let url = format!("{API_URL}/v1/order?{query_params}");
395 let response = self
396 .client
397 .get(&url)
398 .header("Authorization", format!("Bearer {token}"))
399 .send()
400 .await?;
401
402 if response.status().is_success() {
403 let bytes = response.bytes().await?;
404 let mut bytes_vec = bytes.to_vec();
405 let order_response: UpbitOrderResponse = simd_json::from_slice(&mut bytes_vec)?;
406 Ok(order_response.state)
407 } else {
408 let bytes = response.bytes().await?;
409 let mut bytes_vec = bytes.to_vec();
410 let error_response: UpbitErrorResponse = simd_json::from_slice(&mut bytes_vec)?;
411 let error_message = format!(
412 "{}: {}",
413 error_response.error.name, error_response.error.message
414 );
415 error!("Failed to get order status: {error_message}");
416 Err::<String, Box<dyn std::error::Error + Send + Sync>>(Box::new(
420 std::io::Error::other(error_message),
421 ))
422 }
423 }
424
425 #[allow(dead_code)]
426 pub async fn get_orderbook(&self, market: &str) -> Result<String, reqwest::Error> {
427 let url = format!("{API_URL}/v1/orderbook?markets={market}");
428 let response = self.client.get(&url).send().await?.text().await?;
429 Ok(response)
430 }
431}