competitive/data_structure/
compressed_segment_tree.rs

1use super::{Monoid, SliceBisectExt};
2use std::{
3    fmt::{self, Debug},
4    marker::PhantomData,
5    mem::swap,
6    ops::{Bound, RangeBounds},
7};
8
9pub struct CompressedSegmentTree<M, X, Inner>
10where
11    M: Monoid,
12{
13    compress: Vec<X>,
14    segs: Vec<Inner>,
15    _marker: PhantomData<fn() -> M>,
16}
17
18impl<M, X, Inner> Debug for CompressedSegmentTree<M, X, Inner>
19where
20    M: Monoid,
21    X: Debug,
22    Inner: Debug,
23{
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_struct("CompressedSegmentTree")
26            .field("compress", &self.compress)
27            .field("segs", &self.segs)
28            .finish()
29    }
30}
31
32impl<M, X, Inner> Clone for CompressedSegmentTree<M, X, Inner>
33where
34    M: Monoid,
35    X: Clone,
36    Inner: Clone,
37{
38    fn clone(&self) -> Self {
39        Self {
40            compress: self.compress.clone(),
41            segs: self.segs.clone(),
42            _marker: self._marker,
43        }
44    }
45}
46
47impl<M, X, Inner> Default for CompressedSegmentTree<M, X, Inner>
48where
49    M: Monoid,
50{
51    fn default() -> Self {
52        Self {
53            compress: Default::default(),
54            segs: Default::default(),
55            _marker: Default::default(),
56        }
57    }
58}
59
60#[repr(transparent)]
61pub struct Tag<M>(M::T)
62where
63    M: Monoid;
64
65impl<M> Debug for Tag<M>
66where
67    M: Monoid<T: Debug>,
68{
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        self.0.fmt(f)
71    }
72}
73
74impl<M> Clone for Tag<M>
75where
76    M: Monoid,
77{
78    fn clone(&self) -> Self {
79        Self(self.0.clone())
80    }
81}
82
83macro_rules! impl_compressed_segment_tree {
84    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident) => {
85        ($($l)* $T $($r)*,)
86    };
87    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident $($Rest:ident)+) => {
88        ($($l)* $T $($r)*, impl_compressed_segment_tree!(@tuple ($($l)*) ($($r)*) $($Rest)+))
89    };
90    (@cst $M:ident) => {
91        Tag<$M>
92    };
93    (@cst $M:ident $T:ident $($Rest:ident)*) => {
94        CompressedSegmentTree<$M, $T, impl_compressed_segment_tree!(@cst $M $($Rest)*)>
95    };
96    (@from_iter $M:ident $points:ident $T:ident) => {{
97        let mut compress: Vec<_> = $points.into_iter().map(|t| t.0.clone()).collect();
98        compress.sort_unstable();
99        compress.dedup();
100        let n = compress.len();
101        Self {
102            compress,
103            segs: vec![Tag(M::unit()); n * 2],
104            _marker: PhantomData,
105        }
106    }};
107    (@from_iter $M:ident $points:ident $T:ident $U:ident $($Rest:ident)*) => {{
108        let mut compress: Vec<_> = $points.clone().into_iter().map(|t| t.0.clone()).collect();
109        compress.sort_unstable();
110        compress.dedup();
111        let n = compress.len();
112        let mut segs = vec![CompressedSegmentTree::default(); n * 2];
113        let mut ps = vec![vec![]; n * 2];
114        for (x, q) in $points {
115            let i = compress.position_bisect(|c| x <= c);
116            ps[i + n].push(q);
117        }
118        for i in (n..n * 2).rev() {
119            segs[i] = CompressedSegmentTree::<_, _, impl_compressed_segment_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
120        }
121        for i in (1..n).rev() {
122            let [p, l, r] = ps.get_disjoint_mut([i, i * 2, i * 2 + 1]).unwrap();
123            swap(p, l);
124            p.append(r);
125            segs[i] = CompressedSegmentTree::<_, _, impl_compressed_segment_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
126        }
127        Self {
128            compress,
129            segs,
130            _marker: PhantomData,
131        }
132    }};
133    (@fold $e:expr, $rng:ident $T:ident) => {
134        $e.0
135    };
136    (@fold $e:expr, $rng:ident $T:ident $($Rest:ident)+) => {
137        $e.fold(&$rng.1)
138    };
139    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident) => {
140        $M::operate_assign(&mut $e.0, $x);
141    };
142    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident $($Rest:ident)+) => {
143        $e.update(&$key.1, $x);
144    };
145    (@impl $C:ident $($T:ident)*, $($Q:ident)*) => {
146        impl<M, $($T,)*> impl_compressed_segment_tree!(@cst M $($T)*)
147        where
148            M: Monoid,
149            $($T: Clone + Ord,)*
150        {
151            pub fn new(points: &[impl_compressed_segment_tree!(@tuple () () $($T)*)]) -> Self {
152                Self::from_iter(points)
153            }
154            fn from_iter<'a, Iter>(points: Iter) -> Self
155            where
156                $($T: 'a,)*
157                Iter: IntoIterator<Item = &'a impl_compressed_segment_tree!(@tuple () () $($T)*)> + Clone,
158            {
159                impl_compressed_segment_tree!(@from_iter M points $($T)*)
160            }
161            pub fn fold<$($Q,)*>(&self, range: &impl_compressed_segment_tree!(@tuple () () $($Q)*)) -> M::T
162            where
163                $($Q: RangeBounds<$T>,)*
164            {
165                let mut l = match range.0.start_bound() {
166                    Bound::Included(index) => self.compress.position_bisect(|x| x >= &index),
167                    Bound::Excluded(index) => self.compress.position_bisect(|x| x > &index),
168                    Bound::Unbounded => 0,
169                } + self.compress.len();
170                let mut r = match range.0.end_bound() {
171                    Bound::Included(index) => self.compress.position_bisect(|x| x > &index),
172                    Bound::Excluded(index) => self.compress.position_bisect(|x| x >= &index),
173                    Bound::Unbounded => self.compress.len(),
174                } + self.compress.len();
175                let mut x = M::unit();
176                while l < r {
177                    if l & 1 != 0 {
178                        x = M::operate(&x, &impl_compressed_segment_tree!(@fold self.segs[l], range $($T)*));
179                        l += 1;
180                    }
181                    if r & 1 != 0 {
182                        r -= 1;
183                        x = M::operate(&impl_compressed_segment_tree!(@fold self.segs[r], range $($T)*), &x);
184                    }
185                    l /= 2;
186                    r /= 2;
187                }
188                x
189            }
190            pub fn update(&mut self, key: &impl_compressed_segment_tree!(@tuple () () $($T)*), x: &M::T) {
191                let mut i = self.compress.binary_search(&key.0).expect("not exist key") + self.compress.len();
192                while i > 0 {
193                    impl_compressed_segment_tree!(@update self.segs[i], M key x $($T)*);
194                    i /= 2;
195                }
196            }
197        }
198        pub type $C<M, $($T),*> = impl_compressed_segment_tree!(@cst M $($T)*);
199    };
200    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][]) => {
201        impl_compressed_segment_tree!(@impl $C $($T)*, $($Q)*);
202    };
203    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][$D:ident $U:ident $R:ident $($Rest:ident)*]) => {
204        impl_compressed_segment_tree!(@impl $C $($T)*, $($Q)*);
205        impl_compressed_segment_tree!(@inner [$D][$($T)* $U][$($Q)* $R][$($Rest)*]);
206    };
207    ($C:ident $T:ident $Q:ident $($Rest:ident)* $(;$($t:tt)*)?) => {
208        impl_compressed_segment_tree!(@inner [$C][$T][$Q][$($Rest)*]);
209    };
210}
211
212impl_compressed_segment_tree!(
213    CompressedSegmentTree1d T1 Q1
214    CompressedSegmentTree2d T2 Q2
215    CompressedSegmentTree3d T3 Q3
216    CompressedSegmentTree4d T4 Q4;
217    CompressedSegmentTree5d T5 Q5
218    CompressedSegmentTree6d T6 Q6
219    CompressedSegmentTree7d T7 Q7
220    CompressedSegmentTree8d T8 Q8
221    CompressedSegmentTree9d T9 Q9
222);
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::{
228        algebra::AdditiveOperation,
229        tools::{RandRange as RR, Xorshift},
230    };
231    use std::{collections::HashMap, ops::Range};
232
233    #[test]
234    fn test_seg4d() {
235        let mut rng = Xorshift::default();
236        const N: usize = 100;
237        const Q: usize = 5000;
238        const A: Range<i64> = -1_000..1_000;
239        let mut points: Vec<_> = rng.random_iter(((A), (A, (A, (A,))))).take(N).collect();
240        points.sort();
241        points.dedup();
242        let mut map: HashMap<_, _> = points.iter().map(|p| (p, 0i64)).collect();
243        let mut seg = CompressedSegmentTree4d::<AdditiveOperation<i64>, _, _, _, _>::new(&points);
244        for _ in 0..Q {
245            let p = &points[rng.random(0..points.len())];
246            let x = rng.random(A);
247            *map.get_mut(p).unwrap() += x;
248            seg.update(p, &x);
249
250            let range = rng.random((RR::new(A), (RR::new(A), (RR::new(A), (RR::new(A),)))));
251            let (r0, (r1, (r2, (r3,)))) = range;
252            let expected: i64 = map
253                .iter()
254                .filter_map(|((p0, (p1, (p2, (p3,)))), x)| {
255                    if RangeBounds::contains(&r0, p0)
256                        && RangeBounds::contains(&r1, p1)
257                        && RangeBounds::contains(&r2, p2)
258                        && RangeBounds::contains(&r3, p3)
259                    {
260                        Some(*x)
261                    } else {
262                        None
263                    }
264                })
265                .sum();
266            let result = seg.fold(&range);
267            assert_eq!(expected, result);
268        }
269    }
270}