competitive/data_structure/
lazy_segment_tree.rs

1use super::{LazyMapMonoid, RangeBoundsExt};
2use std::{
3    fmt::{self, Debug, Formatter},
4    mem::replace,
5    ops::RangeBounds,
6};
7
8pub struct LazySegmentTree<M>
9where
10    M: LazyMapMonoid,
11{
12    n: usize,
13    seg: Vec<(M::Agg, M::Act)>,
14}
15
16impl<M> Clone for LazySegmentTree<M>
17where
18    M: LazyMapMonoid,
19{
20    fn clone(&self) -> Self {
21        Self {
22            n: self.n,
23            seg: self.seg.clone(),
24        }
25    }
26}
27
28impl<M> Debug for LazySegmentTree<M>
29where
30    M: LazyMapMonoid<Agg: Debug, Act: Debug>,
31{
32    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
33        f.debug_struct("LazySegmentTree")
34            .field("n", &self.n)
35            .field("seg", &self.seg)
36            .finish()
37    }
38}
39
40impl<M> LazySegmentTree<M>
41where
42    M: LazyMapMonoid,
43{
44    pub fn new(n: usize) -> Self {
45        let seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
46        Self { n, seg }
47    }
48    pub fn from_vec(v: Vec<M::Agg>) -> Self {
49        let n = v.len();
50        let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
51        for (i, x) in v.into_iter().enumerate() {
52            seg[i + n].0 = x;
53        }
54        for i in (1..n).rev() {
55            seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
56        }
57        Self { n, seg }
58    }
59    pub fn from_keys(keys: impl ExactSizeIterator<Item = M::Key>) -> Self {
60        let n = keys.len();
61        let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
62        for (i, key) in keys.enumerate() {
63            seg[i + n].0 = M::single_agg(&key);
64        }
65        for i in (1..n).rev() {
66            seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
67        }
68        Self { n, seg }
69    }
70    #[inline]
71    fn update_at(&mut self, k: usize, x: &M::Act) {
72        let nx = M::act_agg(&self.seg[k].0, x);
73        if k < self.n {
74            self.seg[k].1 = M::act_operate(&self.seg[k].1, x);
75        }
76        if let Some(nx) = nx {
77            self.seg[k].0 = nx;
78        } else if k < self.n {
79            self.propagate_at(k);
80            self.recalc_at(k);
81        } else {
82            panic!("act failed on leaf");
83        }
84    }
85    #[inline]
86    fn recalc_at(&mut self, k: usize) {
87        self.seg[k].0 = M::agg_operate(&self.seg[2 * k].0, &self.seg[2 * k + 1].0);
88    }
89    #[inline]
90    fn propagate_at(&mut self, k: usize) {
91        debug_assert!(k < self.n);
92        let x = replace(&mut self.seg[k].1, M::act_unit());
93        self.update_at(2 * k, &x);
94        self.update_at(2 * k + 1, &x);
95    }
96    #[inline]
97    fn propagate(&mut self, k: usize, right: bool, nofilt: bool) {
98        let right = right as usize;
99        for i in (1..(k + 1 - right).next_power_of_two().trailing_zeros()).rev() {
100            if nofilt || (k >> i) << i != k {
101                self.propagate_at((k - right) >> i);
102            }
103        }
104    }
105    #[inline]
106    fn recalc(&mut self, k: usize, right: bool, nofilt: bool) {
107        let right = right as usize;
108        for i in 1..(k + 1 - right).next_power_of_two().trailing_zeros() {
109            if nofilt || (k >> i) << i != k {
110                self.recalc_at((k - right) >> i);
111            }
112        }
113    }
114    pub fn update<R>(&mut self, range: R, x: M::Act)
115    where
116        R: RangeBounds<usize>,
117    {
118        let range = range.to_range_bounded(0, self.n).expect("invalid range");
119        let mut a = range.start + self.n;
120        let mut b = range.end + self.n;
121        self.propagate(a, false, false);
122        self.propagate(b, true, false);
123        while a < b {
124            if a & 1 != 0 {
125                self.update_at(a, &x);
126                a += 1;
127            }
128            if b & 1 != 0 {
129                b -= 1;
130                self.update_at(b, &x);
131            }
132            a /= 2;
133            b /= 2;
134        }
135        self.recalc(range.start + self.n, false, false);
136        self.recalc(range.end + self.n, true, false);
137    }
138    pub fn fold<R>(&mut self, range: R) -> M::Agg
139    where
140        R: RangeBounds<usize>,
141    {
142        let range = range.to_range_bounded(0, self.n).expect("invalid range");
143        let mut l = range.start + self.n;
144        let mut r = range.end + self.n;
145        self.propagate(l, false, true);
146        self.propagate(r, true, true);
147        let mut vl = M::agg_unit();
148        let mut vr = M::agg_unit();
149        while l < r {
150            if l & 1 != 0 {
151                vl = M::agg_operate(&vl, &self.seg[l].0);
152                l += 1;
153            }
154            if r & 1 != 0 {
155                r -= 1;
156                vr = M::agg_operate(&self.seg[r].0, &vr);
157            }
158            l /= 2;
159            r /= 2;
160        }
161        M::agg_operate(&vl, &vr)
162    }
163    pub fn set(&mut self, k: usize, x: M::Agg) {
164        let k = k + self.n;
165        self.propagate(k, false, true);
166        self.seg[k] = (x, M::act_unit());
167        self.recalc(k, false, true);
168    }
169    pub fn get(&mut self, k: usize) -> M::Agg {
170        self.fold(k..k + 1)
171    }
172    pub fn fold_all(&mut self) -> M::Agg {
173        self.fold(0..self.n)
174    }
175    fn bisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
176    where
177        P: Fn(&M::Agg) -> bool,
178    {
179        while pos < self.n {
180            self.propagate_at(pos);
181            pos <<= 1;
182            let nacc = M::agg_operate(&acc, &self.seg[pos].0);
183            if !p(&nacc) {
184                acc = nacc;
185                pos += 1;
186            }
187        }
188        (pos - self.n, acc)
189    }
190    fn rbisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
191    where
192        P: Fn(&M::Agg) -> bool,
193    {
194        while pos < self.n {
195            self.propagate_at(pos);
196            pos = pos * 2 + 1;
197            let nacc = M::agg_operate(&self.seg[pos].0, &acc);
198            if !p(&nacc) {
199                acc = nacc;
200                pos -= 1;
201            }
202        }
203        (pos - self.n, acc)
204    }
205    /// Returns the first index that satisfies a accumlative predicate.
206    pub fn position_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
207    where
208        R: RangeBounds<usize>,
209        P: Fn(&M::Agg) -> bool,
210    {
211        let range = range.to_range_bounded(0, self.n).expect("invalid range");
212        let mut l = range.start + self.n;
213        let r = range.end + self.n;
214        self.propagate(l, false, true);
215        self.propagate(r, true, true);
216        let mut k = 0usize;
217        let mut acc = M::agg_unit();
218        while l < r >> k {
219            if l & 1 != 0 {
220                let nacc = M::agg_operate(&acc, &self.seg[l].0);
221                if p(&nacc) {
222                    return Some(self.bisect_perfect(l, acc, p).0);
223                }
224                acc = nacc;
225                l += 1;
226            }
227            l >>= 1;
228            k += 1;
229        }
230        for k in (0..k).rev() {
231            let r = r >> k;
232            if r & 1 != 0 {
233                let nacc = M::agg_operate(&acc, &self.seg[r - 1].0);
234                if p(&nacc) {
235                    return Some(self.bisect_perfect(r - 1, acc, p).0);
236                }
237                acc = nacc;
238            }
239        }
240        None
241    }
242    /// Returns the last index that satisfies a accumlative predicate.
243    pub fn rposition_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
244    where
245        R: RangeBounds<usize>,
246        P: Fn(&M::Agg) -> bool,
247    {
248        let range = range.to_range_bounded(0, self.n).expect("invalid range");
249        let mut l = range.start + self.n;
250        let mut r = range.end + self.n;
251        self.propagate(l, false, true);
252        self.propagate(r, true, true);
253        let mut c = 0usize;
254        let mut k = 0usize;
255        let mut acc = M::agg_unit();
256        while l >> k < r {
257            c <<= 1;
258            if l & (1 << k) != 0 {
259                l += 1 << k;
260                c += 1;
261            }
262            if r & 1 != 0 {
263                r -= 1;
264                let nacc = M::agg_operate(&self.seg[r].0, &acc);
265                if p(&nacc) {
266                    return Some(self.rbisect_perfect(r, acc, p).0);
267                }
268                acc = nacc;
269            }
270            r >>= 1;
271            k += 1;
272        }
273        for k in (0..k).rev() {
274            if c & 1 != 0 {
275                l -= 1 << k;
276                let l = l >> k;
277                let nacc = M::agg_operate(&self.seg[l].0, &acc);
278                if p(&nacc) {
279                    return Some(self.rbisect_perfect(l, acc, p).0);
280                }
281                acc = nacc;
282            }
283            c >>= 1;
284        }
285        None
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::{
293        algebra::{RangeMaxRangeUpdate, RangeSumRangeAdd},
294        rand,
295        tools::{NotEmptySegment, Xorshift},
296    };
297
298    const N: usize = 1_000;
299    const Q: usize = 20_000;
300    const A: i64 = 1_000_000_000;
301
302    #[test]
303    fn test_lazy_segment_tree() {
304        let mut rng = Xorshift::default();
305        // Range Sum Query & Range Add Query
306        rand!(rng, mut arr: [-A..A; N]);
307        let mut seg =
308            LazySegmentTree::<RangeSumRangeAdd<_>>::from_vec(arr.iter().map(|&a| (a, 1)).collect());
309        for _ in 0..Q {
310            rand!(rng, (l, r): NotEmptySegment(N));
311            if rng.rand(2) == 0 {
312                // Range Add Query
313                rand!(rng, x: -A..A);
314                seg.update(l..r, x);
315                for a in arr[l..r].iter_mut() {
316                    *a += x;
317                }
318            } else {
319                // Range Sum Query
320                let res = arr[l..r].iter().sum();
321                assert_eq!(seg.fold(l..r).0, res);
322            }
323        }
324
325        // Range Max Query & Range Update Query & Binary Search Query
326        rand!(rng, mut arr: [-A..A; N]);
327        let mut seg = LazySegmentTree::<RangeMaxRangeUpdate<_>>::from_vec(arr.clone());
328        for _ in 0..Q {
329            rand!(rng, ty: 0..4, (l, r): NotEmptySegment(N));
330            match ty {
331                0 => {
332                    // Range Update Query
333                    rand!(rng, x: -A..A);
334                    seg.update(l..r, Some(x));
335                    arr[l..r].iter_mut().for_each(|a| *a = x);
336                }
337                1 => {
338                    // Range Max Query
339                    let res = arr[l..r].iter().max().cloned().unwrap_or_default();
340                    assert_eq!(seg.fold(l..r), res);
341                }
342                2 => {
343                    // Binary Search Query
344                    rand!(rng, x: -A..A);
345                    assert_eq!(
346                        seg.position_acc(l..r, |&d| d >= x),
347                        arr[l..r]
348                            .iter()
349                            .scan(i64::MIN, |acc, &a| {
350                                *acc = a.max(*acc);
351                                Some(*acc)
352                            })
353                            .position(|acc| acc >= x)
354                            .map(|i| i + l),
355                    );
356                }
357                _ => {
358                    // Binary Search Query
359                    rand!(rng, x: -A..A);
360                    assert_eq!(
361                        seg.rposition_acc(l..r, |&d| d >= x),
362                        arr[l..r]
363                            .iter()
364                            .rev()
365                            .scan(i64::MIN, |acc, &a| {
366                                *acc = a.max(*acc);
367                                Some(*acc)
368                            })
369                            .position(|acc| acc >= x)
370                            .map(|i| r - i - 1),
371                    );
372                }
373            }
374        }
375    }
376}