competitive/data_structure/
persistent_segment_tree.rs1use super::{Allocator, MemoryPool, Monoid, RangeBoundsExt};
2use std::{
3 fmt::{self, Debug, Formatter},
4 ops::{Range, RangeBounds},
5 ptr::NonNull,
6};
7
8type NodePtr<T> = Option<NonNull<Node<T>>>;
9
10struct Node<T> {
11 children: [NodePtr<T>; 2],
12 value: T,
13}
14
15impl<T> Node<T> {
16 fn new(children: [NodePtr<T>; 2], value: T) -> Self {
17 Self { children, value }
18 }
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22#[must_use]
23pub struct PersistentSegmentTreeVersion(usize);
24
25impl PersistentSegmentTreeVersion {
26 fn base() -> Self {
27 Self(0)
28 }
29
30 fn new(version_id: usize) -> Self {
31 Self(version_id)
32 }
33
34 fn index(self) -> usize {
35 self.0
36 }
37}
38
39pub struct PersistentSegmentTree<M>
40where
41 M: Monoid,
42{
43 len: usize,
44 version_roots: Vec<NodePtr<M::T>>,
45 allocator: MemoryPool<Node<M::T>>,
46}
47
48impl<M> Debug for PersistentSegmentTree<M>
49where
50 M: Monoid,
51{
52 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
53 f.debug_struct("PersistentSegmentTree")
54 .field("len", &self.len)
55 .field("versions", &self.version_roots.len())
56 .finish()
57 }
58}
59
60impl<M> PersistentSegmentTree<M>
61where
62 M: Monoid,
63{
64 #[must_use]
65 pub fn new(len: usize) -> Self {
66 Self {
67 len,
68 version_roots: vec![None],
69 allocator: MemoryPool::new(),
70 }
71 }
72
73 pub fn base(&self) -> PersistentSegmentTreeVersion {
74 PersistentSegmentTreeVersion::base()
75 }
76
77 pub fn len(&self) -> usize {
78 self.len
79 }
80
81 pub fn is_empty(&self) -> bool {
82 self.len == 0
83 }
84
85 fn version_root(&self, version: PersistentSegmentTreeVersion) -> NodePtr<M::T> {
86 *self
87 .version_roots
88 .get(version.index())
89 .expect("invalid version")
90 }
91
92 fn push_version_root(&mut self, root: NodePtr<M::T>) -> PersistentSegmentTreeVersion {
93 let version_id = self.version_roots.len();
94 self.version_roots.push(root);
95 PersistentSegmentTreeVersion::new(version_id)
96 }
97
98 fn allocate_node(&mut self, children: [NodePtr<M::T>; 2], value: M::T) -> NonNull<Node<M::T>> {
99 self.allocator.allocate(Node::new(children, value))
100 }
101
102 fn build_dfs(&mut self, start: usize, end: usize, values: &[M::T]) -> NodePtr<M::T> {
103 if end - start == 1 {
104 return self.leaf_node(values[start].clone());
105 }
106 let mid = (start + end) / 2;
107 let left = self.build_dfs(start, mid, values);
108 let right = self.build_dfs(mid, end, values);
109 self.merge_nodes(left, right)
110 }
111
112 fn leaf_node(&mut self, value: M::T) -> NodePtr<M::T> {
113 Some(self.allocate_node([None, None], value))
114 }
115
116 fn merge_nodes(&mut self, left: NodePtr<M::T>, right: NodePtr<M::T>) -> NodePtr<M::T> {
117 if left.is_none() && right.is_none() {
118 None
119 } else {
120 let value = M::operate(&Self::subtree_value(left), &Self::subtree_value(right));
121 Some(self.allocate_node([left, right], value))
122 }
123 }
124
125 fn subtree_value(node: NodePtr<M::T>) -> M::T {
126 node.map(|node| unsafe { node.as_ref().value.clone() })
127 .unwrap_or_else(M::unit)
128 }
129
130 fn children(node: NodePtr<M::T>) -> [NodePtr<M::T>; 2] {
131 node.map(|node| unsafe { node.as_ref().children })
132 .unwrap_or([None, None])
133 }
134
135 fn point_get_dfs(node: NodePtr<M::T>, start: usize, end: usize, index: usize) -> M::T {
136 let Some(node) = node else {
137 return M::unit();
138 };
139 let node = unsafe { node.as_ref() };
140 if end - start == 1 {
141 node.value.clone()
142 } else {
143 let mid = (start + end) / 2;
144 if index < mid {
145 Self::point_get_dfs(node.children[0], start, mid, index)
146 } else {
147 Self::point_get_dfs(node.children[1], mid, end, index)
148 }
149 }
150 }
151
152 fn fold_dfs(node: NodePtr<M::T>, start: usize, end: usize, range: &Range<usize>) -> M::T {
153 if range.end <= start || end <= range.start {
154 return M::unit();
155 }
156 let Some(node) = node else {
157 return M::unit();
158 };
159 let node = unsafe { node.as_ref() };
160 if range.start <= start && end <= range.end {
161 node.value.clone()
162 } else {
163 let mid = (start + end) / 2;
164 let left = Self::fold_dfs(node.children[0], start, mid, range);
165 let right = Self::fold_dfs(node.children[1], mid, end, range);
166 M::operate(&left, &right)
167 }
168 }
169
170 fn set_dfs(
171 &mut self,
172 node: NodePtr<M::T>,
173 start: usize,
174 end: usize,
175 index: usize,
176 value: &M::T,
177 ) -> NodePtr<M::T> {
178 if end - start == 1 {
179 return self.leaf_node(value.clone());
180 }
181 let mid = (start + end) / 2;
182 let mut children = Self::children(node);
183 if index < mid {
184 children[0] = self.set_dfs(children[0], start, mid, index, value);
185 } else {
186 children[1] = self.set_dfs(children[1], mid, end, index, value);
187 }
188 self.merge_nodes(children[0], children[1])
189 }
190
191 fn update_dfs(
192 &mut self,
193 node: NodePtr<M::T>,
194 start: usize,
195 end: usize,
196 index: usize,
197 value: &M::T,
198 ) -> NodePtr<M::T> {
199 if end - start == 1 {
200 return self.leaf_node(M::operate(&Self::subtree_value(node), value));
201 }
202 let mid = (start + end) / 2;
203 let mut children = Self::children(node);
204 if index < mid {
205 children[0] = self.update_dfs(children[0], start, mid, index, value);
206 } else {
207 children[1] = self.update_dfs(children[1], mid, end, index, value);
208 }
209 self.merge_nodes(children[0], children[1])
210 }
211
212 pub fn from_vec(&mut self, v: Vec<M::T>) -> PersistentSegmentTreeVersion {
213 assert_eq!(v.len(), self.len);
214 let root = if self.len == 0 {
215 None
216 } else {
217 self.build_dfs(0, self.len, &v)
218 };
219 self.push_version_root(root)
220 }
221
222 pub fn set(
223 &mut self,
224 version: PersistentSegmentTreeVersion,
225 index: usize,
226 value: M::T,
227 ) -> PersistentSegmentTreeVersion {
228 assert!(index < self.len);
229 let root = self.set_dfs(self.version_root(version), 0, self.len, index, &value);
230 self.push_version_root(root)
231 }
232
233 pub fn update(
234 &mut self,
235 version: PersistentSegmentTreeVersion,
236 index: usize,
237 value: M::T,
238 ) -> PersistentSegmentTreeVersion {
239 assert!(index < self.len);
240 let root = self.update_dfs(self.version_root(version), 0, self.len, index, &value);
241 self.push_version_root(root)
242 }
243
244 #[must_use]
245 pub fn get(&self, version: PersistentSegmentTreeVersion, index: usize) -> M::T {
246 assert!(index < self.len);
247 Self::point_get_dfs(self.version_root(version), 0, self.len, index)
248 }
249
250 #[must_use]
251 pub fn fold<R>(&self, version: PersistentSegmentTreeVersion, range: R) -> M::T
252 where
253 R: RangeBounds<usize>,
254 {
255 let range = range.to_range_bounded(0, self.len).expect("invalid range");
256 if range.is_empty() {
257 M::unit()
258 } else {
259 Self::fold_dfs(self.version_root(version), 0, self.len, &range)
260 }
261 }
262
263 #[must_use]
264 pub fn fold_all(&self, version: PersistentSegmentTreeVersion) -> M::T {
265 Self::subtree_value(self.version_root(version))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::{
273 algebra::ConcatenateOperation,
274 tools::{WithEmptySegment as Wes, Xorshift},
275 };
276
277 const N: usize = 12;
278 const Q: usize = 2_000;
279 const SIGMA: u8 = 6;
280
281 fn rand_word(rng: &mut Xorshift) -> Vec<u8> {
282 let len = rng.random(0..4usize);
283 (0..len).map(|_| rng.random(0..SIGMA)).collect()
284 }
285
286 #[test]
287 fn test_persistent_segment_tree_random_non_commutative() {
288 let mut rng = Xorshift::default();
289 let mut segtree: PersistentSegmentTree<ConcatenateOperation<u8>> =
290 PersistentSegmentTree::new(N);
291 let initial: Vec<_> = (0..N).map(|_| rand_word(&mut rng)).collect();
292 let mut versions = vec![segtree.base(), segtree.from_vec(initial.clone())];
293 let mut states = vec![vec![Vec::new(); N], initial];
294
295 for _ in 0..Q {
296 let base_version = rng.random(0..versions.len());
297 let index = rng.random(0..N);
298 let mut state = states[base_version].clone();
299
300 if rng.gen_bool(0.5) {
301 let value = rand_word(&mut rng);
302 state[index] = value.clone();
303 versions.push(segtree.set(versions[base_version], index, value));
304 } else {
305 let value = rand_word(&mut rng);
306 state[index].extend_from_slice(&value);
307 versions.push(segtree.update(versions[base_version], index, value));
308 }
309 states.push(state);
310
311 let version = rng.random(0..versions.len());
312 let index = rng.random(0..N);
313 let (start, end) = rng.random(Wes(N));
314 let expected: Vec<_> = states[version][start..end]
315 .iter()
316 .flat_map(|word| word.iter().copied())
317 .collect();
318 let expected_all: Vec<_> = states[version]
319 .iter()
320 .flat_map(|word| word.iter().copied())
321 .collect();
322
323 assert_eq!(
324 segtree.get(versions[version], index),
325 states[version][index]
326 );
327 assert_eq!(segtree.fold(versions[version], start..end), expected);
328 assert_eq!(segtree.fold_all(versions[version]), expected_all);
329 }
330 }
331}