1use super::*;
2use std::{cell::UnsafeCell, mem::swap};
3
4#[macro_export]
5macro_rules! define_basic_mintbase {
6 ($name:ident, $m:expr, $basety:ty, $signedty:ty, $upperty:ty, [$($unsigned:ty),*], [$($signed:ty),*]) => {
7 pub enum $name {}
8 impl MIntBase for $name {
9 type Inner = $basety;
10 #[inline]
11 fn get_mod() -> Self::Inner {
12 $m
13 }
14 #[inline]
15 fn mod_zero() -> Self::Inner {
16 0
17 }
18 #[inline]
19 fn mod_one() -> Self::Inner {
20 1
21 }
22 #[inline]
23 fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
24 let z = x + y;
25 let m = Self::get_mod();
26 if z >= m {
27 z - m
28 } else {
29 z
30 }
31 }
32 #[inline]
33 fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
34 if x < y {
35 x + Self::get_mod() - y
36 } else {
37 x - y
38 }
39 }
40 #[inline]
41 fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
42 $name::rem(x as $upperty * y as $upperty) as $basety
44 }
45 #[inline]
46 fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
47 Self::mod_mul(x, Self::mod_inv(y))
48 }
49 #[inline]
50 fn mod_neg(x: Self::Inner) -> Self::Inner {
51 if x == 0 {
52 0
53 } else {
54 Self::get_mod() - x
55 }
56 }
57 fn mod_inv(x: Self::Inner) -> Self::Inner {
58 let p = Self::get_mod() as $signedty;
59 let (mut a, mut b) = (x as $signedty, p);
60 let (mut u, mut x) = (1, 0);
61 while a != 0 {
62 let k = b / a;
63 x -= k * u;
64 b -= k * a;
65 swap(&mut x, &mut u);
66 swap(&mut b, &mut a);
67 }
68 (if x < 0 { x + p } else { x }) as _
69 }
70 }
71 $(impl MIntConvert<$unsigned> for $name {
72 #[inline]
73 fn from(x: $unsigned) -> Self::Inner {
74 (x % <Self as MIntBase>::get_mod() as $unsigned) as $basety
75 }
76 #[inline]
77 fn into(x: Self::Inner) -> $unsigned {
78 x as $unsigned
79 }
80 #[inline]
81 fn mod_into() -> $unsigned {
82 <Self as MIntBase>::get_mod() as $unsigned
83 }
84 })*
85 $(impl MIntConvert<$signed> for $name {
86 #[inline]
87 fn from(x: $signed) -> Self::Inner {
88 let x = x % <Self as MIntBase>::get_mod() as $signed;
89 if x < 0 {
90 (x + <Self as MIntBase>::get_mod() as $signed) as $basety
91 } else {
92 x as $basety
93 }
94 }
95 #[inline]
96 fn into(x: Self::Inner) -> $signed {
97 x as $signed
98 }
99 #[inline]
100 fn mod_into() -> $signed {
101 <Self as MIntBase>::get_mod() as $signed
102 }
103 })*
104 };
105}
106
107#[macro_export]
108macro_rules! define_basic_mint32 {
109 ($([$name:ident, $m:expr, $mint_name:ident]),*) => {
110 $(define_basic_mintbase!(
111 $name,
112 $m,
113 u32,
114 i32,
115 u64,
116 [u32, u64, u128, usize],
117 [i32, i64, i128, isize]
118 );
119 impl $name {
120 fn rem(x: u64) -> u64 {
121 x % $m
122 }
123 }
124 pub type $mint_name = MInt<$name>;)*
125 };
126}
127
128thread_local!(static DYN_MODULUS_U32: UnsafeCell<BarrettReduction<u64>> = const { UnsafeCell::new(BarrettReduction::<u64>::new(1_000_000_007)) });
129impl DynModuloU32 {
130 pub fn set_mod(m: u32) {
131 DYN_MODULUS_U32
132 .with(|cell| unsafe { *cell.get() = BarrettReduction::<u64>::new(m as u64) });
133 }
134 fn rem(x: u64) -> u64 {
135 DYN_MODULUS_U32.with(|cell| unsafe { (*cell.get()).rem(x) })
136 }
137}
138impl DynMIntU32 {
139 pub fn set_mod(m: u32) {
140 DynModuloU32::set_mod(m)
141 }
142}
143
144thread_local!(static DYN_MODULUS_U64: UnsafeCell<BarrettReduction<u128>> = const { UnsafeCell::new(BarrettReduction::<u128>::new(1_000_000_007)) });
145impl DynModuloU64 {
146 pub fn set_mod(m: u64) {
147 DYN_MODULUS_U64
148 .with(|cell| unsafe { *cell.get() = BarrettReduction::<u128>::new(m as u128) })
149 }
150 fn rem(x: u128) -> u128 {
151 DYN_MODULUS_U64.with(|cell| unsafe { (*cell.get()).rem(x) })
152 }
153}
154impl DynMIntU64 {
155 pub fn set_mod(m: u64) {
156 DynModuloU64::set_mod(m)
157 }
158}
159
160define_basic_mint32!(
161 [Modulo998244353, 998_244_353, MInt998244353],
162 [Modulo1000000007, 1_000_000_007, MInt1000000007],
163 [Modulo1000000009, 1_000_000_009, MInt1000000009]
164);
165
166define_basic_mintbase!(
167 DynModuloU32,
168 DYN_MODULUS_U32.with(|cell| unsafe { (*cell.get()).get_mod() as u32 }),
169 u32,
170 i32,
171 u64,
172 [u32, u64, u128, usize],
173 [i32, i64, i128, isize]
174);
175pub type DynMIntU32 = MInt<DynModuloU32>;
176define_basic_mintbase!(
177 DynModuloU64,
178 DYN_MODULUS_U64.with(|cell| unsafe { (*cell.get()).get_mod() as u64 }),
179 u64,
180 i64,
181 u128,
182 [u64, u128, usize],
183 [i64, i128, isize]
184);
185pub type DynMIntU64 = MInt<DynModuloU64>;
186
187pub struct Modulo2;
188impl MIntBase for Modulo2 {
189 type Inner = u32;
190 #[inline]
191 fn get_mod() -> Self::Inner {
192 2
193 }
194 #[inline]
195 fn mod_zero() -> Self::Inner {
196 0
197 }
198 #[inline]
199 fn mod_one() -> Self::Inner {
200 1
201 }
202 #[inline]
203 fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
204 x ^ y
205 }
206 #[inline]
207 fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
208 x ^ y
209 }
210 #[inline]
211 fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
212 x & y
213 }
214 #[inline]
215 fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
216 assert_ne!(y, 0);
217 x
218 }
219 #[inline]
220 fn mod_neg(x: Self::Inner) -> Self::Inner {
221 x
222 }
223 #[inline]
224 fn mod_inv(x: Self::Inner) -> Self::Inner {
225 assert_ne!(x, 0);
226 x
227 }
228 #[inline]
229 fn mod_pow(x: Self::Inner, y: usize) -> Self::Inner {
230 if y == 0 { 1 } else { x }
231 }
232}
233macro_rules! impl_to_mint_base_for_modulo2 {
234 ($name:ident, $basety:ty, [$($t:ty),*]) => {
235 $(impl MIntConvert<$t> for $name {
236 #[inline]
237 fn from(x: $t) -> Self::Inner {
238 (x & 1) as $basety
239 }
240 #[inline]
241 fn into(x: Self::Inner) -> $t {
242 x as $t
243 }
244 #[inline]
245 fn mod_into() -> $t {
246 1
247 }
248 })*
249 };
250}
251impl_to_mint_base_for_modulo2!(
252 Modulo2,
253 u32,
254 [
255 u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize
256 ]
257);
258pub type MInt2 = MInt<Modulo2>;
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::tools::Xorshift;
264
265 macro_rules! test_mint {
266 ($test_name:ident $mint:ident $($m:expr)?) => {
267 #[test]
268 fn $test_name() {
269 let mut rng = Xorshift::new();
270 const Q: usize = 10_000;
271 for _ in 0..Q {
272 $($mint::set_mod(rng.gen(..$m));)?
273 let a = $mint::new_unchecked(rng.random(1..$mint::get_mod()));
274 let x = a.inv();
275 assert!(x.inner() < $mint::get_mod());
276 assert_eq!(a * x, $mint::one());
277 }
278 }
279 };
280 }
281 test_mint!(test_mint2 MInt2);
282 test_mint!(test_mint998244353 MInt998244353);
283 test_mint!(test_mint1000000007 MInt1000000007);
284 test_mint!(test_mint1000000009 MInt1000000009);
285}