competitive/data_structure/
segment_tree.rs

1use super::{AbelianMonoid, Monoid, RangeBoundsExt};
2use std::{
3    fmt::{self, Debug, Formatter},
4    ops::RangeBounds,
5};
6
7pub struct SegmentTree<M>
8where
9    M: Monoid,
10{
11    n: usize,
12    seg: Vec<M::T>,
13}
14
15impl<M> Clone for SegmentTree<M>
16where
17    M: Monoid,
18{
19    fn clone(&self) -> Self {
20        Self {
21            n: self.n,
22            seg: self.seg.clone(),
23        }
24    }
25}
26
27impl<M> Debug for SegmentTree<M>
28where
29    M: Monoid,
30    M::T: Debug,
31{
32    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
33        f.debug_struct("SegmentTree")
34            .field("n", &self.n)
35            .field("seg", &self.seg)
36            .finish()
37    }
38}
39
40impl<M> SegmentTree<M>
41where
42    M: Monoid,
43{
44    pub fn new(n: usize) -> Self {
45        let seg = vec![M::unit(); 2 * n];
46        Self { n, seg }
47    }
48    pub fn from_vec(v: Vec<M::T>) -> Self {
49        let n = v.len();
50        let mut seg = vec![M::unit(); 2 * n];
51        for (i, x) in v.into_iter().enumerate() {
52            seg[n + i] = x;
53        }
54        for i in (1..n).rev() {
55            seg[i] = M::operate(&seg[2 * i], &seg[2 * i + 1]);
56        }
57        Self { n, seg }
58    }
59    pub fn set(&mut self, k: usize, x: M::T) {
60        assert!(k < self.n);
61        let mut k = k + self.n;
62        self.seg[k] = x;
63        k /= 2;
64        while k > 0 {
65            self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
66            k /= 2;
67        }
68    }
69    pub fn clear(&mut self, k: usize) {
70        self.set(k, M::unit());
71    }
72    pub fn update(&mut self, k: usize, x: M::T) {
73        assert!(k < self.n);
74        let mut k = k + self.n;
75        self.seg[k] = M::operate(&self.seg[k], &x);
76        k /= 2;
77        while k > 0 {
78            self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
79            k /= 2;
80        }
81    }
82    pub fn get(&self, k: usize) -> M::T {
83        assert!(k < self.n);
84        self.seg[k + self.n].clone()
85    }
86    pub fn fold<R>(&self, range: R) -> M::T
87    where
88        R: RangeBounds<usize>,
89    {
90        let range = range.to_range_bounded(0, self.n).expect("invalid range");
91        let mut l = range.start + self.n;
92        let mut r = range.end + self.n;
93        let mut vl = M::unit();
94        let mut vr = M::unit();
95        while l < r {
96            if l & 1 != 0 {
97                vl = M::operate(&vl, &self.seg[l]);
98                l += 1;
99            }
100            if r & 1 != 0 {
101                r -= 1;
102                vr = M::operate(&self.seg[r], &vr);
103            }
104            l /= 2;
105            r /= 2;
106        }
107        M::operate(&vl, &vr)
108    }
109    fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
110    where
111        F: Fn(&M::T) -> bool,
112    {
113        while pos < self.n {
114            pos <<= 1;
115            let nacc = M::operate(&acc, &self.seg[pos]);
116            if !f(&nacc) {
117                acc = nacc;
118                pos += 1;
119            }
120        }
121        (pos - self.n, acc)
122    }
123    fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
124    where
125        F: Fn(&M::T) -> bool,
126    {
127        while pos < self.n {
128            pos = pos * 2 + 1;
129            let nacc = M::operate(&self.seg[pos], &acc);
130            if !f(&nacc) {
131                acc = nacc;
132                pos -= 1;
133            }
134        }
135        (pos - self.n, acc)
136    }
137    /// Returns the first index that satisfies a accumlative predicate.
138    pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
139    where
140        R: RangeBounds<usize>,
141        F: Fn(&M::T) -> bool,
142    {
143        let range = range.to_range_bounded(0, self.n).expect("invalid range");
144        let mut l = range.start + self.n;
145        let r = range.end + self.n;
146        let mut k = 0usize;
147        let mut acc = M::unit();
148        while l < r >> k {
149            if l & 1 != 0 {
150                let nacc = M::operate(&acc, &self.seg[l]);
151                if f(&nacc) {
152                    return Some(self.bisect_perfect(l, acc, f).0);
153                }
154                acc = nacc;
155                l += 1;
156            }
157            l >>= 1;
158            k += 1;
159        }
160        for k in (0..k).rev() {
161            let r = r >> k;
162            if r & 1 != 0 {
163                let nacc = M::operate(&acc, &self.seg[r - 1]);
164                if f(&nacc) {
165                    return Some(self.bisect_perfect(r - 1, acc, f).0);
166                }
167                acc = nacc;
168            }
169        }
170        None
171    }
172    /// Returns the last index that satisfies a accumlative predicate.
173    pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
174    where
175        R: RangeBounds<usize>,
176        F: Fn(&M::T) -> bool,
177    {
178        let range = range.to_range_bounded(0, self.n).expect("invalid range");
179        let mut l = range.start + self.n;
180        let mut r = range.end + self.n;
181        let mut c = 0usize;
182        let mut k = 0usize;
183        let mut acc = M::unit();
184        while l >> k < r {
185            c <<= 1;
186            if l & (1 << k) != 0 {
187                l += 1 << k;
188                c += 1;
189            }
190            if r & 1 != 0 {
191                r -= 1;
192                let nacc = M::operate(&self.seg[r], &acc);
193                if f(&nacc) {
194                    return Some(self.rbisect_perfect(r, acc, f).0);
195                }
196                acc = nacc;
197            }
198            r >>= 1;
199            k += 1;
200        }
201        for k in (0..k).rev() {
202            if c & 1 != 0 {
203                l -= 1 << k;
204                let l = l >> k;
205                let nacc = M::operate(&self.seg[l], &acc);
206                if f(&nacc) {
207                    return Some(self.rbisect_perfect(l, acc, f).0);
208                }
209                acc = nacc;
210            }
211            c >>= 1;
212        }
213        None
214    }
215    pub fn as_slice(&self) -> &[M::T] {
216        &self.seg[self.n..]
217    }
218}
219impl<M> SegmentTree<M>
220where
221    M: AbelianMonoid,
222{
223    pub fn fold_all(&self) -> M::T {
224        self.seg[1].clone()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::{
232        algebra::{AdditiveOperation, MaxOperation},
233        algorithm::SliceBisectExt as _,
234        rand,
235        tools::{NotEmptySegment as Nes, Xorshift},
236    };
237
238    const N: usize = 1_000;
239    const Q: usize = 10_000;
240    const A: i64 = 1_000_000_000;
241
242    #[test]
243    fn test_segment_tree() {
244        let mut rng = Xorshift::new();
245        let mut arr = vec![0; N + 1];
246        let mut seg = SegmentTree::<AdditiveOperation<_>>::new(N);
247        for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
248            seg.set(k, v);
249            arr[k + 1] = v;
250        }
251        for i in 0..N {
252            arr[i + 1] += arr[i];
253        }
254        for i in 0..N {
255            for j in i + 1..N + 1 {
256                assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
257            }
258        }
259        for v in rng.random_iter(1..=A * N as i64).take(Q) {
260            assert_eq!(
261                seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
262                arr[1..].position_bisect(|&x| x >= v)
263            );
264        }
265        for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
266            assert_eq!(
267                seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
268                arr[l + 1..r + 1].position_bisect(|&x| x - arr[l] >= v) + l
269            );
270            assert_eq!(
271                seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
272                arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
273            );
274        }
275
276        rand!(rng, mut arr: [-A..=A; N]);
277        let mut seg = SegmentTree::<MaxOperation<_>>::from_vec(arr.clone());
278        for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
279            seg.set(k, v);
280            arr[k] = v;
281        }
282        for (l, r) in rng.random_iter(Nes(N)).take(Q) {
283            let res = arr[l..r].iter().max().cloned().unwrap_or_default();
284            assert_eq!(seg.fold(l..r), res);
285        }
286    }
287}