competitive/data_structure/
lazy_segment_tree_map.rs

1use super::{MonoidAction, RangeBoundsExt};
2use std::{
3    collections::HashMap,
4    fmt::{self, Debug, Formatter},
5    mem::replace,
6    ops::RangeBounds,
7};
8
9pub struct LazySegmentTreeMap<M>
10where
11    M: MonoidAction,
12    M::Act: PartialEq,
13{
14    n: usize,
15    seg: HashMap<usize, (M::Agg, M::Act)>,
16}
17
18impl<M> Clone for LazySegmentTreeMap<M>
19where
20    M: MonoidAction,
21    M::Act: PartialEq,
22{
23    fn clone(&self) -> Self {
24        Self {
25            n: self.n,
26            seg: self.seg.clone(),
27        }
28    }
29}
30
31impl<M> Debug for LazySegmentTreeMap<M>
32where
33    M: MonoidAction,
34    M::Agg: Debug,
35    M::Act: PartialEq + Debug,
36{
37    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38        f.debug_struct("LazySegmentTreeMap")
39            .field("n", &self.n)
40            .field("seg", &self.seg)
41            .finish()
42    }
43}
44
45impl<M> LazySegmentTreeMap<M>
46where
47    M: MonoidAction,
48    M::Act: PartialEq,
49{
50    pub fn new(n: usize) -> Self {
51        Self {
52            n,
53            seg: Default::default(),
54        }
55    }
56    #[inline]
57    fn get_mut(&mut self, k: usize) -> &mut (M::Agg, M::Act) {
58        self.seg.entry(k).or_insert((M::agg_unit(), M::act_unit()))
59    }
60    #[inline]
61    fn update_at(&mut self, k: usize, x: &M::Act) {
62        let n = self.n;
63        let a = self.get_mut(k);
64        let nx = M::act_agg(&a.0, x);
65        if k < n {
66            a.1 = M::act_operate(&a.1, x);
67        }
68        if let Some(nx) = nx {
69            a.0 = nx;
70        } else if k < n {
71            self.propagate_at(k);
72            self.recalc_at(k);
73        } else {
74            panic!("act failed on leaf");
75        }
76    }
77    #[inline]
78    fn recalc_at(&mut self, k: usize) {
79        let x = match (self.seg.get(&(2 * k)), self.seg.get(&(2 * k + 1))) {
80            (None, None) => M::agg_unit(),
81            (None, Some((y, _))) => y.clone(),
82            (Some((x, _)), None) => x.clone(),
83            (Some((x, _)), Some((y, _))) => M::agg_operate(x, y),
84        };
85        self.get_mut(k).0 = x;
86    }
87    #[inline]
88    fn propagate_at(&mut self, k: usize) {
89        debug_assert!(k < self.n);
90        let x = match self.seg.get_mut(&k) {
91            Some((_, x)) => replace(x, M::act_unit()),
92            None => M::act_unit(),
93        };
94        self.update_at(2 * k, &x);
95        self.update_at(2 * k + 1, &x);
96    }
97    #[inline]
98    fn propagate(&mut self, k: usize, right: bool, nofilt: bool) {
99        let right = right as usize;
100        for i in (1..(k + 1 - right).next_power_of_two().trailing_zeros()).rev() {
101            if nofilt || (k >> i) << i != k {
102                self.propagate_at((k - right) >> i);
103            }
104        }
105    }
106    #[inline]
107    fn recalc(&mut self, k: usize, right: bool, nofilt: bool) {
108        let right = right as usize;
109        for i in 1..(k + 1 - right).next_power_of_two().trailing_zeros() {
110            if nofilt || (k >> i) << i != k {
111                self.recalc_at((k - right) >> i);
112            }
113        }
114    }
115    pub fn update<R>(&mut self, range: R, x: M::Act)
116    where
117        R: RangeBounds<usize>,
118    {
119        let range = range.to_range_bounded(0, self.n).expect("invalid range");
120        let mut a = range.start + self.n;
121        let mut b = range.end + self.n;
122        self.propagate(a, false, false);
123        self.propagate(b, true, false);
124        while a < b {
125            if a & 1 != 0 {
126                self.update_at(a, &x);
127                a += 1;
128            }
129            if b & 1 != 0 {
130                b -= 1;
131                self.update_at(b, &x);
132            }
133            a /= 2;
134            b /= 2;
135        }
136        self.recalc(range.start + self.n, false, false);
137        self.recalc(range.end + self.n, true, false);
138    }
139    pub fn fold<R>(&mut self, range: R) -> M::Agg
140    where
141        R: RangeBounds<usize>,
142    {
143        let range = range.to_range_bounded(0, self.n).expect("invalid range");
144        let mut l = range.start + self.n;
145        let mut r = range.end + self.n;
146        self.propagate(l, false, true);
147        self.propagate(r, true, true);
148        let mut vl = M::agg_unit();
149        let mut vr = M::agg_unit();
150        while l < r {
151            if l & 1 != 0 {
152                if let Some((x, _)) = self.seg.get(&l) {
153                    vl = M::agg_operate(&vl, x);
154                }
155                l += 1;
156            }
157            if r & 1 != 0 {
158                r -= 1;
159                if let Some((x, _)) = self.seg.get(&r) {
160                    vr = M::agg_operate(x, &vr);
161                }
162            }
163            l /= 2;
164            r /= 2;
165        }
166        M::agg_operate(&vl, &vr)
167    }
168    pub fn set(&mut self, k: usize, x: M::Agg) {
169        let k = k + self.n;
170        self.propagate(k, false, true);
171        *self.get_mut(k) = (x, M::act_unit());
172        self.recalc(k, false, true);
173    }
174    pub fn get(&mut self, k: usize) -> M::Agg {
175        self.fold(k..k + 1)
176    }
177    pub fn fold_all(&mut self) -> M::Agg {
178        self.fold(0..self.n)
179    }
180    fn bisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
181    where
182        P: Fn(&M::Agg) -> bool,
183    {
184        while pos < self.n {
185            self.propagate_at(pos);
186            pos <<= 1;
187            let nacc = match self.seg.get(&pos) {
188                Some((x, _)) => M::agg_operate(&acc, x),
189                None => acc.clone(),
190            };
191            if !p(&nacc) {
192                acc = nacc;
193                pos += 1;
194            }
195        }
196        (pos - self.n, acc)
197    }
198    fn rbisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
199    where
200        P: Fn(&M::Agg) -> bool,
201    {
202        while pos < self.n {
203            self.propagate_at(pos);
204            pos = pos * 2 + 1;
205            let nacc = match self.seg.get(&pos) {
206                Some((x, _)) => M::agg_operate(x, &acc),
207                None => acc.clone(),
208            };
209            if !p(&nacc) {
210                acc = nacc;
211                pos -= 1;
212            }
213        }
214        (pos - self.n, acc)
215    }
216    /// Returns the first index that satisfies a accumlative predicate.
217    pub fn position_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
218    where
219        R: RangeBounds<usize>,
220        P: Fn(&M::Agg) -> bool,
221    {
222        let range = range.to_range_bounded(0, self.n).expect("invalid range");
223        let mut l = range.start + self.n;
224        let r = range.end + self.n;
225        self.propagate(l, false, true);
226        self.propagate(r, true, true);
227        let mut k = 0usize;
228        let mut acc = M::agg_unit();
229        while l < r >> k {
230            if l & 1 != 0 {
231                let nacc = match self.seg.get(&l) {
232                    Some((x, _)) => M::agg_operate(&acc, x),
233                    None => acc.clone(),
234                };
235                if p(&nacc) {
236                    return Some(self.bisect_perfect(l, acc, p).0);
237                }
238                acc = nacc;
239                l += 1;
240            }
241            l >>= 1;
242            k += 1;
243        }
244        for k in (0..k).rev() {
245            let r = r >> k;
246            if r & 1 != 0 {
247                let nacc = match self.seg.get(&(r - 1)) {
248                    Some((x, _)) => M::agg_operate(&acc, x),
249                    None => acc.clone(),
250                };
251                if p(&nacc) {
252                    return Some(self.bisect_perfect(r - 1, acc, p).0);
253                }
254                acc = nacc;
255            }
256        }
257        None
258    }
259    /// Returns the last index that satisfies a accumlative predicate.
260    pub fn rposition_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
261    where
262        R: RangeBounds<usize>,
263        P: Fn(&M::Agg) -> bool,
264    {
265        let range = range.to_range_bounded(0, self.n).expect("invalid range");
266        let mut l = range.start + self.n;
267        let mut r = range.end + self.n;
268        self.propagate(l, false, true);
269        self.propagate(r, true, true);
270        let mut c = 0usize;
271        let mut k = 0usize;
272        let mut acc = M::agg_unit();
273        while l >> k < r {
274            c <<= 1;
275            if l & (1 << k) != 0 {
276                l += 1 << k;
277                c += 1;
278            }
279            if r & 1 != 0 {
280                r -= 1;
281                let nacc = match self.seg.get(&r) {
282                    Some((x, _)) => M::agg_operate(x, &acc),
283                    None => acc.clone(),
284                };
285                if p(&nacc) {
286                    return Some(self.rbisect_perfect(r, acc, p).0);
287                }
288                acc = nacc;
289            }
290            r >>= 1;
291            k += 1;
292        }
293        for k in (0..k).rev() {
294            if c & 1 != 0 {
295                l -= 1 << k;
296                let l = l >> k;
297                let nacc = match self.seg.get(&l) {
298                    Some((x, _)) => M::agg_operate(x, &acc),
299                    None => acc.clone(),
300                };
301                if p(&nacc) {
302                    return Some(self.rbisect_perfect(l, acc, p).0);
303                }
304                acc = nacc;
305            }
306            c >>= 1;
307        }
308        None
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::{
316        algebra::{RangeMaxRangeUpdate, RangeSumRangeAdd},
317        rand,
318        tools::{NotEmptySegment, Xorshift},
319    };
320
321    const N: usize = 1_000;
322    const Q: usize = 20_000;
323    const A: i64 = 1_000_000_000;
324
325    #[test]
326    fn test_lazy_segment_tree_map() {
327        let mut rng = Xorshift::new();
328        // Range Sum Query & Range Add Query
329        let mut arr = vec![0i64; N];
330        let mut seg = LazySegmentTreeMap::<RangeSumRangeAdd<_>>::new(N);
331        for i in 0..N {
332            seg.set(i, (0i64, 1i64));
333        }
334        for _ in 0..Q {
335            rand!(rng, (l, r): NotEmptySegment(N));
336            if rng.rand(2) == 0 {
337                // Range Add Query
338                rand!(rng, x: -A..A);
339                seg.update(l..r, x);
340                for a in arr[l..r].iter_mut() {
341                    *a += x;
342                }
343            } else {
344                // Range Sum Query
345                let res = arr[l..r].iter().sum();
346                assert_eq!(seg.fold(l..r).0, res);
347            }
348        }
349
350        // Range Max Query & Range Update Query & Binary Search Query
351        let mut arr = vec![i64::MIN; N];
352        let mut seg = LazySegmentTreeMap::<RangeMaxRangeUpdate<_>>::new(N);
353        for _ in 0..Q {
354            rand!(rng, ty: 0..4, (l, r): NotEmptySegment(N));
355            match ty {
356                0 => {
357                    // Range Update Query
358                    rand!(rng, x: -A..A);
359                    seg.update(l..r, Some(x));
360                    arr[l..r].iter_mut().for_each(|a| *a = x);
361                }
362                1 => {
363                    // Range Max Query
364                    let res = arr[l..r].iter().max().cloned().unwrap_or_default();
365                    assert_eq!(seg.fold(l..r), res);
366                }
367                2 => {
368                    // Binary Search Query
369                    rand!(rng, x: -A..A);
370                    assert_eq!(
371                        seg.position_acc(l..r, |&d| d >= x),
372                        arr[l..r]
373                            .iter()
374                            .scan(i64::MIN, |acc, &a| {
375                                *acc = a.max(*acc);
376                                Some(*acc)
377                            })
378                            .position(|acc| acc >= x)
379                            .map(|i| i + l),
380                    );
381                }
382                _ => {
383                    // Binary Search Query
384                    rand!(rng, x: -A..A);
385                    assert_eq!(
386                        seg.rposition_acc(l..r, |&d| d >= x),
387                        arr[l..r]
388                            .iter()
389                            .rev()
390                            .scan(i64::MIN, |acc, &a| {
391                                *acc = a.max(*acc);
392                                Some(*acc)
393                            })
394                            .position(|acc| acc >= x)
395                            .map(|i| r - i - 1),
396                    );
397                }
398            }
399        }
400    }
401}