competitive/data_structure/
segment_tree.rs1use super::{AbelianMonoid, Monoid, RangeBoundsExt};
2use std::{
3 fmt::{self, Debug, Formatter},
4 ops::RangeBounds,
5};
6
7pub struct SegmentTree<M>
8where
9 M: Monoid,
10{
11 n: usize,
12 seg: Vec<M::T>,
13}
14
15impl<M> Clone for SegmentTree<M>
16where
17 M: Monoid,
18{
19 fn clone(&self) -> Self {
20 Self {
21 n: self.n,
22 seg: self.seg.clone(),
23 }
24 }
25}
26
27impl<M> Debug for SegmentTree<M>
28where
29 M: Monoid,
30 M::T: Debug,
31{
32 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
33 f.debug_struct("SegmentTree")
34 .field("n", &self.n)
35 .field("seg", &self.seg)
36 .finish()
37 }
38}
39
40impl<M> SegmentTree<M>
41where
42 M: Monoid,
43{
44 pub fn new(n: usize) -> Self {
45 let seg = vec![M::unit(); 2 * n];
46 Self { n, seg }
47 }
48 pub fn from_vec(v: Vec<M::T>) -> Self {
49 let n = v.len();
50 let mut seg = vec![M::unit(); 2 * n];
51 for (i, x) in v.into_iter().enumerate() {
52 seg[n + i] = x;
53 }
54 for i in (1..n).rev() {
55 seg[i] = M::operate(&seg[2 * i], &seg[2 * i + 1]);
56 }
57 Self { n, seg }
58 }
59 pub fn set(&mut self, k: usize, x: M::T) {
60 assert!(k < self.n);
61 let mut k = k + self.n;
62 self.seg[k] = x;
63 k /= 2;
64 while k > 0 {
65 self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
66 k /= 2;
67 }
68 }
69 pub fn clear(&mut self, k: usize) {
70 self.set(k, M::unit());
71 }
72 pub fn update(&mut self, k: usize, x: M::T) {
73 assert!(k < self.n);
74 let mut k = k + self.n;
75 self.seg[k] = M::operate(&self.seg[k], &x);
76 k /= 2;
77 while k > 0 {
78 self.seg[k] = M::operate(&self.seg[2 * k], &self.seg[2 * k + 1]);
79 k /= 2;
80 }
81 }
82 pub fn get(&self, k: usize) -> M::T {
83 assert!(k < self.n);
84 self.seg[k + self.n].clone()
85 }
86 pub fn fold<R>(&self, range: R) -> M::T
87 where
88 R: RangeBounds<usize>,
89 {
90 let range = range.to_range_bounded(0, self.n).expect("invalid range");
91 let mut l = range.start + self.n;
92 let mut r = range.end + self.n;
93 let mut vl = M::unit();
94 let mut vr = M::unit();
95 while l < r {
96 if l & 1 != 0 {
97 vl = M::operate(&vl, &self.seg[l]);
98 l += 1;
99 }
100 if r & 1 != 0 {
101 r -= 1;
102 vr = M::operate(&self.seg[r], &vr);
103 }
104 l /= 2;
105 r /= 2;
106 }
107 M::operate(&vl, &vr)
108 }
109 fn bisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
110 where
111 F: Fn(&M::T) -> bool,
112 {
113 while pos < self.n {
114 pos <<= 1;
115 let nacc = M::operate(&acc, &self.seg[pos]);
116 if !f(&nacc) {
117 acc = nacc;
118 pos += 1;
119 }
120 }
121 (pos - self.n, acc)
122 }
123 fn rbisect_perfect<F>(&self, mut pos: usize, mut acc: M::T, f: F) -> (usize, M::T)
124 where
125 F: Fn(&M::T) -> bool,
126 {
127 while pos < self.n {
128 pos = pos * 2 + 1;
129 let nacc = M::operate(&self.seg[pos], &acc);
130 if !f(&nacc) {
131 acc = nacc;
132 pos -= 1;
133 }
134 }
135 (pos - self.n, acc)
136 }
137 pub fn position_acc<R, F>(&self, range: R, f: F) -> Option<usize>
139 where
140 R: RangeBounds<usize>,
141 F: Fn(&M::T) -> bool,
142 {
143 let range = range.to_range_bounded(0, self.n).expect("invalid range");
144 let mut l = range.start + self.n;
145 let r = range.end + self.n;
146 let mut k = 0usize;
147 let mut acc = M::unit();
148 while l < r >> k {
149 if l & 1 != 0 {
150 let nacc = M::operate(&acc, &self.seg[l]);
151 if f(&nacc) {
152 return Some(self.bisect_perfect(l, acc, f).0);
153 }
154 acc = nacc;
155 l += 1;
156 }
157 l >>= 1;
158 k += 1;
159 }
160 for k in (0..k).rev() {
161 let r = r >> k;
162 if r & 1 != 0 {
163 let nacc = M::operate(&acc, &self.seg[r - 1]);
164 if f(&nacc) {
165 return Some(self.bisect_perfect(r - 1, acc, f).0);
166 }
167 acc = nacc;
168 }
169 }
170 None
171 }
172 pub fn rposition_acc<R, F>(&self, range: R, f: F) -> Option<usize>
174 where
175 R: RangeBounds<usize>,
176 F: Fn(&M::T) -> bool,
177 {
178 let range = range.to_range_bounded(0, self.n).expect("invalid range");
179 let mut l = range.start + self.n;
180 let mut r = range.end + self.n;
181 let mut c = 0usize;
182 let mut k = 0usize;
183 let mut acc = M::unit();
184 while l >> k < r {
185 c <<= 1;
186 if l & (1 << k) != 0 {
187 l += 1 << k;
188 c += 1;
189 }
190 if r & 1 != 0 {
191 r -= 1;
192 let nacc = M::operate(&self.seg[r], &acc);
193 if f(&nacc) {
194 return Some(self.rbisect_perfect(r, acc, f).0);
195 }
196 acc = nacc;
197 }
198 r >>= 1;
199 k += 1;
200 }
201 for k in (0..k).rev() {
202 if c & 1 != 0 {
203 l -= 1 << k;
204 let l = l >> k;
205 let nacc = M::operate(&self.seg[l], &acc);
206 if f(&nacc) {
207 return Some(self.rbisect_perfect(l, acc, f).0);
208 }
209 acc = nacc;
210 }
211 c >>= 1;
212 }
213 None
214 }
215 pub fn as_slice(&self) -> &[M::T] {
216 &self.seg[self.n..]
217 }
218}
219impl<M> SegmentTree<M>
220where
221 M: AbelianMonoid,
222{
223 pub fn fold_all(&self) -> M::T {
224 self.seg[1].clone()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::{
232 algebra::{AdditiveOperation, MaxOperation},
233 algorithm::SliceBisectExt as _,
234 rand,
235 tools::{NotEmptySegment as Nes, Xorshift},
236 };
237
238 const N: usize = 1_000;
239 const Q: usize = 10_000;
240 const A: i64 = 1_000_000_000;
241
242 #[test]
243 fn test_segment_tree() {
244 let mut rng = Xorshift::new();
245 let mut arr = vec![0; N + 1];
246 let mut seg = SegmentTree::<AdditiveOperation<_>>::new(N);
247 for (k, v) in rng.random_iter((..N, 1..=A)).take(Q) {
248 seg.set(k, v);
249 arr[k + 1] = v;
250 }
251 for i in 0..N {
252 arr[i + 1] += arr[i];
253 }
254 for i in 0..N {
255 for j in i + 1..N + 1 {
256 assert_eq!(seg.fold(i..j), arr[j] - arr[i]);
257 }
258 }
259 for v in rng.random_iter(1..=A * N as i64).take(Q) {
260 assert_eq!(
261 seg.position_acc(0..N, |&x| v <= x).unwrap_or(N),
262 arr[1..].position_bisect(|&x| x >= v)
263 );
264 }
265 for ((l, r), v) in rng.random_iter((Nes(N), 1..=A)).take(Q) {
266 assert_eq!(
267 seg.position_acc(l..r, |&x| v <= x).unwrap_or(r),
268 arr[l + 1..r + 1].position_bisect(|&x| x - arr[l] >= v) + l
269 );
270 assert_eq!(
271 seg.rposition_acc(l..r, |&x| v <= x).map_or(l, |i| i + 1),
272 arr[l..r].rposition_bisect(|&x| arr[r] - x >= v) + l
273 );
274 }
275
276 rand!(rng, mut arr: [-A..=A; N]);
277 let mut seg = SegmentTree::<MaxOperation<_>>::from_vec(arr.clone());
278 for (k, v) in rng.random_iter((..N, -A..=A)).take(Q) {
279 seg.set(k, v);
280 arr[k] = v;
281 }
282 for (l, r) in rng.random_iter(Nes(N)).take(Q) {
283 let res = arr[l..r].iter().max().cloned().unwrap_or_default();
284 assert_eq!(seg.fold(l..r), res);
285 }
286 }
287}