1use super::{ConvolveSteps, MInt, MIntBase, MIntConvert, One, Zero, montgomery::*};
2#[cfg(target_arch = "x86_64")]
3use super::{SimdBackend, simd_backend};
4use std::{
5 cell::UnsafeCell,
6 marker::PhantomData,
7 ops::{AddAssign, Mul, SubAssign},
8};
9
10pub struct Convolve<M>(PhantomData<fn() -> M>);
11pub type Convolve998244353 = Convolve<Modulo998244353>;
12pub type MIntConvolve<M> = Convolve<(M, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
13pub type U64Convolve = Convolve<(u64, (Modulo2013265921, Modulo1811939329, Modulo2113929217))>;
14
15macro_rules! impl_ntt_modulus {
16 ($([$name:ident, $g:expr]),*) => {
17 $(
18 impl Montgomery32NttModulus for $name {}
19 )*
20 };
21}
22impl_ntt_modulus!(
23 [Modulo998244353, 3],
24 [Modulo2113929217, 5],
25 [Modulo1811939329, 13],
26 [Modulo2013265921, 31]
27);
28
29const fn reduce(z: u64, p: u32, r: u32) -> u32 {
30 let mut z = ((z + r.wrapping_mul(z as u32) as u64 * p as u64) >> 32) as u32;
31 if z >= p {
32 z -= p;
33 }
34 z
35}
36const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32 {
37 reduce(x as u64 * y as u64, p, r)
38}
39const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
40 while y > 0 {
41 if y & 1 == 1 {
42 z = mod_mul(z, x, p, r);
43 }
44 x = mod_mul(x, x, p, r);
45 y >>= 1;
46 }
47 z
48}
49
50pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
51 const PRIMITIVE_ROOT: u32 = {
52 let mut g = 3u32;
53 loop {
54 let mut ok = true;
55 let mut d = 1u32;
56 while d * d < Self::MOD {
57 if (Self::MOD - 1) % d == 0 {
58 let ds = [d, (Self::MOD - 1) / d];
59 let mut i = 0;
60 while i < 2 {
61 ok &= ds[i] == Self::MOD - 1
62 || mod_pow(
63 reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
64 ds[i],
65 Self::MOD,
66 Self::R,
67 Self::N1,
68 ) != Self::N1;
69 i += 1;
70 }
71 }
72 d += 1;
73 }
74 if ok {
75 break;
76 }
77 g += 2;
78 }
79 g
80 };
81 const RANK: u32 = (Self::MOD - 1).trailing_zeros();
82 const INFO: NttInfo = NttInfo::new::<Self>();
83}
84
85#[derive(Debug, PartialEq)]
86pub struct NttInfo {
87 root: [u32; 32],
88 inv_root: [u32; 32],
89 rate2: [u32; 32],
90 inv_rate2: [u32; 32],
91 rate3: [u32; 32],
92 inv_rate3: [u32; 32],
93}
94impl NttInfo {
95 const fn new<M>() -> Self
96 where
97 M: Montgomery32NttModulus,
98 {
99 let mut root = [0; 32];
100 let mut inv_root = [0; 32];
101 let mut rate2 = [0; 32];
102 let mut inv_rate2 = [0; 32];
103 let mut rate3 = [0; 32];
104 let mut inv_rate3 = [0; 32];
105 let rank = M::RANK as usize;
106
107 let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
108 root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
109 inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
110 let mut i = rank - 1;
111 loop {
112 root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
113 inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
114 if i == 0 {
115 break;
116 }
117 i -= 1;
118 }
119
120 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
121 while i < rank - 1 {
122 rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
123 inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
124 prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
125 inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
126 i += 1;
127 }
128
129 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
130 while i < rank - 2 {
131 rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
132 inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
133 prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
134 inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
135 i += 1;
136 }
137
138 NttInfo {
139 root,
140 inv_root,
141 rate2,
142 inv_rate2,
143 rate3,
144 inv_rate3,
145 }
146 }
147}
148
149fn ntt_scalar<M>(a: &mut [MInt<M>])
150where
151 M: Montgomery32NttModulus,
152{
153 let n = a.len();
154 let mut v = n / 2;
155 let imag = MInt::<M>::new_unchecked(M::INFO.root[2]);
156 while v > 1 {
157 let mut w1 = MInt::<M>::one();
158 for (s, a) in a.chunks_exact_mut(v << 1).enumerate() {
159 let (l, r) = a.split_at_mut(v);
160 let (ll, lr) = l.split_at_mut(v >> 1);
161 let (rl, rr) = r.split_at_mut(v >> 1);
162 let w2 = w1 * w1;
163 let w3 = w1 * w2;
164 for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
165 let a0 = *x0;
166 let a1 = *x1 * w1;
167 let a2 = *x2 * w2;
168 let a3 = *x3 * w3;
169 let a0pa2 = a0 + a2;
170 let a0na2 = a0 - a2;
171 let a1pa3 = a1 + a3;
172 let a1na3imag = (a1 - a3) * imag;
173 *x0 = a0pa2 + a1pa3;
174 *x1 = a0pa2 - a1pa3;
175 *x2 = a0na2 + a1na3imag;
176 *x3 = a0na2 - a1na3imag;
177 }
178 w1 *= MInt::<M>::new_unchecked(M::INFO.rate3[s.trailing_ones() as usize]);
179 }
180 v >>= 2;
181 }
182 if v == 1 {
183 let mut w1 = MInt::<M>::one();
184 for (s, a) in a.chunks_exact_mut(2).enumerate() {
185 unsafe {
186 let (l, r) = a.split_at_mut(1);
187 let x0 = l.get_unchecked_mut(0);
188 let x1 = r.get_unchecked_mut(0);
189 let a0 = *x0;
190 let a1 = *x1 * w1;
191 *x0 = a0 + a1;
192 *x1 = a0 - a1;
193 }
194 w1 *= MInt::<M>::new_unchecked(M::INFO.rate2[s.trailing_ones() as usize]);
195 }
196 }
197}
198
199fn intt_scalar<M>(a: &mut [MInt<M>])
200where
201 M: Montgomery32NttModulus,
202{
203 let n = a.len();
204 let mut v = 1;
205 if n.trailing_zeros() & 1 == 1 {
206 let mut w1 = MInt::<M>::one();
207 for (s, a) in a.chunks_exact_mut(2).enumerate() {
208 unsafe {
209 let (l, r) = a.split_at_mut(1);
210 let x0 = l.get_unchecked_mut(0);
211 let x1 = r.get_unchecked_mut(0);
212 let a0 = *x0;
213 let a1 = *x1;
214 *x0 = a0 + a1;
215 *x1 = (a0 - a1) * w1;
216 }
217 w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate2[s.trailing_ones() as usize]);
218 }
219 v <<= 1;
220 }
221 let iimag = MInt::<M>::new_unchecked(M::INFO.inv_root[2]);
222 while v < n {
223 let mut w1 = MInt::<M>::one();
224 for (s, a) in a.chunks_exact_mut(v << 2).enumerate() {
225 let (l, r) = a.split_at_mut(v << 1);
226 let (ll, lr) = l.split_at_mut(v);
227 let (rl, rr) = r.split_at_mut(v);
228 let w2 = w1 * w1;
229 let w3 = w1 * w2;
230 for (((x0, x1), x2), x3) in ll.iter_mut().zip(lr).zip(rl).zip(rr) {
231 let a0 = *x0;
232 let a1 = *x1;
233 let a2 = *x2;
234 let a3 = *x3;
235 let a0pa1 = a0 + a1;
236 let a0na1 = a0 - a1;
237 let a2pa3 = a2 + a3;
238 let a2na3iimag = (a2 - a3) * iimag;
239 *x0 = a0pa1 + a2pa3;
240 *x1 = (a0na1 + a2na3iimag) * w1;
241 *x2 = (a0pa1 - a2pa3) * w2;
242 *x3 = (a0na1 - a2na3iimag) * w3;
243 }
244 w1 *= MInt::<M>::new_unchecked(M::INFO.inv_rate3[s.trailing_ones() as usize]);
245 }
246 v <<= 2;
247 }
248}
249
250fn ntt<M>(a: &mut [MInt<M>])
251where
252 M: Montgomery32NttModulus,
253{
254 #[cfg(target_arch = "x86_64")]
255 {
256 match simd_backend() {
257 SimdBackend::Avx512 => unsafe { ntt_simd::ntt_avx512::<M>(a) },
258 SimdBackend::Avx2 => unsafe { ntt_simd::ntt_avx2::<M>(a) },
259 SimdBackend::Scalar => ntt_scalar(a),
260 }
261 }
262 #[cfg(not(target_arch = "x86_64"))]
263 {
264 ntt_scalar(a);
265 }
266}
267
268fn intt<M>(a: &mut [MInt<M>])
269where
270 M: Montgomery32NttModulus,
271{
272 #[cfg(target_arch = "x86_64")]
273 {
274 match simd_backend() {
275 SimdBackend::Avx512 => unsafe { ntt_simd::intt_avx512::<M>(a) },
276 SimdBackend::Avx2 => unsafe { ntt_simd::intt_avx2::<M>(a) },
277 SimdBackend::Scalar => intt_scalar(a),
278 }
279 }
280 #[cfg(not(target_arch = "x86_64"))]
281 {
282 intt_scalar(a);
283 }
284}
285
286#[cfg(target_arch = "x86_64")]
287#[allow(unsafe_op_in_unsafe_fn)] mod ntt_simd {
289 use super::*;
290 use std::arch::x86_64::*;
291
292 const LAZY_THRESHOLD: u32 = 1 << 30;
293
294 #[target_feature(enable = "avx2")]
295 unsafe fn normalize_avx2<M>(a: &mut [u32])
296 where
297 M: Montgomery32NttModulus,
298 {
299 let mod_vec = _mm256_set1_epi32(M::MOD as i32);
300 let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
301 let mut i = 0;
302 while i + 8 <= a.len() {
303 let x = _mm256_loadu_si256(a.as_ptr().add(i) as *const __m256i);
304 let x_x = _mm256_xor_si256(x, sign);
305 let m_x = _mm256_xor_si256(mod_vec, sign);
306 let gt = _mm256_cmpgt_epi32(x_x, m_x);
307 let eq = _mm256_cmpeq_epi32(x, mod_vec);
308 let mask = _mm256_or_si256(gt, eq);
309 let sub = _mm256_and_si256(mod_vec, mask);
310 let y = _mm256_sub_epi32(x, sub);
311 _mm256_storeu_si256(a.as_mut_ptr().add(i) as *mut __m256i, y);
312 i += 8;
313 }
314 while i < a.len() {
315 let x = a[i];
316 a[i] = if x >= M::MOD { x - M::MOD } else { x };
317 i += 1;
318 }
319 }
320
321 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
322 unsafe fn normalize_avx512<M>(a: &mut [u32])
323 where
324 M: Montgomery32NttModulus,
325 {
326 let mod_vec = _mm512_set1_epi32(M::MOD as i32);
327 let mut i = 0;
328 while i + 16 <= a.len() {
329 let x = _mm512_loadu_si512(a.as_ptr().add(i) as *const __m512i);
330 let mask = !_mm512_cmp_epu32_mask(x, mod_vec, _MM_CMPINT_LT);
331 let y = _mm512_mask_sub_epi32(x, mask, x, mod_vec);
332 _mm512_storeu_si512(a.as_mut_ptr().add(i) as *mut __m512i, y);
333 i += 16;
334 }
335 while i < a.len() {
336 let x = a[i];
337 a[i] = if x >= M::MOD { x - M::MOD } else { x };
338 i += 1;
339 }
340 }
341
342 unsafe fn add_vec_avx2<M>(
343 a: __m256i,
344 b: __m256i,
345 mod_vec: __m256i,
346 mod2_vec: __m256i,
347 sign: __m256i,
348 ) -> __m256i
349 where
350 M: Montgomery32NttModulus,
351 {
352 if M::MOD < LAZY_THRESHOLD {
353 simd32::montgomery_add_256(a, b, mod2_vec, sign)
354 } else {
355 simd32::add_mod_256(a, b, mod_vec, sign)
356 }
357 }
358
359 unsafe fn sub_vec_avx2<M>(
360 a: __m256i,
361 b: __m256i,
362 mod_vec: __m256i,
363 mod2_vec: __m256i,
364 sign: __m256i,
365 ) -> __m256i
366 where
367 M: Montgomery32NttModulus,
368 {
369 if M::MOD < LAZY_THRESHOLD {
370 simd32::montgomery_sub_256(a, b, mod2_vec, sign)
371 } else {
372 simd32::sub_mod_256(a, b, mod_vec, sign)
373 }
374 }
375
376 unsafe fn mul_vec_avx2<M>(
377 a: __m256i,
378 b: __m256i,
379 r_vec: __m256i,
380 mod_vec: __m256i,
381 sign: __m256i,
382 ) -> __m256i
383 where
384 M: Montgomery32NttModulus,
385 {
386 if M::MOD < LAZY_THRESHOLD {
387 simd32::montgomery_mul_256(a, b, r_vec, mod_vec)
388 } else {
389 simd32::montgomery_mul_256_canon(a, b, r_vec, mod_vec, sign)
390 }
391 }
392
393 unsafe fn add_vec_avx512<M>(
394 a: __m512i,
395 b: __m512i,
396 mod_vec: __m512i,
397 mod2_vec: __m512i,
398 ) -> __m512i
399 where
400 M: Montgomery32NttModulus,
401 {
402 if M::MOD < LAZY_THRESHOLD {
403 simd32::montgomery_add_512(a, b, mod2_vec)
404 } else {
405 simd32::add_mod_512(a, b, mod_vec)
406 }
407 }
408
409 unsafe fn sub_vec_avx512<M>(
410 a: __m512i,
411 b: __m512i,
412 mod_vec: __m512i,
413 mod2_vec: __m512i,
414 ) -> __m512i
415 where
416 M: Montgomery32NttModulus,
417 {
418 if M::MOD < LAZY_THRESHOLD {
419 simd32::montgomery_sub_512(a, b, mod2_vec)
420 } else {
421 simd32::sub_mod_512(a, b, mod_vec)
422 }
423 }
424
425 unsafe fn mul_vec_avx512<M>(a: __m512i, b: __m512i, r_vec: __m512i, mod_vec: __m512i) -> __m512i
426 where
427 M: Montgomery32NttModulus,
428 {
429 if M::MOD < LAZY_THRESHOLD {
430 simd32::montgomery_mul_512(a, b, r_vec, mod_vec)
431 } else {
432 simd32::montgomery_mul_512_canon(a, b, r_vec, mod_vec)
433 }
434 }
435
436 #[target_feature(enable = "avx2")]
437 pub(super) unsafe fn ntt_avx2<M>(a: &mut [MInt<M>])
438 where
439 M: Montgomery32NttModulus,
440 {
441 let n = a.len();
442 if n <= 1 {
443 return;
444 }
445 let ptr = a.as_mut_ptr() as *mut u32;
446 let a = std::slice::from_raw_parts_mut(ptr, n);
447 let mod_vec = _mm256_set1_epi32(M::MOD as i32);
448 let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
449 let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
450 let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
451 let imag = M::INFO.root[2];
452 let imag_vec = _mm256_set1_epi32(imag as i32);
453
454 let mut v = n / 2;
455 while v > 1 {
456 let half = v >> 1;
457 let mut w1 = M::N1;
458 for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
459 let base = block.as_mut_ptr();
460 let ll = base;
461 let lr = base.add(half);
462 let rl = base.add(v);
463 let rr = base.add(v + half);
464
465 let w2 = M::mod_mul(w1, w1);
466 let w3 = M::mod_mul(w2, w1);
467 let w1v = _mm256_set1_epi32(w1 as i32);
468 let w2v = _mm256_set1_epi32(w2 as i32);
469 let w3v = _mm256_set1_epi32(w3 as i32);
470
471 let mut i = 0;
472 while i + 8 <= half {
473 let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
474 let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
475 let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
476 let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
477
478 let a1 = mul_vec_avx2::<M>(x1, w1v, r_vec, mod_vec, sign);
479 let a2 = mul_vec_avx2::<M>(x2, w2v, r_vec, mod_vec, sign);
480 let a3 = mul_vec_avx2::<M>(x3, w3v, r_vec, mod_vec, sign);
481
482 let a0pa2 = add_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
483 let a0na2 = sub_vec_avx2::<M>(x0, a2, mod_vec, mod2_vec, sign);
484 let a1pa3 = add_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
485 let a1na3 = sub_vec_avx2::<M>(a1, a3, mod_vec, mod2_vec, sign);
486 let a1na3imag = mul_vec_avx2::<M>(a1na3, imag_vec, r_vec, mod_vec, sign);
487
488 let y0 = add_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
489 let y1 = sub_vec_avx2::<M>(a0pa2, a1pa3, mod_vec, mod2_vec, sign);
490 let y2 = add_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
491 let y3 = sub_vec_avx2::<M>(a0na2, a1na3imag, mod_vec, mod2_vec, sign);
492
493 _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
494 _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
495 _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
496 _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
497 i += 8;
498 }
499 while i < half {
500 let a0 = *ll.add(i);
501 let a1 = M::mod_mul(*lr.add(i), w1);
502 let a2 = M::mod_mul(*rl.add(i), w2);
503 let a3 = M::mod_mul(*rr.add(i), w3);
504 let a0pa2 = M::mod_add(a0, a2);
505 let a0na2 = M::mod_sub(a0, a2);
506 let a1pa3 = M::mod_add(a1, a3);
507 let a1na3 = M::mod_sub(a1, a3);
508 let a1na3imag = M::mod_mul(a1na3, imag);
509 *ll.add(i) = M::mod_add(a0pa2, a1pa3);
510 *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
511 *rl.add(i) = M::mod_add(a0na2, a1na3imag);
512 *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
513 i += 1;
514 }
515 w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
516 }
517 v >>= 2;
518 }
519 if v == 1 {
520 let mut w1 = M::N1;
521 for (s, block) in a.chunks_exact_mut(2).enumerate() {
522 let a0 = *block.get_unchecked(0);
523 let a1 = M::mod_mul(*block.get_unchecked(1), w1);
524 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
525 *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
526 w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
527 }
528 }
529 normalize_avx2::<M>(a);
530 }
531
532 #[target_feature(enable = "avx2")]
533 pub(super) unsafe fn intt_avx2<M>(a: &mut [MInt<M>])
534 where
535 M: Montgomery32NttModulus,
536 {
537 let n = a.len();
538 if n <= 1 {
539 return;
540 }
541 let ptr = a.as_mut_ptr() as *mut u32;
542 let a = std::slice::from_raw_parts_mut(ptr, n);
543 let mod_vec = _mm256_set1_epi32(M::MOD as i32);
544 let mod2_vec = _mm256_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
545 let r_vec = _mm256_set1_epi32(M::R.wrapping_neg() as i32);
546 let sign = _mm256_set1_epi32(0x8000_0000u32 as i32);
547 let iimag = M::INFO.inv_root[2];
548 let iimag_vec = _mm256_set1_epi32(iimag as i32);
549
550 let mut v = 1;
551 if n.trailing_zeros() & 1 == 1 {
552 let mut w1 = M::N1;
553 for (s, block) in a.chunks_exact_mut(2).enumerate() {
554 let a0 = *block.get_unchecked(0);
555 let a1 = *block.get_unchecked(1);
556 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
557 *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
558 w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
559 }
560 v <<= 1;
561 }
562 while v < n {
563 let mut w1 = M::N1;
564 for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
565 let base = block.as_mut_ptr();
566 let ll = base;
567 let lr = base.add(v);
568 let rl = base.add(v << 1);
569 let rr = base.add(v * 3);
570
571 let w2 = M::mod_mul(w1, w1);
572 let w3 = M::mod_mul(w2, w1);
573 let w1v = _mm256_set1_epi32(w1 as i32);
574 let w2v = _mm256_set1_epi32(w2 as i32);
575 let w3v = _mm256_set1_epi32(w3 as i32);
576
577 let mut i = 0;
578 while i + 8 <= v {
579 let x0 = _mm256_loadu_si256(ll.add(i) as *const __m256i);
580 let x1 = _mm256_loadu_si256(lr.add(i) as *const __m256i);
581 let x2 = _mm256_loadu_si256(rl.add(i) as *const __m256i);
582 let x3 = _mm256_loadu_si256(rr.add(i) as *const __m256i);
583
584 let a0pa1 = add_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
585 let a0na1 = sub_vec_avx2::<M>(x0, x1, mod_vec, mod2_vec, sign);
586 let a2pa3 = add_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
587 let a2na3 = sub_vec_avx2::<M>(x2, x3, mod_vec, mod2_vec, sign);
588 let a2na3iimag = mul_vec_avx2::<M>(a2na3, iimag_vec, r_vec, mod_vec, sign);
589
590 let y0 = add_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
591 let y1 = add_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
592 let y2 = sub_vec_avx2::<M>(a0pa1, a2pa3, mod_vec, mod2_vec, sign);
593 let y3 = sub_vec_avx2::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec, sign);
594
595 let y1 = mul_vec_avx2::<M>(y1, w1v, r_vec, mod_vec, sign);
596 let y2 = mul_vec_avx2::<M>(y2, w2v, r_vec, mod_vec, sign);
597 let y3 = mul_vec_avx2::<M>(y3, w3v, r_vec, mod_vec, sign);
598
599 _mm256_storeu_si256(ll.add(i) as *mut __m256i, y0);
600 _mm256_storeu_si256(lr.add(i) as *mut __m256i, y1);
601 _mm256_storeu_si256(rl.add(i) as *mut __m256i, y2);
602 _mm256_storeu_si256(rr.add(i) as *mut __m256i, y3);
603 i += 8;
604 }
605 while i < v {
606 let a0 = *ll.add(i);
607 let a1 = *lr.add(i);
608 let a2 = *rl.add(i);
609 let a3 = *rr.add(i);
610 let a0pa1 = M::mod_add(a0, a1);
611 let a0na1 = M::mod_sub(a0, a1);
612 let a2pa3 = M::mod_add(a2, a3);
613 let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
614 *ll.add(i) = M::mod_add(a0pa1, a2pa3);
615 *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
616 *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
617 *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
618 i += 1;
619 }
620 w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
621 }
622 v <<= 2;
623 }
624 normalize_avx2::<M>(a);
625 }
626
627 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
628 pub(super) unsafe fn ntt_avx512<M>(a: &mut [MInt<M>])
629 where
630 M: Montgomery32NttModulus,
631 {
632 let n = a.len();
633 if n <= 1 {
634 return;
635 }
636 let ptr = a.as_mut_ptr() as *mut u32;
637 let a = std::slice::from_raw_parts_mut(ptr, n);
638 let mod_vec = _mm512_set1_epi32(M::MOD as i32);
639 let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
640 let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
641 let imag = M::INFO.root[2];
642 let imag_vec = _mm512_set1_epi32(imag as i32);
643
644 let mut v = n / 2;
645 while v > 1 {
646 let half = v >> 1;
647 let mut w1 = M::N1;
648 for (s, block) in a.chunks_exact_mut(v << 1).enumerate() {
649 let base = block.as_mut_ptr();
650 let ll = base;
651 let lr = base.add(half);
652 let rl = base.add(v);
653 let rr = base.add(v + half);
654 let w2 = M::mod_mul(w1, w1);
655 let w3 = M::mod_mul(w2, w1);
656 let w1v = _mm512_set1_epi32(w1 as i32);
657 let w2v = _mm512_set1_epi32(w2 as i32);
658 let w3v = _mm512_set1_epi32(w3 as i32);
659
660 let mut i = 0;
661 while i + 16 <= half {
662 let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
663 let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
664 let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
665 let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
666
667 let a1 = mul_vec_avx512::<M>(x1, w1v, r_vec, mod_vec);
668 let a2 = mul_vec_avx512::<M>(x2, w2v, r_vec, mod_vec);
669 let a3 = mul_vec_avx512::<M>(x3, w3v, r_vec, mod_vec);
670
671 let a0pa2 = add_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
672 let a0na2 = sub_vec_avx512::<M>(x0, a2, mod_vec, mod2_vec);
673 let a1pa3 = add_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
674 let a1na3 = sub_vec_avx512::<M>(a1, a3, mod_vec, mod2_vec);
675 let a1na3imag = mul_vec_avx512::<M>(a1na3, imag_vec, r_vec, mod_vec);
676
677 let y0 = add_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
678 let y1 = sub_vec_avx512::<M>(a0pa2, a1pa3, mod_vec, mod2_vec);
679 let y2 = add_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
680 let y3 = sub_vec_avx512::<M>(a0na2, a1na3imag, mod_vec, mod2_vec);
681
682 _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
683 _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
684 _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
685 _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
686 i += 16;
687 }
688 while i < half {
689 let a0 = *ll.add(i);
690 let a1 = M::mod_mul(*lr.add(i), w1);
691 let a2 = M::mod_mul(*rl.add(i), w2);
692 let a3 = M::mod_mul(*rr.add(i), w3);
693 let a0pa2 = M::mod_add(a0, a2);
694 let a0na2 = M::mod_sub(a0, a2);
695 let a1pa3 = M::mod_add(a1, a3);
696 let a1na3 = M::mod_sub(a1, a3);
697 let a1na3imag = M::mod_mul(a1na3, imag);
698 *ll.add(i) = M::mod_add(a0pa2, a1pa3);
699 *lr.add(i) = M::mod_sub(a0pa2, a1pa3);
700 *rl.add(i) = M::mod_add(a0na2, a1na3imag);
701 *rr.add(i) = M::mod_sub(a0na2, a1na3imag);
702 i += 1;
703 }
704 w1 = M::mod_mul(w1, M::INFO.rate3[s.trailing_ones() as usize]);
705 }
706 v >>= 2;
707 }
708 if v == 1 {
709 let mut w1 = M::N1;
710 for (s, block) in a.chunks_exact_mut(2).enumerate() {
711 let a0 = *block.get_unchecked(0);
712 let a1 = M::mod_mul(*block.get_unchecked(1), w1);
713 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
714 *block.get_unchecked_mut(1) = M::mod_sub(a0, a1);
715 w1 = M::mod_mul(w1, M::INFO.rate2[s.trailing_ones() as usize]);
716 }
717 }
718 normalize_avx512::<M>(a);
719 }
720
721 #[target_feature(enable = "avx512f,avx512dq,avx512cd,avx512bw,avx512vl")]
722 pub(super) unsafe fn intt_avx512<M>(a: &mut [MInt<M>])
723 where
724 M: Montgomery32NttModulus,
725 {
726 let n = a.len();
727 if n <= 1 {
728 return;
729 }
730 let ptr = a.as_mut_ptr() as *mut u32;
731 let a = std::slice::from_raw_parts_mut(ptr, n);
732 let mod_vec = _mm512_set1_epi32(M::MOD as i32);
733 let mod2_vec = _mm512_set1_epi32(M::MOD.wrapping_add(M::MOD) as i32);
734 let r_vec = _mm512_set1_epi32(M::R.wrapping_neg() as i32);
735 let iimag = M::INFO.inv_root[2];
736 let iimag_vec = _mm512_set1_epi32(iimag as i32);
737
738 let mut v = 1;
739 if n.trailing_zeros() & 1 == 1 {
740 let mut w1 = M::N1;
741 for (s, block) in a.chunks_exact_mut(2).enumerate() {
742 let a0 = *block.get_unchecked(0);
743 let a1 = *block.get_unchecked(1);
744 *block.get_unchecked_mut(0) = M::mod_add(a0, a1);
745 *block.get_unchecked_mut(1) = M::mod_mul(M::mod_sub(a0, a1), w1);
746 w1 = M::mod_mul(w1, M::INFO.inv_rate2[s.trailing_ones() as usize]);
747 }
748 v <<= 1;
749 }
750 while v < n {
751 let mut w1 = M::N1;
752 for (s, block) in a.chunks_exact_mut(v << 2).enumerate() {
753 let base = block.as_mut_ptr();
754 let ll = base;
755 let lr = base.add(v);
756 let rl = base.add(v << 1);
757 let rr = base.add(v * 3);
758 let w2 = M::mod_mul(w1, w1);
759 let w3 = M::mod_mul(w2, w1);
760 let w1v = _mm512_set1_epi32(w1 as i32);
761 let w2v = _mm512_set1_epi32(w2 as i32);
762 let w3v = _mm512_set1_epi32(w3 as i32);
763
764 let mut i = 0;
765 while i + 16 <= v {
766 let x0 = _mm512_loadu_si512(ll.add(i) as *const __m512i);
767 let x1 = _mm512_loadu_si512(lr.add(i) as *const __m512i);
768 let x2 = _mm512_loadu_si512(rl.add(i) as *const __m512i);
769 let x3 = _mm512_loadu_si512(rr.add(i) as *const __m512i);
770
771 let a0pa1 = add_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
772 let a0na1 = sub_vec_avx512::<M>(x0, x1, mod_vec, mod2_vec);
773 let a2pa3 = add_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
774 let a2na3 = sub_vec_avx512::<M>(x2, x3, mod_vec, mod2_vec);
775 let a2na3iimag = mul_vec_avx512::<M>(a2na3, iimag_vec, r_vec, mod_vec);
776
777 let y0 = add_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
778 let y1 = add_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
779 let y2 = sub_vec_avx512::<M>(a0pa1, a2pa3, mod_vec, mod2_vec);
780 let y3 = sub_vec_avx512::<M>(a0na1, a2na3iimag, mod_vec, mod2_vec);
781
782 let y1 = mul_vec_avx512::<M>(y1, w1v, r_vec, mod_vec);
783 let y2 = mul_vec_avx512::<M>(y2, w2v, r_vec, mod_vec);
784 let y3 = mul_vec_avx512::<M>(y3, w3v, r_vec, mod_vec);
785
786 _mm512_storeu_si512(ll.add(i) as *mut __m512i, y0);
787 _mm512_storeu_si512(lr.add(i) as *mut __m512i, y1);
788 _mm512_storeu_si512(rl.add(i) as *mut __m512i, y2);
789 _mm512_storeu_si512(rr.add(i) as *mut __m512i, y3);
790 i += 16;
791 }
792 while i < v {
793 let a0 = *ll.add(i);
794 let a1 = *lr.add(i);
795 let a2 = *rl.add(i);
796 let a3 = *rr.add(i);
797 let a0pa1 = M::mod_add(a0, a1);
798 let a0na1 = M::mod_sub(a0, a1);
799 let a2pa3 = M::mod_add(a2, a3);
800 let a2na3iimag = M::mod_mul(M::mod_sub(a2, a3), iimag);
801 *ll.add(i) = M::mod_add(a0pa1, a2pa3);
802 *lr.add(i) = M::mod_mul(M::mod_add(a0na1, a2na3iimag), w1);
803 *rl.add(i) = M::mod_mul(M::mod_sub(a0pa1, a2pa3), w2);
804 *rr.add(i) = M::mod_mul(M::mod_sub(a0na1, a2na3iimag), w3);
805 i += 1;
806 }
807 w1 = M::mod_mul(w1, M::INFO.inv_rate3[s.trailing_ones() as usize]);
808 }
809 v <<= 2;
810 }
811 normalize_avx512::<M>(a);
812 }
813}
814
815fn convolve_naive<T>(a: &[T], b: &[T]) -> Vec<T>
816where
817 T: Copy + Zero + AddAssign<T> + Mul<Output = T>,
818{
819 if a.is_empty() && b.is_empty() {
820 return Vec::new();
821 }
822 let len = a.len() + b.len() - 1;
823 let mut c = vec![T::zero(); len];
824 if a.len() < b.len() {
825 for (i, &b) in b.iter().enumerate() {
826 for (a, c) in a.iter().zip(&mut c[i..]) {
827 *c += *a * b;
828 }
829 }
830 } else {
831 for (i, &a) in a.iter().enumerate() {
832 for (b, c) in b.iter().zip(&mut c[i..]) {
833 *c += *b * a;
834 }
835 }
836 }
837 c
838}
839
840fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
841where
842 T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
843{
844 if a.len().min(b.len()) <= 30 {
845 return convolve_naive(a, b);
846 }
847 let m = a.len().max(b.len()).div_ceil(2);
848 let (a0, a1) = if a.len() <= m {
849 (a, &[][..])
850 } else {
851 a.split_at(m)
852 };
853 let (b0, b1) = if b.len() <= m {
854 (b, &[][..])
855 } else {
856 b.split_at(m)
857 };
858 let f00 = convolve_karatsuba(a0, b0);
859 let f11 = convolve_karatsuba(a1, b1);
860 let mut a0a1 = a0.to_vec();
861 for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
862 *a0a1 += a1;
863 }
864 let mut b0b1 = b0.to_vec();
865 for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
866 *b0b1 += b1;
867 }
868 let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
869 for (f01, &f00) in f01.iter_mut().zip(&f00) {
870 *f01 -= f00;
871 }
872 for (f01, &f11) in f01.iter_mut().zip(&f11) {
873 *f01 -= f11;
874 }
875 let mut c = vec![T::zero(); a.len() + b.len() - 1];
876 for (c, &f00) in c.iter_mut().zip(&f00) {
877 *c += f00;
878 }
879 for (c, &f01) in c[m..].iter_mut().zip(&f01) {
880 *c += f01;
881 }
882 for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
883 *c += f11;
884 }
885 c
886}
887
888impl<M> ConvolveSteps for Convolve<M>
889where
890 M: Montgomery32NttModulus,
891{
892 type T = Vec<MInt<M>>;
893 type F = Vec<MInt<M>>;
894 fn length(t: &Self::T) -> usize {
895 t.len()
896 }
897 fn transform(mut t: Self::T, len: usize) -> Self::F {
898 t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
899 ntt(&mut t);
900 t
901 }
902 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
903 intt(&mut f);
904 f.truncate(len);
905 let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
906 for f in f.iter_mut() {
907 *f *= inv;
908 }
909 f
910 }
911 fn multiply(f: &mut Self::F, g: &Self::F) {
912 assert_eq!(f.len(), g.len());
913 for (f, g) in f.iter_mut().zip(g.iter()) {
914 *f *= *g;
915 }
916 }
917 fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
918 if Self::length(&a).max(Self::length(&b)) <= 100 {
919 return convolve_karatsuba(&a, &b);
920 }
921 if Self::length(&a).min(Self::length(&b)) <= 60 {
922 return convolve_naive(&a, &b);
923 }
924 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
925 let size = len.max(1).next_power_of_two();
926 if len <= size / 2 + 2 {
927 let xa = a.pop().unwrap();
928 let xb = b.pop().unwrap();
929 let mut c = vec![MInt::<M>::zero(); len];
930 *c.last_mut().unwrap() = xa * xb;
931 for (a, c) in a.iter().zip(&mut c[b.len()..]) {
932 *c += *a * xb;
933 }
934 for (b, c) in b.iter().zip(&mut c[a.len()..]) {
935 *c += *b * xa;
936 }
937 let d = Self::convolve(a, b);
938 for (d, c) in d.into_iter().zip(&mut c) {
939 *c += d;
940 }
941 return c;
942 }
943 let same = a == b;
944 let mut a = Self::transform(a, len);
945 if same {
946 for a in a.iter_mut() {
947 *a *= *a;
948 }
949 } else {
950 let b = Self::transform(b, len);
951 Self::multiply(&mut a, &b);
952 }
953 Self::inverse_transform(a, len)
954 }
955}
956
957type MVec<M> = Vec<MInt<M>>;
958impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
959where
960 M: MIntConvert + MIntConvert<u32>,
961 N1: Montgomery32NttModulus,
962 N2: Montgomery32NttModulus,
963 N3: Montgomery32NttModulus,
964{
965 type T = MVec<M>;
966 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
967 fn length(t: &Self::T) -> usize {
968 t.len()
969 }
970 fn transform(t: Self::T, len: usize) -> Self::F {
971 let npot = len.max(1).next_power_of_two();
972 let mut f = (
973 MVec::<N1>::with_capacity(npot),
974 MVec::<N2>::with_capacity(npot),
975 MVec::<N3>::with_capacity(npot),
976 );
977 for t in t {
978 f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
979 f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
980 f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
981 }
982 f.0.resize_with(npot, Zero::zero);
983 f.1.resize_with(npot, Zero::zero);
984 f.2.resize_with(npot, Zero::zero);
985 ntt(&mut f.0);
986 ntt(&mut f.1);
987 ntt(&mut f.2);
988 f
989 }
990 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
991 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
992 let m1 = MInt::<M>::from(N1::get_mod());
993 let m1_3 = MInt::<N3>::new(N1::get_mod());
994 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
995 let m2 = m1 * MInt::<M>::from(N2::get_mod());
996 Convolve::<N1>::inverse_transform(f.0, len)
997 .into_iter()
998 .zip(Convolve::<N2>::inverse_transform(f.1, len))
999 .zip(Convolve::<N3>::inverse_transform(f.2, len))
1000 .map(|((c1, c2), c3)| {
1001 let d1 = c1.inner();
1002 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1003 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1004 let d3 = ((c3 - x) * t2).inner();
1005 MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
1006 })
1007 .collect()
1008 }
1009 fn multiply(f: &mut Self::F, g: &Self::F) {
1010 assert_eq!(f.0.len(), g.0.len());
1011 assert_eq!(f.1.len(), g.1.len());
1012 assert_eq!(f.2.len(), g.2.len());
1013 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1014 *f *= *g;
1015 }
1016 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1017 *f *= *g;
1018 }
1019 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1020 *f *= *g;
1021 }
1022 }
1023 fn convolve(a: Self::T, b: Self::T) -> Self::T {
1024 if Self::length(&a).max(Self::length(&b)) <= 300 {
1025 return convolve_karatsuba(&a, &b);
1026 }
1027 if Self::length(&a).min(Self::length(&b)) <= 60 {
1028 return convolve_naive(&a, &b);
1029 }
1030 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1031 let mut a = Self::transform(a, len);
1032 let b = Self::transform(b, len);
1033 Self::multiply(&mut a, &b);
1034 Self::inverse_transform(a, len)
1035 }
1036}
1037
1038impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
1039where
1040 N1: Montgomery32NttModulus,
1041 N2: Montgomery32NttModulus,
1042 N3: Montgomery32NttModulus,
1043{
1044 type T = Vec<u64>;
1045 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
1046
1047 fn length(t: &Self::T) -> usize {
1048 t.len()
1049 }
1050
1051 fn transform(t: Self::T, len: usize) -> Self::F {
1052 let npot = len.max(1).next_power_of_two();
1053 let mut f = (
1054 MVec::<N1>::with_capacity(npot),
1055 MVec::<N2>::with_capacity(npot),
1056 MVec::<N3>::with_capacity(npot),
1057 );
1058 for t in t {
1059 f.0.push(t.into());
1060 f.1.push(t.into());
1061 f.2.push(t.into());
1062 }
1063 f.0.resize_with(npot, Zero::zero);
1064 f.1.resize_with(npot, Zero::zero);
1065 f.2.resize_with(npot, Zero::zero);
1066 ntt(&mut f.0);
1067 ntt(&mut f.1);
1068 ntt(&mut f.2);
1069 f
1070 }
1071
1072 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
1073 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
1074 let m1 = N1::get_mod() as u64;
1075 let m1_3 = MInt::<N3>::new(N1::get_mod());
1076 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
1077 let m2 = m1 * N2::get_mod() as u64;
1078 Convolve::<N1>::inverse_transform(f.0, len)
1079 .into_iter()
1080 .zip(Convolve::<N2>::inverse_transform(f.1, len))
1081 .zip(Convolve::<N3>::inverse_transform(f.2, len))
1082 .map(|((c1, c2), c3)| {
1083 let d1 = c1.inner();
1084 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1085 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1086 let d3 = ((c3 - x) * t2).inner();
1087 d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
1088 })
1089 .collect()
1090 }
1091
1092 fn multiply(f: &mut Self::F, g: &Self::F) {
1093 assert_eq!(f.0.len(), g.0.len());
1094 assert_eq!(f.1.len(), g.1.len());
1095 assert_eq!(f.2.len(), g.2.len());
1096 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1097 *f *= *g;
1098 }
1099 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1100 *f *= *g;
1101 }
1102 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1103 *f *= *g;
1104 }
1105 }
1106
1107 fn convolve(a: Self::T, b: Self::T) -> Self::T {
1108 if Self::length(&a).max(Self::length(&b)) <= 300 {
1109 return convolve_karatsuba(&a, &b);
1110 }
1111 if Self::length(&a).min(Self::length(&b)) <= 60 {
1112 return convolve_naive(&a, &b);
1113 }
1114 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1115 let mut a = Self::transform(a, len);
1116 let b = Self::transform(b, len);
1117 Self::multiply(&mut a, &b);
1118 Self::inverse_transform(a, len)
1119 }
1120}
1121
1122pub trait NttReuse: ConvolveSteps {
1123 const MULTIPLE: bool = true;
1124
1125 fn ntt_doubling(f: Self::F) -> Self::F;
1127
1128 fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
1130
1131 fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F;
1133}
1134
1135thread_local!(
1136 static BIT_REVERSE: UnsafeCell<Vec<Vec<usize>>> = const { UnsafeCell::new(vec![]) };
1137);
1138
1139impl<M> NttReuse for Convolve<M>
1140where
1141 M: Montgomery32NttModulus,
1142{
1143 const MULTIPLE: bool = false;
1144
1145 fn ntt_doubling(mut f: Self::F) -> Self::F {
1146 let n = f.len();
1147 let k = n.trailing_zeros() as usize;
1148 let mut a = Self::inverse_transform(f.clone(), n);
1149 let mut rot = MInt::<M>::one();
1150 let zeta = MInt::<M>::new_unchecked(M::INFO.root[k + 1]);
1151 for a in a.iter_mut() {
1152 *a *= rot;
1153 rot *= zeta;
1154 }
1155 f.extend(Self::transform(a, n));
1156 f
1157 }
1158
1159 fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1160 assert_eq!(f.len(), g.len());
1161 assert!(f.len().is_power_of_two());
1162 assert!(f.len() >= 2);
1163 let inv2 = MInt::<M>::from(2).inv();
1164 let n = f.len() / 2;
1165 (0..n)
1166 .map(|i| (f[i << 1] * g[i << 1 | 1] + f[i << 1 | 1] * g[i << 1]) * inv2)
1167 .collect()
1168 }
1169
1170 fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1171 assert_eq!(f.len(), g.len());
1172 assert!(f.len().is_power_of_two());
1173 assert!(f.len() >= 2);
1174 let mut inv2 = MInt::<M>::from(2).inv();
1175 let n = f.len() / 2;
1176 let k = f.len().trailing_zeros() as usize;
1177 let mut h = vec![MInt::<M>::zero(); n];
1178 let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
1179 BIT_REVERSE.with(|br| {
1180 let br = unsafe { &mut *br.get() };
1181 if br.len() < k {
1182 br.resize_with(k, Default::default);
1183 }
1184 let k = k - 1;
1185 if br[k].is_empty() {
1186 let mut v = vec![0; 1 << k];
1187 for i in 0..1 << k {
1188 v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
1189 }
1190 br[k] = v;
1191 }
1192 for &i in &br[k] {
1193 h[i] = (f[i << 1] * g[i << 1 | 1] - f[i << 1 | 1] * g[i << 1]) * inv2;
1194 inv2 *= w;
1195 }
1196 });
1197 h
1198 }
1199}
1200
1201impl<M, N1, N2, N3> NttReuse for Convolve<(M, (N1, N2, N3))>
1202where
1203 M: MIntConvert + MIntConvert<u32>,
1204 N1: Montgomery32NttModulus,
1205 N2: Montgomery32NttModulus,
1206 N3: Montgomery32NttModulus,
1207{
1208 fn ntt_doubling(f: Self::F) -> Self::F {
1209 (
1210 Convolve::<N1>::ntt_doubling(f.0),
1211 Convolve::<N2>::ntt_doubling(f.1),
1212 Convolve::<N3>::ntt_doubling(f.2),
1213 )
1214 }
1215
1216 fn even_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1217 fn even_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
1218 where
1219 M: Montgomery32NttModulus,
1220 {
1221 let n = f.len();
1222 assert_eq!(f.len(), g.len());
1223 assert!(f.len().is_power_of_two());
1224 assert!(f.len() >= 2);
1225 let inv2 = MInt::<M>::from(2).inv();
1226 let u = MInt::<M>::new(m) * MInt::<M>::from(n as u32);
1227 let n = f.len() / 2;
1228 (0..n)
1229 .map(|i| {
1230 (f[i << 1]
1231 * if i == 0 {
1232 g[i << 1 | 1] + u
1233 } else {
1234 g[i << 1 | 1]
1235 }
1236 + f[i << 1 | 1] * g[i << 1])
1237 * inv2
1238 })
1239 .collect()
1240 }
1241
1242 let m = M::mod_into();
1243 (
1244 even_mul_normal_neg_corrected(&f.0, &g.0, m),
1245 even_mul_normal_neg_corrected(&f.1, &g.1, m),
1246 even_mul_normal_neg_corrected(&f.2, &g.2, m),
1247 )
1248 }
1249
1250 fn odd_mul_normal_neg(f: &Self::F, g: &Self::F) -> Self::F {
1251 fn odd_mul_normal_neg_corrected<M>(f: &[MInt<M>], g: &[MInt<M>], m: u32) -> Vec<MInt<M>>
1252 where
1253 M: Montgomery32NttModulus,
1254 {
1255 assert_eq!(f.len(), g.len());
1256 assert!(f.len().is_power_of_two());
1257 assert!(f.len() >= 2);
1258 let mut inv2 = MInt::<M>::from(2).inv();
1259 let u = MInt::<M>::new(m) * MInt::<M>::from(f.len() as u32);
1260 let n = f.len() / 2;
1261 let k = f.len().trailing_zeros() as usize;
1262 let mut h = vec![MInt::<M>::zero(); n];
1263 let w = MInt::<M>::new_unchecked(M::INFO.inv_root[k]);
1264 BIT_REVERSE.with(|br| {
1265 let br = unsafe { &mut *br.get() };
1266 if br.len() < k {
1267 br.resize_with(k, Default::default);
1268 }
1269 let k = k - 1;
1270 if br[k].is_empty() {
1271 let mut v = vec![0; 1 << k];
1272 for i in 0..1 << k {
1273 v[i] = (v[i >> 1] >> 1) | ((i & 1) << (k.saturating_sub(1)));
1274 }
1275 br[k] = v;
1276 }
1277 for &i in &br[k] {
1278 h[i] = (f[i << 1]
1279 * if i == 0 {
1280 g[i << 1 | 1] + u
1281 } else {
1282 g[i << 1 | 1]
1283 }
1284 - f[i << 1 | 1] * g[i << 1])
1285 * inv2;
1286 inv2 *= w;
1287 }
1288 });
1289 h
1290 }
1291
1292 let m = M::mod_into();
1293 (
1294 odd_mul_normal_neg_corrected(&f.0, &g.0, m),
1295 odd_mul_normal_neg_corrected(&f.1, &g.1, m),
1296 odd_mul_normal_neg_corrected(&f.2, &g.2, m),
1297 )
1298 }
1299}
1300
1301#[cfg(test)]
1302mod tests {
1303 use super::*;
1304 use crate::num::{mint_basic::Modulo1000000009, montgomery::MInt998244353};
1305 use crate::tools::Xorshift;
1306
1307 #[test]
1308 fn test_convolve_naive() {
1309 let mut rng = Xorshift::default();
1310 for _ in 0..1000 {
1311 let n = rng.random(0..=60);
1312 let m = rng.random(0..=60);
1313 let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
1314 let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
1315 let mut c = vec![0u32; (n + m).saturating_sub(1)];
1316 for i in 0..n {
1317 for j in 0..m {
1318 c[i + j] += a[i] * b[j];
1319 }
1320 }
1321 let d = convolve_naive(&a, &b);
1322 assert_eq!(c, d);
1323 }
1324 }
1325
1326 #[test]
1327 fn test_convolve_karatsuba() {
1328 let mut rng = Xorshift::default();
1329 for _ in 0..1000 {
1330 let n = rng.random(0..=200);
1331 let m = rng.random(0..=200);
1332 let a: Vec<u32> = rng.random_iter(0u32..1000).take(n).collect();
1333 let b: Vec<u32> = rng.random_iter(0u32..1000).take(m).collect();
1334 let mut c = vec![0u32; (n + m).saturating_sub(1)];
1335 for i in 0..n {
1336 for j in 0..m {
1337 c[i + j] += a[i] * b[j];
1338 }
1339 }
1340 let d = convolve_karatsuba(&a, &b);
1341 assert_eq!(c, d);
1342 }
1343 }
1344
1345 #[test]
1346 fn test_ntt998244353() {
1347 let mut rng = Xorshift::default();
1348 for t in 0..1000 {
1349 let n: usize = rng.random(0..=5);
1350 let n = if n == 5 { rng.random(70..=120) } else { n };
1351 let m: usize = rng.random(0..=5);
1352 let m = if m == 5 { rng.random(70..=120) } else { m };
1353 let (n, m) = if t % 100 != 0 {
1354 (n, m)
1355 } else {
1356 let w = rng.random(6..=8);
1357 ((1usize << w) + 1usize, (1usize << w) + 1usize)
1358 };
1359 let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
1360 let mut b: Vec<MInt998244353> = rng.random_iter(..).take(m).collect();
1361 if n == m && rng.random(0..2) == 0 {
1362 b = a.clone();
1363 }
1364
1365 let mut c = vec![MInt998244353::zero(); (n + m).saturating_sub(1)];
1366 for i in 0..n {
1367 for j in 0..m {
1368 c[i + j] += a[i] * b[j];
1369 }
1370 }
1371 let d = Convolve998244353::convolve(a, b);
1372 assert_eq!(c, d);
1373 }
1374 assert_eq!(NttInfo::new::<Modulo998244353>(), Modulo998244353::INFO);
1375 }
1376
1377 #[test]
1378 fn test_convolve3() {
1379 type M = MInt<Modulo1000000009>;
1380 let mut rng = Xorshift::default();
1381 for _ in 0..1000 {
1382 let n = rng.random(0..=5);
1383 let n = if n == 5 { rng.random(70..=400) } else { n };
1384 let m = rng.random(0..=5);
1385 let m = if m == 5 { rng.random(70..=400) } else { m };
1386 let a: Vec<M> = rng.random_iter(..).take(n).collect();
1387 let b: Vec<M> = rng.random_iter(..).take(m).collect();
1388 let mut c = vec![M::zero(); (n + m).saturating_sub(1)];
1389 for i in 0..n {
1390 for j in 0..m {
1391 c[i + j] += a[i] * b[j];
1392 }
1393 }
1394 let d = MIntConvolve::<Modulo1000000009>::convolve(a, b);
1395 assert_eq!(c, d);
1396 }
1397 }
1398
1399 #[test]
1400 fn test_convolve_u64() {
1401 let mut rng = Xorshift::default();
1402 for _ in 0..1000 {
1403 let n = rng.random(0..=5);
1404 let n = if n == 5 { rng.random(70..=400) } else { n };
1405 let m = rng.random(0..=5);
1406 let m = if m == 5 { rng.random(70..=400) } else { m };
1407 let a: Vec<u64> = rng.random_iter(0u64..1 << 24).take(n).collect();
1408 let b: Vec<u64> = rng.random_iter(0u64..1 << 24).take(m).collect();
1409 let mut c = vec![0; (n + m).saturating_sub(1)];
1410 for i in 0..n {
1411 for j in 0..m {
1412 c[i + j] += a[i] * b[j];
1413 }
1414 }
1415 let d = U64Convolve::convolve(a, b);
1416 assert_eq!(c, d);
1417 }
1418 }
1419
1420 #[test]
1421 fn test_ntt_reuse_998244353() {
1422 let mut rng = Xorshift::default();
1423 for _ in 0..100 {
1424 let n: usize = if rng.gen_bool(0.5) {
1425 rng.random(1..=20)
1426 } else {
1427 rng.random(1..=1000)
1428 };
1429 let a: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
1430 let f = Convolve998244353::transform(a.clone(), n);
1431
1432 {
1434 let f_double = Convolve998244353::ntt_doubling(f.clone());
1435 let mut a = a.clone();
1436 a.resize_with(n * 2, Zero::zero);
1437 let f2 = Convolve998244353::transform(a, n * 2);
1438 assert_eq!(f_double, f2);
1439 }
1440
1441 let f = Convolve998244353::transform(a.clone(), n * 2);
1442 let b: Vec<MInt998244353> = rng.random_iter(..).take(n).collect();
1443 let g = Convolve998244353::transform(b.clone(), n * 2);
1444 let mut b_neg = b.clone();
1445 for b in b_neg.iter_mut().skip(1).step_by(2) {
1446 *b = -*b;
1447 }
1448
1449 {
1451 let fg_neg = Convolve998244353::even_mul_normal_neg(&f, &g);
1452 let ab_neg_even: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
1453 .into_iter()
1454 .step_by(2)
1455 .collect();
1456 let fg = Convolve998244353::transform(ab_neg_even, n);
1457 assert_eq!(fg_neg, fg);
1458 }
1459
1460 {
1462 let fg_neg = Convolve998244353::odd_mul_normal_neg(&f, &g);
1463 let ab_neg_odd: Vec<_> = Convolve998244353::convolve(a.clone(), b_neg.clone())
1464 .into_iter()
1465 .skip(1)
1466 .step_by(2)
1467 .collect();
1468 let fg = Convolve998244353::transform(ab_neg_odd, n);
1469 assert_eq!(fg_neg, fg);
1470 }
1471 }
1472 }
1473
1474 #[test]
1475 fn test_ntt_reuse_triple() {
1476 type M = MInt<Modulo1000000009>;
1477 let mut rng = Xorshift::default();
1478 for _ in 0..100 {
1479 let n: usize = if rng.gen_bool(0.5) {
1480 rng.random(1..=20)
1481 } else {
1482 rng.random(1..=1000)
1483 };
1484 let a: Vec<M> = rng.random_iter(..).take(n).collect();
1485 let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n);
1486
1487 {
1489 let f_double = MIntConvolve::<Modulo1000000009>::ntt_doubling(f.clone());
1490 let mut a = a.clone();
1491 a.resize_with(n * 2, Zero::zero);
1492 let f2 = MIntConvolve::<Modulo1000000009>::transform(a, n * 2);
1493 assert_eq!(f_double, f2);
1494 }
1495
1496 let f = MIntConvolve::<Modulo1000000009>::transform(a.clone(), n * 2);
1497 let b: Vec<M> = rng.random_iter(..).take(n).collect();
1498 let g = MIntConvolve::<Modulo1000000009>::transform(b.clone(), n * 2);
1499 let mut b_neg = b.clone();
1500 for b in b_neg.iter_mut().skip(1).step_by(2) {
1501 *b = -*b;
1502 }
1503
1504 {
1506 let fg_neg = MIntConvolve::<Modulo1000000009>::even_mul_normal_neg(&f, &g);
1507 let ab_neg_even: Vec<_> =
1508 MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
1509 .into_iter()
1510 .step_by(2)
1511 .collect();
1512 assert_eq!(
1513 MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
1514 ab_neg_even
1515 );
1516 }
1517
1518 {
1520 let fg_neg = MIntConvolve::<Modulo1000000009>::odd_mul_normal_neg(&f, &g);
1521 let ab_neg_odd: Vec<_> =
1522 MIntConvolve::<Modulo1000000009>::convolve(a.clone(), b_neg.clone())
1523 .into_iter()
1524 .skip(1)
1525 .step_by(2)
1526 .chain([M::zero()])
1527 .collect();
1528 assert_eq!(
1529 MIntConvolve::<Modulo1000000009>::inverse_transform(fg_neg.clone(), n),
1530 ab_neg_odd
1531 );
1532 }
1533 }
1534 }
1535}