rusty_model/
simd.rs

1use rust_decimal::Decimal;
2use rust_decimal::prelude::{FromPrimitive, ToPrimitive};
3use wide::f64x4;
4
5// Import SIMD macros from rusty-common to eliminate manual loop unrolling
6use rusty_common::{
7    extract_max_f64x4, extract_min_f64x4, extract_sum_f64x4, simd_dual_op, simd_reduce,
8};
9
10/// Safe SIMD-accelerated operations for performance-critical paths
11/// Uses the `wide` crate for portable, safe SIMD operations
12///
13/// This struct provides vectorized math operations for numerical processing
14/// without any unsafe code. It automatically uses SIMD when beneficial
15/// and falls back to scalar operations for small arrays.
16///
17/// # Performance considerations
18/// - Uses safe `wide` f64x4 operations for parallel processing
19/// - Portable across all platforms (not limited to `x86_64`)
20/// - NaN-safe operations with hardware NaN propagation
21/// - Avoids heap allocations for small to medium-sized arrays
22/// - Handles edge cases like empty arrays and single-element arrays efficiently
23pub struct SimdOps;
24
25impl SimdOps {
26    /// Apply safe SIMD-accelerated sum to an array of f64 values
27    /// Uses wide f64x4 operations for maximum performance and safety
28    /// Optimized with different strategies based on array size
29    #[inline]
30    #[must_use]
31    pub fn sum_f64(values: &[f64]) -> f64 {
32        if values.is_empty() {
33            return 0.0;
34        }
35
36        if values.len() == 1 {
37            return values[0];
38        }
39
40        // For very small arrays, scalar is faster due to SIMD setup overhead
41        if values.len() < 4 {
42            return values.iter().sum();
43        }
44
45        // Use SIMD for medium to large arrays
46        Self::sum_f64_wide(values)
47    }
48
49    /// Convert an array of decimal values to f64 for SIMD operations
50    #[inline]
51    #[must_use]
52    pub fn decimal_to_f64(values: &[Decimal]) -> Vec<f64> {
53        values
54            .iter()
55            .map(|v| rusty_common::decimal_utils::decimal_to_f64_or_nan(*v))
56            .collect()
57    }
58
59    /// Sum an array of Decimal values using safe SIMD acceleration
60    #[inline]
61    #[must_use]
62    pub fn sum_decimal(values: &[Decimal]) -> Decimal {
63        if values.is_empty() {
64            return Decimal::ZERO;
65        }
66
67        if values.len() == 1 {
68            return values[0];
69        }
70
71        // For very small arrays, use scalar sum directly
72        if values.len() < 4 {
73            return values.iter().sum();
74        }
75
76        // Use stack allocation for small to medium arrays to avoid heap allocation
77        if values.len() <= 64 {
78            let mut f64_values = [0.0f64; 64];
79            for (i, &value) in values.iter().enumerate() {
80                if i >= 64 {
81                    break;
82                }
83                let converted = rusty_common::decimal_utils::decimal_to_f64_or_nan(value);
84                f64_values[i] = if converted.is_nan() {
85                    0.0 // Replace NaN with a default value
86                } else {
87                    converted
88                };
89            }
90
91            // Perform SIMD sum on the stack-allocated array
92            let sum_f64 = Self::sum_f64(&f64_values[..values.len()]);
93
94            // Convert back to Decimal with appropriate precision
95            return Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO);
96        }
97
98        // For larger arrays, fall back to heap allocation
99        // Convert to f64 for SIMD operations
100        let f64_values = Self::decimal_to_f64(values);
101
102        // Perform SIMD sum
103        let sum_f64 = Self::sum_f64(&f64_values);
104
105        // Convert back to Decimal with appropriate precision
106        // Note: This may lose some precision, but should be acceptable for most use cases
107        Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO)
108    }
109
110    /// Sum the size values from an array of `PriceLevel` structs using SIMD acceleration
111    /// This avoids creating a temporary Vec<Decimal> for the size values
112    #[inline]
113    pub fn sum_price_level_sizes<T>(levels: &[T], extract_size: impl Fn(&T) -> Decimal) -> Decimal {
114        if levels.is_empty() {
115            return Decimal::ZERO;
116        }
117
118        if levels.len() == 1 {
119            return extract_size(&levels[0]);
120        }
121
122        // For very small arrays, use scalar sum directly
123        if levels.len() < 4 {
124            return levels.iter().map(extract_size).sum();
125        }
126
127        // Use stack allocation for small to medium arrays to avoid heap allocation
128        if levels.len() <= 64 {
129            let mut f64_values = [0.0f64; 64];
130            for (i, level) in levels.iter().enumerate() {
131                if i >= 64 {
132                    break;
133                }
134                f64_values[i] =
135                    rusty_common::decimal_utils::decimal_to_f64_or_nan(extract_size(level));
136            }
137
138            // Perform SIMD sum on the stack-allocated array
139            let sum_f64 = Self::sum_f64(&f64_values[..levels.len()]);
140
141            // Convert back to Decimal with appropriate precision
142            return Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO);
143        }
144
145        // For larger arrays, fall back to heap allocation
146        // Extract size values and convert to f64 for SIMD operations
147        let f64_values: Vec<f64> = levels
148            .iter()
149            .map(|level| rusty_common::decimal_utils::decimal_to_f64_or_nan(extract_size(level)))
150            .collect();
151
152        // Perform SIMD sum
153        let sum_f64 = Self::sum_f64(&f64_values);
154
155        // Convert back to Decimal with appropriate precision
156        Decimal::from_f64(sum_f64).unwrap_or(Decimal::ZERO)
157    }
158
159    /// Apply safe SIMD-accelerated min to an array of f64 values
160    /// Uses wide f64x4 operations with NaN-safe comparisons
161    #[inline]
162    #[must_use]
163    pub fn min_f64(values: &[f64]) -> f64 {
164        if values.is_empty() {
165            return f64::NAN;
166        }
167
168        if values.len() == 1 {
169            return values[0];
170        }
171
172        if values.len() >= 4 {
173            return Self::min_f64_wide(values);
174        }
175
176        // Fallback to scalar implementation
177        let mut min_val = values[0];
178        for &val in &values[1..] {
179            min_val = min_val.min(val);
180        }
181        min_val
182    }
183
184    /// Apply safe SIMD-accelerated max to an array of f64 values
185    /// Uses wide f64x4 operations with NaN-safe comparisons
186    #[inline]
187    #[must_use]
188    pub fn max_f64(values: &[f64]) -> f64 {
189        if values.is_empty() {
190            return f64::NAN;
191        }
192
193        if values.len() == 1 {
194            return values[0];
195        }
196
197        if values.len() >= 4 {
198            return Self::max_f64_wide(values);
199        }
200
201        // Fallback to scalar implementation
202        let mut max_val = values[0];
203        for &val in &values[1..] {
204            max_val = max_val.max(val);
205        }
206        max_val
207    }
208
209    /// Calculate the min value of an array of Decimal values using SIMD acceleration
210    #[inline]
211    #[must_use]
212    pub fn min_decimal(values: &[Decimal]) -> Decimal {
213        if values.is_empty() {
214            return Decimal::ZERO; // Could return NaN if Decimal supported it
215        }
216
217        if values.len() == 1 {
218            return values[0];
219        }
220
221        // For very small arrays, use scalar min directly
222        if values.len() < 4 {
223            let mut min_val = values[0];
224            for &val in &values[1..] {
225                if val < min_val {
226                    min_val = val;
227                }
228            }
229            return min_val;
230        }
231
232        // Convert to f64 for SIMD operations
233        let f64_values = Self::decimal_to_f64(values);
234
235        // Perform SIMD min
236        let min_f64 = Self::min_f64(&f64_values);
237
238        // Find the original Decimal value that corresponds to the min f64 value
239        // This is more accurate than converting back from f64
240        for &val in values {
241            if let Some(f64_val) = val.to_f64()
242                && (f64_val - min_f64).abs() < f64::EPSILON
243            {
244                return val;
245            }
246        }
247
248        // Fallback to converting from f64 if exact match not found
249        Decimal::from_f64(min_f64).unwrap_or(Decimal::ZERO)
250    }
251
252    /// Calculate the max value of an array of Decimal values using SIMD acceleration
253    #[inline]
254    #[must_use]
255    pub fn max_decimal(values: &[Decimal]) -> Decimal {
256        if values.is_empty() {
257            return Decimal::ZERO; // Could return NaN if Decimal supported it
258        }
259
260        if values.len() == 1 {
261            return values[0];
262        }
263
264        // For very small arrays, use scalar max directly
265        if values.len() < 4 {
266            let mut max_val = values[0];
267            for &val in &values[1..] {
268                if val > max_val {
269                    max_val = val;
270                }
271            }
272            return max_val;
273        }
274
275        // Convert to f64 for SIMD operations
276        let f64_values = Self::decimal_to_f64(values);
277
278        // Perform SIMD max
279        let max_f64 = Self::max_f64(&f64_values);
280
281        // Find the original Decimal value that corresponds to the max f64 value
282        // This is more accurate than converting back from f64
283        for &val in values {
284            if let Some(f64_val) = val.to_f64()
285                && (f64_val - max_f64).abs() < f64::EPSILON
286            {
287                return val;
288            }
289        }
290
291        // Fallback to converting from f64 if exact match not found
292        Decimal::from_f64(max_f64).unwrap_or(Decimal::ZERO)
293    }
294
295    /// Calculate the dot product of two f64 arrays using safe SIMD acceleration
296    #[inline]
297    #[must_use]
298    pub fn dot_product_f64(a: &[f64], b: &[f64]) -> f64 {
299        if a.is_empty() || b.is_empty() || a.len() != b.len() {
300            return 0.0;
301        }
302
303        if a.len() == 1 {
304            return a[0] * b[0];
305        }
306
307        if a.len() >= 4 {
308            return Self::dot_product_f64_wide(a, b);
309        }
310
311        // Fallback to scalar implementation
312        let mut sum = 0.0;
313        for i in 0..a.len() {
314            sum += a[i] * b[i];
315        }
316        sum
317    }
318
319    /// Calculate the mean (average) of an array of f64 values using SIMD acceleration
320    #[inline]
321    #[must_use]
322    #[allow(clippy::cast_precision_loss)] // Array sizes in financial computing are << 2^52
323    pub fn mean_f64(values: &[f64]) -> f64 {
324        if values.is_empty() {
325            return f64::NAN;
326        }
327
328        if values.len() == 1 {
329            return values[0];
330        }
331
332        let sum = Self::sum_f64(values);
333        sum / (values.len() as f64)
334    }
335
336    /// Calculate the mean (average) of an array of Decimal values using SIMD acceleration
337    #[inline]
338    #[must_use]
339    pub fn mean_decimal(values: &[Decimal]) -> Decimal {
340        if values.is_empty() {
341            return Decimal::ZERO; // Could return NaN if Decimal supported it
342        }
343
344        if values.len() == 1 {
345            return values[0];
346        }
347
348        let sum = Self::sum_decimal(values);
349        sum / Decimal::from(values.len())
350    }
351
352    /// Calculate the variance of an array of f64 values using SIMD acceleration with Welford's algorithm
353    /// This is the population variance (divides by n, not n-1) - more numerically stable than naive method
354    #[inline]
355    #[must_use]
356    pub fn variance_f64(values: &[f64]) -> f64 {
357        if values.is_empty() {
358            return f64::NAN;
359        }
360
361        if values.len() == 1 {
362            return 0.0; // Variance of a single value is 0
363        }
364
365        // Use Welford's online algorithm for better numerical stability
366        let mut count = 0;
367        let mut mean = 0.0;
368        let mut m2 = 0.0;
369
370        let mut i = 0;
371        let n = values.len();
372
373        // Process 4 elements at a time with SIMD-assisted Welford
374        while i + 4 <= n {
375            let chunk = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
376            let chunk_array = chunk.as_array_ref();
377
378            // Update running statistics for each element
379            for &value in chunk_array {
380                if !value.is_nan() {
381                    count += 1;
382                    let delta = value - mean;
383                    mean += delta / f64::from(count);
384                    let delta2 = value - mean;
385                    m2 += delta * delta2;
386                }
387            }
388
389            i += 4;
390        }
391
392        // Handle remaining elements
393        while i < n {
394            let value = values[i];
395            if !value.is_nan() {
396                count += 1;
397                let delta = value - mean;
398                mean += delta / f64::from(count);
399                let delta2 = value - mean;
400                m2 += delta * delta2;
401            }
402            i += 1;
403        }
404
405        if count > 0 {
406            m2 / f64::from(count)
407        } else {
408            f64::NAN
409        }
410    }
411
412    /// Calculate the standard deviation of an array of f64 values using SIMD acceleration
413    /// This is the population standard deviation (divides by n, not n-1)
414    #[inline]
415    #[must_use]
416    pub fn std_dev_f64(values: &[f64]) -> f64 {
417        if values.is_empty() {
418            return f64::NAN;
419        }
420
421        if values.len() == 1 {
422            return 0.0; // Standard deviation of a single value is 0
423        }
424
425        let variance = Self::variance_f64(values);
426        variance.sqrt()
427    }
428
429    /// Calculate the variance of an array of Decimal values using SIMD acceleration
430    /// This is the population variance (divides by n, not n-1)
431    #[inline]
432    #[must_use]
433    pub fn variance_decimal(values: &[Decimal]) -> Decimal {
434        if values.is_empty() {
435            return Decimal::ZERO; // Could return NaN if Decimal supported it
436        }
437
438        if values.len() == 1 {
439            return Decimal::ZERO; // Variance of a single value is 0
440        }
441
442        // Calculate mean
443        let mean = Self::mean_decimal(values);
444
445        // For very small arrays, use scalar calculation directly
446        if values.len() < 4 {
447            let mut sum_squared_diff = Decimal::ZERO;
448            for &val in values {
449                let diff = val - mean;
450                sum_squared_diff += diff * diff;
451            }
452            return sum_squared_diff / Decimal::from(values.len());
453        }
454
455        // For larger arrays, convert to f64 for SIMD operations
456        let f64_values = Self::decimal_to_f64(values);
457        let f64_mean = rusty_common::decimal_utils::decimal_to_f64_or_nan(mean);
458
459        // Calculate squared differences
460        let mut squared_diffs = vec![0.0; values.len()];
461        for (i, &val) in f64_values.iter().enumerate() {
462            squared_diffs[i] = (val - f64_mean).powi(2);
463        }
464
465        // Calculate mean of squared differences (variance)
466        let variance_f64 = Self::mean_f64(&squared_diffs);
467
468        // Convert back to Decimal
469        Decimal::from_f64(variance_f64).unwrap_or(Decimal::ZERO)
470    }
471
472    /// Calculate the standard deviation of an array of Decimal values using SIMD acceleration
473    /// This is the population standard deviation (divides by n, not n-1)
474    #[inline]
475    #[must_use]
476    pub fn std_dev_decimal(values: &[Decimal]) -> Decimal {
477        if values.is_empty() {
478            return Decimal::ZERO; // Could return NaN if Decimal supported it
479        }
480
481        if values.len() == 1 {
482            return Decimal::ZERO; // Standard deviation of a single value is 0
483        }
484
485        let variance = Self::variance_decimal(values);
486
487        // Decimal doesn't have a sqrt method, so we need to convert to f64
488        if let Some(variance_f64) = variance.to_f64() {
489            let std_dev_f64 = variance_f64.sqrt();
490            return Decimal::from_f64(std_dev_f64).unwrap_or(Decimal::ZERO);
491        }
492
493        // Fallback if conversion fails
494        Decimal::ZERO
495    }
496
497    /// Safe wide f64x4 implementation of sum for f64 arrays with macro-based loop unrolling
498    #[inline]
499    fn sum_f64_wide(values: &[f64]) -> f64 {
500        simd_reduce!(
501            values,
502            2,
503            8,                                  // Use dual accumulators with 8-element chunks
504            f64x4::ZERO,                        // Initialize with zeros
505            |acc: f64x4, val: f64x4| acc + val, // Add vectors to accumulator (NaN-safe)
506            |acc1: f64x4, acc2: f64x4| {
507                let total: f64x4 = acc1 + acc2; // Combine the accumulators
508                extract_sum_f64x4!(total) // Extract and sum all elements
509            },
510            0.0,                            // Scalar initial value
511            |acc: f64, val: f64| acc + val  // Scalar addition for remainder
512        )
513    }
514
515    /// Safe wide f64x4 implementation of min for f64 arrays with macro-based loop unrolling
516    #[inline]
517    fn min_f64_wide(values: &[f64]) -> f64 {
518        simd_reduce!(
519            values,
520            1,
521            4,                                     // Single accumulator with 4-element chunks
522            f64x4::splat(values[0]),               // Initialize with first element
523            |acc: f64x4, val: f64x4| acc.min(val), // Find minimum (NaN-safe)
524            |acc: f64x4| extract_min_f64x4!(acc),  // Extract minimum from vector
525            values[0],                             // Scalar initial value
526            |acc: f64, val: f64| acc.min(val)      // Scalar minimum for remainder
527        )
528    }
529
530    /// Safe wide f64x4 implementation of max for f64 arrays with macro-based loop unrolling
531    #[inline]
532    fn max_f64_wide(values: &[f64]) -> f64 {
533        simd_reduce!(
534            values,
535            1,
536            4,                                     // Single accumulator with 4-element chunks
537            f64x4::splat(values[0]),               // Initialize with first element
538            |acc: f64x4, val: f64x4| acc.max(val), // Find maximum (NaN-safe)
539            |acc: f64x4| extract_max_f64x4!(acc),  // Extract maximum from vector
540            values[0],                             // Scalar initial value
541            |acc: f64, val: f64| acc.max(val)      // Scalar maximum for remainder
542        )
543    }
544
545    /// Safe wide f64x4 implementation of dot product for f64 arrays with macro-based loop unrolling
546    #[inline]
547    fn dot_product_f64_wide(a: &[f64], b: &[f64]) -> f64 {
548        simd_dual_op!(
549            a,
550            b,
551            1,
552            4,           // Single accumulator with 4-element chunks
553            f64x4::ZERO, // Initialize accumulator with zeros
554            |acc: f64x4, a_vec: f64x4, b_vec: f64x4| acc + (a_vec * b_vec), // Fused multiply-add (NaN-safe)
555            |acc: f64x4| extract_sum_f64x4!(acc), // Extract and sum all elements
556            0.0,                                  // Scalar initial value
557            |acc: f64, a_val: f64, b_val: f64| a_val.mul_add(b_val, acc) // Scalar multiply-add for remainder
558        )
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use rust_decimal_macros::dec;
566
567    #[test]
568    fn test_sum_f64() {
569        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
570        let sum = SimdOps::sum_f64(&values);
571        assert!((sum - 36.0).abs() < f64::EPSILON);
572    }
573
574    #[test]
575    fn test_sum_decimal() {
576        let values = vec![
577            dec!(1.0),
578            dec!(2.0),
579            dec!(3.0),
580            dec!(4.0),
581            dec!(5.0),
582            dec!(6.0),
583            dec!(7.0),
584            dec!(8.0),
585        ];
586        let sum = SimdOps::sum_decimal(&values);
587        assert_eq!(sum, dec!(36.0));
588    }
589
590    #[test]
591    fn test_min_f64() {
592        let values = vec![3.0, 1.0, 7.0, 4.0, 5.0, 2.0, 8.0, 6.0];
593        let min = SimdOps::min_f64(&values);
594        assert!((min - 1.0).abs() < f64::EPSILON);
595    }
596
597    #[test]
598    fn test_max_f64() {
599        let values = vec![3.0, 1.0, 7.0, 4.0, 5.0, 2.0, 8.0, 6.0];
600        let max = SimdOps::max_f64(&values);
601        assert!((max - 8.0).abs() < f64::EPSILON);
602    }
603
604    #[test]
605    fn test_min_decimal() {
606        let values = vec![
607            dec!(3.0),
608            dec!(1.0),
609            dec!(7.0),
610            dec!(4.0),
611            dec!(5.0),
612            dec!(2.0),
613            dec!(8.0),
614            dec!(6.0),
615        ];
616        let min = SimdOps::min_decimal(&values);
617        assert_eq!(min, dec!(1.0));
618    }
619
620    #[test]
621    fn test_max_decimal() {
622        let values = vec![
623            dec!(3.0),
624            dec!(1.0),
625            dec!(7.0),
626            dec!(4.0),
627            dec!(5.0),
628            dec!(2.0),
629            dec!(8.0),
630            dec!(6.0),
631        ];
632        let max = SimdOps::max_decimal(&values);
633        assert_eq!(max, dec!(8.0));
634    }
635
636    #[test]
637    fn test_dot_product_f64() {
638        let a = vec![1.0, 2.0, 3.0, 4.0];
639        let b = vec![5.0, 6.0, 7.0, 8.0];
640        // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
641        let dot = SimdOps::dot_product_f64(&a, &b);
642        assert!((dot - 70.0).abs() < f64::EPSILON);
643    }
644
645    #[test]
646    fn test_mean_f64() {
647        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
648        let mean = SimdOps::mean_f64(&values);
649        assert!((mean - 4.5).abs() < f64::EPSILON);
650    }
651
652    #[test]
653    fn test_mean_decimal() {
654        let values = vec![
655            dec!(1.0),
656            dec!(2.0),
657            dec!(3.0),
658            dec!(4.0),
659            dec!(5.0),
660            dec!(6.0),
661            dec!(7.0),
662            dec!(8.0),
663        ];
664        let mean = SimdOps::mean_decimal(&values);
665        assert_eq!(mean, dec!(4.5));
666    }
667
668    #[test]
669    fn test_edge_cases() {
670        // Empty array
671        assert_eq!(SimdOps::sum_f64(&[]), 0.0);
672        assert_eq!(SimdOps::sum_decimal(&[]), dec!(0));
673        assert!(SimdOps::min_f64(&[]).is_nan());
674        assert!(SimdOps::max_f64(&[]).is_nan());
675        assert_eq!(SimdOps::min_decimal(&[]), dec!(0));
676        assert_eq!(SimdOps::max_decimal(&[]), dec!(0));
677        assert!(SimdOps::mean_f64(&[]).is_nan());
678        assert_eq!(SimdOps::mean_decimal(&[]), dec!(0));
679        assert_eq!(SimdOps::dot_product_f64(&[], &[1.0]), 0.0);
680        assert_eq!(SimdOps::dot_product_f64(&[1.0], &[]), 0.0);
681        assert_eq!(SimdOps::dot_product_f64(&[1.0, 2.0], &[1.0]), 0.0); // Mismatched lengths
682
683        // Single element
684        assert_eq!(SimdOps::sum_f64(&[42.0]), 42.0);
685        assert_eq!(SimdOps::sum_decimal(&[dec!(42.0)]), dec!(42.0));
686        assert_eq!(SimdOps::min_f64(&[42.0]), 42.0);
687        assert_eq!(SimdOps::max_f64(&[42.0]), 42.0);
688        assert_eq!(SimdOps::min_decimal(&[dec!(42.0)]), dec!(42.0));
689        assert_eq!(SimdOps::max_decimal(&[dec!(42.0)]), dec!(42.0));
690        assert_eq!(SimdOps::mean_f64(&[42.0]), 42.0);
691        assert_eq!(SimdOps::mean_decimal(&[dec!(42.0)]), dec!(42.0));
692        assert_eq!(SimdOps::dot_product_f64(&[3.0], &[4.0]), 12.0);
693    }
694
695    #[test]
696    fn test_small_arrays() {
697        // Arrays with fewer than 4 elements (will use scalar fallback)
698        let values_f64 = [1.0, 2.0, 3.0];
699        let values_decimal = [dec!(1.0), dec!(2.0), dec!(3.0)];
700
701        assert_eq!(SimdOps::sum_f64(&values_f64), 6.0);
702        assert_eq!(SimdOps::sum_decimal(&values_decimal), dec!(6.0));
703        assert_eq!(SimdOps::min_f64(&values_f64), 1.0);
704        assert_eq!(SimdOps::max_f64(&values_f64), 3.0);
705        assert_eq!(SimdOps::min_decimal(&values_decimal), dec!(1.0));
706        assert_eq!(SimdOps::max_decimal(&values_decimal), dec!(3.0));
707        assert_eq!(SimdOps::mean_f64(&values_f64), 2.0);
708        assert_eq!(SimdOps::mean_decimal(&values_decimal), dec!(2.0));
709    }
710
711    #[test]
712    fn test_variance_f64() {
713        // Test with a simple array where variance is easy to calculate
714        // [2, 4, 6, 8] has mean 5 and variance ((2-5)² + (4-5)² + (6-5)² + (8-5)²) / 4 = (9 + 1 + 1 + 9) / 4 = 5
715        let values = vec![2.0, 4.0, 6.0, 8.0];
716        let variance = SimdOps::variance_f64(&values);
717        assert_eq!(variance, 5.0);
718    }
719
720    #[test]
721    fn test_std_dev_f64() {
722        // Test with a simple array where standard deviation is easy to calculate
723        // [2, 4, 6, 8] has variance 5, so std_dev = √5 ≈ 2.236
724        let values = vec![2.0, 4.0, 6.0, 8.0];
725        let std_dev = SimdOps::std_dev_f64(&values);
726        assert!((std_dev - 2.236).abs() < 0.001);
727    }
728
729    #[test]
730    fn test_variance_decimal() {
731        // Test with a simple array where variance is easy to calculate
732        let values = vec![dec!(2.0), dec!(4.0), dec!(6.0), dec!(8.0)];
733        let variance = SimdOps::variance_decimal(&values);
734        assert_eq!(variance, dec!(5.0));
735    }
736
737    #[test]
738    fn test_std_dev_decimal() {
739        // Test with a simple array where standard deviation is easy to calculate
740        let values = vec![dec!(2.0), dec!(4.0), dec!(6.0), dec!(8.0)];
741        let std_dev = SimdOps::std_dev_decimal(&values);
742        // Check that it's approximately 2.236
743        assert!(std_dev > dec!(2.235) && std_dev < dec!(2.237));
744    }
745
746    #[test]
747    fn test_variance_std_dev_edge_cases() {
748        // Empty array
749        assert!(SimdOps::variance_f64(&[]).is_nan());
750        assert!(SimdOps::std_dev_f64(&[]).is_nan());
751        assert_eq!(SimdOps::variance_decimal(&[]), dec!(0));
752        assert_eq!(SimdOps::std_dev_decimal(&[]), dec!(0));
753
754        // Single element (variance and std_dev should be 0)
755        assert_eq!(SimdOps::variance_f64(&[42.0]), 0.0);
756        assert_eq!(SimdOps::std_dev_f64(&[42.0]), 0.0);
757        assert_eq!(SimdOps::variance_decimal(&[dec!(42.0)]), dec!(0));
758        assert_eq!(SimdOps::std_dev_decimal(&[dec!(42.0)]), dec!(0));
759
760        // All elements the same (variance and std_dev should be 0)
761        assert_eq!(SimdOps::variance_f64(&[5.0, 5.0, 5.0, 5.0]), 0.0);
762        assert_eq!(SimdOps::std_dev_f64(&[5.0, 5.0, 5.0, 5.0]), 0.0);
763        assert_eq!(
764            SimdOps::variance_decimal(&[dec!(5.0), dec!(5.0), dec!(5.0), dec!(5.0)]),
765            dec!(0)
766        );
767        assert_eq!(
768            SimdOps::std_dev_decimal(&[dec!(5.0), dec!(5.0), dec!(5.0), dec!(5.0)]),
769            dec!(0)
770        );
771    }
772
773    #[test]
774    fn test_large_arrays() {
775        // Arrays with more than 64 elements (will use heap allocation)
776        let values_f64: Vec<f64> = (1..=100).map(f64::from).collect();
777        let values_decimal: Vec<Decimal> = (1..=100).map(Decimal::from).collect();
778
779        // Sum of 1..100 = 5050
780        assert_eq!(SimdOps::sum_f64(&values_f64), 5050.0);
781        assert_eq!(SimdOps::sum_decimal(&values_decimal), dec!(5050));
782        assert_eq!(SimdOps::min_f64(&values_f64), 1.0);
783        assert_eq!(SimdOps::max_f64(&values_f64), 100.0);
784        assert_eq!(SimdOps::min_decimal(&values_decimal), dec!(1));
785        assert_eq!(SimdOps::max_decimal(&values_decimal), dec!(100));
786        assert_eq!(SimdOps::mean_f64(&values_f64), 50.5);
787        assert_eq!(SimdOps::mean_decimal(&values_decimal), dec!(50.5));
788
789        // Variance of 1..100 = 833.25 (formula: (n²-1)/12 for integers 1 to n)
790        let expected_variance = 833.25;
791        let variance = SimdOps::variance_f64(&values_f64);
792        assert!((variance - expected_variance).abs() < 0.01);
793
794        // Standard deviation = √variance ≈ 28.866
795        let expected_std_dev = expected_variance.sqrt();
796        let std_dev = SimdOps::std_dev_f64(&values_f64);
797        assert!((std_dev - expected_std_dev).abs() < 0.01);
798
799        // Same tests for Decimal
800        let variance_decimal = SimdOps::variance_decimal(&values_decimal);
801        assert!(variance_decimal > dec!(833.2) && variance_decimal < dec!(833.3));
802
803        let std_dev_decimal = SimdOps::std_dev_decimal(&values_decimal);
804        assert!(std_dev_decimal > dec!(28.86) && std_dev_decimal < dec!(28.87));
805    }
806
807    #[test]
808    fn test_nan_safety() {
809        // Test NaN handling in various operations
810        let values_with_nan = vec![1.0, f64::NAN, 3.0, 4.0];
811
812        // Sum should handle NaN gracefully (NaN propagates)
813        let sum = SimdOps::sum_f64(&values_with_nan);
814        assert!(sum.is_nan());
815
816        // Min/Max should handle NaN gracefully (NaN propagates)
817        let min = SimdOps::min_f64(&values_with_nan);
818        assert!(min.is_nan());
819
820        let max = SimdOps::max_f64(&values_with_nan);
821        assert!(max.is_nan());
822
823        // Dot product with NaN
824        let a = vec![1.0, 2.0, f64::NAN, 4.0];
825        let b = vec![1.0, 2.0, 3.0, 4.0];
826        let dot = SimdOps::dot_product_f64(&a, &b);
827        assert!(dot.is_nan());
828    }
829
830    // ===== Comprehensive NaN and Edge Case Tests =====
831
832    #[test]
833    fn test_comprehensive_nan_in_sum() {
834        // NaN in various positions
835        let values = vec![1.0, 2.0, f64::NAN, 4.0, 5.0];
836        let sum = SimdOps::sum_f64(&values);
837        assert!(sum.is_nan(), "Sum with NaN should be NaN");
838
839        // All NaN
840        let all_nan = vec![f64::NAN; 8];
841        let sum = SimdOps::sum_f64(&all_nan);
842        assert!(sum.is_nan(), "Sum of all NaN should be NaN");
843
844        // Single NaN
845        let single_nan = vec![f64::NAN];
846        let sum = SimdOps::sum_f64(&single_nan);
847        assert!(sum.is_nan(), "Sum of single NaN should be NaN");
848
849        // NaN at start
850        let nan_start = vec![f64::NAN, 1.0, 2.0, 3.0];
851        let sum = SimdOps::sum_f64(&nan_start);
852        assert!(sum.is_nan(), "Sum with NaN at start should be NaN");
853
854        // NaN at end
855        let nan_end = vec![1.0, 2.0, 3.0, f64::NAN];
856        let sum = SimdOps::sum_f64(&nan_end);
857        assert!(sum.is_nan(), "Sum with NaN at end should be NaN");
858    }
859
860    #[test]
861    fn test_comprehensive_nan_in_min_max() {
862        // Note: The standard f64 min/max operations in Rust do NOT propagate NaN
863        // They follow IEEE 754-2008 semantics where NaN is ignored
864        // This is different from arithmetic operations which DO propagate NaN
865
866        // NaN in various positions for min
867        let values = vec![3.0, f64::NAN, 1.0, 4.0];
868        let min = SimdOps::min_f64(&values);
869        // IEEE 754-2008: min ignores NaN values, so result is 1.0
870        assert_eq!(min, 1.0, "Min should ignore NaN per IEEE 754-2008");
871
872        let max = SimdOps::max_f64(&values);
873        // IEEE 754-2008: max ignores NaN values, so result is 4.0
874        assert_eq!(max, 4.0, "Max should ignore NaN per IEEE 754-2008");
875
876        // All NaN - this case will propagate NaN since there are no non-NaN values
877        let all_nan = vec![f64::NAN; 4];
878        assert!(SimdOps::min_f64(&all_nan).is_nan());
879        assert!(SimdOps::max_f64(&all_nan).is_nan());
880
881        // Single NaN
882        let single_nan = vec![f64::NAN];
883        assert!(SimdOps::min_f64(&single_nan).is_nan());
884        assert!(SimdOps::max_f64(&single_nan).is_nan());
885    }
886
887    #[test]
888    fn test_comprehensive_nan_in_mean() {
889        let values = vec![1.0, 2.0, f64::NAN, 4.0];
890        let mean = SimdOps::mean_f64(&values);
891        assert!(mean.is_nan(), "Mean with NaN should be NaN");
892
893        // Empty array mean
894        let empty: Vec<f64> = vec![];
895        let mean = SimdOps::mean_f64(&empty);
896        assert!(mean.is_nan(), "Mean of empty array should be NaN");
897    }
898
899    #[test]
900    fn test_comprehensive_nan_in_variance_std_dev() {
901        let values = vec![1.0, 2.0, f64::NAN, 4.0];
902        let variance = SimdOps::variance_f64(&values);
903        assert!(variance.is_nan(), "Variance with NaN should be NaN");
904
905        let std_dev = SimdOps::std_dev_f64(&values);
906        assert!(std_dev.is_nan(), "Std dev with NaN should be NaN");
907    }
908
909    #[test]
910    fn test_infinity_handling() {
911        // Positive infinity
912        let pos_inf = vec![1.0, 2.0, f64::INFINITY, 4.0];
913        let sum = SimdOps::sum_f64(&pos_inf);
914        assert_eq!(sum, f64::INFINITY, "Sum with +inf should be +inf");
915
916        // Negative infinity
917        let neg_inf = vec![1.0, 2.0, f64::NEG_INFINITY, 4.0];
918        let sum = SimdOps::sum_f64(&neg_inf);
919        assert_eq!(sum, f64::NEG_INFINITY, "Sum with -inf should be -inf");
920
921        // Both infinities (results in NaN)
922        let both_inf = vec![f64::INFINITY, f64::NEG_INFINITY];
923        let sum = SimdOps::sum_f64(&both_inf);
924        assert!(sum.is_nan(), "Sum of +inf and -inf should be NaN");
925
926        // Min/max with infinity
927        let inf_values = vec![1.0, f64::INFINITY, 3.0];
928        assert_eq!(SimdOps::min_f64(&inf_values), 1.0);
929        assert_eq!(SimdOps::max_f64(&inf_values), f64::INFINITY);
930
931        let neg_inf_values = vec![1.0, f64::NEG_INFINITY, 3.0];
932        assert_eq!(SimdOps::min_f64(&neg_inf_values), f64::NEG_INFINITY);
933        assert_eq!(SimdOps::max_f64(&neg_inf_values), 3.0);
934    }
935
936    #[test]
937    fn test_subnormal_numbers() {
938        // Very small numbers (subnormals)
939        let subnormals = vec![
940            f64::MIN_POSITIVE / 2.0,
941            f64::MIN_POSITIVE / 4.0,
942            f64::MIN_POSITIVE / 8.0,
943            f64::MIN_POSITIVE,
944        ];
945
946        let sum = SimdOps::sum_f64(&subnormals);
947        assert!(sum > 0.0, "Sum of subnormals should be positive");
948        assert!(sum < f64::MIN_POSITIVE * 2.0, "Sum should be very small");
949
950        let min = SimdOps::min_f64(&subnormals);
951        assert_eq!(min, f64::MIN_POSITIVE / 8.0);
952
953        let max = SimdOps::max_f64(&subnormals);
954        assert_eq!(max, f64::MIN_POSITIVE);
955    }
956
957    #[test]
958    fn test_zero_handling() {
959        // Positive and negative zeros
960        let zeros = vec![0.0, -0.0, 0.0, -0.0];
961        let sum = SimdOps::sum_f64(&zeros);
962        assert_eq!(sum, 0.0);
963
964        // Min/max with signed zeros
965        let signed_zeros = vec![-0.0, 0.0];
966        let min = SimdOps::min_f64(&signed_zeros);
967        let max = SimdOps::max_f64(&signed_zeros);
968        // Both should be treated as equal
969        assert_eq!(min, -0.0);
970        assert_eq!(max, 0.0);
971    }
972
973    #[test]
974    fn test_large_arrays_edge_cases() {
975        // Very large array with NaN
976        let mut large = vec![1.0; 10000];
977        large[5000] = f64::NAN;
978        let sum = SimdOps::sum_f64(&large);
979        assert!(sum.is_nan(), "Large array with NaN should propagate NaN");
980
981        // Array that would overflow without NaN
982        let overflow_risk = vec![f64::MAX / 10.0; 20];
983        let sum = SimdOps::sum_f64(&overflow_risk);
984        assert_eq!(sum, f64::INFINITY, "Overflow should produce infinity");
985    }
986
987    #[test]
988    fn test_boundary_sizes() {
989        // Test arrays of various sizes around SIMD lane boundaries
990        for size in [1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65] {
991            let values: Vec<f64> = (0..size).map(f64::from).collect();
992            let expected_sum = f64::from(size * (size - 1)) / 2.0;
993            let sum = SimdOps::sum_f64(&values);
994            assert_eq!(sum, expected_sum, "Sum failed for size {size}");
995
996            let min = SimdOps::min_f64(&values);
997            assert_eq!(min, 0.0, "Min failed for size {size}");
998
999            let max = SimdOps::max_f64(&values);
1000            assert_eq!(max, f64::from(size - 1), "Max failed for size {size}");
1001
1002            let mean = SimdOps::mean_f64(&values);
1003            assert_eq!(
1004                mean,
1005                f64::from(size - 1) / 2.0,
1006                "Mean failed for size {size}"
1007            );
1008        }
1009    }
1010
1011    #[test]
1012    fn test_dot_product_comprehensive_edge_cases() {
1013        // Different lengths (should return 0)
1014        let a = vec![1.0, 2.0, 3.0];
1015        let b = vec![4.0, 5.0];
1016        let dot = SimdOps::dot_product_f64(&a, &b);
1017        assert_eq!(
1018            dot, 0.0,
1019            "Dot product with different lengths should return 0"
1020        );
1021
1022        // With NaN
1023        let a_nan = vec![1.0, f64::NAN, 3.0];
1024        let b_normal = vec![4.0, 5.0, 6.0];
1025        let dot_nan = SimdOps::dot_product_f64(&a_nan, &b_normal);
1026        assert!(dot_nan.is_nan(), "Dot product with NaN should be NaN");
1027
1028        // Empty arrays
1029        let empty_a: Vec<f64> = vec![];
1030        let empty_b: Vec<f64> = vec![];
1031        let dot_empty = SimdOps::dot_product_f64(&empty_a, &empty_b);
1032        assert_eq!(dot_empty, 0.0, "Dot product of empty arrays should be 0");
1033    }
1034
1035    #[test]
1036    fn test_precision_edge_cases() {
1037        // Test with values that might lose precision
1038        let precise = vec![1e-10, 1e-10, 1e-10, 1e-10, 1e10, 1e10, 1e10, 1e10];
1039        let sum = SimdOps::sum_f64(&precise);
1040        // Should be 4e10 + 4e-10, but precision might be lost
1041        assert!(sum >= 4e10, "Sum should preserve large values");
1042
1043        // Variance of identical values should be very close to 0
1044        let identical = vec![42.42; 100];
1045        let variance = SimdOps::variance_f64(&identical);
1046        assert!(
1047            variance.abs() < 1e-10,
1048            "Variance of identical values should be near 0, got {variance}"
1049        );
1050    }
1051
1052    #[test]
1053    fn test_decimal_nan_equivalent() {
1054        // Decimals don't have NaN, but we convert through f64
1055        // Test what happens when conversion produces NaN
1056        use rust_decimal::Decimal;
1057
1058        // Very large decimal values that are still safe to add
1059        let values = vec![Decimal::MAX / dec!(10), dec!(1.0), dec!(2.0)];
1060
1061        // This should handle the conversion gracefully
1062        let sum = SimdOps::sum_decimal(&values);
1063        // The sum should be valid (not panic)
1064        assert!(sum > dec!(0));
1065
1066        // Test with values that lose precision in f64 conversion
1067        let precise_values = vec![
1068            dec!(0.1111111111111111111111111111),
1069            dec!(0.2222222222222222222222222222),
1070            dec!(0.3333333333333333333333333333),
1071        ];
1072        let sum = SimdOps::sum_decimal(&precise_values);
1073        // Should be close to 0.6666... even with f64 conversion
1074        assert!((sum - dec!(0.6666666666666666666666666666)).abs() < dec!(0.0001));
1075    }
1076}