competitive/data_structure/
lazy_segment_tree.rs

1use super::{MonoidAction, RangeBoundsExt};
2use std::{
3    fmt::{self, Debug, Formatter},
4    mem::replace,
5    ops::RangeBounds,
6};
7
8pub struct LazySegmentTree<M>
9where
10    M: MonoidAction,
11{
12    n: usize,
13    seg: Vec<(M::Agg, M::Act)>,
14}
15
16impl<M> Clone for LazySegmentTree<M>
17where
18    M: MonoidAction,
19{
20    fn clone(&self) -> Self {
21        Self {
22            n: self.n,
23            seg: self.seg.clone(),
24        }
25    }
26}
27
28impl<M> Debug for LazySegmentTree<M>
29where
30    M: MonoidAction,
31    M::Agg: Debug,
32    M::Act: Debug,
33{
34    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35        f.debug_struct("LazySegmentTree")
36            .field("n", &self.n)
37            .field("seg", &self.seg)
38            .finish()
39    }
40}
41
42impl<M> LazySegmentTree<M>
43where
44    M: MonoidAction,
45{
46    pub fn new(n: usize) -> Self {
47        let seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
48        Self { n, seg }
49    }
50    pub fn from_vec(v: Vec<M::Agg>) -> Self {
51        let n = v.len();
52        let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
53        for (i, x) in v.into_iter().enumerate() {
54            seg[i + n].0 = x;
55        }
56        for i in (1..n).rev() {
57            seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
58        }
59        Self { n, seg }
60    }
61    pub fn from_keys(keys: impl ExactSizeIterator<Item = M::Key>) -> Self {
62        let n = keys.len();
63        let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
64        for (i, key) in keys.enumerate() {
65            seg[i + n].0 = M::single_agg(&key);
66        }
67        for i in (1..n).rev() {
68            seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
69        }
70        Self { n, seg }
71    }
72    #[inline]
73    fn update_at(&mut self, k: usize, x: &M::Act) {
74        let nx = M::act_agg(&self.seg[k].0, x);
75        if k < self.n {
76            self.seg[k].1 = M::act_operate(&self.seg[k].1, x);
77        }
78        if let Some(nx) = nx {
79            self.seg[k].0 = nx;
80        } else if k < self.n {
81            self.propagate_at(k);
82            self.recalc_at(k);
83        } else {
84            panic!("act failed on leaf");
85        }
86    }
87    #[inline]
88    fn recalc_at(&mut self, k: usize) {
89        self.seg[k].0 = M::agg_operate(&self.seg[2 * k].0, &self.seg[2 * k + 1].0);
90    }
91    #[inline]
92    fn propagate_at(&mut self, k: usize) {
93        debug_assert!(k < self.n);
94        let x = replace(&mut self.seg[k].1, M::act_unit());
95        self.update_at(2 * k, &x);
96        self.update_at(2 * k + 1, &x);
97    }
98    #[inline]
99    fn propagate(&mut self, k: usize, right: bool, nofilt: bool) {
100        let right = right as usize;
101        for i in (1..(k + 1 - right).next_power_of_two().trailing_zeros()).rev() {
102            if nofilt || (k >> i) << i != k {
103                self.propagate_at((k - right) >> i);
104            }
105        }
106    }
107    #[inline]
108    fn recalc(&mut self, k: usize, right: bool, nofilt: bool) {
109        let right = right as usize;
110        for i in 1..(k + 1 - right).next_power_of_two().trailing_zeros() {
111            if nofilt || (k >> i) << i != k {
112                self.recalc_at((k - right) >> i);
113            }
114        }
115    }
116    pub fn update<R>(&mut self, range: R, x: M::Act)
117    where
118        R: RangeBounds<usize>,
119    {
120        let range = range.to_range_bounded(0, self.n).expect("invalid range");
121        let mut a = range.start + self.n;
122        let mut b = range.end + self.n;
123        self.propagate(a, false, false);
124        self.propagate(b, true, false);
125        while a < b {
126            if a & 1 != 0 {
127                self.update_at(a, &x);
128                a += 1;
129            }
130            if b & 1 != 0 {
131                b -= 1;
132                self.update_at(b, &x);
133            }
134            a /= 2;
135            b /= 2;
136        }
137        self.recalc(range.start + self.n, false, false);
138        self.recalc(range.end + self.n, true, false);
139    }
140    pub fn fold<R>(&mut self, range: R) -> M::Agg
141    where
142        R: RangeBounds<usize>,
143    {
144        let range = range.to_range_bounded(0, self.n).expect("invalid range");
145        let mut l = range.start + self.n;
146        let mut r = range.end + self.n;
147        self.propagate(l, false, true);
148        self.propagate(r, true, true);
149        let mut vl = M::agg_unit();
150        let mut vr = M::agg_unit();
151        while l < r {
152            if l & 1 != 0 {
153                vl = M::agg_operate(&vl, &self.seg[l].0);
154                l += 1;
155            }
156            if r & 1 != 0 {
157                r -= 1;
158                vr = M::agg_operate(&self.seg[r].0, &vr);
159            }
160            l /= 2;
161            r /= 2;
162        }
163        M::agg_operate(&vl, &vr)
164    }
165    pub fn set(&mut self, k: usize, x: M::Agg) {
166        let k = k + self.n;
167        self.propagate(k, false, true);
168        self.seg[k] = (x, M::act_unit());
169        self.recalc(k, false, true);
170    }
171    pub fn get(&mut self, k: usize) -> M::Agg {
172        self.fold(k..k + 1)
173    }
174    pub fn fold_all(&mut self) -> M::Agg {
175        self.fold(0..self.n)
176    }
177    fn bisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
178    where
179        P: Fn(&M::Agg) -> bool,
180    {
181        while pos < self.n {
182            self.propagate_at(pos);
183            pos <<= 1;
184            let nacc = M::agg_operate(&acc, &self.seg[pos].0);
185            if !p(&nacc) {
186                acc = nacc;
187                pos += 1;
188            }
189        }
190        (pos - self.n, acc)
191    }
192    fn rbisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
193    where
194        P: Fn(&M::Agg) -> bool,
195    {
196        while pos < self.n {
197            self.propagate_at(pos);
198            pos = pos * 2 + 1;
199            let nacc = M::agg_operate(&self.seg[pos].0, &acc);
200            if !p(&nacc) {
201                acc = nacc;
202                pos -= 1;
203            }
204        }
205        (pos - self.n, acc)
206    }
207    /// Returns the first index that satisfies a accumlative predicate.
208    pub fn position_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
209    where
210        R: RangeBounds<usize>,
211        P: Fn(&M::Agg) -> bool,
212    {
213        let range = range.to_range_bounded(0, self.n).expect("invalid range");
214        let mut l = range.start + self.n;
215        let r = range.end + self.n;
216        self.propagate(l, false, true);
217        self.propagate(r, true, true);
218        let mut k = 0usize;
219        let mut acc = M::agg_unit();
220        while l < r >> k {
221            if l & 1 != 0 {
222                let nacc = M::agg_operate(&acc, &self.seg[l].0);
223                if p(&nacc) {
224                    return Some(self.bisect_perfect(l, acc, p).0);
225                }
226                acc = nacc;
227                l += 1;
228            }
229            l >>= 1;
230            k += 1;
231        }
232        for k in (0..k).rev() {
233            let r = r >> k;
234            if r & 1 != 0 {
235                let nacc = M::agg_operate(&acc, &self.seg[r - 1].0);
236                if p(&nacc) {
237                    return Some(self.bisect_perfect(r - 1, acc, p).0);
238                }
239                acc = nacc;
240            }
241        }
242        None
243    }
244    /// Returns the last index that satisfies a accumlative predicate.
245    pub fn rposition_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
246    where
247        R: RangeBounds<usize>,
248        P: Fn(&M::Agg) -> bool,
249    {
250        let range = range.to_range_bounded(0, self.n).expect("invalid range");
251        let mut l = range.start + self.n;
252        let mut r = range.end + self.n;
253        self.propagate(l, false, true);
254        self.propagate(r, true, true);
255        let mut c = 0usize;
256        let mut k = 0usize;
257        let mut acc = M::agg_unit();
258        while l >> k < r {
259            c <<= 1;
260            if l & (1 << k) != 0 {
261                l += 1 << k;
262                c += 1;
263            }
264            if r & 1 != 0 {
265                r -= 1;
266                let nacc = M::agg_operate(&self.seg[r].0, &acc);
267                if p(&nacc) {
268                    return Some(self.rbisect_perfect(r, acc, p).0);
269                }
270                acc = nacc;
271            }
272            r >>= 1;
273            k += 1;
274        }
275        for k in (0..k).rev() {
276            if c & 1 != 0 {
277                l -= 1 << k;
278                let l = l >> k;
279                let nacc = M::agg_operate(&self.seg[l].0, &acc);
280                if p(&nacc) {
281                    return Some(self.rbisect_perfect(l, acc, p).0);
282                }
283                acc = nacc;
284            }
285            c >>= 1;
286        }
287        None
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::{
295        algebra::{RangeMaxRangeUpdate, RangeSumRangeAdd},
296        rand,
297        tools::{NotEmptySegment, Xorshift},
298    };
299
300    const N: usize = 1_000;
301    const Q: usize = 20_000;
302    const A: i64 = 1_000_000_000;
303
304    #[test]
305    fn test_lazy_segment_tree() {
306        let mut rng = Xorshift::default();
307        // Range Sum Query & Range Add Query
308        rand!(rng, mut arr: [-A..A; N]);
309        let mut seg =
310            LazySegmentTree::<RangeSumRangeAdd<_>>::from_vec(arr.iter().map(|&a| (a, 1)).collect());
311        for _ in 0..Q {
312            rand!(rng, (l, r): NotEmptySegment(N));
313            if rng.rand(2) == 0 {
314                // Range Add Query
315                rand!(rng, x: -A..A);
316                seg.update(l..r, x);
317                for a in arr[l..r].iter_mut() {
318                    *a += x;
319                }
320            } else {
321                // Range Sum Query
322                let res = arr[l..r].iter().sum();
323                assert_eq!(seg.fold(l..r).0, res);
324            }
325        }
326
327        // Range Max Query & Range Update Query & Binary Search Query
328        rand!(rng, mut arr: [-A..A; N]);
329        let mut seg = LazySegmentTree::<RangeMaxRangeUpdate<_>>::from_vec(arr.clone());
330        for _ in 0..Q {
331            rand!(rng, ty: 0..4, (l, r): NotEmptySegment(N));
332            match ty {
333                0 => {
334                    // Range Update Query
335                    rand!(rng, x: -A..A);
336                    seg.update(l..r, Some(x));
337                    arr[l..r].iter_mut().for_each(|a| *a = x);
338                }
339                1 => {
340                    // Range Max Query
341                    let res = arr[l..r].iter().max().cloned().unwrap_or_default();
342                    assert_eq!(seg.fold(l..r), res);
343                }
344                2 => {
345                    // Binary Search Query
346                    rand!(rng, x: -A..A);
347                    assert_eq!(
348                        seg.position_acc(l..r, |&d| d >= x),
349                        arr[l..r]
350                            .iter()
351                            .scan(i64::MIN, |acc, &a| {
352                                *acc = a.max(*acc);
353                                Some(*acc)
354                            })
355                            .position(|acc| acc >= x)
356                            .map(|i| i + l),
357                    );
358                }
359                _ => {
360                    // Binary Search Query
361                    rand!(rng, x: -A..A);
362                    assert_eq!(
363                        seg.rposition_acc(l..r, |&d| d >= x),
364                        arr[l..r]
365                            .iter()
366                            .rev()
367                            .scan(i64::MIN, |acc, &a| {
368                                *acc = a.max(*acc);
369                                Some(*acc)
370                            })
371                            .position(|acc| acc >= x)
372                            .map(|i| r - i - 1),
373                    );
374                }
375            }
376        }
377    }
378}