1#[macro_export]
46macro_rules! simd_reduce {
47 (
50 $values:expr,
51 2, 8, $init_expr:expr, $combine_expr:expr, $extract_expr:expr, $scalar_init:expr, $scalar_combine:expr ) => {{
58 let values = $values;
59 if values.is_empty() {
60 $scalar_init
61 } else {
62 let mut i = 0;
63 let n = values.len();
64
65 let mut sum_vec1 = $init_expr;
67 let mut sum_vec2 = $init_expr;
68
69 while i + 8 <= n {
71 let v1 = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
73 let v2 = f64x4::from([values[i + 4], values[i + 5], values[i + 6], values[i + 7]]);
74
75 sum_vec1 = ($combine_expr)(sum_vec1, v1);
77 sum_vec2 = ($combine_expr)(sum_vec2, v2);
78
79 i += 8;
80 }
81
82 if i + 4 <= n {
84 let v = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
85 sum_vec1 = ($combine_expr)(sum_vec1, v);
86 i += 4;
87 }
88
89 let mut result = ($extract_expr)(sum_vec1, sum_vec2);
91
92 while i < n {
94 result = ($scalar_combine)(result, values[i]);
95 i += 1;
96 }
97
98 result
99 }
100 }};
101
102 (
104 $values:expr,
105 1, 4,
106 $init_expr:expr,
107 $combine_expr:expr,
108 $extract_expr:expr,
109 $scalar_init:expr,
110 $scalar_combine:expr
111 ) => {{
112 let values = $values;
113 if values.is_empty() {
114 $scalar_init
115 } else {
116 let mut i = 0;
117 let n = values.len();
118
119 let mut acc_vec = $init_expr;
121
122 while i + 4 <= n {
124 let v = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
126
127 acc_vec = ($combine_expr)(acc_vec, v);
129
130 i += 4;
131 }
132
133 let mut result = ($extract_expr)(acc_vec);
135
136 while i < n {
138 result = ($scalar_combine)(result, values[i]);
139 i += 1;
140 }
141
142 result
143 }
144 }};
145
146 (
148 $values:expr,
149 4, 16,
150 $init_expr:expr,
151 $combine_expr:expr,
152 $extract_expr:expr,
153 $scalar_init:expr,
154 $scalar_combine:expr
155 ) => {{
156 let values = $values;
157 if values.is_empty() {
158 $scalar_init
159 } else {
160 let mut i = 0;
161 let n = values.len();
162
163 let mut acc_vec1 = $init_expr;
165 let mut acc_vec2 = $init_expr;
166 let mut acc_vec3 = $init_expr;
167 let mut acc_vec4 = $init_expr;
168
169 while i + 16 <= n {
171 let v1 = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
172 let v2 = f64x4::from([values[i + 4], values[i + 5], values[i + 6], values[i + 7]]);
173 let v3 =
174 f64x4::from([values[i + 8], values[i + 9], values[i + 10], values[i + 11]]);
175 let v4 = f64x4::from([
176 values[i + 12],
177 values[i + 13],
178 values[i + 14],
179 values[i + 15],
180 ]);
181
182 acc_vec1 = ($combine_expr)(acc_vec1, v1);
183 acc_vec2 = ($combine_expr)(acc_vec2, v2);
184 acc_vec3 = ($combine_expr)(acc_vec3, v3);
185 acc_vec4 = ($combine_expr)(acc_vec4, v4);
186
187 i += 16;
188 }
189
190 if i + 8 <= n {
192 let v1 = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
193 let v2 = f64x4::from([values[i + 4], values[i + 5], values[i + 6], values[i + 7]]);
194 acc_vec1 = ($combine_expr)(acc_vec1, v1);
195 acc_vec2 = ($combine_expr)(acc_vec2, v2);
196 i += 8;
197 }
198
199 if i + 4 <= n {
201 let v = f64x4::from([values[i], values[i + 1], values[i + 2], values[i + 3]]);
202 acc_vec1 = ($combine_expr)(acc_vec1, v);
203 i += 4;
204 }
205
206 let mut result = ($extract_expr)(acc_vec1, acc_vec2, acc_vec3, acc_vec4);
208
209 while i < n {
211 result = ($scalar_combine)(result, values[i]);
212 i += 1;
213 }
214
215 result
216 }
217 }};
218}
219
220#[macro_export]
236macro_rules! simd_dual_op {
237 (
238 $a:expr, $b:expr,
239 1, 4,
240 $init_expr:expr,
241 $combine_expr:expr,
242 $extract_expr:expr,
243 $scalar_init:expr,
244 $scalar_combine:expr
245 ) => {{
246 let a = $a;
247 let b = $b;
248
249 if a.is_empty() || b.is_empty() || a.len() != b.len() {
250 $scalar_init
251 } else {
252 let mut i = 0;
253 let n = a.len();
254
255 let mut acc_vec = $init_expr;
257
258 while i + 4 <= n {
260 let a_vec = f64x4::from([a[i], a[i + 1], a[i + 2], a[i + 3]]);
262 let b_vec = f64x4::from([b[i], b[i + 1], b[i + 2], b[i + 3]]);
263
264 acc_vec = ($combine_expr)(acc_vec, a_vec, b_vec);
266
267 i += 4;
268 }
269
270 let mut result = ($extract_expr)(acc_vec);
272
273 while i < n {
275 result = ($scalar_combine)(result, a[i], b[i]);
276 i += 1;
277 }
278
279 result
280 }
281 }};
282}
283
284#[macro_export]
295macro_rules! simd_chunk_process {
296 (
297 $values:expr,
298 $chunk_size:expr,
299 $process_chunk:expr,
300 $process_remainder:expr
301 ) => {{
302 let values = $values;
303 let chunk_size = $chunk_size;
304 let mut i = 0;
305 let n = values.len();
306
307 while i + chunk_size <= n {
309 let chunk_start = i;
310 ($process_chunk)(chunk_start, &values[i..i + chunk_size]);
311 i += chunk_size;
312 }
313
314 if i < n {
316 ($process_remainder)(i, &values[i..]);
317 }
318 }};
319}
320
321#[macro_export]
323macro_rules! extract_sum_f64x4 {
324 ($vec:expr) => {{
325 let arr = $vec.as_array_ref();
326 arr[0] + arr[1] + arr[2] + arr[3]
327 }};
328}
329
330#[macro_export]
332macro_rules! extract_min_f64x4 {
333 ($vec:expr) => {{
334 let arr = $vec.as_array_ref();
335 arr[0].min(arr[1]).min(arr[2]).min(arr[3])
336 }};
337}
338
339#[macro_export]
341macro_rules! extract_max_f64x4 {
342 ($vec:expr) => {{
343 let arr = $vec.as_array_ref();
344 arr[0].max(arr[1]).max(arr[2]).max(arr[3])
345 }};
346}
347
348#[cfg(test)]
359mod tests {
360 use wide::f64x4;
361
362 #[test]
364 fn test_simd_reduce_dual_accumulator() {
365 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
367
368 let sum = simd_reduce!(
369 &values,
370 2,
371 8,
372 f64x4::ZERO,
373 |acc, vec| acc + vec,
374 |acc1: f64x4, acc2: f64x4| {
375 let total: f64x4 = acc1 + acc2;
376 extract_sum_f64x4!(total)
377 },
378 0.0,
379 |acc, val| acc + val
380 );
381
382 assert_eq!(sum, 55.0, "Sum should be 1+2+...+10 = 55");
383 }
384
385 #[test]
387 fn test_simd_reduce_single_accumulator() {
388 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
389
390 let min_val = simd_reduce!(
392 &values,
393 1,
394 4,
395 f64x4::splat(f64::INFINITY),
396 |acc: f64x4, vec: f64x4| acc.min(vec),
397 |acc: f64x4| extract_min_f64x4!(acc),
398 f64::INFINITY,
399 |acc: f64, val: f64| acc.min(val)
400 );
401
402 assert_eq!(min_val, 1.0, "Minimum value should be 1.0");
403
404 let max_val = simd_reduce!(
406 &values,
407 1,
408 4,
409 f64x4::splat(f64::NEG_INFINITY),
410 |acc: f64x4, vec: f64x4| acc.max(vec),
411 |acc: f64x4| extract_max_f64x4!(acc),
412 f64::NEG_INFINITY,
413 |acc: f64, val: f64| acc.max(val)
414 );
415
416 assert_eq!(max_val, 5.0, "Maximum value should be 5.0");
417 }
418
419 #[test]
421 fn test_simd_reduce_quad_accumulator() {
422 let values: Vec<f64> = (1..=64).map(|x| x as f64).collect();
424 let expected_sum: f64 = (1..=64).sum::<i32>() as f64; let sum = simd_reduce!(
427 &values,
428 4,
429 16,
430 f64x4::ZERO,
431 |acc, vec| acc + vec,
432 |acc1: f64x4, acc2: f64x4, acc3: f64x4, acc4: f64x4| {
433 let total: f64x4 = acc1 + acc2 + acc3 + acc4;
434 extract_sum_f64x4!(total)
435 },
436 0.0,
437 |acc, val| acc + val
438 );
439
440 assert_eq!(sum, expected_sum, "Sum of 1..64 should be 2080");
441 }
442
443 #[test]
445 fn test_simd_dual_op() {
446 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
447 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
448 let expected_dot_product = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0 + 5.0 * 6.0; let dot_product = simd_dual_op!(
451 &a,
452 &b,
453 1,
454 4,
455 f64x4::ZERO,
456 |acc: f64x4, a_vec: f64x4, b_vec: f64x4| acc + (a_vec * b_vec),
457 |acc: f64x4| extract_sum_f64x4!(acc),
458 0.0,
459 |acc, a_val, b_val| acc + (a_val * b_val)
460 );
461
462 assert_eq!(
463 dot_product, expected_dot_product,
464 "Dot product should be 70.0"
465 );
466 }
467
468 #[test]
470 fn test_simd_chunk_process() {
471 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
472 let mut chunk_sums = Vec::new();
473 let mut remainder_sum = 0.0;
474
475 simd_chunk_process!(
476 &values,
477 4,
478 |_start, chunk: &[f64]| {
479 let chunk_sum: f64 = chunk.iter().sum();
480 chunk_sums.push(chunk_sum);
481 },
482 |_start, remainder: &[f64]| {
483 remainder_sum = remainder.iter().sum();
484 }
485 );
486
487 assert_eq!(chunk_sums.len(), 2, "Should have 2 chunks of size 4");
488 assert_eq!(chunk_sums[0], 10.0, "First chunk sum: 1+2+3+4 = 10");
489 assert_eq!(chunk_sums[1], 26.0, "Second chunk sum: 5+6+7+8 = 26");
490 assert_eq!(remainder_sum, 9.0, "Remainder sum: 9 = 9");
491 }
492
493 #[test]
495 fn test_extract_helper_macros() {
496 let vec = f64x4::from([1.0, 2.0, 3.0, 4.0]);
497
498 let sum = extract_sum_f64x4!(vec);
500 assert_eq!(sum, 10.0, "Sum should be 1+2+3+4 = 10");
501
502 let min_vec = f64x4::from([4.0, 1.0, 3.0, 2.0]);
504 let min_val = extract_min_f64x4!(min_vec);
505 assert_eq!(min_val, 1.0, "Minimum should be 1.0");
506
507 let max_vec = f64x4::from([1.0, 4.0, 2.0, 3.0]);
509 let max_val = extract_max_f64x4!(max_vec);
510 assert_eq!(max_val, 4.0, "Maximum should be 4.0");
511 }
512
513 #[test]
515 fn test_edge_cases_empty_arrays() {
516 let empty: Vec<f64> = vec![];
517
518 let sum = simd_reduce!(
520 &empty,
521 2,
522 8,
523 f64x4::ZERO,
524 |acc, vec| acc + vec,
525 |acc1: f64x4, acc2: f64x4| {
526 let total: f64x4 = acc1 + acc2;
527 extract_sum_f64x4!(total)
528 },
529 0.0,
530 |acc, val| acc + val
531 );
532 assert_eq!(sum, 0.0, "Empty array sum should be 0.0");
533
534 let dot_product = simd_dual_op!(
536 &empty,
537 &empty,
538 1,
539 4,
540 f64x4::ZERO,
541 |acc: f64x4, a_vec: f64x4, b_vec: f64x4| acc + (a_vec * b_vec),
542 |acc: f64x4| extract_sum_f64x4!(acc),
543 0.0,
544 |acc, a_val, b_val| acc + (a_val * b_val)
545 );
546 assert_eq!(dot_product, 0.0, "Empty array dot product should be 0.0");
547 }
548
549 #[test]
551 fn test_edge_cases_single_element() {
552 let single = vec![42.0];
553
554 let sum = simd_reduce!(
555 &single,
556 1,
557 4,
558 f64x4::ZERO,
559 |acc, vec| acc + vec,
560 |acc: f64x4| extract_sum_f64x4!(acc),
561 0.0,
562 |acc, val| acc + val
563 );
564 assert_eq!(sum, 42.0, "Single element sum should be 42.0");
565 }
566
567 #[test]
569 fn test_edge_cases_mismatched_lengths() {
570 let a = vec![1.0, 2.0, 3.0];
571 let b = vec![1.0, 2.0]; let dot_product = simd_dual_op!(
574 &a,
575 &b,
576 1,
577 4,
578 f64x4::ZERO,
579 |acc: f64x4, a_vec: f64x4, b_vec: f64x4| acc + (a_vec * b_vec),
580 |acc: f64x4| extract_sum_f64x4!(acc),
581 0.0,
582 |acc, a_val, b_val| acc + (a_val * b_val)
583 );
584 assert_eq!(
585 dot_product, 0.0,
586 "Mismatched lengths should return initial value"
587 );
588 }
589
590 #[test]
592 fn test_nan_handling() {
593 let values_with_nan = vec![1.0, 2.0, f64::NAN, 4.0, 5.0];
594
595 let sum = simd_reduce!(
597 &values_with_nan,
598 1,
599 4,
600 f64x4::ZERO,
601 |acc, vec| acc + vec,
602 |acc: f64x4| extract_sum_f64x4!(acc),
603 0.0,
604 |acc, val| acc + val
605 );
606 assert!(sum.is_nan(), "Sum with NaN should be NaN");
607
608 let min_val = simd_reduce!(
610 &values_with_nan,
611 1,
612 4,
613 f64x4::splat(f64::INFINITY),
614 |acc: f64x4, vec: f64x4| acc.min(vec),
615 |acc: f64x4| extract_min_f64x4!(acc),
616 f64::INFINITY,
617 |acc: f64, val: f64| acc.min(val)
618 );
619 assert!(
621 min_val.is_nan() || min_val == 1.0,
622 "Min with NaN should handle gracefully"
623 );
624 }
625
626 #[test]
628 fn test_large_arrays() {
629 let large_array: Vec<f64> = (0..10000).map(|x| x as f64).collect();
630 let expected_sum: f64 = (0..10000).sum::<i32>() as f64; let sum = simd_reduce!(
633 &large_array,
634 4,
635 16,
636 f64x4::ZERO,
637 |acc, vec| acc + vec,
638 |acc1: f64x4, acc2: f64x4, acc3: f64x4, acc4: f64x4| {
639 let total: f64x4 = acc1 + acc2 + acc3 + acc4;
640 extract_sum_f64x4!(total)
641 },
642 0.0,
643 |acc, val| acc + val
644 );
645
646 assert_eq!(sum, expected_sum, "Large array sum should be correct");
647 }
648
649 #[test]
651 fn test_remainder_handling() {
652 let sizes_and_expected = vec![
654 (1, 1.0), (3, 6.0), (5, 15.0), (7, 28.0), (15, 120.0), ];
660
661 for (size, expected) in sizes_and_expected {
662 let values: Vec<f64> = (1..=size).map(|x| x as f64).collect();
663
664 let sum = simd_reduce!(
665 &values,
666 2,
667 8,
668 f64x4::ZERO,
669 |acc, vec| acc + vec,
670 |acc1, acc2| {
671 let total: f64x4 = acc1 + acc2;
672 extract_sum_f64x4!(total)
673 },
674 0.0,
675 |acc, val| acc + val
676 );
677
678 assert_eq!(sum, expected, "Sum for size {size} should be {expected}");
679 }
680 }
681
682 #[test]
684 fn test_simd_performance_validation() {
685 let values: Vec<f64> = (0..1000).map(|x| x as f64).collect();
686
687 let start = std::time::Instant::now();
688 for _ in 0..1000 {
689 let _ = simd_reduce!(
690 &values,
691 2,
692 8,
693 f64x4::ZERO,
694 |acc, vec| acc + vec,
695 |acc1, acc2| {
696 let total: f64x4 = acc1 + acc2;
697 extract_sum_f64x4!(total)
698 },
699 0.0,
700 |acc, val| acc + val
701 );
702 }
703 let duration = start.elapsed();
704
705 assert!(
707 duration.as_millis() < 100,
708 "SIMD macro performance regression detected: {duration:?}"
709 );
710 }
711}