competitive/data_structure/
wavelet_matrix.rs

1use super::{BitVector, RankSelectDictionaries};
2use std::ops::Range;
3
4#[derive(Debug, Clone)]
5pub struct WaveletMatrix {
6    len: usize,
7    table: Vec<(usize, BitVector)>,
8}
9
10impl WaveletMatrix {
11    pub fn new<T>(mut v: Vec<T>, bit_length: usize) -> Self
12    where
13        T: Clone + RankSelectDictionaries,
14    {
15        let len = v.len();
16        let mut table = Vec::new();
17        for d in (0..bit_length).rev() {
18            let b: BitVector = v.iter().map(|x| x.access(d)).collect();
19            table.push((b.rank0(len), b));
20            v = v
21                .iter()
22                .filter(|&x| !x.access(d))
23                .chain(v.iter().filter(|&x| x.access(d)))
24                .cloned()
25                .collect();
26        }
27        Self { len, table }
28    }
29    pub fn new_with_init<T, F>(v: Vec<T>, bit_length: usize, mut f: F) -> Self
30    where
31        T: Clone + RankSelectDictionaries,
32        F: FnMut(usize, usize, T),
33    {
34        let this = Self::new(v.clone(), bit_length);
35        for (mut k, v) in v.into_iter().enumerate() {
36            for (d, &(c, ref b)) in this.table.iter().rev().enumerate().rev() {
37                if v.access(d) {
38                    k = c + b.rank1(k);
39                } else {
40                    k = b.rank0(k);
41                }
42                f(d, k, v.clone());
43            }
44        }
45        this
46    }
47    /// get k-th value
48    pub fn access(&self, mut k: usize) -> usize {
49        let mut val = 0;
50        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
51            if b.access(k) {
52                k = c + b.rank1(k);
53                val |= 1 << d;
54            } else {
55                k = b.rank0(k);
56            }
57        }
58        val
59    }
60    /// the number of val in range
61    pub fn rank(&self, val: usize, mut range: Range<usize>) -> usize {
62        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
63            if val.access(d) {
64                range.start = c + b.rank1(range.start);
65                range.end = c + b.rank1(range.end);
66            } else {
67                range.start = b.rank0(range.start);
68                range.end = b.rank0(range.end);
69            }
70        }
71        range.end - range.start
72    }
73    /// index of k-th val
74    pub fn select(&self, val: usize, k: usize) -> Option<usize> {
75        if self.rank(val, 0..self.len) <= k {
76            return None;
77        }
78        let mut i = 0;
79        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
80            if val.access(d) {
81                i = c + b.rank1(i);
82            } else {
83                i = b.rank0(i);
84            }
85        }
86        i += k;
87        for &(c, ref b) in self.table.iter().rev() {
88            if i >= c {
89                i = b.select1(i - c).unwrap();
90            } else {
91                i = b.select0(i).unwrap();
92            }
93        }
94        Some(i)
95    }
96    /// get k-th smallest value in range
97    pub fn quantile(&self, mut range: Range<usize>, mut k: usize) -> usize {
98        let mut val = 0;
99        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
100            let z = b.rank0(range.end) - b.rank0(range.start);
101            if z <= k {
102                k -= z;
103                val |= 1 << d;
104                range.start = c + b.rank1(range.start);
105                range.end = c + b.rank1(range.end);
106            } else {
107                range.start = b.rank0(range.start);
108                range.end = b.rank0(range.end);
109            }
110        }
111        val
112    }
113    /// get k-th smallest value out of range
114    pub fn quantile_outer(&self, mut range: Range<usize>, mut k: usize) -> usize {
115        let mut val = 0;
116        let mut orange = 0..self.len;
117        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
118            let z = b.rank0(orange.end) - b.rank0(orange.start) + b.rank0(range.start)
119                - b.rank0(range.end);
120            if z <= k {
121                k -= z;
122                val |= 1 << d;
123                range.start = c + b.rank1(range.start);
124                range.end = c + b.rank1(range.end);
125                orange.start = c + b.rank1(orange.start);
126                orange.end = c + b.rank1(orange.end);
127            } else {
128                range.start = b.rank0(range.start);
129                range.end = b.rank0(range.end);
130                orange.start = b.rank0(orange.start);
131                orange.end = b.rank0(orange.end);
132            }
133        }
134        val
135    }
136    /// the number of value less than val in range
137    pub fn rank_lessthan(&self, val: usize, mut range: Range<usize>) -> usize {
138        let mut res = 0;
139        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
140            if val.access(d) {
141                res += b.rank0(range.end) - b.rank0(range.start);
142                range.start = c + b.rank1(range.start);
143                range.end = c + b.rank1(range.end);
144            } else {
145                range.start = b.rank0(range.start);
146                range.end = b.rank0(range.end);
147            }
148        }
149        res
150    }
151    /// the number of valrange in range
152    pub fn rank_range(&self, valrange: Range<usize>, range: Range<usize>) -> usize {
153        self.rank_lessthan(valrange.end, range.clone()) - self.rank_lessthan(valrange.start, range)
154    }
155    pub fn query_less_than<F>(&self, val: usize, mut range: Range<usize>, mut f: F)
156    where
157        F: FnMut(usize, Range<usize>),
158    {
159        for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
160            if val.access(d) {
161                f(d, b.rank0(range.start)..b.rank0(range.end));
162                range.start = c + b.rank1(range.start);
163                range.end = c + b.rank1(range.end);
164            } else {
165                range.start = b.rank0(range.start);
166                range.end = b.rank0(range.end);
167            }
168        }
169    }
170}
171
172#[test]
173fn test_wavelet_matrix() {
174    use crate::rand_value;
175    use crate::tools::{NotEmptySegment as Nes, Xorshift};
176    const N: usize = 1_000;
177    const Q: usize = 1_000;
178    const A: usize = 1 << 8;
179    let mut rng = Xorshift::new();
180    crate::rand!(rng, v: [..A; N]);
181    let wm = WaveletMatrix::new(v.clone(), 8);
182    for (i, v) in v.iter().cloned().enumerate() {
183        assert_eq!(wm.access(i), v);
184    }
185    for ((l, r), a) in rand_value!(rng, [(Nes(N), ..A); Q]) {
186        assert_eq!(
187            wm.rank(a, l..r),
188            v[l..r].iter().filter(|&&x| x == a).count()
189        );
190
191        if wm.rank(a, 0..N) > 0 {
192            let k = rng.random(..wm.rank(a, 0..N));
193            assert_eq!(
194                wm.select(a, k).unwrap().min(N),
195                (0..N)
196                    .position(|i| wm.rank(a, 0..i + 1) == k + 1)
197                    .unwrap_or(N)
198            );
199        }
200
201        assert_eq!(
202            (0..r - l).map(|k| wm.quantile(l..r, k)).collect::<Vec<_>>(),
203            {
204                let mut v: Vec<_> = v[l..r].to_vec();
205                v.sort_unstable();
206                v
207            }
208        );
209
210        assert_eq!(
211            (0..N + l - r)
212                .map(|k| wm.quantile_outer(l..r, k))
213                .collect::<Vec<_>>(),
214            {
215                let mut v: Vec<_> = v.to_vec();
216                v.drain(l..r);
217                v.sort_unstable();
218                v
219            }
220        );
221
222        assert_eq!(
223            wm.rank_lessthan(a, l..r),
224            v[l..r].iter().filter(|&&x| x < a).count()
225        );
226
227        let (p, q) = rng.random(Nes(A - 1));
228        assert_eq!(
229            wm.rank_range(p..q, l..r),
230            v[l..r].iter().filter(|&&x| p <= x && x < q).count()
231        );
232    }
233}