competitive/tools/
avx_helper.rs1use std::sync::atomic::{AtomicBool, Ordering};
2
3#[derive(Copy, Clone, Debug, Eq, PartialEq)]
4pub enum SimdBackend {
5 Scalar,
6 Avx2,
7 Avx512,
8}
9
10static AVX512_ENABLED: AtomicBool = AtomicBool::new(true);
11
12#[inline]
13pub fn disable_avx512() {
14 AVX512_ENABLED.store(false, Ordering::Relaxed);
15}
16
17#[inline]
18pub fn enable_avx512() {
19 AVX512_ENABLED.store(true, Ordering::Relaxed);
20}
21
22#[inline]
23pub fn avx512_enabled() -> bool {
24 AVX512_ENABLED.load(Ordering::Relaxed)
25}
26
27#[inline]
28pub fn simd_backend() -> SimdBackend {
29 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
30 {
31 if avx512_enabled()
32 && is_x86_feature_detected!("avx512f")
33 && is_x86_feature_detected!("avx512dq")
34 && is_x86_feature_detected!("avx512cd")
35 && is_x86_feature_detected!("avx512bw")
36 && is_x86_feature_detected!("avx512vl")
37 {
38 return SimdBackend::Avx512;
39 }
40 if is_x86_feature_detected!("avx2") {
41 return SimdBackend::Avx2;
42 }
43 }
44 SimdBackend::Scalar
45}
46
47#[macro_export]
48macro_rules! avx_helper {
49 (@avx512 $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($i:ident: $t:ty),*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
50 $(#[$meta])*
51 $vis fn $name$(<$($T)*>)?($($i: $t),*) -> $ret
52 where
53 $($clauses)*
54 {
55 if is_x86_feature_detected!("avx512f")
56 && is_x86_feature_detected!("avx512dq")
57 && is_x86_feature_detected!("avx512cd")
58 && is_x86_feature_detected!("avx512bw")
59 && is_x86_feature_detected!("avx512vl")
60 {
61 $crate::avx_helper!(@def_avx512 fn avx512$(<$($T)*>)?($($i: $t),*) -> $ret where [$($clauses)*] $body);
62 unsafe { avx512$(::<$($T),*>)?($($i),*) }
63 } else if is_x86_feature_detected!("avx2") {
64 $crate::avx_helper!(@def_avx2 fn avx2$(<$($T)*>)?($($i: $t),*) -> $ret where [$($clauses)*] $body);
65 unsafe { avx2$(::<$($T),*>)?($($i),*) }
66 } else {
67 $body
68 }
69 }
70 };
71 (@avx2 $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($i:ident: $t:ty),*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
72 $(#[$meta])*
73 $vis fn $name$(<$($T)*>)?($($i: $t),*) -> $ret
74 where
75 $($clauses)*
76 {
77 if is_x86_feature_detected!("avx2") {
78 $crate::avx_helper!(@def_avx2 fn avx2$(<$($T)*>)?($($i: $t),*) -> $ret where [$($clauses)*] $body);
79 unsafe { avx2$(::<$($T),*>)?($($i),*) }
80 } else {
81 $body
82 }
83 }
84 };
85 (@def_avx512 fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
86 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
87 unsafe fn $name$(<$($T)*>)?($($args)*) -> $ret
88 where
89 $($clauses)*
90 $body
91 };
92 (@def_avx2 fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
93 #[target_feature(enable = "avx2")]
94 unsafe fn $name$(<$($T)*>)?($($args)*) -> $ret
95 where
96 $($clauses)*
97 $body
98 };
99 (@$tag:ident $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) -> $ret:ty $body:block) => {
100 $crate::avx_helper!(@$tag $(#[$meta])* $vis fn $name$(<$($T)*>)?($($args)*) -> $ret where [] $body);
101 };
102 (@$tag:ident $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) $($t:tt)*) => {
103 $crate::avx_helper!(@$tag $(#[$meta])* $vis fn $name$(<$($T)*>)?($($args)*) -> () $($t)*);
104 };
105 ($($t:tt)*) => {
106 ::std::compile_error!($($t)*);
107 }
108}