competitive/data_structure/
segment_tree.rs1use super::{AbelianMonoid, Monoid, RangeBoundsExt};
2use std::{
3 fmt::{self, Debug, Formatter},
4 ops::RangeBounds,
5};
6
7pub struct SegmentTree<M>
8where
9 M: Monoid,
10{
11 n: usize,
12 seg: Vec<M::T>,
13}
14
15impl<M> Clone for SegmentTree<M>
16where
17 M: Monoid,
18{
19 fn clone(&self) -> Self {
20 Self {
21 n: self.n,
22 seg: self.seg.clone(),
23 }
24 }
25}
26
27impl<M> Debug for SegmentTree<M>
28where
29 M: Monoid<T: Debug>,
30{
31 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
32 f.debug_struct("SegmentTree")
33 .field("n", &self.n)
34 .field("seg", &self.seg)
35 .finish()
36 }
37}
38
39impl<M> SegmentTree<M>
40where
41 M: Monoid,
42{
43 pub fn new(n: usize) -> Self {
44 let seg = vec![M::unit(); 2 * n];
45 Self { n, seg }
46 }
47 pub fn from_vec(v: Vec<M::T>) -> Self {
48 let n = v.len();
49 let mut seg = vec![M::unit(); 2 * n];
50 for (i, x) in v.into_iter().enumerate() {
51 seg[n + i] = x;
52 }
53 for i in (1..n).rev() {
54 seg[i] = M::operate(&seg[2 * i], &seg[2 * i + 1]);
55 }
56 Self { n, seg }
57 }
58 pub fn set(&mut self, k: usize, x: M::T) {
59 assert!(k < self.n);
60 let mut k = k + self.n;
61 self.seg[k] = x;
62 k /= 2;
63 while k > 0 {
64 self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
65 k /= 2;
66 }
67 }
68 pub fn clear(&mut self, k: usize) {
69 self.set(k, M::unit());
70 }
71 pub fn update(&mut self, k: usize, x: M::T) {
72 assert!(k < self.n);
73 let mut k = k + self.n;
74 self.seg[k] = M::operate(&self.seg[k], &x);
75 k /= 2;
76 while k > 0 {
77 self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
78 k /= 2;
79 }
80 }
81 pub fn get(&self, k: usize) -> M::T {
82 assert!(k < self.n);
83 self.seg[k + self.n].clone()
84 }
85 pub fn fold<R>(&self, range: R) -> M::T
86 where
87 R: RangeBounds<usize>,
88 {
89 let range = range.to_range_bounded(0, self.n).expect("invalid range");
90 let mut l = range.start + self.n;
91 let mut r = range.end + self.n;
92 let mut vl = M::unit();
93 let mut vr = M::unit();
94 while l < r {
95 if l & 1 != 0 {
96 vl = M::operate(&vl, &self.seg[l]);
97 l += 1;
98 }
99 if r & 1 != 0 {
100 r -= 1;
101 vr = M::operate(&self.seg[r], &vr);
102 }
103 l /= 2;
104 r /= 2;
105 }
106 M::operate(&vl, &vr)
107 }
108 fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
109 where
110 F: Fn(&M::T) -> bool,
111 {
112 while pos < self.n {
113 pos <<= 1;
114 let nacc = M::operate(&acc, &self.seg[pos]);
115 if !f(&nacc) {
116 acc = nacc;
117 pos += 1;
118 }
119 }
120 (pos - self.n, acc)
121 }
122 fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
123 where
124 F: Fn(&M::T) -> bool,
125 {
126 while pos < self.n {
127 pos = pos * 2 + 1;
128 let nacc = M::operate(&self.seg[pos], &acc);
129 if !f(&nacc) {
130 acc = nacc;
131 pos -= 1;
132 }
133 }
134 (pos - self.n, acc)
135 }
136 pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
138 where
139 R: RangeBounds<usize>,
140 F: Fn(&M::T) -> bool,
141 {
142 let range = range.to_range_bounded(0, self.n).expect("invalid range");
143 let mut l = range.start + self.n;
144 let r = range.end + self.n;
145 let mut k = 0usize;
146 let mut acc = M::unit();
147 while l < r >> k {
148 if l & 1 != 0 {
149 let nacc = M::operate(&acc, &self.seg[l]);
150 if f(&nacc) {
151 return Some(self.bisect_perfect(l, acc, f).0);
152 }
153 acc = nacc;
154 l += 1;
155 }
156 l >>= 1;
157 k += 1;
158 }
159 for k in (0..k).rev() {
160 let r = r >> k;
161 if r & 1 != 0 {
162 let nacc = M::operate(&acc, &self.seg[r - 1]);
163 if f(&nacc) {
164 return Some(self.bisect_perfect(r - 1, acc, f).0);
165 }
166 acc = nacc;
167 }
168 }
169 None
170 }
171 pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
173 where
174 R: RangeBounds<usize>,
175 F: Fn(&M::T) -> bool,
176 {
177 let range = range.to_range_bounded(0, self.n).expect("invalid range");
178 let mut l = range.start + self.n;
179 let mut r = range.end + self.n;
180 let mut c = 0usize;
181 let mut k = 0usize;
182 let mut acc = M::unit();
183 while l >> k < r {
184 c <<= 1;
185 if l & (1 << k) != 0 {
186 l += 1 << k;
187 c += 1;
188 }
189 if r & 1 != 0 {
190 r -= 1;
191 let nacc = M::operate(&self.seg[r], &acc);
192 if f(&nacc) {
193 return Some(self.rbisect_perfect(r, acc, f).0);
194 }
195 acc = nacc;
196 }
197 r >>= 1;
198 k += 1;
199 }
200 for k in (0..k).rev() {
201 if c & 1 != 0 {
202 l -= 1 << k;
203 let l = l >> k;
204 let nacc = M::operate(&self.seg[l], &acc);
205 if f(&nacc) {
206 return Some(self.rbisect_perfect(l, acc, f).0);
207 }
208 acc = nacc;
209 }
210 c >>= 1;
211 }
212 None
213 }
214 pub fn as_slice(&self) -> &[M::T] {
215 &self.seg[self.n..]
216 }
217}
218impl<M> SegmentTree<M>
219where
220 M: AbelianMonoid,
221{
222 pub fn fold_all(&self) -> M::T {
223 self.seg[1].clone()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::{
231 algebra::{AdditiveOperation, MaxOperation},
232 algorithm::SliceBisectExt as _,
233 rand,
234 tools::{NotEmptySegment as Nes, Xorshift},
235 };
236
237 const N: usize = 1_000;
238 const Q: usize = 10_000;
239 const A: i64 = 1_000_000_000;
240
241 #[test]
242 fn test_segment_tree() {
243 let mut rng = Xorshift::default();
244 let mut arr = vec![0; N + 1];
245 let mut seg = SegmentTree::<AdditiveOperation<_>>::new(N);
246 for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
247 seg.set(k, v);
248 arr[k + 1] = v;
249 }
250 for i in 0..N {
251 arr[i + 1] += arr[i];
252 }
253 for i in 0..N {
254 for j in i + 1..N + 1 {
255 assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
256 }
257 }
258 for v in rng.random_iter(1..=A * N as i64).take(Q) {
259 assert_eq!(
260 seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
261 arr[1..].position_bisect(|&x| x >= v)
262 );
263 }
264 for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
265 assert_eq!(
266 seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
267 arr[l + 1..r + 1].position_bisect(|&x| x - arr[l] >= v) + l
268 );
269 assert_eq!(
270 seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
271 arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
272 );
273 }
274
275 rand!(rng, mut arr: [-A..=A; N]);
276 let mut seg = SegmentTree::<MaxOperation<_>>::from_vec(arr.clone());
277 for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
278 seg.set(k, v);
279 arr[k] = v;
280 }
281 for (l, r) in rng.random_iter(Nes(N)).take(Q) {
282 let res = arr[l..r].iter().max().cloned().unwrap_or_default();
283 assert_eq!(seg.fold(l..r), res);
284 }
285 }
286}