competitive/data_structure/
segment_tree_map.rs

1use super::{AbelianMonoid, Monoid, RangeBoundsExt};
2use std::{
3    collections::HashMap,
4    fmt::{self, Debug, Formatter},
5    ops::RangeBounds,
6};
7
8pub struct SegmentTreeMap<M>
9where
10    M: Monoid,
11{
12    n: usize,
13    seg: HashMap<usize, M::T>,
14    u: M::T,
15}
16
17impl<M> Clone for SegmentTreeMap<M>
18where
19    M: Monoid,
20{
21    fn clone(&self) -> Self {
22        Self {
23            n: self.n,
24            seg: self.seg.clone(),
25            u: self.u.clone(),
26        }
27    }
28}
29
30impl<M> Debug for SegmentTreeMap<M>
31where
32    M: Monoid<T: Debug>,
33{
34    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35        f.debug_struct("SegmentTreeMap")
36            .field("n", &self.n)
37            .field("seg", &self.seg)
38            .field("u", &self.u)
39            .finish()
40    }
41}
42
43impl<M> SegmentTreeMap<M>
44where
45    M: Monoid,
46{
47    pub fn new(n: usize) -> Self {
48        let u = M::unit();
49        Self {
50            n,
51            seg: Default::default(),
52            u,
53        }
54    }
55    #[inline]
56    fn get_ref(&self, k: usize) -> &M::T {
57        self.seg.get(&k).unwrap_or(&self.u)
58    }
59    pub fn set(&mut self, k: usize, x: M::T) {
60        debug_assert!(k < self.n);
61        let mut k = k + self.n;
62        *self.seg.entry(k).or_insert(M::unit()) = x;
63        k /= 2;
64        while k > 0 {
65            *self.seg.entry(k).or_insert(M::unit()) =
66                M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
67            k /= 2;
68        }
69    }
70    pub fn update(&mut self, k: usize, x: M::T) {
71        debug_assert!(k < self.n);
72        let mut k = k + self.n;
73        let t = self.seg.entry(k).or_insert(M::unit());
74        *t = M::operate(t, &x);
75        k /= 2;
76        while k > 0 {
77            *self.seg.entry(k).or_insert(M::unit()) =
78                M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
79            k /= 2;
80        }
81    }
82    pub fn get(&self, k: usize) -> M::T {
83        debug_assert!(k < self.n);
84        self.seg.get(&(k + self.n)).cloned().unwrap_or_else(M::unit)
85    }
86    pub fn fold<R>(&self, range: R) -> M::T
87    where
88        R: RangeBounds<usize>,
89    {
90        let range = range.to_range();
91        debug_assert!(range.end <= self.n);
92        let mut l = range.start + self.n;
93        let mut r = range.end + self.n;
94        let mut vl = M::unit();
95        let mut vr = M::unit();
96        while l < r {
97            if l & 1 != 0 {
98                vl = M::operate(&vl, self.get_ref(l));
99                l += 1;
100            }
101            if r & 1 != 0 {
102                r -= 1;
103                vr = M::operate(self.get_ref(r), &vr);
104            }
105            l /= 2;
106            r /= 2;
107        }
108        M::operate(&vl, &vr)
109    }
110    fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
111    where
112        F: Fn(&M::T) -> bool,
113    {
114        while pos < self.n {
115            pos <<= 1;
116            let nacc = M::operate(&acc, self.get_ref(pos));
117            if !f(&nacc) {
118                acc = nacc;
119                pos += 1;
120            }
121        }
122        (pos - self.n, acc)
123    }
124    fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
125    where
126        F: Fn(&M::T) -> bool,
127    {
128        while pos < self.n {
129            pos = pos * 2 + 1;
130            let nacc = M::operate(self.get_ref(pos), &acc);
131            if !f(&nacc) {
132                acc = nacc;
133                pos -= 1;
134            }
135        }
136        (pos - self.n, acc)
137    }
138    /// Returns the first index that satisfies a accumlative predicate.
139    pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
140    where
141        R: RangeBounds<usize>,
142        F: Fn(&M::T) -> bool,
143    {
144        let range = range.to_range();
145        debug_assert!(range.end <= self.n);
146        let mut l = range.start + self.n;
147        let r = range.end + self.n;
148        let mut k = 0usize;
149        let mut acc = M::unit();
150        while l < r >> k {
151            if l & 1 != 0 {
152                let nacc = M::operate(&acc, self.get_ref(l));
153                if f(&nacc) {
154                    return Some(self.bisect_perfect(l, acc, f).0);
155                }
156                acc = nacc;
157                l += 1;
158            }
159            l >>= 1;
160            k += 1;
161        }
162        for k in (0..k).rev() {
163            let r = r >> k;
164            if r & 1 != 0 {
165                let nacc = M::operate(&acc, self.get_ref(r - 1));
166                if f(&nacc) {
167                    return Some(self.bisect_perfect(r - 1, acc, f).0);
168                }
169                acc = nacc;
170            }
171        }
172        None
173    }
174    /// Returns the last index that satisfies a accumlative predicate.
175    pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
176    where
177        R: RangeBounds<usize>,
178        F: Fn(&M::T) -> bool,
179    {
180        let range = range.to_range();
181        debug_assert!(range.end <= self.n);
182        let mut l = range.start + self.n;
183        let mut r = range.end + self.n;
184        let mut c = 0usize;
185        let mut k = 0usize;
186        let mut acc = M::unit();
187        while l >> k < r {
188            c <<= 1;
189            if l & (1 << k) != 0 {
190                l += 1 << k;
191                c += 1;
192            }
193            if r & 1 != 0 {
194                r -= 1;
195                let nacc = M::operate(self.get_ref(r), &acc);
196                if f(&nacc) {
197                    return Some(self.rbisect_perfect(r, acc, f).0);
198                }
199                acc = nacc;
200            }
201            r >>= 1;
202            k += 1;
203        }
204        for k in (0..k).rev() {
205            if c & 1 != 0 {
206                l -= 1 << k;
207                let l = l >> k;
208                let nacc = M::operate(self.get_ref(l), &acc);
209                if f(&nacc) {
210                    return Some(self.rbisect_perfect(l, acc, f).0);
211                }
212                acc = nacc;
213            }
214            c >>= 1;
215        }
216        None
217    }
218}
219
220impl<M> SegmentTreeMap<M>
221where
222    M: AbelianMonoid,
223{
224    pub fn fold_all(&self) -> M::T {
225        self.seg.get(&1).cloned().unwrap_or_else(M::unit)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::{
233        algebra::{AdditiveOperation, MaxOperation},
234        algorithm::SliceBisectExt as _,
235        rand,
236        tools::{NotEmptySegment as Nes, Xorshift},
237    };
238
239    const N: usize = 1_000;
240    const Q: usize = 10_000;
241    const A: i64 = 1_000_000_000;
242
243    #[test]
244    fn test_segment_tree_map() {
245        let mut rng = Xorshift::default();
246        let mut arr = vec![0; N + 1];
247        let mut seg = SegmentTreeMap::<AdditiveOperation<_>>::new(N);
248        for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
249            seg.set(k, v);
250            arr[k + 1] = v;
251        }
252        for i in 0..N {
253            arr[i + 1] += arr[i];
254        }
255        for i in 0..N {
256            for j in i + 1..N + 1 {
257                assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
258            }
259        }
260        for v in rng.random_iter(1..=A * N as i64).take(Q) {
261            assert_eq!(
262                seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
263                arr[1..].position_bisect(|&x| x >= v)
264            );
265        }
266        for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
267            assert_eq!(
268                seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
269                arr[l + 1..r + 1].position_bisect(|&x| x >= v + arr[l]) + l
270            );
271            assert_eq!(
272                seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
273                arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
274            );
275        }
276
277        rand!(rng, mut arr: [-A..=A; N]);
278        let mut seg = SegmentTreeMap::<MaxOperation<_>>::new(N);
279        for (i, a) in arr.iter().cloned().enumerate() {
280            seg.set(i, a);
281        }
282        for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
283            seg.set(k, v);
284            arr[k] = v;
285        }
286        for (l, r) in rng.random_iter(Nes(N)).take(Q) {
287            let res = arr[l..r].iter().max().cloned().unwrap_or_default();
288            assert_eq!(seg.fold(l..r), res);
289        }
290    }
291}