competitive/data_structure/
segment_tree_map.rs1#![allow(clippy::or_fun_call)]
2
3use super::{AbelianMonoid, Monoid, RangeBoundsExt};
4use std::{
5 collections::HashMap,
6 fmt::{self, Debug, Formatter},
7 ops::RangeBounds,
8};
9
10pub struct SegmentTreeMap<M>
11where
12 M: Monoid,
13{
14 n: usize,
15 seg: HashMap<usize, M::T>,
16 u: M::T,
17}
18
19impl<M> Clone for SegmentTreeMap<M>
20where
21 M: Monoid,
22{
23 fn clone(&self) -> Self {
24 Self {
25 n: self.n,
26 seg: self.seg.clone(),
27 u: self.u.clone(),
28 }
29 }
30}
31
32impl<M> Debug for SegmentTreeMap<M>
33where
34 M: Monoid,
35 M::T: Debug,
36{
37 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
38 f.debug_struct("SegmentTreeMap")
39 .field("n", &self.n)
40 .field("seg", &self.seg)
41 .field("u", &self.u)
42 .finish()
43 }
44}
45
46impl<M> SegmentTreeMap<M>
47where
48 M: Monoid,
49{
50 pub fn new(n: usize) -> Self {
51 let u = M::unit();
52 Self {
53 n,
54 seg: Default::default(),
55 u,
56 }
57 }
58 #[inline]
59 fn get_ref(&self, k: usize) -> &M::T {
60 self.seg.get(&k).unwrap_or(&self.u)
61 }
62 pub fn set(&mut self, k: usize, x: M::T) {
63 debug_assert!(k < self.n);
64 let mut k = k + self.n;
65 *self.seg.entry(k).or_insert(M::unit()) = x;
66 k /= 2;
67 while k > 0 {
68 *self.seg.entry(k).or_insert(M::unit()) =
69 M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
70 k /= 2;
71 }
72 }
73 pub fn update(&mut self, k: usize, x: M::T) {
74 debug_assert!(k < self.n);
75 let mut k = k + self.n;
76 let t = self.seg.entry(k).or_insert(M::unit());
77 *t = M::operate(t, &x);
78 k /= 2;
79 while k > 0 {
80 *self.seg.entry(k).or_insert(M::unit()) =
81 M::operate(self.get_ref(2 * k), self.get_ref(2 * k + 1));
82 k /= 2;
83 }
84 }
85 pub fn get(&self, k: usize) -> M::T {
86 debug_assert!(k < self.n);
87 self.seg.get(&(k + self.n)).cloned().unwrap_or_else(M::unit)
88 }
89 pub fn fold<R>(&self, range: R) -> M::T
90 where
91 R: RangeBounds<usize>,
92 {
93 let range = range.to_range();
94 debug_assert!(range.end <= self.n);
95 let mut l = range.start + self.n;
96 let mut r = range.end + self.n;
97 let mut vl = M::unit();
98 let mut vr = M::unit();
99 while l < r {
100 if l & 1 != 0 {
101 vl = M::operate(&vl, self.get_ref(l));
102 l += 1;
103 }
104 if r & 1 != 0 {
105 r -= 1;
106 vr = M::operate(self.get_ref(r), &vr);
107 }
108 l /= 2;
109 r /= 2;
110 }
111 M::operate(&vl, &vr)
112 }
113 fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
114 where
115 F: Fn(&M::T) -> bool,
116 {
117 while pos < self.n {
118 pos <<= 1;
119 let nacc = M::operate(&acc, self.get_ref(pos));
120 if !f(&nacc) {
121 acc = nacc;
122 pos += 1;
123 }
124 }
125 (pos - self.n, acc)
126 }
127 fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
128 where
129 F: Fn(&M::T) -> bool,
130 {
131 while pos < self.n {
132 pos = pos * 2 + 1;
133 let nacc = M::operate(self.get_ref(pos), &acc);
134 if !f(&nacc) {
135 acc = nacc;
136 pos -= 1;
137 }
138 }
139 (pos - self.n, acc)
140 }
141 pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
143 where
144 R: RangeBounds<usize>,
145 F: Fn(&M::T) -> bool,
146 {
147 let range = range.to_range();
148 debug_assert!(range.end <= self.n);
149 let mut l = range.start + self.n;
150 let r = range.end + self.n;
151 let mut k = 0usize;
152 let mut acc = M::unit();
153 while l < r >> k {
154 if l & 1 != 0 {
155 let nacc = M::operate(&acc, self.get_ref(l));
156 if f(&nacc) {
157 return Some(self.bisect_perfect(l, acc, f).0);
158 }
159 acc = nacc;
160 l += 1;
161 }
162 l >>= 1;
163 k += 1;
164 }
165 for k in (0..k).rev() {
166 let r = r >> k;
167 if r & 1 != 0 {
168 let nacc = M::operate(&acc, self.get_ref(r - 1));
169 if f(&nacc) {
170 return Some(self.bisect_perfect(r - 1, acc, f).0);
171 }
172 acc = nacc;
173 }
174 }
175 None
176 }
177 pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
179 where
180 R: RangeBounds<usize>,
181 F: Fn(&M::T) -> bool,
182 {
183 let range = range.to_range();
184 debug_assert!(range.end <= self.n);
185 let mut l = range.start + self.n;
186 let mut r = range.end + self.n;
187 let mut c = 0usize;
188 let mut k = 0usize;
189 let mut acc = M::unit();
190 while l >> k < r {
191 c <<= 1;
192 if l & (1 << k) != 0 {
193 l += 1 << k;
194 c += 1;
195 }
196 if r & 1 != 0 {
197 r -= 1;
198 let nacc = M::operate(self.get_ref(r), &acc);
199 if f(&nacc) {
200 return Some(self.rbisect_perfect(r, acc, f).0);
201 }
202 acc = nacc;
203 }
204 r >>= 1;
205 k += 1;
206 }
207 for k in (0..k).rev() {
208 if c & 1 != 0 {
209 l -= 1 << k;
210 let l = l >> k;
211 let nacc = M::operate(self.get_ref(l), &acc);
212 if f(&nacc) {
213 return Some(self.rbisect_perfect(l, acc, f).0);
214 }
215 acc = nacc;
216 }
217 c >>= 1;
218 }
219 None
220 }
221}
222
223impl<M> SegmentTreeMap<M>
224where
225 M: AbelianMonoid,
226{
227 pub fn fold_all(&self) -> M::T {
228 self.seg.get(&1).cloned().unwrap_or_else(M::unit)
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::{
236 algebra::{AdditiveOperation, MaxOperation},
237 algorithm::SliceBisectExt as _,
238 rand,
239 tools::{NotEmptySegment as Nes, Xorshift},
240 };
241
242 const N: usize = 1_000;
243 const Q: usize = 10_000;
244 const A: i64 = 1_000_000_000;
245
246 #[test]
247 fn test_segment_tree_map() {
248 let mut rng = Xorshift::new();
249 let mut arr = vec![0; N + 1];
250 let mut seg = SegmentTreeMap::<AdditiveOperation<_>>::new(N);
251 for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
252 seg.set(k, v);
253 arr[k + 1] = v;
254 }
255 for i in 0..N {
256 arr[i + 1] += arr[i];
257 }
258 for i in 0..N {
259 for j in i + 1..N + 1 {
260 assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
261 }
262 }
263 for v in rng.random_iter(1..=A * N as i64).take(Q) {
264 assert_eq!(
265 seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
266 arr[1..].position_bisect(|&x| x >= v)
267 );
268 }
269 for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
270 assert_eq!(
271 seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
272 arr[l + 1..r + 1].position_bisect(|&x| x >= v + arr[l]) + l
273 );
274 assert_eq!(
275 seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
276 arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
277 );
278 }
279
280 rand!(rng, mut arr: [-A..=A; N]);
281 let mut seg = SegmentTreeMap::<MaxOperation<_>>::new(N);
282 for (i, a) in arr.iter().cloned().enumerate() {
283 seg.set(i, a);
284 }
285 for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
286 seg.set(k, v);
287 arr[k] = v;
288 }
289 for (l, r) in rng.random_iter(Nes(N)).take(Q) {
290 let res = arr[l..r].iter().max().cloned().unwrap_or_default();
291 assert_eq!(seg.fold(l..r), res);
292 }
293 }
294}