competitive/data_structure/
compressed_binary_indexed_tree.rs

1use super::{GetDistinctMut, Monoid, SliceBisectExt};
2use std::{
3    fmt::{self, Debug},
4    marker::PhantomData,
5    ops::{Bound, RangeBounds},
6};
7
8pub struct CompressedBinaryIndexedTree<M, X, Inner>
9where
10    M: Monoid,
11{
12    compress: Vec<X>,
13    bits: Vec<Inner>,
14    _marker: PhantomData<fn() -> M>,
15}
16impl<M, X, Inner> Debug for CompressedBinaryIndexedTree<M, X, Inner>
17where
18    M: Monoid,
19    X: Debug,
20    Inner: Debug,
21{
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        f.debug_struct("CompressedBinaryIndexedTree")
24            .field("compress", &self.compress)
25            .field("bits", &self.bits)
26            .finish()
27    }
28}
29impl<M, X, Inner> Clone for CompressedBinaryIndexedTree<M, X, Inner>
30where
31    M: Monoid,
32    X: Clone,
33    Inner: Clone,
34{
35    fn clone(&self) -> Self {
36        Self {
37            compress: self.compress.clone(),
38            bits: self.bits.clone(),
39            _marker: self._marker,
40        }
41    }
42}
43impl<M, X, Inner> Default for CompressedBinaryIndexedTree<M, X, Inner>
44where
45    M: Monoid,
46{
47    fn default() -> Self {
48        Self {
49            compress: Default::default(),
50            bits: Default::default(),
51            _marker: Default::default(),
52        }
53    }
54}
55#[repr(transparent)]
56pub struct Tag<M>(M::T)
57where
58    M: Monoid;
59impl<M> Debug for Tag<M>
60where
61    M: Monoid,
62    M::T: Debug,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        self.0.fmt(f)
66    }
67}
68impl<M> Clone for Tag<M>
69where
70    M: Monoid,
71{
72    fn clone(&self) -> Self {
73        Self(self.0.clone())
74    }
75}
76#[allow(unused_macros)]
77macro_rules! impl_compressed_binary_indexed_tree {
78    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident) => {
79        ($($l)* $T $($r)*,)
80    };
81    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident $($Rest:ident)+) => {
82        ($($l)* $T $($r)*, impl_compressed_binary_indexed_tree!(@tuple ($($l)*) ($($r)*) $($Rest)+))
83    };
84    (@cst $M:ident) => {
85        Tag<$M>
86    };
87    (@cst $M:ident $T:ident $($Rest:ident)*) => {
88        CompressedBinaryIndexedTree<$M, $T, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>
89    };
90    (@from_iter $M:ident $points:ident $T:ident) => {{
91        let mut compress: Vec<_> = $points.into_iter().map(|t| t.0.clone()).collect();
92        compress.sort_unstable();
93        compress.dedup();
94        let n = compress.len();
95        Self {
96            compress,
97            bits: vec![Tag(M::unit()); n + 1],
98            _marker: PhantomData,
99        }
100    }};
101    (@from_iter $M:ident $points:ident $T:ident $U:ident $($Rest:ident)*) => {{
102        let mut compress: Vec<_> = $points.clone().into_iter().map(|t| t.0.clone()).collect();
103        compress.sort_unstable();
104        compress.dedup();
105        let n = compress.len();
106        let mut bits = vec![CompressedBinaryIndexedTree::default(); n + 1];
107        let mut ps = vec![vec![]; n + 1];
108        for (x, q) in $points {
109            let i = compress.position_bisect(|c| x <= c);
110            ps[i + 1].push(q);
111        }
112        for i in 1..=n {
113            bits[i] = CompressedBinaryIndexedTree::<_, _, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
114            let j = i + (i & (!i + 1));
115            if j <= n {
116                let (s, ns) = ps.get_distinct_mut((i, j));
117                ns.append(s);
118            }
119        }
120        Self {
121            compress,
122            bits,
123            _marker: PhantomData,
124        }
125    }};
126    (@acc $e:expr, $rng:ident $T:ident) => {
127        $e.0
128    };
129    (@acc $e:expr, $rng:ident $T:ident $($Rest:ident)+) => {
130        $e.accumulate(&$rng.1)
131    };
132    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident) => {
133        $M::operate_assign(&mut $e.0, $x);
134    };
135    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident $($Rest:ident)+) => {
136        $e.update(&$key.1, $x);
137    };
138    (@impl $C:ident $($T:ident)*, $($Q:ident)*) => {
139        impl<M, $($T,)*> impl_compressed_binary_indexed_tree!(@cst M $($T)*)
140        where
141            M: Monoid,
142            $($T: Clone + Ord,)*
143        {
144            pub fn new(points: &[impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)]) -> Self {
145                Self::from_iter(points)
146            }
147            fn from_iter<'a, Iter>(points: Iter) -> Self
148            where
149                $($T: 'a,)*
150                Iter: IntoIterator<Item = &'a impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)> + Clone,
151            {
152                impl_compressed_binary_indexed_tree!(@from_iter M points $($T)*)
153            }
154            pub fn accumulate<$($Q,)*>(&self, range: &impl_compressed_binary_indexed_tree!(@tuple () () $($Q)*)) -> M::T
155            where
156                $($Q: RangeBounds<$T>,)*
157            {
158                match range.0.start_bound() {
159                    Bound::Unbounded => (),
160                    _ => panic!("expected `Bound::Unbounded`"),
161                };
162                let mut k = match range.0.end_bound() {
163                    Bound::Included(index) => self.compress.position_bisect(|x| x > &index),
164                    Bound::Excluded(index) => self.compress.position_bisect(|x| x >= &index),
165                    Bound::Unbounded => self.compress.len(),
166                };
167                let mut x = M::unit();
168                while k > 0 {
169                    x = M::operate(&x, &impl_compressed_binary_indexed_tree!(@acc self.bits[k], range $($T)*));
170                    k -= k & (!k + 1);
171                }
172                x
173            }
174            pub fn update(&mut self, key: &impl_compressed_binary_indexed_tree!(@tuple () () $($T)*), x: &M::T) {
175                let mut k = self.compress.binary_search(&key.0).expect("not exist key") + 1;
176                while k < self.bits.len() {
177                    impl_compressed_binary_indexed_tree!(@update self.bits[k], M key x $($T)*);
178                    k += k & (!k + 1);
179                }
180            }
181        }
182        pub type $C<M, $($T),*> = impl_compressed_binary_indexed_tree!(@cst M $($T)*);
183    };
184    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][]) => {
185        impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
186    };
187    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][$D:ident $U:ident $R:ident $($Rest:ident)*]) => {
188        impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
189        impl_compressed_binary_indexed_tree!(@inner [$D][$($T)* $U][$($Q)* $R][$($Rest)*]);
190    };
191    ($C:ident $T:ident $Q:ident $($Rest:ident)* $(;$($t:tt)*)?) => {
192        impl_compressed_binary_indexed_tree!(@inner [$C][$T][$Q][$($Rest)*]);
193    };
194    ($($t:tt)*) => {
195        compile_error!($($t:tt)*)
196    }
197}
198
199impl_compressed_binary_indexed_tree!(
200    CompressedBinaryIndexedTree1d A QA
201    CompressedBinaryIndexedTree2d B QB
202    CompressedBinaryIndexedTree3d C QC
203    CompressedBinaryIndexedTree4d D QD;
204    CompressedBinaryIndexedTree5d E QE
205    CompressedBinaryIndexedTree6d F QF
206    CompressedBinaryIndexedTree7d G QG
207    CompressedBinaryIndexedTree8d H QH
208    CompressedBinaryIndexedTree9d I QI
209);
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::{algebra::AdditiveOperation, tools::Xorshift};
215    use std::{collections::HashMap, ops::RangeTo};
216
217    #[test]
218    fn test_bit4d() {
219        let mut rng = Xorshift::default();
220        const N: usize = 100;
221        const Q: usize = 5000;
222        const A: RangeTo<u64> = ..1_000;
223        let mut points: Vec<_> = rng.random_iter(((A), (A, (A, (A,))))).take(N).collect();
224        points.sort();
225        points.dedup();
226        let mut map: HashMap<_, _> = points.iter().map(|p| (p, 0u64)).collect();
227        let mut bit =
228            CompressedBinaryIndexedTree4d::<AdditiveOperation<u64>, _, _, _, _>::new(&points);
229        for _ in 0..Q {
230            let p = &points[rng.random(0..points.len())];
231            let x = rng.random(A);
232            *map.get_mut(p).unwrap() += x;
233            bit.update(p, &x);
234
235            let mut f = || {
236                (
237                    Bound::Unbounded,
238                    match rng.rand(3) {
239                        0 => Bound::Excluded(rng.random(A)),
240                        1 => Bound::Included(rng.random(A)),
241                        _ => Bound::Unbounded,
242                    },
243                )
244            };
245
246            let range = (f(), (f(), (f(), (f(),))));
247            let (r0, (r1, (r2, (r3,)))) = range;
248            let expected: u64 = map
249                .iter()
250                .filter_map(|((p0, (p1, (p2, (p3,)))), x)| {
251                    if RangeBounds::contains(&r0, p0)
252                        && RangeBounds::contains(&r1, p1)
253                        && RangeBounds::contains(&r2, p2)
254                        && RangeBounds::contains(&r3, p3)
255                    {
256                        Some(*x)
257                    } else {
258                        None
259                    }
260                })
261                .sum();
262            let result = bit.accumulate(&range);
263            assert_eq!(expected, result);
264        }
265    }
266}