competitive/data_structure/
wavelet_matrix.rs

1use super::{BitVector, RankSelectDictionaries};
2use std::ops::Range;
3
4#[derive(Debug, Clone)]
5pub struct WaveletMatrix {
6    len: usize,
7    table: Vec<(usize, BitVector)>,
8}
9
10impl WaveletMatrix {
11    pub fn new<T>(mut v: Vec<T>, bit_length: usize) -> Self
12    where
13        T: Clone + RankSelectDictionaries,
14    {
15        let len = v.len();
16        let mut table = Vec::new();
17        for d in (0..bit_length).rev() {
18            let b: BitVector = v.iter().map(|x| x.access(d)).collect();
19            table.push((b.rank0(len), b));
20            v = v
21                .iter()
22                .filter(|&x| !x.access(d))
23                .chain(v.iter().filter(|&x| x.access(d)))
24                .cloned()
25                .collect();
26        }
27        Self { len, table }
28    }
29    pub fn new_with_init<T, F>(v: Vec<T>, bit_length: usize, mut f: F) -> Self
30    where
31        T: Clone + RankSelectDictionaries,
32        F: FnMut(usize, usize, T),
33    {
34        let this = Self::new(v.clone(), bit_length);
35        for (mut k, v) in v.into_iter().enumerate() {
36            for (d, &(c, ref b)) in this.table.iter().rev().enumerate().rev() {
37                if v.access(d) {
38                    k = c + b.rank1(k);
39                } else {
40                    k = b.rank0(k);
41                }
42                f(d, k, v.clone());
43            }
44        }
45        this
46    }
47    /// get k-th value
48    pub fn access(&self, mut k: usize) -> usize {
49        let mut val = 0;
50        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
51            if b.access(k) {
52                k = c + b.rank1(k);
53                val |= 1 << d;
54            } else {
55                k = b.rank0(k);
56            }
57        }
58        val
59    }
60    /// the number of val in range
61    pub fn rank(&self, val: usize, mut range: Range<usize>) -> usize {
62        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
63            if val.access(d) {
64                range.start = c + b.rank1(range.start);
65                range.end = c + b.rank1(range.end);
66            } else {
67                range.start = b.rank0(range.start);
68                range.end = b.rank0(range.end);
69            }
70        }
71        range.end - range.start
72    }
73    /// index of k-th val
74    pub fn select(&self, val: usize, k: usize) -> Option<usize> {
75        if self.rank(val, 0..self.len) <= k {
76            return None;
77        }
78        let mut i = 0;
79        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
80            if val.access(d) {
81                i = c + b.rank1(i);
82            } else {
83                i = b.rank0(i);
84            }
85        }
86        i += k;
87        for &(c, ref b) in self.table.iter().rev() {
88            if i >= c {
89                i = b.select1(i - c).unwrap();
90            } else {
91                i = b.select0(i).unwrap();
92            }
93        }
94        Some(i)
95    }
96    /// get k-th smallest value in range
97    pub fn quantile(&self, mut range: Range<usize>, mut k: usize) -> usize {
98        let mut val = 0;
99        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
100            let z = b.rank0(range.end) - b.rank0(range.start);
101            if z <= k {
102                k -= z;
103                val |= 1 << d;
104                range.start = c + b.rank1(range.start);
105                range.end = c + b.rank1(range.end);
106            } else {
107                range.start = b.rank0(range.start);
108                range.end = b.rank0(range.end);
109            }
110        }
111        val
112    }
113    /// get k-th smallest value out of range
114    pub fn quantile_outer(&self, mut range: Range<usize>, mut k: usize) -> usize {
115        let mut val = 0;
116        let mut orange = 0..self.len;
117        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
118            let z = b.rank0(orange.end) - b.rank0(orange.start) + b.rank0(range.start)
119                - b.rank0(range.end);
120            if z <= k {
121                k -= z;
122                val |= 1 << d;
123                range.start = c + b.rank1(range.start);
124                range.end = c + b.rank1(range.end);
125                orange.start = c + b.rank1(orange.start);
126                orange.end = c + b.rank1(orange.end);
127            } else {
128                range.start = b.rank0(range.start);
129                range.end = b.rank0(range.end);
130                orange.start = b.rank0(orange.start);
131                orange.end = b.rank0(orange.end);
132            }
133        }
134        val
135    }
136    /// the number of value less than val in range
137    pub fn rank_lessthan(&self, val: usize, mut range: Range<usize>) -> usize {
138        let mut res = 0;
139        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
140            if val.access(d) {
141                res += b.rank0(range.end) - b.rank0(range.start);
142                range.start = c + b.rank1(range.start);
143                range.end = c + b.rank1(range.end);
144            } else {
145                range.start = b.rank0(range.start);
146                range.end = b.rank0(range.end);
147            }
148        }
149        res
150    }
151    /// the number of valrange in range
152    pub fn rank_range(&self, valrange: Range<usize>, range: Range<usize>) -> usize {
153        self.rank_lessthan(valrange.end, range.clone()) - self.rank_lessthan(valrange.start, range)
154    }
155    pub fn query_less_than<F>(&self, val: usize, mut range: Range<usize>, mut f: F)
156    where
157        F: FnMut(usize, Range<usize>),
158    {
159        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
160            if val.access(d) {
161                f(d, b.rank0(range.start)..b.rank0(range.end));
162                range.start = c + b.rank1(range.start);
163                range.end = c + b.rank1(range.end);
164            } else {
165                range.start = b.rank0(range.start);
166                range.end = b.rank0(range.end);
167            }
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::{
176        rand_value,
177        tools::{NotEmptySegment as Nes, Xorshift},
178    };
179
180    #[test]
181    fn test_wavelet_matrix() {
182        const N: usize = 1_000;
183        const Q: usize = 1_000;
184        const A: usize = 1 << 8;
185        let mut rng = Xorshift::default();
186        crate::rand!(rng, v: [..A; N]);
187        let wm = WaveletMatrix::new(v.clone(), 8);
188        for (i, v) in v.iter().cloned().enumerate() {
189            assert_eq!(wm.access(i), v);
190        }
191        for ((l, r), a) in rand_value!(rng, [(Nes(N), ..A); Q]) {
192            assert_eq!(
193                wm.rank(a, l..r),
194                v[l..r].iter().filter(|&&x| x == a).count()
195            );
196
197            if wm.rank(a, 0..N) > 0 {
198                let k = rng.random(..wm.rank(a, 0..N));
199                assert_eq!(
200                    wm.select(a, k).unwrap().min(N),
201                    (0..N)
202                        .position(|i| wm.rank(a, 0..i + 1) == k + 1)
203                        .unwrap_or(N)
204                );
205            }
206
207            assert_eq!(
208                (0..r - l).map(|k| wm.quantile(l..r, k)).collect::<Vec<_>>(),
209                {
210                    let mut v: Vec<_> = v[l..r].to_vec();
211                    v.sort_unstable();
212                    v
213                }
214            );
215
216            assert_eq!(
217                (0..N + l - r)
218                    .map(|k| wm.quantile_outer(l..r, k))
219                    .collect::<Vec<_>>(),
220                {
221                    let mut v: Vec<_> = v.to_vec();
222                    v.drain(l..r);
223                    v.sort_unstable();
224                    v
225                }
226            );
227
228            assert_eq!(
229                wm.rank_lessthan(a, l..r),
230                v[l..r].iter().filter(|&&x| x < a).count()
231            );
232
233            let (p, q) = rng.random(Nes(A - 1));
234            assert_eq!(
235                wm.rank_range(p..q, l..r),
236                v[l..r].iter().filter(|&&x| p <= x && x < q).count()
237            );
238        }
239    }
240}