fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 858)
840fn convolve_karatsuba<T>(a: &[T], b: &[T]) -> Vec<T>
841where
842 T: Copy + Zero + AddAssign<T> + SubAssign<T> + Mul<Output = T>,
843{
844 if a.len().min(b.len()) <= 30 {
845 return convolve_naive(a, b);
846 }
847 let m = a.len().max(b.len()).div_ceil(2);
848 let (a0, a1) = if a.len() <= m {
849 (a, &[][..])
850 } else {
851 a.split_at(m)
852 };
853 let (b0, b1) = if b.len() <= m {
854 (b, &[][..])
855 } else {
856 b.split_at(m)
857 };
858 let f00 = convolve_karatsuba(a0, b0);
859 let f11 = convolve_karatsuba(a1, b1);
860 let mut a0a1 = a0.to_vec();
861 for (a0a1, &a1) in a0a1.iter_mut().zip(a1) {
862 *a0a1 += a1;
863 }
864 let mut b0b1 = b0.to_vec();
865 for (b0b1, &b1) in b0b1.iter_mut().zip(b1) {
866 *b0b1 += b1;
867 }
868 let mut f01 = convolve_karatsuba(&a0a1, &b0b1);
869 for (f01, &f00) in f01.iter_mut().zip(&f00) {
870 *f01 -= f00;
871 }
872 for (f01, &f11) in f01.iter_mut().zip(&f11) {
873 *f01 -= f11;
874 }
875 let mut c = vec![T::zero(); a.len() + b.len() - 1];
876 for (c, &f00) in c.iter_mut().zip(&f00) {
877 *c += f00;
878 }
879 for (c, &f01) in c[m..].iter_mut().zip(&f01) {
880 *c += f01;
881 }
882 for (c, &f11) in c[m << 1..].iter_mut().zip(&f11) {
883 *c += f11;
884 }
885 c
886}
887
888impl<M> ConvolveSteps for Convolve<M>
889where
890 M: Montgomery32NttModulus,
891{
892 type T = Vec<MInt<M>>;
893 type F = Vec<MInt<M>>;
894 fn length(t: &Self::T) -> usize {
895 t.len()
896 }
897 fn transform(mut t: Self::T, len: usize) -> Self::F {
898 t.resize_with(len.max(1).next_power_of_two(), Zero::zero);
899 ntt(&mut t);
900 t
901 }
902 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
903 intt(&mut f);
904 f.truncate(len);
905 let inv = MInt::from(len.max(1).next_power_of_two() as u32).inv();
906 for f in f.iter_mut() {
907 *f *= inv;
908 }
909 f
910 }
911 fn multiply(f: &mut Self::F, g: &Self::F) {
912 assert_eq!(f.len(), g.len());
913 for (f, g) in f.iter_mut().zip(g.iter()) {
914 *f *= *g;
915 }
916 }
917 fn convolve(mut a: Self::T, mut b: Self::T) -> Self::T {
918 if Self::length(&a).max(Self::length(&b)) <= 100 {
919 return convolve_karatsuba(&a, &b);
920 }
921 if Self::length(&a).min(Self::length(&b)) <= 60 {
922 return convolve_naive(&a, &b);
923 }
924 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
925 let size = len.max(1).next_power_of_two();
926 if len <= size / 2 + 2 {
927 let xa = a.pop().unwrap();
928 let xb = b.pop().unwrap();
929 let mut c = vec![MInt::<M>::zero(); len];
930 *c.last_mut().unwrap() = xa * xb;
931 for (a, c) in a.iter().zip(&mut c[b.len()..]) {
932 *c += *a * xb;
933 }
934 for (b, c) in b.iter().zip(&mut c[a.len()..]) {
935 *c += *b * xa;
936 }
937 let d = Self::convolve(a, b);
938 for (d, c) in d.into_iter().zip(&mut c) {
939 *c += d;
940 }
941 return c;
942 }
943 let same = a == b;
944 let mut a = Self::transform(a, len);
945 if same {
946 for a in a.iter_mut() {
947 *a *= *a;
948 }
949 } else {
950 let b = Self::transform(b, len);
951 Self::multiply(&mut a, &b);
952 }
953 Self::inverse_transform(a, len)
954 }
955}
956
957type MVec<M> = Vec<MInt<M>>;
958impl<M, N1, N2, N3> ConvolveSteps for Convolve<(M, (N1, N2, N3))>
959where
960 M: MIntConvert + MIntConvert<u32>,
961 N1: Montgomery32NttModulus,
962 N2: Montgomery32NttModulus,
963 N3: Montgomery32NttModulus,
964{
965 type T = MVec<M>;
966 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
967 fn length(t: &Self::T) -> usize {
968 t.len()
969 }
970 fn transform(t: Self::T, len: usize) -> Self::F {
971 let npot = len.max(1).next_power_of_two();
972 let mut f = (
973 MVec::<N1>::with_capacity(npot),
974 MVec::<N2>::with_capacity(npot),
975 MVec::<N3>::with_capacity(npot),
976 );
977 for t in t {
978 f.0.push(<M as MIntConvert<u32>>::into(t.inner()).into());
979 f.1.push(<M as MIntConvert<u32>>::into(t.inner()).into());
980 f.2.push(<M as MIntConvert<u32>>::into(t.inner()).into());
981 }
982 f.0.resize_with(npot, Zero::zero);
983 f.1.resize_with(npot, Zero::zero);
984 f.2.resize_with(npot, Zero::zero);
985 ntt(&mut f.0);
986 ntt(&mut f.1);
987 ntt(&mut f.2);
988 f
989 }
990 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
991 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
992 let m1 = MInt::<M>::from(N1::get_mod());
993 let m1_3 = MInt::<N3>::new(N1::get_mod());
994 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
995 let m2 = m1 * MInt::<M>::from(N2::get_mod());
996 Convolve::<N1>::inverse_transform(f.0, len)
997 .into_iter()
998 .zip(Convolve::<N2>::inverse_transform(f.1, len))
999 .zip(Convolve::<N3>::inverse_transform(f.2, len))
1000 .map(|((c1, c2), c3)| {
1001 let d1 = c1.inner();
1002 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1003 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1004 let d3 = ((c3 - x) * t2).inner();
1005 MInt::<M>::from(d1) + MInt::<M>::from(d2) * m1 + MInt::<M>::from(d3) * m2
1006 })
1007 .collect()
1008 }
1009 fn multiply(f: &mut Self::F, g: &Self::F) {
1010 assert_eq!(f.0.len(), g.0.len());
1011 assert_eq!(f.1.len(), g.1.len());
1012 assert_eq!(f.2.len(), g.2.len());
1013 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1014 *f *= *g;
1015 }
1016 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1017 *f *= *g;
1018 }
1019 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1020 *f *= *g;
1021 }
1022 }
1023 fn convolve(a: Self::T, b: Self::T) -> Self::T {
1024 if Self::length(&a).max(Self::length(&b)) <= 300 {
1025 return convolve_karatsuba(&a, &b);
1026 }
1027 if Self::length(&a).min(Self::length(&b)) <= 60 {
1028 return convolve_naive(&a, &b);
1029 }
1030 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1031 let mut a = Self::transform(a, len);
1032 let b = Self::transform(b, len);
1033 Self::multiply(&mut a, &b);
1034 Self::inverse_transform(a, len)
1035 }
1036}
1037
1038impl<N1, N2, N3> ConvolveSteps for Convolve<(u64, (N1, N2, N3))>
1039where
1040 N1: Montgomery32NttModulus,
1041 N2: Montgomery32NttModulus,
1042 N3: Montgomery32NttModulus,
1043{
1044 type T = Vec<u64>;
1045 type F = (MVec<N1>, MVec<N2>, MVec<N3>);
1046
1047 fn length(t: &Self::T) -> usize {
1048 t.len()
1049 }
1050
1051 fn transform(t: Self::T, len: usize) -> Self::F {
1052 let npot = len.max(1).next_power_of_two();
1053 let mut f = (
1054 MVec::<N1>::with_capacity(npot),
1055 MVec::<N2>::with_capacity(npot),
1056 MVec::<N3>::with_capacity(npot),
1057 );
1058 for t in t {
1059 f.0.push(t.into());
1060 f.1.push(t.into());
1061 f.2.push(t.into());
1062 }
1063 f.0.resize_with(npot, Zero::zero);
1064 f.1.resize_with(npot, Zero::zero);
1065 f.2.resize_with(npot, Zero::zero);
1066 ntt(&mut f.0);
1067 ntt(&mut f.1);
1068 ntt(&mut f.2);
1069 f
1070 }
1071
1072 fn inverse_transform(f: Self::F, len: usize) -> Self::T {
1073 let t1 = MInt::<N2>::new(N1::get_mod()).inv();
1074 let m1 = N1::get_mod() as u64;
1075 let m1_3 = MInt::<N3>::new(N1::get_mod());
1076 let t2 = (m1_3 * MInt::<N3>::new(N2::get_mod())).inv();
1077 let m2 = m1 * N2::get_mod() as u64;
1078 Convolve::<N1>::inverse_transform(f.0, len)
1079 .into_iter()
1080 .zip(Convolve::<N2>::inverse_transform(f.1, len))
1081 .zip(Convolve::<N3>::inverse_transform(f.2, len))
1082 .map(|((c1, c2), c3)| {
1083 let d1 = c1.inner();
1084 let d2 = ((c2 - MInt::<N2>::from(d1)) * t1).inner();
1085 let x = MInt::<N3>::new(d1) + MInt::<N3>::new(d2) * m1_3;
1086 let d3 = ((c3 - x) * t2).inner();
1087 d1 as u64 + d2 as u64 * m1 + d3 as u64 * m2
1088 })
1089 .collect()
1090 }
1091
1092 fn multiply(f: &mut Self::F, g: &Self::F) {
1093 assert_eq!(f.0.len(), g.0.len());
1094 assert_eq!(f.1.len(), g.1.len());
1095 assert_eq!(f.2.len(), g.2.len());
1096 for (f, g) in f.0.iter_mut().zip(g.0.iter()) {
1097 *f *= *g;
1098 }
1099 for (f, g) in f.1.iter_mut().zip(g.1.iter()) {
1100 *f *= *g;
1101 }
1102 for (f, g) in f.2.iter_mut().zip(g.2.iter()) {
1103 *f *= *g;
1104 }
1105 }
1106
1107 fn convolve(a: Self::T, b: Self::T) -> Self::T {
1108 if Self::length(&a).max(Self::length(&b)) <= 300 {
1109 return convolve_karatsuba(&a, &b);
1110 }
1111 if Self::length(&a).min(Self::length(&b)) <= 60 {
1112 return convolve_naive(&a, &b);
1113 }
1114 let len = (Self::length(&a) + Self::length(&b)).saturating_sub(1);
1115 let mut a = Self::transform(a, len);
1116 let b = Self::transform(b, len);
1117 Self::multiply(&mut a, &b);
1118 Self::inverse_transform(a, len)
1119 }