1use super::{ConvolveSteps, MInt, MIntBase, MIntConvert, One, Zero, montgomery::*};
2use std::marker::PhantomData;
3
4pub struct Convolve<M>(PhantomData<fn() -> M>);
5pub type Convolve998244353 = Convolve<Modulo998244353>;
6pub type MIntConvolve<M> = Convolve<(M, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
7
8macro_rules! impl_ntt_modulus {
9 ($([$name:ident, $g:expr]),*) => {
10 $(
11 impl Montgomery32NttModulus for $name {}
12 )*
13 };
14}
15impl_ntt_modulus!(
16 [Modulo998244353, 3],
17 [Modulo2113929217, 5],
18 [Modulo1811939329, 13],
19 [Modulo2013265921, 31]
20);
21
22const fn reduce(z: u64, p: u32, r: u32) -> u32 {
23 let mut z = ((z + r.wrapping_mul(z as u32) as u64 * p as u64) >> 32) as u32;
24 if z >= p {
25 z -= p;
26 }
27 z
28}
29const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
30 reduce(x as u64 * y as u64, p, r)
31}
32const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
33 while y > 0 {
34 if y & 1 == 1 {
35 z = mod_mul(z, x, p, r);
36 }
37 x = mod_mul(x, x, p, r);
38 y >>= 1;
39 }
40 z
41}
42
43pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
44 const PRIMITIVE_ROOT: u32 = {
45 let mut g = 3u32;
46 loop {
47 let mut ok = true;
48 let mut d = 1u32;
49 while d * d < Self::MOD {
50 if (Self::MOD - 1) % d == 0 {
51 let ds = [d, (Self::MOD - 1) / d];
52 let mut i = 0;
53 while i < 2 {
54 ok &= ds[i] == Self::MOD - 1
55 || mod_pow(
56 reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
57 ds[i],
58 Self::MOD,
59 Self::R,
60 Self::N1,
61 ) != Self::N1;
62 i += 1;
63 }
64 }
65 d += 1;
66 }
67 if ok {
68 break;
69 }
70 g += 2;
71 }
72 g
73 };
74 const RANK: u32 = (Self::MOD - 1).trailing_zeros();
75 const INFO: NttInfo = NttInfo::new::<Self>();
76}
77
78pub struct NttInfo {
79 root: [u32; 32],
80 inv_root: [u32; 32],
81 rate2: [u32; 32],
82 inv_rate2: [u32; 32],
83 rate3: [u32; 32],
84 inv_rate3: [u32; 32],
85}
86impl NttInfo {
87 const fn new<M>() -> Self
88 where
89 M: Montgomery32NttModulus,
90 {
91 let mut root = [0; 32];
92 let mut inv_root = [0; 32];
93 let mut rate2 = [0; 32];
94 let mut inv_rate2 = [0; 32];
95 let mut rate3 = [0; 32];
96 let mut inv_rate3 = [0; 32];
97 let rank = M::RANK as usize;
98
99 let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
100 root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
101 inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
102 let mut i = rank - 1;
103 loop {
104 root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
105 inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
106 if i == 0 {
107 break;
108 }
109 i -= 1;
110 }
111
112 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
113 while i < rank - 1 {
114 rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
115 inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
116 prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
117 inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
118 i += 1;
119 }
120
121 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
122 while i < rank - 2 {
123 rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
124 inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
125 prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
126 inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
127 i += 1;
128 }
129
130 NttInfo {
131 root,
132 inv_root,
133 rate2,
134 inv_rate2,
135 rate3,
136 inv_rate3,
137 }
138 }
139}
140
141crate::avx_helper!(
142 @avx2 fn ntt<M>(a: &mut [MInt<M>])
143 where
144 [M: Montgomery32NttModulus]
145 {
146 let n = a.len();
147 let mut v = n / 2;
148 let imag = MInt::<M>::new_unchecked(M::INFO.root[2]);
149 while v > 1 {
150 let mut w1 = MInt::<M>::one();
151 for (s, a) in a.chunks_exact_mut(v << 1).enumerate() {
152 let (l, r) = a.split_at_mut(v);
153 let (ll, lr) = l.split_at_mut(v >> 1);
154 let (rl, rr) = r.split_at_mut(v >> 1);
155 let w2 = w1 * w1;
156 let w3 = w1 * w2;
157 for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
158 let a0 = *x0;
159 let a1 = *x1 * w1;
160 let a2 = *x2 * w2;
161 let a3 = *x3 * w3;
162 let a0pa2 = a0 + a2;
163 let a0na2 = a0 - a2;
164 let a1pa3 = a1 + a3;
165 let a1na3imag = (a1 - a3) * imag;
166 *x0 = a0pa2 + a1pa3;
167 *x1 = a0pa2 - a1pa3;
168 *x2 = a0na2 + a1na3imag;
169 *x3 = a0na2 - a1na3imag;
170 }
171 w1 *= MInt::<M>::new_unchecked(M::INFO.rate3[s.trailing_ones() as usize]);
172 }
173 v >>= 2;
174 }
175 if v == 1 {
176 let mut w1 = MInt::<M>::one();
177 for (s, a) in a.chunks_exact_mut(2).enumerate() {
178 unsafe {
179 let (l, r) = a.split_at_mut(1);
180 let x0 = l.get_unchecked_mut(0);
181 let x1 = r.get_unchecked_mut(0);
182 let a0 = *x0;
183 let a1 = *x1 * w1;
184 *x0 = a0 + a1;
185 *x1 = a0 - a1;
186 }
187 w1 *= MInt::<M>::new_unchecked(M::INFO.rate2[s.trailing_ones() as usize]);
188 }
189 }
190 }
191);
192crate::avx_helper!(
193 @avx2 fn intt<M>(a: &mut [MInt<M>])
194 where
195 [M: Montgomery32NttModulus]
196 {
197 let n = a.len();
198 let mut v = 1;
199 if n.trailing_zeros() & 1 == 1 {
200 let mut w1 = MInt::<M>::one();
201 for (s, a) in a.chunks_exact_mut(2).enumerate() {
202 unsafe {
203 let (l, r) = a.split_at_mut(1);
204 let x0 = l.get_unchecked_mut(0);
205 let x1 = r.get_unchecked_mut(0);
206 let a0 = *x0;
207 let a1 = *x1;
208 *x0 = a0 + a1;
209 *x1 = (a0 - a1) * w1;
210 }
211 w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate2[s.trailing_ones() as usize]);
212 }
213 v <<= 1;
214 }
215 let iimag = MInt::<M>::new_unchecked(M::INFO.inv_root[2]);
216 while v < n {
217 let mut w1 = MInt::<M>::one();
218 for (s, a) in a.chunks_exact_mut(v << 2).enumerate() {
219 let (l, r) = a.split_at_mut(v << 1);
220 let (ll, lr) = l.split_at_mut(v);
221 let (rl, rr) = r.split_at_mut(v);
222 let w2 = w1 * w1;
223 let w3 = w1 * w2;
224 for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
225 let a0 = *x0;
226 let a1 = *x1;
227 let a2 = *x2;
228 let a3 = *x3;
229 let a0pa1 = a0 + a1;
230 let a0na1 = a0 - a1;
231 let a2pa3 = a2 + a3;
232 let a2na3iimag = (a2 - a3) * iimag;
233 *x0 = a0pa1 + a2pa3;
234 *x1 = (a0na1 + a2na3iimag) * w1;
235 *x2 = (a0pa1 - a2pa3) * w2;
236 *x3 = (a0na1 - a2na3iimag) * w3;
237 }
238 w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate3[s.trailing_ones() as usize]);
239 }
240 v <<= 2;
241 }
242 }
243);
244
245fn convolve_naive<M>(a: &[MInt<M>], b: &[MInt<M>]) -> Vec<MInt<M>>
246where
247 M: MIntBase,
248{
249 if a.is_empty() && b.is_empty() {
250 return Vec::new();
251 }
252 let len = a.len() + b.len() - 1;
253 let mut c = vec![MInt::<M>::zero(); len];
254 if a.len() < b.len() {
255 for (i, &b) in b.iter().enumerate() {
256 for (a, c) in a.iter().zip(&mut c[i..]) {
257 *c += *a * b;
258 }
259 }
260 } else {
261 for (i, &a) in a.iter().enumerate() {
262 for (b, c) in b.iter().zip(&mut c[i..]) {
263 *c += *b * a;
264 }
265 }
266 }
267 c
268}
269impl<M> ConvolveSteps for Convolve<M>
270where
271 M: Montgomery32NttModulus,
272{
273 type T = Vec<MInt<M>>;
274 type F = Vec<MInt<M>>;
275 fn length(t: &Self::T) -> usize {
276 t.len()
277 }
278 fn transform(mut t: Self::T, len: usize) -> Self::F {
279 t.resize_with(len.max(2).next_power_of_two(), Zero::zero);
280 ntt(&mut t);
281 t
282 }
283 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
284 intt(&mut f);
285 f.truncate(len);
286 let inv = MInt::from(len.max(2).next_power_of_two() as u32).inv();
287 for f in f.iter_mut() {
288 *f *= inv;
289 }
290 f
291 }
292 fn multiply(f: &mut Self::F, g: &Self::F) {
293 assert_eq!(f.len(), g.len());
294 for (f, g) in f.iter_mut().zip(g.iter()) {
295 *f *= *g;
296 }
297 }
298 fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
299 if Self::length(&a).min(Self::length(&b)) <= 60 {
300 return convolve_naive(&a, &b);
301 }
302 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
303 let size = len.max(2).next_power_of_two();
304 if len <= size / 2 + 2 {
305 let xa = a.pop().unwrap();
306 let xb = b.pop().unwrap();
307 let mut c = vec![MInt::<M>::zero(); len];
308 *c.last_mut().unwrap() = xa * xb;
309 for (a, c) in a.iter().zip(&mut c[b.len()..]) {
310 *c += *a * xb;
311 }
312 for (b, c) in b.iter().zip(&mut c[a.len()..]) {
313 *c += *b * xa;
314 }
315 let d = Self::convolve(a, b);
316 for (d, c) in d.into_iter().zip(&mut c) {
317 *c += d;
318 }
319 return c;
320 }
321 let same = a == b;
322 let mut a = Self::transform(a, len);
323 if same {
324 for a in a.iter_mut() {
325 *a *= *a;
326 }
327 } else {
328 let b = Self::transform(b, len);
329 Self::multiply(&mut a, &b);
330 }
331 Self::inverse_transform(a, len)
332 }
333}
334type MVec<M> = Vec<MInt<M>>;
335impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
336where
337 M: MIntConvert + MIntConvert<u32>,
338 N1: Montgomery32NttModulus,
339 N2: Montgomery32NttModulus,
340 N3: Montgomery32NttModulus,
341{
342 type T = MVec<M>;
343 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
344 fn length(t: &Self::T) -> usize {
345 t.len()
346 }
347 fn transform(t: Self::T, len: usize) -> Self::F {
348 let npot = len.max(2).next_power_of_two();
349 let mut f = (
350 MVec::<N1>::with_capacity(npot),
351 MVec::<N2>::with_capacity(npot),
352 MVec::<N3>::with_capacity(npot),
353 );
354 for t in t {
355 f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
356 f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
357 f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
358 }
359 f.0.resize_with(npot, Zero::zero);
360 f.1.resize_with(npot, Zero::zero);
361 f.2.resize_with(npot, Zero::zero);
362 ntt(&mut f.0);
363 ntt(&mut f.1);
364 ntt(&mut f.2);
365 f
366 }
367 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
368 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
369 let m1 = MInt::<M>::from(N1::get_mod());
370 let m1_3 = MInt::<N3>::new(N1::get_mod());
371 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
372 let m2 = m1 * MInt::<M>::from(N2::get_mod());
373 Convolve::<N1>::inverse_transform(f.0, len)
374 .into_iter()
375 .zip(Convolve::<N2>::inverse_transform(f.1, len))
376 .zip(Convolve::<N3>::inverse_transform(f.2, len))
377 .map(|((c1, c2), c3)| {
378 let d1 = c1.inner();
379 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
380 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
381 let d3 = ((c3 - x) * t2).inner();
382 MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
383 })
384 .collect()
385 }
386 fn multiply(f: &mut Self::F, g: &Self::F) {
387 assert_eq!(f.0.len(), g.0.len());
388 assert_eq!(f.1.len(), g.1.len());
389 assert_eq!(f.2.len(), g.2.len());
390 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
391 *f *= *g;
392 }
393 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
394 *f *= *g;
395 }
396 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
397 *f *= *g;
398 }
399 }
400 fn convolve(a: Self::T, b: Self::T) -> Self::T {
401 if Self::length(&a).min(Self::length(&b)) <= 60 {
402 return convolve_naive(&a, &b);
403 }
404 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
405 let mut a = Self::transform(a, len);
406 let b = Self::transform(b, len);
407 Self::multiply(&mut a, &b);
408 Self::inverse_transform(a, len)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::num::{
416 mint_basic::Modulo1000000009,
417 montgomery::{MInt998244353, Modulo998244353},
418 };
419 use crate::tools::Xorshift;
420
421 const N: usize = 8;
422
423 #[test]
424 fn test_ntt998244353() {
425 let mut rng = Xorshift::new();
426 let a: Vec<_> = rng
427 .random_iter(..MInt998244353::get_mod())
428 .map(MInt998244353::new_unchecked)
429 .take(N)
430 .collect();
431 let b: Vec<_> = rng
432 .random_iter(..MInt998244353::get_mod())
433 .map(MInt998244353::new_unchecked)
434 .take(N)
435 .collect();
436 let mut c = vec![MInt998244353::zero(); N * 2 - 1];
437 for i in 0..N {
438 for j in 0..N {
439 c[i + j] += a[i] * b[j];
440 }
441 }
442 let d = Convolve::<Modulo998244353>::convolve(a, b);
443 assert_eq!(c, d);
444 }
445
446 #[test]
447 fn test_convolve3() {
448 type M = MInt<Modulo1000000009>;
449 let mut rng = Xorshift::new();
450 let a: Vec<_> = rng
451 .random_iter(..M::get_mod())
452 .map(M::new_unchecked)
453 .take(N)
454 .collect();
455 let b: Vec<_> = rng
456 .random_iter(..M::get_mod())
457 .map(M::new_unchecked)
458 .take(N)
459 .collect();
460 let mut c = vec![M::zero(); N * 2 - 1];
461 for i in 0..N {
462 for j in 0..N {
463 c[i + j] += a[i] * b[j];
464 }
465 }
466 let d = MIntConvolve::<Modulo1000000009>::convolve(a, b);
467 assert_eq!(c, d);
468 }
469
470 #[allow(dead_code)]
472 fn find_proth() {
473 use crate::math::{divisors, prime_factors_flatten};
474 use crate::num::mint_basic::DynMIntU32;
475 for b in 22..32 {
477 for a in (1..1u64 << b).step_by(2) {
478 let p = a * (1u64 << b) + 1;
479 if 1 << 31 < p {
480 break;
481 }
482 if p < 1 << 29 {
483 continue;
484 }
485 let f = prime_factors_flatten(p);
486 if f.len() == 1 && f[0] == p {
487 DynMIntU32::set_mod(p as u32);
488 for g in (3..).step_by(2) {
489 let g = DynMIntU32::new(g);
490 if divisors(p - 1)
491 .into_iter()
492 .filter(|&d| d != p - 1)
493 .all(|d| g.pow(d as usize) != DynMIntU32::one())
494 {
495 println!("(p,a,b,g) = {:?}", (p, a, b, g));
496 break;
497 }
498 }
499 }
500 }
501 }
502 }
533}