competitive/data_structure/
compressed_binary_indexed_tree.rs1use super::{GetDistinctMut, Monoid, SliceBisectExt};
2use std::{
3 fmt::{self, Debug},
4 marker::PhantomData,
5 ops::{Bound, RangeBounds},
6};
7
8pub struct CompressedBinaryIndexedTree<M, X, Inner>
9where
10 M: Monoid,
11{
12 compress: Vec<X>,
13 bits: Vec<Inner>,
14 _marker: PhantomData<fn() -> M>,
15}
16impl<M, X, Inner> Debug for CompressedBinaryIndexedTree<M, X, Inner>
17where
18 M: Monoid,
19 X: Debug,
20 Inner: Debug,
21{
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 f.debug_struct("CompressedBinaryIndexedTree")
24 .field("compress", &self.compress)
25 .field("bits", &self.bits)
26 .finish()
27 }
28}
29impl<M, X, Inner> Clone for CompressedBinaryIndexedTree<M, X, Inner>
30where
31 M: Monoid,
32 X: Clone,
33 Inner: Clone,
34{
35 fn clone(&self) -> Self {
36 Self {
37 compress: self.compress.clone(),
38 bits: self.bits.clone(),
39 _marker: self._marker,
40 }
41 }
42}
43impl<M, X, Inner> Default for CompressedBinaryIndexedTree<M, X, Inner>
44where
45 M: Monoid,
46{
47 fn default() -> Self {
48 Self {
49 compress: Default::default(),
50 bits: Default::default(),
51 _marker: Default::default(),
52 }
53 }
54}
55#[repr(transparent)]
56pub struct Tag<M>(M::T)
57where
58 M: Monoid;
59impl<M> Debug for Tag<M>
60where
61 M: Monoid,
62 M::T: Debug,
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 self.0.fmt(f)
66 }
67}
68impl<M> Clone for Tag<M>
69where
70 M: Monoid,
71{
72 fn clone(&self) -> Self {
73 Self(self.0.clone())
74 }
75}
76#[allow(unused_macros)]
77macro_rules! impl_compressed_binary_indexed_tree {
78 (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident) => {
79 ($($l)* $T $($r)*,)
80 };
81 (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident $($Rest:ident)+) => {
82 ($($l)* $T $($r)*, impl_compressed_binary_indexed_tree!(@tuple ($($l)*) ($($r)*) $($Rest)+))
83 };
84 (@cst $M:ident) => {
85 Tag<$M>
86 };
87 (@cst $M:ident $T:ident $($Rest:ident)*) => {
88 CompressedBinaryIndexedTree<$M, $T, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>
89 };
90 (@from_iter $M:ident $points:ident $T:ident) => {{
91 let mut compress: Vec<_> = $points.into_iter().map(|t| t.0.clone()).collect();
92 compress.sort_unstable();
93 compress.dedup();
94 let n = compress.len();
95 Self {
96 compress,
97 bits: vec![Tag(M::unit()); n + 1],
98 _marker: PhantomData,
99 }
100 }};
101 (@from_iter $M:ident $points:ident $T:ident $U:ident $($Rest:ident)*) => {{
102 let mut compress: Vec<_> = $points.clone().into_iter().map(|t| t.0.clone()).collect();
103 compress.sort_unstable();
104 compress.dedup();
105 let n = compress.len();
106 let mut bits = vec![CompressedBinaryIndexedTree::default(); n + 1];
107 let mut ps = vec![vec![]; n + 1];
108 for (x, q) in $points {
109 let i = compress.position_bisect(|c| x <= c);
110 ps[i + 1].push(q);
111 }
112 for i in 1..=n {
113 bits[i] = CompressedBinaryIndexedTree::<_, _, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
114 let j = i + (i & (!i + 1));
115 if j <= n {
116 let (s, ns) = ps.get_distinct_mut((i, j));
117 ns.append(s);
118 }
119 }
120 Self {
121 compress,
122 bits,
123 _marker: PhantomData,
124 }
125 }};
126 (@acc $e:expr, $rng:ident $T:ident) => {
127 $e.0
128 };
129 (@acc $e:expr, $rng:ident $T:ident $($Rest:ident)+) => {
130 $e.accumulate(&$rng.1)
131 };
132 (@update $e:expr, $M:ident $key:ident $x:ident $T:ident) => {
133 $M::operate_assign(&mut $e.0, $x);
134 };
135 (@update $e:expr, $M:ident $key:ident $x:ident $T:ident $($Rest:ident)+) => {
136 $e.update(&$key.1, $x);
137 };
138 (@impl $C:ident $($T:ident)*, $($Q:ident)*) => {
139 impl<M, $($T,)*> impl_compressed_binary_indexed_tree!(@cst M $($T)*)
140 where
141 M: Monoid,
142 $($T: Clone + Ord,)*
143 {
144 pub fn new(points: &[impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)]) -> Self {
145 Self::from_iter(points)
146 }
147 fn from_iter<'a, Iter>(points: Iter) -> Self
148 where
149 $($T: 'a,)*
150 Iter: IntoIterator<Item = &'a impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)> + Clone,
151 {
152 impl_compressed_binary_indexed_tree!(@from_iter M points $($T)*)
153 }
154 pub fn accumulate<$($Q,)*>(&self, range: &impl_compressed_binary_indexed_tree!(@tuple () () $($Q)*)) -> M::T
155 where
156 $($Q: RangeBounds<$T>,)*
157 {
158 match range.0.start_bound() {
159 Bound::Unbounded => (),
160 _ => panic!("expected `Bound::Unbounded`"),
161 };
162 let mut k = match range.0.end_bound() {
163 Bound::Included(index) => self.compress.position_bisect(|x| x > &index),
164 Bound::Excluded(index) => self.compress.position_bisect(|x| x >= &index),
165 Bound::Unbounded => self.compress.len(),
166 };
167 let mut x = M::unit();
168 while k > 0 {
169 x = M::operate(&x, &impl_compressed_binary_indexed_tree!(@acc self.bits[k], range $($T)*));
170 k -= k & (!k + 1);
171 }
172 x
173 }
174 pub fn update(&mut self, key: &impl_compressed_binary_indexed_tree!(@tuple () () $($T)*), x: &M::T) {
175 let mut k = self.compress.binary_search(&key.0).expect("not exist key") + 1;
176 while k < self.bits.len() {
177 impl_compressed_binary_indexed_tree!(@update self.bits[k], M key x $($T)*);
178 k += k & (!k + 1);
179 }
180 }
181 }
182 pub type $C<M, $($T),*> = impl_compressed_binary_indexed_tree!(@cst M $($T)*);
183 };
184 (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][]) => {
185 impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
186 };
187 (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][$D:ident $U:ident $R:ident $($Rest:ident)*]) => {
188 impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
189 impl_compressed_binary_indexed_tree!(@inner [$D][$($T)* $U][$($Q)* $R][$($Rest)*]);
190 };
191 ($C:ident $T:ident $Q:ident $($Rest:ident)* $(;$($t:tt)*)?) => {
192 impl_compressed_binary_indexed_tree!(@inner [$C][$T][$Q][$($Rest)*]);
193 };
194 ($($t:tt)*) => {
195 compile_error!($($t:tt)*)
196 }
197}
198
199impl_compressed_binary_indexed_tree!(
200 CompressedBinaryIndexedTree1d A QA
201 CompressedBinaryIndexedTree2d B QB
202 CompressedBinaryIndexedTree3d C QC
203 CompressedBinaryIndexedTree4d D QD;
204 CompressedBinaryIndexedTree5d E QE
205 CompressedBinaryIndexedTree6d F QF
206 CompressedBinaryIndexedTree7d G QG
207 CompressedBinaryIndexedTree8d H QH
208 CompressedBinaryIndexedTree9d I QI
209);
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::{algebra::AdditiveOperation, tools::Xorshift};
215 use std::{collections::HashMap, ops::RangeTo};
216
217 #[test]
218 fn test_bit4d() {
219 let mut rng = Xorshift::default();
220 const N: usize = 100;
221 const Q: usize = 5000;
222 const A: RangeTo<u64> = ..1_000;
223 let mut points: Vec<_> = rng.random_iter(((A), (A, (A, (A,))))).take(N).collect();
224 points.sort();
225 points.dedup();
226 let mut map: HashMap<_, _> = points.iter().map(|p| (p, 0u64)).collect();
227 let mut bit =
228 CompressedBinaryIndexedTree4d::<AdditiveOperation<u64>, _, _, _, _>::new(&points);
229 for _ in 0..Q {
230 let p = &points[rng.random(0..points.len())];
231 let x = rng.random(A);
232 *map.get_mut(p).unwrap() += x;
233 bit.update(p, &x);
234
235 let mut f = || {
236 (
237 Bound::Unbounded,
238 match rng.rand(3) {
239 0 => Bound::Excluded(rng.random(A)),
240 1 => Bound::Included(rng.random(A)),
241 _ => Bound::Unbounded,
242 },
243 )
244 };
245
246 let range = (f(), (f(), (f(), (f(),))));
247 let (r0, (r1, (r2, (r3,)))) = range;
248 let expected: u64 = map
249 .iter()
250 .filter_map(|((p0, (p1, (p2, (p3,)))), x)| {
251 if RangeBounds::contains(&r0, p0)
252 && RangeBounds::contains(&r1, p1)
253 && RangeBounds::contains(&r2, p2)
254 && RangeBounds::contains(&r3, p3)
255 {
256 Some(*x)
257 } else {
258 None
259 }
260 })
261 .sum();
262 let result = bit.accumulate(&range);
263 assert_eq!(expected, result);
264 }
265 }
266}