1use anyhow::Result;
10use async_trait::async_trait;
11use flume::Sender;
12use parking_lot::RwLock;
13use rust_decimal::Decimal;
14use rusty_common::{SmartString, collections::FxHashMap};
15use rusty_model::{
16 position::{FuturesPosition, PositionUpdate},
17 types::PositionId,
18 venues::Venue,
19};
20use std::sync::Arc;
21
22#[derive(Debug, Clone)]
24pub enum PositionEvent {
25 PositionOpened {
27 position_id: PositionId,
29 venue: Venue,
31 symbol: SmartString,
33 },
34 PositionUpdated {
36 position_id: PositionId,
38 venue: Venue,
40 symbol: SmartString,
42 old_amount: Decimal,
44 new_amount: Decimal,
46 },
47 PositionClosed {
49 position_id: PositionId,
51 venue: Venue,
53 symbol: SmartString,
55 realized_pnl: Decimal,
57 },
58}
59
60#[async_trait]
62pub trait PositionManager: Send + Sync {
63 async fn update_position(&self, update: PositionUpdate) -> Result<()>;
65
66 async fn get_position(&self, position_id: &PositionId) -> Option<FuturesPosition>;
68
69 async fn get_positions_by_symbol(&self, venue: Venue, symbol: &str) -> Vec<FuturesPosition>;
71
72 async fn get_all_positions(&self) -> Vec<FuturesPosition>;
74
75 async fn get_total_unrealized_pnl(&self) -> Decimal;
77
78 async fn get_total_realized_pnl(&self) -> Decimal;
80
81 async fn clear_all_positions(&self);
83}
84
85type SymbolIndex = Arc<RwLock<FxHashMap<(Venue, SmartString), Vec<PositionId>>>>;
87
88pub struct DefaultPositionManager {
90 positions: Arc<RwLock<FxHashMap<PositionId, FuturesPosition>>>,
92
93 symbol_index: SymbolIndex,
95
96 event_sender: Option<Sender<PositionEvent>>,
98}
99
100impl DefaultPositionManager {
101 #[must_use]
103 pub fn new(event_sender: Option<Sender<PositionEvent>>) -> Self {
104 Self {
105 positions: Arc::new(RwLock::new(FxHashMap::default())),
106 symbol_index: Arc::new(RwLock::new(FxHashMap::default())),
107 event_sender,
108 }
109 }
110
111 fn send_event(&self, event: PositionEvent) {
113 if let Some(sender) = &self.event_sender
114 && let Err(e) = sender.send(event)
115 {
116 log::warn!("Failed to send position event: {e}");
117 }
118 }
119
120 fn update_symbol_index(
122 &self,
123 venue: Venue,
124 symbol: &str,
125 position_id: PositionId,
126 is_closed: bool,
127 ) {
128 let mut index = self.symbol_index.write();
129 let key = (venue, SmartString::from(symbol));
130
131 if is_closed {
132 if let Some(positions) = index.get_mut(&key) {
134 positions.retain(|&id| id != position_id);
135 if positions.is_empty() {
136 index.remove(&key);
137 }
138 }
139 } else {
140 let positions = index.entry(key).or_default();
142 if !positions.contains(&position_id) {
143 positions.push(position_id);
144 }
145 }
146 }
147}
148
149#[async_trait]
150impl PositionManager for DefaultPositionManager {
151 async fn update_position(&self, update: PositionUpdate) -> Result<()> {
152 let mut positions = self.positions.write();
153
154 if let Some(existing_position) = positions.get_mut(&update.position_id) {
156 let old_amount = existing_position.amount;
158
159 existing_position.update(
160 update.amount,
161 update.entry_price,
162 update.breakeven_price,
163 update.unrealized_pnl,
164 update.realized_pnl,
165 update.isolated_wallet,
166 );
167
168 if existing_position.is_closed() {
170 self.update_symbol_index(update.venue, &update.symbol, update.position_id, true);
171 self.send_event(PositionEvent::PositionClosed {
172 position_id: update.position_id,
173 venue: update.venue,
174 symbol: update.symbol,
175 realized_pnl: update.realized_pnl,
176 });
177 } else {
178 self.send_event(PositionEvent::PositionUpdated {
179 position_id: update.position_id,
180 venue: update.venue,
181 symbol: update.symbol.clone(),
182 old_amount,
183 new_amount: update.amount,
184 });
185 }
186 } else if !update.amount.is_zero() {
187 let position = FuturesPosition::new(
189 update.venue,
190 &update.symbol,
191 update.side,
192 update.amount,
193 update.entry_price,
194 update.margin_type,
195 );
196
197 let mut position = position;
199 position.id = update.position_id; position.update(
201 update.amount,
202 update.entry_price,
203 update.breakeven_price,
204 update.unrealized_pnl,
205 update.realized_pnl,
206 update.isolated_wallet,
207 );
208
209 positions.insert(update.position_id, position);
210 self.update_symbol_index(update.venue, &update.symbol, update.position_id, false);
211
212 self.send_event(PositionEvent::PositionOpened {
213 position_id: update.position_id,
214 venue: update.venue,
215 symbol: update.symbol,
216 });
217 }
218
219 Ok(())
220 }
221
222 async fn get_position(&self, position_id: &PositionId) -> Option<FuturesPosition> {
223 self.positions.read().get(position_id).cloned()
224 }
225
226 async fn get_positions_by_symbol(&self, venue: Venue, symbol: &str) -> Vec<FuturesPosition> {
227 let positions = self.positions.read();
228 let index = self.symbol_index.read();
229
230 let key = (venue, SmartString::from(symbol));
231 if let Some(position_ids) = index.get(&key) {
232 position_ids
233 .iter()
234 .filter_map(|id| positions.get(id).cloned())
235 .collect()
236 } else {
237 Vec::new()
238 }
239 }
240
241 async fn get_all_positions(&self) -> Vec<FuturesPosition> {
242 self.positions
243 .read()
244 .values()
245 .filter(|p| !p.is_closed())
246 .cloned()
247 .collect()
248 }
249
250 async fn get_total_unrealized_pnl(&self) -> Decimal {
251 self.positions
252 .read()
253 .values()
254 .filter(|p| !p.is_closed())
255 .map(|p| p.unrealized_pnl)
256 .sum()
257 }
258
259 async fn get_total_realized_pnl(&self) -> Decimal {
260 self.positions.read().values().map(|p| p.realized_pnl).sum()
261 }
262
263 async fn clear_all_positions(&self) {
264 self.positions.write().clear();
265 self.symbol_index.write().clear();
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use rusty_model::position::{MarginType, PositionSide};
273
274 #[tokio::test]
275 async fn test_position_creation() -> Result<()> {
276 let manager = DefaultPositionManager::new(None);
277
278 let update = PositionUpdate {
279 position_id: PositionId::new(),
280 venue: Venue::Binance,
281 symbol: "BTCUSDT".into(),
282 side: PositionSide::Long,
283 amount: Decimal::from(1),
284 entry_price: Decimal::from(50000),
285 breakeven_price: Decimal::from(50000),
286 unrealized_pnl: Decimal::ZERO,
287 realized_pnl: Decimal::ZERO,
288 margin_type: MarginType::Cross,
289 isolated_wallet: Decimal::ZERO,
290 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
291 };
292
293 manager.update_position(update.clone()).await?;
294
295 let position = manager.get_position(&update.position_id).await;
296 assert!(position.is_some());
297
298 let position = position.unwrap();
299 assert_eq!(position.symbol.as_str(), "BTCUSDT");
300 assert_eq!(position.amount, Decimal::from(1));
301
302 Ok(())
303 }
304
305 #[tokio::test]
306 async fn test_position_update() -> Result<()> {
307 let manager = DefaultPositionManager::new(None);
308
309 let position_id = PositionId::new();
310
311 let update1 = PositionUpdate {
313 position_id,
314 venue: Venue::Binance,
315 symbol: "BTCUSDT".into(),
316 side: PositionSide::Long,
317 amount: Decimal::from(1),
318 entry_price: Decimal::from(50000),
319 breakeven_price: Decimal::from(50000),
320 unrealized_pnl: Decimal::ZERO,
321 realized_pnl: Decimal::ZERO,
322 margin_type: MarginType::Cross,
323 isolated_wallet: Decimal::ZERO,
324 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
325 };
326
327 manager.update_position(update1).await?;
328
329 let update2 = PositionUpdate {
331 position_id,
332 venue: Venue::Binance,
333 symbol: "BTCUSDT".into(),
334 side: PositionSide::Long,
335 amount: Decimal::from(2),
336 entry_price: Decimal::from(51000),
337 breakeven_price: Decimal::from(51000),
338 unrealized_pnl: Decimal::from(1000),
339 realized_pnl: Decimal::ZERO,
340 margin_type: MarginType::Cross,
341 isolated_wallet: Decimal::ZERO,
342 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
343 };
344
345 manager.update_position(update2).await?;
346
347 let position = manager.get_position(&position_id).await;
348 assert!(position.is_some());
349
350 let position = position.unwrap();
351 assert_eq!(position.amount, Decimal::from(2));
352 assert_eq!(position.entry_price, Decimal::from(51000));
353 assert_eq!(position.unrealized_pnl, Decimal::from(1000));
354
355 Ok(())
356 }
357
358 #[tokio::test]
359 async fn test_position_closure() -> Result<()> {
360 let (event_tx, event_rx) = flume::unbounded();
361 let manager = DefaultPositionManager::new(Some(event_tx));
362
363 let position_id = PositionId::new();
364
365 let update1 = PositionUpdate {
367 position_id,
368 venue: Venue::Binance,
369 symbol: "BTCUSDT".into(),
370 side: PositionSide::Long,
371 amount: Decimal::from(1),
372 entry_price: Decimal::from(50000),
373 breakeven_price: Decimal::from(50000),
374 unrealized_pnl: Decimal::ZERO,
375 realized_pnl: Decimal::ZERO,
376 margin_type: MarginType::Cross,
377 isolated_wallet: Decimal::ZERO,
378 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
379 };
380
381 manager.update_position(update1).await?;
382
383 let event = event_rx.recv_async().await?;
385 assert!(matches!(event, PositionEvent::PositionOpened { .. }));
386
387 let update2 = PositionUpdate {
389 position_id,
390 venue: Venue::Binance,
391 symbol: "BTCUSDT".into(),
392 side: PositionSide::Long,
393 amount: Decimal::ZERO,
394 entry_price: Decimal::from(50000),
395 breakeven_price: Decimal::from(50000),
396 unrealized_pnl: Decimal::ZERO,
397 realized_pnl: Decimal::from(1000),
398 margin_type: MarginType::Cross,
399 isolated_wallet: Decimal::ZERO,
400 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
401 };
402
403 manager.update_position(update2).await?;
404
405 let event = event_rx.recv_async().await?;
407 match event {
408 PositionEvent::PositionClosed { realized_pnl, .. } => {
409 assert_eq!(realized_pnl, Decimal::from(1000));
410 }
411 _ => panic!("Expected PositionClosed event"),
412 }
413
414 let all_positions = manager.get_all_positions().await;
416 assert!(all_positions.is_empty());
417
418 Ok(())
419 }
420
421 #[tokio::test]
422 async fn test_get_positions_by_symbol() -> Result<()> {
423 let manager = DefaultPositionManager::new(None);
424
425 for i in 0..3 {
427 let update = PositionUpdate {
428 position_id: PositionId::new(),
429 venue: Venue::Binance,
430 symbol: "BTCUSDT".into(),
431 side: PositionSide::Long,
432 amount: Decimal::from(i + 1),
433 entry_price: Decimal::from(50000),
434 breakeven_price: Decimal::from(50000),
435 unrealized_pnl: Decimal::ZERO,
436 realized_pnl: Decimal::ZERO,
437 margin_type: MarginType::Cross,
438 isolated_wallet: Decimal::ZERO,
439 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
440 };
441
442 manager.update_position(update).await?;
443 }
444
445 let update = PositionUpdate {
447 position_id: PositionId::new(),
448 venue: Venue::Binance,
449 symbol: "ETHUSDT".into(),
450 side: PositionSide::Short,
451 amount: Decimal::from(5),
452 entry_price: Decimal::from(3000),
453 breakeven_price: Decimal::from(3000),
454 unrealized_pnl: Decimal::ZERO,
455 realized_pnl: Decimal::ZERO,
456 margin_type: MarginType::Cross,
457 isolated_wallet: Decimal::ZERO,
458 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
459 };
460
461 manager.update_position(update).await?;
462
463 let btc_positions = manager
465 .get_positions_by_symbol(Venue::Binance, "BTCUSDT")
466 .await;
467 assert_eq!(btc_positions.len(), 3);
468
469 let eth_positions = manager
470 .get_positions_by_symbol(Venue::Binance, "ETHUSDT")
471 .await;
472 assert_eq!(eth_positions.len(), 1);
473
474 let empty_positions = manager
475 .get_positions_by_symbol(Venue::Binance, "XRPUSDT")
476 .await;
477 assert!(empty_positions.is_empty());
478
479 Ok(())
480 }
481
482 #[tokio::test]
483 async fn test_pnl_calculations() -> Result<()> {
484 let manager = DefaultPositionManager::new(None);
485
486 let positions = [
488 (Decimal::from(1000), Decimal::from(500)), (Decimal::from(-500), Decimal::from(200)),
490 (Decimal::from(2000), Decimal::from(-100)),
491 ];
492
493 for (i, (unrealized, realized)) in positions.iter().enumerate() {
494 let update = PositionUpdate {
495 position_id: PositionId::new(),
496 venue: Venue::Binance,
497 symbol: format!("TEST{i}").into(),
498 side: PositionSide::Long,
499 amount: Decimal::from(1),
500 entry_price: Decimal::from(50000),
501 breakeven_price: Decimal::from(50000),
502 unrealized_pnl: *unrealized,
503 realized_pnl: *realized,
504 margin_type: MarginType::Cross,
505 isolated_wallet: Decimal::ZERO,
506 timestamp_ns: rusty_common::time::get_epoch_timestamp_ns(),
507 };
508
509 manager.update_position(update).await?;
510 }
511
512 let total_unrealized = manager.get_total_unrealized_pnl().await;
513 assert_eq!(total_unrealized, Decimal::from(2500)); let total_realized = manager.get_total_realized_pnl().await;
516 assert_eq!(total_realized, Decimal::from(600)); Ok(())
519 }
520}