1use rust_decimal::Decimal;
2use rust_decimal::prelude::{FromPrimitive, ToPrimitive};
3use wide::f64x4;
4
5use rusty_common::{
7 extract_max_f64x4, extract_min_f64x4, extract_sum_f64x4, simd_dual_op, simd_reduce,
8};
9
10pub struct SimdOps;
24
25impl SimdOps {
26 #[inline]
30 #[must_use]
31 pub fn sum_f64(values: &[f64]) -> f64 {
32 if values.is_empty() {
33 return 0.0;
34 }
35
36 if values.len() == 1 {
37 return values[0];
38 }
39
40 if values.len() < 4 {
42 return values.iter().sum();
43 }
44
45 Self::sum_f64_wide(values)
47 }
48
49 #[inline]
51 #[must_use]
52 pub fn decimal_to_f64(values: &[Decimal]) -> Vec<f64> {
53 values
54 .iter()
55 .map(|v| rusty_common::decimal_utils::decimal_to_f64_or_nan(*v))
56 .collect()
57 }
58
59 #[inline]
61 #[must_use]
62 pub fn sum_decimal(values: &[Decimal]) -> Decimal {
63 if values.is_empty() {
64 return Decimal::ZERO;
65 }
66
67 if values.len() == 1 {
68 return values[0];
69 }
70
71 if values.len() < 4 {
73 return values.iter().sum();
74 }
75
76 if values.len() <= 64 {
78 let mut f64_values = [0.0f64; 64];
79 for (i, &value) in values.iter().enumerate() {
80 if i >= 64 {
81 break;
82 }
83 let converted = rusty_common::decimal_utils::decimal_to_f64_or_nan(value);
84 f64_values[i] = if converted.is_nan() {
85 0.0 } else {
87 converted
88 };
89 }
90
91 let sum_f64 = Self::sum_f64(&f64_values[..values.len()]);
93
94 return Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO);
96 }
97
98 let f64_values = Self::decimal_to_f64(values);
101
102 let sum_f64 = Self::sum_f64(&f64_values);
104
105 Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO)
108 }
109
110 #[inline]
113 pub fn sum_price_level_sizes<T>(levels: &[T], extract_size: impl Fn(&T) -> Decimal) -> Decimal {
114 if levels.is_empty() {
115 return Decimal::ZERO;
116 }
117
118 if levels.len() == 1 {
119 return extract_size(&levels[0]);
120 }
121
122 if levels.len() < 4 {
124 return levels.iter().map(extract_size).sum();
125 }
126
127 if levels.len() <= 64 {
129 let mut f64_values = [0.0f64; 64];
130 for (i, level) in levels.iter().enumerate() {
131 if i >= 64 {
132 break;
133 }
134 f64_values[i] =
135 rusty_common::decimal_utils::decimal_to_f64_or_nan(extract_size(level));
136 }
137
138 let sum_f64 = Self::sum_f64(&f64_values[..levels.len()]);
140
141 return Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO);
143 }
144
145 let f64_values: Vec<f64> = levels
148 .iter()
149 .map(|level| rusty_common::decimal_utils::decimal_to_f64_or_nan(extract_size(level)))
150 .collect();
151
152 let sum_f64 = Self::sum_f64(&f64_values);
154
155 Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO)
157 }
158
159 #[inline]
162 #[must_use]
163 pub fn min_f64(values: &[f64]) -> f64 {
164 if values.is_empty() {
165 return f64::NAN;
166 }
167
168 if values.len() == 1 {
169 return values[0];
170 }
171
172 if values.len() >= 4 {
173 return Self::min_f64_wide(values);
174 }
175
176 let mut min_val = values[0];
178 for &val in &values[1..] {
179 min_val = min_val.min(val);
180 }
181 min_val
182 }
183
184 #[inline]
187 #[must_use]
188 pub fn max_f64(values: &[f64]) -> f64 {
189 if values.is_empty() {
190 return f64::NAN;
191 }
192
193 if values.len() == 1 {
194 return values[0];
195 }
196
197 if values.len() >= 4 {
198 return Self::max_f64_wide(values);
199 }
200
201 let mut max_val = values[0];
203 for &val in &values[1..] {
204 max_val = max_val.max(val);
205 }
206 max_val
207 }
208
209 #[inline]
211 #[must_use]
212 pub fn min_decimal(values: &[Decimal]) -> Decimal {
213 if values.is_empty() {
214 return Decimal::ZERO; }
216
217 if values.len() == 1 {
218 return values[0];
219 }
220
221 if values.len() < 4 {
223 let mut min_val = values[0];
224 for &val in &values[1..] {
225 if val < min_val {
226 min_val = val;
227 }
228 }
229 return min_val;
230 }
231
232 let f64_values = Self::decimal_to_f64(values);
234
235 let min_f64 = Self::min_f64(&f64_values);
237
238 for &val in values {
241 if let Some(f64_val) = val.to_f64()
242 && (f64_val - min_f64).abs() < f64::EPSILON
243 {
244 return val;
245 }
246 }
247
248 Decimal::from_f64(min_f64).unwrap_or(Decimal::ZERO)
250 }
251
252 #[inline]
254 #[must_use]
255 pub fn max_decimal(values: &[Decimal]) -> Decimal {
256 if values.is_empty() {
257 return Decimal::ZERO; }
259
260 if values.len() == 1 {
261 return values[0];
262 }
263
264 if values.len() < 4 {
266 let mut max_val = values[0];
267 for &val in &values[1..] {
268 if val > max_val {
269 max_val = val;
270 }
271 }
272 return max_val;
273 }
274
275 let f64_values = Self::decimal_to_f64(values);
277
278 let max_f64 = Self::max_f64(&f64_values);
280
281 for &val in values {
284 if let Some(f64_val) = val.to_f64()
285 && (f64_val - max_f64).abs() < f64::EPSILON
286 {
287 return val;
288 }
289 }
290
291 Decimal::from_f64(max_f64).unwrap_or(Decimal::ZERO)
293 }
294
295 #[inline]
297 #[must_use]
298 pub fn dot_product_f64(a: &[f64], b: &[f64]) -> f64 {
299 if a.is_empty() || b.is_empty() || a.len() != b.len() {
300 return 0.0;
301 }
302
303 if a.len() == 1 {
304 return a[0] * b[0];
305 }
306
307 if a.len() >= 4 {
308 return Self::dot_product_f64_wide(a, b);
309 }
310
311 let mut sum = 0.0;
313 for i in 0..a.len() {
314 sum += a[i] * b[i];
315 }
316 sum
317 }
318
319 #[inline]
321 #[must_use]
322 #[allow(clippy::cast_precision_loss)] pub fn mean_f64(values: &[f64]) -> f64 {
324 if values.is_empty() {
325 return f64::NAN;
326 }
327
328 if values.len() == 1 {
329 return values[0];
330 }
331
332 let sum = Self::sum_f64(values);
333 sum / (values.len() as f64)
334 }
335
336 #[inline]
338 #[must_use]
339 pub fn mean_decimal(values: &[Decimal]) -> Decimal {
340 if values.is_empty() {
341 return Decimal::ZERO; }
343
344 if values.len() == 1 {
345 return values[0];
346 }
347
348 let sum = Self::sum_decimal(values);
349 sum / Decimal::from(values.len())
350 }
351
352 #[inline]
355 #[must_use]
356 pub fn variance_f64(values: &[f64]) -> f64 {
357 if values.is_empty() {
358 return f64::NAN;
359 }
360
361 if values.len() == 1 {
362 return 0.0; }
364
365 let mut count = 0;
367 let mut mean = 0.0;
368 let mut m2 = 0.0;
369
370 let mut i = 0;
371 let n = values.len();
372
373 while i + 4 <= n {
375 let chunk = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
376 let chunk_array = chunk.as_array_ref();
377
378 for &value in chunk_array {
380 if !value.is_nan() {
381 count += 1;
382 let delta = value - mean;
383 mean += delta / f64::from(count);
384 let delta2 = value - mean;
385 m2 += delta * delta2;
386 }
387 }
388
389 i += 4;
390 }
391
392 while i < n {
394 let value = values[i];
395 if !value.is_nan() {
396 count += 1;
397 let delta = value - mean;
398 mean += delta / f64::from(count);
399 let delta2 = value - mean;
400 m2 += delta * delta2;
401 }
402 i += 1;
403 }
404
405 if count > 0 {
406 m2 / f64::from(count)
407 } else {
408 f64::NAN
409 }
410 }
411
412 #[inline]
415 #[must_use]
416 pub fn std_dev_f64(values: &[f64]) -> f64 {
417 if values.is_empty() {
418 return f64::NAN;
419 }
420
421 if values.len() == 1 {
422 return 0.0; }
424
425 let variance = Self::variance_f64(values);
426 variance.sqrt()
427 }
428
429 #[inline]
432 #[must_use]
433 pub fn variance_decimal(values: &[Decimal]) -> Decimal {
434 if values.is_empty() {
435 return Decimal::ZERO; }
437
438 if values.len() == 1 {
439 return Decimal::ZERO; }
441
442 let mean = Self::mean_decimal(values);
444
445 if values.len() < 4 {
447 let mut sum_squared_diff = Decimal::ZERO;
448 for &val in values {
449 let diff = val - mean;
450 sum_squared_diff += diff * diff;
451 }
452 return sum_squared_diff / Decimal::from(values.len());
453 }
454
455 let f64_values = Self::decimal_to_f64(values);
457 let f64_mean = rusty_common::decimal_utils::decimal_to_f64_or_nan(mean);
458
459 let mut squared_diffs = vec![0.0; values.len()];
461 for (i, &val) in f64_values.iter().enumerate() {
462 squared_diffs[i] = (val - f64_mean).powi(2);
463 }
464
465 let variance_f64 = Self::mean_f64(&squared_diffs);
467
468 Decimal::from_f64(variance_f64).unwrap_or(Decimal::ZERO)
470 }
471
472 #[inline]
475 #[must_use]
476 pub fn std_dev_decimal(values: &[Decimal]) -> Decimal {
477 if values.is_empty() {
478 return Decimal::ZERO; }
480
481 if values.len() == 1 {
482 return Decimal::ZERO; }
484
485 let variance = Self::variance_decimal(values);
486
487 if let Some(variance_f64) = variance.to_f64() {
489 let std_dev_f64 = variance_f64.sqrt();
490 return Decimal::from_f64(std_dev_f64).unwrap_or(Decimal::ZERO);
491 }
492
493 Decimal::ZERO
495 }
496
497 #[inline]
499 fn sum_f64_wide(values: &[f64]) -> f64 {
500 simd_reduce!(
501 values,
502 2,
503 8, f64x4::ZERO, |acc: f64x4, val: f64x4| acc + val, |acc1: f64x4, acc2: f64x4| {
507 let total: f64x4 = acc1 + acc2; extract_sum_f64x4!(total) },
510 0.0, |acc: f64, val: f64| acc + val )
513 }
514
515 #[inline]
517 fn min_f64_wide(values: &[f64]) -> f64 {
518 simd_reduce!(
519 values,
520 1,
521 4, f64x4::splat(values[0]), |acc: f64x4, val: f64x4| acc.min(val), |acc: f64x4| extract_min_f64x4!(acc), values[0], |acc: f64, val: f64| acc.min(val) )
528 }
529
530 #[inline]
532 fn max_f64_wide(values: &[f64]) -> f64 {
533 simd_reduce!(
534 values,
535 1,
536 4, f64x4::splat(values[0]), |acc: f64x4, val: f64x4| acc.max(val), |acc: f64x4| extract_max_f64x4!(acc), values[0], |acc: f64, val: f64| acc.max(val) )
543 }
544
545 #[inline]
547 fn dot_product_f64_wide(a: &[f64], b: &[f64]) -> f64 {
548 simd_dual_op!(
549 a,
550 b,
551 1,
552 4, f64x4::ZERO, |acc: f64x4, a_vec: f64x4, b_vec: f64x4| acc + (a_vec * b_vec), |acc: f64x4| extract_sum_f64x4!(acc), 0.0, |acc: f64, a_val: f64, b_val: f64| a_val.mul_add(b_val, acc) )
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use rust_decimal_macros::dec;
566
567 #[test]
568 fn test_sum_f64() {
569 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
570 let sum = SimdOps::sum_f64(&values);
571 assert!((sum - 36.0).abs() < f64::EPSILON);
572 }
573
574 #[test]
575 fn test_sum_decimal() {
576 let values = vec![
577 dec!(1.0),
578 dec!(2.0),
579 dec!(3.0),
580 dec!(4.0),
581 dec!(5.0),
582 dec!(6.0),
583 dec!(7.0),
584 dec!(8.0),
585 ];
586 let sum = SimdOps::sum_decimal(&values);
587 assert_eq!(sum, dec!(36.0));
588 }
589
590 #[test]
591 fn test_min_f64() {
592 let values = vec![3.0, 1.0, 7.0, 4.0, 5.0, 2.0, 8.0, 6.0];
593 let min = SimdOps::min_f64(&values);
594 assert!((min - 1.0).abs() < f64::EPSILON);
595 }
596
597 #[test]
598 fn test_max_f64() {
599 let values = vec![3.0, 1.0, 7.0, 4.0, 5.0, 2.0, 8.0, 6.0];
600 let max = SimdOps::max_f64(&values);
601 assert!((max - 8.0).abs() < f64::EPSILON);
602 }
603
604 #[test]
605 fn test_min_decimal() {
606 let values = vec![
607 dec!(3.0),
608 dec!(1.0),
609 dec!(7.0),
610 dec!(4.0),
611 dec!(5.0),
612 dec!(2.0),
613 dec!(8.0),
614 dec!(6.0),
615 ];
616 let min = SimdOps::min_decimal(&values);
617 assert_eq!(min, dec!(1.0));
618 }
619
620 #[test]
621 fn test_max_decimal() {
622 let values = vec![
623 dec!(3.0),
624 dec!(1.0),
625 dec!(7.0),
626 dec!(4.0),
627 dec!(5.0),
628 dec!(2.0),
629 dec!(8.0),
630 dec!(6.0),
631 ];
632 let max = SimdOps::max_decimal(&values);
633 assert_eq!(max, dec!(8.0));
634 }
635
636 #[test]
637 fn test_dot_product_f64() {
638 let a = vec![1.0, 2.0, 3.0, 4.0];
639 let b = vec![5.0, 6.0, 7.0, 8.0];
640 let dot = SimdOps::dot_product_f64(&a, &b);
642 assert!((dot - 70.0).abs() < f64::EPSILON);
643 }
644
645 #[test]
646 fn test_mean_f64() {
647 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
648 let mean = SimdOps::mean_f64(&values);
649 assert!((mean - 4.5).abs() < f64::EPSILON);
650 }
651
652 #[test]
653 fn test_mean_decimal() {
654 let values = vec![
655 dec!(1.0),
656 dec!(2.0),
657 dec!(3.0),
658 dec!(4.0),
659 dec!(5.0),
660 dec!(6.0),
661 dec!(7.0),
662 dec!(8.0),
663 ];
664 let mean = SimdOps::mean_decimal(&values);
665 assert_eq!(mean, dec!(4.5));
666 }
667
668 #[test]
669 fn test_edge_cases() {
670 assert_eq!(SimdOps::sum_f64(&[]), 0.0);
672 assert_eq!(SimdOps::sum_decimal(&[]), dec!(0));
673 assert!(SimdOps::min_f64(&[]).is_nan());
674 assert!(SimdOps::max_f64(&[]).is_nan());
675 assert_eq!(SimdOps::min_decimal(&[]), dec!(0));
676 assert_eq!(SimdOps::max_decimal(&[]), dec!(0));
677 assert!(SimdOps::mean_f64(&[]).is_nan());
678 assert_eq!(SimdOps::mean_decimal(&[]), dec!(0));
679 assert_eq!(SimdOps::dot_product_f64(&[], &[1.0]), 0.0);
680 assert_eq!(SimdOps::dot_product_f64(&[1.0], &[]), 0.0);
681 assert_eq!(SimdOps::dot_product_f64(&[1.0, 2.0], &[1.0]), 0.0); assert_eq!(SimdOps::sum_f64(&[42.0]), 42.0);
685 assert_eq!(SimdOps::sum_decimal(&[dec!(42.0)]), dec!(42.0));
686 assert_eq!(SimdOps::min_f64(&[42.0]), 42.0);
687 assert_eq!(SimdOps::max_f64(&[42.0]), 42.0);
688 assert_eq!(SimdOps::min_decimal(&[dec!(42.0)]), dec!(42.0));
689 assert_eq!(SimdOps::max_decimal(&[dec!(42.0)]), dec!(42.0));
690 assert_eq!(SimdOps::mean_f64(&[42.0]), 42.0);
691 assert_eq!(SimdOps::mean_decimal(&[dec!(42.0)]), dec!(42.0));
692 assert_eq!(SimdOps::dot_product_f64(&[3.0], &[4.0]), 12.0);
693 }
694
695 #[test]
696 fn test_small_arrays() {
697 let values_f64 = [1.0, 2.0, 3.0];
699 let values_decimal = [dec!(1.0), dec!(2.0), dec!(3.0)];
700
701 assert_eq!(SimdOps::sum_f64(&values_f64), 6.0);
702 assert_eq!(SimdOps::sum_decimal(&values_decimal), dec!(6.0));
703 assert_eq!(SimdOps::min_f64(&values_f64), 1.0);
704 assert_eq!(SimdOps::max_f64(&values_f64), 3.0);
705 assert_eq!(SimdOps::min_decimal(&values_decimal), dec!(1.0));
706 assert_eq!(SimdOps::max_decimal(&values_decimal), dec!(3.0));
707 assert_eq!(SimdOps::mean_f64(&values_f64), 2.0);
708 assert_eq!(SimdOps::mean_decimal(&values_decimal), dec!(2.0));
709 }
710
711 #[test]
712 fn test_variance_f64() {
713 let values = vec![2.0, 4.0, 6.0, 8.0];
716 let variance = SimdOps::variance_f64(&values);
717 assert_eq!(variance, 5.0);
718 }
719
720 #[test]
721 fn test_std_dev_f64() {
722 let values = vec![2.0, 4.0, 6.0, 8.0];
725 let std_dev = SimdOps::std_dev_f64(&values);
726 assert!((std_dev - 2.236).abs() < 0.001);
727 }
728
729 #[test]
730 fn test_variance_decimal() {
731 let values = vec![dec!(2.0), dec!(4.0), dec!(6.0), dec!(8.0)];
733 let variance = SimdOps::variance_decimal(&values);
734 assert_eq!(variance, dec!(5.0));
735 }
736
737 #[test]
738 fn test_std_dev_decimal() {
739 let values = vec![dec!(2.0), dec!(4.0), dec!(6.0), dec!(8.0)];
741 let std_dev = SimdOps::std_dev_decimal(&values);
742 assert!(std_dev > dec!(2.235) && std_dev < dec!(2.237));
744 }
745
746 #[test]
747 fn test_variance_std_dev_edge_cases() {
748 assert!(SimdOps::variance_f64(&[]).is_nan());
750 assert!(SimdOps::std_dev_f64(&[]).is_nan());
751 assert_eq!(SimdOps::variance_decimal(&[]), dec!(0));
752 assert_eq!(SimdOps::std_dev_decimal(&[]), dec!(0));
753
754 assert_eq!(SimdOps::variance_f64(&[42.0]), 0.0);
756 assert_eq!(SimdOps::std_dev_f64(&[42.0]), 0.0);
757 assert_eq!(SimdOps::variance_decimal(&[dec!(42.0)]), dec!(0));
758 assert_eq!(SimdOps::std_dev_decimal(&[dec!(42.0)]), dec!(0));
759
760 assert_eq!(SimdOps::variance_f64(&[5.0, 5.0, 5.0, 5.0]), 0.0);
762 assert_eq!(SimdOps::std_dev_f64(&[5.0, 5.0, 5.0, 5.0]), 0.0);
763 assert_eq!(
764 SimdOps::variance_decimal(&[dec!(5.0), dec!(5.0), dec!(5.0), dec!(5.0)]),
765 dec!(0)
766 );
767 assert_eq!(
768 SimdOps::std_dev_decimal(&[dec!(5.0), dec!(5.0), dec!(5.0), dec!(5.0)]),
769 dec!(0)
770 );
771 }
772
773 #[test]
774 fn test_large_arrays() {
775 let values_f64: Vec<f64> = (1..=100).map(f64::from).collect();
777 let values_decimal: Vec<Decimal> = (1..=100).map(Decimal::from).collect();
778
779 assert_eq!(SimdOps::sum_f64(&values_f64), 5050.0);
781 assert_eq!(SimdOps::sum_decimal(&values_decimal), dec!(5050));
782 assert_eq!(SimdOps::min_f64(&values_f64), 1.0);
783 assert_eq!(SimdOps::max_f64(&values_f64), 100.0);
784 assert_eq!(SimdOps::min_decimal(&values_decimal), dec!(1));
785 assert_eq!(SimdOps::max_decimal(&values_decimal), dec!(100));
786 assert_eq!(SimdOps::mean_f64(&values_f64), 50.5);
787 assert_eq!(SimdOps::mean_decimal(&values_decimal), dec!(50.5));
788
789 let expected_variance = 833.25;
791 let variance = SimdOps::variance_f64(&values_f64);
792 assert!((variance - expected_variance).abs() < 0.01);
793
794 let expected_std_dev = expected_variance.sqrt();
796 let std_dev = SimdOps::std_dev_f64(&values_f64);
797 assert!((std_dev - expected_std_dev).abs() < 0.01);
798
799 let variance_decimal = SimdOps::variance_decimal(&values_decimal);
801 assert!(variance_decimal > dec!(833.2) && variance_decimal < dec!(833.3));
802
803 let std_dev_decimal = SimdOps::std_dev_decimal(&values_decimal);
804 assert!(std_dev_decimal > dec!(28.86) && std_dev_decimal < dec!(28.87));
805 }
806
807 #[test]
808 fn test_nan_safety() {
809 let values_with_nan = vec![1.0, f64::NAN, 3.0, 4.0];
811
812 let sum = SimdOps::sum_f64(&values_with_nan);
814 assert!(sum.is_nan());
815
816 let min = SimdOps::min_f64(&values_with_nan);
818 assert!(min.is_nan());
819
820 let max = SimdOps::max_f64(&values_with_nan);
821 assert!(max.is_nan());
822
823 let a = vec![1.0, 2.0, f64::NAN, 4.0];
825 let b = vec![1.0, 2.0, 3.0, 4.0];
826 let dot = SimdOps::dot_product_f64(&a, &b);
827 assert!(dot.is_nan());
828 }
829
830 #[test]
833 fn test_comprehensive_nan_in_sum() {
834 let values = vec![1.0, 2.0, f64::NAN, 4.0, 5.0];
836 let sum = SimdOps::sum_f64(&values);
837 assert!(sum.is_nan(), "Sum with NaN should be NaN");
838
839 let all_nan = vec![f64::NAN; 8];
841 let sum = SimdOps::sum_f64(&all_nan);
842 assert!(sum.is_nan(), "Sum of all NaN should be NaN");
843
844 let single_nan = vec![f64::NAN];
846 let sum = SimdOps::sum_f64(&single_nan);
847 assert!(sum.is_nan(), "Sum of single NaN should be NaN");
848
849 let nan_start = vec![f64::NAN, 1.0, 2.0, 3.0];
851 let sum = SimdOps::sum_f64(&nan_start);
852 assert!(sum.is_nan(), "Sum with NaN at start should be NaN");
853
854 let nan_end = vec![1.0, 2.0, 3.0, f64::NAN];
856 let sum = SimdOps::sum_f64(&nan_end);
857 assert!(sum.is_nan(), "Sum with NaN at end should be NaN");
858 }
859
860 #[test]
861 fn test_comprehensive_nan_in_min_max() {
862 let values = vec![3.0, f64::NAN, 1.0, 4.0];
868 let min = SimdOps::min_f64(&values);
869 assert_eq!(min, 1.0, "Min should ignore NaN per IEEE 754-2008");
871
872 let max = SimdOps::max_f64(&values);
873 assert_eq!(max, 4.0, "Max should ignore NaN per IEEE 754-2008");
875
876 let all_nan = vec![f64::NAN; 4];
878 assert!(SimdOps::min_f64(&all_nan).is_nan());
879 assert!(SimdOps::max_f64(&all_nan).is_nan());
880
881 let single_nan = vec![f64::NAN];
883 assert!(SimdOps::min_f64(&single_nan).is_nan());
884 assert!(SimdOps::max_f64(&single_nan).is_nan());
885 }
886
887 #[test]
888 fn test_comprehensive_nan_in_mean() {
889 let values = vec![1.0, 2.0, f64::NAN, 4.0];
890 let mean = SimdOps::mean_f64(&values);
891 assert!(mean.is_nan(), "Mean with NaN should be NaN");
892
893 let empty: Vec<f64> = vec![];
895 let mean = SimdOps::mean_f64(&empty);
896 assert!(mean.is_nan(), "Mean of empty array should be NaN");
897 }
898
899 #[test]
900 fn test_comprehensive_nan_in_variance_std_dev() {
901 let values = vec![1.0, 2.0, f64::NAN, 4.0];
902 let variance = SimdOps::variance_f64(&values);
903 assert!(variance.is_nan(), "Variance with NaN should be NaN");
904
905 let std_dev = SimdOps::std_dev_f64(&values);
906 assert!(std_dev.is_nan(), "Std dev with NaN should be NaN");
907 }
908
909 #[test]
910 fn test_infinity_handling() {
911 let pos_inf = vec![1.0, 2.0, f64::INFINITY, 4.0];
913 let sum = SimdOps::sum_f64(&pos_inf);
914 assert_eq!(sum, f64::INFINITY, "Sum with +inf should be +inf");
915
916 let neg_inf = vec![1.0, 2.0, f64::NEG_INFINITY, 4.0];
918 let sum = SimdOps::sum_f64(&neg_inf);
919 assert_eq!(sum, f64::NEG_INFINITY, "Sum with -inf should be -inf");
920
921 let both_inf = vec![f64::INFINITY, f64::NEG_INFINITY];
923 let sum = SimdOps::sum_f64(&both_inf);
924 assert!(sum.is_nan(), "Sum of +inf and -inf should be NaN");
925
926 let inf_values = vec![1.0, f64::INFINITY, 3.0];
928 assert_eq!(SimdOps::min_f64(&inf_values), 1.0);
929 assert_eq!(SimdOps::max_f64(&inf_values), f64::INFINITY);
930
931 let neg_inf_values = vec![1.0, f64::NEG_INFINITY, 3.0];
932 assert_eq!(SimdOps::min_f64(&neg_inf_values), f64::NEG_INFINITY);
933 assert_eq!(SimdOps::max_f64(&neg_inf_values), 3.0);
934 }
935
936 #[test]
937 fn test_subnormal_numbers() {
938 let subnormals = vec![
940 f64::MIN_POSITIVE / 2.0,
941 f64::MIN_POSITIVE / 4.0,
942 f64::MIN_POSITIVE / 8.0,
943 f64::MIN_POSITIVE,
944 ];
945
946 let sum = SimdOps::sum_f64(&subnormals);
947 assert!(sum > 0.0, "Sum of subnormals should be positive");
948 assert!(sum < f64::MIN_POSITIVE * 2.0, "Sum should be very small");
949
950 let min = SimdOps::min_f64(&subnormals);
951 assert_eq!(min, f64::MIN_POSITIVE / 8.0);
952
953 let max = SimdOps::max_f64(&subnormals);
954 assert_eq!(max, f64::MIN_POSITIVE);
955 }
956
957 #[test]
958 fn test_zero_handling() {
959 let zeros = vec![0.0, -0.0, 0.0, -0.0];
961 let sum = SimdOps::sum_f64(&zeros);
962 assert_eq!(sum, 0.0);
963
964 let signed_zeros = vec![-0.0, 0.0];
966 let min = SimdOps::min_f64(&signed_zeros);
967 let max = SimdOps::max_f64(&signed_zeros);
968 assert_eq!(min, -0.0);
970 assert_eq!(max, 0.0);
971 }
972
973 #[test]
974 fn test_large_arrays_edge_cases() {
975 let mut large = vec![1.0; 10000];
977 large[5000] = f64::NAN;
978 let sum = SimdOps::sum_f64(&large);
979 assert!(sum.is_nan(), "Large array with NaN should propagate NaN");
980
981 let overflow_risk = vec![f64::MAX / 10.0; 20];
983 let sum = SimdOps::sum_f64(&overflow_risk);
984 assert_eq!(sum, f64::INFINITY, "Overflow should produce infinity");
985 }
986
987 #[test]
988 fn test_boundary_sizes() {
989 for size in [1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65] {
991 let values: Vec<f64> = (0..size).map(f64::from).collect();
992 let expected_sum = f64::from(size * (size - 1)) / 2.0;
993 let sum = SimdOps::sum_f64(&values);
994 assert_eq!(sum, expected_sum, "Sum failed for size {size}");
995
996 let min = SimdOps::min_f64(&values);
997 assert_eq!(min, 0.0, "Min failed for size {size}");
998
999 let max = SimdOps::max_f64(&values);
1000 assert_eq!(max, f64::from(size - 1), "Max failed for size {size}");
1001
1002 let mean = SimdOps::mean_f64(&values);
1003 assert_eq!(
1004 mean,
1005 f64::from(size - 1) / 2.0,
1006 "Mean failed for size {size}"
1007 );
1008 }
1009 }
1010
1011 #[test]
1012 fn test_dot_product_comprehensive_edge_cases() {
1013 let a = vec![1.0, 2.0, 3.0];
1015 let b = vec![4.0, 5.0];
1016 let dot = SimdOps::dot_product_f64(&a, &b);
1017 assert_eq!(
1018 dot, 0.0,
1019 "Dot product with different lengths should return 0"
1020 );
1021
1022 let a_nan = vec![1.0, f64::NAN, 3.0];
1024 let b_normal = vec![4.0, 5.0, 6.0];
1025 let dot_nan = SimdOps::dot_product_f64(&a_nan, &b_normal);
1026 assert!(dot_nan.is_nan(), "Dot product with NaN should be NaN");
1027
1028 let empty_a: Vec<f64> = vec![];
1030 let empty_b: Vec<f64> = vec![];
1031 let dot_empty = SimdOps::dot_product_f64(&empty_a, &empty_b);
1032 assert_eq!(dot_empty, 0.0, "Dot product of empty arrays should be 0");
1033 }
1034
1035 #[test]
1036 fn test_precision_edge_cases() {
1037 let precise = vec![1e-10, 1e-10, 1e-10, 1e-10, 1e10, 1e10, 1e10, 1e10];
1039 let sum = SimdOps::sum_f64(&precise);
1040 assert!(sum >= 4e10, "Sum should preserve large values");
1042
1043 let identical = vec![42.42; 100];
1045 let variance = SimdOps::variance_f64(&identical);
1046 assert!(
1047 variance.abs() < 1e-10,
1048 "Variance of identical values should be near 0, got {variance}"
1049 );
1050 }
1051
1052 #[test]
1053 fn test_decimal_nan_equivalent() {
1054 use rust_decimal::Decimal;
1057
1058 let values = vec![Decimal::MAX / dec!(10), dec!(1.0), dec!(2.0)];
1060
1061 let sum = SimdOps::sum_decimal(&values);
1063 assert!(sum > dec!(0));
1065
1066 let precise_values = vec![
1068 dec!(0.1111111111111111111111111111),
1069 dec!(0.2222222222222222222222222222),
1070 dec!(0.3333333333333333333333333333),
1071 ];
1072 let sum = SimdOps::sum_decimal(&precise_values);
1073 assert!((sum - dec!(0.6666666666666666666666666666)).abs() < dec!(0.0001));
1075 }
1076}