competitive/data_structure/
wavelet_matrix.rs1use super::{BitVector, RankSelectDictionaries};
2use std::ops::Range;
3
4#[derive(Debug, Clone)]
5pub struct WaveletMatrix {
6 len: usize,
7 table: Vec<(usize, BitVector)>,
8}
9
10impl WaveletMatrix {
11 pub fn new<T>(mut v: Vec<T>, bit_length: usize) -> Self
12 where
13 T: Clone + RankSelectDictionaries,
14 {
15 let len = v.len();
16 let mut table = Vec::new();
17 for d in (0..bit_length).rev() {
18 let b: BitVector = v.iter().map(|x| x.access(d)).collect();
19 table.push((b.rank0(len), b));
20 v = v
21 .iter()
22 .filter(|&x| !x.access(d))
23 .chain(v.iter().filter(|&x| x.access(d)))
24 .cloned()
25 .collect();
26 }
27 Self { len, table }
28 }
29 pub fn new_with_init<T, F>(v: Vec<T>, bit_length: usize, mut f: F) -> Self
30 where
31 T: Clone + RankSelectDictionaries,
32 F: FnMut(usize, usize, T),
33 {
34 let this = Self::new(v.clone(), bit_length);
35 for (mut k, v) in v.into_iter().enumerate() {
36 for (d, &(c, ref b)) in this.table.iter().rev().enumerate().rev() {
37 if v.access(d) {
38 k = c + b.rank1(k);
39 } else {
40 k = b.rank0(k);
41 }
42 f(d, k, v.clone());
43 }
44 }
45 this
46 }
47 pub fn access(&self, mut k: usize) -> usize {
49 let mut val = 0;
50 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
51 if b.access(k) {
52 k = c + b.rank1(k);
53 val |= 1 << d;
54 } else {
55 k = b.rank0(k);
56 }
57 }
58 val
59 }
60 pub fn rank(&self, val: usize, mut range: Range<usize>) -> usize {
62 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
63 if val.access(d) {
64 range.start = c + b.rank1(range.start);
65 range.end = c + b.rank1(range.end);
66 } else {
67 range.start = b.rank0(range.start);
68 range.end = b.rank0(range.end);
69 }
70 }
71 range.end - range.start
72 }
73 pub fn select(&self, val: usize, k: usize) -> Option<usize> {
75 if self.rank(val, 0..self.len) <= k {
76 return None;
77 }
78 let mut i = 0;
79 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
80 if val.access(d) {
81 i = c + b.rank1(i);
82 } else {
83 i = b.rank0(i);
84 }
85 }
86 i += k;
87 for &(c, ref b) in self.table.iter().rev() {
88 if i >= c {
89 i = b.select1(i - c).unwrap();
90 } else {
91 i = b.select0(i).unwrap();
92 }
93 }
94 Some(i)
95 }
96 pub fn quantile(&self, mut range: Range<usize>, mut k: usize) -> usize {
98 let mut val = 0;
99 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
100 let z = b.rank0(range.end) - b.rank0(range.start);
101 if z <= k {
102 k -= z;
103 val |= 1 << d;
104 range.start = c + b.rank1(range.start);
105 range.end = c + b.rank1(range.end);
106 } else {
107 range.start = b.rank0(range.start);
108 range.end = b.rank0(range.end);
109 }
110 }
111 val
112 }
113 pub fn quantile_outer(&self, mut range: Range<usize>, mut k: usize) -> usize {
115 let mut val = 0;
116 let mut orange = 0..self.len;
117 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
118 let z = b.rank0(orange.end) - b.rank0(orange.start) + b.rank0(range.start)
119 - b.rank0(range.end);
120 if z <= k {
121 k -= z;
122 val |= 1 << d;
123 range.start = c + b.rank1(range.start);
124 range.end = c + b.rank1(range.end);
125 orange.start = c + b.rank1(orange.start);
126 orange.end = c + b.rank1(orange.end);
127 } else {
128 range.start = b.rank0(range.start);
129 range.end = b.rank0(range.end);
130 orange.start = b.rank0(orange.start);
131 orange.end = b.rank0(orange.end);
132 }
133 }
134 val
135 }
136 pub fn rank_lessthan(&self, val: usize, mut range: Range<usize>) -> usize {
138 let mut res = 0;
139 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
140 if val.access(d) {
141 res += b.rank0(range.end) - b.rank0(range.start);
142 range.start = c + b.rank1(range.start);
143 range.end = c + b.rank1(range.end);
144 } else {
145 range.start = b.rank0(range.start);
146 range.end = b.rank0(range.end);
147 }
148 }
149 res
150 }
151 pub fn rank_range(&self, valrange: Range<usize>, range: Range<usize>) -> usize {
153 self.rank_lessthan(valrange.end, range.clone()) - self.rank_lessthan(valrange.start, range)
154 }
155 pub fn query_less_than<F>(&self, val: usize, mut range: Range<usize>, mut f: F)
156 where
157 F: FnMut(usize, Range<usize>),
158 {
159 for (d, &(c, ref b)) in self.table.iter().rev().enumerate().rev() {
160 if val.access(d) {
161 f(d, b.rank0(range.start)..b.rank0(range.end));
162 range.start = c + b.rank1(range.start);
163 range.end = c + b.rank1(range.end);
164 } else {
165 range.start = b.rank0(range.start);
166 range.end = b.rank0(range.end);
167 }
168 }
169 }
170}
171
172#[test]
173fn test_wavelet_matrix() {
174 use crate::rand_value;
175 use crate::tools::{NotEmptySegment as Nes, Xorshift};
176 const N: usize = 1_000;
177 const Q: usize = 1_000;
178 const A: usize = 1 << 8;
179 let mut rng = Xorshift::new();
180 crate::rand!(rng, v: [..A; N]);
181 let wm = WaveletMatrix::new(v.clone(), 8);
182 for (i, v) in v.iter().cloned().enumerate() {
183 assert_eq!(wm.access(i), v);
184 }
185 for ((l, r), a) in rand_value!(rng, [(Nes(N), ..A); Q]) {
186 assert_eq!(
187 wm.rank(a, l..r),
188 v[l..r].iter().filter(|&&x| x == a).count()
189 );
190
191 if wm.rank(a, 0..N) > 0 {
192 let k = rng.random(..wm.rank(a, 0..N));
193 assert_eq!(
194 wm.select(a, k).unwrap().min(N),
195 (0..N)
196 .position(|i| wm.rank(a, 0..i + 1) == k + 1)
197 .unwrap_or(N)
198 );
199 }
200
201 assert_eq!(
202 (0..r - l).map(|k| wm.quantile(l..r, k)).collect::<Vec<_>>(),
203 {
204 let mut v: Vec<_> = v[l..r].to_vec();
205 v.sort_unstable();
206 v
207 }
208 );
209
210 assert_eq!(
211 (0..N + l - r)
212 .map(|k| wm.quantile_outer(l..r, k))
213 .collect::<Vec<_>>(),
214 {
215 let mut v: Vec<_> = v.to_vec();
216 v.drain(l..r);
217 v.sort_unstable();
218 v
219 }
220 );
221
222 assert_eq!(
223 wm.rank_lessthan(a, l..r),
224 v[l..r].iter().filter(|&&x| x < a).count()
225 );
226
227 let (p, q) = rng.random(Nes(A - 1));
228 assert_eq!(
229 wm.rank_range(p..q, l..r),
230 v[l..r].iter().filter(|&&x| p <= x && x < q).count()
231 );
232 }
233}