convolve_naive

Function convolve_naive 

Source
fn convolve_naive<T>(a: &[T], b: &[T]) -> Vec<T>
where T: Copy + Zero + AddAssign<T> + Mul<Output = T>,
Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 281)
276fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
277where
278    T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
279{
280    if a.len().min(b.len()) <= 30 {
281        return convolve_naive(a, b);
282    }
283    let m = a.len().max(b.len()).div_ceil(2);
284    let (a0, a1) = if a.len() <= m {
285        (a, &[][..])
286    } else {
287        a.split_at(m)
288    };
289    let (b0, b1) = if b.len() <= m {
290        (b, &[][..])
291    } else {
292        b.split_at(m)
293    };
294    let f00 = convolve_karatsuba(a0, b0);
295    let f11 = convolve_karatsuba(a1, b1);
296    let mut a0a1 = a0.to_vec();
297    for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
298        *a0a1 += a1;
299    }
300    let mut b0b1 = b0.to_vec();
301    for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
302        *b0b1 += b1;
303    }
304    let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
305    for (f01, &f00) in f01.iter_mut().zip(&f00) {
306        *f01 -= f00;
307    }
308    for (f01, &f11) in f01.iter_mut().zip(&f11) {
309        *f01 -= f11;
310    }
311    let mut c = vec![T::zero(); a.len() + b.len() - 1];
312    for (c, &f00) in c.iter_mut().zip(&f00) {
313        *c += f00;
314    }
315    for (c, &f01) in c[m..].iter_mut().zip(&f01) {
316        *c += f01;
317    }
318    for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
319        *c += f11;
320    }
321    c
322}
323
324impl<M> ConvolveSteps for Convolve<M>
325where
326    M: Montgomery32NttModulus,
327{
328    type T = Vec<MInt<M>>;
329    type F = Vec<MInt<M>>;
330    fn length(t: &Self::T) -> usize {
331        t.len()
332    }
333    fn transform(mut t: Self::T, len: usize) -> Self::F {
334        t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
335        ntt(&mut t);
336        t
337    }
338    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
339        intt(&mut f);
340        f.truncate(len);
341        let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
342        for f in f.iter_mut() {
343            *f *= inv;
344        }
345        f
346    }
347    fn multiply(f: &mut Self::F, g: &Self::F) {
348        assert_eq!(f.len(), g.len());
349        for (f, g) in f.iter_mut().zip(g.iter()) {
350            *f *= *g;
351        }
352    }
353    fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
354        if Self::length(&a).max(Self::length(&b)) <= 100 {
355            return convolve_karatsuba(&a, &b);
356        }
357        if Self::length(&a).min(Self::length(&b)) <= 60 {
358            return convolve_naive(&a, &b);
359        }
360        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
361        let size = len.max(1).next_power_of_two();
362        if len <= size / 2 + 2 {
363            let xa = a.pop().unwrap();
364            let xb = b.pop().unwrap();
365            let mut c = vec![MInt::<M>::zero(); len];
366            *c.last_mut().unwrap() = xa * xb;
367            for (a, c) in a.iter().zip(&mut c[b.len()..]) {
368                *c += *a * xb;
369            }
370            for (b, c) in b.iter().zip(&mut c[a.len()..]) {
371                *c += *b * xa;
372            }
373            let d = Self::convolve(a, b);
374            for (d, c) in d.into_iter().zip(&mut c) {
375                *c += d;
376            }
377            return c;
378        }
379        let same = a == b;
380        let mut a = Self::transform(a, len);
381        if same {
382            for a in a.iter_mut() {
383                *a *= *a;
384            }
385        } else {
386            let b = Self::transform(b, len);
387            Self::multiply(&mut a, &b);
388        }
389        Self::inverse_transform(a, len)
390    }
391}
392
393type MVec<M> = Vec<MInt<M>>;
394impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
395where
396    M: MIntConvert + MIntConvert<u32>,
397    N1: Montgomery32NttModulus,
398    N2: Montgomery32NttModulus,
399    N3: Montgomery32NttModulus,
400{
401    type T = MVec<M>;
402    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
403    fn length(t: &Self::T) -> usize {
404        t.len()
405    }
406    fn transform(t: Self::T, len: usize) -> Self::F {
407        let npot = len.max(1).next_power_of_two();
408        let mut f = (
409            MVec::<N1>::with_capacity(npot),
410            MVec::<N2>::with_capacity(npot),
411            MVec::<N3>::with_capacity(npot),
412        );
413        for t in t {
414            f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
415            f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
416            f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
417        }
418        f.0.resize_with(npot, Zero::zero);
419        f.1.resize_with(npot, Zero::zero);
420        f.2.resize_with(npot, Zero::zero);
421        ntt(&mut f.0);
422        ntt(&mut f.1);
423        ntt(&mut f.2);
424        f
425    }
426    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
427        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
428        let m1 = MInt::<M>::from(N1::get_mod());
429        let m1_3 = MInt::<N3>::new(N1::get_mod());
430        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
431        let m2 = m1 * MInt::<M>::from(N2::get_mod());
432        Convolve::<N1>::inverse_transform(f.0, len)
433            .into_iter()
434            .zip(Convolve::<N2>::inverse_transform(f.1, len))
435            .zip(Convolve::<N3>::inverse_transform(f.2, len))
436            .map(|((c1, c2), c3)| {
437                let d1 = c1.inner();
438                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
439                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
440                let d3 = ((c3 - x) * t2).inner();
441                MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
442            })
443            .collect()
444    }
445    fn multiply(f: &mut Self::F, g: &Self::F) {
446        assert_eq!(f.0.len(), g.0.len());
447        assert_eq!(f.1.len(), g.1.len());
448        assert_eq!(f.2.len(), g.2.len());
449        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
450            *f *= *g;
451        }
452        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
453            *f *= *g;
454        }
455        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
456            *f *= *g;
457        }
458    }
459    fn convolve(a: Self::T, b: Self::T) -> Self::T {
460        if Self::length(&a).max(Self::length(&b)) <= 300 {
461            return convolve_karatsuba(&a, &b);
462        }
463        if Self::length(&a).min(Self::length(&b)) <= 60 {
464            return convolve_naive(&a, &b);
465        }
466        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
467        let mut a = Self::transform(a, len);
468        let b = Self::transform(b, len);
469        Self::multiply(&mut a, &b);
470        Self::inverse_transform(a, len)
471    }
472}
473
474impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
475where
476    N1: Montgomery32NttModulus,
477    N2: Montgomery32NttModulus,
478    N3: Montgomery32NttModulus,
479{
480    type T = Vec<u64>;
481    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
482
483    fn length(t: &Self::T) -> usize {
484        t.len()
485    }
486
487    fn transform(t: Self::T, len: usize) -> Self::F {
488        let npot = len.max(1).next_power_of_two();
489        let mut f = (
490            MVec::<N1>::with_capacity(npot),
491            MVec::<N2>::with_capacity(npot),
492            MVec::<N3>::with_capacity(npot),
493        );
494        for t in t {
495            f.0.push(t.into());
496            f.1.push(t.into());
497            f.2.push(t.into());
498        }
499        f.0.resize_with(npot, Zero::zero);
500        f.1.resize_with(npot, Zero::zero);
501        f.2.resize_with(npot, Zero::zero);
502        ntt(&mut f.0);
503        ntt(&mut f.1);
504        ntt(&mut f.2);
505        f
506    }
507
508    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
509        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
510        let m1 = N1::get_mod() as u64;
511        let m1_3 = MInt::<N3>::new(N1::get_mod());
512        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
513        let m2 = m1 * N2::get_mod() as u64;
514        Convolve::<N1>::inverse_transform(f.0, len)
515            .into_iter()
516            .zip(Convolve::<N2>::inverse_transform(f.1, len))
517            .zip(Convolve::<N3>::inverse_transform(f.2, len))
518            .map(|((c1, c2), c3)| {
519                let d1 = c1.inner();
520                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
521                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
522                let d3 = ((c3 - x) * t2).inner();
523                d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
524            })
525            .collect()
526    }
527
528    fn multiply(f: &mut Self::F, g: &Self::F) {
529        assert_eq!(f.0.len(), g.0.len());
530        assert_eq!(f.1.len(), g.1.len());
531        assert_eq!(f.2.len(), g.2.len());
532        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
533            *f *= *g;
534        }
535        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
536            *f *= *g;
537        }
538        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
539            *f *= *g;
540        }
541    }
542
543    fn convolve(a: Self::T, b: Self::T) -> Self::T {
544        if Self::length(&a).max(Self::length(&b)) <= 300 {
545            return convolve_karatsuba(&a, &b);
546        }
547        if Self::length(&a).min(Self::length(&b)) <= 60 {
548            return convolve_naive(&a, &b);
549        }
550        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
551        let mut a = Self::transform(a, len);
552        let b = Self::transform(b, len);
553        Self::multiply(&mut a, &b);
554        Self::inverse_transform(a, len)
555    }