competitive/data_structure/
segment_tree.rs

1use super::{AbelianMonoid, Monoid, RangeBoundsExt};
2use std::{
3    fmt::{self, Debug, Formatter},
4    ops::RangeBounds,
5};
6
7pub struct SegmentTree<M>
8where
9    M: Monoid,
10{
11    n: usize,
12    seg: Vec<M::T>,
13}
14
15impl<M> Clone for SegmentTree<M>
16where
17    M: Monoid,
18{
19    fn clone(&self) -> Self {
20        Self {
21            n: self.n,
22            seg: self.seg.clone(),
23        }
24    }
25}
26
27impl<M> Debug for SegmentTree<M>
28where
29    M: Monoid<T: Debug>,
30{
31    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32        f.debug_struct("SegmentTree")
33            .field("n", &self.n)
34            .field("seg", &self.seg)
35            .finish()
36    }
37}
38
39impl<M> SegmentTree<M>
40where
41    M: Monoid,
42{
43    pub fn new(n: usize) -> Self {
44        let seg = vec![M::unit(); 2 * n];
45        Self { n, seg }
46    }
47    pub fn from_vec(v: Vec<M::T>) -> Self {
48        let n = v.len();
49        let mut seg = vec![M::unit(); 2 * n];
50        for (i, x) in v.into_iter().enumerate() {
51            seg[n + i] = x;
52        }
53        for i in (1..n).rev() {
54            seg[i] = M::operate(&seg[2 * i], &seg[2 * i + 1]);
55        }
56        Self { n, seg }
57    }
58    pub fn set(&mut self, k: usize, x: M::T) {
59        assert!(k < self.n);
60        let mut k = k + self.n;
61        self.seg[k] = x;
62        k /= 2;
63        while k > 0 {
64            self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
65            k /= 2;
66        }
67    }
68    pub fn clear(&mut self, k: usize) {
69        self.set(k, M::unit());
70    }
71    pub fn update(&mut self, k: usize, x: M::T) {
72        assert!(k < self.n);
73        let mut k = k + self.n;
74        self.seg[k] = M::operate(&self.seg[k], &x);
75        k /= 2;
76        while k > 0 {
77            self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
78            k /= 2;
79        }
80    }
81    pub fn get(&self, k: usize) -> M::T {
82        assert!(k < self.n);
83        self.seg[k + self.n].clone()
84    }
85    pub fn fold<R>(&self, range: R) -> M::T
86    where
87        R: RangeBounds<usize>,
88    {
89        let range = range.to_range_bounded(0, self.n).expect("invalid range");
90        let mut l = range.start + self.n;
91        let mut r = range.end + self.n;
92        let mut vl = M::unit();
93        let mut vr = M::unit();
94        while l < r {
95            if l & 1 != 0 {
96                vl = M::operate(&vl, &self.seg[l]);
97                l += 1;
98            }
99            if r & 1 != 0 {
100                r -= 1;
101                vr = M::operate(&self.seg[r], &vr);
102            }
103            l /= 2;
104            r /= 2;
105        }
106        M::operate(&vl, &vr)
107    }
108    fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
109    where
110        F: Fn(&M::T) -> bool,
111    {
112        while pos < self.n {
113            pos <<= 1;
114            let nacc = M::operate(&acc, &self.seg[pos]);
115            if !f(&nacc) {
116                acc = nacc;
117                pos += 1;
118            }
119        }
120        (pos - self.n, acc)
121    }
122    fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
123    where
124        F: Fn(&M::T) -> bool,
125    {
126        while pos < self.n {
127            pos = pos * 2 + 1;
128            let nacc = M::operate(&self.seg[pos], &acc);
129            if !f(&nacc) {
130                acc = nacc;
131                pos -= 1;
132            }
133        }
134        (pos - self.n, acc)
135    }
136    /// Returns the first index that satisfies a accumlative predicate.
137    pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
138    where
139        R: RangeBounds<usize>,
140        F: Fn(&M::T) -> bool,
141    {
142        let range = range.to_range_bounded(0, self.n).expect("invalid range");
143        let mut l = range.start + self.n;
144        let r = range.end + self.n;
145        let mut k = 0usize;
146        let mut acc = M::unit();
147        while l < r >> k {
148            if l & 1 != 0 {
149                let nacc = M::operate(&acc, &self.seg[l]);
150                if f(&nacc) {
151                    return Some(self.bisect_perfect(l, acc, f).0);
152                }
153                acc = nacc;
154                l += 1;
155            }
156            l >>= 1;
157            k += 1;
158        }
159        for k in (0..k).rev() {
160            let r = r >> k;
161            if r & 1 != 0 {
162                let nacc = M::operate(&acc, &self.seg[r - 1]);
163                if f(&nacc) {
164                    return Some(self.bisect_perfect(r - 1, acc, f).0);
165                }
166                acc = nacc;
167            }
168        }
169        None
170    }
171    /// Returns the last index that satisfies a accumlative predicate.
172    pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
173    where
174        R: RangeBounds<usize>,
175        F: Fn(&M::T) -> bool,
176    {
177        let range = range.to_range_bounded(0, self.n).expect("invalid range");
178        let mut l = range.start + self.n;
179        let mut r = range.end + self.n;
180        let mut c = 0usize;
181        let mut k = 0usize;
182        let mut acc = M::unit();
183        while l >> k < r {
184            c <<= 1;
185            if l & (1 << k) != 0 {
186                l += 1 << k;
187                c += 1;
188            }
189            if r & 1 != 0 {
190                r -= 1;
191                let nacc = M::operate(&self.seg[r], &acc);
192                if f(&nacc) {
193                    return Some(self.rbisect_perfect(r, acc, f).0);
194                }
195                acc = nacc;
196            }
197            r >>= 1;
198            k += 1;
199        }
200        for k in (0..k).rev() {
201            if c & 1 != 0 {
202                l -= 1 << k;
203                let l = l >> k;
204                let nacc = M::operate(&self.seg[l], &acc);
205                if f(&nacc) {
206                    return Some(self.rbisect_perfect(l, acc, f).0);
207                }
208                acc = nacc;
209            }
210            c >>= 1;
211        }
212        None
213    }
214    pub fn as_slice(&self) -> &[M::T] {
215        &self.seg[self.n..]
216    }
217}
218impl<M> SegmentTree<M>
219where
220    M: AbelianMonoid,
221{
222    pub fn fold_all(&self) -> M::T {
223        self.seg[1].clone()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::{
231        algebra::{AdditiveOperation, MaxOperation},
232        algorithm::SliceBisectExt as _,
233        rand,
234        tools::{NotEmptySegment as Nes, Xorshift},
235    };
236
237    const N: usize = 1_000;
238    const Q: usize = 10_000;
239    const A: i64 = 1_000_000_000;
240
241    #[test]
242    fn test_segment_tree() {
243        let mut rng = Xorshift::default();
244        let mut arr = vec![0; N + 1];
245        let mut seg = SegmentTree::<AdditiveOperation<_>>::new(N);
246        for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
247            seg.set(k, v);
248            arr[k + 1] = v;
249        }
250        for i in 0..N {
251            arr[i + 1] += arr[i];
252        }
253        for i in 0..N {
254            for j in i + 1..N + 1 {
255                assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
256            }
257        }
258        for v in rng.random_iter(1..=A * N as i64).take(Q) {
259            assert_eq!(
260                seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
261                arr[1..].position_bisect(|&x| x >= v)
262            );
263        }
264        for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
265            assert_eq!(
266                seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
267                arr[l + 1..r + 1].position_bisect(|&x| x - arr[l] >= v) + l
268            );
269            assert_eq!(
270                seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
271                arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
272            );
273        }
274
275        rand!(rng, mut arr: [-A..=A; N]);
276        let mut seg = SegmentTree::<MaxOperation<_>>::from_vec(arr.clone());
277        for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
278            seg.set(k, v);
279            arr[k] = v;
280        }
281        for (l, r) in rng.random_iter(Nes(N)).take(Q) {
282            let res = arr[l..r].iter().max().cloned().unwrap_or_default();
283            assert_eq!(seg.fold(l..r), res);
284        }
285    }
286}