competitive/data_structure/
compressed_segment_tree.rs

1use super::{GetDistinctMut, Monoid, SliceBisectExt};
2use std::{
3    fmt::{self, Debug},
4    marker::PhantomData,
5    mem::swap,
6    ops::{Bound, RangeBounds},
7};
8
9pub struct CompressedSegmentTree<M, X, Inner>
10where
11    M: Monoid,
12{
13    compress: Vec<X>,
14    segs: Vec<Inner>,
15    _marker: PhantomData<fn() -> M>,
16}
17
18impl<M, X, Inner> Debug for CompressedSegmentTree<M, X, Inner>
19where
20    M: Monoid,
21    X: Debug,
22    Inner: Debug,
23{
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        f.debug_struct("CompressedSegmentTree")
26            .field("compress", &self.compress)
27            .field("segs", &self.segs)
28            .finish()
29    }
30}
31
32impl<M, X, Inner> Clone for CompressedSegmentTree<M, X, Inner>
33where
34    M: Monoid,
35    X: Clone,
36    Inner: Clone,
37{
38    fn clone(&self) -> Self {
39        Self {
40            compress: self.compress.clone(),
41            segs: self.segs.clone(),
42            _marker: self._marker,
43        }
44    }
45}
46
47impl<M, X, Inner> Default for CompressedSegmentTree<M, X, Inner>
48where
49    M: Monoid,
50{
51    fn default() -> Self {
52        Self {
53            compress: Default::default(),
54            segs: Default::default(),
55            _marker: Default::default(),
56        }
57    }
58}
59
60#[repr(transparent)]
61pub struct Tag<M>(M::T)
62where
63    M: Monoid;
64
65impl<M> Debug for Tag<M>
66where
67    M: Monoid,
68    M::T: Debug,
69{
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        self.0.fmt(f)
72    }
73}
74
75impl<M> Clone for Tag<M>
76where
77    M: Monoid,
78{
79    fn clone(&self) -> Self {
80        Self(self.0.clone())
81    }
82}
83
84#[allow(unused_macros)]
85macro_rules! impl_compressed_segment_tree {
86    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident) => {
87        ($($l)* $T $($r)*,)
88    };
89    (@tuple ($($l:tt)*) ($($r:tt)*) $T:ident $($Rest:ident)+) => {
90        ($($l)* $T $($r)*, impl_compressed_segment_tree!(@tuple ($($l)*) ($($r)*) $($Rest)+))
91    };
92    (@cst $M:ident) => {
93        Tag<$M>
94    };
95    (@cst $M:ident $T:ident $($Rest:ident)*) => {
96        CompressedSegmentTree<$M, $T, impl_compressed_segment_tree!(@cst $M $($Rest)*)>
97    };
98    (@from_iter $M:ident $points:ident $T:ident) => {{
99        let mut compress: Vec<_> = $points.into_iter().map(|t| t.0.clone()).collect();
100        compress.sort_unstable();
101        compress.dedup();
102        let n = compress.len();
103        Self {
104            compress,
105            segs: vec![Tag(M::unit()); n * 2],
106            _marker: PhantomData,
107        }
108    }};
109    (@from_iter $M:ident $points:ident $T:ident $U:ident $($Rest:ident)*) => {{
110        let mut compress: Vec<_> = $points.clone().into_iter().map(|t| t.0.clone()).collect();
111        compress.sort_unstable();
112        compress.dedup();
113        let n = compress.len();
114        let mut segs = vec![CompressedSegmentTree::default(); n * 2];
115        let mut ps = vec![vec![]; n * 2];
116        for (x, q) in $points {
117            let i = compress.position_bisect(|c| x <= c);
118            ps[i + n].push(q);
119        }
120        for i in (n..n * 2).rev() {
121            segs[i] = CompressedSegmentTree::<_, _, impl_compressed_segment_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
122        }
123        for i in (1..n).rev() {
124            let (p, l, r) = ps.get_distinct_mut((i, i * 2, i * 2 + 1));
125            swap(p, l);
126            p.append(r);
127            segs[i] = CompressedSegmentTree::<_, _, impl_compressed_segment_tree!(@cst $M $($Rest)*)>::from_iter(ps[i].iter().cloned());
128        }
129        Self {
130            compress,
131            segs,
132            _marker: PhantomData,
133        }
134    }};
135    (@fold $e:expr, $rng:ident $T:ident) => {
136        $e.0
137    };
138    (@fold $e:expr, $rng:ident $T:ident $($Rest:ident)+) => {
139        $e.fold(&$rng.1)
140    };
141    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident) => {
142        $M::operate_assign(&mut $e.0, $x);
143    };
144    (@update $e:expr, $M:ident $key:ident $x:ident $T:ident $($Rest:ident)+) => {
145        $e.update(&$key.1, $x);
146    };
147    (@impl $C:ident $($T:ident)*, $($Q:ident)*) => {
148        impl<M, $($T,)*> impl_compressed_segment_tree!(@cst M $($T)*)
149        where
150            M: Monoid,
151            $($T: Clone + Ord,)*
152        {
153            pub fn new(points: &[impl_compressed_segment_tree!(@tuple () () $($T)*)]) -> Self {
154                Self::from_iter(points)
155            }
156            fn from_iter<'a, Iter>(points: Iter) -> Self
157            where
158                $($T: 'a,)*
159                Iter: IntoIterator<Item = &'a impl_compressed_segment_tree!(@tuple () () $($T)*)> + Clone,
160            {
161                impl_compressed_segment_tree!(@from_iter M points $($T)*)
162            }
163            pub fn fold<$($Q,)*>(&self, range: &impl_compressed_segment_tree!(@tuple () () $($Q)*)) -> M::T
164            where
165                $($Q: RangeBounds<$T>,)*
166            {
167                let mut l = match range.0.start_bound() {
168                    Bound::Included(index) => self.compress.position_bisect(|x| x >= &index),
169                    Bound::Excluded(index) => self.compress.position_bisect(|x| x > &index),
170                    Bound::Unbounded => 0,
171                } + self.compress.len();
172                let mut r = match range.0.end_bound() {
173                    Bound::Included(index) => self.compress.position_bisect(|x| x > &index),
174                    Bound::Excluded(index) => self.compress.position_bisect(|x| x >= &index),
175                    Bound::Unbounded => self.compress.len(),
176                } + self.compress.len();
177                let mut x = M::unit();
178                while l < r {
179                    if l & 1 != 0 {
180                        x = M::operate(&x, &impl_compressed_segment_tree!(@fold self.segs[l], range $($T)*));
181                        l += 1;
182                    }
183                    if r & 1 != 0 {
184                        r -= 1;
185                        x = M::operate(&impl_compressed_segment_tree!(@fold self.segs[r], range $($T)*), &x);
186                    }
187                    l /= 2;
188                    r /= 2;
189                }
190                x
191            }
192            pub fn update(&mut self, key: &impl_compressed_segment_tree!(@tuple () () $($T)*), x: &M::T) {
193                let mut i = self.compress.binary_search(&key.0).expect("not exist key") + self.compress.len();
194                while i > 0 {
195                    impl_compressed_segment_tree!(@update self.segs[i], M key x $($T)*);
196                    i /= 2;
197                }
198            }
199        }
200        pub type $C<M, $($T),*> = impl_compressed_segment_tree!(@cst M $($T)*);
201    };
202    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][]) => {
203        impl_compressed_segment_tree!(@impl $C $($T)*, $($Q)*);
204    };
205    (@inner [$C:ident][$($T:ident)*][$($Q:ident)*][$D:ident $U:ident $R:ident $($Rest:ident)*]) => {
206        impl_compressed_segment_tree!(@impl $C $($T)*, $($Q)*);
207        impl_compressed_segment_tree!(@inner [$D][$($T)* $U][$($Q)* $R][$($Rest)*]);
208    };
209    ($C:ident $T:ident $Q:ident $($Rest:ident)* $(;$($t:tt)*)?) => {
210        impl_compressed_segment_tree!(@inner [$C][$T][$Q][$($Rest)*]);
211    };
212}
213
214impl_compressed_segment_tree!(
215    CompressedSegmentTree1d T1 Q1
216    CompressedSegmentTree2d T2 Q2
217    CompressedSegmentTree3d T3 Q3
218    CompressedSegmentTree4d T4 Q4;
219    CompressedSegmentTree5d T5 Q5
220    CompressedSegmentTree6d T6 Q6
221    CompressedSegmentTree7d T7 Q7
222    CompressedSegmentTree8d T8 Q8
223    CompressedSegmentTree9d T9 Q9
224);
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use crate::{
230        algebra::AdditiveOperation,
231        tools::{RandRange as RR, Xorshift},
232    };
233    use std::{collections::HashMap, ops::Range};
234
235    #[test]
236    fn test_seg4d() {
237        let mut rng = Xorshift::default();
238        const N: usize = 100;
239        const Q: usize = 5000;
240        const A: Range<i64> = -1_000..1_000;
241        let mut points: Vec<_> = rng.random_iter(((A), (A, (A, (A,))))).take(N).collect();
242        points.sort();
243        points.dedup();
244        let mut map: HashMap<_, _> = points.iter().map(|p| (p, 0i64)).collect();
245        let mut seg = CompressedSegmentTree4d::<AdditiveOperation<i64>, _, _, _, _>::new(&points);
246        for _ in 0..Q {
247            let p = &points[rng.random(0..points.len())];
248            let x = rng.random(A);
249            *map.get_mut(p).unwrap() += x;
250            seg.update(p, &x);
251
252            let range = rng.random((RR::new(A), (RR::new(A), (RR::new(A), (RR::new(A),)))));
253            let (r0, (r1, (r2, (r3,)))) = range;
254            let expected: i64 = map
255                .iter()
256                .filter_map(|((p0, (p1, (p2, (p3,)))), x)| {
257                    if RangeBounds::contains(&r0, p0)
258                        && RangeBounds::contains(&r1, p1)
259                        && RangeBounds::contains(&r2, p2)
260                        && RangeBounds::contains(&r3, p3)
261                    {
262                        Some(*x)
263                    } else {
264                        None
265                    }
266                })
267                .sum();
268            let result = seg.fold(&range);
269            assert_eq!(expected, result);
270        }
271    }
272}