competitive/math/
bitwisexor_convolve.rs

1use super::{ConvolveSteps, Field, Group, Invertible, bitwise_transform};
2use std::{fmt::Debug, marker::PhantomData};
3
4pub struct BitwisexorConvolve<M, const TRY: bool = false> {
5    _marker: PhantomData<fn() -> M>,
6}
7
8impl<G, const TRY: bool> BitwisexorConvolve<G, TRY>
9where
10    G: Group,
11{
12    pub fn hadamard_transform(f: &mut [G::T]) {
13        bitwise_transform(f, |x, y| {
14            let t = G::operate(x, y);
15            *y = G::rinv_operate(x, y);
16            *x = t;
17        });
18    }
19}
20
21impl<R> ConvolveSteps for BitwisexorConvolve<R, false>
22where
23    R: Field,
24    R::T: PartialEq,
25    R::Additive: Invertible,
26    R::Multiplicative: Invertible,
27    R::T: From<usize>,
28{
29    type T = Vec<R::T>;
30    type F = Vec<R::T>;
31
32    fn length(t: &Self::T) -> usize {
33        t.len()
34    }
35
36    fn transform(mut t: Self::T, _len: usize) -> Self::F {
37        BitwisexorConvolve::<R::Additive, false>::hadamard_transform(&mut t);
38        t
39    }
40
41    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
42        BitwisexorConvolve::<R::Additive, false>::hadamard_transform(&mut f);
43        let len = R::T::from(len);
44        for f in f.iter_mut() {
45            *f = R::div(f, &len);
46        }
47        f
48    }
49
50    fn multiply(f: &mut Self::F, g: &Self::F) {
51        for (f, g) in f.iter_mut().zip(g) {
52            *f = R::mul(f, g);
53        }
54    }
55
56    fn convolve(a: Self::T, b: Self::T) -> Self::T {
57        assert_eq!(a.len(), b.len());
58        let len = a.len();
59        let same = a == b;
60        let mut a = Self::transform(a, len);
61        if same {
62            for a in a.iter_mut() {
63                *a = R::mul(a, a);
64            }
65        } else {
66            let b = Self::transform(b, len);
67            Self::multiply(&mut a, &b);
68        }
69        Self::inverse_transform(a, len)
70    }
71}
72
73impl<R> ConvolveSteps for BitwisexorConvolve<R, true>
74where
75    R: Field,
76    R::T: PartialEq,
77    R::Additive: Invertible,
78    R::Multiplicative: Invertible,
79    R::T: TryFrom<usize>,
80    <R::T as TryFrom<usize>>::Error: Debug,
81{
82    type T = Vec<R::T>;
83    type F = Vec<R::T>;
84
85    fn length(t: &Self::T) -> usize {
86        t.len()
87    }
88
89    fn transform(mut t: Self::T, _len: usize) -> Self::F {
90        BitwisexorConvolve::<R::Additive, true>::hadamard_transform(&mut t);
91        t
92    }
93
94    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
95        BitwisexorConvolve::<R::Additive, true>::hadamard_transform(&mut f);
96        let len = R::T::try_from(len).unwrap();
97        for f in f.iter_mut() {
98            *f = R::div(f, &len);
99        }
100        f
101    }
102
103    fn multiply(f: &mut Self::F, g: &Self::F) {
104        for (f, g) in f.iter_mut().zip(g) {
105            *f = R::mul(f, g);
106        }
107    }
108
109    fn convolve(a: Self::T, b: Self::T) -> Self::T {
110        assert_eq!(a.len(), b.len());
111        let len = a.len();
112        let same = a == b;
113        let mut a = Self::transform(a, len);
114        if same {
115            for a in a.iter_mut() {
116                *a = R::mul(a, a);
117            }
118        } else {
119            let b = Self::transform(b, len);
120            Self::multiply(&mut a, &b);
121        }
122        Self::inverse_transform(a, len)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::{algebra::AddMulOperation, rand, tools::Xorshift};
130
131    const A: i64 = 100_000;
132
133    #[test]
134    fn test_bitwisexor_convolve() {
135        let mut rng = Xorshift::new();
136
137        for k in 0..12 {
138            let n = 1 << k;
139            rand!(rng, f: [-A..A; n], g: [-A..A; n]);
140            let mut h = vec![0i64; n];
141            for i in 0..n {
142                for j in 0..n {
143                    h[i ^ j] += f[i] * g[j];
144                }
145            }
146            let i = BitwisexorConvolve::<AddMulOperation<i64>, true>::convolve(f, g);
147            assert_eq!(h, i);
148        }
149    }
150}