competitive/num/mint/
mint_basic.rs

1use super::*;
2use std::{cell::UnsafeCell, mem::swap};
3
4#[macro_export]
5macro_rules! define_basic_mintbase {
6    ($name:ident, $m:expr, $basety:ty, $signedty:ty, $upperty:ty, [$($unsigned:ty),*], [$($signed:ty),*]) => {
7        pub enum $name {}
8        impl MIntBase for $name {
9            type Inner = $basety;
10            #[inline]
11            fn get_mod() -> Self::Inner {
12                $m
13            }
14            #[inline]
15            fn mod_zero() -> Self::Inner {
16                0
17            }
18            #[inline]
19            fn mod_one() -> Self::Inner {
20                1
21            }
22            #[inline]
23            fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
24                let z = x + y;
25                let m = Self::get_mod();
26                if z >= m {
27                    z - m
28                } else {
29                    z
30                }
31            }
32            #[inline]
33            fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
34                if x < y {
35                    x + Self::get_mod() - y
36                } else {
37                    x - y
38                }
39            }
40            #[inline]
41            fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
42                // (x as $upperty * y as $upperty % Self::get_mod() as $upperty) as $basety
43                $name::rem(x as $upperty * y as $upperty) as $basety
44            }
45            #[inline]
46            fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
47                Self::mod_mul(x, Self::mod_inv(y))
48            }
49            #[inline]
50            fn mod_neg(x: Self::Inner) -> Self::Inner {
51                if x == 0 {
52                    0
53                } else {
54                    Self::get_mod() - x
55                }
56            }
57            fn mod_inv(x: Self::Inner) -> Self::Inner {
58                let p = Self::get_mod() as $signedty;
59                let (mut a, mut b) = (x as $signedty, p);
60                let (mut u, mut x) = (1, 0);
61                while a != 0 {
62                    let k = b / a;
63                    x -= k * u;
64                    b -= k * a;
65                    swap(&mut x, &mut u);
66                    swap(&mut b, &mut a);
67                }
68                (if x < 0 { x + p } else { x }) as _
69            }
70        }
71        $(impl MIntConvert<$unsigned> for $name {
72            #[inline]
73            fn from(x: $unsigned) -> Self::Inner {
74                (x % <Self as MIntBase>::get_mod() as $unsigned) as $basety
75            }
76            #[inline]
77            fn into(x: Self::Inner) -> $unsigned {
78                x as $unsigned
79            }
80            #[inline]
81            fn mod_into() -> $unsigned {
82                <Self as MIntBase>::get_mod() as $unsigned
83            }
84        })*
85        $(impl MIntConvert<$signed> for $name {
86            #[inline]
87            fn from(x: $signed) -> Self::Inner {
88                let x = x % <Self as MIntBase>::get_mod() as $signed;
89                if x < 0 {
90                    (x + <Self as MIntBase>::get_mod() as $signed) as $basety
91                } else {
92                    x as $basety
93                }
94            }
95            #[inline]
96            fn into(x: Self::Inner) -> $signed {
97                x as $signed
98            }
99            #[inline]
100            fn mod_into() -> $signed {
101                <Self as MIntBase>::get_mod() as $signed
102            }
103        })*
104    };
105}
106
107#[macro_export]
108macro_rules! define_basic_mint32 {
109    ($([$name:ident, $m:expr, $mint_name:ident]),*) => {
110        $(define_basic_mintbase!(
111            $name,
112            $m,
113            u32,
114            i32,
115            u64,
116            [u32, u64, u128, usize],
117            [i32, i64, i128, isize]
118        );
119        impl $name {
120            fn rem(x: u64) -> u64 {
121                x % $m
122            }
123        }
124        pub type $mint_name = MInt<$name>;)*
125    };
126}
127
128thread_local!(static DYN_MODULUS_U32: UnsafeCell<BarrettReduction<u64>> = const { UnsafeCell::new(BarrettReduction::<u64>::new(1_000_000_007)) });
129impl DynModuloU32 {
130    pub fn set_mod(m: u32) {
131        DYN_MODULUS_U32
132            .with(|cell| unsafe { *cell.get() = BarrettReduction::<u64>::new(m as u64) });
133    }
134    fn rem(x: u64) -> u64 {
135        DYN_MODULUS_U32.with(|cell| unsafe { (*cell.get()).rem(x) })
136    }
137}
138impl DynMIntU32 {
139    pub fn set_mod(m: u32) {
140        DynModuloU32::set_mod(m)
141    }
142}
143
144thread_local!(static DYN_MODULUS_U64: UnsafeCell<BarrettReduction<u128>> = const { UnsafeCell::new(BarrettReduction::<u128>::new(1_000_000_007)) });
145impl DynModuloU64 {
146    pub fn set_mod(m: u64) {
147        DYN_MODULUS_U64
148            .with(|cell| unsafe { *cell.get() = BarrettReduction::<u128>::new(m as u128) })
149    }
150    fn rem(x: u128) -> u128 {
151        DYN_MODULUS_U64.with(|cell| unsafe { (*cell.get()).rem(x) })
152    }
153}
154impl DynMIntU64 {
155    pub fn set_mod(m: u64) {
156        DynModuloU64::set_mod(m)
157    }
158}
159
160define_basic_mint32!(
161    [Modulo998244353, 998_244_353, MInt998244353],
162    [Modulo1000000007, 1_000_000_007, MInt1000000007],
163    [Modulo1000000009, 1_000_000_009, MInt1000000009]
164);
165
166define_basic_mintbase!(
167    DynModuloU32,
168    DYN_MODULUS_U32.with(|cell| unsafe { (*cell.get()).get_mod() as u32 }),
169    u32,
170    i32,
171    u64,
172    [u32, u64, u128, usize],
173    [i32, i64, i128, isize]
174);
175pub type DynMIntU32 = MInt<DynModuloU32>;
176define_basic_mintbase!(
177    DynModuloU64,
178    DYN_MODULUS_U64.with(|cell| unsafe { (*cell.get()).get_mod() as u64 }),
179    u64,
180    i64,
181    u128,
182    [u64, u128, usize],
183    [i64, i128, isize]
184);
185pub type DynMIntU64 = MInt<DynModuloU64>;
186
187pub struct Modulo2;
188impl MIntBase for Modulo2 {
189    type Inner = u32;
190    #[inline]
191    fn get_mod() -> Self::Inner {
192        2
193    }
194    #[inline]
195    fn mod_zero() -> Self::Inner {
196        0
197    }
198    #[inline]
199    fn mod_one() -> Self::Inner {
200        1
201    }
202    #[inline]
203    fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
204        x ^ y
205    }
206    #[inline]
207    fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
208        x ^ y
209    }
210    #[inline]
211    fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
212        x & y
213    }
214    #[inline]
215    fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
216        assert_ne!(y, 0);
217        x
218    }
219    #[inline]
220    fn mod_neg(x: Self::Inner) -> Self::Inner {
221        x
222    }
223    #[inline]
224    fn mod_inv(x: Self::Inner) -> Self::Inner {
225        assert_ne!(x, 0);
226        x
227    }
228    #[inline]
229    fn mod_pow(x: Self::Inner, y: usize) -> Self::Inner {
230        if y == 0 { 1 } else { x }
231    }
232}
233macro_rules! impl_to_mint_base_for_modulo2 {
234    ($name:ident, $basety:ty, [$($t:ty),*]) => {
235        $(impl MIntConvert<$t> for $name {
236            #[inline]
237            fn from(x: $t) -> Self::Inner {
238                (x & 1) as $basety
239            }
240            #[inline]
241            fn into(x: Self::Inner) -> $t {
242                x as $t
243            }
244            #[inline]
245            fn mod_into() -> $t {
246                1
247            }
248        })*
249    };
250}
251impl_to_mint_base_for_modulo2!(
252    Modulo2,
253    u32,
254    [
255        u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
256    ]
257);
258pub type MInt2 = MInt<Modulo2>;
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::tools::Xorshift;
264
265    macro_rules! test_mint {
266        ($test_name:ident $mint:ident $($m:expr)?) => {
267            #[test]
268            fn $test_name() {
269                let mut rng = Xorshift::new();
270                const Q: usize = 10_000;
271                for _ in 0..Q {
272                    $($mint::set_mod(rng.gen(..$m));)?
273                    let a = $mint::new_unchecked(rng.random(1..$mint::get_mod()));
274                    let x = a.inv();
275                    assert!(x.inner() < $mint::get_mod());
276                    assert_eq!(a * x, $mint::one());
277                }
278            }
279        };
280    }
281    test_mint!(test_mint2 MInt2);
282    test_mint!(test_mint998244353 MInt998244353);
283    test_mint!(test_mint1000000007 MInt1000000007);
284    test_mint!(test_mint1000000009 MInt1000000009);
285}