1use super::{AbelianGroup, BitVector, Compressor, RankSelectDictionaries, VecCompress};
2use std::{
3 mem::{self, MaybeUninit},
4 ops::Range,
5};
6
7#[derive(Debug, Clone)]
8pub struct WaveletMatrix<T> {
9 len: usize,
10 bit_length: usize,
11 zeros: Vec<usize>,
12 ones_prefix: Vec<usize>,
13 bit_vector: BitVector,
14 compress: VecCompress<T>,
15}
16
17impl<T> WaveletMatrix<T>
18where
19 T: Ord + Clone,
20{
21 pub fn new(v: Vec<T>) -> Self {
22 let len = v.len();
23 let compress: VecCompress<T> = v.iter().cloned().collect();
24 let bit_length = usize::BITS as usize - compress.size().leading_zeros() as usize;
25 let mut indices: Vec<usize> = v
26 .iter()
27 .map(|value| compress.index_exact(value).unwrap())
28 .collect();
29 let mut bit_vector = BitVector::with_capacity(len * bit_length);
30 let mut zeros = Vec::with_capacity(bit_length);
31 for d in (0..bit_length).rev() {
32 let mut zero_count = 0;
33 for &idx in &indices {
34 let bit = ((idx >> d) & 1) != 0;
35 bit_vector.push(bit);
36 if !bit {
37 zero_count += 1;
38 }
39 }
40 zeros.push(zero_count);
41 let mut next = Vec::with_capacity(len);
42 next.extend(
43 indices
44 .iter()
45 .filter(|&&idx| ((idx >> d) & 1) == 0)
46 .copied(),
47 );
48 next.extend(
49 indices
50 .iter()
51 .filter(|&&idx| ((idx >> d) & 1) == 1)
52 .copied(),
53 );
54 indices = next;
55 }
56 let mut ones_prefix = Vec::with_capacity(bit_length);
57 let mut prefix = 0;
58 for &zero in &zeros {
59 ones_prefix.push(prefix);
60 prefix += len - zero;
61 }
62 Self {
63 len,
64 bit_length,
65 zeros,
66 ones_prefix,
67 bit_vector,
68 compress,
69 }
70 }
71
72 pub fn new_with_init<F>(v: Vec<T>, mut f: F) -> Self
73 where
74 F: FnMut(usize, usize, T),
75 {
76 let this = Self::new(v.clone());
77 let indices: Vec<usize> = v
78 .iter()
79 .map(|value| this.compress.index_exact(value).unwrap())
80 .collect();
81 for (mut k, value) in v.into_iter().enumerate() {
82 let idx = indices[k];
83 for d in (0..this.bit_length).rev() {
84 let level = this.level(d);
85 if ((idx >> d) & 1) != 0 {
86 k = this.zeros[level] + this.rank1(level, k);
87 } else {
88 k = this.rank0(level, k);
89 }
90 f(d, k, value.clone());
91 }
92 }
93 this
94 }
95
96 fn level(&self, d: usize) -> usize {
97 self.bit_length - 1 - d
98 }
99
100 fn rank1(&self, level: usize, k: usize) -> usize {
101 let offset = level * self.len;
102 self.bit_vector.rank1(offset + k) - self.ones_prefix[level]
103 }
104
105 fn rank0(&self, level: usize, k: usize) -> usize {
106 k - self.rank1(level, k)
107 }
108
109 fn rank_by_index(&self, idx: usize, mut range: Range<usize>) -> usize {
110 for d in (0..self.bit_length).rev() {
111 let level = self.level(d);
112 if ((idx >> d) & 1) != 0 {
113 range.start = self.zeros[level] + self.rank1(level, range.start);
114 range.end = self.zeros[level] + self.rank1(level, range.end);
115 } else {
116 range.start = self.rank0(level, range.start);
117 range.end = self.rank0(level, range.end);
118 }
119 }
120 range.end - range.start
121 }
122
123 pub fn access(&self, mut k: usize) -> T {
125 let mut idx = 0;
126 for d in (0..self.bit_length).rev() {
127 let level = self.level(d);
128 if self.bit_vector.access(level * self.len + k) {
129 idx |= 1 << d;
130 k = self.zeros[level] + self.rank1(level, k);
131 } else {
132 k = self.rank0(level, k);
133 }
134 }
135 self.compress.values()[idx].clone()
136 }
137
138 pub fn rank(&self, val: T, range: Range<usize>) -> usize {
140 match self.compress.index_exact(&val) {
141 Some(idx) => self.rank_by_index(idx, range),
142 None => 0,
143 }
144 }
145
146 pub fn select(&self, val: T, k: usize) -> Option<usize> {
148 let idx = self.compress.index_exact(&val)?;
149 if self.rank_by_index(idx, 0..self.len) <= k {
150 return None;
151 }
152 let mut i = 0;
153 for d in (0..self.bit_length).rev() {
154 let level = self.level(d);
155 if ((idx >> d) & 1) != 0 {
156 i = self.zeros[level] + self.rank1(level, i);
157 } else {
158 i = self.rank0(level, i);
159 }
160 }
161 i += k;
162 for level in (0..self.bit_length).rev() {
163 let offset = level * self.len;
164 if i >= self.zeros[level] {
165 let global_k = self.ones_prefix[level] + (i - self.zeros[level]);
166 let pos = self.bit_vector.select1(global_k).unwrap();
167 i = pos - offset;
168 } else {
169 let zeros_before = offset - self.ones_prefix[level];
170 let global_k = zeros_before + i;
171 let pos = self.bit_vector.select0(global_k).unwrap();
172 i = pos - offset;
173 }
174 }
175 Some(i)
176 }
177
178 pub fn quantile(&self, mut range: Range<usize>, mut k: usize) -> T {
180 let mut idx = 0;
181 for d in (0..self.bit_length).rev() {
182 let level = self.level(d);
183 let z = self.rank0(level, range.end) - self.rank0(level, range.start);
184 if z <= k {
185 k -= z;
186 idx |= 1 << d;
187 range.start = self.zeros[level] + self.rank1(level, range.start);
188 range.end = self.zeros[level] + self.rank1(level, range.end);
189 } else {
190 range.start = self.rank0(level, range.start);
191 range.end = self.rank0(level, range.end);
192 }
193 }
194 self.compress.values()[idx].clone()
195 }
196
197 pub fn quantile_outer(&self, mut range: Range<usize>, mut k: usize) -> T {
199 let mut idx = 0;
200 let mut orange = 0..self.len;
201 for d in (0..self.bit_length).rev() {
202 let level = self.level(d);
203 let z = self.rank0(level, orange.end) - self.rank0(level, orange.start)
204 + self.rank0(level, range.start)
205 - self.rank0(level, range.end);
206 if z <= k {
207 k -= z;
208 idx |= 1 << d;
209 range.start = self.zeros[level] + self.rank1(level, range.start);
210 range.end = self.zeros[level] + self.rank1(level, range.end);
211 orange.start = self.zeros[level] + self.rank1(level, orange.start);
212 orange.end = self.zeros[level] + self.rank1(level, orange.end);
213 } else {
214 range.start = self.rank0(level, range.start);
215 range.end = self.rank0(level, range.end);
216 orange.start = self.rank0(level, orange.start);
217 orange.end = self.rank0(level, orange.end);
218 }
219 }
220 self.compress.values()[idx].clone()
221 }
222
223 pub fn rank_lessthan(&self, val: T, mut range: Range<usize>) -> usize {
225 let idx = self.compress.index_lower_bound(&val);
226 let mut res = 0;
227 for d in (0..self.bit_length).rev() {
228 let level = self.level(d);
229 if ((idx >> d) & 1) != 0 {
230 res += self.rank0(level, range.end) - self.rank0(level, range.start);
231 range.start = self.zeros[level] + self.rank1(level, range.start);
232 range.end = self.zeros[level] + self.rank1(level, range.end);
233 } else {
234 range.start = self.rank0(level, range.start);
235 range.end = self.rank0(level, range.end);
236 }
237 }
238 res
239 }
240
241 pub fn rank_range(&self, valrange: Range<T>, range: Range<usize>) -> usize {
243 self.rank_lessthan(valrange.end, range.clone()) - self.rank_lessthan(valrange.start, range)
244 }
245
246 pub fn query_less_than<F>(&self, val: T, mut range: Range<usize>, mut f: F)
247 where
248 F: FnMut(usize, Range<usize>),
249 {
250 let idx = self.compress.index_lower_bound(&val);
251 for d in (0..self.bit_length).rev() {
252 let level = self.level(d);
253 if ((idx >> d) & 1) != 0 {
254 f(
255 d,
256 self.rank0(level, range.start)..self.rank0(level, range.end),
257 );
258 range.start = self.zeros[level] + self.rank1(level, range.start);
259 range.end = self.zeros[level] + self.rank1(level, range.end);
260 } else {
261 range.start = self.rank0(level, range.start);
262 range.end = self.rank0(level, range.end);
263 }
264 }
265 }
266
267 pub fn build_fold<M>(&self, weights: &[M::T]) -> WaveletMatrixFold<'_, T, M>
268 where
269 M: AbelianGroup,
270 {
271 let len = self.len;
272 assert_eq!(weights.len(), len);
273 let mut prefix = Vec::with_capacity((self.bit_length + 1) * (len + 1));
274 let mut current: Vec<M::T> = weights.to_vec();
275 for level in 0..self.bit_length {
276 let offset = level * len;
277 let zeros = self.zeros[level];
278 let mut next: Vec<MaybeUninit<M::T>> = Vec::with_capacity(len);
279 next.resize_with(len, MaybeUninit::uninit);
280 let mut zero_pos = 0;
281 let mut one_pos = zeros;
282 let mut acc = M::unit();
283 prefix.push(acc.clone());
284 for (i, w) in current.into_iter().enumerate() {
285 acc = M::operate(&acc, &w);
286 prefix.push(acc.clone());
287 if self.bit_vector.access(offset + i) {
288 next[one_pos].write(w);
289 one_pos += 1;
290 } else {
291 next[zero_pos].write(w);
292 zero_pos += 1;
293 }
294 }
295 debug_assert_eq!(zero_pos, zeros);
296 debug_assert_eq!(one_pos, len);
297 let next = unsafe {
298 let mut next = mem::ManuallyDrop::new(next);
299 let ptr = next.as_mut_ptr() as *mut M::T;
300 let len = next.len();
301 let cap = next.capacity();
302 Vec::from_raw_parts(ptr, len, cap)
303 };
304 current = next;
305 }
306 let mut acc = M::unit();
307 prefix.push(acc.clone());
308 for w in current.into_iter() {
309 acc = M::operate(&acc, &w);
310 prefix.push(acc.clone());
311 }
312 WaveletMatrixFold {
313 wavelet_matrix: self,
314 prefix,
315 }
316 }
317}
318
319#[derive(Debug, Clone)]
320pub struct WaveletMatrixFold<'a, T, M>
321where
322 T: Ord + Clone,
323 M: AbelianGroup,
324{
325 wavelet_matrix: &'a WaveletMatrix<T>,
326 prefix: Vec<M::T>,
327}
328
329impl<'a, T, M> WaveletMatrixFold<'a, T, M>
330where
331 T: Ord + Clone,
332 M: AbelianGroup,
333{
334 #[inline]
335 fn range_sum(&self, level: usize, range: Range<usize>) -> M::T {
336 let offset = level * (self.wavelet_matrix.len + 1);
337 unsafe {
338 M::rinv_operate(
339 self.prefix.get_unchecked(offset + range.end),
340 self.prefix.get_unchecked(offset + range.start),
341 )
342 }
343 }
344
345 pub fn fold_lessthan(&self, val: T, range: Range<usize>) -> M::T {
346 self.fold_lessthan_with_count(val, range).1
347 }
348
349 pub fn fold_lessthan_with_count(&self, val: T, mut range: Range<usize>) -> (usize, M::T) {
350 debug_assert!(range.end <= self.wavelet_matrix.len);
351 let idx = self.wavelet_matrix.compress.index_lower_bound(&val);
352 let mut count = 0;
353 let mut sum = M::unit();
354 for d in (0..self.wavelet_matrix.bit_length).rev() {
355 let level = self.wavelet_matrix.level(d);
356 let start0 = self.wavelet_matrix.rank0(level, range.start);
357 let end0 = self.wavelet_matrix.rank0(level, range.end);
358 if ((idx >> d) & 1) != 0 {
359 count += end0 - start0;
360 sum = M::operate(&sum, &self.range_sum(level + 1, start0..end0));
361 range.start = self.wavelet_matrix.zeros[level] + (range.start - start0);
362 range.end = self.wavelet_matrix.zeros[level] + (range.end - end0);
363 } else {
364 range.start = start0;
365 range.end = end0;
366 }
367 }
368 (count, sum)
369 }
370
371 pub fn fold_range(&self, valrange: Range<T>, range: Range<usize>) -> M::T {
372 M::rinv_operate(
373 &self.fold_lessthan(valrange.end, range.clone()),
374 &self.fold_lessthan(valrange.start, range),
375 )
376 }
377
378 pub fn fold_range_with_count(&self, valrange: Range<T>, range: Range<usize>) -> (usize, M::T) {
379 let (count_upper, sum_upper) = self.fold_lessthan_with_count(valrange.end, range.clone());
380 let (count_lower, sum_lower) = self.fold_lessthan_with_count(valrange.start, range);
381 (
382 count_upper - count_lower,
383 M::rinv_operate(&sum_upper, &sum_lower),
384 )
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::{
392 algebra::AdditiveOperation,
393 rand_value,
394 tools::{NotEmptySegment as Nes, Xorshift},
395 };
396
397 #[test]
398 fn test_wavelet_matrix() {
399 const N: usize = 1_000;
400 const Q: usize = 1_000;
401 const A: usize = 1 << 8;
402 const B: i64 = 1_000_000_000;
403 let mut rng = Xorshift::default();
404 crate::rand!(rng, v: [..A; N]);
405 crate::rand!(rng, w: [-B..B; N]);
406 let wm = WaveletMatrix::new(v.clone());
407 let fold = wm.build_fold::<AdditiveOperation<i64>>(&w);
408 for (i, v) in v.iter().cloned().enumerate() {
409 assert_eq!(wm.access(i), v);
410 }
411 assert_eq!(fold.fold_lessthan(A, 0..N), w.iter().sum::<i64>());
412 for ((l, r), a) in rand_value!(rng, [(Nes(N), ..A); Q]) {
413 assert_eq!(
414 wm.rank(a, l..r),
415 v[l..r].iter().filter(|&&x| x == a).count()
416 );
417
418 if wm.rank(a, 0..N) > 0 {
419 let k = rng.random(..wm.rank(a, 0..N));
420 assert_eq!(
421 wm.select(a, k).unwrap().min(N),
422 (0..N)
423 .position(|i| wm.rank(a, 0..i + 1) == k + 1)
424 .unwrap_or(N)
425 );
426 }
427
428 assert_eq!(
429 (0..r - l).map(|k| wm.quantile(l..r, k)).collect::<Vec<_>>(),
430 {
431 let mut v: Vec<_> = v[l..r].to_vec();
432 v.sort_unstable();
433 v
434 }
435 );
436
437 assert_eq!(
438 (0..N + l - r)
439 .map(|k| wm.quantile_outer(l..r, k))
440 .collect::<Vec<_>>(),
441 {
442 let mut v: Vec<_> = v.to_vec();
443 v.drain(l..r);
444 v.sort_unstable();
445 v
446 }
447 );
448
449 assert_eq!(
450 wm.rank_lessthan(a, l..r),
451 v[l..r].iter().filter(|&&x| x < a).count()
452 );
453
454 let mut count_lt = 0usize;
455 let mut sum_lt = 0i64;
456 for (&value, &weight) in v[l..r].iter().zip(w[l..r].iter()) {
457 if value < a {
458 count_lt += 1;
459 sum_lt += weight;
460 }
461 }
462 assert_eq!(fold.fold_lessthan_with_count(a, l..r), (count_lt, sum_lt));
463 assert_eq!(fold.fold_lessthan(A, l..r), w[l..r].iter().sum::<i64>());
464
465 let (p, q) = rng.random(Nes(A - 1));
466 assert_eq!(
467 wm.rank_range(p..q, l..r),
468 v[l..r].iter().filter(|&&x| p <= x && x < q).count()
469 );
470 let mut count_range = 0usize;
471 let mut sum_range = 0i64;
472 for (&value, &weight) in v[l..r].iter().zip(w[l..r].iter()) {
473 if p <= value && value < q {
474 count_range += 1;
475 sum_range += weight;
476 }
477 }
478 assert_eq!(fold.fold_range(p..q, l..r), sum_range);
479 assert_eq!(
480 fold.fold_range_with_count(p..q, l..r),
481 (count_range, sum_range)
482 );
483 }
484 }
485}