competitive/data_structure/
lazy_segment_tree.rs1use super::{MonoidAction, RangeBoundsExt};
2use std::{
3 fmt::{self, Debug, Formatter},
4 mem::replace,
5 ops::RangeBounds,
6};
7
8pub struct LazySegmentTree<M>
9where
10 M: MonoidAction,
11{
12 n: usize,
13 seg: Vec<(M::Agg, M::Act)>,
14}
15
16impl<M> Clone for LazySegmentTree<M>
17where
18 M: MonoidAction,
19{
20 fn clone(&self) -> Self {
21 Self {
22 n: self.n,
23 seg: self.seg.clone(),
24 }
25 }
26}
27
28impl<M> Debug for LazySegmentTree<M>
29where
30 M: MonoidAction,
31 M::Agg: Debug,
32 M::Act: Debug,
33{
34 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35 f.debug_struct("LazySegmentTree")
36 .field("n", &self.n)
37 .field("seg", &self.seg)
38 .finish()
39 }
40}
41
42impl<M> LazySegmentTree<M>
43where
44 M: MonoidAction,
45{
46 pub fn new(n: usize) -> Self {
47 let seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
48 Self { n, seg }
49 }
50 pub fn from_vec(v: Vec<M::Agg>) -> Self {
51 let n = v.len();
52 let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
53 for (i, x) in v.into_iter().enumerate() {
54 seg[i + n].0 = x;
55 }
56 for i in (1..n).rev() {
57 seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
58 }
59 Self { n, seg }
60 }
61 pub fn from_keys(keys: impl ExactSizeIterator<Item = M::Key>) -> Self {
62 let n = keys.len();
63 let mut seg = vec![(M::agg_unit(), M::act_unit()); 2 * n];
64 for (i, key) in keys.enumerate() {
65 seg[i + n].0 = M::single_agg(&key);
66 }
67 for i in (1..n).rev() {
68 seg[i].0 = M::agg_operate(&seg[2 * i].0, &seg[2 * i + 1].0);
69 }
70 Self { n, seg }
71 }
72 #[inline]
73 fn update_at(&mut self, k: usize, x: &M::Act) {
74 let nx = M::act_agg(&self.seg[k].0, x);
75 if k < self.n {
76 self.seg[k].1 = M::act_operate(&self.seg[k].1, x);
77 }
78 if let Some(nx) = nx {
79 self.seg[k].0 = nx;
80 } else if k < self.n {
81 self.propagate_at(k);
82 self.recalc_at(k);
83 } else {
84 panic!("act failed on leaf");
85 }
86 }
87 #[inline]
88 fn recalc_at(&mut self, k: usize) {
89 self.seg[k].0 = M::agg_operate(&self.seg[2 * k].0, &self.seg[2 * k + 1].0);
90 }
91 #[inline]
92 fn propagate_at(&mut self, k: usize) {
93 debug_assert!(k < self.n);
94 let x = replace(&mut self.seg[k].1, M::act_unit());
95 self.update_at(2 * k, &x);
96 self.update_at(2 * k + 1, &x);
97 }
98 #[inline]
99 fn propagate(&mut self, k: usize, right: bool, nofilt: bool) {
100 let right = right as usize;
101 for i in (1..(k + 1 - right).next_power_of_two().trailing_zeros()).rev() {
102 if nofilt || (k >> i) << i != k {
103 self.propagate_at((k - right) >> i);
104 }
105 }
106 }
107 #[inline]
108 fn recalc(&mut self, k: usize, right: bool, nofilt: bool) {
109 let right = right as usize;
110 for i in 1..(k + 1 - right).next_power_of_two().trailing_zeros() {
111 if nofilt || (k >> i) << i != k {
112 self.recalc_at((k - right) >> i);
113 }
114 }
115 }
116 pub fn update<R>(&mut self, range: R, x: M::Act)
117 where
118 R: RangeBounds<usize>,
119 {
120 let range = range.to_range_bounded(0, self.n).expect("invalid range");
121 let mut a = range.start + self.n;
122 let mut b = range.end + self.n;
123 self.propagate(a, false, false);
124 self.propagate(b, true, false);
125 while a < b {
126 if a & 1 != 0 {
127 self.update_at(a, &x);
128 a += 1;
129 }
130 if b & 1 != 0 {
131 b -= 1;
132 self.update_at(b, &x);
133 }
134 a /= 2;
135 b /= 2;
136 }
137 self.recalc(range.start + self.n, false, false);
138 self.recalc(range.end + self.n, true, false);
139 }
140 pub fn fold<R>(&mut self, range: R) -> M::Agg
141 where
142 R: RangeBounds<usize>,
143 {
144 let range = range.to_range_bounded(0, self.n).expect("invalid range");
145 let mut l = range.start + self.n;
146 let mut r = range.end + self.n;
147 self.propagate(l, false, true);
148 self.propagate(r, true, true);
149 let mut vl = M::agg_unit();
150 let mut vr = M::agg_unit();
151 while l < r {
152 if l & 1 != 0 {
153 vl = M::agg_operate(&vl, &self.seg[l].0);
154 l += 1;
155 }
156 if r & 1 != 0 {
157 r -= 1;
158 vr = M::agg_operate(&self.seg[r].0, &vr);
159 }
160 l /= 2;
161 r /= 2;
162 }
163 M::agg_operate(&vl, &vr)
164 }
165 pub fn set(&mut self, k: usize, x: M::Agg) {
166 let k = k + self.n;
167 self.propagate(k, false, true);
168 self.seg[k] = (x, M::act_unit());
169 self.recalc(k, false, true);
170 }
171 pub fn get(&mut self, k: usize) -> M::Agg {
172 self.fold(k..k + 1)
173 }
174 pub fn fold_all(&mut self) -> M::Agg {
175 self.fold(0..self.n)
176 }
177 fn bisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
178 where
179 P: Fn(&M::Agg) -> bool,
180 {
181 while pos < self.n {
182 self.propagate_at(pos);
183 pos <<= 1;
184 let nacc = M::agg_operate(&acc, &self.seg[pos].0);
185 if !p(&nacc) {
186 acc = nacc;
187 pos += 1;
188 }
189 }
190 (pos - self.n, acc)
191 }
192 fn rbisect_perfect<P>(&mut self, mut pos: usize, mut acc: M::Agg, p: P) -> (usize, M::Agg)
193 where
194 P: Fn(&M::Agg) -> bool,
195 {
196 while pos < self.n {
197 self.propagate_at(pos);
198 pos = pos * 2 + 1;
199 let nacc = M::agg_operate(&self.seg[pos].0, &acc);
200 if !p(&nacc) {
201 acc = nacc;
202 pos -= 1;
203 }
204 }
205 (pos - self.n, acc)
206 }
207 pub fn position_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
209 where
210 R: RangeBounds<usize>,
211 P: Fn(&M::Agg) -> bool,
212 {
213 let range = range.to_range_bounded(0, self.n).expect("invalid range");
214 let mut l = range.start + self.n;
215 let r = range.end + self.n;
216 self.propagate(l, false, true);
217 self.propagate(r, true, true);
218 let mut k = 0usize;
219 let mut acc = M::agg_unit();
220 while l < r >> k {
221 if l & 1 != 0 {
222 let nacc = M::agg_operate(&acc, &self.seg[l].0);
223 if p(&nacc) {
224 return Some(self.bisect_perfect(l, acc, p).0);
225 }
226 acc = nacc;
227 l += 1;
228 }
229 l >>= 1;
230 k += 1;
231 }
232 for k in (0..k).rev() {
233 let r = r >> k;
234 if r & 1 != 0 {
235 let nacc = M::agg_operate(&acc, &self.seg[r - 1].0);
236 if p(&nacc) {
237 return Some(self.bisect_perfect(r - 1, acc, p).0);
238 }
239 acc = nacc;
240 }
241 }
242 None
243 }
244 pub fn rposition_acc<R, P>(&mut self, range: R, p: P) -> Option<usize>
246 where
247 R: RangeBounds<usize>,
248 P: Fn(&M::Agg) -> bool,
249 {
250 let range = range.to_range_bounded(0, self.n).expect("invalid range");
251 let mut l = range.start + self.n;
252 let mut r = range.end + self.n;
253 self.propagate(l, false, true);
254 self.propagate(r, true, true);
255 let mut c = 0usize;
256 let mut k = 0usize;
257 let mut acc = M::agg_unit();
258 while l >> k < r {
259 c <<= 1;
260 if l & (1 << k) != 0 {
261 l += 1 << k;
262 c += 1;
263 }
264 if r & 1 != 0 {
265 r -= 1;
266 let nacc = M::agg_operate(&self.seg[r].0, &acc);
267 if p(&nacc) {
268 return Some(self.rbisect_perfect(r, acc, p).0);
269 }
270 acc = nacc;
271 }
272 r >>= 1;
273 k += 1;
274 }
275 for k in (0..k).rev() {
276 if c & 1 != 0 {
277 l -= 1 << k;
278 let l = l >> k;
279 let nacc = M::agg_operate(&self.seg[l].0, &acc);
280 if p(&nacc) {
281 return Some(self.rbisect_perfect(l, acc, p).0);
282 }
283 acc = nacc;
284 }
285 c >>= 1;
286 }
287 None
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::{
295 algebra::{RangeMaxRangeUpdate, RangeSumRangeAdd},
296 rand,
297 tools::{NotEmptySegment, Xorshift},
298 };
299
300 const N: usize = 1_000;
301 const Q: usize = 20_000;
302 const A: i64 = 1_000_000_000;
303
304 #[test]
305 fn test_lazy_segment_tree() {
306 let mut rng = Xorshift::default();
307 rand!(rng, mut arr: [-A..A; N]);
309 let mut seg =
310 LazySegmentTree::<RangeSumRangeAdd<_>>::from_vec(arr.iter().map(|&a| (a, 1)).collect());
311 for _ in 0..Q {
312 rand!(rng, (l, r): NotEmptySegment(N));
313 if rng.rand(2) == 0 {
314 rand!(rng, x: -A..A);
316 seg.update(l..r, x);
317 for a in arr[l..r].iter_mut() {
318 *a += x;
319 }
320 } else {
321 let res = arr[l..r].iter().sum();
323 assert_eq!(seg.fold(l..r).0, res);
324 }
325 }
326
327 rand!(rng, mut arr: [-A..A; N]);
329 let mut seg = LazySegmentTree::<RangeMaxRangeUpdate<_>>::from_vec(arr.clone());
330 for _ in 0..Q {
331 rand!(rng, ty: 0..4, (l, r): NotEmptySegment(N));
332 match ty {
333 0 => {
334 rand!(rng, x: -A..A);
336 seg.update(l..r, Some(x));
337 arr[l..r].iter_mut().for_each(|a| *a = x);
338 }
339 1 => {
340 let res = arr[l..r].iter().max().cloned().unwrap_or_default();
342 assert_eq!(seg.fold(l..r), res);
343 }
344 2 => {
345 rand!(rng, x: -A..A);
347 assert_eq!(
348 seg.position_acc(l..r, |&d| d >= x),
349 arr[l..r]
350 .iter()
351 .scan(i64::MIN, |acc, &a| {
352 *acc = a.max(*acc);
353 Some(*acc)
354 })
355 .position(|acc| acc >= x)
356 .map(|i| i + l),
357 );
358 }
359 _ => {
360 rand!(rng, x: -A..A);
362 assert_eq!(
363 seg.rposition_acc(l..r, |&d| d >= x),
364 arr[l..r]
365 .iter()
366 .rev()
367 .scan(i64::MIN, |acc, &a| {
368 *acc = a.max(*acc);
369 Some(*acc)
370 })
371 .position(|acc| acc >= x)
372 .map(|i| r - i - 1),
373 );
374 }
375 }
376 }
377 }
378}