1use super::{ConvolveSteps, MInt, MIntBase, MIntConvert, One, Zero, montgomery::*};
2use std::{
3 cell::UnsafeCell,
4 marker::PhantomData,
5 ops::{AddAssign, Mul, SubAssign},
6};
7
8pub struct Convolve<M>(PhantomData<fn() -> M>);
9pub type Convolve998244353 = Convolve<Modulo998244353>;
10pub type MIntConvolve<M> = Convolve<(M, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
11pub type U64Convolve = Convolve<(u64, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
12
13macro_rules! impl_ntt_modulus {
14 ($([$name:ident, $g:expr]),*) => {
15 $(
16 impl Montgomery32NttModulus for $name {}
17 )*
18 };
19}
20impl_ntt_modulus!(
21 [Modulo998244353, 3],
22 [Modulo2113929217, 5],
23 [Modulo1811939329, 13],
24 [Modulo2013265921, 31]
25);
26
27const fn reduce(z: u64, p: u32, r: u32) -> u32 {
28 let mut z = ((z + r.wrapping_mul(z as u32) as u64 * p as u64) >> 32) as u32;
29 if z >= p {
30 z -= p;
31 }
32 z
33}
34const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
35 reduce(x as u64 * y as u64, p, r)
36}
37const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
38 while y > 0 {
39 if y & 1 == 1 {
40 z = mod_mul(z, x, p, r);
41 }
42 x = mod_mul(x, x, p, r);
43 y >>= 1;
44 }
45 z
46}
47
48pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
49 const PRIMITIVE_ROOT: u32 = {
50 let mut g = 3u32;
51 loop {
52 let mut ok = true;
53 let mut d = 1u32;
54 while d * d < Self::MOD {
55 if (Self::MOD - 1) % d == 0 {
56 let ds = [d, (Self::MOD - 1) / d];
57 let mut i = 0;
58 while i < 2 {
59 ok &= ds[i] == Self::MOD - 1
60 || mod_pow(
61 reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
62 ds[i],
63 Self::MOD,
64 Self::R,
65 Self::N1,
66 ) != Self::N1;
67 i += 1;
68 }
69 }
70 d += 1;
71 }
72 if ok {
73 break;
74 }
75 g += 2;
76 }
77 g
78 };
79 const RANK: u32 = (Self::MOD - 1).trailing_zeros();
80 const INFO: NttInfo = NttInfo::new::<Self>();
81}
82
83#[derive(Debug, PartialEq)]
84pub struct NttInfo {
85 root: [u32; 32],
86 inv_root: [u32; 32],
87 rate2: [u32; 32],
88 inv_rate2: [u32; 32],
89 rate3: [u32; 32],
90 inv_rate3: [u32; 32],
91}
92impl NttInfo {
93 const fn new<M>() -> Self
94 where
95 M: Montgomery32NttModulus,
96 {
97 let mut root = [0; 32];
98 let mut inv_root = [0; 32];
99 let mut rate2 = [0; 32];
100 let mut inv_rate2 = [0; 32];
101 let mut rate3 = [0; 32];
102 let mut inv_rate3 = [0; 32];
103 let rank = M::RANK as usize;
104
105 let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
106 root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
107 inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
108 let mut i = rank - 1;
109 loop {
110 root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
111 inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
112 if i == 0 {
113 break;
114 }
115 i -= 1;
116 }
117
118 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
119 while i < rank - 1 {
120 rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
121 inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
122 prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
123 inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
124 i += 1;
125 }
126
127 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
128 while i < rank - 2 {
129 rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
130 inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
131 prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
132 inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
133 i += 1;
134 }
135
136 NttInfo {
137 root,
138 inv_root,
139 rate2,
140 inv_rate2,
141 rate3,
142 inv_rate3,
143 }
144 }
145}
146
147crate::avx_helper!(
148 @avx2 fn ntt<M>(a: &mut [MInt<M>])
149 where
150 [M: Montgomery32NttModulus]
151 {
152 let n = a.len();
153 let mut v = n / 2;
154 let imag = MInt::<M>::new_unchecked(M::INFO.root[2]);
155 while v > 1 {
156 let mut w1 = MInt::<M>::one();
157 for (s, a) in a.chunks_exact_mut(v << 1).enumerate() {
158 let (l, r) = a.split_at_mut(v);
159 let (ll, lr) = l.split_at_mut(v >> 1);
160 let (rl, rr) = r.split_at_mut(v >> 1);
161 let w2 = w1 * w1;
162 let w3 = w1 * w2;
163 for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
164 let a0 = *x0;
165 let a1 = *x1 * w1;
166 let a2 = *x2 * w2;
167 let a3 = *x3 * w3;
168 let a0pa2 = a0 + a2;
169 let a0na2 = a0 - a2;
170 let a1pa3 = a1 + a3;
171 let a1na3imag = (a1 - a3) * imag;
172 *x0 = a0pa2 + a1pa3;
173 *x1 = a0pa2 - a1pa3;
174 *x2 = a0na2 + a1na3imag;
175 *x3 = a0na2 - a1na3imag;
176 }
177 w1 *= MInt::<M>::new_unchecked(M::INFO.rate3[s.trailing_ones() as usize]);
178 }
179 v >>= 2;
180 }
181 if v == 1 {
182 let mut w1 = MInt::<M>::one();
183 for (s, a) in a.chunks_exact_mut(2).enumerate() {
184 unsafe {
185 let (l, r) = a.split_at_mut(1);
186 let x0 = l.get_unchecked_mut(0);
187 let x1 = r.get_unchecked_mut(0);
188 let a0 = *x0;
189 let a1 = *x1 * w1;
190 *x0 = a0 + a1;
191 *x1 = a0 - a1;
192 }
193 w1 *= MInt::<M>::new_unchecked(M::INFO.rate2[s.trailing_ones() as usize]);
194 }
195 }
196 }
197);
198crate::avx_helper!(
199 @avx2 fn intt<M>(a: &mut [MInt<M>])
200 where
201 [M: Montgomery32NttModulus]
202 {
203 let n = a.len();
204 let mut v = 1;
205 if n.trailing_zeros() & 1 == 1 {
206 let mut w1 = MInt::<M>::one();
207 for (s, a) in a.chunks_exact_mut(2).enumerate() {
208 unsafe {
209 let (l, r) = a.split_at_mut(1);
210 let x0 = l.get_unchecked_mut(0);
211 let x1 = r.get_unchecked_mut(0);
212 let a0 = *x0;
213 let a1 = *x1;
214 *x0 = a0 + a1;
215 *x1 = (a0 - a1) * w1;
216 }
217 w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate2[s.trailing_ones() as usize]);
218 }
219 v <<= 1;
220 }
221 let iimag = MInt::<M>::new_unchecked(M::INFO.inv_root[2]);
222 while v < n {
223 let mut w1 = MInt::<M>::one();
224 for (s, a) in a.chunks_exact_mut(v << 2).enumerate() {
225 let (l, r) = a.split_at_mut(v << 1);
226 let (ll, lr) = l.split_at_mut(v);
227 let (rl, rr) = r.split_at_mut(v);
228 let w2 = w1 * w1;
229 let w3 = w1 * w2;
230 for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
231 let a0 = *x0;
232 let a1 = *x1;
233 let a2 = *x2;
234 let a3 = *x3;
235 let a0pa1 = a0 + a1;
236 let a0na1 = a0 - a1;
237 let a2pa3 = a2 + a3;
238 let a2na3iimag = (a2 - a3) * iimag;
239 *x0 = a0pa1 + a2pa3;
240 *x1 = (a0na1 + a2na3iimag) * w1;
241 *x2 = (a0pa1 - a2pa3) * w2;
242 *x3 = (a0na1 - a2na3iimag) * w3;
243 }
244 w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate3[s.trailing_ones() as usize]);
245 }
246 v <<= 2;
247 }
248 }
249);
250
251fn convolve_naive<T>(a: &[T], b: &[T]) -> Vec<T>
252where
253 T: Copy + Zero + AddAssign<T> + Mul<Output = T>,
254{
255 if a.is_empty() && b.is_empty() {
256 return Vec::new();
257 }
258 let len = a.len() + b.len() - 1;
259 let mut c = vec![T::zero(); len];
260 if a.len() < b.len() {
261 for (i, &b) in b.iter().enumerate() {
262 for (a, c) in a.iter().zip(&mut c[i..]) {
263 *c += *a * b;
264 }
265 }
266 } else {
267 for (i, &a) in a.iter().enumerate() {
268 for (b, c) in b.iter().zip(&mut c[i..]) {
269 *c += *b * a;
270 }
271 }
272 }
273 c
274}
275
276fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
277where
278 T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
279{
280 if a.len().min(b.len()) <= 30 {
281 return convolve_naive(a, b);
282 }
283 let m = a.len().max(b.len()).div_ceil(2);
284 let (a0, a1) = if a.len() <= m {
285 (a, &[][..])
286 } else {
287 a.split_at(m)
288 };
289 let (b0, b1) = if b.len() <= m {
290 (b, &[][..])
291 } else {
292 b.split_at(m)
293 };
294 let f00 = convolve_karatsuba(a0, b0);
295 let f11 = convolve_karatsuba(a1, b1);
296 let mut a0a1 = a0.to_vec();
297 for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
298 *a0a1 += a1;
299 }
300 let mut b0b1 = b0.to_vec();
301 for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
302 *b0b1 += b1;
303 }
304 let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
305 for (f01, &f00) in f01.iter_mut().zip(&f00) {
306 *f01 -= f00;
307 }
308 for (f01, &f11) in f01.iter_mut().zip(&f11) {
309 *f01 -= f11;
310 }
311 let mut c = vec![T::zero(); a.len() + b.len() - 1];
312 for (c, &f00) in c.iter_mut().zip(&f00) {
313 *c += f00;
314 }
315 for (c, &f01) in c[m..].iter_mut().zip(&f01) {
316 *c += f01;
317 }
318 for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
319 *c += f11;
320 }
321 c
322}
323
324impl<M> ConvolveSteps for Convolve<M>
325where
326 M: Montgomery32NttModulus,
327{
328 type T = Vec<MInt<M>>;
329 type F = Vec<MInt<M>>;
330 fn length(t: &Self::T) -> usize {
331 t.len()
332 }
333 fn transform(mut t: Self::T, len: usize) -> Self::F {
334 t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
335 ntt(&mut t);
336 t
337 }
338 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
339 intt(&mut f);
340 f.truncate(len);
341 let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
342 for f in f.iter_mut() {
343 *f *= inv;
344 }
345 f
346 }
347 fn multiply(f: &mut Self::F, g: &Self::F) {
348 assert_eq!(f.len(), g.len());
349 for (f, g) in f.iter_mut().zip(g.iter()) {
350 *f *= *g;
351 }
352 }
353 fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
354 if Self::length(&a).max(Self::length(&b)) <= 100 {
355 return convolve_karatsuba(&a, &b);
356 }
357 if Self::length(&a).min(Self::length(&b)) <= 60 {
358 return convolve_naive(&a, &b);
359 }
360 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
361 let size = len.max(1).next_power_of_two();
362 if len <= size / 2 + 2 {
363 let xa = a.pop().unwrap();
364 let xb = b.pop().unwrap();
365 let mut c = vec![MInt::<M>::zero(); len];
366 *c.last_mut().unwrap() = xa * xb;
367 for (a, c) in a.iter().zip(&mut c[b.len()..]) {
368 *c += *a * xb;
369 }
370 for (b, c) in b.iter().zip(&mut c[a.len()..]) {
371 *c += *b * xa;
372 }
373 let d = Self::convolve(a, b);
374 for (d, c) in d.into_iter().zip(&mut c) {
375 *c += d;
376 }
377 return c;
378 }
379 let same = a == b;
380 let mut a = Self::transform(a, len);
381 if same {
382 for a in a.iter_mut() {
383 *a *= *a;
384 }
385 } else {
386 let b = Self::transform(b, len);
387 Self::multiply(&mut a, &b);
388 }
389 Self::inverse_transform(a, len)
390 }
391}
392
393type MVec<M> = Vec<MInt<M>>;
394impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
395where
396 M: MIntConvert + MIntConvert<u32>,
397 N1: Montgomery32NttModulus,
398 N2: Montgomery32NttModulus,
399 N3: Montgomery32NttModulus,
400{
401 type T = MVec<M>;
402 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
403 fn length(t: &Self::T) -> usize {
404 t.len()
405 }
406 fn transform(t: Self::T, len: usize) -> Self::F {
407 let npot = len.max(1).next_power_of_two();
408 let mut f = (
409 MVec::<N1>::with_capacity(npot),
410 MVec::<N2>::with_capacity(npot),
411 MVec::<N3>::with_capacity(npot),
412 );
413 for t in t {
414 f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
415 f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
416 f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
417 }
418 f.0.resize_with(npot, Zero::zero);
419 f.1.resize_with(npot, Zero::zero);
420 f.2.resize_with(npot, Zero::zero);
421 ntt(&mut f.0);
422 ntt(&mut f.1);
423 ntt(&mut f.2);
424 f
425 }
426 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
427 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
428 let m1 = MInt::<M>::from(N1::get_mod());
429 let m1_3 = MInt::<N3>::new(N1::get_mod());
430 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
431 let m2 = m1 * MInt::<M>::from(N2::get_mod());
432 Convolve::<N1>::inverse_transform(f.0, len)
433 .into_iter()
434 .zip(Convolve::<N2>::inverse_transform(f.1, len))
435 .zip(Convolve::<N3>::inverse_transform(f.2, len))
436 .map(|((c1, c2), c3)| {
437 let d1 = c1.inner();
438 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
439 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
440 let d3 = ((c3 - x) * t2).inner();
441 MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
442 })
443 .collect()
444 }
445 fn multiply(f: &mut Self::F, g: &Self::F) {
446 assert_eq!(f.0.len(), g.0.len());
447 assert_eq!(f.1.len(), g.1.len());
448 assert_eq!(f.2.len(), g.2.len());
449 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
450 *f *= *g;
451 }
452 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
453 *f *= *g;
454 }
455 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
456 *f *= *g;
457 }
458 }
459 fn convolve(a: Self::T, b: Self::T) -> Self::T {
460 if Self::length(&a).max(Self::length(&b)) <= 300 {
461 return convolve_karatsuba(&a, &b);
462 }
463 if Self::length(&a).min(Self::length(&b)) <= 60 {
464 return convolve_naive(&a, &b);
465 }
466 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
467 let mut a = Self::transform(a, len);
468 let b = Self::transform(b, len);
469 Self::multiply(&mut a, &b);
470 Self::inverse_transform(a, len)
471 }
472}
473
474impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
475where
476 N1: Montgomery32NttModulus,
477 N2: Montgomery32NttModulus,
478 N3: Montgomery32NttModulus,
479{
480 type T = Vec<u64>;
481 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
482
483 fn length(t: &Self::T) -> usize {
484 t.len()
485 }
486
487 fn transform(t: Self::T, len: usize) -> Self::F {
488 let npot = len.max(1).next_power_of_two();
489 let mut f = (
490 MVec::<N1>::with_capacity(npot),
491 MVec::<N2>::with_capacity(npot),
492 MVec::<N3>::with_capacity(npot),
493 );
494 for t in t {
495 f.0.push(t.into());
496 f.1.push(t.into());
497 f.2.push(t.into());
498 }
499 f.0.resize_with(npot, Zero::zero);
500 f.1.resize_with(npot, Zero::zero);
501 f.2.resize_with(npot, Zero::zero);
502 ntt(&mut f.0);
503 ntt(&mut f.1);
504 ntt(&mut f.2);
505 f
506 }
507
508 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
509 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
510 let m1 = N1::get_mod() as u64;
511 let m1_3 = MInt::<N3>::new(N1::get_mod());
512 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
513 let m2 = m1 * N2::get_mod() as u64;
514 Convolve::<N1>::inverse_transform(f.0, len)
515 .into_iter()
516 .zip(Convolve::<N2>::inverse_transform(f.1, len))
517 .zip(Convolve::<N3>::inverse_transform(f.2, len))
518 .map(|((c1, c2), c3)| {
519 let d1 = c1.inner();
520 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
521 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
522 let d3 = ((c3 - x) * t2).inner();
523 d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
524 })
525 .collect()
526 }
527
528 fn multiply(f: &mut Self::F, g: &Self::F) {
529 assert_eq!(f.0.len(), g.0.len());
530 assert_eq!(f.1.len(), g.1.len());
531 assert_eq!(f.2.len(), g.2.len());
532 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
533 *f *= *g;
534 }
535 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
536 *f *= *g;
537 }
538 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
539 *f *= *g;
540 }
541 }
542
543 fn convolve(a: Self::T, b: Self::T) -> Self::T {
544 if Self::length(&a).max(Self::length(&b)) <= 300 {
545 return convolve_karatsuba(&a, &b);
546 }
547 if Self::length(&a).min(Self::length(&b)) <= 60 {
548 return convolve_naive(&a, &b);
549 }
550 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
551 let mut a = Self::transform(a, len);
552 let b = Self::transform(b, len);
553 Self::multiply(&mut a, &b);
554 Self::inverse_transform(a, len)
555 }
556}
557
558pub trait NttReuse: ConvolveSteps {
559 const MULTIPLE: bool = true;
560
561 fn ntt_doubling(f: Self::F) -> Self::F;
563
564 fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
566
567 fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
569}
570
571thread_local!(
572 static BIT_REVERSE: UnsafeCell<Vec<Vec<usize>>> = const { UnsafeCell::new(vec![]) };
573);
574
575impl<M> NttReuse for Convolve<M>
576where
577 M: Montgomery32NttModulus,
578{
579 const MULTIPLE: bool = false;
580
581 fn ntt_doubling(mut f: Self::F) -> Self::F {
582 let n = f.len();
583 let k = n.trailing_zeros() as usize;
584 let mut a = Self::inverse_transform(f.clone(), n);
585 let mut rot = MInt::<M>::one();
586 let zeta = MInt::<M>::new_unchecked(M::INFO.root[k + 1]);
587 for a in a.iter_mut() {
588 *a *= rot;
589 rot *= zeta;
590 }
591 f.extend(Self::transform(a, n));
592 f
593 }
594
595 fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
596 assert_eq!(f.len(), g.len());
597 assert!(f.len().is_power_of_two());
598 assert!(f.len() >= 2);
599 let inv2 = MInt::<M>::from(2).inv();
600 let n = f.len() / 2;
601 (0..n)
602 .map(|i| (f[i << 1] * g[i << 1 | 1] + f[i << 1 | 1] * g[i << 1]) * inv2)
603 .collect()
604 }
605
606 fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
607 assert_eq!(f.len(), g.len());
608 assert!(f.len().is_power_of_two());
609 assert!(f.len() >= 2);
610 let mut inv2 = MInt::<M>::from(2).inv();
611 let n = f.len() / 2;
612 let k = f.len().trailing_zeros() as usize;
613 let mut h = vec![MInt::<M>::zero(); n];
614 let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
615 BIT_REVERSE.with(|br| {
616 let br = unsafe { &mut *br.get() };
617 if br.len() < k {
618 br.resize_with(k, Default::default);
619 }
620 let k = k - 1;
621 if br[k].is_empty() {
622 let mut v = vec![0; 1 << k];
623 for i in 0..1 << k {
624 v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
625 }
626 br[k] = v;
627 }
628 for &i in &br[k] {
629 h[i] = (f[i << 1] * g[i << 1 | 1] - f[i << 1 | 1] * g[i << 1]) * inv2;
630 inv2 *= w;
631 }
632 });
633 h
634 }
635}
636
637impl<M, N1, N2, N3> NttReuse for Convolve<(M, (N1, N2, N3))>
638where
639 M: MIntConvert + MIntConvert<u32>,
640 N1: Montgomery32NttModulus,
641 N2: Montgomery32NttModulus,
642 N3: Montgomery32NttModulus,
643{
644 fn ntt_doubling(f: Self::F) -> Self::F {
645 (
646 Convolve::<N1>::ntt_doubling(f.0),
647 Convolve::<N2>::ntt_doubling(f.1),
648 Convolve::<N3>::ntt_doubling(f.2),
649 )
650 }
651
652 fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
653 fn even_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
654 where
655 M: Montgomery32NttModulus,
656 {
657 let n = f.len();
658 assert_eq!(f.len(), g.len());
659 assert!(f.len().is_power_of_two());
660 assert!(f.len() >= 2);
661 let inv2 = MInt::<M>::from(2).inv();
662 let u = MInt::<M>::new(m) * MInt::<M>::from(n as u32);
663 let n = f.len() / 2;
664 (0..n)
665 .map(|i| {
666 (f[i << 1]
667 * if i == 0 {
668 g[i << 1 | 1] + u
669 } else {
670 g[i << 1 | 1]
671 }
672 + f[i << 1 | 1] * g[i << 1])
673 * inv2
674 })
675 .collect()
676 }
677
678 let m = M::mod_into();
679 (
680 even_mul_normal_neg_corrected(&f.0, &g.0, m),
681 even_mul_normal_neg_corrected(&f.1, &g.1, m),
682 even_mul_normal_neg_corrected(&f.2, &g.2, m),
683 )
684 }
685
686 fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
687 fn odd_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
688 where
689 M: Montgomery32NttModulus,
690 {
691 assert_eq!(f.len(), g.len());
692 assert!(f.len().is_power_of_two());
693 assert!(f.len() >= 2);
694 let mut inv2 = MInt::<M>::from(2).inv();
695 let u = MInt::<M>::new(m) * MInt::<M>::from(f.len() as u32);
696 let n = f.len() / 2;
697 let k = f.len().trailing_zeros() as usize;
698 let mut h = vec![MInt::<M>::zero(); n];
699 let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
700 BIT_REVERSE.with(|br| {
701 let br = unsafe { &mut *br.get() };
702 if br.len() < k {
703 br.resize_with(k, Default::default);
704 }
705 let k = k - 1;
706 if br[k].is_empty() {
707 let mut v = vec![0; 1 << k];
708 for i in 0..1 << k {
709 v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
710 }
711 br[k] = v;
712 }
713 for &i in &br[k] {
714 h[i] = (f[i << 1]
715 * if i == 0 {
716 g[i << 1 | 1] + u
717 } else {
718 g[i << 1 | 1]
719 }
720 - f[i << 1 | 1] * g[i << 1])
721 * inv2;
722 inv2 *= w;
723 }
724 });
725 h
726 }
727
728 let m = M::mod_into();
729 (
730 odd_mul_normal_neg_corrected(&f.0, &g.0, m),
731 odd_mul_normal_neg_corrected(&f.1, &g.1, m),
732 odd_mul_normal_neg_corrected(&f.2, &g.2, m),
733 )
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740 use crate::num::{mint_basic::Modulo1000000009, montgomery::MInt998244353};
741 use crate::tools::Xorshift;
742
743 #[test]
744 fn test_convolve_naive() {
745 let mut rng = Xorshift::default();
746 for _ in 0..1000 {
747 let n = rng.random(0..=60);
748 let m = rng.random(0..=60);
749 let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
750 let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
751 let mut c = vec![0u32; (n + m).saturating_sub(1)];
752 for i in 0..n {
753 for j in 0..m {
754 c[i + j] += a[i] * b[j];
755 }
756 }
757 let d = convolve_naive(&a, &b);
758 assert_eq!(c, d);
759 }
760 }
761
762 #[test]
763 fn test_convolve_karatsuba() {
764 let mut rng = Xorshift::default();
765 for _ in 0..1000 {
766 let n = rng.random(0..=200);
767 let m = rng.random(0..=200);
768 let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
769 let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
770 let mut c = vec![0u32; (n + m).saturating_sub(1)];
771 for i in 0..n {
772 for j in 0..m {
773 c[i + j] += a[i] * b[j];
774 }
775 }
776 let d = convolve_karatsuba(&a, &b);
777 assert_eq!(c, d);
778 }
779 }
780
781 #[test]
782 fn test_ntt998244353() {
783 let mut rng = Xorshift::default();
784 for t in 0..1000 {
785 let n: usize = rng.random(0..=5);
786 let n = if n == 5 { rng.random(70..=120) } else { n };
787 let m: usize = rng.random(0..=5);
788 let m = if m == 5 { rng.random(70..=120) } else { m };
789 let (n, m) = if t % 100 != 0 {
790 (n, m)
791 } else {
792 let w = rng.random(6..=8);
793 ((1usize << w) + 1usize, (1usize << w) + 1usize)
794 };
795 let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
796 let mut b: Vec<MInt998244353> = rng.random_iter(..).take(m).collect();
797 if n == m && rng.random(0..2) == 0 {
798 b = a.clone();
799 }
800
801 let mut c = vec![MInt998244353::zero(); (n + m).saturating_sub(1)];
802 for i in 0..n {
803 for j in 0..m {
804 c[i + j] += a[i] * b[j];
805 }
806 }
807 let d = Convolve998244353::convolve(a, b);
808 assert_eq!(c, d);
809 }
810 assert_eq!(NttInfo::new::<Modulo998244353>(), Modulo998244353::INFO);
811 }
812
813 #[test]
814 fn test_convolve3() {
815 type M = MInt<Modulo1000000009>;
816 let mut rng = Xorshift::default();
817 for _ in 0..1000 {
818 let n = rng.random(0..=5);
819 let n = if n == 5 { rng.random(70..=400) } else { n };
820 let m = rng.random(0..=5);
821 let m = if m == 5 { rng.random(70..=400) } else { m };
822 let a: Vec<M> = rng.random_iter(..).take(n).collect();
823 let b: Vec<M> = rng.random_iter(..).take(m).collect();
824 let mut c = vec![M::zero(); (n + m).saturating_sub(1)];
825 for i in 0..n {
826 for j in 0..m {
827 c[i + j] += a[i] * b[j];
828 }
829 }
830 let d = MIntConvolve::<Modulo1000000009>::convolve(a, b);
831 assert_eq!(c, d);
832 }
833 }
834
835 #[test]
836 fn test_convolve_u64() {
837 let mut rng = Xorshift::default();
838 for _ in 0..1000 {
839 let n = rng.random(0..=5);
840 let n = if n == 5 { rng.random(70..=400) } else { n };
841 let m = rng.random(0..=5);
842 let m = if m == 5 { rng.random(70..=400) } else { m };
843 let a: Vec<u64> = rng.random_iter(0u64..1 << 24).take(n).collect();
844 let b: Vec<u64> = rng.random_iter(0u64..1 << 24).take(m).collect();
845 let mut c = vec![0; (n + m).saturating_sub(1)];
846 for i in 0..n {
847 for j in 0..m {
848 c[i + j] += a[i] * b[j];
849 }
850 }
851 let d = U64Convolve::convolve(a, b);
852 assert_eq!(c, d);
853 }
854 }
855
856 #[test]
857 fn test_ntt_reuse_998244353() {
858 let mut rng = Xorshift::default();
859 for _ in 0..100 {
860 let n: usize = if rng.gen_bool(0.5) {
861 rng.random(1..=20)
862 } else {
863 rng.random(1..=1000)
864 };
865 let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
866 let f = Convolve998244353::transform(a.clone(), n);
867
868 {
870 let f_double = Convolve998244353::ntt_doubling(f.clone());
871 let mut a = a.clone();
872 a.resize_with(n * 2, Zero::zero);
873 let f2 = Convolve998244353::transform(a, n * 2);
874 assert_eq!(f_double, f2);
875 }
876
877 let f = Convolve998244353::transform(a.clone(), n * 2);
878 let b: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
879 let g = Convolve998244353::transform(b.clone(), n * 2);
880 let mut b_neg = b.clone();
881 for b in b_neg.iter_mut().skip(1).step_by(2) {
882 *b = -*b;
883 }
884
885 {
887 let fg_neg = Convolve998244353::even_mul_normal_neg(&f, &g);
888 let ab_neg_even: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
889 .into_iter()
890 .step_by(2)
891 .collect();
892 let fg = Convolve998244353::transform(ab_neg_even, n);
893 assert_eq!(fg_neg, fg);
894 }
895
896 {
898 let fg_neg = Convolve998244353::odd_mul_normal_neg(&f, &g);
899 let ab_neg_odd: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
900 .into_iter()
901 .skip(1)
902 .step_by(2)
903 .collect();
904 let fg = Convolve998244353::transform(ab_neg_odd, n);
905 assert_eq!(fg_neg, fg);
906 }
907 }
908 }
909
910 #[test]
911 fn test_ntt_reuse_triple() {
912 type M = MInt<Modulo1000000009>;
913 let mut rng = Xorshift::default();
914 for _ in 0..100 {
915 let n: usize = if rng.gen_bool(0.5) {
916 rng.random(1..=20)
917 } else {
918 rng.random(1..=1000)
919 };
920 let a: Vec<M> = rng.random_iter(..).take(n).collect();
921 let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n);
922
923 {
925 let f_double = MIntConvolve::<Modulo1000000009>::ntt_doubling(f.clone());
926 let mut a = a.clone();
927 a.resize_with(n * 2, Zero::zero);
928 let f2 = MIntConvolve::<Modulo1000000009>::transform(a, n * 2);
929 assert_eq!(f_double, f2);
930 }
931
932 let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n * 2);
933 let b: Vec<M> = rng.random_iter(..).take(n).collect();
934 let g = MIntConvolve::<Modulo1000000009>::transform(b.clone(), n * 2);
935 let mut b_neg = b.clone();
936 for b in b_neg.iter_mut().skip(1).step_by(2) {
937 *b = -*b;
938 }
939
940 {
942 let fg_neg = MIntConvolve::<Modulo1000000009>::even_mul_normal_neg(&f, &g);
943 let ab_neg_even: Vec<_> =
944 MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
945 .into_iter()
946 .step_by(2)
947 .collect();
948 assert_eq!(
949 MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
950 ab_neg_even
951 );
952 }
953
954 {
956 let fg_neg = MIntConvolve::<Modulo1000000009>::odd_mul_normal_neg(&f, &g);
957 let ab_neg_odd: Vec<_> =
958 MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
959 .into_iter()
960 .skip(1)
961 .step_by(2)
962 .chain([M::zero()])
963 .collect();
964 assert_eq!(
965 MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
966 ab_neg_odd
967 );
968 }
969 }
970 }
971}