competitive/math/
bitwisexor_convolve.rs1use super::{ConvolveSteps, Field, Group, Invertible, bitwise_transform};
2use std::{fmt::Debug, marker::PhantomData};
3
4pub struct BitwisexorConvolve<M, const TRY: bool = false> {
5 _marker: PhantomData<fn() -> M>,
6}
7
8impl<G, const TRY: bool> BitwisexorConvolve<G, TRY>
9where
10 G: Group,
11{
12 pub fn hadamard_transform(f: &mut [G::T]) {
13 bitwise_transform(f, |x, y| {
14 let t = G::operate(x, y);
15 *y = G::rinv_operate(x, y);
16 *x = t;
17 });
18 }
19}
20
21impl<R> ConvolveSteps for BitwisexorConvolve<R, false>
22where
23 R: Field<T: PartialEq + From<usize>, Additive: Invertible, Multiplicative: Invertible>,
24{
25 type T = Vec<R::T>;
26 type F = Vec<R::T>;
27
28 fn length(t: &Self::T) -> usize {
29 t.len()
30 }
31
32 fn transform(mut t: Self::T, _len: usize) -> Self::F {
33 BitwisexorConvolve::<R::Additive, false>::hadamard_transform(&mut t);
34 t
35 }
36
37 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
38 BitwisexorConvolve::<R::Additive, false>::hadamard_transform(&mut f);
39 let len = R::T::from(len);
40 for f in f.iter_mut() {
41 *f = R::div(f, &len);
42 }
43 f
44 }
45
46 fn multiply(f: &mut Self::F, g: &Self::F) {
47 for (f, g) in f.iter_mut().zip(g) {
48 *f = R::mul(f, g);
49 }
50 }
51
52 fn convolve(a: Self::T, b: Self::T) -> Self::T {
53 assert_eq!(a.len(), b.len());
54 let len = a.len();
55 let same = a == b;
56 let mut a = Self::transform(a, len);
57 if same {
58 for a in a.iter_mut() {
59 *a = R::mul(a, a);
60 }
61 } else {
62 let b = Self::transform(b, len);
63 Self::multiply(&mut a, &b);
64 }
65 Self::inverse_transform(a, len)
66 }
67}
68
69impl<R> ConvolveSteps for BitwisexorConvolve<R, true>
70where
71 R: Field<T: PartialEq + TryFrom<usize>, Additive: Invertible, Multiplicative: Invertible>,
72 <R::T as TryFrom<usize>>::Error: Debug,
73{
74 type T = Vec<R::T>;
75 type F = Vec<R::T>;
76
77 fn length(t: &Self::T) -> usize {
78 t.len()
79 }
80
81 fn transform(mut t: Self::T, _len: usize) -> Self::F {
82 BitwisexorConvolve::<R::Additive, true>::hadamard_transform(&mut t);
83 t
84 }
85
86 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
87 BitwisexorConvolve::<R::Additive, true>::hadamard_transform(&mut f);
88 let len = R::T::try_from(len).unwrap();
89 for f in f.iter_mut() {
90 *f = R::div(f, &len);
91 }
92 f
93 }
94
95 fn multiply(f: &mut Self::F, g: &Self::F) {
96 for (f, g) in f.iter_mut().zip(g) {
97 *f = R::mul(f, g);
98 }
99 }
100
101 fn convolve(a: Self::T, b: Self::T) -> Self::T {
102 assert_eq!(a.len(), b.len());
103 let len = a.len();
104 let same = a == b;
105 let mut a = Self::transform(a, len);
106 if same {
107 for a in a.iter_mut() {
108 *a = R::mul(a, a);
109 }
110 } else {
111 let b = Self::transform(b, len);
112 Self::multiply(&mut a, &b);
113 }
114 Self::inverse_transform(a, len)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::{algebra::AddMulOperation, rand, tools::Xorshift};
122
123 const A: i64 = 100_000;
124
125 #[test]
126 fn test_bitwisexor_convolve() {
127 let mut rng = Xorshift::default();
128
129 for k in 0..12 {
130 let n = 1 << k;
131 rand!(rng, f: [-A..A; n], g: [-A..A; n]);
132 let mut h = vec![0i64; n];
133 for i in 0..n {
134 for j in 0..n {
135 h[i ^ j] += f[i] * g[j];
136 }
137 }
138 let i = BitwisexorConvolve::<AddMulOperation<i64>, true>::convolve(f, g);
139 assert_eq!(h, i);
140 }
141 }
142}