competitive/data_structure/
compressed_binary_indexed_tree.rs

1use super::{Monoid, SliceBisectExt};
2use std::{
3    fmt::{self, Debug},
4    marker::PhantomData,
5    ops::{Bound, RangeBounds},
6};
7
8pub struct CompressedBinaryIndexedTree<M, X, Inner>
9where
10    M: Monoid,
11{
12    compress: Vec<X>,
13    bits: Vec<Inner>,
14    _marker: PhantomData<fn() -> M>,
15}
16impl<M, X, Inner> Debug for CompressedBinaryIndexedTree<M, X, Inner>
17where
18    M: Monoid,
19    X: Debug,
20    Inner: Debug,
21{
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        f.debug_struct("CompressedBinaryIndexedTree")
24            .field("compress", &self.compress)
25            .field("bits", &self.bits)
26            .finish()
27    }
28}
29impl<M, X, Inner> Clone for CompressedBinaryIndexedTree<M, X, Inner>
30where
31    M: Monoid,
32    X: Clone,
33    Inner: Clone,
34{
35    fn clone(&self) -> Self {
36        Self {
37            compress: self.compress.clone(),
38            bits: self.bits.clone(),
39            _marker: self._marker,
40        }
41    }
42}
43impl<M, X, Inner> Default for CompressedBinaryIndexedTree<M, X, Inner>
44where
45    M: Monoid,
46{
47    fn default() -> Self {
48        Self {
49            compress: Default::default(),
50            bits: Default::default(),
51            _marker: Default::default(),
52        }
53    }
54}
55#[repr(transparent)]
56pub struct Tag<M>(M::T)
57where
58    M: Monoid;
59impl<M> Debug for Tag<M>
60where
61    M: Monoid<T: Debug>,
62{
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        self.0.fmt(f)
65    }
66}
67impl<M> Clone for Tag<M>
68where
69    M: Monoid,
70{
71    fn clone(&self) -> Self {
72        Self(self.0.clone())
73    }
74}
75
76macro_rules! impl_compressed_binary_indexed_tree {
77    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident) => {
78        ($($l)* $T $($r)*,)
79    };
80    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident $($Rest:ident)+) => {
81        ($($l)* $T $($r)*, impl_compressed_binary_indexed_tree!(@tuple ($($l)*) ($($r)*) $($Rest)+))
82    };
83    (@cst $M:ident) => {
84        Tag<$M>
85    };
86    (@cst $M:ident $T:ident $($Rest:ident)*) => {
87        CompressedBinaryIndexedTree<$M, $T, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>
88    };
89    (@from_iter $M:ident $points:ident $T:ident) => {{
90        let mut compress: Vec<_> = $points.into_iter().map(|t| t.0.clone()).collect();
91        compress.sort_unstable();
92        compress.dedup();
93        let n = compress.len();
94        Self {
95            compress,
96            bits: vec![Tag(M::unit()); n + 1],
97            _marker: PhantomData,
98        }
99    }};
100    (@from_iter $M:ident $points:ident $T:ident $U:ident $($Rest:ident)*) => {{
101        let mut compress: Vec<_> = $points.clone().into_iter().map(|t| t.0.clone()).collect();
102        compress.sort_unstable();
103        compress.dedup();
104        let n = compress.len();
105        let mut bits = vec![CompressedBinaryIndexedTree::default(); n + 1];
106        let mut ps = vec![vec![]; n + 1];
107        for (x, q) in $points {
108            let i = compress.position_bisect(|c| x <= c);
109            ps[i + 1].push(q);
110        }
111        for i in 1..=n {
112            bits[i] = CompressedBinaryIndexedTree::<_, _, impl_compressed_binary_indexed_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
113            let j = i + (i & (!i + 1));
114            if j <= n {
115                let [s, ns] = ps.get_disjoint_mut([i, j]).unwrap();
116                ns.append(s);
117            }
118        }
119        Self {
120            compress,
121            bits,
122            _marker: PhantomData,
123        }
124    }};
125    (@acc $e:expr, $rng:ident $T:ident) => {
126        $e.0
127    };
128    (@acc $e:expr, $rng:ident $T:ident $($Rest:ident)+) => {
129        $e.accumulate(&$rng.1)
130    };
131    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident) => {
132        $M::operate_assign(&mut $e.0, $x);
133    };
134    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident $($Rest:ident)+) => {
135        $e.update(&$key.1, $x);
136    };
137    (@impl $C:ident $($T:ident)*, $($Q:ident)*) => {
138        impl<M, $($T,)*> impl_compressed_binary_indexed_tree!(@cst M $($T)*)
139        where
140            M: Monoid,
141            $($T: Clone + Ord,)*
142        {
143            pub fn new(points: &[impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)]) -> Self {
144                Self::from_iter(points)
145            }
146            fn from_iter<'a, Iter>(points: Iter) -> Self
147            where
148                $($T: 'a,)*
149                Iter: IntoIterator<Item = &'a impl_compressed_binary_indexed_tree!(@tuple () () $($T)*)> + Clone,
150            {
151                impl_compressed_binary_indexed_tree!(@from_iter M points $($T)*)
152            }
153            pub fn accumulate<$($Q,)*>(&self, range: &impl_compressed_binary_indexed_tree!(@tuple () () $($Q)*)) -> M::T
154            where
155                $($Q: RangeBounds<$T>,)*
156            {
157                match range.0.start_bound() {
158                    Bound::Unbounded => (),
159                    _ => panic!("expected `Bound::Unbounded`"),
160                };
161                let mut k = match range.0.end_bound() {
162                    Bound::Included(index) => self.compress.position_bisect(|x| x > &index),
163                    Bound::Excluded(index) => self.compress.position_bisect(|x| x >= &index),
164                    Bound::Unbounded => self.compress.len(),
165                };
166                let mut x = M::unit();
167                while k > 0 {
168                    x = M::operate(&x, &impl_compressed_binary_indexed_tree!(@acc self.bits[k], range $($T)*));
169                    k -= k & (!k + 1);
170                }
171                x
172            }
173            pub fn update(&mut self, key: &impl_compressed_binary_indexed_tree!(@tuple () () $($T)*), x: &M::T) {
174                let mut k = self.compress.binary_search(&key.0).expect("not exist key") + 1;
175                while k < self.bits.len() {
176                    impl_compressed_binary_indexed_tree!(@update self.bits[k], M key x $($T)*);
177                    k += k & (!k + 1);
178                }
179            }
180        }
181        pub type $C<M, $($T),*> = impl_compressed_binary_indexed_tree!(@cst M $($T)*);
182    };
183    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][]) => {
184        impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
185    };
186    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][$D:ident $U:ident $R:ident $($Rest:ident)*]) => {
187        impl_compressed_binary_indexed_tree!(@impl $C $($T)*, $($Q)*);
188        impl_compressed_binary_indexed_tree!(@inner [$D][$($T)* $U][$($Q)* $R][$($Rest)*]);
189    };
190    ($C:ident $T:ident $Q:ident $($Rest:ident)* $(;$($t:tt)*)?) => {
191        impl_compressed_binary_indexed_tree!(@inner [$C][$T][$Q][$($Rest)*]);
192    };
193    ($($t:tt)*) => {
194        compile_error!($($t:tt)*)
195    }
196}
197
198impl_compressed_binary_indexed_tree!(
199    CompressedBinaryIndexedTree1d A QA
200    CompressedBinaryIndexedTree2d B QB
201    CompressedBinaryIndexedTree3d C QC
202    CompressedBinaryIndexedTree4d D QD;
203    CompressedBinaryIndexedTree5d E QE
204    CompressedBinaryIndexedTree6d F QF
205    CompressedBinaryIndexedTree7d G QG
206    CompressedBinaryIndexedTree8d H QH
207    CompressedBinaryIndexedTree9d I QI
208);
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::{algebra::AdditiveOperation, tools::Xorshift};
214    use std::{collections::HashMap, ops::RangeTo};
215
216    #[test]
217    fn test_bit4d() {
218        let mut rng = Xorshift::default();
219        const N: usize = 100;
220        const Q: usize = 5000;
221        const A: RangeTo<u64> = ..1_000;
222        let mut points: Vec<_> = rng.random_iter(((A), (A, (A, (A,))))).take(N).collect();
223        points.sort();
224        points.dedup();
225        let mut map: HashMap<_, _> = points.iter().map(|p| (p, 0u64)).collect();
226        let mut bit =
227            CompressedBinaryIndexedTree4d::<AdditiveOperation<u64>, _, _, _, _>::new(&points);
228        for _ in 0..Q {
229            let p = &points[rng.random(0..points.len())];
230            let x = rng.random(A);
231            *map.get_mut(p).unwrap() += x;
232            bit.update(p, &x);
233
234            let mut f = || {
235                (
236                    Bound::Unbounded,
237                    match rng.rand(3) {
238                        0 => Bound::Excluded(rng.random(A)),
239                        1 => Bound::Included(rng.random(A)),
240                        _ => Bound::Unbounded,
241                    },
242                )
243            };
244
245            let range = (f(), (f(), (f(), (f(),))));
246            let (r0, (r1, (r2, (r3,)))) = range;
247            let expected: u64 = map
248                .iter()
249                .filter_map(|((p0, (p1, (p2, (p3,)))), x)| {
250                    if RangeBounds::contains(&r0, p0)
251                        && RangeBounds::contains(&r1, p1)
252                        && RangeBounds::contains(&r2, p2)
253                        && RangeBounds::contains(&r3, p3)
254                    {
255                        Some(*x)
256                    } else {
257                        None
258                    }
259                })
260                .sum();
261            let result = bit.accumulate(&range);
262            assert_eq!(expected, result);
263        }
264    }
265}