Skip to main content

competitive/data_structure/
wavelet_matrix.rs

1use super::{AbelianGroup, BitVector, Compressor, RankSelectDictionaries, VecCompress};
2use std::{
3    mem::{self, MaybeUninit},
4    ops::Range,
5};
6
7#[derive(Debug, Clone)]
8pub struct WaveletMatrix<T> {
9    len: usize,
10    bit_length: usize,
11    zeros: Vec<usize>,
12    ones_prefix: Vec<usize>,
13    bit_vector: BitVector,
14    compress: VecCompress<T>,
15}
16
17impl<T> WaveletMatrix<T>
18where
19    T: Ord + Clone,
20{
21    pub fn new(v: Vec<T>) -> Self {
22        let len = v.len();
23        let compress: VecCompress<T> = v.iter().cloned().collect();
24        let bit_length = usize::BITS as usize - compress.size().leading_zeros() as usize;
25        let mut indices: Vec<usize> = v
26            .iter()
27            .map(|value| compress.index_exact(value).unwrap())
28            .collect();
29        let mut bit_vector = BitVector::with_capacity(len * bit_length);
30        let mut zeros = Vec::with_capacity(bit_length);
31        for d in (0..bit_length).rev() {
32            let mut zero_count = 0;
33            for &idx in &indices {
34                let bit = ((idx >> d) & 1) != 0;
35                bit_vector.push(bit);
36                if !bit {
37                    zero_count += 1;
38                }
39            }
40            zeros.push(zero_count);
41            let mut next = Vec::with_capacity(len);
42            next.extend(
43                indices
44                    .iter()
45                    .filter(|&&idx| ((idx >> d) & 1) == 0)
46                    .copied(),
47            );
48            next.extend(
49                indices
50                    .iter()
51                    .filter(|&&idx| ((idx >> d) & 1) == 1)
52                    .copied(),
53            );
54            indices = next;
55        }
56        let mut ones_prefix = Vec::with_capacity(bit_length);
57        let mut prefix = 0;
58        for &zero in &zeros {
59            ones_prefix.push(prefix);
60            prefix += len - zero;
61        }
62        Self {
63            len,
64            bit_length,
65            zeros,
66            ones_prefix,
67            bit_vector,
68            compress,
69        }
70    }
71
72    pub fn new_with_init<F>(v: Vec<T>, mut f: F) -> Self
73    where
74        F: FnMut(usize, usize, T),
75    {
76        let this = Self::new(v.clone());
77        let indices: Vec<usize> = v
78            .iter()
79            .map(|value| this.compress.index_exact(value).unwrap())
80            .collect();
81        for (mut k, value) in v.into_iter().enumerate() {
82            let idx = indices[k];
83            for d in (0..this.bit_length).rev() {
84                let level = this.level(d);
85                if ((idx >> d) & 1) != 0 {
86                    k = this.zeros[level] + this.rank1(level, k);
87                } else {
88                    k = this.rank0(level, k);
89                }
90                f(d, k, value.clone());
91            }
92        }
93        this
94    }
95
96    fn level(&self, d: usize) -> usize {
97        self.bit_length - 1 - d
98    }
99
100    fn rank1(&self, level: usize, k: usize) -> usize {
101        let offset = level * self.len;
102        self.bit_vector.rank1(offset + k) - self.ones_prefix[level]
103    }
104
105    fn rank0(&self, level: usize, k: usize) -> usize {
106        k - self.rank1(level, k)
107    }
108
109    fn rank_by_index(&self, idx: usize, mut range: Range<usize>) -> usize {
110        for d in (0..self.bit_length).rev() {
111            let level = self.level(d);
112            if ((idx >> d) & 1) != 0 {
113                range.start = self.zeros[level] + self.rank1(level, range.start);
114                range.end = self.zeros[level] + self.rank1(level, range.end);
115            } else {
116                range.start = self.rank0(level, range.start);
117                range.end = self.rank0(level, range.end);
118            }
119        }
120        range.end - range.start
121    }
122
123    /// get k-th value
124    pub fn access(&self, mut k: usize) -> T {
125        let mut idx = 0;
126        for d in (0..self.bit_length).rev() {
127            let level = self.level(d);
128            if self.bit_vector.access(level * self.len + k) {
129                idx |= 1 << d;
130                k = self.zeros[level] + self.rank1(level, k);
131            } else {
132                k = self.rank0(level, k);
133            }
134        }
135        self.compress.values()[idx].clone()
136    }
137
138    /// the number of val in range
139    pub fn rank(&self, val: T, range: Range<usize>) -> usize {
140        match self.compress.index_exact(&val) {
141            Some(idx) => self.rank_by_index(idx, range),
142            None => 0,
143        }
144    }
145
146    /// index of k-th val
147    pub fn select(&self, val: T, k: usize) -> Option<usize> {
148        let idx = self.compress.index_exact(&val)?;
149        if self.rank_by_index(idx, 0..self.len) <= k {
150            return None;
151        }
152        let mut i = 0;
153        for d in (0..self.bit_length).rev() {
154            let level = self.level(d);
155            if ((idx >> d) & 1) != 0 {
156                i = self.zeros[level] + self.rank1(level, i);
157            } else {
158                i = self.rank0(level, i);
159            }
160        }
161        i += k;
162        for level in (0..self.bit_length).rev() {
163            let offset = level * self.len;
164            if i >= self.zeros[level] {
165                let global_k = self.ones_prefix[level] + (i - self.zeros[level]);
166                let pos = self.bit_vector.select1(global_k).unwrap();
167                i = pos - offset;
168            } else {
169                let zeros_before = offset - self.ones_prefix[level];
170                let global_k = zeros_before + i;
171                let pos = self.bit_vector.select0(global_k).unwrap();
172                i = pos - offset;
173            }
174        }
175        Some(i)
176    }
177
178    /// get k-th smallest value in range
179    pub fn quantile(&self, mut range: Range<usize>, mut k: usize) -> T {
180        let mut idx = 0;
181        for d in (0..self.bit_length).rev() {
182            let level = self.level(d);
183            let z = self.rank0(level, range.end) - self.rank0(level, range.start);
184            if z <= k {
185                k -= z;
186                idx |= 1 << d;
187                range.start = self.zeros[level] + self.rank1(level, range.start);
188                range.end = self.zeros[level] + self.rank1(level, range.end);
189            } else {
190                range.start = self.rank0(level, range.start);
191                range.end = self.rank0(level, range.end);
192            }
193        }
194        self.compress.values()[idx].clone()
195    }
196
197    /// get k-th smallest value out of range
198    pub fn quantile_outer(&self, mut range: Range<usize>, mut k: usize) -> T {
199        let mut idx = 0;
200        let mut orange = 0..self.len;
201        for d in (0..self.bit_length).rev() {
202            let level = self.level(d);
203            let z = self.rank0(level, orange.end) - self.rank0(level, orange.start)
204                + self.rank0(level, range.start)
205                - self.rank0(level, range.end);
206            if z <= k {
207                k -= z;
208                idx |= 1 << d;
209                range.start = self.zeros[level] + self.rank1(level, range.start);
210                range.end = self.zeros[level] + self.rank1(level, range.end);
211                orange.start = self.zeros[level] + self.rank1(level, orange.start);
212                orange.end = self.zeros[level] + self.rank1(level, orange.end);
213            } else {
214                range.start = self.rank0(level, range.start);
215                range.end = self.rank0(level, range.end);
216                orange.start = self.rank0(level, orange.start);
217                orange.end = self.rank0(level, orange.end);
218            }
219        }
220        self.compress.values()[idx].clone()
221    }
222
223    /// the number of value less than val in range
224    pub fn rank_lessthan(&self, val: T, mut range: Range<usize>) -> usize {
225        let idx = self.compress.index_lower_bound(&val);
226        let mut res = 0;
227        for d in (0..self.bit_length).rev() {
228            let level = self.level(d);
229            if ((idx >> d) & 1) != 0 {
230                res += self.rank0(level, range.end) - self.rank0(level, range.start);
231                range.start = self.zeros[level] + self.rank1(level, range.start);
232                range.end = self.zeros[level] + self.rank1(level, range.end);
233            } else {
234                range.start = self.rank0(level, range.start);
235                range.end = self.rank0(level, range.end);
236            }
237        }
238        res
239    }
240
241    /// the number of valrange in range
242    pub fn rank_range(&self, valrange: Range<T>, range: Range<usize>) -> usize {
243        self.rank_lessthan(valrange.end, range.clone()) - self.rank_lessthan(valrange.start, range)
244    }
245
246    pub fn query_less_than<F>(&self, val: T, mut range: Range<usize>, mut f: F)
247    where
248        F: FnMut(usize, Range<usize>),
249    {
250        let idx = self.compress.index_lower_bound(&val);
251        for d in (0..self.bit_length).rev() {
252            let level = self.level(d);
253            if ((idx >> d) & 1) != 0 {
254                f(
255                    d,
256                    self.rank0(level, range.start)..self.rank0(level, range.end),
257                );
258                range.start = self.zeros[level] + self.rank1(level, range.start);
259                range.end = self.zeros[level] + self.rank1(level, range.end);
260            } else {
261                range.start = self.rank0(level, range.start);
262                range.end = self.rank0(level, range.end);
263            }
264        }
265    }
266
267    pub fn build_fold<M>(&self, weights: &[M::T]) -> WaveletMatrixFold<'_, T, M>
268    where
269        M: AbelianGroup,
270    {
271        let len = self.len;
272        assert_eq!(weights.len(), len);
273        let mut prefix = Vec::with_capacity((self.bit_length + 1) * (len + 1));
274        let mut current: Vec<M::T> = weights.to_vec();
275        for level in 0..self.bit_length {
276            let offset = level * len;
277            let zeros = self.zeros[level];
278            let mut next: Vec<MaybeUninit<M::T>> = Vec::with_capacity(len);
279            next.resize_with(len, MaybeUninit::uninit);
280            let mut zero_pos = 0;
281            let mut one_pos = zeros;
282            let mut acc = M::unit();
283            prefix.push(acc.clone());
284            for (i, w) in current.into_iter().enumerate() {
285                acc = M::operate(&acc, &w);
286                prefix.push(acc.clone());
287                if self.bit_vector.access(offset + i) {
288                    next[one_pos].write(w);
289                    one_pos += 1;
290                } else {
291                    next[zero_pos].write(w);
292                    zero_pos += 1;
293                }
294            }
295            debug_assert_eq!(zero_pos, zeros);
296            debug_assert_eq!(one_pos, len);
297            let next = unsafe {
298                let mut next = mem::ManuallyDrop::new(next);
299                let ptr = next.as_mut_ptr() as *mut M::T;
300                let len = next.len();
301                let cap = next.capacity();
302                Vec::from_raw_parts(ptr, len, cap)
303            };
304            current = next;
305        }
306        let mut acc = M::unit();
307        prefix.push(acc.clone());
308        for w in current.into_iter() {
309            acc = M::operate(&acc, &w);
310            prefix.push(acc.clone());
311        }
312        WaveletMatrixFold {
313            wavelet_matrix: self,
314            prefix,
315        }
316    }
317}
318
319#[derive(Debug, Clone)]
320pub struct WaveletMatrixFold<'a, T, M>
321where
322    T: Ord + Clone,
323    M: AbelianGroup,
324{
325    wavelet_matrix: &'a WaveletMatrix<T>,
326    prefix: Vec<M::T>,
327}
328
329impl<'a, T, M> WaveletMatrixFold<'a, T, M>
330where
331    T: Ord + Clone,
332    M: AbelianGroup,
333{
334    #[inline]
335    fn range_sum(&self, level: usize, range: Range<usize>) -> M::T {
336        let offset = level * (self.wavelet_matrix.len + 1);
337        unsafe {
338            M::rinv_operate(
339                self.prefix.get_unchecked(offset + range.end),
340                self.prefix.get_unchecked(offset + range.start),
341            )
342        }
343    }
344
345    pub fn fold_lessthan(&self, val: T, range: Range<usize>) -> M::T {
346        self.fold_lessthan_with_count(val, range).1
347    }
348
349    pub fn fold_lessthan_with_count(&self, val: T, mut range: Range<usize>) -> (usize, M::T) {
350        debug_assert!(range.end <= self.wavelet_matrix.len);
351        let idx = self.wavelet_matrix.compress.index_lower_bound(&val);
352        let mut count = 0;
353        let mut sum = M::unit();
354        for d in (0..self.wavelet_matrix.bit_length).rev() {
355            let level = self.wavelet_matrix.level(d);
356            let start0 = self.wavelet_matrix.rank0(level, range.start);
357            let end0 = self.wavelet_matrix.rank0(level, range.end);
358            if ((idx >> d) & 1) != 0 {
359                count += end0 - start0;
360                sum = M::operate(&sum, &self.range_sum(level + 1, start0..end0));
361                range.start = self.wavelet_matrix.zeros[level] + (range.start - start0);
362                range.end = self.wavelet_matrix.zeros[level] + (range.end - end0);
363            } else {
364                range.start = start0;
365                range.end = end0;
366            }
367        }
368        (count, sum)
369    }
370
371    pub fn fold_range(&self, valrange: Range<T>, range: Range<usize>) -> M::T {
372        M::rinv_operate(
373            &self.fold_lessthan(valrange.end, range.clone()),
374            &self.fold_lessthan(valrange.start, range),
375        )
376    }
377
378    pub fn fold_range_with_count(&self, valrange: Range<T>, range: Range<usize>) -> (usize, M::T) {
379        let (count_upper, sum_upper) = self.fold_lessthan_with_count(valrange.end, range.clone());
380        let (count_lower, sum_lower) = self.fold_lessthan_with_count(valrange.start, range);
381        (
382            count_upper - count_lower,
383            M::rinv_operate(&sum_upper, &sum_lower),
384        )
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::{
392        algebra::AdditiveOperation,
393        rand_value,
394        tools::{NotEmptySegment as Nes, Xorshift},
395    };
396
397    #[test]
398    fn test_wavelet_matrix() {
399        const N: usize = 1_000;
400        const Q: usize = 1_000;
401        const A: usize = 1 << 8;
402        const B: i64 = 1_000_000_000;
403        let mut rng = Xorshift::default();
404        crate::rand!(rng, v: [..A; N]);
405        crate::rand!(rng, w: [-B..B; N]);
406        let wm = WaveletMatrix::new(v.clone());
407        let fold = wm.build_fold::<AdditiveOperation<i64>>(&w);
408        for (i, v) in v.iter().cloned().enumerate() {
409            assert_eq!(wm.access(i), v);
410        }
411        assert_eq!(fold.fold_lessthan(A, 0..N), w.iter().sum::<i64>());
412        for ((l, r), a) in rand_value!(rng, [(Nes(N), ..A); Q]) {
413            assert_eq!(
414                wm.rank(a, l..r),
415                v[l..r].iter().filter(|&&x| x == a).count()
416            );
417
418            if wm.rank(a, 0..N) > 0 {
419                let k = rng.random(..wm.rank(a, 0..N));
420                assert_eq!(
421                    wm.select(a, k).unwrap().min(N),
422                    (0..N)
423                        .position(|i| wm.rank(a, 0..i + 1) == k + 1)
424                        .unwrap_or(N)
425                );
426            }
427
428            assert_eq!(
429                (0..r - l).map(|k| wm.quantile(l..r, k)).collect::<Vec<_>>(),
430                {
431                    let mut v: Vec<_> = v[l..r].to_vec();
432                    v.sort_unstable();
433                    v
434                }
435            );
436
437            assert_eq!(
438                (0..N + l - r)
439                    .map(|k| wm.quantile_outer(l..r, k))
440                    .collect::<Vec<_>>(),
441                {
442                    let mut v: Vec<_> = v.to_vec();
443                    v.drain(l..r);
444                    v.sort_unstable();
445                    v
446                }
447            );
448
449            assert_eq!(
450                wm.rank_lessthan(a, l..r),
451                v[l..r].iter().filter(|&&x| x < a).count()
452            );
453
454            let mut count_lt = 0usize;
455            let mut sum_lt = 0i64;
456            for (&value, &weight) in v[l..r].iter().zip(w[l..r].iter()) {
457                if value < a {
458                    count_lt += 1;
459                    sum_lt += weight;
460                }
461            }
462            assert_eq!(fold.fold_lessthan_with_count(a, l..r), (count_lt, sum_lt));
463            assert_eq!(fold.fold_lessthan(A, l..r), w[l..r].iter().sum::<i64>());
464
465            let (p, q) = rng.random(Nes(A - 1));
466            assert_eq!(
467                wm.rank_range(p..q, l..r),
468                v[l..r].iter().filter(|&&x| p <= x && x < q).count()
469            );
470            let mut count_range = 0usize;
471            let mut sum_range = 0i64;
472            for (&value, &weight) in v[l..r].iter().zip(w[l..r].iter()) {
473                if p <= value && value < q {
474                    count_range += 1;
475                    sum_range += weight;
476                }
477            }
478            assert_eq!(fold.fold_range(p..q, l..r), sum_range);
479            assert_eq!(
480                fold.fold_range_with_count(p..q, l..r),
481                (count_range, sum_range)
482            );
483        }
484    }
485}