competitive/data_structure/
segment_tree_map.rs1use super::{AbelianMonoid, Monoid, RangeBoundsExt};
2use std::{
3 collections::HashMap,
4 fmt::{self, Debug, Formatter},
5 ops::RangeBounds,
6};
7
8pub struct SegmentTreeMap<M>
9where
10 M: Monoid,
11{
12 n: usize,
13 seg: HashMap<usize, M::T>,
14 u: M::T,
15}
16
17impl<M> Clone for SegmentTreeMap<M>
18where
19 M: Monoid,
20{
21 fn clone(&self) -> Self {
22 Self {
23 n: self.n,
24 seg: self.seg.clone(),
25 u: self.u.clone(),
26 }
27 }
28}
29
30impl<M> Debug for SegmentTreeMap<M>
31where
32 M: Monoid<T: Debug>,
33{
34 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35 f.debug_struct("SegmentTreeMap")
36 .field("n", &self.n)
37 .field("seg", &self.seg)
38 .field("u", &self.u)
39 .finish()
40 }
41}
42
43impl<M> SegmentTreeMap<M>
44where
45 M: Monoid,
46{
47 pub fn new(n: usize) -> Self {
48 let u = M::unit();
49 Self {
50 n,
51 seg: Default::default(),
52 u,
53 }
54 }
55 #[inline]
56 fn get_ref(&self, k: usize) -> &M::T {
57 self.seg.get(&k).unwrap_or(&self.u)
58 }
59 pub fn set(&mut self, k: usize, x: M::T) {
60 debug_assert!(k < self.n);
61 let mut k = k + self.n;
62 *self.seg.entry(k).or_insert(M::unit()) = x;
63 k /= 2;
64 while k > 0 {
65 *self.seg.entry(k).or_insert(M::unit()) =
66 M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
67 k /= 2;
68 }
69 }
70 pub fn update(&mut self, k: usize, x: M::T) {
71 debug_assert!(k < self.n);
72 let mut k = k + self.n;
73 let t = self.seg.entry(k).or_insert(M::unit());
74 *t = M::operate(t, &x);
75 k /= 2;
76 while k > 0 {
77 *self.seg.entry(k).or_insert(M::unit()) =
78 M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
79 k /= 2;
80 }
81 }
82 pub fn get(&self, k: usize) -> M::T {
83 debug_assert!(k < self.n);
84 self.seg.get(&(k + self.n)).cloned().unwrap_or_else(M::unit)
85 }
86 pub fn fold<R>(&self, range: R) -> M::T
87 where
88 R: RangeBounds<usize>,
89 {
90 let range = range.to_range();
91 debug_assert!(range.end <= self.n);
92 let mut l = range.start + self.n;
93 let mut r = range.end + self.n;
94 let mut vl = M::unit();
95 let mut vr = M::unit();
96 while l < r {
97 if l & 1 != 0 {
98 vl = M::operate(&vl, self.get_ref(l));
99 l += 1;
100 }
101 if r & 1 != 0 {
102 r -= 1;
103 vr = M::operate(self.get_ref(r), &vr);
104 }
105 l /= 2;
106 r /= 2;
107 }
108 M::operate(&vl, &vr)
109 }
110 fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
111 where
112 F: Fn(&M::T) -> bool,
113 {
114 while pos < self.n {
115 pos <<= 1;
116 let nacc = M::operate(&acc, self.get_ref(pos));
117 if !f(&nacc) {
118 acc = nacc;
119 pos += 1;
120 }
121 }
122 (pos - self.n, acc)
123 }
124 fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
125 where
126 F: Fn(&M::T) -> bool,
127 {
128 while pos < self.n {
129 pos = pos * 2 + 1;
130 let nacc = M::operate(self.get_ref(pos), &acc);
131 if !f(&nacc) {
132 acc = nacc;
133 pos -= 1;
134 }
135 }
136 (pos - self.n, acc)
137 }
138 pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
140 where
141 R: RangeBounds<usize>,
142 F: Fn(&M::T) -> bool,
143 {
144 let range = range.to_range();
145 debug_assert!(range.end <= self.n);
146 let mut l = range.start + self.n;
147 let r = range.end + self.n;
148 let mut k = 0usize;
149 let mut acc = M::unit();
150 while l < r >> k {
151 if l & 1 != 0 {
152 let nacc = M::operate(&acc, self.get_ref(l));
153 if f(&nacc) {
154 return Some(self.bisect_perfect(l, acc, f).0);
155 }
156 acc = nacc;
157 l += 1;
158 }
159 l >>= 1;
160 k += 1;
161 }
162 for k in (0..k).rev() {
163 let r = r >> k;
164 if r & 1 != 0 {
165 let nacc = M::operate(&acc, self.get_ref(r - 1));
166 if f(&nacc) {
167 return Some(self.bisect_perfect(r - 1, acc, f).0);
168 }
169 acc = nacc;
170 }
171 }
172 None
173 }
174 pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
176 where
177 R: RangeBounds<usize>,
178 F: Fn(&M::T) -> bool,
179 {
180 let range = range.to_range();
181 debug_assert!(range.end <= self.n);
182 let mut l = range.start + self.n;
183 let mut r = range.end + self.n;
184 let mut c = 0usize;
185 let mut k = 0usize;
186 let mut acc = M::unit();
187 while l >> k < r {
188 c <<= 1;
189 if l & (1 << k) != 0 {
190 l += 1 << k;
191 c += 1;
192 }
193 if r & 1 != 0 {
194 r -= 1;
195 let nacc = M::operate(self.get_ref(r), &acc);
196 if f(&nacc) {
197 return Some(self.rbisect_perfect(r, acc, f).0);
198 }
199 acc = nacc;
200 }
201 r >>= 1;
202 k += 1;
203 }
204 for k in (0..k).rev() {
205 if c & 1 != 0 {
206 l -= 1 << k;
207 let l = l >> k;
208 let nacc = M::operate(self.get_ref(l), &acc);
209 if f(&nacc) {
210 return Some(self.rbisect_perfect(l, acc, f).0);
211 }
212 acc = nacc;
213 }
214 c >>= 1;
215 }
216 None
217 }
218}
219
220impl<M> SegmentTreeMap<M>
221where
222 M: AbelianMonoid,
223{
224 pub fn fold_all(&self) -> M::T {
225 self.seg.get(&1).cloned().unwrap_or_else(M::unit)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::{
233 algebra::{AdditiveOperation, MaxOperation},
234 algorithm::SliceBisectExt as _,
235 rand,
236 tools::{NotEmptySegment as Nes, Xorshift},
237 };
238
239 const N: usize = 1_000;
240 const Q: usize = 10_000;
241 const A: i64 = 1_000_000_000;
242
243 #[test]
244 fn test_segment_tree_map() {
245 let mut rng = Xorshift::default();
246 let mut arr = vec![0; N + 1];
247 let mut seg = SegmentTreeMap::<AdditiveOperation<_>>::new(N);
248 for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
249 seg.set(k, v);
250 arr[k + 1] = v;
251 }
252 for i in 0..N {
253 arr[i + 1] += arr[i];
254 }
255 for i in 0..N {
256 for j in i + 1..N + 1 {
257 assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
258 }
259 }
260 for v in rng.random_iter(1..=A * N as i64).take(Q) {
261 assert_eq!(
262 seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
263 arr[1..].position_bisect(|&x| x >= v)
264 );
265 }
266 for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
267 assert_eq!(
268 seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
269 arr[l + 1..r + 1].position_bisect(|&x| x >= v + arr[l]) + l
270 );
271 assert_eq!(
272 seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
273 arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
274 );
275 }
276
277 rand!(rng, mut arr: [-A..=A; N]);
278 let mut seg = SegmentTreeMap::<MaxOperation<_>>::new(N);
279 for (i, a) in arr.iter().cloned().enumerate() {
280 seg.set(i, a);
281 }
282 for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
283 seg.set(k, v);
284 arr[k] = v;
285 }
286 for (l, r) in rng.random_iter(Nes(N)).take(Q) {
287 let res = arr[l..r].iter().max().cloned().unwrap_or_default();
288 assert_eq!(seg.fold(l..r), res);
289 }
290 }
291}