competitive/math/
number_theoretic_transform.rs

1use super::{ConvolveSteps, MInt, MIntBase, MIntConvert, One, Zero, montgomery::*};
2use std::{
3    cell::UnsafeCell,
4    marker::PhantomData,
5    ops::{AddAssign, Mul, SubAssign},
6};
7
8pub struct Convolve<M>(PhantomData<fn() -> M>);
9pub type Convolve998244353 = Convolve<Modulo998244353>;
10pub type MIntConvolve<M> = Convolve<(M, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
11pub type U64Convolve = Convolve<(u64, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
12
13macro_rules! impl_ntt_modulus {
14    ($([$name:ident, $g:expr]),*) => {
15        $(
16            impl Montgomery32NttModulus for $name {}
17        )*
18    };
19}
20impl_ntt_modulus!(
21    [Modulo998244353, 3],
22    [Modulo2113929217, 5],
23    [Modulo1811939329, 13],
24    [Modulo2013265921, 31]
25);
26
27const fn reduce(z: u64, p: u32, r: u32) -> u32 {
28    let mut z = ((z + r.wrapping_mul(z as u32) as u64 * p as u64) >> 32) as u32;
29    if z >= p {
30        z -= p;
31    }
32    z
33}
34const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
35    reduce(x as u64 * y as u64, p, r)
36}
37const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
38    while y > 0 {
39        if y & 1 == 1 {
40            z = mod_mul(z, x, p, r);
41        }
42        x = mod_mul(x, x, p, r);
43        y >>= 1;
44    }
45    z
46}
47
48pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
49    const PRIMITIVE_ROOT: u32 = {
50        let mut g = 3u32;
51        loop {
52            let mut ok = true;
53            let mut d = 1u32;
54            while d * d < Self::MOD {
55                if (Self::MOD - 1) % d == 0 {
56                    let ds = [d, (Self::MOD - 1) / d];
57                    let mut i = 0;
58                    while i < 2 {
59                        ok &= ds[i] == Self::MOD - 1
60                            || mod_pow(
61                                reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
62                                ds[i],
63                                Self::MOD,
64                                Self::R,
65                                Self::N1,
66                            ) != Self::N1;
67                        i += 1;
68                    }
69                }
70                d += 1;
71            }
72            if ok {
73                break;
74            }
75            g += 2;
76        }
77        g
78    };
79    const RANK: u32 = (Self::MOD - 1).trailing_zeros();
80    const INFO: NttInfo = NttInfo::new::<Self>();
81}
82
83#[derive(Debug, PartialEq)]
84pub struct NttInfo {
85    root: [u32; 32],
86    inv_root: [u32; 32],
87    rate2: [u32; 32],
88    inv_rate2: [u32; 32],
89    rate3: [u32; 32],
90    inv_rate3: [u32; 32],
91}
92impl NttInfo {
93    const fn new<M>() -> Self
94    where
95        M: Montgomery32NttModulus,
96    {
97        let mut root = [0; 32];
98        let mut inv_root = [0; 32];
99        let mut rate2 = [0; 32];
100        let mut inv_rate2 = [0; 32];
101        let mut rate3 = [0; 32];
102        let mut inv_rate3 = [0; 32];
103        let rank = M::RANK as usize;
104
105        let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
106        root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
107        inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
108        let mut i = rank - 1;
109        loop {
110            root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
111            inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
112            if i == 0 {
113                break;
114            }
115            i -= 1;
116        }
117
118        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
119        while i < rank - 1 {
120            rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
121            inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
122            prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
123            inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
124            i += 1;
125        }
126
127        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
128        while i < rank - 2 {
129            rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
130            inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
131            prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
132            inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
133            i += 1;
134        }
135
136        NttInfo {
137            root,
138            inv_root,
139            rate2,
140            inv_rate2,
141            rate3,
142            inv_rate3,
143        }
144    }
145}
146
147crate::avx_helper!(
148    @avx2 fn ntt<M>(a: &mut [MInt<M>])
149    where
150        [M: Montgomery32NttModulus]
151    {
152        let n = a.len();
153        let mut v = n / 2;
154        let imag = MInt::<M>::new_unchecked(M::INFO.root[2]);
155        while v > 1 {
156            let mut w1 = MInt::<M>::one();
157            for (s, a) in a.chunks_exact_mut(v << 1).enumerate() {
158                let (l, r) = a.split_at_mut(v);
159                let (ll, lr) = l.split_at_mut(v >> 1);
160                let (rl, rr) = r.split_at_mut(v >> 1);
161                let w2 = w1 * w1;
162                let w3 = w1 * w2;
163                for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
164                    let a0 = *x0;
165                    let a1 = *x1 * w1;
166                    let a2 = *x2 * w2;
167                    let a3 = *x3 * w3;
168                    let a0pa2 = a0 + a2;
169                    let a0na2 = a0 - a2;
170                    let a1pa3 = a1 + a3;
171                    let a1na3imag = (a1 - a3) * imag;
172                    *x0 = a0pa2 + a1pa3;
173                    *x1 = a0pa2 - a1pa3;
174                    *x2 = a0na2 + a1na3imag;
175                    *x3 = a0na2 - a1na3imag;
176                }
177                w1 *= MInt::<M>::new_unchecked(M::INFO.rate3[s.trailing_ones() as usize]);
178            }
179            v >>= 2;
180        }
181        if v == 1 {
182            let mut w1 = MInt::<M>::one();
183            for (s, a) in a.chunks_exact_mut(2).enumerate() {
184                unsafe {
185                    let (l, r) = a.split_at_mut(1);
186                    let x0 = l.get_unchecked_mut(0);
187                    let x1 = r.get_unchecked_mut(0);
188                    let a0 = *x0;
189                    let a1 = *x1 * w1;
190                    *x0 = a0 + a1;
191                    *x1 = a0 - a1;
192                }
193                w1 *= MInt::<M>::new_unchecked(M::INFO.rate2[s.trailing_ones() as usize]);
194            }
195        }
196    }
197);
198crate::avx_helper!(
199    @avx2 fn intt<M>(a: &mut [MInt<M>])
200    where
201        [M: Montgomery32NttModulus]
202    {
203        let n = a.len();
204        let mut v = 1;
205        if n.trailing_zeros() & 1 == 1 {
206            let mut w1 = MInt::<M>::one();
207            for (s, a) in a.chunks_exact_mut(2).enumerate() {
208                unsafe {
209                    let (l, r) = a.split_at_mut(1);
210                    let x0 = l.get_unchecked_mut(0);
211                    let x1 = r.get_unchecked_mut(0);
212                    let a0 = *x0;
213                    let a1 = *x1;
214                    *x0 = a0 + a1;
215                    *x1 = (a0 - a1) * w1;
216                }
217                w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate2[s.trailing_ones() as usize]);
218            }
219            v <<= 1;
220        }
221        let iimag = MInt::<M>::new_unchecked(M::INFO.inv_root[2]);
222        while v < n {
223            let mut w1 = MInt::<M>::one();
224            for (s, a) in a.chunks_exact_mut(v << 2).enumerate() {
225                let (l, r) = a.split_at_mut(v << 1);
226                let (ll, lr) = l.split_at_mut(v);
227                let (rl, rr) = r.split_at_mut(v);
228                let w2 = w1 * w1;
229                let w3 = w1 * w2;
230                for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
231                    let a0 = *x0;
232                    let a1 = *x1;
233                    let a2 = *x2;
234                    let a3 = *x3;
235                    let a0pa1 = a0 + a1;
236                    let a0na1 = a0 - a1;
237                    let a2pa3 = a2 + a3;
238                    let a2na3iimag = (a2 - a3) * iimag;
239                    *x0 = a0pa1 + a2pa3;
240                    *x1 = (a0na1 + a2na3iimag) * w1;
241                    *x2 = (a0pa1 - a2pa3) * w2;
242                    *x3 = (a0na1 - a2na3iimag) * w3;
243                }
244                w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate3[s.trailing_ones() as usize]);
245            }
246            v <<= 2;
247        }
248    }
249);
250
251fn convolve_naive<T>(a: &[T], b: &[T]) -> Vec<T>
252where
253    T: Copy + Zero + AddAssign<T> + Mul<Output = T>,
254{
255    if a.is_empty() && b.is_empty() {
256        return Vec::new();
257    }
258    let len = a.len() + b.len() - 1;
259    let mut c = vec![T::zero(); len];
260    if a.len() < b.len() {
261        for (i, &b) in b.iter().enumerate() {
262            for (a, c) in a.iter().zip(&mut c[i..]) {
263                *c += *a * b;
264            }
265        }
266    } else {
267        for (i, &a) in a.iter().enumerate() {
268            for (b, c) in b.iter().zip(&mut c[i..]) {
269                *c += *b * a;
270            }
271        }
272    }
273    c
274}
275
276fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
277where
278    T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
279{
280    if a.len().min(b.len()) <= 30 {
281        return convolve_naive(a, b);
282    }
283    let m = a.len().max(b.len()).div_ceil(2);
284    let (a0, a1) = if a.len() <= m {
285        (a, &[][..])
286    } else {
287        a.split_at(m)
288    };
289    let (b0, b1) = if b.len() <= m {
290        (b, &[][..])
291    } else {
292        b.split_at(m)
293    };
294    let f00 = convolve_karatsuba(a0, b0);
295    let f11 = convolve_karatsuba(a1, b1);
296    let mut a0a1 = a0.to_vec();
297    for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
298        *a0a1 += a1;
299    }
300    let mut b0b1 = b0.to_vec();
301    for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
302        *b0b1 += b1;
303    }
304    let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
305    for (f01, &f00) in f01.iter_mut().zip(&f00) {
306        *f01 -= f00;
307    }
308    for (f01, &f11) in f01.iter_mut().zip(&f11) {
309        *f01 -= f11;
310    }
311    let mut c = vec![T::zero(); a.len() + b.len() - 1];
312    for (c, &f00) in c.iter_mut().zip(&f00) {
313        *c += f00;
314    }
315    for (c, &f01) in c[m..].iter_mut().zip(&f01) {
316        *c += f01;
317    }
318    for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
319        *c += f11;
320    }
321    c
322}
323
324impl<M> ConvolveSteps for Convolve<M>
325where
326    M: Montgomery32NttModulus,
327{
328    type T = Vec<MInt<M>>;
329    type F = Vec<MInt<M>>;
330    fn length(t: &Self::T) -> usize {
331        t.len()
332    }
333    fn transform(mut t: Self::T, len: usize) -> Self::F {
334        t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
335        ntt(&mut t);
336        t
337    }
338    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
339        intt(&mut f);
340        f.truncate(len);
341        let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
342        for f in f.iter_mut() {
343            *f *= inv;
344        }
345        f
346    }
347    fn multiply(f: &mut Self::F, g: &Self::F) {
348        assert_eq!(f.len(), g.len());
349        for (f, g) in f.iter_mut().zip(g.iter()) {
350            *f *= *g;
351        }
352    }
353    fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
354        if Self::length(&a).max(Self::length(&b)) <= 100 {
355            return convolve_karatsuba(&a, &b);
356        }
357        if Self::length(&a).min(Self::length(&b)) <= 60 {
358            return convolve_naive(&a, &b);
359        }
360        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
361        let size = len.max(1).next_power_of_two();
362        if len <= size / 2 + 2 {
363            let xa = a.pop().unwrap();
364            let xb = b.pop().unwrap();
365            let mut c = vec![MInt::<M>::zero(); len];
366            *c.last_mut().unwrap() = xa * xb;
367            for (a, c) in a.iter().zip(&mut c[b.len()..]) {
368                *c += *a * xb;
369            }
370            for (b, c) in b.iter().zip(&mut c[a.len()..]) {
371                *c += *b * xa;
372            }
373            let d = Self::convolve(a, b);
374            for (d, c) in d.into_iter().zip(&mut c) {
375                *c += d;
376            }
377            return c;
378        }
379        let same = a == b;
380        let mut a = Self::transform(a, len);
381        if same {
382            for a in a.iter_mut() {
383                *a *= *a;
384            }
385        } else {
386            let b = Self::transform(b, len);
387            Self::multiply(&mut a, &b);
388        }
389        Self::inverse_transform(a, len)
390    }
391}
392
393type MVec<M> = Vec<MInt<M>>;
394impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
395where
396    M: MIntConvert + MIntConvert<u32>,
397    N1: Montgomery32NttModulus,
398    N2: Montgomery32NttModulus,
399    N3: Montgomery32NttModulus,
400{
401    type T = MVec<M>;
402    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
403    fn length(t: &Self::T) -> usize {
404        t.len()
405    }
406    fn transform(t: Self::T, len: usize) -> Self::F {
407        let npot = len.max(1).next_power_of_two();
408        let mut f = (
409            MVec::<N1>::with_capacity(npot),
410            MVec::<N2>::with_capacity(npot),
411            MVec::<N3>::with_capacity(npot),
412        );
413        for t in t {
414            f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
415            f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
416            f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
417        }
418        f.0.resize_with(npot, Zero::zero);
419        f.1.resize_with(npot, Zero::zero);
420        f.2.resize_with(npot, Zero::zero);
421        ntt(&mut f.0);
422        ntt(&mut f.1);
423        ntt(&mut f.2);
424        f
425    }
426    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
427        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
428        let m1 = MInt::<M>::from(N1::get_mod());
429        let m1_3 = MInt::<N3>::new(N1::get_mod());
430        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
431        let m2 = m1 * MInt::<M>::from(N2::get_mod());
432        Convolve::<N1>::inverse_transform(f.0, len)
433            .into_iter()
434            .zip(Convolve::<N2>::inverse_transform(f.1, len))
435            .zip(Convolve::<N3>::inverse_transform(f.2, len))
436            .map(|((c1, c2), c3)| {
437                let d1 = c1.inner();
438                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
439                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
440                let d3 = ((c3 - x) * t2).inner();
441                MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
442            })
443            .collect()
444    }
445    fn multiply(f: &mut Self::F, g: &Self::F) {
446        assert_eq!(f.0.len(), g.0.len());
447        assert_eq!(f.1.len(), g.1.len());
448        assert_eq!(f.2.len(), g.2.len());
449        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
450            *f *= *g;
451        }
452        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
453            *f *= *g;
454        }
455        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
456            *f *= *g;
457        }
458    }
459    fn convolve(a: Self::T, b: Self::T) -> Self::T {
460        if Self::length(&a).max(Self::length(&b)) <= 300 {
461            return convolve_karatsuba(&a, &b);
462        }
463        if Self::length(&a).min(Self::length(&b)) <= 60 {
464            return convolve_naive(&a, &b);
465        }
466        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
467        let mut a = Self::transform(a, len);
468        let b = Self::transform(b, len);
469        Self::multiply(&mut a, &b);
470        Self::inverse_transform(a, len)
471    }
472}
473
474impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
475where
476    N1: Montgomery32NttModulus,
477    N2: Montgomery32NttModulus,
478    N3: Montgomery32NttModulus,
479{
480    type T = Vec<u64>;
481    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
482
483    fn length(t: &Self::T) -> usize {
484        t.len()
485    }
486
487    fn transform(t: Self::T, len: usize) -> Self::F {
488        let npot = len.max(1).next_power_of_two();
489        let mut f = (
490            MVec::<N1>::with_capacity(npot),
491            MVec::<N2>::with_capacity(npot),
492            MVec::<N3>::with_capacity(npot),
493        );
494        for t in t {
495            f.0.push(t.into());
496            f.1.push(t.into());
497            f.2.push(t.into());
498        }
499        f.0.resize_with(npot, Zero::zero);
500        f.1.resize_with(npot, Zero::zero);
501        f.2.resize_with(npot, Zero::zero);
502        ntt(&mut f.0);
503        ntt(&mut f.1);
504        ntt(&mut f.2);
505        f
506    }
507
508    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
509        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
510        let m1 = N1::get_mod() as u64;
511        let m1_3 = MInt::<N3>::new(N1::get_mod());
512        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
513        let m2 = m1 * N2::get_mod() as u64;
514        Convolve::<N1>::inverse_transform(f.0, len)
515            .into_iter()
516            .zip(Convolve::<N2>::inverse_transform(f.1, len))
517            .zip(Convolve::<N3>::inverse_transform(f.2, len))
518            .map(|((c1, c2), c3)| {
519                let d1 = c1.inner();
520                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
521                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
522                let d3 = ((c3 - x) * t2).inner();
523                d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
524            })
525            .collect()
526    }
527
528    fn multiply(f: &mut Self::F, g: &Self::F) {
529        assert_eq!(f.0.len(), g.0.len());
530        assert_eq!(f.1.len(), g.1.len());
531        assert_eq!(f.2.len(), g.2.len());
532        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
533            *f *= *g;
534        }
535        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
536            *f *= *g;
537        }
538        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
539            *f *= *g;
540        }
541    }
542
543    fn convolve(a: Self::T, b: Self::T) -> Self::T {
544        if Self::length(&a).max(Self::length(&b)) <= 300 {
545            return convolve_karatsuba(&a, &b);
546        }
547        if Self::length(&a).min(Self::length(&b)) <= 60 {
548            return convolve_naive(&a, &b);
549        }
550        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
551        let mut a = Self::transform(a, len);
552        let b = Self::transform(b, len);
553        Self::multiply(&mut a, &b);
554        Self::inverse_transform(a, len)
555    }
556}
557
558pub trait NttReuse: ConvolveSteps {
559    const MULTIPLE: bool = true;
560
561    /// F(a) → F(a + [0] * a.len())
562    fn ntt_doubling(f: Self::F) -> Self::F;
563
564    /// F(a(x)), F(b(x)) → even(F(a(x) * b(-x)))
565    fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
566
567    /// F(a(x)), F(b(x)) → odd(F(a(x) * b(-x)))
568    fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
569}
570
571thread_local!(
572    static BIT_REVERSE: UnsafeCell<Vec<Vec<usize>>> = const { UnsafeCell::new(vec![]) };
573);
574
575impl<M> NttReuse for Convolve<M>
576where
577    M: Montgomery32NttModulus,
578{
579    const MULTIPLE: bool = false;
580
581    fn ntt_doubling(mut f: Self::F) -> Self::F {
582        let n = f.len();
583        let k = n.trailing_zeros() as usize;
584        let mut a = Self::inverse_transform(f.clone(), n);
585        let mut rot = MInt::<M>::one();
586        let zeta = MInt::<M>::new_unchecked(M::INFO.root[k + 1]);
587        for a in a.iter_mut() {
588            *a *= rot;
589            rot *= zeta;
590        }
591        f.extend(Self::transform(a, n));
592        f
593    }
594
595    fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
596        assert_eq!(f.len(), g.len());
597        assert!(f.len().is_power_of_two());
598        assert!(f.len() >= 2);
599        let inv2 = MInt::<M>::from(2).inv();
600        let n = f.len() / 2;
601        (0..n)
602            .map(|i| (f[i << 1] * g[i << 1 | 1] + f[i << 1 | 1] * g[i << 1]) * inv2)
603            .collect()
604    }
605
606    fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
607        assert_eq!(f.len(), g.len());
608        assert!(f.len().is_power_of_two());
609        assert!(f.len() >= 2);
610        let mut inv2 = MInt::<M>::from(2).inv();
611        let n = f.len() / 2;
612        let k = f.len().trailing_zeros() as usize;
613        let mut h = vec![MInt::<M>::zero(); n];
614        let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
615        BIT_REVERSE.with(|br| {
616            let br = unsafe { &mut *br.get() };
617            if br.len() < k {
618                br.resize_with(k, Default::default);
619            }
620            let k = k - 1;
621            if br[k].is_empty() {
622                let mut v = vec![0; 1 << k];
623                for i in 0..1 << k {
624                    v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
625                }
626                br[k] = v;
627            }
628            for &i in &br[k] {
629                h[i] = (f[i << 1] * g[i << 1 | 1] - f[i << 1 | 1] * g[i << 1]) * inv2;
630                inv2 *= w;
631            }
632        });
633        h
634    }
635}
636
637impl<M, N1, N2, N3> NttReuse for Convolve<(M, (N1, N2, N3))>
638where
639    M: MIntConvert + MIntConvert<u32>,
640    N1: Montgomery32NttModulus,
641    N2: Montgomery32NttModulus,
642    N3: Montgomery32NttModulus,
643{
644    fn ntt_doubling(f: Self::F) -> Self::F {
645        (
646            Convolve::<N1>::ntt_doubling(f.0),
647            Convolve::<N2>::ntt_doubling(f.1),
648            Convolve::<N3>::ntt_doubling(f.2),
649        )
650    }
651
652    fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
653        fn even_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
654        where
655            M: Montgomery32NttModulus,
656        {
657            let n = f.len();
658            assert_eq!(f.len(), g.len());
659            assert!(f.len().is_power_of_two());
660            assert!(f.len() >= 2);
661            let inv2 = MInt::<M>::from(2).inv();
662            let u = MInt::<M>::new(m) * MInt::<M>::from(n as u32);
663            let n = f.len() / 2;
664            (0..n)
665                .map(|i| {
666                    (f[i << 1]
667                        * if i == 0 {
668                            g[i << 1 | 1] + u
669                        } else {
670                            g[i << 1 | 1]
671                        }
672                        + f[i << 1 | 1] * g[i << 1])
673                        * inv2
674                })
675                .collect()
676        }
677
678        let m = M::mod_into();
679        (
680            even_mul_normal_neg_corrected(&f.0, &g.0, m),
681            even_mul_normal_neg_corrected(&f.1, &g.1, m),
682            even_mul_normal_neg_corrected(&f.2, &g.2, m),
683        )
684    }
685
686    fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
687        fn odd_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
688        where
689            M: Montgomery32NttModulus,
690        {
691            assert_eq!(f.len(), g.len());
692            assert!(f.len().is_power_of_two());
693            assert!(f.len() >= 2);
694            let mut inv2 = MInt::<M>::from(2).inv();
695            let u = MInt::<M>::new(m) * MInt::<M>::from(f.len() as u32);
696            let n = f.len() / 2;
697            let k = f.len().trailing_zeros() as usize;
698            let mut h = vec![MInt::<M>::zero(); n];
699            let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
700            BIT_REVERSE.with(|br| {
701                let br = unsafe { &mut *br.get() };
702                if br.len() < k {
703                    br.resize_with(k, Default::default);
704                }
705                let k = k - 1;
706                if br[k].is_empty() {
707                    let mut v = vec![0; 1 << k];
708                    for i in 0..1 << k {
709                        v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
710                    }
711                    br[k] = v;
712                }
713                for &i in &br[k] {
714                    h[i] = (f[i << 1]
715                        * if i == 0 {
716                            g[i << 1 | 1] + u
717                        } else {
718                            g[i << 1 | 1]
719                        }
720                        - f[i << 1 | 1] * g[i << 1])
721                        * inv2;
722                    inv2 *= w;
723                }
724            });
725            h
726        }
727
728        let m = M::mod_into();
729        (
730            odd_mul_normal_neg_corrected(&f.0, &g.0, m),
731            odd_mul_normal_neg_corrected(&f.1, &g.1, m),
732            odd_mul_normal_neg_corrected(&f.2, &g.2, m),
733        )
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740    use crate::num::{mint_basic::Modulo1000000009, montgomery::MInt998244353};
741    use crate::tools::Xorshift;
742
743    #[test]
744    fn test_convolve_naive() {
745        let mut rng = Xorshift::default();
746        for _ in 0..1000 {
747            let n = rng.random(0..=60);
748            let m = rng.random(0..=60);
749            let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
750            let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
751            let mut c = vec![0u32; (n + m).saturating_sub(1)];
752            for i in 0..n {
753                for j in 0..m {
754                    c[i + j] += a[i] * b[j];
755                }
756            }
757            let d = convolve_naive(&a, &b);
758            assert_eq!(c, d);
759        }
760    }
761
762    #[test]
763    fn test_convolve_karatsuba() {
764        let mut rng = Xorshift::default();
765        for _ in 0..1000 {
766            let n = rng.random(0..=200);
767            let m = rng.random(0..=200);
768            let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
769            let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
770            let mut c = vec![0u32; (n + m).saturating_sub(1)];
771            for i in 0..n {
772                for j in 0..m {
773                    c[i + j] += a[i] * b[j];
774                }
775            }
776            let d = convolve_karatsuba(&a, &b);
777            assert_eq!(c, d);
778        }
779    }
780
781    #[test]
782    fn test_ntt998244353() {
783        let mut rng = Xorshift::default();
784        for t in 0..1000 {
785            let n: usize = rng.random(0..=5);
786            let n = if n == 5 { rng.random(70..=120) } else { n };
787            let m: usize = rng.random(0..=5);
788            let m = if m == 5 { rng.random(70..=120) } else { m };
789            let (n, m) = if t % 100 != 0 {
790                (n, m)
791            } else {
792                let w = rng.random(6..=8);
793                ((1usize << w) + 1usize, (1usize << w) + 1usize)
794            };
795            let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
796            let mut b: Vec<MInt998244353> = rng.random_iter(..).take(m).collect();
797            if n == m && rng.random(0..2) == 0 {
798                b = a.clone();
799            }
800
801            let mut c = vec![MInt998244353::zero(); (n + m).saturating_sub(1)];
802            for i in 0..n {
803                for j in 0..m {
804                    c[i + j] += a[i] * b[j];
805                }
806            }
807            let d = Convolve998244353::convolve(a, b);
808            assert_eq!(c, d);
809        }
810        assert_eq!(NttInfo::new::<Modulo998244353>(), Modulo998244353::INFO);
811    }
812
813    #[test]
814    fn test_convolve3() {
815        type M = MInt<Modulo1000000009>;
816        let mut rng = Xorshift::default();
817        for _ in 0..1000 {
818            let n = rng.random(0..=5);
819            let n = if n == 5 { rng.random(70..=400) } else { n };
820            let m = rng.random(0..=5);
821            let m = if m == 5 { rng.random(70..=400) } else { m };
822            let a: Vec<M> = rng.random_iter(..).take(n).collect();
823            let b: Vec<M> = rng.random_iter(..).take(m).collect();
824            let mut c = vec![M::zero(); (n + m).saturating_sub(1)];
825            for i in 0..n {
826                for j in 0..m {
827                    c[i + j] += a[i] * b[j];
828                }
829            }
830            let d = MIntConvolve::<Modulo1000000009>::convolve(a, b);
831            assert_eq!(c, d);
832        }
833    }
834
835    #[test]
836    fn test_convolve_u64() {
837        let mut rng = Xorshift::default();
838        for _ in 0..1000 {
839            let n = rng.random(0..=5);
840            let n = if n == 5 { rng.random(70..=400) } else { n };
841            let m = rng.random(0..=5);
842            let m = if m == 5 { rng.random(70..=400) } else { m };
843            let a: Vec<u64> = rng.random_iter(0u64..1 << 24).take(n).collect();
844            let b: Vec<u64> = rng.random_iter(0u64..1 << 24).take(m).collect();
845            let mut c = vec![0; (n + m).saturating_sub(1)];
846            for i in 0..n {
847                for j in 0..m {
848                    c[i + j] += a[i] * b[j];
849                }
850            }
851            let d = U64Convolve::convolve(a, b);
852            assert_eq!(c, d);
853        }
854    }
855
856    #[test]
857    fn test_ntt_reuse_998244353() {
858        let mut rng = Xorshift::default();
859        for _ in 0..100 {
860            let n: usize = if rng.gen_bool(0.5) {
861                rng.random(1..=20)
862            } else {
863                rng.random(1..=1000)
864            };
865            let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
866            let f = Convolve998244353::transform(a.clone(), n);
867
868            // doubling
869            {
870                let f_double = Convolve998244353::ntt_doubling(f.clone());
871                let mut a = a.clone();
872                a.resize_with(n * 2, Zero::zero);
873                let f2 = Convolve998244353::transform(a, n * 2);
874                assert_eq!(f_double, f2);
875            }
876
877            let f = Convolve998244353::transform(a.clone(), n * 2);
878            let b: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
879            let g = Convolve998244353::transform(b.clone(), n * 2);
880            let mut b_neg = b.clone();
881            for b in b_neg.iter_mut().skip(1).step_by(2) {
882                *b = -*b;
883            }
884
885            // even_mul_normal_neg
886            {
887                let fg_neg = Convolve998244353::even_mul_normal_neg(&f, &g);
888                let ab_neg_even: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
889                    .into_iter()
890                    .step_by(2)
891                    .collect();
892                let fg = Convolve998244353::transform(ab_neg_even, n);
893                assert_eq!(fg_neg, fg);
894            }
895
896            // odd_mul_normal_neg
897            {
898                let fg_neg = Convolve998244353::odd_mul_normal_neg(&f, &g);
899                let ab_neg_odd: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
900                    .into_iter()
901                    .skip(1)
902                    .step_by(2)
903                    .collect();
904                let fg = Convolve998244353::transform(ab_neg_odd, n);
905                assert_eq!(fg_neg, fg);
906            }
907        }
908    }
909
910    #[test]
911    fn test_ntt_reuse_triple() {
912        type M = MInt<Modulo1000000009>;
913        let mut rng = Xorshift::default();
914        for _ in 0..100 {
915            let n: usize = if rng.gen_bool(0.5) {
916                rng.random(1..=20)
917            } else {
918                rng.random(1..=1000)
919            };
920            let a: Vec<M> = rng.random_iter(..).take(n).collect();
921            let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n);
922
923            // doubling
924            {
925                let f_double = MIntConvolve::<Modulo1000000009>::ntt_doubling(f.clone());
926                let mut a = a.clone();
927                a.resize_with(n * 2, Zero::zero);
928                let f2 = MIntConvolve::<Modulo1000000009>::transform(a, n * 2);
929                assert_eq!(f_double, f2);
930            }
931
932            let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n * 2);
933            let b: Vec<M> = rng.random_iter(..).take(n).collect();
934            let g = MIntConvolve::<Modulo1000000009>::transform(b.clone(), n * 2);
935            let mut b_neg = b.clone();
936            for b in b_neg.iter_mut().skip(1).step_by(2) {
937                *b = -*b;
938            }
939
940            // even_mul_normal_neg
941            {
942                let fg_neg = MIntConvolve::<Modulo1000000009>::even_mul_normal_neg(&f, &g);
943                let ab_neg_even: Vec<_> =
944                    MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
945                        .into_iter()
946                        .step_by(2)
947                        .collect();
948                assert_eq!(
949                    MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
950                    ab_neg_even
951                );
952            }
953
954            // odd_mul_normal_neg
955            {
956                let fg_neg = MIntConvolve::<Modulo1000000009>::odd_mul_normal_neg(&f, &g);
957                let ab_neg_odd: Vec<_> =
958                    MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
959                        .into_iter()
960                        .skip(1)
961                        .step_by(2)
962                        .chain([M::zero()])
963                        .collect();
964                assert_eq!(
965                    MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
966                    ab_neg_odd
967                );
968            }
969        }
970    }
971}