competitive/math/
bitwisexor_convolve.rs

1use super::{ConvolveSteps, Field, Group, Invertible, bitwise_transform};
2use std::{fmt::Debug, marker::PhantomData};
3
4pub struct BitwisexorConvolve<M, const TRY: bool = false> {
5    _marker: PhantomData<fn() -> M>,
6}
7
8impl<G, const TRY: bool> BitwisexorConvolve<G, TRY>
9where
10    G: Group,
11{
12    pub fn hadamard_transform(f: &mut [G::T]) {
13        bitwise_transform(f, |x, y| {
14            let t = G::operate(x, y);
15            *y = G::rinv_operate(x, y);
16            *x = t;
17        });
18    }
19}
20
21impl<R> ConvolveSteps for BitwisexorConvolve<R, false>
22where
23    R: Field<T: PartialEq + From<usize>, Additive: Invertible, Multiplicative: Invertible>,
24{
25    type T = Vec<R::T>;
26    type F = Vec<R::T>;
27
28    fn length(t: &Self::T) -> usize {
29        t.len()
30    }
31
32    fn transform(mut t: Self::T, _len: usize) -> Self::F {
33        BitwisexorConvolve::<R::Additive, false>::hadamard_transform(&mut t);
34        t
35    }
36
37    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
38        BitwisexorConvolve::<R::Additive, false>::hadamard_transform(&mut f);
39        let len = R::T::from(len);
40        for f in f.iter_mut() {
41            *f = R::div(f, &len);
42        }
43        f
44    }
45
46    fn multiply(f: &mut Self::F, g: &Self::F) {
47        for (f, g) in f.iter_mut().zip(g) {
48            *f = R::mul(f, g);
49        }
50    }
51
52    fn convolve(a: Self::T, b: Self::T) -> Self::T {
53        assert_eq!(a.len(), b.len());
54        let len = a.len();
55        let same = a == b;
56        let mut a = Self::transform(a, len);
57        if same {
58            for a in a.iter_mut() {
59                *a = R::mul(a, a);
60            }
61        } else {
62            let b = Self::transform(b, len);
63            Self::multiply(&mut a, &b);
64        }
65        Self::inverse_transform(a, len)
66    }
67}
68
69impl<R> ConvolveSteps for BitwisexorConvolve<R, true>
70where
71    R: Field<T: PartialEq + TryFrom<usize>, Additive: Invertible, Multiplicative: Invertible>,
72    <R::T as TryFrom<usize>>::Error: Debug,
73{
74    type T = Vec<R::T>;
75    type F = Vec<R::T>;
76
77    fn length(t: &Self::T) -> usize {
78        t.len()
79    }
80
81    fn transform(mut t: Self::T, _len: usize) -> Self::F {
82        BitwisexorConvolve::<R::Additive, true>::hadamard_transform(&mut t);
83        t
84    }
85
86    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
87        BitwisexorConvolve::<R::Additive, true>::hadamard_transform(&mut f);
88        let len = R::T::try_from(len).unwrap();
89        for f in f.iter_mut() {
90            *f = R::div(f, &len);
91        }
92        f
93    }
94
95    fn multiply(f: &mut Self::F, g: &Self::F) {
96        for (f, g) in f.iter_mut().zip(g) {
97            *f = R::mul(f, g);
98        }
99    }
100
101    fn convolve(a: Self::T, b: Self::T) -> Self::T {
102        assert_eq!(a.len(), b.len());
103        let len = a.len();
104        let same = a == b;
105        let mut a = Self::transform(a, len);
106        if same {
107            for a in a.iter_mut() {
108                *a = R::mul(a, a);
109            }
110        } else {
111            let b = Self::transform(b, len);
112            Self::multiply(&mut a, &b);
113        }
114        Self::inverse_transform(a, len)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::{algebra::AddMulOperation, rand, tools::Xorshift};
122
123    const A: i64 = 100_000;
124
125    #[test]
126    fn test_bitwisexor_convolve() {
127        let mut rng = Xorshift::default();
128
129        for k in 0..12 {
130            let n = 1 << k;
131            rand!(rng, f: [-A..A; n], g: [-A..A; n]);
132            let mut h = vec![0i64; n];
133            for i in 0..n {
134                for j in 0..n {
135                    h[i ^ j] += f[i] * g[j];
136                }
137            }
138            let i = BitwisexorConvolve::<AddMulOperation<i64>, true>::convolve(f, g);
139            assert_eq!(h, i);
140        }
141    }
142}