Skip to main content

competitive/tools/
avx_helper.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2
3#[derive(Copy, Clone, Debug, Eq, PartialEq)]
4pub enum SimdBackend {
5    Scalar,
6    Avx2,
7    Avx512,
8}
9
10static AVX512_ENABLED: AtomicBool = AtomicBool::new(true);
11
12#[inline]
13pub fn disable_avx512() {
14    AVX512_ENABLED.store(false, Ordering::Relaxed);
15}
16
17#[inline]
18pub fn enable_avx512() {
19    AVX512_ENABLED.store(true, Ordering::Relaxed);
20}
21
22#[inline]
23pub fn avx512_enabled() -> bool {
24    AVX512_ENABLED.load(Ordering::Relaxed)
25}
26
27#[inline]
28pub fn simd_backend() -> SimdBackend {
29    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
30    {
31        if avx512_enabled()
32            && is_x86_feature_detected!("avx512f")
33            && is_x86_feature_detected!("avx512dq")
34            && is_x86_feature_detected!("avx512cd")
35            && is_x86_feature_detected!("avx512bw")
36            && is_x86_feature_detected!("avx512vl")
37        {
38            return SimdBackend::Avx512;
39        }
40        if is_x86_feature_detected!("avx2") {
41            return SimdBackend::Avx2;
42        }
43    }
44    SimdBackend::Scalar
45}
46
47#[macro_export]
48macro_rules! avx_helper {
49    (@avx512 $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($i:ident: $t:ty),*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
50        $(#[$meta])*
51        $vis fn $name$(<$($T)*>)?($($i: $t),*) -> $ret
52        where
53            $($clauses)*
54        {
55            if is_x86_feature_detected!("avx512f")
56                && is_x86_feature_detected!("avx512dq")
57                && is_x86_feature_detected!("avx512cd")
58                && is_x86_feature_detected!("avx512bw")
59                && is_x86_feature_detected!("avx512vl")
60            {
61                $crate::avx_helper!(@def_avx512 fn avx512$(<$($T)*>)?($($i: $t),*) -> $ret where [$($clauses)*] $body);
62                unsafe { avx512$(::<$($T),*>)?($($i),*) }
63            } else if is_x86_feature_detected!("avx2") {
64                $crate::avx_helper!(@def_avx2 fn avx2$(<$($T)*>)?($($i: $t),*) -> $ret where [$($clauses)*] $body);
65                unsafe { avx2$(::<$($T),*>)?($($i),*) }
66            } else {
67                $body
68            }
69        }
70    };
71    (@avx2 $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($i:ident: $t:ty),*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
72        $(#[$meta])*
73        $vis fn $name$(<$($T)*>)?($($i: $t),*) -> $ret
74        where
75            $($clauses)*
76        {
77            if is_x86_feature_detected!("avx2") {
78                $crate::avx_helper!(@def_avx2 fn avx2$(<$($T)*>)?($($i: $t),*) -> $ret where [$($clauses)*] $body);
79                unsafe { avx2$(::<$($T),*>)?($($i),*) }
80            } else {
81                $body
82            }
83        }
84    };
85    (@def_avx512 fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
86        #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
87        unsafe fn $name$(<$($T)*>)?($($args)*) -> $ret
88        where
89            $($clauses)*
90        $body
91    };
92    (@def_avx2 fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) -> $ret:ty where [$($clauses:tt)*] $body:block) => {
93        #[target_feature(enable = "avx2")]
94        unsafe fn $name$(<$($T)*>)?($($args)*) -> $ret
95        where
96            $($clauses)*
97        $body
98    };
99    (@$tag:ident $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) -> $ret:ty $body:block) => {
100        $crate::avx_helper!(@$tag $(#[$meta])* $vis fn $name$(<$($T)*>)?($($args)*) -> $ret where [] $body);
101    };
102    (@$tag:ident $(#[$meta:meta])* $vis:vis fn $name:ident$(<$($T:ident),+>)?($($args:tt)*) $($t:tt)*) => {
103        $crate::avx_helper!(@$tag $(#[$meta])* $vis fn $name$(<$($T)*>)?($($args)*) -> () $($t)*);
104    };
105    ($($t:tt)*) => {
106        ::std::compile_error!($($t)*);
107    }
108}