fn convolve_naive<T>(a: &[T], b: &[T]) -> Vec<T>Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 281)
276fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
277where
278 T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
279{
280 if a.len().min(b.len()) <= 30 {
281 return convolve_naive(a, b);
282 }
283 let m = a.len().max(b.len()).div_ceil(2);
284 let (a0, a1) = if a.len() <= m {
285 (a, &[][..])
286 } else {
287 a.split_at(m)
288 };
289 let (b0, b1) = if b.len() <= m {
290 (b, &[][..])
291 } else {
292 b.split_at(m)
293 };
294 let f00 = convolve_karatsuba(a0, b0);
295 let f11 = convolve_karatsuba(a1, b1);
296 let mut a0a1 = a0.to_vec();
297 for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
298 *a0a1 += a1;
299 }
300 let mut b0b1 = b0.to_vec();
301 for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
302 *b0b1 += b1;
303 }
304 let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
305 for (f01, &f00) in f01.iter_mut().zip(&f00) {
306 *f01 -= f00;
307 }
308 for (f01, &f11) in f01.iter_mut().zip(&f11) {
309 *f01 -= f11;
310 }
311 let mut c = vec![T::zero(); a.len() + b.len() - 1];
312 for (c, &f00) in c.iter_mut().zip(&f00) {
313 *c += f00;
314 }
315 for (c, &f01) in c[m..].iter_mut().zip(&f01) {
316 *c += f01;
317 }
318 for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
319 *c += f11;
320 }
321 c
322}
323
324impl<M> ConvolveSteps for Convolve<M>
325where
326 M: Montgomery32NttModulus,
327{
328 type T = Vec<MInt<M>>;
329 type F = Vec<MInt<M>>;
330 fn length(t: &Self::T) -> usize {
331 t.len()
332 }
333 fn transform(mut t: Self::T, len: usize) -> Self::F {
334 t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
335 ntt(&mut t);
336 t
337 }
338 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
339 intt(&mut f);
340 f.truncate(len);
341 let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
342 for f in f.iter_mut() {
343 *f *= inv;
344 }
345 f
346 }
347 fn multiply(f: &mut Self::F, g: &Self::F) {
348 assert_eq!(f.len(), g.len());
349 for (f, g) in f.iter_mut().zip(g.iter()) {
350 *f *= *g;
351 }
352 }
353 fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
354 if Self::length(&a).max(Self::length(&b)) <= 100 {
355 return convolve_karatsuba(&a, &b);
356 }
357 if Self::length(&a).min(Self::length(&b)) <= 60 {
358 return convolve_naive(&a, &b);
359 }
360 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
361 let size = len.max(1).next_power_of_two();
362 if len <= size / 2 + 2 {
363 let xa = a.pop().unwrap();
364 let xb = b.pop().unwrap();
365 let mut c = vec![MInt::<M>::zero(); len];
366 *c.last_mut().unwrap() = xa * xb;
367 for (a, c) in a.iter().zip(&mut c[b.len()..]) {
368 *c += *a * xb;
369 }
370 for (b, c) in b.iter().zip(&mut c[a.len()..]) {
371 *c += *b * xa;
372 }
373 let d = Self::convolve(a, b);
374 for (d, c) in d.into_iter().zip(&mut c) {
375 *c += d;
376 }
377 return c;
378 }
379 let same = a == b;
380 let mut a = Self::transform(a, len);
381 if same {
382 for a in a.iter_mut() {
383 *a *= *a;
384 }
385 } else {
386 let b = Self::transform(b, len);
387 Self::multiply(&mut a, &b);
388 }
389 Self::inverse_transform(a, len)
390 }
391}
392
393type MVec<M> = Vec<MInt<M>>;
394impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
395where
396 M: MIntConvert + MIntConvert<u32>,
397 N1: Montgomery32NttModulus,
398 N2: Montgomery32NttModulus,
399 N3: Montgomery32NttModulus,
400{
401 type T = MVec<M>;
402 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
403 fn length(t: &Self::T) -> usize {
404 t.len()
405 }
406 fn transform(t: Self::T, len: usize) -> Self::F {
407 let npot = len.max(1).next_power_of_two();
408 let mut f = (
409 MVec::<N1>::with_capacity(npot),
410 MVec::<N2>::with_capacity(npot),
411 MVec::<N3>::with_capacity(npot),
412 );
413 for t in t {
414 f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
415 f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
416 f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
417 }
418 f.0.resize_with(npot, Zero::zero);
419 f.1.resize_with(npot, Zero::zero);
420 f.2.resize_with(npot, Zero::zero);
421 ntt(&mut f.0);
422 ntt(&mut f.1);
423 ntt(&mut f.2);
424 f
425 }
426 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
427 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
428 let m1 = MInt::<M>::from(N1::get_mod());
429 let m1_3 = MInt::<N3>::new(N1::get_mod());
430 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
431 let m2 = m1 * MInt::<M>::from(N2::get_mod());
432 Convolve::<N1>::inverse_transform(f.0, len)
433 .into_iter()
434 .zip(Convolve::<N2>::inverse_transform(f.1, len))
435 .zip(Convolve::<N3>::inverse_transform(f.2, len))
436 .map(|((c1, c2), c3)| {
437 let d1 = c1.inner();
438 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
439 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
440 let d3 = ((c3 - x) * t2).inner();
441 MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
442 })
443 .collect()
444 }
445 fn multiply(f: &mut Self::F, g: &Self::F) {
446 assert_eq!(f.0.len(), g.0.len());
447 assert_eq!(f.1.len(), g.1.len());
448 assert_eq!(f.2.len(), g.2.len());
449 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
450 *f *= *g;
451 }
452 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
453 *f *= *g;
454 }
455 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
456 *f *= *g;
457 }
458 }
459 fn convolve(a: Self::T, b: Self::T) -> Self::T {
460 if Self::length(&a).max(Self::length(&b)) <= 300 {
461 return convolve_karatsuba(&a, &b);
462 }
463 if Self::length(&a).min(Self::length(&b)) <= 60 {
464 return convolve_naive(&a, &b);
465 }
466 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
467 let mut a = Self::transform(a, len);
468 let b = Self::transform(b, len);
469 Self::multiply(&mut a, &b);
470 Self::inverse_transform(a, len)
471 }
472}
473
474impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
475where
476 N1: Montgomery32NttModulus,
477 N2: Montgomery32NttModulus,
478 N3: Montgomery32NttModulus,
479{
480 type T = Vec<u64>;
481 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
482
483 fn length(t: &Self::T) -> usize {
484 t.len()
485 }
486
487 fn transform(t: Self::T, len: usize) -> Self::F {
488 let npot = len.max(1).next_power_of_two();
489 let mut f = (
490 MVec::<N1>::with_capacity(npot),
491 MVec::<N2>::with_capacity(npot),
492 MVec::<N3>::with_capacity(npot),
493 );
494 for t in t {
495 f.0.push(t.into());
496 f.1.push(t.into());
497 f.2.push(t.into());
498 }
499 f.0.resize_with(npot, Zero::zero);
500 f.1.resize_with(npot, Zero::zero);
501 f.2.resize_with(npot, Zero::zero);
502 ntt(&mut f.0);
503 ntt(&mut f.1);
504 ntt(&mut f.2);
505 f
506 }
507
508 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
509 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
510 let m1 = N1::get_mod() as u64;
511 let m1_3 = MInt::<N3>::new(N1::get_mod());
512 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
513 let m2 = m1 * N2::get_mod() as u64;
514 Convolve::<N1>::inverse_transform(f.0, len)
515 .into_iter()
516 .zip(Convolve::<N2>::inverse_transform(f.1, len))
517 .zip(Convolve::<N3>::inverse_transform(f.2, len))
518 .map(|((c1, c2), c3)| {
519 let d1 = c1.inner();
520 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
521 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
522 let d3 = ((c3 - x) * t2).inner();
523 d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
524 })
525 .collect()
526 }
527
528 fn multiply(f: &mut Self::F, g: &Self::F) {
529 assert_eq!(f.0.len(), g.0.len());
530 assert_eq!(f.1.len(), g.1.len());
531 assert_eq!(f.2.len(), g.2.len());
532 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
533 *f *= *g;
534 }
535 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
536 *f *= *g;
537 }
538 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
539 *f *= *g;
540 }
541 }
542
543 fn convolve(a: Self::T, b: Self::T) -> Self::T {
544 if Self::length(&a).max(Self::length(&b)) <= 300 {
545 return convolve_karatsuba(&a, &b);
546 }
547 if Self::length(&a).min(Self::length(&b)) <= 60 {
548 return convolve_naive(&a, &b);
549 }
550 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
551 let mut a = Self::transform(a, len);
552 let b = Self::transform(b, len);
553 Self::multiply(&mut a, &b);
554 Self::inverse_transform(a, len)
555 }