competitive/math/
number_theoretic_transform.rs

1use super::{ConvolveSteps, MInt, MIntBase, MIntConvert, One, Zero, montgomery::*};
2use std::marker::PhantomData;
3
4pub struct Convolve<M>(PhantomData<fn() -> M>);
5pub type Convolve998244353 = Convolve<Modulo998244353>;
6pub type MIntConvolve<M> = Convolve<(M, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
7
8macro_rules! impl_ntt_modulus {
9    ($([$name:ident, $g:expr]),*) => {
10        $(
11            impl Montgomery32NttModulus for $name {}
12        )*
13    };
14}
15impl_ntt_modulus!(
16    [Modulo998244353, 3],
17    [Modulo2113929217, 5],
18    [Modulo1811939329, 13],
19    [Modulo2013265921, 31]
20);
21
22const fn reduce(z: u64, p: u32, r: u32) -> u32 {
23    let mut z = ((z + r.wrapping_mul(z as u32) as u64 * p as u64) >> 32) as u32;
24    if z >= p {
25        z -= p;
26    }
27    z
28}
29const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
30    reduce(x as u64 * y as u64, p, r)
31}
32const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
33    while y > 0 {
34        if y & 1 == 1 {
35            z = mod_mul(z, x, p, r);
36        }
37        x = mod_mul(x, x, p, r);
38        y >>= 1;
39    }
40    z
41}
42
43pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
44    const PRIMITIVE_ROOT: u32 = {
45        let mut g = 3u32;
46        loop {
47            let mut ok = true;
48            let mut d = 1u32;
49            while d * d < Self::MOD {
50                if (Self::MOD - 1) % d == 0 {
51                    let ds = [d, (Self::MOD - 1) / d];
52                    let mut i = 0;
53                    while i < 2 {
54                        ok &= ds[i] == Self::MOD - 1
55                            || mod_pow(
56                                reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
57                                ds[i],
58                                Self::MOD,
59                                Self::R,
60                                Self::N1,
61                            ) != Self::N1;
62                        i += 1;
63                    }
64                }
65                d += 1;
66            }
67            if ok {
68                break;
69            }
70            g += 2;
71        }
72        g
73    };
74    const RANK: u32 = (Self::MOD - 1).trailing_zeros();
75    const INFO: NttInfo = NttInfo::new::<Self>();
76}
77
78pub struct NttInfo {
79    root: [u32; 32],
80    inv_root: [u32; 32],
81    rate2: [u32; 32],
82    inv_rate2: [u32; 32],
83    rate3: [u32; 32],
84    inv_rate3: [u32; 32],
85}
86impl NttInfo {
87    const fn new<M>() -> Self
88    where
89        M: Montgomery32NttModulus,
90    {
91        let mut root = [0; 32];
92        let mut inv_root = [0; 32];
93        let mut rate2 = [0; 32];
94        let mut inv_rate2 = [0; 32];
95        let mut rate3 = [0; 32];
96        let mut inv_rate3 = [0; 32];
97        let rank = M::RANK as usize;
98
99        let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
100        root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
101        inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
102        let mut i = rank - 1;
103        loop {
104            root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
105            inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
106            if i == 0 {
107                break;
108            }
109            i -= 1;
110        }
111
112        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
113        while i < rank - 1 {
114            rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
115            inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
116            prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
117            inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
118            i += 1;
119        }
120
121        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
122        while i < rank - 2 {
123            rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
124            inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
125            prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
126            inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
127            i += 1;
128        }
129
130        NttInfo {
131            root,
132            inv_root,
133            rate2,
134            inv_rate2,
135            rate3,
136            inv_rate3,
137        }
138    }
139}
140
141crate::avx_helper!(
142    @avx2 fn ntt<M>(a: &mut [MInt<M>])
143    where
144        [M: Montgomery32NttModulus]
145    {
146        let n = a.len();
147        let mut v = n / 2;
148        let imag = MInt::<M>::new_unchecked(M::INFO.root[2]);
149        while v > 1 {
150            let mut w1 = MInt::<M>::one();
151            for (s, a) in a.chunks_exact_mut(v << 1).enumerate() {
152                let (l, r) = a.split_at_mut(v);
153                let (ll, lr) = l.split_at_mut(v >> 1);
154                let (rl, rr) = r.split_at_mut(v >> 1);
155                let w2 = w1 * w1;
156                let w3 = w1 * w2;
157                for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
158                    let a0 = *x0;
159                    let a1 = *x1 * w1;
160                    let a2 = *x2 * w2;
161                    let a3 = *x3 * w3;
162                    let a0pa2 = a0 + a2;
163                    let a0na2 = a0 - a2;
164                    let a1pa3 = a1 + a3;
165                    let a1na3imag = (a1 - a3) * imag;
166                    *x0 = a0pa2 + a1pa3;
167                    *x1 = a0pa2 - a1pa3;
168                    *x2 = a0na2 + a1na3imag;
169                    *x3 = a0na2 - a1na3imag;
170                }
171                w1 *= MInt::<M>::new_unchecked(M::INFO.rate3[s.trailing_ones() as usize]);
172            }
173            v >>= 2;
174        }
175        if v == 1 {
176            let mut w1 = MInt::<M>::one();
177            for (s, a) in a.chunks_exact_mut(2).enumerate() {
178                unsafe {
179                    let (l, r) = a.split_at_mut(1);
180                    let x0 = l.get_unchecked_mut(0);
181                    let x1 = r.get_unchecked_mut(0);
182                    let a0 = *x0;
183                    let a1 = *x1 * w1;
184                    *x0 = a0 + a1;
185                    *x1 = a0 - a1;
186                }
187                w1 *= MInt::<M>::new_unchecked(M::INFO.rate2[s.trailing_ones() as usize]);
188            }
189        }
190    }
191);
192crate::avx_helper!(
193    @avx2 fn intt<M>(a: &mut [MInt<M>])
194    where
195        [M: Montgomery32NttModulus]
196    {
197        let n = a.len();
198        let mut v = 1;
199        if n.trailing_zeros() & 1 == 1 {
200            let mut w1 = MInt::<M>::one();
201            for (s, a) in a.chunks_exact_mut(2).enumerate() {
202                unsafe {
203                    let (l, r) = a.split_at_mut(1);
204                    let x0 = l.get_unchecked_mut(0);
205                    let x1 = r.get_unchecked_mut(0);
206                    let a0 = *x0;
207                    let a1 = *x1;
208                    *x0 = a0 + a1;
209                    *x1 = (a0 - a1) * w1;
210                }
211                w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate2[s.trailing_ones() as usize]);
212            }
213            v <<= 1;
214        }
215        let iimag = MInt::<M>::new_unchecked(M::INFO.inv_root[2]);
216        while v < n {
217            let mut w1 = MInt::<M>::one();
218            for (s, a) in a.chunks_exact_mut(v << 2).enumerate() {
219                let (l, r) = a.split_at_mut(v << 1);
220                let (ll, lr) = l.split_at_mut(v);
221                let (rl, rr) = r.split_at_mut(v);
222                let w2 = w1 * w1;
223                let w3 = w1 * w2;
224                for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
225                    let a0 = *x0;
226                    let a1 = *x1;
227                    let a2 = *x2;
228                    let a3 = *x3;
229                    let a0pa1 = a0 + a1;
230                    let a0na1 = a0 - a1;
231                    let a2pa3 = a2 + a3;
232                    let a2na3iimag = (a2 - a3) * iimag;
233                    *x0 = a0pa1 + a2pa3;
234                    *x1 = (a0na1 + a2na3iimag) * w1;
235                    *x2 = (a0pa1 - a2pa3) * w2;
236                    *x3 = (a0na1 - a2na3iimag) * w3;
237                }
238                w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate3[s.trailing_ones() as usize]);
239            }
240            v <<= 2;
241        }
242    }
243);
244
245fn convolve_naive<M>(a: &[MInt<M>], b: &[MInt<M>]) -> Vec<MInt<M>>
246where
247    M: MIntBase,
248{
249    if a.is_empty() && b.is_empty() {
250        return Vec::new();
251    }
252    let len = a.len() + b.len() - 1;
253    let mut c = vec![MInt::<M>::zero(); len];
254    if a.len() < b.len() {
255        for (i, &b) in b.iter().enumerate() {
256            for (a, c) in a.iter().zip(&mut c[i..]) {
257                *c += *a * b;
258            }
259        }
260    } else {
261        for (i, &a) in a.iter().enumerate() {
262            for (b, c) in b.iter().zip(&mut c[i..]) {
263                *c += *b * a;
264            }
265        }
266    }
267    c
268}
269impl<M> ConvolveSteps for Convolve<M>
270where
271    M: Montgomery32NttModulus,
272{
273    type T = Vec<MInt<M>>;
274    type F = Vec<MInt<M>>;
275    fn length(t: &Self::T) -> usize {
276        t.len()
277    }
278    fn transform(mut t: Self::T, len: usize) -> Self::F {
279        t.resize_with(len.max(2).next_power_of_two(), Zero::zero);
280        ntt(&mut t);
281        t
282    }
283    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
284        intt(&mut f);
285        f.truncate(len);
286        let inv = MInt::from(len.max(2).next_power_of_two() as u32).inv();
287        for f in f.iter_mut() {
288            *f *= inv;
289        }
290        f
291    }
292    fn multiply(f: &mut Self::F, g: &Self::F) {
293        assert_eq!(f.len(), g.len());
294        for (f, g) in f.iter_mut().zip(g.iter()) {
295            *f *= *g;
296        }
297    }
298    fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
299        if Self::length(&a).min(Self::length(&b)) <= 60 {
300            return convolve_naive(&a, &b);
301        }
302        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
303        let size = len.max(2).next_power_of_two();
304        if len <= size / 2 + 2 {
305            let xa = a.pop().unwrap();
306            let xb = b.pop().unwrap();
307            let mut c = vec![MInt::<M>::zero(); len];
308            *c.last_mut().unwrap() = xa * xb;
309            for (a, c) in a.iter().zip(&mut c[b.len()..]) {
310                *c += *a * xb;
311            }
312            for (b, c) in b.iter().zip(&mut c[a.len()..]) {
313                *c += *b * xa;
314            }
315            let d = Self::convolve(a, b);
316            for (d, c) in d.into_iter().zip(&mut c) {
317                *c += d;
318            }
319            return c;
320        }
321        let same = a == b;
322        let mut a = Self::transform(a, len);
323        if same {
324            for a in a.iter_mut() {
325                *a *= *a;
326            }
327        } else {
328            let b = Self::transform(b, len);
329            Self::multiply(&mut a, &b);
330        }
331        Self::inverse_transform(a, len)
332    }
333}
334type MVec<M> = Vec<MInt<M>>;
335impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
336where
337    M: MIntConvert + MIntConvert<u32>,
338    N1: Montgomery32NttModulus,
339    N2: Montgomery32NttModulus,
340    N3: Montgomery32NttModulus,
341{
342    type T = MVec<M>;
343    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
344    fn length(t: &Self::T) -> usize {
345        t.len()
346    }
347    fn transform(t: Self::T, len: usize) -> Self::F {
348        let npot = len.max(2).next_power_of_two();
349        let mut f = (
350            MVec::<N1>::with_capacity(npot),
351            MVec::<N2>::with_capacity(npot),
352            MVec::<N3>::with_capacity(npot),
353        );
354        for t in t {
355            f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
356            f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
357            f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
358        }
359        f.0.resize_with(npot, Zero::zero);
360        f.1.resize_with(npot, Zero::zero);
361        f.2.resize_with(npot, Zero::zero);
362        ntt(&mut f.0);
363        ntt(&mut f.1);
364        ntt(&mut f.2);
365        f
366    }
367    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
368        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
369        let m1 = MInt::<M>::from(N1::get_mod());
370        let m1_3 = MInt::<N3>::new(N1::get_mod());
371        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
372        let m2 = m1 * MInt::<M>::from(N2::get_mod());
373        Convolve::<N1>::inverse_transform(f.0, len)
374            .into_iter()
375            .zip(Convolve::<N2>::inverse_transform(f.1, len))
376            .zip(Convolve::<N3>::inverse_transform(f.2, len))
377            .map(|((c1, c2), c3)| {
378                let d1 = c1.inner();
379                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
380                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
381                let d3 = ((c3 - x) * t2).inner();
382                MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
383            })
384            .collect()
385    }
386    fn multiply(f: &mut Self::F, g: &Self::F) {
387        assert_eq!(f.0.len(), g.0.len());
388        assert_eq!(f.1.len(), g.1.len());
389        assert_eq!(f.2.len(), g.2.len());
390        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
391            *f *= *g;
392        }
393        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
394            *f *= *g;
395        }
396        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
397            *f *= *g;
398        }
399    }
400    fn convolve(a: Self::T, b: Self::T) -> Self::T {
401        if Self::length(&a).min(Self::length(&b)) <= 60 {
402            return convolve_naive(&a, &b);
403        }
404        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
405        let mut a = Self::transform(a, len);
406        let b = Self::transform(b, len);
407        Self::multiply(&mut a, &b);
408        Self::inverse_transform(a, len)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::num::{
416        mint_basic::Modulo1000000009,
417        montgomery::{MInt998244353, Modulo998244353},
418    };
419    use crate::tools::Xorshift;
420
421    const N: usize = 8;
422
423    #[test]
424    fn test_ntt998244353() {
425        let mut rng = Xorshift::new();
426        let a: Vec<_> = rng
427            .random_iter(..MInt998244353::get_mod())
428            .map(MInt998244353::new_unchecked)
429            .take(N)
430            .collect();
431        let b: Vec<_> = rng
432            .random_iter(..MInt998244353::get_mod())
433            .map(MInt998244353::new_unchecked)
434            .take(N)
435            .collect();
436        let mut c = vec![MInt998244353::zero(); N * 2 - 1];
437        for i in 0..N {
438            for j in 0..N {
439                c[i + j] += a[i] * b[j];
440            }
441        }
442        let d = Convolve::<Modulo998244353>::convolve(a, b);
443        assert_eq!(c, d);
444    }
445
446    #[test]
447    fn test_convolve3() {
448        type M = MInt<Modulo1000000009>;
449        let mut rng = Xorshift::new();
450        let a: Vec<_> = rng
451            .random_iter(..M::get_mod())
452            .map(M::new_unchecked)
453            .take(N)
454            .collect();
455        let b: Vec<_> = rng
456            .random_iter(..M::get_mod())
457            .map(M::new_unchecked)
458            .take(N)
459            .collect();
460        let mut c = vec![M::zero(); N * 2 - 1];
461        for i in 0..N {
462            for j in 0..N {
463                c[i + j] += a[i] * b[j];
464            }
465        }
466        let d = MIntConvolve::<Modulo1000000009>::convolve(a, b);
467        assert_eq!(c, d);
468    }
469
470    // #[test]
471    #[allow(dead_code)]
472    fn find_proth() {
473        use crate::math::{divisors, prime_factors_flatten};
474        use crate::num::mint_basic::DynMIntU32;
475        // p = a * 2^b + 1 (b >= 1, a < 2^b)
476        for b in 22..32 {
477            for a in (1..1u64 << b).step_by(2) {
478                let p = a * (1u64 << b) + 1;
479                if 1 << 31 < p {
480                    break;
481                }
482                if p < 1 << 29 {
483                    continue;
484                }
485                let f = prime_factors_flatten(p);
486                if f.len() == 1 && f[0] == p {
487                    DynMIntU32::set_mod(p as u32);
488                    for g in (3..).step_by(2) {
489                        let g = DynMIntU32::new(g);
490                        if divisors(p - 1)
491                            .into_iter()
492                            .filter(|&d| d != p - 1)
493                            .all(|d| g.pow(d as usize) != DynMIntU32::one())
494                        {
495                            println!("(p,a,b,g) = {:?}", (p, a, b, g));
496                            break;
497                        }
498                    }
499                }
500            }
501        }
502        // (p,a,b,g) = (666894337, 159, 22, 5)
503        // (p,a,b,g) = (683671553, 163, 22, 3)
504        // (p,a,b,g) = (918552577, 219, 22, 5)
505        // (p,a,b,g) = (935329793, 223, 22, 3)
506        // (p,a,b,g) = (943718401, 225, 22, 7)
507        // (p,a,b,g) = (985661441, 235, 22, 3)
508        // (p,a,b,g) = (1161822209, 277, 22, 3)
509        // (p,a,b,g) = (1212153857, 289, 22, 3)
510        // (p,a,b,g) = (1321205761, 315, 22, 11)
511        // (p,a,b,g) = (1438646273, 343, 22, 3)
512        // (p,a,b,g) = (1572864001, 375, 22, 13)
513        // (p,a,b,g) = (1790967809, 427, 22, 13)
514        // (p,a,b,g) = (1866465281, 445, 22, 3)
515        // (p,a,b,g) = (2025848833, 483, 22, 11)
516        // (p,a,b,g) = (595591169, 71, 23, 3)
517        // (p,a,b,g) = (645922817, 77, 23, 3)
518        // (p,a,b,g) = (880803841, 105, 23, 37)
519        // (p,a,b,g) = (897581057, 107, 23, 3)
520        // (p,a,b,g) = (998244353, 119, 23, 3)
521        // (p,a,b,g) = (1300234241, 155, 23, 3)
522        // (p,a,b,g) = (1484783617, 177, 23, 5)
523        // (p,a,b,g) = (2088763393, 249, 23, 5)
524        // (p,a,b,g) = (754974721, 45, 24, 11)
525        // (p,a,b,g) = (1224736769, 73, 24, 3)
526        // (p,a,b,g) = (2130706433, 127, 24, 3)
527        // (p,a,b,g) = (1107296257, 33, 25, 31)
528        // (p,a,b,g) = (1711276033, 51, 25, 29)
529        // (p,a,b,g) = (2113929217, 63, 25, 5)
530        // (p,a,b,g) = (1811939329, 27, 26, 13)
531        // (p,a,b,g) = (2013265921, 15, 27, 31)
532    }
533}