Skip to main content

competitive/math/
number_theoretic_transform.rs

1use super::{ConvolveSteps, MInt, MIntBase, MIntConvert, One, Zero, montgomery::*};
2#[cfg(target_arch = "x86_64")]
3use super::{SimdBackend, simd_backend};
4use std::{
5    cell::UnsafeCell,
6    marker::PhantomData,
7    ops::{AddAssign, Mul, SubAssign},
8};
9
10pub struct Convolve<M>(PhantomData<fn() -> M>);
11pub type Convolve998244353 = Convolve<Modulo998244353>;
12pub type MIntConvolve<M> = Convolve<(M, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
13pub type U64Convolve = Convolve<(u64, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
14
15macro_rules! impl_ntt_modulus {
16    ($([$name:ident, $g:expr]),*) => {
17        $(
18            impl Montgomery32NttModulus for $name {}
19        )*
20    };
21}
22impl_ntt_modulus!(
23    [Modulo998244353, 3],
24    [Modulo2113929217, 5],
25    [Modulo1811939329, 13],
26    [Modulo2013265921, 31]
27);
28
29const fn reduce(z: u64, p: u32, r: u32) -> u32 {
30    let mut z = ((z + r.wrapping_mul(z as u32) as u64 * p as u64) >> 32) as u32;
31    if z >= p {
32        z -= p;
33    }
34    z
35}
36const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
37    reduce(x as u64 * y as u64, p, r)
38}
39const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
40    while y > 0 {
41        if y & 1 == 1 {
42            z = mod_mul(z, x, p, r);
43        }
44        x = mod_mul(x, x, p, r);
45        y >>= 1;
46    }
47    z
48}
49
50pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
51    const PRIMITIVE_ROOT: u32 = {
52        let mut g = 3u32;
53        loop {
54            let mut ok = true;
55            let mut d = 1u32;
56            while d * d < Self::MOD {
57                if (Self::MOD - 1) % d == 0 {
58                    let ds = [d, (Self::MOD - 1) / d];
59                    let mut i = 0;
60                    while i < 2 {
61                        ok &= ds[i] == Self::MOD - 1
62                            || mod_pow(
63                                reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
64                                ds[i],
65                                Self::MOD,
66                                Self::R,
67                                Self::N1,
68                            ) != Self::N1;
69                        i += 1;
70                    }
71                }
72                d += 1;
73            }
74            if ok {
75                break;
76            }
77            g += 2;
78        }
79        g
80    };
81    const RANK: u32 = (Self::MOD - 1).trailing_zeros();
82    const INFO: NttInfo = NttInfo::new::<Self>();
83}
84
85#[derive(Debug, PartialEq)]
86pub struct NttInfo {
87    root: [u32; 32],
88    inv_root: [u32; 32],
89    rate2: [u32; 32],
90    inv_rate2: [u32; 32],
91    rate3: [u32; 32],
92    inv_rate3: [u32; 32],
93}
94impl NttInfo {
95    const fn new<M>() -> Self
96    where
97        M: Montgomery32NttModulus,
98    {
99        let mut root = [0; 32];
100        let mut inv_root = [0; 32];
101        let mut rate2 = [0; 32];
102        let mut inv_rate2 = [0; 32];
103        let mut rate3 = [0; 32];
104        let mut inv_rate3 = [0; 32];
105        let rank = M::RANK as usize;
106
107        let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
108        root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
109        inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
110        let mut i = rank - 1;
111        loop {
112            root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
113            inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
114            if i == 0 {
115                break;
116            }
117            i -= 1;
118        }
119
120        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
121        while i < rank - 1 {
122            rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
123            inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
124            prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
125            inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
126            i += 1;
127        }
128
129        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
130        while i < rank - 2 {
131            rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
132            inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
133            prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
134            inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
135            i += 1;
136        }
137
138        NttInfo {
139            root,
140            inv_root,
141            rate2,
142            inv_rate2,
143            rate3,
144            inv_rate3,
145        }
146    }
147}
148
149fn ntt_scalar<M>(a: &mut [MInt<M>])
150where
151    M: Montgomery32NttModulus,
152{
153    let n = a.len();
154    let mut v = n / 2;
155    let imag = MInt::<M>::new_unchecked(M::INFO.root[2]);
156    while v > 1 {
157        let mut w1 = MInt::<M>::one();
158        for (s, a) in a.chunks_exact_mut(v << 1).enumerate() {
159            let (l, r) = a.split_at_mut(v);
160            let (ll, lr) = l.split_at_mut(v >> 1);
161            let (rl, rr) = r.split_at_mut(v >> 1);
162            let w2 = w1 * w1;
163            let w3 = w1 * w2;
164            for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
165                let a0 = *x0;
166                let a1 = *x1 * w1;
167                let a2 = *x2 * w2;
168                let a3 = *x3 * w3;
169                let a0pa2 = a0 + a2;
170                let a0na2 = a0 - a2;
171                let a1pa3 = a1 + a3;
172                let a1na3imag = (a1 - a3) * imag;
173                *x0 = a0pa2 + a1pa3;
174                *x1 = a0pa2 - a1pa3;
175                *x2 = a0na2 + a1na3imag;
176                *x3 = a0na2 - a1na3imag;
177            }
178            w1 *= MInt::<M>::new_unchecked(M::INFO.rate3[s.trailing_ones() as usize]);
179        }
180        v >>= 2;
181    }
182    if v == 1 {
183        let mut w1 = MInt::<M>::one();
184        for (s, a) in a.chunks_exact_mut(2).enumerate() {
185            unsafe {
186                let (l, r) = a.split_at_mut(1);
187                let x0 = l.get_unchecked_mut(0);
188                let x1 = r.get_unchecked_mut(0);
189                let a0 = *x0;
190                let a1 = *x1 * w1;
191                *x0 = a0 + a1;
192                *x1 = a0 - a1;
193            }
194            w1 *= MInt::<M>::new_unchecked(M::INFO.rate2[s.trailing_ones() as usize]);
195        }
196    }
197}
198
199fn intt_scalar<M>(a: &mut [MInt<M>])
200where
201    M: Montgomery32NttModulus,
202{
203    let n = a.len();
204    let mut v = 1;
205    if n.trailing_zeros() & 1 == 1 {
206        let mut w1 = MInt::<M>::one();
207        for (s, a) in a.chunks_exact_mut(2).enumerate() {
208            unsafe {
209                let (l, r) = a.split_at_mut(1);
210                let x0 = l.get_unchecked_mut(0);
211                let x1 = r.get_unchecked_mut(0);
212                let a0 = *x0;
213                let a1 = *x1;
214                *x0 = a0 + a1;
215                *x1 = (a0 - a1) * w1;
216            }
217            w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate2[s.trailing_ones() as usize]);
218        }
219        v <<= 1;
220    }
221    let iimag = MInt::<M>::new_unchecked(M::INFO.inv_root[2]);
222    while v < n {
223        let mut w1 = MInt::<M>::one();
224        for (s, a) in a.chunks_exact_mut(v << 2).enumerate() {
225            let (l, r) = a.split_at_mut(v << 1);
226            let (ll, lr) = l.split_at_mut(v);
227            let (rl, rr) = r.split_at_mut(v);
228            let w2 = w1 * w1;
229            let w3 = w1 * w2;
230            for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
231                let a0 = *x0;
232                let a1 = *x1;
233                let a2 = *x2;
234                let a3 = *x3;
235                let a0pa1 = a0 + a1;
236                let a0na1 = a0 - a1;
237                let a2pa3 = a2 + a3;
238                let a2na3iimag = (a2 - a3) * iimag;
239                *x0 = a0pa1 + a2pa3;
240                *x1 = (a0na1 + a2na3iimag) * w1;
241                *x2 = (a0pa1 - a2pa3) * w2;
242                *x3 = (a0na1 - a2na3iimag) * w3;
243            }
244            w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate3[s.trailing_ones() as usize]);
245        }
246        v <<= 2;
247    }
248}
249
250fn ntt<M>(a: &mut [MInt<M>])
251where
252    M: Montgomery32NttModulus,
253{
254    #[cfg(target_arch = "x86_64")]
255    {
256        match simd_backend() {
257            SimdBackend::Avx512 => unsafe { ntt_simd::ntt_avx512::<M>(a) },
258            SimdBackend::Avx2 => unsafe { ntt_simd::ntt_avx2::<M>(a) },
259            SimdBackend::Scalar => ntt_scalar(a),
260        }
261    }
262    #[cfg(not(target_arch = "x86_64"))]
263    {
264        ntt_scalar(a);
265    }
266}
267
268fn intt<M>(a: &mut [MInt<M>])
269where
270    M: Montgomery32NttModulus,
271{
272    #[cfg(target_arch = "x86_64")]
273    {
274        match simd_backend() {
275            SimdBackend::Avx512 => unsafe { ntt_simd::intt_avx512::<M>(a) },
276            SimdBackend::Avx2 => unsafe { ntt_simd::intt_avx2::<M>(a) },
277            SimdBackend::Scalar => intt_scalar(a),
278        }
279    }
280    #[cfg(not(target_arch = "x86_64"))]
281    {
282        intt_scalar(a);
283    }
284}
285
286#[cfg(target_arch = "x86_64")]
287#[allow(unsafe_op_in_unsafe_fn)] // SIMD intrinsics and raw pointers are confined here
288mod ntt_simd {
289    use super::*;
290    use std::arch::x86_64::*;
291
292    const LAZY_THRESHOLD: u32 = 1 << 30;
293
294    #[target_feature(enable = "avx2")]
295    unsafe fn normalize_avx2<M>(a: &mut [u32])
296    where
297        M: Montgomery32NttModulus,
298    {
299        let mod_vec = _mm256_set1_epi32(M::MOD as i32);
300        let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
301        let mut i = 0;
302        while i + 8 <= a.len() {
303            let x = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
304            let x_x = _mm256_xor_si256(x, sign);
305            let m_x = _mm256_xor_si256(mod_vec, sign);
306            let gt = _mm256_cmpgt_epi32(x_x, m_x);
307            let eq = _mm256_cmpeq_epi32(x, mod_vec);
308            let mask = _mm256_or_si256(gt, eq);
309            let sub = _mm256_and_si256(mod_vec, mask);
310            let y = _mm256_sub_epi32(x, sub);
311            _mm256_storeu_si256(a.as_mut_ptr().add(i) as *mut __m256i, y);
312            i += 8;
313        }
314        while i < a.len() {
315            let x = a[i];
316            a[i] = if x >= M::MOD { x - M::MOD } else { x };
317            i += 1;
318        }
319    }
320
321    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
322    unsafe fn normalize_avx512<M>(a: &mut [u32])
323    where
324        M: Montgomery32NttModulus,
325    {
326        let mod_vec = _mm512_set1_epi32(M::MOD as i32);
327        let mut i = 0;
328        while i + 16 <= a.len() {
329            let x = _mm512_loadu_si512(a.as_ptr().add(i) as *const __m512i);
330            let mask = !_mm512_cmp_epu32_mask(x, mod_vec, _MM_CMPINT_LT);
331            let y = _mm512_mask_sub_epi32(x, mask, x, mod_vec);
332            _mm512_storeu_si512(a.as_mut_ptr().add(i) as *mut __m512i, y);
333            i += 16;
334        }
335        while i < a.len() {
336            let x = a[i];
337            a[i] = if x >= M::MOD { x - M::MOD } else { x };
338            i += 1;
339        }
340    }
341
342    unsafe fn add_vec_avx2<M>(
343        a: __m256i,
344        b: __m256i,
345        mod_vec: __m256i,
346        mod2_vec: __m256i,
347        sign: __m256i,
348    ) -> __m256i
349    where
350        M: Montgomery32NttModulus,
351    {
352        if M::MOD < LAZY_THRESHOLD {
353            simd32::montgomery_add_256(a, b, mod2_vec, sign)
354        } else {
355            simd32::add_mod_256(a, b, mod_vec, sign)
356        }
357    }
358
359    unsafe fn sub_vec_avx2<M>(
360        a: __m256i,
361        b: __m256i,
362        mod_vec: __m256i,
363        mod2_vec: __m256i,
364        sign: __m256i,
365    ) -> __m256i
366    where
367        M: Montgomery32NttModulus,
368    {
369        if M::MOD < LAZY_THRESHOLD {
370            simd32::montgomery_sub_256(a, b, mod2_vec, sign)
371        } else {
372            simd32::sub_mod_256(a, b, mod_vec, sign)
373        }
374    }
375
376    unsafe fn mul_vec_avx2<M>(
377        a: __m256i,
378        b: __m256i,
379        r_vec: __m256i,
380        mod_vec: __m256i,
381        sign: __m256i,
382    ) -> __m256i
383    where
384        M: Montgomery32NttModulus,
385    {
386        if M::MOD < LAZY_THRESHOLD {
387            simd32::montgomery_mul_256(a, b, r_vec, mod_vec)
388        } else {
389            simd32::montgomery_mul_256_canon(a, b, r_vec, mod_vec, sign)
390        }
391    }
392
393    unsafe fn add_vec_avx512<M>(
394        a: __m512i,
395        b: __m512i,
396        mod_vec: __m512i,
397        mod2_vec: __m512i,
398    ) -> __m512i
399    where
400        M: Montgomery32NttModulus,
401    {
402        if M::MOD < LAZY_THRESHOLD {
403            simd32::montgomery_add_512(a, b, mod2_vec)
404        } else {
405            simd32::add_mod_512(a, b, mod_vec)
406        }
407    }
408
409    unsafe fn sub_vec_avx512<M>(
410        a: __m512i,
411        b: __m512i,
412        mod_vec: __m512i,
413        mod2_vec: __m512i,
414    ) -> __m512i
415    where
416        M: Montgomery32NttModulus,
417    {
418        if M::MOD < LAZY_THRESHOLD {
419            simd32::montgomery_sub_512(a, b, mod2_vec)
420        } else {
421            simd32::sub_mod_512(a, b, mod_vec)
422        }
423    }
424
425    unsafe fn mul_vec_avx512<M>(a: __m512i, b: __m512i, r_vec: __m512i, mod_vec: __m512i) -> __m512i
426    where
427        M: Montgomery32NttModulus,
428    {
429        if M::MOD < LAZY_THRESHOLD {
430            simd32::montgomery_mul_512(a, b, r_vec, mod_vec)
431        } else {
432            simd32::montgomery_mul_512_canon(a, b, r_vec, mod_vec)
433        }
434    }
435
436    #[target_feature(enable = "avx2")]
437    pub(super) unsafe fn ntt_avx2<M>(a: &mut [MInt<M>])
438    where
439        M: Montgomery32NttModulus,
440    {
441        let n = a.len();
442        if n <= 1 {
443            return;
444        }
445        let ptr = a.as_mut_ptr() as *mut u32;
446        let a = std::slice::from_raw_parts_mut(ptr, n);
447        let mod_vec = _mm256_set1_epi32(M::MOD as i32);
448        let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
449        let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
450        let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
451        let imag = M::INFO.root[2];
452        let imag_vec = _mm256_set1_epi32(imag as i32);
453
454        let mut v = n / 2;
455        while v > 1 {
456            let half = v >> 1;
457            let mut w1 = M::N1;
458            for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
459                let base = block.as_mut_ptr();
460                let ll = base;
461                let lr = base.add(half);
462                let rl = base.add(v);
463                let rr = base.add(v + half);
464
465                let w2 = M::mod_mul(w1, w1);
466                let w3 = M::mod_mul(w2, w1);
467                let w1v = _mm256_set1_epi32(w1 as i32);
468                let w2v = _mm256_set1_epi32(w2 as i32);
469                let w3v = _mm256_set1_epi32(w3 as i32);
470
471                let mut i = 0;
472                while i + 8 <= half {
473                    let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
474                    let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
475                    let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
476                    let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
477
478                    let a1 = mul_vec_avx2::<M>(x1, w1v, r_vec, mod_vec, sign);
479                    let a2 = mul_vec_avx2::<M>(x2, w2v, r_vec, mod_vec, sign);
480                    let a3 = mul_vec_avx2::<M>(x3, w3v, r_vec, mod_vec, sign);
481
482                    let a0pa2 = add_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
483                    let a0na2 = sub_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
484                    let a1pa3 = add_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
485                    let a1na3 = sub_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
486                    let a1na3imag = mul_vec_avx2::<M>(a1na3, imag_vec, r_vec, mod_vec, sign);
487
488                    let y0 = add_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
489                    let y1 = sub_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
490                    let y2 = add_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
491                    let y3 = sub_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
492
493                    _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
494                    _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
495                    _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
496                    _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
497                    i += 8;
498                }
499                while i < half {
500                    let a0 = *ll.add(i);
501                    let a1 = M::mod_mul(*lr.add(i), w1);
502                    let a2 = M::mod_mul(*rl.add(i), w2);
503                    let a3 = M::mod_mul(*rr.add(i), w3);
504                    let a0pa2 = M::mod_add(a0, a2);
505                    let a0na2 = M::mod_sub(a0, a2);
506                    let a1pa3 = M::mod_add(a1, a3);
507                    let a1na3 = M::mod_sub(a1, a3);
508                    let a1na3imag = M::mod_mul(a1na3, imag);
509                    *ll.add(i) = M::mod_add(a0pa2, a1pa3);
510                    *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
511                    *rl.add(i) = M::mod_add(a0na2, a1na3imag);
512                    *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
513                    i += 1;
514                }
515                w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
516            }
517            v >>= 2;
518        }
519        if v == 1 {
520            let mut w1 = M::N1;
521            for (s, block) in a.chunks_exact_mut(2).enumerate() {
522                let a0 = *block.get_unchecked(0);
523                let a1 = M::mod_mul(*block.get_unchecked(1), w1);
524                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
525                *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
526                w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
527            }
528        }
529        normalize_avx2::<M>(a);
530    }
531
532    #[target_feature(enable = "avx2")]
533    pub(super) unsafe fn intt_avx2<M>(a: &mut [MInt<M>])
534    where
535        M: Montgomery32NttModulus,
536    {
537        let n = a.len();
538        if n <= 1 {
539            return;
540        }
541        let ptr = a.as_mut_ptr() as *mut u32;
542        let a = std::slice::from_raw_parts_mut(ptr, n);
543        let mod_vec = _mm256_set1_epi32(M::MOD as i32);
544        let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
545        let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
546        let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
547        let iimag = M::INFO.inv_root[2];
548        let iimag_vec = _mm256_set1_epi32(iimag as i32);
549
550        let mut v = 1;
551        if n.trailing_zeros() & 1 == 1 {
552            let mut w1 = M::N1;
553            for (s, block) in a.chunks_exact_mut(2).enumerate() {
554                let a0 = *block.get_unchecked(0);
555                let a1 = *block.get_unchecked(1);
556                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
557                *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
558                w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
559            }
560            v <<= 1;
561        }
562        while v < n {
563            let mut w1 = M::N1;
564            for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
565                let base = block.as_mut_ptr();
566                let ll = base;
567                let lr = base.add(v);
568                let rl = base.add(v << 1);
569                let rr = base.add(v * 3);
570
571                let w2 = M::mod_mul(w1, w1);
572                let w3 = M::mod_mul(w2, w1);
573                let w1v = _mm256_set1_epi32(w1 as i32);
574                let w2v = _mm256_set1_epi32(w2 as i32);
575                let w3v = _mm256_set1_epi32(w3 as i32);
576
577                let mut i = 0;
578                while i + 8 <= v {
579                    let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
580                    let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
581                    let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
582                    let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
583
584                    let a0pa1 = add_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
585                    let a0na1 = sub_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
586                    let a2pa3 = add_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
587                    let a2na3 = sub_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
588                    let a2na3iimag = mul_vec_avx2::<M>(a2na3, iimag_vec, r_vec, mod_vec, sign);
589
590                    let y0 = add_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
591                    let y1 = add_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
592                    let y2 = sub_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
593                    let y3 = sub_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
594
595                    let y1 = mul_vec_avx2::<M>(y1, w1v, r_vec, mod_vec, sign);
596                    let y2 = mul_vec_avx2::<M>(y2, w2v, r_vec, mod_vec, sign);
597                    let y3 = mul_vec_avx2::<M>(y3, w3v, r_vec, mod_vec, sign);
598
599                    _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
600                    _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
601                    _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
602                    _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
603                    i += 8;
604                }
605                while i < v {
606                    let a0 = *ll.add(i);
607                    let a1 = *lr.add(i);
608                    let a2 = *rl.add(i);
609                    let a3 = *rr.add(i);
610                    let a0pa1 = M::mod_add(a0, a1);
611                    let a0na1 = M::mod_sub(a0, a1);
612                    let a2pa3 = M::mod_add(a2, a3);
613                    let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
614                    *ll.add(i) = M::mod_add(a0pa1, a2pa3);
615                    *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
616                    *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
617                    *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
618                    i += 1;
619                }
620                w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
621            }
622            v <<= 2;
623        }
624        normalize_avx2::<M>(a);
625    }
626
627    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
628    pub(super) unsafe fn ntt_avx512<M>(a: &mut [MInt<M>])
629    where
630        M: Montgomery32NttModulus,
631    {
632        let n = a.len();
633        if n <= 1 {
634            return;
635        }
636        let ptr = a.as_mut_ptr() as *mut u32;
637        let a = std::slice::from_raw_parts_mut(ptr, n);
638        let mod_vec = _mm512_set1_epi32(M::MOD as i32);
639        let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
640        let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
641        let imag = M::INFO.root[2];
642        let imag_vec = _mm512_set1_epi32(imag as i32);
643
644        let mut v = n / 2;
645        while v > 1 {
646            let half = v >> 1;
647            let mut w1 = M::N1;
648            for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
649                let base = block.as_mut_ptr();
650                let ll = base;
651                let lr = base.add(half);
652                let rl = base.add(v);
653                let rr = base.add(v + half);
654                let w2 = M::mod_mul(w1, w1);
655                let w3 = M::mod_mul(w2, w1);
656                let w1v = _mm512_set1_epi32(w1 as i32);
657                let w2v = _mm512_set1_epi32(w2 as i32);
658                let w3v = _mm512_set1_epi32(w3 as i32);
659
660                let mut i = 0;
661                while i + 16 <= half {
662                    let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
663                    let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
664                    let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
665                    let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
666
667                    let a1 = mul_vec_avx512::<M>(x1, w1v, r_vec, mod_vec);
668                    let a2 = mul_vec_avx512::<M>(x2, w2v, r_vec, mod_vec);
669                    let a3 = mul_vec_avx512::<M>(x3, w3v, r_vec, mod_vec);
670
671                    let a0pa2 = add_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
672                    let a0na2 = sub_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
673                    let a1pa3 = add_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
674                    let a1na3 = sub_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
675                    let a1na3imag = mul_vec_avx512::<M>(a1na3, imag_vec, r_vec, mod_vec);
676
677                    let y0 = add_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
678                    let y1 = sub_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
679                    let y2 = add_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
680                    let y3 = sub_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
681
682                    _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
683                    _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
684                    _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
685                    _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
686                    i += 16;
687                }
688                while i < half {
689                    let a0 = *ll.add(i);
690                    let a1 = M::mod_mul(*lr.add(i), w1);
691                    let a2 = M::mod_mul(*rl.add(i), w2);
692                    let a3 = M::mod_mul(*rr.add(i), w3);
693                    let a0pa2 = M::mod_add(a0, a2);
694                    let a0na2 = M::mod_sub(a0, a2);
695                    let a1pa3 = M::mod_add(a1, a3);
696                    let a1na3 = M::mod_sub(a1, a3);
697                    let a1na3imag = M::mod_mul(a1na3, imag);
698                    *ll.add(i) = M::mod_add(a0pa2, a1pa3);
699                    *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
700                    *rl.add(i) = M::mod_add(a0na2, a1na3imag);
701                    *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
702                    i += 1;
703                }
704                w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
705            }
706            v >>= 2;
707        }
708        if v == 1 {
709            let mut w1 = M::N1;
710            for (s, block) in a.chunks_exact_mut(2).enumerate() {
711                let a0 = *block.get_unchecked(0);
712                let a1 = M::mod_mul(*block.get_unchecked(1), w1);
713                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
714                *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
715                w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
716            }
717        }
718        normalize_avx512::<M>(a);
719    }
720
721    #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
722    pub(super) unsafe fn intt_avx512<M>(a: &mut [MInt<M>])
723    where
724        M: Montgomery32NttModulus,
725    {
726        let n = a.len();
727        if n <= 1 {
728            return;
729        }
730        let ptr = a.as_mut_ptr() as *mut u32;
731        let a = std::slice::from_raw_parts_mut(ptr, n);
732        let mod_vec = _mm512_set1_epi32(M::MOD as i32);
733        let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
734        let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
735        let iimag = M::INFO.inv_root[2];
736        let iimag_vec = _mm512_set1_epi32(iimag as i32);
737
738        let mut v = 1;
739        if n.trailing_zeros() & 1 == 1 {
740            let mut w1 = M::N1;
741            for (s, block) in a.chunks_exact_mut(2).enumerate() {
742                let a0 = *block.get_unchecked(0);
743                let a1 = *block.get_unchecked(1);
744                *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
745                *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
746                w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
747            }
748            v <<= 1;
749        }
750        while v < n {
751            let mut w1 = M::N1;
752            for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
753                let base = block.as_mut_ptr();
754                let ll = base;
755                let lr = base.add(v);
756                let rl = base.add(v << 1);
757                let rr = base.add(v * 3);
758                let w2 = M::mod_mul(w1, w1);
759                let w3 = M::mod_mul(w2, w1);
760                let w1v = _mm512_set1_epi32(w1 as i32);
761                let w2v = _mm512_set1_epi32(w2 as i32);
762                let w3v = _mm512_set1_epi32(w3 as i32);
763
764                let mut i = 0;
765                while i + 16 <= v {
766                    let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
767                    let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
768                    let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
769                    let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
770
771                    let a0pa1 = add_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
772                    let a0na1 = sub_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
773                    let a2pa3 = add_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
774                    let a2na3 = sub_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
775                    let a2na3iimag = mul_vec_avx512::<M>(a2na3, iimag_vec, r_vec, mod_vec);
776
777                    let y0 = add_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
778                    let y1 = add_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
779                    let y2 = sub_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
780                    let y3 = sub_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
781
782                    let y1 = mul_vec_avx512::<M>(y1, w1v, r_vec, mod_vec);
783                    let y2 = mul_vec_avx512::<M>(y2, w2v, r_vec, mod_vec);
784                    let y3 = mul_vec_avx512::<M>(y3, w3v, r_vec, mod_vec);
785
786                    _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
787                    _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
788                    _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
789                    _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
790                    i += 16;
791                }
792                while i < v {
793                    let a0 = *ll.add(i);
794                    let a1 = *lr.add(i);
795                    let a2 = *rl.add(i);
796                    let a3 = *rr.add(i);
797                    let a0pa1 = M::mod_add(a0, a1);
798                    let a0na1 = M::mod_sub(a0, a1);
799                    let a2pa3 = M::mod_add(a2, a3);
800                    let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
801                    *ll.add(i) = M::mod_add(a0pa1, a2pa3);
802                    *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
803                    *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
804                    *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
805                    i += 1;
806                }
807                w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
808            }
809            v <<= 2;
810        }
811        normalize_avx512::<M>(a);
812    }
813}
814
815fn convolve_naive<T>(a: &[T], b: &[T]) -> Vec<T>
816where
817    T: Copy + Zero + AddAssign<T> + Mul<Output = T>,
818{
819    if a.is_empty() && b.is_empty() {
820        return Vec::new();
821    }
822    let len = a.len() + b.len() - 1;
823    let mut c = vec![T::zero(); len];
824    if a.len() < b.len() {
825        for (i, &b) in b.iter().enumerate() {
826            for (a, c) in a.iter().zip(&mut c[i..]) {
827                *c += *a * b;
828            }
829        }
830    } else {
831        for (i, &a) in a.iter().enumerate() {
832            for (b, c) in b.iter().zip(&mut c[i..]) {
833                *c += *b * a;
834            }
835        }
836    }
837    c
838}
839
840fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
841where
842    T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
843{
844    if a.len().min(b.len()) <= 30 {
845        return convolve_naive(a, b);
846    }
847    let m = a.len().max(b.len()).div_ceil(2);
848    let (a0, a1) = if a.len() <= m {
849        (a, &[][..])
850    } else {
851        a.split_at(m)
852    };
853    let (b0, b1) = if b.len() <= m {
854        (b, &[][..])
855    } else {
856        b.split_at(m)
857    };
858    let f00 = convolve_karatsuba(a0, b0);
859    let f11 = convolve_karatsuba(a1, b1);
860    let mut a0a1 = a0.to_vec();
861    for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
862        *a0a1 += a1;
863    }
864    let mut b0b1 = b0.to_vec();
865    for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
866        *b0b1 += b1;
867    }
868    let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
869    for (f01, &f00) in f01.iter_mut().zip(&f00) {
870        *f01 -= f00;
871    }
872    for (f01, &f11) in f01.iter_mut().zip(&f11) {
873        *f01 -= f11;
874    }
875    let mut c = vec![T::zero(); a.len() + b.len() - 1];
876    for (c, &f00) in c.iter_mut().zip(&f00) {
877        *c += f00;
878    }
879    for (c, &f01) in c[m..].iter_mut().zip(&f01) {
880        *c += f01;
881    }
882    for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
883        *c += f11;
884    }
885    c
886}
887
888impl<M> ConvolveSteps for Convolve<M>
889where
890    M: Montgomery32NttModulus,
891{
892    type T = Vec<MInt<M>>;
893    type F = Vec<MInt<M>>;
894    fn length(t: &Self::T) -> usize {
895        t.len()
896    }
897    fn transform(mut t: Self::T, len: usize) -> Self::F {
898        t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
899        ntt(&mut t);
900        t
901    }
902    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
903        intt(&mut f);
904        f.truncate(len);
905        let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
906        for f in f.iter_mut() {
907            *f *= inv;
908        }
909        f
910    }
911    fn multiply(f: &mut Self::F, g: &Self::F) {
912        assert_eq!(f.len(), g.len());
913        for (f, g) in f.iter_mut().zip(g.iter()) {
914            *f *= *g;
915        }
916    }
917    fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
918        if Self::length(&a).max(Self::length(&b)) <= 100 {
919            return convolve_karatsuba(&a, &b);
920        }
921        if Self::length(&a).min(Self::length(&b)) <= 60 {
922            return convolve_naive(&a, &b);
923        }
924        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
925        let size = len.max(1).next_power_of_two();
926        if len <= size / 2 + 2 {
927            let xa = a.pop().unwrap();
928            let xb = b.pop().unwrap();
929            let mut c = vec![MInt::<M>::zero(); len];
930            *c.last_mut().unwrap() = xa * xb;
931            for (a, c) in a.iter().zip(&mut c[b.len()..]) {
932                *c += *a * xb;
933            }
934            for (b, c) in b.iter().zip(&mut c[a.len()..]) {
935                *c += *b * xa;
936            }
937            let d = Self::convolve(a, b);
938            for (d, c) in d.into_iter().zip(&mut c) {
939                *c += d;
940            }
941            return c;
942        }
943        let same = a == b;
944        let mut a = Self::transform(a, len);
945        if same {
946            for a in a.iter_mut() {
947                *a *= *a;
948            }
949        } else {
950            let b = Self::transform(b, len);
951            Self::multiply(&mut a, &b);
952        }
953        Self::inverse_transform(a, len)
954    }
955}
956
957type MVec<M> = Vec<MInt<M>>;
958impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
959where
960    M: MIntConvert + MIntConvert<u32>,
961    N1: Montgomery32NttModulus,
962    N2: Montgomery32NttModulus,
963    N3: Montgomery32NttModulus,
964{
965    type T = MVec<M>;
966    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
967    fn length(t: &Self::T) -> usize {
968        t.len()
969    }
970    fn transform(t: Self::T, len: usize) -> Self::F {
971        let npot = len.max(1).next_power_of_two();
972        let mut f = (
973            MVec::<N1>::with_capacity(npot),
974            MVec::<N2>::with_capacity(npot),
975            MVec::<N3>::with_capacity(npot),
976        );
977        for t in t {
978            f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
979            f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
980            f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
981        }
982        f.0.resize_with(npot, Zero::zero);
983        f.1.resize_with(npot, Zero::zero);
984        f.2.resize_with(npot, Zero::zero);
985        ntt(&mut f.0);
986        ntt(&mut f.1);
987        ntt(&mut f.2);
988        f
989    }
990    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
991        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
992        let m1 = MInt::<M>::from(N1::get_mod());
993        let m1_3 = MInt::<N3>::new(N1::get_mod());
994        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
995        let m2 = m1 * MInt::<M>::from(N2::get_mod());
996        Convolve::<N1>::inverse_transform(f.0, len)
997            .into_iter()
998            .zip(Convolve::<N2>::inverse_transform(f.1, len))
999            .zip(Convolve::<N3>::inverse_transform(f.2, len))
1000            .map(|((c1, c2), c3)| {
1001                let d1 = c1.inner();
1002                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1003                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1004                let d3 = ((c3 - x) * t2).inner();
1005                MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
1006            })
1007            .collect()
1008    }
1009    fn multiply(f: &mut Self::F, g: &Self::F) {
1010        assert_eq!(f.0.len(), g.0.len());
1011        assert_eq!(f.1.len(), g.1.len());
1012        assert_eq!(f.2.len(), g.2.len());
1013        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1014            *f *= *g;
1015        }
1016        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1017            *f *= *g;
1018        }
1019        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1020            *f *= *g;
1021        }
1022    }
1023    fn convolve(a: Self::T, b: Self::T) -> Self::T {
1024        if Self::length(&a).max(Self::length(&b)) <= 300 {
1025            return convolve_karatsuba(&a, &b);
1026        }
1027        if Self::length(&a).min(Self::length(&b)) <= 60 {
1028            return convolve_naive(&a, &b);
1029        }
1030        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1031        let mut a = Self::transform(a, len);
1032        let b = Self::transform(b, len);
1033        Self::multiply(&mut a, &b);
1034        Self::inverse_transform(a, len)
1035    }
1036}
1037
1038impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
1039where
1040    N1: Montgomery32NttModulus,
1041    N2: Montgomery32NttModulus,
1042    N3: Montgomery32NttModulus,
1043{
1044    type T = Vec<u64>;
1045    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
1046
1047    fn length(t: &Self::T) -> usize {
1048        t.len()
1049    }
1050
1051    fn transform(t: Self::T, len: usize) -> Self::F {
1052        let npot = len.max(1).next_power_of_two();
1053        let mut f = (
1054            MVec::<N1>::with_capacity(npot),
1055            MVec::<N2>::with_capacity(npot),
1056            MVec::<N3>::with_capacity(npot),
1057        );
1058        for t in t {
1059            f.0.push(t.into());
1060            f.1.push(t.into());
1061            f.2.push(t.into());
1062        }
1063        f.0.resize_with(npot, Zero::zero);
1064        f.1.resize_with(npot, Zero::zero);
1065        f.2.resize_with(npot, Zero::zero);
1066        ntt(&mut f.0);
1067        ntt(&mut f.1);
1068        ntt(&mut f.2);
1069        f
1070    }
1071
1072    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
1073        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
1074        let m1 = N1::get_mod() as u64;
1075        let m1_3 = MInt::<N3>::new(N1::get_mod());
1076        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
1077        let m2 = m1 * N2::get_mod() as u64;
1078        Convolve::<N1>::inverse_transform(f.0, len)
1079            .into_iter()
1080            .zip(Convolve::<N2>::inverse_transform(f.1, len))
1081            .zip(Convolve::<N3>::inverse_transform(f.2, len))
1082            .map(|((c1, c2), c3)| {
1083                let d1 = c1.inner();
1084                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1085                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1086                let d3 = ((c3 - x) * t2).inner();
1087                d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
1088            })
1089            .collect()
1090    }
1091
1092    fn multiply(f: &mut Self::F, g: &Self::F) {
1093        assert_eq!(f.0.len(), g.0.len());
1094        assert_eq!(f.1.len(), g.1.len());
1095        assert_eq!(f.2.len(), g.2.len());
1096        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1097            *f *= *g;
1098        }
1099        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1100            *f *= *g;
1101        }
1102        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1103            *f *= *g;
1104        }
1105    }
1106
1107    fn convolve(a: Self::T, b: Self::T) -> Self::T {
1108        if Self::length(&a).max(Self::length(&b)) <= 300 {
1109            return convolve_karatsuba(&a, &b);
1110        }
1111        if Self::length(&a).min(Self::length(&b)) <= 60 {
1112            return convolve_naive(&a, &b);
1113        }
1114        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1115        let mut a = Self::transform(a, len);
1116        let b = Self::transform(b, len);
1117        Self::multiply(&mut a, &b);
1118        Self::inverse_transform(a, len)
1119    }
1120}
1121
1122pub trait NttReuse: ConvolveSteps {
1123    const MULTIPLE: bool = true;
1124
1125    /// F(a) → F(a + [0] * a.len())
1126    fn ntt_doubling(f: Self::F) -> Self::F;
1127
1128    /// F(a(x)), F(b(x)) → even(F(a(x) * b(-x)))
1129    fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
1130
1131    /// F(a(x)), F(b(x)) → odd(F(a(x) * b(-x)))
1132    fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
1133}
1134
1135thread_local!(
1136    static BIT_REVERSE: UnsafeCell<Vec<Vec<usize>>> = const { UnsafeCell::new(vec![]) };
1137);
1138
1139impl<M> NttReuse for Convolve<M>
1140where
1141    M: Montgomery32NttModulus,
1142{
1143    const MULTIPLE: bool = false;
1144
1145    fn ntt_doubling(mut f: Self::F) -> Self::F {
1146        let n = f.len();
1147        let k = n.trailing_zeros() as usize;
1148        let mut a = Self::inverse_transform(f.clone(), n);
1149        let mut rot = MInt::<M>::one();
1150        let zeta = MInt::<M>::new_unchecked(M::INFO.root[k + 1]);
1151        for a in a.iter_mut() {
1152            *a *= rot;
1153            rot *= zeta;
1154        }
1155        f.extend(Self::transform(a, n));
1156        f
1157    }
1158
1159    fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1160        assert_eq!(f.len(), g.len());
1161        assert!(f.len().is_power_of_two());
1162        assert!(f.len() >= 2);
1163        let inv2 = MInt::<M>::from(2).inv();
1164        let n = f.len() / 2;
1165        (0..n)
1166            .map(|i| (f[i << 1] * g[i << 1 | 1] + f[i << 1 | 1] * g[i << 1]) * inv2)
1167            .collect()
1168    }
1169
1170    fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1171        assert_eq!(f.len(), g.len());
1172        assert!(f.len().is_power_of_two());
1173        assert!(f.len() >= 2);
1174        let mut inv2 = MInt::<M>::from(2).inv();
1175        let n = f.len() / 2;
1176        let k = f.len().trailing_zeros() as usize;
1177        let mut h = vec![MInt::<M>::zero(); n];
1178        let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
1179        BIT_REVERSE.with(|br| {
1180            let br = unsafe { &mut *br.get() };
1181            if br.len() < k {
1182                br.resize_with(k, Default::default);
1183            }
1184            let k = k - 1;
1185            if br[k].is_empty() {
1186                let mut v = vec![0; 1 << k];
1187                for i in 0..1 << k {
1188                    v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
1189                }
1190                br[k] = v;
1191            }
1192            for &i in &br[k] {
1193                h[i] = (f[i << 1] * g[i << 1 | 1] - f[i << 1 | 1] * g[i << 1]) * inv2;
1194                inv2 *= w;
1195            }
1196        });
1197        h
1198    }
1199}
1200
1201impl<M, N1, N2, N3> NttReuse for Convolve<(M, (N1, N2, N3))>
1202where
1203    M: MIntConvert + MIntConvert<u32>,
1204    N1: Montgomery32NttModulus,
1205    N2: Montgomery32NttModulus,
1206    N3: Montgomery32NttModulus,
1207{
1208    fn ntt_doubling(f: Self::F) -> Self::F {
1209        (
1210            Convolve::<N1>::ntt_doubling(f.0),
1211            Convolve::<N2>::ntt_doubling(f.1),
1212            Convolve::<N3>::ntt_doubling(f.2),
1213        )
1214    }
1215
1216    fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1217        fn even_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
1218        where
1219            M: Montgomery32NttModulus,
1220        {
1221            let n = f.len();
1222            assert_eq!(f.len(), g.len());
1223            assert!(f.len().is_power_of_two());
1224            assert!(f.len() >= 2);
1225            let inv2 = MInt::<M>::from(2).inv();
1226            let u = MInt::<M>::new(m) * MInt::<M>::from(n as u32);
1227            let n = f.len() / 2;
1228            (0..n)
1229                .map(|i| {
1230                    (f[i << 1]
1231                        * if i == 0 {
1232                            g[i << 1 | 1] + u
1233                        } else {
1234                            g[i << 1 | 1]
1235                        }
1236                        + f[i << 1 | 1] * g[i << 1])
1237                        * inv2
1238                })
1239                .collect()
1240        }
1241
1242        let m = M::mod_into();
1243        (
1244            even_mul_normal_neg_corrected(&f.0, &g.0, m),
1245            even_mul_normal_neg_corrected(&f.1, &g.1, m),
1246            even_mul_normal_neg_corrected(&f.2, &g.2, m),
1247        )
1248    }
1249
1250    fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1251        fn odd_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
1252        where
1253            M: Montgomery32NttModulus,
1254        {
1255            assert_eq!(f.len(), g.len());
1256            assert!(f.len().is_power_of_two());
1257            assert!(f.len() >= 2);
1258            let mut inv2 = MInt::<M>::from(2).inv();
1259            let u = MInt::<M>::new(m) * MInt::<M>::from(f.len() as u32);
1260            let n = f.len() / 2;
1261            let k = f.len().trailing_zeros() as usize;
1262            let mut h = vec![MInt::<M>::zero(); n];
1263            let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
1264            BIT_REVERSE.with(|br| {
1265                let br = unsafe { &mut *br.get() };
1266                if br.len() < k {
1267                    br.resize_with(k, Default::default);
1268                }
1269                let k = k - 1;
1270                if br[k].is_empty() {
1271                    let mut v = vec![0; 1 << k];
1272                    for i in 0..1 << k {
1273                        v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
1274                    }
1275                    br[k] = v;
1276                }
1277                for &i in &br[k] {
1278                    h[i] = (f[i << 1]
1279                        * if i == 0 {
1280                            g[i << 1 | 1] + u
1281                        } else {
1282                            g[i << 1 | 1]
1283                        }
1284                        - f[i << 1 | 1] * g[i << 1])
1285                        * inv2;
1286                    inv2 *= w;
1287                }
1288            });
1289            h
1290        }
1291
1292        let m = M::mod_into();
1293        (
1294            odd_mul_normal_neg_corrected(&f.0, &g.0, m),
1295            odd_mul_normal_neg_corrected(&f.1, &g.1, m),
1296            odd_mul_normal_neg_corrected(&f.2, &g.2, m),
1297        )
1298    }
1299}
1300
1301#[cfg(test)]
1302mod tests {
1303    use super::*;
1304    use crate::num::{mint_basic::Modulo1000000009, montgomery::MInt998244353};
1305    use crate::tools::Xorshift;
1306
1307    #[test]
1308    fn test_convolve_naive() {
1309        let mut rng = Xorshift::default();
1310        for _ in 0..1000 {
1311            let n = rng.random(0..=60);
1312            let m = rng.random(0..=60);
1313            let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
1314            let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
1315            let mut c = vec![0u32; (n + m).saturating_sub(1)];
1316            for i in 0..n {
1317                for j in 0..m {
1318                    c[i + j] += a[i] * b[j];
1319                }
1320            }
1321            let d = convolve_naive(&a, &b);
1322            assert_eq!(c, d);
1323        }
1324    }
1325
1326    #[test]
1327    fn test_convolve_karatsuba() {
1328        let mut rng = Xorshift::default();
1329        for _ in 0..1000 {
1330            let n = rng.random(0..=200);
1331            let m = rng.random(0..=200);
1332            let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
1333            let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
1334            let mut c = vec![0u32; (n + m).saturating_sub(1)];
1335            for i in 0..n {
1336                for j in 0..m {
1337                    c[i + j] += a[i] * b[j];
1338                }
1339            }
1340            let d = convolve_karatsuba(&a, &b);
1341            assert_eq!(c, d);
1342        }
1343    }
1344
1345    #[test]
1346    fn test_ntt998244353() {
1347        let mut rng = Xorshift::default();
1348        for t in 0..1000 {
1349            let n: usize = rng.random(0..=5);
1350            let n = if n == 5 { rng.random(70..=120) } else { n };
1351            let m: usize = rng.random(0..=5);
1352            let m = if m == 5 { rng.random(70..=120) } else { m };
1353            let (n, m) = if t % 100 != 0 {
1354                (n, m)
1355            } else {
1356                let w = rng.random(6..=8);
1357                ((1usize << w) + 1usize, (1usize << w) + 1usize)
1358            };
1359            let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
1360            let mut b: Vec<MInt998244353> = rng.random_iter(..).take(m).collect();
1361            if n == m && rng.random(0..2) == 0 {
1362                b = a.clone();
1363            }
1364
1365            let mut c = vec![MInt998244353::zero(); (n + m).saturating_sub(1)];
1366            for i in 0..n {
1367                for j in 0..m {
1368                    c[i + j] += a[i] * b[j];
1369                }
1370            }
1371            let d = Convolve998244353::convolve(a, b);
1372            assert_eq!(c, d);
1373        }
1374        assert_eq!(NttInfo::new::<Modulo998244353>(), Modulo998244353::INFO);
1375    }
1376
1377    #[test]
1378    fn test_convolve3() {
1379        type M = MInt<Modulo1000000009>;
1380        let mut rng = Xorshift::default();
1381        for _ in 0..1000 {
1382            let n = rng.random(0..=5);
1383            let n = if n == 5 { rng.random(70..=400) } else { n };
1384            let m = rng.random(0..=5);
1385            let m = if m == 5 { rng.random(70..=400) } else { m };
1386            let a: Vec<M> = rng.random_iter(..).take(n).collect();
1387            let b: Vec<M> = rng.random_iter(..).take(m).collect();
1388            let mut c = vec![M::zero(); (n + m).saturating_sub(1)];
1389            for i in 0..n {
1390                for j in 0..m {
1391                    c[i + j] += a[i] * b[j];
1392                }
1393            }
1394            let d = MIntConvolve::<Modulo1000000009>::convolve(a, b);
1395            assert_eq!(c, d);
1396        }
1397    }
1398
1399    #[test]
1400    fn test_convolve_u64() {
1401        let mut rng = Xorshift::default();
1402        for _ in 0..1000 {
1403            let n = rng.random(0..=5);
1404            let n = if n == 5 { rng.random(70..=400) } else { n };
1405            let m = rng.random(0..=5);
1406            let m = if m == 5 { rng.random(70..=400) } else { m };
1407            let a: Vec<u64> = rng.random_iter(0u64..1 << 24).take(n).collect();
1408            let b: Vec<u64> = rng.random_iter(0u64..1 << 24).take(m).collect();
1409            let mut c = vec![0; (n + m).saturating_sub(1)];
1410            for i in 0..n {
1411                for j in 0..m {
1412                    c[i + j] += a[i] * b[j];
1413                }
1414            }
1415            let d = U64Convolve::convolve(a, b);
1416            assert_eq!(c, d);
1417        }
1418    }
1419
1420    #[test]
1421    fn test_ntt_reuse_998244353() {
1422        let mut rng = Xorshift::default();
1423        for _ in 0..100 {
1424            let n: usize = if rng.gen_bool(0.5) {
1425                rng.random(1..=20)
1426            } else {
1427                rng.random(1..=1000)
1428            };
1429            let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
1430            let f = Convolve998244353::transform(a.clone(), n);
1431
1432            // doubling
1433            {
1434                let f_double = Convolve998244353::ntt_doubling(f.clone());
1435                let mut a = a.clone();
1436                a.resize_with(n * 2, Zero::zero);
1437                let f2 = Convolve998244353::transform(a, n * 2);
1438                assert_eq!(f_double, f2);
1439            }
1440
1441            let f = Convolve998244353::transform(a.clone(), n * 2);
1442            let b: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
1443            let g = Convolve998244353::transform(b.clone(), n * 2);
1444            let mut b_neg = b.clone();
1445            for b in b_neg.iter_mut().skip(1).step_by(2) {
1446                *b = -*b;
1447            }
1448
1449            // even_mul_normal_neg
1450            {
1451                let fg_neg = Convolve998244353::even_mul_normal_neg(&f, &g);
1452                let ab_neg_even: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
1453                    .into_iter()
1454                    .step_by(2)
1455                    .collect();
1456                let fg = Convolve998244353::transform(ab_neg_even, n);
1457                assert_eq!(fg_neg, fg);
1458            }
1459
1460            // odd_mul_normal_neg
1461            {
1462                let fg_neg = Convolve998244353::odd_mul_normal_neg(&f, &g);
1463                let ab_neg_odd: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
1464                    .into_iter()
1465                    .skip(1)
1466                    .step_by(2)
1467                    .collect();
1468                let fg = Convolve998244353::transform(ab_neg_odd, n);
1469                assert_eq!(fg_neg, fg);
1470            }
1471        }
1472    }
1473
1474    #[test]
1475    fn test_ntt_reuse_triple() {
1476        type M = MInt<Modulo1000000009>;
1477        let mut rng = Xorshift::default();
1478        for _ in 0..100 {
1479            let n: usize = if rng.gen_bool(0.5) {
1480                rng.random(1..=20)
1481            } else {
1482                rng.random(1..=1000)
1483            };
1484            let a: Vec<M> = rng.random_iter(..).take(n).collect();
1485            let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n);
1486
1487            // doubling
1488            {
1489                let f_double = MIntConvolve::<Modulo1000000009>::ntt_doubling(f.clone());
1490                let mut a = a.clone();
1491                a.resize_with(n * 2, Zero::zero);
1492                let f2 = MIntConvolve::<Modulo1000000009>::transform(a, n * 2);
1493                assert_eq!(f_double, f2);
1494            }
1495
1496            let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n * 2);
1497            let b: Vec<M> = rng.random_iter(..).take(n).collect();
1498            let g = MIntConvolve::<Modulo1000000009>::transform(b.clone(), n * 2);
1499            let mut b_neg = b.clone();
1500            for b in b_neg.iter_mut().skip(1).step_by(2) {
1501                *b = -*b;
1502            }
1503
1504            // even_mul_normal_neg
1505            {
1506                let fg_neg = MIntConvolve::<Modulo1000000009>::even_mul_normal_neg(&f, &g);
1507                let ab_neg_even: Vec<_> =
1508                    MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
1509                        .into_iter()
1510                        .step_by(2)
1511                        .collect();
1512                assert_eq!(
1513                    MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
1514                    ab_neg_even
1515                );
1516            }
1517
1518            // odd_mul_normal_neg
1519            {
1520                let fg_neg = MIntConvolve::<Modulo1000000009>::odd_mul_normal_neg(&f, &g);
1521                let ab_neg_odd: Vec<_> =
1522                    MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
1523                        .into_iter()
1524                        .skip(1)
1525                        .step_by(2)
1526                        .chain([M::zero()])
1527                        .collect();
1528                assert_eq!(
1529                    MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
1530                    ab_neg_odd
1531                );
1532            }
1533        }
1534    }
1535}