competitive/num/mint/
montgomery.rs

1use super::*;
2
3impl<M> MIntBase for M
4where
5    M: MontgomeryReduction32,
6{
7    type Inner = u32;
8    #[inline]
9    fn get_mod() -> Self::Inner {
10        <Self as MontgomeryReduction32>::MOD
11    }
12    #[inline]
13    fn mod_zero() -> Self::Inner {
14        0
15    }
16    #[inline]
17    fn mod_one() -> Self::Inner {
18        Self::N1
19    }
20    #[inline]
21    fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
22        let z = x + y;
23        let m = Self::get_mod();
24        if z >= m { z - m } else { z }
25    }
26    #[inline]
27    fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
28        if x < y {
29            x + Self::get_mod() - y
30        } else {
31            x - y
32        }
33    }
34    #[inline]
35    fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
36        Self::reduce(x as u64 * y as u64)
37    }
38    #[inline]
39    fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
40        Self::mod_mul(x, Self::mod_inv(y))
41    }
42    #[inline]
43    fn mod_neg(x: Self::Inner) -> Self::Inner {
44        if x == 0 { 0 } else { Self::get_mod() - x }
45    }
46    fn mod_inv(x: Self::Inner) -> Self::Inner {
47        let p = Self::get_mod() as i32;
48        let (mut a, mut b) = (x as i32, p);
49        let (mut u, mut x) = (1, 0);
50        while a != 0 {
51            let k = b / a;
52            x -= k * u;
53            b -= k * a;
54            std::mem::swap(&mut x, &mut u);
55            std::mem::swap(&mut b, &mut a);
56        }
57        Self::reduce((if x < 0 { x + p } else { x }) as u64 * Self::N3 as u64)
58    }
59    fn mod_inner(x: Self::Inner) -> Self::Inner {
60        Self::reduce(x as u64)
61    }
62}
63impl<M> MIntConvert<u32> for M
64where
65    M: MontgomeryReduction32,
66{
67    #[inline]
68    fn from(x: u32) -> Self::Inner {
69        Self::reduce(x as u64 * Self::N2 as u64)
70    }
71    #[inline]
72    fn into(x: Self::Inner) -> u32 {
73        Self::reduce(x as u64)
74    }
75    #[inline]
76    fn mod_into() -> u32 {
77        <Self as MIntBase>::get_mod()
78    }
79}
80impl<M> MIntConvert<u64> for M
81where
82    M: MontgomeryReduction32,
83{
84    #[inline]
85    fn from(x: u64) -> Self::Inner {
86        Self::reduce(x % Self::get_mod() as u64 * Self::N2 as u64)
87    }
88    #[inline]
89    fn into(x: Self::Inner) -> u64 {
90        Self::reduce(x as u64) as u64
91    }
92    #[inline]
93    fn mod_into() -> u64 {
94        <Self as MIntBase>::get_mod() as u64
95    }
96}
97impl<M> MIntConvert<usize> for M
98where
99    M: MontgomeryReduction32,
100{
101    #[inline]
102    fn from(x: usize) -> Self::Inner {
103        Self::reduce(x as u64 % Self::get_mod() as u64 * Self::N2 as u64)
104    }
105    #[inline]
106    fn into(x: Self::Inner) -> usize {
107        Self::reduce(x as u64) as usize
108    }
109    #[inline]
110    fn mod_into() -> usize {
111        <Self as MIntBase>::get_mod() as usize
112    }
113}
114impl<M> MIntConvert<i32> for M
115where
116    M: MontgomeryReduction32,
117{
118    #[inline]
119    fn from(x: i32) -> Self::Inner {
120        let x = x % <Self as MIntBase>::get_mod() as i32;
121        let x = if x < 0 {
122            (x + <Self as MIntBase>::get_mod() as i32) as u64
123        } else {
124            x as u64
125        };
126        Self::reduce(x * Self::N2 as u64)
127    }
128    #[inline]
129    fn into(x: Self::Inner) -> i32 {
130        Self::reduce(x as u64) as i32
131    }
132    #[inline]
133    fn mod_into() -> i32 {
134        <Self as MIntBase>::get_mod() as i32
135    }
136}
137impl<M> MIntConvert<i64> for M
138where
139    M: MontgomeryReduction32,
140{
141    #[inline]
142    fn from(x: i64) -> Self::Inner {
143        let x = x % <Self as MIntBase>::get_mod() as i64;
144        let x = if x < 0 {
145            (x + <Self as MIntBase>::get_mod() as i64) as u64
146        } else {
147            x as u64
148        };
149        Self::reduce(x * Self::N2 as u64)
150    }
151    #[inline]
152    fn into(x: Self::Inner) -> i64 {
153        Self::reduce(x as u64) as i64
154    }
155    #[inline]
156    fn mod_into() -> i64 {
157        <Self as MIntBase>::get_mod() as i64
158    }
159}
160impl<M> MIntConvert<isize> for M
161where
162    M: MontgomeryReduction32,
163{
164    #[inline]
165    fn from(x: isize) -> Self::Inner {
166        let x = x % <Self as MIntBase>::get_mod() as isize;
167        let x = if x < 0 {
168            (x + <Self as MIntBase>::get_mod() as isize) as u64
169        } else {
170            x as u64
171        };
172        Self::reduce(x * Self::N2 as u64)
173    }
174    #[inline]
175    fn into(x: Self::Inner) -> isize {
176        Self::reduce(x as u64) as isize
177    }
178    #[inline]
179    fn mod_into() -> isize {
180        <Self as MIntBase>::get_mod() as isize
181    }
182}
183/// m is prime, n = 2^32
184pub trait MontgomeryReduction32 {
185    /// m
186    const MOD: u32;
187    /// (-m)^{-1} mod n
188    const R: u32 = {
189        let m = Self::MOD;
190        let mut r = 0;
191        let mut t = 0;
192        let mut i = 0;
193        while i < 32 {
194            if t % 2 == 0 {
195                t += m;
196                r += 1 << i;
197            }
198            t /= 2;
199            i += 1;
200        }
201        r
202    };
203    /// n^1 mod m
204    const N1: u32 = ((1u64 << 32) % Self::MOD as u64) as _;
205    /// n^2 mod m
206    const N2: u32 = (Self::N1 as u64 * Self::N1 as u64 % Self::MOD as u64) as _;
207    /// n^3 mod m
208    const N3: u32 = (Self::N1 as u64 * Self::N2 as u64 % Self::MOD as u64) as _;
209    /// n^{-1}x = (x + (xr mod n)m) / n
210    fn reduce(x: u64) -> u32 {
211        let m: u32 = Self::MOD;
212        let r = Self::R;
213        let mut x = ((x + r.wrapping_mul(x as u32) as u64 * m as u64) >> 32) as u32;
214        if x >= m {
215            x -= m;
216        }
217        x
218    }
219}
220macro_rules! define_montgomery_reduction_32 {
221    ($([$name:ident, $m:expr, $mint_name:ident $(,)?]),* $(,)?) => {
222        $(
223            pub enum $name {}
224            impl MontgomeryReduction32 for $name {
225                const MOD: u32 = $m;
226            }
227            pub type $mint_name = MInt<$name>;
228        )*
229    };
230}
231define_montgomery_reduction_32!(
232    [Modulo998244353, 998_244_353, MInt998244353],
233    [Modulo2113929217, 2_113_929_217, MInt2113929217],
234    [Modulo1811939329, 1_811_939_329, MInt1811939329],
235    [Modulo2013265921, 2_013_265_921, MInt2013265921],
236);
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::num::montgomery::MInt998244353 as M;
242    use crate::tools::Xorshift;
243
244    #[test]
245    fn test_mint998244353() {
246        let mut rng = Xorshift::new();
247        const Q: usize = 1000;
248        assert_eq!(0, MInt998244353::zero().inner());
249        assert_eq!(1, MInt998244353::one().inner());
250        assert_eq!(
251            Modulo998244353::reduce(Modulo998244353::N3 as u64),
252            Modulo998244353::N2
253        );
254        assert_eq!(
255            Modulo998244353::reduce(Modulo998244353::N2 as u64),
256            Modulo998244353::N1
257        );
258        assert_eq!(Modulo998244353::reduce(Modulo998244353::N1 as u64), 1);
259        for _ in 0..Q {
260            let x = rng.random(..MInt998244353::get_mod());
261            assert_eq!(x, MInt998244353::new(x).inner());
262            assert_eq!((-M::new(x)).inner(), (-MInt998244353::new(x)).inner());
263            assert_eq!(x, MInt998244353::new(x).inv().inv().inner());
264            assert_eq!(M::new(x).inv().inner(), MInt998244353::new(x).inv().inner());
265        }
266
267        for _ in 0..Q {
268            let x = rng.random(..MInt998244353::get_mod());
269            let y = rng.random(..MInt998244353::get_mod());
270            assert_eq!(
271                (M::new(x) + M::new(y)).inner(),
272                (MInt998244353::new(x) + MInt998244353::new(y)).inner()
273            );
274            assert_eq!(
275                (M::new(x) - M::new(y)).inner(),
276                (MInt998244353::new(x) - MInt998244353::new(y)).inner()
277            );
278            assert_eq!(
279                (M::new(x) * M::new(y)).inner(),
280                (MInt998244353::new(x) * MInt998244353::new(y)).inner()
281            );
282            assert_eq!(
283                (M::new(x) / M::new(y)).inner(),
284                (MInt998244353::new(x) / MInt998244353::new(y)).inner()
285            );
286            assert_eq!(
287                M::new(x).pow(y as usize).inner(),
288                MInt998244353::new(x).pow(y as usize).inner()
289            );
290        }
291
292        for _ in 0..Q {
293            let x = rng.rand64();
294            assert_eq!(
295                M::from(x as u32).inner(),
296                MInt998244353::from(x as u32).inner()
297            );
298            assert_eq!(M::from(x).inner(), MInt998244353::from(x).inner());
299            assert_eq!(
300                M::from(x as usize).inner(),
301                MInt998244353::from(x as usize).inner()
302            );
303            assert_eq!(
304                M::from(x as i32).inner(),
305                MInt998244353::from(x as i32).inner()
306            );
307            assert_eq!(
308                M::from(x as i64).inner(),
309                MInt998244353::from(x as i64).inner()
310            );
311            assert_eq!(
312                M::from(x as isize).inner(),
313                MInt998244353::from(x as isize).inner()
314            );
315        }
316    }
317}