Skip to main content

competitive/num/mint/
montgomery.rs

1use super::*;
2
3impl<M> MIntBase for M
4where
5    M: MontgomeryReduction32,
6{
7    type Inner = u32;
8    fn get_mod() -> Self::Inner {
9        <Self as MontgomeryReduction32>::MOD
10    }
11    fn mod_zero() -> Self::Inner {
12        0
13    }
14    fn mod_one() -> Self::Inner {
15        Self::N1
16    }
17    fn mod_add(x: Self::Inner, y: Self::Inner) -> Self::Inner {
18        let z = x + y;
19        let m = Self::get_mod();
20        if z >= m { z - m } else { z }
21    }
22    fn mod_sub(x: Self::Inner, y: Self::Inner) -> Self::Inner {
23        if x < y {
24            x + Self::get_mod() - y
25        } else {
26            x - y
27        }
28    }
29    fn mod_mul(x: Self::Inner, y: Self::Inner) -> Self::Inner {
30        Self::reduce(x as u64 * y as u64)
31    }
32    fn mod_div(x: Self::Inner, y: Self::Inner) -> Self::Inner {
33        Self::mod_mul(x, Self::mod_inv(y))
34    }
35    fn mod_neg(x: Self::Inner) -> Self::Inner {
36        if x == 0 { 0 } else { Self::get_mod() - x }
37    }
38    fn mod_inv(x: Self::Inner) -> Self::Inner {
39        let p = Self::get_mod() as i32;
40        let (mut a, mut b) = (x as i32, p);
41        let (mut u, mut x) = (1, 0);
42        while a != 0 {
43            let k = b / a;
44            x -= k * u;
45            b -= k * a;
46            std::mem::swap(&mut x, &mut u);
47            std::mem::swap(&mut b, &mut a);
48        }
49        Self::reduce((if x < 0 { x + p } else { x }) as u64 * Self::N3 as u64)
50    }
51    fn mod_inner(x: Self::Inner) -> Self::Inner {
52        Self::reduce(x as u64)
53    }
54}
55impl<M> MIntConvert<u32> for M
56where
57    M: MontgomeryReduction32,
58{
59    fn from(x: u32) -> Self::Inner {
60        Self::reduce(x as u64 * Self::N2 as u64)
61    }
62    fn into(x: Self::Inner) -> u32 {
63        Self::reduce(x as u64)
64    }
65    fn mod_into() -> u32 {
66        <Self as MIntBase>::get_mod()
67    }
68}
69impl<M> MIntConvert<u64> for M
70where
71    M: MontgomeryReduction32,
72{
73    fn from(x: u64) -> Self::Inner {
74        Self::reduce(x % Self::get_mod() as u64 * Self::N2 as u64)
75    }
76    fn into(x: Self::Inner) -> u64 {
77        Self::reduce(x as u64) as u64
78    }
79    fn mod_into() -> u64 {
80        <Self as MIntBase>::get_mod() as u64
81    }
82}
83impl<M> MIntConvert<usize> for M
84where
85    M: MontgomeryReduction32,
86{
87    fn from(x: usize) -> Self::Inner {
88        Self::reduce(x as u64 % Self::get_mod() as u64 * Self::N2 as u64)
89    }
90    fn into(x: Self::Inner) -> usize {
91        Self::reduce(x as u64) as usize
92    }
93    fn mod_into() -> usize {
94        <Self as MIntBase>::get_mod() as usize
95    }
96}
97impl<M> MIntConvert<i32> for M
98where
99    M: MontgomeryReduction32,
100{
101    fn from(x: i32) -> Self::Inner {
102        let x = x % <Self as MIntBase>::get_mod() as i32;
103        let x = if x < 0 {
104            (x + <Self as MIntBase>::get_mod() as i32) as u64
105        } else {
106            x as u64
107        };
108        Self::reduce(x * Self::N2 as u64)
109    }
110    fn into(x: Self::Inner) -> i32 {
111        Self::reduce(x as u64) as i32
112    }
113    fn mod_into() -> i32 {
114        <Self as MIntBase>::get_mod() as i32
115    }
116}
117impl<M> MIntConvert<i64> for M
118where
119    M: MontgomeryReduction32,
120{
121    fn from(x: i64) -> Self::Inner {
122        let x = x % <Self as MIntBase>::get_mod() as i64;
123        let x = if x < 0 {
124            (x + <Self as MIntBase>::get_mod() as i64) as u64
125        } else {
126            x as u64
127        };
128        Self::reduce(x * Self::N2 as u64)
129    }
130    fn into(x: Self::Inner) -> i64 {
131        Self::reduce(x as u64) as i64
132    }
133    fn mod_into() -> i64 {
134        <Self as MIntBase>::get_mod() as i64
135    }
136}
137impl<M> MIntConvert<isize> for M
138where
139    M: MontgomeryReduction32,
140{
141    fn from(x: isize) -> Self::Inner {
142        let x = x % <Self as MIntBase>::get_mod() as isize;
143        let x = if x < 0 {
144            (x + <Self as MIntBase>::get_mod() as isize) as u64
145        } else {
146            x as u64
147        };
148        Self::reduce(x * Self::N2 as u64)
149    }
150    fn into(x: Self::Inner) -> isize {
151        Self::reduce(x as u64) as isize
152    }
153    fn mod_into() -> isize {
154        <Self as MIntBase>::get_mod() as isize
155    }
156}
157/// m is prime, n = 2^32
158pub trait MontgomeryReduction32 {
159    /// m
160    const MOD: u32;
161    /// (-m)^{-1} mod n
162    const R: u32 = {
163        let m = Self::MOD;
164        let mut r = 0;
165        let mut t = 0;
166        let mut i = 0;
167        while i < 32 {
168            if t % 2 == 0 {
169                t += m;
170                r += 1 << i;
171            }
172            t /= 2;
173            i += 1;
174        }
175        r
176    };
177    /// n^1 mod m
178    const N1: u32 = ((1u64 << 32) % Self::MOD as u64) as _;
179    /// n^2 mod m
180    const N2: u32 = (Self::N1 as u64 * Self::N1 as u64 % Self::MOD as u64) as _;
181    /// n^3 mod m
182    const N3: u32 = (Self::N1 as u64 * Self::N2 as u64 % Self::MOD as u64) as _;
183    /// n^{-1}x = (x + (xr mod n)m) / n
184    fn reduce(x: u64) -> u32 {
185        let m: u32 = Self::MOD;
186        let r = Self::R;
187        let mut x = ((x + r.wrapping_mul(x as u32) as u64 * m as u64) >> 32) as u32;
188        if x >= m {
189            x -= m;
190        }
191        x
192    }
193}
194macro_rules! define_montgomery_reduction_32 {
195    ($([$name:ident, $m:expr, $mint_name:ident $(,)?]),* $(,)?) => {
196        $(
197            pub enum $name {}
198            impl MontgomeryReduction32 for $name {
199                const MOD: u32 = $m;
200            }
201            pub type $mint_name = MInt<$name>;
202        )*
203    };
204}
205define_montgomery_reduction_32!(
206    [Modulo998244353, 998_244_353, MInt998244353],
207    [Modulo2113929217, 2_113_929_217, MInt2113929217],
208    [Modulo1811939329, 1_811_939_329, MInt1811939329],
209    [Modulo2013265921, 2_013_265_921, MInt2013265921],
210);
211
212#[cfg(target_arch = "x86_64")]
213#[allow(unsafe_op_in_unsafe_fn)] // SIMD intrinsics and raw pointers are confined here
214pub mod simd32 {
215    use std::arch::x86_64::*;
216
217    #[target_feature(enable = "avx2")]
218    unsafe fn my256_mullo_epu32(a: __m256i, b: __m256i) -> __m256i {
219        _mm256_mullo_epi32(a, b)
220    }
221
222    #[target_feature(enable = "avx2")]
223    unsafe fn my256_mulhi_epu32(a: __m256i, b: __m256i) -> __m256i {
224        let a13 = _mm256_shuffle_epi32(a, 0xF5);
225        let b13 = _mm256_shuffle_epi32(b, 0xF5);
226        let prod02 = _mm256_mul_epu32(a, b);
227        let prod13 = _mm256_mul_epu32(a13, b13);
228        let t0 = _mm256_unpacklo_epi32(prod02, prod13);
229        let t1 = _mm256_unpackhi_epi32(prod02, prod13);
230        _mm256_unpackhi_epi64(t0, t1)
231    }
232
233    #[target_feature(enable = "avx2")]
234    pub unsafe fn montgomery_mul_256(
235        a: __m256i,
236        b: __m256i,
237        r_vec: __m256i,
238        mod_vec: __m256i,
239    ) -> __m256i {
240        let hi = my256_mulhi_epu32(a, b);
241        let lo = my256_mullo_epu32(a, b);
242        let lo = my256_mullo_epu32(lo, r_vec);
243        let lo = my256_mulhi_epu32(lo, mod_vec);
244        _mm256_sub_epi32(_mm256_add_epi32(hi, mod_vec), lo)
245    }
246
247    #[target_feature(enable = "avx2")]
248    pub unsafe fn add_mod_256(a: __m256i, b: __m256i, mod_vec: __m256i, sign: __m256i) -> __m256i {
249        let sum = _mm256_add_epi32(a, b);
250        let sum_x = _mm256_xor_si256(sum, sign);
251        let mod_x = _mm256_xor_si256(mod_vec, sign);
252        let gt = _mm256_cmpgt_epi32(sum_x, mod_x);
253        let eq = _mm256_cmpeq_epi32(sum, mod_vec);
254        let mask = _mm256_or_si256(gt, eq);
255        let sub = _mm256_and_si256(mod_vec, mask);
256        _mm256_sub_epi32(sum, sub)
257    }
258
259    #[target_feature(enable = "avx2")]
260    pub unsafe fn sub_mod_256(a: __m256i, b: __m256i, mod_vec: __m256i, sign: __m256i) -> __m256i {
261        let diff = _mm256_sub_epi32(a, b);
262        let a_x = _mm256_xor_si256(a, sign);
263        let b_x = _mm256_xor_si256(b, sign);
264        let mask = _mm256_cmpgt_epi32(b_x, a_x);
265        let add = _mm256_and_si256(mod_vec, mask);
266        _mm256_add_epi32(diff, add)
267    }
268
269    #[target_feature(enable = "avx2")]
270    pub unsafe fn montgomery_mul_256_canon(
271        a: __m256i,
272        b: __m256i,
273        r_vec: __m256i,
274        mod_vec: __m256i,
275        sign: __m256i,
276    ) -> __m256i {
277        let x = montgomery_mul_256(a, b, r_vec, mod_vec);
278        add_mod_256(x, _mm256_setzero_si256(), mod_vec, sign)
279    }
280
281    #[target_feature(enable = "avx2")]
282    pub unsafe fn montgomery_add_256(
283        a: __m256i,
284        b: __m256i,
285        mod2_vec: __m256i,
286        sign: __m256i,
287    ) -> __m256i {
288        let sum = _mm256_add_epi32(a, b);
289        let sum_x = _mm256_xor_si256(sum, sign);
290        let mod_x = _mm256_xor_si256(mod2_vec, sign);
291        let gt = _mm256_cmpgt_epi32(sum_x, mod_x);
292        let eq = _mm256_cmpeq_epi32(sum, mod2_vec);
293        let mask = _mm256_or_si256(gt, eq);
294        let sub = _mm256_and_si256(mod2_vec, mask);
295        _mm256_sub_epi32(sum, sub)
296    }
297
298    #[target_feature(enable = "avx2")]
299    pub unsafe fn montgomery_sub_256(
300        a: __m256i,
301        b: __m256i,
302        mod2_vec: __m256i,
303        sign: __m256i,
304    ) -> __m256i {
305        let diff = _mm256_sub_epi32(a, b);
306        let a_x = _mm256_xor_si256(a, sign);
307        let b_x = _mm256_xor_si256(b, sign);
308        let mask = _mm256_cmpgt_epi32(b_x, a_x);
309        let add = _mm256_and_si256(mod2_vec, mask);
310        _mm256_add_epi32(diff, add)
311    }
312
313    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
314    unsafe fn my512_mullo_epu32(a: __m512i, b: __m512i) -> __m512i {
315        _mm512_mullo_epi32(a, b)
316    }
317
318    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
319    unsafe fn my512_mulhi_epu32(a: __m512i, b: __m512i) -> __m512i {
320        let a13 = _mm512_shuffle_epi32(a, 0xF5);
321        let b13 = _mm512_shuffle_epi32(b, 0xF5);
322        let prod02 = _mm512_mul_epu32(a, b);
323        let prod13 = _mm512_mul_epu32(a13, b13);
324        let t0 = _mm512_unpacklo_epi32(prod02, prod13);
325        let t1 = _mm512_unpackhi_epi32(prod02, prod13);
326        _mm512_unpackhi_epi64(t0, t1)
327    }
328
329    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
330    pub unsafe fn montgomery_mul_512(
331        a: __m512i,
332        b: __m512i,
333        r_vec: __m512i,
334        mod_vec: __m512i,
335    ) -> __m512i {
336        let hi = my512_mulhi_epu32(a, b);
337        let lo = my512_mullo_epu32(a, b);
338        let lo = my512_mullo_epu32(lo, r_vec);
339        let lo = my512_mulhi_epu32(lo, mod_vec);
340        _mm512_sub_epi32(_mm512_add_epi32(hi, mod_vec), lo)
341    }
342
343    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
344    pub unsafe fn add_mod_512(a: __m512i, b: __m512i, mod_vec: __m512i) -> __m512i {
345        let sum = _mm512_add_epi32(a, b);
346        let mask = !_mm512_cmp_epu32_mask(sum, mod_vec, _MM_CMPINT_LT);
347        _mm512_mask_sub_epi32(sum, mask, sum, mod_vec)
348    }
349
350    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
351    pub unsafe fn sub_mod_512(a: __m512i, b: __m512i, mod_vec: __m512i) -> __m512i {
352        let diff = _mm512_sub_epi32(a, b);
353        let mask = _mm512_cmp_epu32_mask(a, b, _MM_CMPINT_LT);
354        _mm512_mask_add_epi32(diff, mask, diff, mod_vec)
355    }
356
357    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
358    pub unsafe fn montgomery_mul_512_canon(
359        a: __m512i,
360        b: __m512i,
361        r_vec: __m512i,
362        mod_vec: __m512i,
363    ) -> __m512i {
364        let x = montgomery_mul_512(a, b, r_vec, mod_vec);
365        add_mod_512(x, _mm512_setzero_si512(), mod_vec)
366    }
367
368    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
369    pub unsafe fn montgomery_add_512(a: __m512i, b: __m512i, mod2_vec: __m512i) -> __m512i {
370        let sum = _mm512_add_epi32(a, b);
371        let mask = !_mm512_cmp_epu32_mask(sum, mod2_vec, _MM_CMPINT_LT);
372        _mm512_mask_sub_epi32(sum, mask, sum, mod2_vec)
373    }
374
375    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
376    pub unsafe fn montgomery_sub_512(a: __m512i, b: __m512i, mod2_vec: __m512i) -> __m512i {
377        let diff = _mm512_sub_epi32(a, b);
378        let mask = _mm512_cmp_epu32_mask(a, b, _MM_CMPINT_LT);
379        _mm512_mask_add_epi32(diff, mask, diff, mod2_vec)
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::num::montgomery::MInt998244353 as M;
387    use crate::tools::Xorshift;
388
389    #[test]
390    fn test_mint998244353() {
391        let mut rng = Xorshift::default();
392        const Q: usize = 1000;
393        assert_eq!(0, MInt998244353::zero().inner());
394        assert_eq!(1, MInt998244353::one().inner());
395        assert_eq!(
396            Modulo998244353::reduce(Modulo998244353::N3 as u64),
397            Modulo998244353::N2
398        );
399        assert_eq!(
400            Modulo998244353::reduce(Modulo998244353::N2 as u64),
401            Modulo998244353::N1
402        );
403        assert_eq!(Modulo998244353::reduce(Modulo998244353::N1 as u64), 1);
404        for _ in 0..Q {
405            let x = rng.random(..MInt998244353::get_mod());
406            assert_eq!(x, MInt998244353::new(x).inner());
407            assert_eq!((-M::new(x)).inner(), (-MInt998244353::new(x)).inner());
408            assert_eq!(x, MInt998244353::new(x).inv().inv().inner());
409            assert_eq!(M::new(x).inv().inner(), MInt998244353::new(x).inv().inner());
410        }
411
412        for _ in 0..Q {
413            let x = rng.random(..MInt998244353::get_mod());
414            let y = rng.random(..MInt998244353::get_mod());
415            assert_eq!(
416                (M::new(x) + M::new(y)).inner(),
417                (MInt998244353::new(x) + MInt998244353::new(y)).inner()
418            );
419            assert_eq!(
420                (M::new(x) - M::new(y)).inner(),
421                (MInt998244353::new(x) - MInt998244353::new(y)).inner()
422            );
423            assert_eq!(
424                (M::new(x) * M::new(y)).inner(),
425                (MInt998244353::new(x) * MInt998244353::new(y)).inner()
426            );
427            assert_eq!(
428                (M::new(x) / M::new(y)).inner(),
429                (MInt998244353::new(x) / MInt998244353::new(y)).inner()
430            );
431            assert_eq!(
432                M::new(x).pow(y as usize).inner(),
433                MInt998244353::new(x).pow(y as usize).inner()
434            );
435        }
436
437        for _ in 0..Q {
438            let x = rng.rand64();
439            assert_eq!(
440                M::from(x as u32).inner(),
441                MInt998244353::from(x as u32).inner()
442            );
443            assert_eq!(M::from(x).inner(), MInt998244353::from(x).inner());
444            assert_eq!(
445                M::from(x as usize).inner(),
446                MInt998244353::from(x as usize).inner()
447            );
448            assert_eq!(
449                M::from(x as i32).inner(),
450                MInt998244353::from(x as i32).inner()
451            );
452            assert_eq!(
453                M::from(x as i64).inner(),
454                MInt998244353::from(x as i64).inner()
455            );
456            assert_eq!(
457                M::from(x as isize).inner(),
458                MInt998244353::from(x as isize).inner()
459            );
460        }
461    }
462}