competitive/math/
fast_fourier_transform.rs

1use super::{AssociatedValue, Complex, ConvolveSteps, One, Zero};
2
3pub enum ConvolveRealFft {}
4
5enum RotateCache {}
6impl RotateCache {
7    fn ensure(n: usize) {
8        assert_eq!(n.count_ones(), 1, "call with power of two but {}", n);
9        Self::modify(|cache| {
10            let mut m = cache.len();
11            assert!(
12                m.count_ones() <= 1,
13                "length might be power of two but {}",
14                m
15            );
16            if m >= n {
17                return;
18            }
19            cache.reserve_exact(n - m);
20            if cache.is_empty() {
21                cache.push(Complex::one());
22                m += 1;
23            }
24            while m < n {
25                let p = Complex::primitive_nth_root_of_unity(-((m * 4) as f64));
26                for i in 0..m {
27                    cache.push(cache[i] * p);
28                }
29                m <<= 1;
30            }
31            assert_eq!(cache.len(), n);
32        });
33    }
34}
35crate::impl_assoc_value!(RotateCache, Vec<Complex<f64>>, vec![Complex::one()]);
36
37fn bit_reverse<T>(f: &mut [T]) {
38    let mut ip = vec![0u32];
39    let mut k = f.len();
40    let mut m = 1;
41    while 2 * m < k {
42        k /= 2;
43        for j in 0..m {
44            ip.push(ip[j] + k as u32);
45        }
46        m *= 2;
47    }
48    if m == k {
49        for i in 1..m {
50            for j in 0..i {
51                let ji = j + ip[i] as usize;
52                let ij = i + ip[j] as usize;
53                f.swap(ji, ij);
54            }
55        }
56    } else {
57        for i in 1..m {
58            for j in 0..i {
59                let ji = j + ip[i] as usize;
60                let ij = i + ip[j] as usize;
61                f.swap(ji, ij);
62                f.swap(ji + m, ij + m);
63            }
64        }
65    }
66}
67
68impl ConvolveSteps for ConvolveRealFft {
69    type T = Vec<i64>;
70    type F = Vec<Complex<f64>>;
71    fn length(t: &Self::T) -> usize {
72        t.len()
73    }
74    fn transform(t: Self::T, len: usize) -> Self::F {
75        let n = len.max(4).next_power_of_two();
76        let mut f = vec![Complex::zero(); n / 2];
77        for (i, t) in t.into_iter().enumerate() {
78            if i & 1 == 0 {
79                f[i / 2].re = t as f64;
80            } else {
81                f[i / 2].im = t as f64;
82            }
83        }
84        fft(&mut f);
85        bit_reverse(&mut f);
86        f[0] = Complex::new(f[0].re + f[0].im, f[0].re - f[0].im);
87        f[n / 4] = f[n / 4].conjugate();
88        let w = Complex::primitive_nth_root_of_unity(-(n as f64));
89        let mut wk = Complex::<f64>::one();
90        for k in 1..n / 4 {
91            wk *= w;
92            let c = wk.conjugate().transpose() + 1.;
93            let d = c * (f[k] - f[n / 2 - k].conjugate()) * 0.5;
94            f[k] -= d;
95            f[n / 2 - k] += d.conjugate();
96        }
97        f
98    }
99    fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
100        let n = len.max(4).next_power_of_two();
101        assert_eq!(f.len(), n / 2);
102        f[0] = Complex::new((f[0].re + f[0].im) * 0.5, (f[0].re - f[0].im) * 0.5);
103        f[n / 4] = f[n / 4].conjugate();
104        let w = Complex::primitive_nth_root_of_unity(n as f64);
105        let mut wk = Complex::<f64>::one();
106        for k in 1..n / 4 {
107            wk *= w;
108            let c = wk.transpose().conjugate() + 1.;
109            let d = c * (f[k] - f[n / 2 - k].conjugate()) * 0.5;
110            f[k] -= d;
111            f[n / 2 - k] += d.conjugate();
112        }
113        bit_reverse(&mut f);
114        ifft(&mut f);
115        let inv = 1. / (n / 2) as f64;
116        (0..len)
117            .map(|i| (inv * if i & 1 == 0 { f[i / 2].re } else { f[i / 2].im }).round() as i64)
118            .collect()
119    }
120    fn multiply(f: &mut Self::F, g: &Self::F) {
121        assert_eq!(f.len(), g.len());
122        f[0].re *= g[0].re;
123        f[0].im *= g[0].im;
124        for (f, g) in f.iter_mut().zip(g.iter()).skip(1) {
125            *f *= *g;
126        }
127    }
128}
129
130pub fn fft(a: &mut [Complex<f64>]) {
131    let n = a.len();
132    RotateCache::ensure(n / 2);
133    RotateCache::with(|cache| {
134        let mut v = n / 2;
135        while v > 0 {
136            for (a, wj) in a.chunks_exact_mut(v << 1).zip(cache) {
137                let (l, r) = a.split_at_mut(v);
138                for (x, y) in l.iter_mut().zip(r) {
139                    let ajv = wj * *y;
140                    *y = *x - ajv;
141                    *x += ajv;
142                }
143            }
144            v >>= 1;
145        }
146    });
147}
148
149pub fn ifft(a: &mut [Complex<f64>]) {
150    let n = a.len();
151    RotateCache::ensure(n / 2);
152    RotateCache::with(|cache| {
153        let mut v = 1;
154        while v < n {
155            for (a, wj) in a
156                .chunks_exact_mut(v << 1)
157                .zip(cache.iter().map(|wj| wj.conjugate()))
158            {
159                let (l, r) = a.split_at_mut(v);
160                for (x, y) in l.iter_mut().zip(r) {
161                    let ajv = *x - *y;
162                    *x += *y;
163                    *y = wj * ajv;
164                }
165            }
166            v <<= 1;
167        }
168    });
169}
170
171#[test]
172fn test_convolve_fft() {
173    use crate::{rand, tools::Xorshift};
174    let mut rng = Xorshift::default();
175    for n in 0..10 {
176        for m in 0..10 {
177            for rn in 0..2 {
178                for rm in 0..2 {
179                    let n = 2usize.pow(n);
180                    let m = 2usize.pow(m);
181                    let n = n - rng.random(0..n) * rn;
182                    let m = m - rng.random(0..m) * rm;
183                    const A: i64 = 100_000;
184                    rand!(rng, a: [-A..=A; n], b: [-A..=A; m]);
185                    let mut c = vec![0; n + m - 1];
186                    for i in 0..n {
187                        for j in 0..m {
188                            c[i + j] += a[i] * b[j];
189                        }
190                    }
191                    let d = ConvolveRealFft::convolve(a, b);
192                    assert_eq!(c, d);
193                }
194            }
195        }
196    }
197}