competitive/math/
fast_fourier_transform.rs1use super::{AssociatedValue, Complex, ConvolveSteps, One, Zero};
2
3pub enum ConvolveRealFft {}
4
5enum RotateCache {}
6impl RotateCache {
7 fn ensure(n: usize) {
8 assert_eq!(n.count_ones(), 1, "call with power of two but {}", n);
9 Self::modify(|cache| {
10 let mut m = cache.len();
11 assert!(
12 m.count_ones() <= 1,
13 "length might be power of two but {}",
14 m
15 );
16 if m >= n {
17 return;
18 }
19 cache.reserve_exact(n - m);
20 if cache.is_empty() {
21 cache.push(Complex::one());
22 m += 1;
23 }
24 while m < n {
25 let p = Complex::primitive_nth_root_of_unity(-((m * 4) as f64));
26 for i in 0..m {
27 cache.push(cache[i] * p);
28 }
29 m <<= 1;
30 }
31 assert_eq!(cache.len(), n);
32 });
33 }
34}
35crate::impl_assoc_value!(RotateCache, Vec<Complex<f64>>, vec![Complex::one()]);
36
37fn bit_reverse<T>(f: &mut [T]) {
38 let mut ip = vec![0u32];
39 let mut k = f.len();
40 let mut m = 1;
41 while 2 * m < k {
42 k /= 2;
43 for j in 0..m {
44 ip.push(ip[j] + k as u32);
45 }
46 m *= 2;
47 }
48 if m == k {
49 for i in 1..m {
50 for j in 0..i {
51 let ji = j + ip[i] as usize;
52 let ij = i + ip[j] as usize;
53 f.swap(ji, ij);
54 }
55 }
56 } else {
57 for i in 1..m {
58 for j in 0..i {
59 let ji = j + ip[i] as usize;
60 let ij = i + ip[j] as usize;
61 f.swap(ji, ij);
62 f.swap(ji + m, ij + m);
63 }
64 }
65 }
66}
67
68impl ConvolveSteps for ConvolveRealFft {
69 type T = Vec<i64>;
70 type F = Vec<Complex<f64>>;
71 fn length(t: &Self::T) -> usize {
72 t.len()
73 }
74 fn transform(t: Self::T, len: usize) -> Self::F {
75 let n = len.max(4).next_power_of_two();
76 let mut f = vec![Complex::zero(); n / 2];
77 for (i, t) in t.into_iter().enumerate() {
78 if i & 1 == 0 {
79 f[i / 2].re = t as f64;
80 } else {
81 f[i / 2].im = t as f64;
82 }
83 }
84 fft(&mut f);
85 bit_reverse(&mut f);
86 f[0] = Complex::new(f[0].re + f[0].im, f[0].re - f[0].im);
87 f[n / 4] = f[n / 4].conjugate();
88 let w = Complex::primitive_nth_root_of_unity(-(n as f64));
89 let mut wk = Complex::<f64>::one();
90 for k in 1..n / 4 {
91 wk *= w;
92 let c = wk.conjugate().transpose() + 1.;
93 let d = c * (f[k] - f[n / 2 - k].conjugate()) * 0.5;
94 f[k] -= d;
95 f[n / 2 - k] += d.conjugate();
96 }
97 f
98 }
99 fn inverse_transform(mut f: Self::F, len: usize) -> Self::T {
100 let n = len.max(4).next_power_of_two();
101 assert_eq!(f.len(), n / 2);
102 f[0] = Complex::new((f[0].re + f[0].im) * 0.5, (f[0].re - f[0].im) * 0.5);
103 f[n / 4] = f[n / 4].conjugate();
104 let w = Complex::primitive_nth_root_of_unity(n as f64);
105 let mut wk = Complex::<f64>::one();
106 for k in 1..n / 4 {
107 wk *= w;
108 let c = wk.transpose().conjugate() + 1.;
109 let d = c * (f[k] - f[n / 2 - k].conjugate()) * 0.5;
110 f[k] -= d;
111 f[n / 2 - k] += d.conjugate();
112 }
113 bit_reverse(&mut f);
114 ifft(&mut f);
115 let inv = 1. / (n / 2) as f64;
116 (0..len)
117 .map(|i| (inv * if i & 1 == 0 { f[i / 2].re } else { f[i / 2].im }).round() as i64)
118 .collect()
119 }
120 fn multiply(f: &mut Self::F, g: &Self::F) {
121 assert_eq!(f.len(), g.len());
122 f[0].re *= g[0].re;
123 f[0].im *= g[0].im;
124 for (f, g) in f.iter_mut().zip(g.iter()).skip(1) {
125 *f *= *g;
126 }
127 }
128}
129
130pub fn fft(a: &mut [Complex<f64>]) {
131 let n = a.len();
132 RotateCache::ensure(n / 2);
133 RotateCache::with(|cache| {
134 let mut v = n / 2;
135 while v > 0 {
136 for (a, wj) in a.chunks_exact_mut(v << 1).zip(cache) {
137 let (l, r) = a.split_at_mut(v);
138 for (x, y) in l.iter_mut().zip(r) {
139 let ajv = wj * *y;
140 *y = *x - ajv;
141 *x += ajv;
142 }
143 }
144 v >>= 1;
145 }
146 });
147}
148
149pub fn ifft(a: &mut [Complex<f64>]) {
150 let n = a.len();
151 RotateCache::ensure(n / 2);
152 RotateCache::with(|cache| {
153 let mut v = 1;
154 while v < n {
155 for (a, wj) in a
156 .chunks_exact_mut(v << 1)
157 .zip(cache.iter().map(|wj| wj.conjugate()))
158 {
159 let (l, r) = a.split_at_mut(v);
160 for (x, y) in l.iter_mut().zip(r) {
161 let ajv = *x - *y;
162 *x += *y;
163 *y = wj * ajv;
164 }
165 }
166 v <<= 1;
167 }
168 });
169}
170
171#[test]
172fn test_convolve_fft() {
173 use crate::{rand, tools::Xorshift};
174 let mut rng = Xorshift::default();
175 for n in 0..10 {
176 for m in 0..10 {
177 for rn in 0..2 {
178 for rm in 0..2 {
179 let n = 2usize.pow(n);
180 let m = 2usize.pow(m);
181 let n = n - rng.random(0..n) * rn;
182 let m = m - rng.random(0..m) * rm;
183 const A: i64 = 100_000;
184 rand!(rng, a: [-A..=A; n], b: [-A..=A; m]);
185 let mut c = vec![0; n + m - 1];
186 for i in 0..n {
187 for j in 0..m {
188 c[i + j] += a[i] * b[j];
189 }
190 }
191 let d = ConvolveRealFft::convolve(a, b);
192 assert_eq!(c, d);
193 }
194 }
195 }
196 }
197}