Skip to main content

convolve_karatsuba

Function convolve_karatsuba 

Source
fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
where T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 858)
840fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
841where
842    T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
843{
844    if a.len().min(b.len()) <= 30 {
845        return convolve_naive(a, b);
846    }
847    let m = a.len().max(b.len()).div_ceil(2);
848    let (a0, a1) = if a.len() <= m {
849        (a, &[][..])
850    } else {
851        a.split_at(m)
852    };
853    let (b0, b1) = if b.len() <= m {
854        (b, &[][..])
855    } else {
856        b.split_at(m)
857    };
858    let f00 = convolve_karatsuba(a0, b0);
859    let f11 = convolve_karatsuba(a1, b1);
860    let mut a0a1 = a0.to_vec();
861    for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
862        *a0a1 += a1;
863    }
864    let mut b0b1 = b0.to_vec();
865    for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
866        *b0b1 += b1;
867    }
868    let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
869    for (f01, &f00) in f01.iter_mut().zip(&f00) {
870        *f01 -= f00;
871    }
872    for (f01, &f11) in f01.iter_mut().zip(&f11) {
873        *f01 -= f11;
874    }
875    let mut c = vec![T::zero(); a.len() + b.len() - 1];
876    for (c, &f00) in c.iter_mut().zip(&f00) {
877        *c += f00;
878    }
879    for (c, &f01) in c[m..].iter_mut().zip(&f01) {
880        *c += f01;
881    }
882    for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
883        *c += f11;
884    }
885    c
886}
887
888impl<M> ConvolveSteps for Convolve<M>
889where
890    M: Montgomery32NttModulus,
891{
892    type T = Vec<MInt<M>>;
893    type F = Vec<MInt<M>>;
894    fn length(t: &Self::T) -> usize {
895        t.len()
896    }
897    fn transform(mut t: Self::T, len: usize) -> Self::F {
898        t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
899        ntt(&mut t);
900        t
901    }
902    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
903        intt(&mut f);
904        f.truncate(len);
905        let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
906        for f in f.iter_mut() {
907            *f *= inv;
908        }
909        f
910    }
911    fn multiply(f: &mut Self::F, g: &Self::F) {
912        assert_eq!(f.len(), g.len());
913        for (f, g) in f.iter_mut().zip(g.iter()) {
914            *f *= *g;
915        }
916    }
917    fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
918        if Self::length(&a).max(Self::length(&b)) <= 100 {
919            return convolve_karatsuba(&a, &b);
920        }
921        if Self::length(&a).min(Self::length(&b)) <= 60 {
922            return convolve_naive(&a, &b);
923        }
924        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
925        let size = len.max(1).next_power_of_two();
926        if len <= size / 2 + 2 {
927            let xa = a.pop().unwrap();
928            let xb = b.pop().unwrap();
929            let mut c = vec![MInt::<M>::zero(); len];
930            *c.last_mut().unwrap() = xa * xb;
931            for (a, c) in a.iter().zip(&mut c[b.len()..]) {
932                *c += *a * xb;
933            }
934            for (b, c) in b.iter().zip(&mut c[a.len()..]) {
935                *c += *b * xa;
936            }
937            let d = Self::convolve(a, b);
938            for (d, c) in d.into_iter().zip(&mut c) {
939                *c += d;
940            }
941            return c;
942        }
943        let same = a == b;
944        let mut a = Self::transform(a, len);
945        if same {
946            for a in a.iter_mut() {
947                *a *= *a;
948            }
949        } else {
950            let b = Self::transform(b, len);
951            Self::multiply(&mut a, &b);
952        }
953        Self::inverse_transform(a, len)
954    }
955}
956
957type MVec<M> = Vec<MInt<M>>;
958impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
959where
960    M: MIntConvert + MIntConvert<u32>,
961    N1: Montgomery32NttModulus,
962    N2: Montgomery32NttModulus,
963    N3: Montgomery32NttModulus,
964{
965    type T = MVec<M>;
966    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
967    fn length(t: &Self::T) -> usize {
968        t.len()
969    }
970    fn transform(t: Self::T, len: usize) -> Self::F {
971        let npot = len.max(1).next_power_of_two();
972        let mut f = (
973            MVec::<N1>::with_capacity(npot),
974            MVec::<N2>::with_capacity(npot),
975            MVec::<N3>::with_capacity(npot),
976        );
977        for t in t {
978            f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
979            f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
980            f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
981        }
982        f.0.resize_with(npot, Zero::zero);
983        f.1.resize_with(npot, Zero::zero);
984        f.2.resize_with(npot, Zero::zero);
985        ntt(&mut f.0);
986        ntt(&mut f.1);
987        ntt(&mut f.2);
988        f
989    }
990    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
991        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
992        let m1 = MInt::<M>::from(N1::get_mod());
993        let m1_3 = MInt::<N3>::new(N1::get_mod());
994        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
995        let m2 = m1 * MInt::<M>::from(N2::get_mod());
996        Convolve::<N1>::inverse_transform(f.0, len)
997            .into_iter()
998            .zip(Convolve::<N2>::inverse_transform(f.1, len))
999            .zip(Convolve::<N3>::inverse_transform(f.2, len))
1000            .map(|((c1, c2), c3)| {
1001                let d1 = c1.inner();
1002                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1003                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1004                let d3 = ((c3 - x) * t2).inner();
1005                MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
1006            })
1007            .collect()
1008    }
1009    fn multiply(f: &mut Self::F, g: &Self::F) {
1010        assert_eq!(f.0.len(), g.0.len());
1011        assert_eq!(f.1.len(), g.1.len());
1012        assert_eq!(f.2.len(), g.2.len());
1013        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1014            *f *= *g;
1015        }
1016        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1017            *f *= *g;
1018        }
1019        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1020            *f *= *g;
1021        }
1022    }
1023    fn convolve(a: Self::T, b: Self::T) -> Self::T {
1024        if Self::length(&a).max(Self::length(&b)) <= 300 {
1025            return convolve_karatsuba(&a, &b);
1026        }
1027        if Self::length(&a).min(Self::length(&b)) <= 60 {
1028            return convolve_naive(&a, &b);
1029        }
1030        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1031        let mut a = Self::transform(a, len);
1032        let b = Self::transform(b, len);
1033        Self::multiply(&mut a, &b);
1034        Self::inverse_transform(a, len)
1035    }
1036}
1037
1038impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
1039where
1040    N1: Montgomery32NttModulus,
1041    N2: Montgomery32NttModulus,
1042    N3: Montgomery32NttModulus,
1043{
1044    type T = Vec<u64>;
1045    type F = (MVec<N1>, MVec<N2>, MVec<N3>);
1046
1047    fn length(t: &Self::T) -> usize {
1048        t.len()
1049    }
1050
1051    fn transform(t: Self::T, len: usize) -> Self::F {
1052        let npot = len.max(1).next_power_of_two();
1053        let mut f = (
1054            MVec::<N1>::with_capacity(npot),
1055            MVec::<N2>::with_capacity(npot),
1056            MVec::<N3>::with_capacity(npot),
1057        );
1058        for t in t {
1059            f.0.push(t.into());
1060            f.1.push(t.into());
1061            f.2.push(t.into());
1062        }
1063        f.0.resize_with(npot, Zero::zero);
1064        f.1.resize_with(npot, Zero::zero);
1065        f.2.resize_with(npot, Zero::zero);
1066        ntt(&mut f.0);
1067        ntt(&mut f.1);
1068        ntt(&mut f.2);
1069        f
1070    }
1071
1072    fn inverse_transform(f: Self::F, len: usize) -> Self::T {
1073        let t1 = MInt::<N2>::new(N1::get_mod()).inv();
1074        let m1 = N1::get_mod() as u64;
1075        let m1_3 = MInt::<N3>::new(N1::get_mod());
1076        let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
1077        let m2 = m1 * N2::get_mod() as u64;
1078        Convolve::<N1>::inverse_transform(f.0, len)
1079            .into_iter()
1080            .zip(Convolve::<N2>::inverse_transform(f.1, len))
1081            .zip(Convolve::<N3>::inverse_transform(f.2, len))
1082            .map(|((c1, c2), c3)| {
1083                let d1 = c1.inner();
1084                let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1085                let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1086                let d3 = ((c3 - x) * t2).inner();
1087                d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
1088            })
1089            .collect()
1090    }
1091
1092    fn multiply(f: &mut Self::F, g: &Self::F) {
1093        assert_eq!(f.0.len(), g.0.len());
1094        assert_eq!(f.1.len(), g.1.len());
1095        assert_eq!(f.2.len(), g.2.len());
1096        for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1097            *f *= *g;
1098        }
1099        for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1100            *f *= *g;
1101        }
1102        for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1103            *f *= *g;
1104        }
1105    }
1106
1107    fn convolve(a: Self::T, b: Self::T) -> Self::T {
1108        if Self::length(&a).max(Self::length(&b)) <= 300 {
1109            return convolve_karatsuba(&a, &b);
1110        }
1111        if Self::length(&a).min(Self::length(&b)) <= 60 {
1112            return convolve_naive(&a, &b);
1113        }
1114        let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1115        let mut a = Self::transform(a, len);
1116        let b = Self::transform(b, len);
1117        Self::multiply(&mut a, &b);
1118        Self::inverse_transform(a, len)
1119    }