competitive/data_structure/
lazy_segment_tree_map.rs

1use super::{LazyMapMonoid, RangeBoundsExt};
2use std::{
3    collections::HashMap,
4    fmt::{self, Debug, Formatter},
5    mem::replace,
6    ops::RangeBounds,
7};
8
9pub struct LazySegmentTreeMap<M>
10where
11    M: LazyMapMonoid<Act: PartialEq>,
12{
13    n: usize,
14    seg: HashMap<usize, (M::Agg, M::Act)>,
15}
16
17impl<M> Clone for LazySegmentTreeMap<M>
18where
19    M: LazyMapMonoid<Act: PartialEq>,
20{
21    fn clone(&self) -> Self {
22        Self {
23            n: self.n,
24            seg: self.seg.clone(),
25        }
26    }
27}
28
29impl<M> Debug for LazySegmentTreeMap<M>
30where
31    M: LazyMapMonoid<Agg: Debug, Act: PartialEq + Debug>,
32{
33    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34        f.debug_struct("LazySegmentTreeMap")
35            .field("n", &self.n)
36            .field("seg", &self.seg)
37            .finish()
38    }
39}
40
41impl<M> LazySegmentTreeMap<M>
42where
43    M: LazyMapMonoid<Act: PartialEq>,
44{
45    pub fn new(n: usize) -> Self {
46        Self {
47            n,
48            seg: Default::default(),
49        }
50    }
51    #[inline]
52    fn get_mut(&mut self, k: usize) -> &mut (M::Agg, M::Act) {
53        self.seg.entry(k).or_insert((M::agg_unit(), M::act_unit()))
54    }
55    #[inline]
56    fn update_at(&mut self, k: usize, x: &M::Act) {
57        let n = self.n;
58        let a = self.get_mut(k);
59        let nx = M::act_agg(&a.0, x);
60        if k < n {
61            a.1 = M::act_operate(&a.1, x);
62        }
63        if let Some(nx) = nx {
64            a.0 = nx;
65        } else if k < n {
66            self.propagate_at(k);
67            self.recalc_at(k);
68        } else {
69            panic!("act failed on leaf");
70        }
71    }
72    #[inline]
73    fn recalc_at(&mut self, k: usize) {
74        let x = match (self.seg.get(&(2 * k)), self.seg.get(&(2 * k + 1))) {
75            (None, None) => M::agg_unit(),
76            (None, Some((y, _))) => y.clone(),
77            (Some((x, _)), None) => x.clone(),
78            (Some((x, _)), Some((y, _))) => M::agg_operate(x, y),
79        };
80        self.get_mut(k).0 = x;
81    }
82    #[inline]
83    fn propagate_at(&mut self, k: usize) {
84        debug_assert!(k < self.n);
85        let x = match self.seg.get_mut(&k) {
86            Some((_, x)) => replace(x, M::act_unit()),
87            None => M::act_unit(),
88        };
89        self.update_at(2 * k, &x);
90        self.update_at(2 * k + 1, &x);
91    }
92    #[inline]
93    fn propagate(&mut self, k: usize, right: bool, nofilt: bool) {
94        let right = right as usize;
95        for i in (1..(k + 1 - right).next_power_of_two().trailing_zeros()).rev() {
96            if nofilt || (k >> i) << i != k {
97                self.propagate_at((k - right) >> i);
98            }
99        }
100    }
101    #[inline]
102    fn recalc(&mut self, k: usize, right: bool, nofilt: bool) {
103        let right = right as usize;
104        for i in 1..(k + 1 - right).next_power_of_two().trailing_zeros() {
105            if nofilt || (k >> i) << i != k {
106                self.recalc_at((k - right) >> i);
107            }
108        }
109    }
110    pub fn update<R>(&mut self, range: R, x: M::Act)
111    where
112        R: RangeBounds<usize>,
113    {
114        let range = range.to_range_bounded(0, self.n).expect("invalid range");
115        let mut a = range.start + self.n;
116        let mut b = range.end + self.n;
117        self.propagate(a, false, false);
118        self.propagate(b, true, false);
119        while a < b {
120            if a & 1 != 0 {
121                self.update_at(a, &x);
122                a += 1;
123            }
124            if b & 1 != 0 {
125                b -= 1;
126                self.update_at(b, &x);
127            }
128            a /= 2;
129            b /= 2;
130        }
131        self.recalc(range.start + self.n, false, false);
132        self.recalc(range.end + self.n, true, false);
133    }
134    pub fn fold<R>(&mut self, range: R) -> M::Agg
135    where
136        R: RangeBounds<usize>,
137    {
138        let range = range.to_range_bounded(0, self.n).expect("invalid range");
139        let mut l = range.start + self.n;
140        let mut r = range.end + self.n;
141        self.propagate(l, false, true);
142        self.propagate(r, true, true);
143        let mut vl = M::agg_unit();
144        let mut vr = M::agg_unit();
145        while l < r {
146            if l & 1 != 0 {
147                if let Some((x, _)) = self.seg.get(&l) {
148                    vl = M::agg_operate(&vl, x);
149                }
150                l += 1;
151            }
152            if r & 1 != 0 {
153                r -= 1;
154                if let Some((x, _)) = self.seg.get(&r) {
155                    vr = M::agg_operate(x, &vr);
156                }
157            }
158            l /= 2;
159            r /= 2;
160        }
161        M::agg_operate(&vl, &vr)
162    }
163    pub fn set(&mut self, k: usize, x: M::Agg) {
164        let k = k + self.n;
165        self.propagate(k, false, true);
166        *self.get_mut(k) = (x, M::act_unit());
167        self.recalc(k, false, true);
168    }
169    pub fn get(&mut self, k: usize) -> M::Agg {
170        self.fold(k..k + 1)
171    }
172    pub fn fold_all(&mut self) -> M::Agg {
173        self.fold(0..self.n)
174    }
175    fn bisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
176    where
177        P: Fn(&M::Agg) -> bool,
178    {
179        while pos < self.n {
180            self.propagate_at(pos);
181            pos <<= 1;
182            let nacc = match self.seg.get(&pos) {
183                Some((x, _)) => M::agg_operate(&acc, x),
184                None => acc.clone(),
185            };
186            if !p(&nacc) {
187                acc = nacc;
188                pos += 1;
189            }
190        }
191        (pos - self.n, acc)
192    }
193    fn rbisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
194    where
195        P: Fn(&M::Agg) -> bool,
196    {
197        while pos < self.n {
198            self.propagate_at(pos);
199            pos = pos * 2 + 1;
200            let nacc = match self.seg.get(&pos) {
201                Some((x, _)) => M::agg_operate(x, &acc),
202                None => acc.clone(),
203            };
204            if !p(&nacc) {
205                acc = nacc;
206                pos -= 1;
207            }
208        }
209        (pos - self.n, acc)
210    }
211    /// Returns the first index that satisfies a accumlative predicate.
212    pub fn position_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
213    where
214        R: RangeBounds<usize>,
215        P: Fn(&M::Agg) -> bool,
216    {
217        let range = range.to_range_bounded(0, self.n).expect("invalid range");
218        let mut l = range.start + self.n;
219        let r = range.end + self.n;
220        self.propagate(l, false, true);
221        self.propagate(r, true, true);
222        let mut k = 0usize;
223        let mut acc = M::agg_unit();
224        while l < r >> k {
225            if l & 1 != 0 {
226                let nacc = match self.seg.get(&l) {
227                    Some((x, _)) => M::agg_operate(&acc, x),
228                    None => acc.clone(),
229                };
230                if p(&nacc) {
231                    return Some(self.bisect_perfect(l, acc, p).0);
232                }
233                acc = nacc;
234                l += 1;
235            }
236            l >>= 1;
237            k += 1;
238        }
239        for k in (0..k).rev() {
240            let r = r >> k;
241            if r & 1 != 0 {
242                let nacc = match self.seg.get(&(r - 1)) {
243                    Some((x, _)) => M::agg_operate(&acc, x),
244                    None => acc.clone(),
245                };
246                if p(&nacc) {
247                    return Some(self.bisect_perfect(r - 1, acc, p).0);
248                }
249                acc = nacc;
250            }
251        }
252        None
253    }
254    /// Returns the last index that satisfies a accumlative predicate.
255    pub fn rposition_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
256    where
257        R: RangeBounds<usize>,
258        P: Fn(&M::Agg) -> bool,
259    {
260        let range = range.to_range_bounded(0, self.n).expect("invalid range");
261        let mut l = range.start + self.n;
262        let mut r = range.end + self.n;
263        self.propagate(l, false, true);
264        self.propagate(r, true, true);
265        let mut c = 0usize;
266        let mut k = 0usize;
267        let mut acc = M::agg_unit();
268        while l >> k < r {
269            c <<= 1;
270            if l & (1 << k) != 0 {
271                l += 1 << k;
272                c += 1;
273            }
274            if r & 1 != 0 {
275                r -= 1;
276                let nacc = match self.seg.get(&r) {
277                    Some((x, _)) => M::agg_operate(x, &acc),
278                    None => acc.clone(),
279                };
280                if p(&nacc) {
281                    return Some(self.rbisect_perfect(r, acc, p).0);
282                }
283                acc = nacc;
284            }
285            r >>= 1;
286            k += 1;
287        }
288        for k in (0..k).rev() {
289            if c & 1 != 0 {
290                l -= 1 << k;
291                let l = l >> k;
292                let nacc = match self.seg.get(&l) {
293                    Some((x, _)) => M::agg_operate(x, &acc),
294                    None => acc.clone(),
295                };
296                if p(&nacc) {
297                    return Some(self.rbisect_perfect(l, acc, p).0);
298                }
299                acc = nacc;
300            }
301            c >>= 1;
302        }
303        None
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use crate::{
311        algebra::{RangeMaxRangeUpdate, RangeSumRangeAdd},
312        rand,
313        tools::{NotEmptySegment, Xorshift},
314    };
315
316    const N: usize = 1_000;
317    const Q: usize = 20_000;
318    const A: i64 = 1_000_000_000;
319
320    #[test]
321    fn test_lazy_segment_tree_map() {
322        let mut rng = Xorshift::default();
323        // Range Sum Query & Range Add Query
324        let mut arr = vec![0i64; N];
325        let mut seg = LazySegmentTreeMap::<RangeSumRangeAdd<_>>::new(N);
326        for i in 0..N {
327            seg.set(i, (0i64, 1i64));
328        }
329        for _ in 0..Q {
330            rand!(rng, (l, r): NotEmptySegment(N));
331            if rng.rand(2) == 0 {
332                // Range Add Query
333                rand!(rng, x: -A..A);
334                seg.update(l..r, x);
335                for a in arr[l..r].iter_mut() {
336                    *a += x;
337                }
338            } else {
339                // Range Sum Query
340                let res = arr[l..r].iter().sum();
341                assert_eq!(seg.fold(l..r).0, res);
342            }
343        }
344
345        // Range Max Query & Range Update Query & Binary Search Query
346        let mut arr = vec![i64::MIN; N];
347        let mut seg = LazySegmentTreeMap::<RangeMaxRangeUpdate<_>>::new(N);
348        for _ in 0..Q {
349            rand!(rng, ty: 0..4, (l, r): NotEmptySegment(N));
350            match ty {
351                0 => {
352                    // Range Update Query
353                    rand!(rng, x: -A..A);
354                    seg.update(l..r, Some(x));
355                    arr[l..r].iter_mut().for_each(|a| *a = x);
356                }
357                1 => {
358                    // Range Max Query
359                    let res = arr[l..r].iter().max().cloned().unwrap_or_default();
360                    assert_eq!(seg.fold(l..r), res);
361                }
362                2 => {
363                    // Binary Search Query
364                    rand!(rng, x: -A..A);
365                    assert_eq!(
366                        seg.position_acc(l..r, |&d| d >= x),
367                        arr[l..r]
368                            .iter()
369                            .scan(i64::MIN, |acc, &a| {
370                                *acc = a.max(*acc);
371                                Some(*acc)
372                            })
373                            .position(|acc| acc >= x)
374                            .map(|i| i + l),
375                    );
376                }
377                _ => {
378                    // Binary Search Query
379                    rand!(rng, x: -A..A);
380                    assert_eq!(
381                        seg.rposition_acc(l..r, |&d| d >= x),
382                        arr[l..r]
383                            .iter()
384                            .rev()
385                            .scan(i64::MIN, |acc, &a| {
386                                *acc = a.max(*acc);
387                                Some(*acc)
388                            })
389                            .position(|acc| acc >= x)
390                            .map(|i| r - i - 1),
391                    );
392                }
393            }
394        }
395    }
396}