1#[cfg(test)]
7use super::TradeSide;
8use super::tardis_features::{
9 ExponentialDecayCalculator, PriceEntropyCalculator, RollingPriceImpactCalculator,
10 TradingBurstDetector, VolumePriceSensitivityCalculator,
11};
12use super::{OrderBookSnapshot, TradeTick};
13use rust_decimal::Decimal;
14use rust_decimal::prelude::ToPrimitive;
15use std::collections::VecDeque;
16
17#[allow(dead_code)] struct ValidatedTrade<'a> {
24 trade: &'a TradeTick,
25 price_f64: f64,
26 quantity_f64: f64,
27}
28
29pub struct HarmonicOscillator {
33 short_window: VecDeque<f64>,
34 long_window: VecDeque<f64>,
35 short_size: usize,
36 long_size: usize,
37}
38
39#[derive(Debug, Clone)]
41pub struct HarmonicResult {
42 pub frequency: f64,
44 pub amplitude: Decimal,
46 pub phase: f64,
48}
49
50impl HarmonicOscillator {
51 #[must_use]
53 pub fn new(short_window_size: usize, long_window_size: usize) -> Self {
54 Self {
55 short_window: VecDeque::with_capacity(short_window_size),
56 long_window: VecDeque::with_capacity(long_window_size),
57 short_size: short_window_size,
58 long_size: long_window_size,
59 }
60 }
61
62 pub fn update(&mut self, price: Decimal) -> bool {
66 let price_f64 = match price.to_f64() {
68 Some(p) => p,
69 None => {
70 #[cfg(debug_assertions)]
72 eprintln!(
73 "Warning: Failed to convert price {price} to f64, skipping oscillator update"
74 );
75 return false;
76 }
77 };
78
79 self.short_window.push_back(price_f64);
81 self.long_window.push_back(price_f64);
82
83 if self.short_window.len() > self.short_size {
85 self.short_window.pop_front();
86 }
87 if self.long_window.len() > self.long_size {
88 self.long_window.pop_front();
89 }
90
91 true
92 }
93
94 #[must_use]
96 pub fn detect_oscillation(&self) -> HarmonicResult {
97 if self.long_window.len() < self.long_size {
98 return HarmonicResult {
99 frequency: 0.0,
100 amplitude: Decimal::ZERO,
101 phase: 0.0,
102 };
103 }
104
105 let short_vol = self.calculate_volatility(&self.short_window);
107 let long_vol = self.calculate_volatility(&self.long_window);
108
109 let frequency = if long_vol > 0.0 {
111 short_vol / long_vol
112 } else {
113 0.0
114 };
115
116 let prices: Vec<f64> = self.long_window.iter().copied().collect();
118 let min_price = prices.iter().fold(f64::INFINITY, |a, &b| a.min(b));
119 let max_price = prices.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
120 let amplitude =
121 Decimal::from_f64_retain((max_price - min_price) / 2.0).unwrap_or(Decimal::ZERO);
122
123 let phase = self.estimate_phase(&prices);
125
126 HarmonicResult {
127 frequency,
128 amplitude,
129 phase,
130 }
131 }
132
133 fn calculate_volatility(&self, window: &VecDeque<f64>) -> f64 {
135 if window.len() < 2 {
136 return 0.0;
137 }
138
139 let mean = window.iter().sum::<f64>() / window.len() as f64;
140 let variance =
141 window.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / window.len() as f64;
142
143 variance.sqrt()
144 }
145
146 fn estimate_phase(&self, prices: &[f64]) -> f64 {
148 if prices.len() < 10 {
149 return 0.0;
150 }
151
152 let mut best_correlation: f64 = 0.0;
153 let mut best_phase: f64 = 0.0;
154
155 for phase_steps in 0..20 {
157 let phase = (phase_steps as f64) * std::f64::consts::PI / 10.0;
158 let correlation = self.calculate_sine_correlation(prices, phase);
159
160 if correlation.abs() > best_correlation.abs() {
161 best_correlation = correlation;
162 best_phase = phase;
163 }
164 }
165
166 best_phase
167 }
168
169 fn calculate_sine_correlation(&self, prices: &[f64], phase: f64) -> f64 {
171 if prices.is_empty() {
172 return 0.0;
173 }
174
175 let n = prices.len();
176 let frequency = 2.0 * std::f64::consts::PI / n as f64;
177
178 let mut sum_xy = 0.0;
179 let mut sum_x = 0.0;
180 let mut sum_y = 0.0;
181 let mut sum_x2 = 0.0;
182 let mut sum_y2 = 0.0;
183
184 for (i, &price) in prices.iter().enumerate() {
185 let sine_val = (frequency * i as f64 + phase).sin();
186
187 sum_xy += price * sine_val;
188 sum_x += price;
189 sum_y += sine_val;
190 sum_x2 += price * price;
191 sum_y2 += sine_val * sine_val;
192 }
193
194 let n_f64 = n as f64;
195 let numerator = n_f64 * sum_xy - sum_x * sum_y;
196 let denominator =
197 ((n_f64 * sum_x2 - sum_x * sum_x) * (n_f64 * sum_y2 - sum_y * sum_y)).sqrt();
198
199 if denominator != 0.0 {
200 numerator / denominator
201 } else {
202 0.0
203 }
204 }
205}
206
207pub struct OrderTypeTracker {
216 limit_orders: u64,
217 market_orders: u64,
218 unknown_orders: u64,
219}
220
221impl OrderTypeTracker {
222 #[must_use]
224 pub const fn new() -> Self {
225 Self {
226 limit_orders: 0,
227 market_orders: 0,
228 unknown_orders: 0,
229 }
230 }
231
232 pub const fn add_order(&mut self, order_type: OrderType) {
234 match order_type {
235 OrderType::Limit => self.limit_orders += 1,
236 OrderType::Market => self.market_orders += 1,
237 OrderType::Unknown => self.unknown_orders += 1,
238 }
239 }
240
241 #[must_use]
243 pub fn limit_to_market_ratio(&self) -> Option<Decimal> {
244 if self.market_orders == 0 {
245 return None;
246 }
247 Some(Decimal::from(self.limit_orders) / Decimal::from(self.market_orders))
248 }
249
250 #[must_use]
255 pub fn market_aggressiveness(&self) -> Decimal {
256 let total_known = self.limit_orders + self.market_orders;
257 if total_known == 0 {
258 return Decimal::ZERO;
259 }
260 Decimal::from(self.market_orders) / Decimal::from(total_known)
261 }
262
263 pub const fn reset(&mut self) {
265 self.limit_orders = 0;
266 self.market_orders = 0;
267 self.unknown_orders = 0;
268 }
269
270 #[must_use]
272 pub const fn get_counts(&self) -> (u64, u64, u64) {
273 (self.limit_orders, self.market_orders, self.unknown_orders)
274 }
275
276 #[must_use]
278 pub const fn get_known_counts(&self) -> (u64, u64) {
279 (self.limit_orders, self.market_orders)
280 }
281
282 #[must_use]
284 pub const fn unknown_count(&self) -> u64 {
285 self.unknown_orders
286 }
287
288 #[must_use]
290 pub const fn total_count(&self) -> u64 {
291 self.limit_orders + self.market_orders + self.unknown_orders
292 }
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq)]
301pub enum OrderType {
302 Limit,
304 Market,
306 Unknown,
308}
309
310impl Default for OrderTypeTracker {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316#[derive(Debug, Clone)]
318pub struct TardisConfig {
319 pub decay_span: usize,
321 pub price_impact_window: usize,
323 pub volume_sensitivity_window: usize,
325 pub entropy_window: usize,
327 pub burst_detector_window: usize,
329 pub harmonic_short_window: usize,
331 pub harmonic_long_window: usize,
333 pub enable_order_type_inference: bool,
338}
339
340impl Default for TardisConfig {
341 fn default() -> Self {
342 Self {
343 decay_span: 20,
344 price_impact_window: 100,
345 volume_sensitivity_window: 50,
346 entropy_window: 100,
347 burst_detector_window: 100,
348 harmonic_short_window: 50,
349 harmonic_long_window: 200,
350 enable_order_type_inference: false,
352 }
353 }
354}
355
356pub struct TardisAdvancedFeatures {
360 decay_calculator: ExponentialDecayCalculator,
361 impact_tracker: RollingPriceImpactCalculator,
362 oscillator: HarmonicOscillator,
363 burst_detector: TradingBurstDetector,
364 sensitivity: VolumePriceSensitivityCalculator,
365 order_tracker: OrderTypeTracker,
366 entropy: PriceEntropyCalculator,
367 last_price_impact: f64,
369 last_burst_intensity: f64,
370 last_volume_price_sensitivity: f64,
371 last_price_entropy: f64,
372 config: TardisConfig,
374}
375
376#[derive(Debug, Clone)]
378pub struct TardisFeatureVector {
379 pub exponential_mean: f64,
381 pub exponential_std: f64,
383 pub exponential_var: f64,
385 pub price_impact: f64,
387 pub harmonic_frequency: f64,
389 pub harmonic_amplitude: Decimal,
391 pub harmonic_phase: f64,
393 pub burst_intensity: f64,
395 pub volume_price_sensitivity: f64,
397 pub limit_market_ratio: Option<Decimal>,
399 pub market_aggressiveness: Decimal,
401 pub price_entropy: f64,
403 pub unknown_order_count: u64,
405 pub total_order_count: u64,
407}
408
409impl TardisAdvancedFeatures {
410 #[must_use]
412 pub fn new(config: TardisConfig) -> Self {
413 Self {
414 decay_calculator: ExponentialDecayCalculator::new(config.decay_span),
415 impact_tracker: RollingPriceImpactCalculator::new(config.price_impact_window),
416 oscillator: HarmonicOscillator::new(
417 config.harmonic_short_window,
418 config.harmonic_long_window,
419 ),
420 burst_detector: TradingBurstDetector::new(config.burst_detector_window),
421 sensitivity: VolumePriceSensitivityCalculator::new(config.volume_sensitivity_window),
422 order_tracker: OrderTypeTracker::new(),
423 entropy: PriceEntropyCalculator::new(config.entropy_window),
424 last_price_impact: 0.0,
426 last_burst_intensity: 0.0,
427 last_volume_price_sensitivity: 0.0,
428 last_price_entropy: 0.0,
429 config,
431 }
432 }
433
434 pub fn update_with_trade(&mut self, trade: &TradeTick) -> Result<(), &'static str> {
436 let price_f64 = match trade.price.to_f64() {
438 Some(p) => p,
439 None => {
440 #[cfg(debug_assertions)]
442 eprintln!(
443 "Warning: Failed to convert trade price {} to f64, skipping all feature updates",
444 trade.price
445 );
446 return Err("Failed to convert trade price to f64");
447 }
448 };
449
450 let quantity_f64 = match trade.quantity.to_f64() {
451 Some(q) => q,
452 None => {
453 #[cfg(debug_assertions)]
455 eprintln!(
456 "Warning: Failed to convert trade quantity {} to f64, skipping all feature updates",
457 trade.quantity
458 );
459 return Err("Failed to convert trade quantity to f64");
460 }
461 };
462
463 let validated_trade = ValidatedTrade {
465 trade,
466 price_f64,
467 quantity_f64,
468 };
469
470 self.decay_calculator.update(price_f64);
475 self.update_harmonic_oscillator_safe(trade.price)?;
476 self.last_price_entropy = self.entropy.update(trade.price);
477
478 self.last_price_impact = self.update_price_impact_safe(&validated_trade)?;
480 self.last_burst_intensity = self.update_burst_detector_safe(&validated_trade)?;
481 self.last_volume_price_sensitivity =
482 self.update_volume_sensitivity_safe(&validated_trade)?;
483
484 let order_type = self.infer_order_type(trade);
486 self.order_tracker.add_order(order_type);
487
488 Ok(())
489 }
490
491 pub fn update_with_snapshot(
493 &mut self,
494 snapshot: &OrderBookSnapshot,
495 ) -> Result<(), &'static str> {
496 let mid_price = snapshot.mid_price();
497 let mid_price_f64 = match mid_price.to_f64() {
499 Some(p) => p,
500 None => {
501 #[cfg(debug_assertions)]
503 eprintln!(
504 "Warning: Failed to convert mid price {mid_price} to f64, skipping all feature updates"
505 );
506 return Err("Failed to convert mid price to f64");
507 }
508 };
509
510 self.decay_calculator.update(mid_price_f64);
512 self.update_harmonic_oscillator_safe(mid_price)?;
513 self.last_price_entropy = self.entropy.update(mid_price);
514
515 Ok(())
516 }
517
518 fn update_harmonic_oscillator_safe(&mut self, price: Decimal) -> Result<(), &'static str> {
520 if self.oscillator.update(price) {
522 Ok(())
523 } else {
524 #[cfg(debug_assertions)]
525 log::warn!("Failed to update harmonic oscillator with price {price}");
526 Err("Failed to update harmonic oscillator")
527 }
528 }
529
530 fn update_price_impact_safe(
532 &mut self,
533 validated_trade: &ValidatedTrade,
534 ) -> Result<f64, &'static str> {
535 if validated_trade.price_f64.is_nan() {
537 log::warn!("Price impact calculation skipped: price is NaN");
538 return Err("Price is NaN in validated trade");
539 }
540
541 if validated_trade.quantity_f64.is_nan() {
542 log::warn!("Price impact calculation skipped: quantity is NaN");
543 return Err("Quantity is NaN in validated trade");
544 }
545
546 let impact = self.impact_tracker.add_trade(validated_trade.trade);
549
550 if impact.is_nan() {
553 #[cfg(debug_assertions)]
555 log::debug!(
556 "Price impact calculation returned NaN (may be expected for initial state)"
557 );
558 }
559
560 Ok(impact) }
562
563 fn update_burst_detector_safe(
565 &mut self,
566 validated_trade: &ValidatedTrade,
567 ) -> Result<f64, &'static str> {
568 let burst = self.burst_detector.add_trade(validated_trade.trade);
570
571 Ok(burst) }
574
575 fn update_volume_sensitivity_safe(
577 &mut self,
578 validated_trade: &ValidatedTrade,
579 ) -> Result<f64, &'static str> {
580 let sensitivity = self.sensitivity.add_trade(validated_trade.trade);
582
583 Ok(sensitivity) }
586
587 #[must_use]
589 pub fn get_features(&self) -> TardisFeatureVector {
590 let harmonic_result = self.oscillator.detect_oscillation();
591
592 TardisFeatureVector {
593 exponential_mean: self.decay_calculator.get_mean(),
594 exponential_std: self.decay_calculator.get_std(),
595 exponential_var: self.decay_calculator.get_var(),
596 price_impact: self.last_price_impact,
597 harmonic_frequency: harmonic_result.frequency,
598 harmonic_amplitude: harmonic_result.amplitude,
599 harmonic_phase: harmonic_result.phase,
600 burst_intensity: self.last_burst_intensity,
601 volume_price_sensitivity: self.last_volume_price_sensitivity,
602 limit_market_ratio: self.order_tracker.limit_to_market_ratio(),
603 market_aggressiveness: self.order_tracker.market_aggressiveness(),
604 price_entropy: self.last_price_entropy,
605 unknown_order_count: self.order_tracker.unknown_count(),
606 total_order_count: self.order_tracker.total_count(),
607 }
608 }
609
610 pub const fn reset(&mut self) {
612 self.order_tracker.reset();
613 }
615
616 #[must_use]
622 const fn infer_order_type(&self, _trade: &TradeTick) -> OrderType {
623 if self.config.enable_order_type_inference {
630 OrderType::Unknown
633 } else {
634 OrderType::Unknown
635 }
636 }
637
638 #[must_use]
640 pub fn get_config(&self) -> TardisConfig {
641 self.config.clone()
642 }
643}
644
645impl Default for TardisAdvancedFeatures {
646 fn default() -> Self {
647 Self::new(TardisConfig::default())
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654 use rust_decimal_macros::dec;
655
656 #[test]
657 fn test_harmonic_oscillations() {
658 let mut oscillator = HarmonicOscillator::new(50, 200);
659
660 for i in 0..300 {
662 let price = 100.0 + 5.0 * (i as f64 * 0.1).sin();
663 let success = oscillator.update(Decimal::from_f64_retain(price).unwrap());
664 assert!(
665 success,
666 "Harmonic oscillator update should succeed for valid price"
667 );
668 }
669
670 let harmonics = oscillator.detect_oscillation();
671 assert!(harmonics.frequency > 0.0);
672 assert!(harmonics.amplitude > dec!(4.0));
673 }
674
675 #[test]
676 fn test_limit_market_order_ratio() {
677 let mut tracker = OrderTypeTracker::new();
678
679 for _ in 0..70 {
681 tracker.add_order(OrderType::Limit);
682 }
683 for _ in 0..30 {
684 tracker.add_order(OrderType::Market);
685 }
686
687 let ratio = tracker.limit_to_market_ratio();
688 assert!(ratio.is_some());
689 assert!((ratio.unwrap() - dec!(2.33)).abs() < dec!(0.01));
690
691 let aggressiveness = tracker.market_aggressiveness();
692 assert!((aggressiveness - dec!(0.30)).abs() < dec!(0.01));
693
694 assert_eq!(tracker.get_counts(), (70, 30, 0));
696 assert_eq!(tracker.total_count(), 100);
697 assert_eq!(tracker.unknown_count(), 0);
698 }
699
700 #[test]
701 fn test_tardis_advanced_features_integration() {
702 let config = TardisConfig::default();
703 let mut features = TardisAdvancedFeatures::new(config);
704
705 let trade = TradeTick {
707 timestamp_ns: 1000000000,
708 symbol: "BTCUSDT".to_string(),
709 side: TradeSide::Buy,
710 price: dec!(50000.0),
711 quantity: dec!(1.5),
712 };
713
714 let update_result = features.update_with_trade(&trade);
716 assert!(
717 update_result.is_ok(),
718 "Valid trade should update successfully"
719 );
720
721 let feature_vector = features.get_features();
723
724 assert!(!feature_vector.exponential_mean.is_nan());
726 assert!(feature_vector.exponential_std >= 0.0);
727 assert!(feature_vector.exponential_var >= 0.0);
728 assert!(feature_vector.harmonic_frequency >= 0.0);
729 assert!(feature_vector.harmonic_amplitude >= Decimal::ZERO);
730 assert!(
732 feature_vector.limit_market_ratio.is_none()
733 || feature_vector.limit_market_ratio.unwrap() >= Decimal::ZERO
734 );
735 assert!(feature_vector.market_aggressiveness >= Decimal::ZERO);
736
737 let _ = feature_vector.unknown_order_count; let _ = feature_vector.total_order_count; assert!(feature_vector.unknown_order_count > 0);
743 }
744
745 #[test]
746 fn test_order_type_inference_returns_unknown() {
747 let config = TardisConfig::default();
748 let features = TardisAdvancedFeatures::new(config);
749
750 let small_trade = TradeTick {
752 timestamp_ns: 1000000000,
753 symbol: "BTCUSDT".to_string(),
754 side: TradeSide::Buy,
755 price: dec!(50000.0),
756 quantity: dec!(0.1), };
758
759 let large_trade = TradeTick {
760 timestamp_ns: 1000000000,
761 symbol: "BTCUSDT".to_string(),
762 side: TradeSide::Buy,
763 price: dec!(50000.0),
764 quantity: dec!(100.0), };
766
767 assert_eq!(features.infer_order_type(&small_trade), OrderType::Unknown);
769 assert_eq!(features.infer_order_type(&large_trade), OrderType::Unknown);
770 }
771
772 #[test]
773 fn test_order_type_tracker_reset() {
774 let mut tracker = OrderTypeTracker::new();
775
776 tracker.add_order(OrderType::Limit);
777 tracker.add_order(OrderType::Market);
778
779 assert_eq!(tracker.get_counts(), (1, 1, 0));
780
781 tracker.reset();
782 assert_eq!(tracker.get_counts(), (0, 0, 0));
783 }
784
785 #[test]
786 fn test_harmonic_oscillator_empty_windows() {
787 let oscillator = HarmonicOscillator::new(10, 20);
788 let result = oscillator.detect_oscillation();
789
790 assert_eq!(result.frequency, 0.0);
791 assert_eq!(result.amplitude, Decimal::ZERO);
792 assert_eq!(result.phase, 0.0);
793 }
794
795 #[test]
796 fn test_order_tracker_edge_cases() {
797 let tracker = OrderTypeTracker::new();
798
799 assert_eq!(tracker.limit_to_market_ratio(), None);
801 assert_eq!(tracker.market_aggressiveness(), Decimal::ZERO);
802 }
803
804 #[test]
805 fn test_unknown_order_type_handling() {
806 let mut tracker = OrderTypeTracker::new();
807
808 tracker.add_order(OrderType::Limit);
810 tracker.add_order(OrderType::Market);
811 tracker.add_order(OrderType::Unknown);
812 tracker.add_order(OrderType::Unknown);
813
814 assert_eq!(tracker.get_counts(), (1, 1, 2));
816 assert_eq!(tracker.get_known_counts(), (1, 1));
817 assert_eq!(tracker.unknown_count(), 2);
818 assert_eq!(tracker.total_count(), 4);
819
820 let aggressiveness = tracker.market_aggressiveness();
822 assert_eq!(aggressiveness, dec!(0.5)); tracker.reset();
826 assert_eq!(tracker.get_counts(), (0, 0, 0));
827 }
828
829 #[test]
830 fn test_tardis_config_default() {
831 let config = TardisConfig::default();
832
833 assert_eq!(config.decay_span, 20);
834 assert_eq!(config.price_impact_window, 100);
835 assert_eq!(config.harmonic_short_window, 50);
836 assert_eq!(config.harmonic_long_window, 200);
837 assert!(!config.enable_order_type_inference); }
839
840 #[test]
841 fn test_atomic_update_behavior() {
842 let config = TardisConfig::default();
843 let mut features = TardisAdvancedFeatures::new(config);
844
845 let valid_trade = TradeTick {
847 timestamp_ns: 1000000000,
848 symbol: "BTCUSDT".to_string(),
849 side: TradeSide::Buy,
850 price: dec!(50000.0),
851 quantity: dec!(1.5),
852 };
853
854 let result = features.update_with_trade(&valid_trade);
855 assert!(result.is_ok(), "Valid trade should update successfully");
856
857 let valid_trade2 = TradeTick {
859 timestamp_ns: 2000000000,
860 symbol: "BTCUSDT".to_string(),
861 side: TradeSide::Sell,
862 price: dec!(49500.0),
863 quantity: dec!(2.0),
864 };
865
866 let result = features.update_with_trade(&valid_trade2);
867 assert!(
868 result.is_ok(),
869 "Second valid trade should also update successfully"
870 );
871
872 let feature_vector = features.get_features();
874 assert!(!feature_vector.exponential_mean.is_nan());
875 assert!(feature_vector.exponential_std >= 0.0);
876 assert!(feature_vector.exponential_var >= 0.0);
877 assert!(feature_vector.total_order_count >= 2);
878 }
879
880 #[test]
881 fn test_harmonic_oscillator_conversion_failure() {
882 let mut oscillator = HarmonicOscillator::new(10, 20);
883
884 let valid_price = dec!(100.0);
886 let success = oscillator.update(valid_price);
887 assert!(success, "Valid price should succeed");
888
889 let large_decimal = Decimal::MAX;
891 let success = oscillator.update(large_decimal);
892 assert!(success, "Large decimal should still succeed");
894
895 let min_decimal = Decimal::MIN;
897 let success = oscillator.update(min_decimal);
898 assert!(success, "Minimum decimal should still succeed");
900 }
901}