competitive/data_structure/
segment_tree_map.rs

1#![allow(clippy::or_fun_call)]
2
3use super::{AbelianMonoid, Monoid, RangeBoundsExt};
4use std::{
5    collections::HashMap,
6    fmt::{self, Debug, Formatter},
7    ops::RangeBounds,
8};
9
10pub struct SegmentTreeMap<M>
11where
12    M: Monoid,
13{
14    n: usize,
15    seg: HashMap<usize, M::T>,
16    u: M::T,
17}
18
19impl<M> Clone for SegmentTreeMap<M>
20where
21    M: Monoid,
22{
23    fn clone(&self) -> Self {
24        Self {
25            n: self.n,
26            seg: self.seg.clone(),
27            u: self.u.clone(),
28        }
29    }
30}
31
32impl<M> Debug for SegmentTreeMap<M>
33where
34    M: Monoid,
35    M::T: Debug,
36{
37    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38        f.debug_struct("SegmentTreeMap")
39            .field("n", &self.n)
40            .field("seg", &self.seg)
41            .field("u", &self.u)
42            .finish()
43    }
44}
45
46impl<M> SegmentTreeMap<M>
47where
48    M: Monoid,
49{
50    pub fn new(n: usize) -> Self {
51        let u = M::unit();
52        Self {
53            n,
54            seg: Default::default(),
55            u,
56        }
57    }
58    #[inline]
59    fn get_ref(&self, k: usize) -> &M::T {
60        self.seg.get(&k).unwrap_or(&self.u)
61    }
62    pub fn set(&mut self, k: usize, x: M::T) {
63        debug_assert!(k < self.n);
64        let mut k = k + self.n;
65        *self.seg.entry(k).or_insert(M::unit()) = x;
66        k /= 2;
67        while k > 0 {
68            *self.seg.entry(k).or_insert(M::unit()) =
69                M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
70            k /= 2;
71        }
72    }
73    pub fn update(&mut self, k: usize, x: M::T) {
74        debug_assert!(k < self.n);
75        let mut k = k + self.n;
76        let t = self.seg.entry(k).or_insert(M::unit());
77        *t = M::operate(t, &x);
78        k /= 2;
79        while k > 0 {
80            *self.seg.entry(k).or_insert(M::unit()) =
81                M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
82            k /= 2;
83        }
84    }
85    pub fn get(&self, k: usize) -> M::T {
86        debug_assert!(k < self.n);
87        self.seg.get(&(k + self.n)).cloned().unwrap_or_else(M::unit)
88    }
89    pub fn fold<R>(&self, range: R) -> M::T
90    where
91        R: RangeBounds<usize>,
92    {
93        let range = range.to_range();
94        debug_assert!(range.end <= self.n);
95        let mut l = range.start + self.n;
96        let mut r = range.end + self.n;
97        let mut vl = M::unit();
98        let mut vr = M::unit();
99        while l < r {
100            if l & 1 != 0 {
101                vl = M::operate(&vl, self.get_ref(l));
102                l += 1;
103            }
104            if r & 1 != 0 {
105                r -= 1;
106                vr = M::operate(self.get_ref(r), &vr);
107            }
108            l /= 2;
109            r /= 2;
110        }
111        M::operate(&vl, &vr)
112    }
113    fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
114    where
115        F: Fn(&M::T) -> bool,
116    {
117        while pos < self.n {
118            pos <<= 1;
119            let nacc = M::operate(&acc, self.get_ref(pos));
120            if !f(&nacc) {
121                acc = nacc;
122                pos += 1;
123            }
124        }
125        (pos - self.n, acc)
126    }
127    fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
128    where
129        F: Fn(&M::T) -> bool,
130    {
131        while pos < self.n {
132            pos = pos * 2 + 1;
133            let nacc = M::operate(self.get_ref(pos), &acc);
134            if !f(&nacc) {
135                acc = nacc;
136                pos -= 1;
137            }
138        }
139        (pos - self.n, acc)
140    }
141    /// Returns the first index that satisfies a accumlative predicate.
142    pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
143    where
144        R: RangeBounds<usize>,
145        F: Fn(&M::T) -> bool,
146    {
147        let range = range.to_range();
148        debug_assert!(range.end <= self.n);
149        let mut l = range.start + self.n;
150        let r = range.end + self.n;
151        let mut k = 0usize;
152        let mut acc = M::unit();
153        while l < r >> k {
154            if l & 1 != 0 {
155                let nacc = M::operate(&acc, self.get_ref(l));
156                if f(&nacc) {
157                    return Some(self.bisect_perfect(l, acc, f).0);
158                }
159                acc = nacc;
160                l += 1;
161            }
162            l >>= 1;
163            k += 1;
164        }
165        for k in (0..k).rev() {
166            let r = r >> k;
167            if r & 1 != 0 {
168                let nacc = M::operate(&acc, self.get_ref(r - 1));
169                if f(&nacc) {
170                    return Some(self.bisect_perfect(r - 1, acc, f).0);
171                }
172                acc = nacc;
173            }
174        }
175        None
176    }
177    /// Returns the last index that satisfies a accumlative predicate.
178    pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
179    where
180        R: RangeBounds<usize>,
181        F: Fn(&M::T) -> bool,
182    {
183        let range = range.to_range();
184        debug_assert!(range.end <= self.n);
185        let mut l = range.start + self.n;
186        let mut r = range.end + self.n;
187        let mut c = 0usize;
188        let mut k = 0usize;
189        let mut acc = M::unit();
190        while l >> k < r {
191            c <<= 1;
192            if l & (1 << k) != 0 {
193                l += 1 << k;
194                c += 1;
195            }
196            if r & 1 != 0 {
197                r -= 1;
198                let nacc = M::operate(self.get_ref(r), &acc);
199                if f(&nacc) {
200                    return Some(self.rbisect_perfect(r, acc, f).0);
201                }
202                acc = nacc;
203            }
204            r >>= 1;
205            k += 1;
206        }
207        for k in (0..k).rev() {
208            if c & 1 != 0 {
209                l -= 1 << k;
210                let l = l >> k;
211                let nacc = M::operate(self.get_ref(l), &acc);
212                if f(&nacc) {
213                    return Some(self.rbisect_perfect(l, acc, f).0);
214                }
215                acc = nacc;
216            }
217            c >>= 1;
218        }
219        None
220    }
221}
222
223impl<M> SegmentTreeMap<M>
224where
225    M: AbelianMonoid,
226{
227    pub fn fold_all(&self) -> M::T {
228        self.seg.get(&1).cloned().unwrap_or_else(M::unit)
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::{
236        algebra::{AdditiveOperation, MaxOperation},
237        algorithm::SliceBisectExt as _,
238        rand,
239        tools::{NotEmptySegment as Nes, Xorshift},
240    };
241
242    const N: usize = 1_000;
243    const Q: usize = 10_000;
244    const A: i64 = 1_000_000_000;
245
246    #[test]
247    fn test_segment_tree_map() {
248        let mut rng = Xorshift::new();
249        let mut arr = vec![0; N + 1];
250        let mut seg = SegmentTreeMap::<AdditiveOperation<_>>::new(N);
251        for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
252            seg.set(k, v);
253            arr[k + 1] = v;
254        }
255        for i in 0..N {
256            arr[i + 1] += arr[i];
257        }
258        for i in 0..N {
259            for j in i + 1..N + 1 {
260                assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
261            }
262        }
263        for v in rng.random_iter(1..=A * N as i64).take(Q) {
264            assert_eq!(
265                seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
266                arr[1..].position_bisect(|&x| x >= v)
267            );
268        }
269        for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
270            assert_eq!(
271                seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
272                arr[l + 1..r + 1].position_bisect(|&x| x >= v + arr[l]) + l
273            );
274            assert_eq!(
275                seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
276                arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
277            );
278        }
279
280        rand!(rng, mut arr: [-A..=A; N]);
281        let mut seg = SegmentTreeMap::<MaxOperation<_>>::new(N);
282        for (i, a) in arr.iter().cloned().enumerate() {
283            seg.set(i, a);
284        }
285        for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
286            seg.set(k, v);
287            arr[k] = v;
288        }
289        for (l, r) in rng.random_iter(Nes(N)).take(Q) {
290            let res = arr[l..r].iter().max().cloned().unwrap_or_default();
291            assert_eq!(seg.fold(l..r), res);
292        }
293    }
294}