competitive/data_structure/
lazy_segment_tree.rs1use super::{LazyMapMonoid, RangeBoundsExt};
2use std::{
3 fmt::{self, Debug, Formatter},
4 mem::replace,
5 ops::RangeBounds,
6};
7
8pub struct LazySegmentTree<M>
9where
10 M: LazyMapMonoid,
11{
12 n: usize,
13 seg: Vec<(M::Agg, M::Act)>,
14}
15
16impl<M> Clone for LazySegmentTree<M>
17where
18 M: LazyMapMonoid,
19{
20 fn clone(&self) -> Self {
21 Self {
22 n: self.n,
23 seg: self.seg.clone(),
24 }
25 }
26}
27
28impl<M> Debug for LazySegmentTree<M>
29where
30 M: LazyMapMonoid<Agg: Debug, Act: Debug>,
31{
32 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
33 f.debug_struct("LazySegmentTree")
34 .field("n", &self.n)
35 .field("seg", &self.seg)
36 .finish()
37 }
38}
39
40impl<M> LazySegmentTree<M>
41where
42 M: LazyMapMonoid,
43{
44 pub fn new(n: usize) -> Self {
45 let seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
46 Self { n, seg }
47 }
48 pub fn from_vec(v: Vec<M::Agg>) -> Self {
49 let n = v.len();
50 let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
51 for (i, x) in v.into_iter().enumerate() {
52 seg[i + n].0 = x;
53 }
54 for i in (1..n).rev() {
55 seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
56 }
57 Self { n, seg }
58 }
59 pub fn from_keys(keys: impl ExactSizeIterator<Item = M::Key>) -> Self {
60 let n = keys.len();
61 let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
62 for (i, key) in keys.enumerate() {
63 seg[i + n].0 = M::single_agg(&key);
64 }
65 for i in (1..n).rev() {
66 seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
67 }
68 Self { n, seg }
69 }
70 #[inline]
71 fn update_at(&mut self, k: usize, x: &M::Act) {
72 let nx = M::act_agg(&self.seg[k].0, x);
73 if k < self.n {
74 self.seg[k].1 = M::act_operate(&self.seg[k].1, x);
75 }
76 if let Some(nx) = nx {
77 self.seg[k].0 = nx;
78 } else if k < self.n {
79 self.propagate_at(k);
80 self.recalc_at(k);
81 } else {
82 panic!("act failed on leaf");
83 }
84 }
85 #[inline]
86 fn recalc_at(&mut self, k: usize) {
87 self.seg[k].0 = M::agg_operate(&self.seg[2 * k].0, &self.seg[2 * k + 1].0);
88 }
89 #[inline]
90 fn propagate_at(&mut self, k: usize) {
91 debug_assert!(k < self.n);
92 let x = replace(&mut self.seg[k].1, M::act_unit());
93 self.update_at(2 * k, &x);
94 self.update_at(2 * k + 1, &x);
95 }
96 #[inline]
97 fn propagate(&mut self, k: usize, right: bool, nofilt: bool) {
98 let right = right as usize;
99 for i in (1..(k + 1 - right).next_power_of_two().trailing_zeros()).rev() {
100 if nofilt || (k >> i) << i != k {
101 self.propagate_at((k - right) >> i);
102 }
103 }
104 }
105 #[inline]
106 fn recalc(&mut self, k: usize, right: bool, nofilt: bool) {
107 let right = right as usize;
108 for i in 1..(k + 1 - right).next_power_of_two().trailing_zeros() {
109 if nofilt || (k >> i) << i != k {
110 self.recalc_at((k - right) >> i);
111 }
112 }
113 }
114 pub fn update<R>(&mut self, range: R, x: M::Act)
115 where
116 R: RangeBounds<usize>,
117 {
118 let range = range.to_range_bounded(0, self.n).expect("invalid range");
119 let mut a = range.start + self.n;
120 let mut b = range.end + self.n;
121 self.propagate(a, false, false);
122 self.propagate(b, true, false);
123 while a < b {
124 if a & 1 != 0 {
125 self.update_at(a, &x);
126 a += 1;
127 }
128 if b & 1 != 0 {
129 b -= 1;
130 self.update_at(b, &x);
131 }
132 a /= 2;
133 b /= 2;
134 }
135 self.recalc(range.start + self.n, false, false);
136 self.recalc(range.end + self.n, true, false);
137 }
138 pub fn fold<R>(&mut self, range: R) -> M::Agg
139 where
140 R: RangeBounds<usize>,
141 {
142 let range = range.to_range_bounded(0, self.n).expect("invalid range");
143 let mut l = range.start + self.n;
144 let mut r = range.end + self.n;
145 self.propagate(l, false, true);
146 self.propagate(r, true, true);
147 let mut vl = M::agg_unit();
148 let mut vr = M::agg_unit();
149 while l < r {
150 if l & 1 != 0 {
151 vl = M::agg_operate(&vl, &self.seg[l].0);
152 l += 1;
153 }
154 if r & 1 != 0 {
155 r -= 1;
156 vr = M::agg_operate(&self.seg[r].0, &vr);
157 }
158 l /= 2;
159 r /= 2;
160 }
161 M::agg_operate(&vl, &vr)
162 }
163 pub fn set(&mut self, k: usize, x: M::Agg) {
164 let k = k + self.n;
165 self.propagate(k, false, true);
166 self.seg[k] = (x, M::act_unit());
167 self.recalc(k, false, true);
168 }
169 pub fn get(&mut self, k: usize) -> M::Agg {
170 self.fold(k..k + 1)
171 }
172 pub fn fold_all(&mut self) -> M::Agg {
173 self.fold(0..self.n)
174 }
175 fn bisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
176 where
177 P: Fn(&M::Agg) -> bool,
178 {
179 while pos < self.n {
180 self.propagate_at(pos);
181 pos <<= 1;
182 let nacc = M::agg_operate(&acc, &self.seg[pos].0);
183 if !p(&nacc) {
184 acc = nacc;
185 pos += 1;
186 }
187 }
188 (pos - self.n, acc)
189 }
190 fn rbisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
191 where
192 P: Fn(&M::Agg) -> bool,
193 {
194 while pos < self.n {
195 self.propagate_at(pos);
196 pos = pos * 2 + 1;
197 let nacc = M::agg_operate(&self.seg[pos].0, &acc);
198 if !p(&nacc) {
199 acc = nacc;
200 pos -= 1;
201 }
202 }
203 (pos - self.n, acc)
204 }
205 pub fn position_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
207 where
208 R: RangeBounds<usize>,
209 P: Fn(&M::Agg) -> bool,
210 {
211 let range = range.to_range_bounded(0, self.n).expect("invalid range");
212 let mut l = range.start + self.n;
213 let r = range.end + self.n;
214 self.propagate(l, false, true);
215 self.propagate(r, true, true);
216 let mut k = 0usize;
217 let mut acc = M::agg_unit();
218 while l < r >> k {
219 if l & 1 != 0 {
220 let nacc = M::agg_operate(&acc, &self.seg[l].0);
221 if p(&nacc) {
222 return Some(self.bisect_perfect(l, acc, p).0);
223 }
224 acc = nacc;
225 l += 1;
226 }
227 l >>= 1;
228 k += 1;
229 }
230 for k in (0..k).rev() {
231 let r = r >> k;
232 if r & 1 != 0 {
233 let nacc = M::agg_operate(&acc, &self.seg[r - 1].0);
234 if p(&nacc) {
235 return Some(self.bisect_perfect(r - 1, acc, p).0);
236 }
237 acc = nacc;
238 }
239 }
240 None
241 }
242 pub fn rposition_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
244 where
245 R: RangeBounds<usize>,
246 P: Fn(&M::Agg) -> bool,
247 {
248 let range = range.to_range_bounded(0, self.n).expect("invalid range");
249 let mut l = range.start + self.n;
250 let mut r = range.end + self.n;
251 self.propagate(l, false, true);
252 self.propagate(r, true, true);
253 let mut c = 0usize;
254 let mut k = 0usize;
255 let mut acc = M::agg_unit();
256 while l >> k < r {
257 c <<= 1;
258 if l & (1 << k) != 0 {
259 l += 1 << k;
260 c += 1;
261 }
262 if r & 1 != 0 {
263 r -= 1;
264 let nacc = M::agg_operate(&self.seg[r].0, &acc);
265 if p(&nacc) {
266 return Some(self.rbisect_perfect(r, acc, p).0);
267 }
268 acc = nacc;
269 }
270 r >>= 1;
271 k += 1;
272 }
273 for k in (0..k).rev() {
274 if c & 1 != 0 {
275 l -= 1 << k;
276 let l = l >> k;
277 let nacc = M::agg_operate(&self.seg[l].0, &acc);
278 if p(&nacc) {
279 return Some(self.rbisect_perfect(l, acc, p).0);
280 }
281 acc = nacc;
282 }
283 c >>= 1;
284 }
285 None
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::{
293 algebra::{RangeMaxRangeUpdate, RangeSumRangeAdd},
294 rand,
295 tools::{NotEmptySegment, Xorshift},
296 };
297
298 const N: usize = 1_000;
299 const Q: usize = 20_000;
300 const A: i64 = 1_000_000_000;
301
302 #[test]
303 fn test_lazy_segment_tree() {
304 let mut rng = Xorshift::default();
305 rand!(rng, mut arr: [-A..A; N]);
307 let mut seg =
308 LazySegmentTree::<RangeSumRangeAdd<_>>::from_vec(arr.iter().map(|&a| (a, 1)).collect());
309 for _ in 0..Q {
310 rand!(rng, (l, r): NotEmptySegment(N));
311 if rng.rand(2) == 0 {
312 rand!(rng, x: -A..A);
314 seg.update(l..r, x);
315 for a in arr[l..r].iter_mut() {
316 *a += x;
317 }
318 } else {
319 let res = arr[l..r].iter().sum();
321 assert_eq!(seg.fold(l..r).0, res);
322 }
323 }
324
325 rand!(rng, mut arr: [-A..A; N]);
327 let mut seg = LazySegmentTree::<RangeMaxRangeUpdate<_>>::from_vec(arr.clone());
328 for _ in 0..Q {
329 rand!(rng, ty: 0..4, (l, r): NotEmptySegment(N));
330 match ty {
331 0 => {
332 rand!(rng, x: -A..A);
334 seg.update(l..r, Some(x));
335 arr[l..r].iter_mut().for_each(|a| *a = x);
336 }
337 1 => {
338 let res = arr[l..r].iter().max().cloned().unwrap_or_default();
340 assert_eq!(seg.fold(l..r), res);
341 }
342 2 => {
343 rand!(rng, x: -A..A);
345 assert_eq!(
346 seg.position_acc(l..r, |&d| d >= x),
347 arr[l..r]
348 .iter()
349 .scan(i64::MIN, |acc, &a| {
350 *acc = a.max(*acc);
351 Some(*acc)
352 })
353 .position(|acc| acc >= x)
354 .map(|i| i + l),
355 );
356 }
357 _ => {
358 rand!(rng, x: -A..A);
360 assert_eq!(
361 seg.rposition_acc(l..r, |&d| d >= x),
362 arr[l..r]
363 .iter()
364 .rev()
365 .scan(i64::MIN, |acc, &a| {
366 *acc = a.max(*acc);
367 Some(*acc)
368 })
369 .position(|acc| acc >= x)
370 .map(|i| r - i - 1),
371 );
372 }
373 }
374 }
375 }
376}