Skip to main content

competitive/data_structure/
persistent_segment_tree.rs

1use super::{Allocator, MemoryPool, Monoid, RangeBoundsExt};
2use std::{
3    fmt::{self, Debug, Formatter},
4    ops::{Range, RangeBounds},
5    ptr::NonNull,
6};
7
8type NodePtr<T> = Option<NonNull<Node<T>>>;
9
10struct Node<T> {
11    children: [NodePtr<T>; 2],
12    value: T,
13}
14
15impl<T> Node<T> {
16    fn new(children: [NodePtr<T>; 2], value: T) -> Self {
17        Self { children, value }
18    }
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22#[must_use]
23pub struct PersistentSegmentTreeVersion(usize);
24
25impl PersistentSegmentTreeVersion {
26    fn base() -> Self {
27        Self(0)
28    }
29
30    fn new(version_id: usize) -> Self {
31        Self(version_id)
32    }
33
34    fn index(self) -> usize {
35        self.0
36    }
37}
38
39pub struct PersistentSegmentTree<M>
40where
41    M: Monoid,
42{
43    len: usize,
44    version_roots: Vec<NodePtr<M::T>>,
45    allocator: MemoryPool<Node<M::T>>,
46}
47
48impl<M> Debug for PersistentSegmentTree<M>
49where
50    M: Monoid,
51{
52    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
53        f.debug_struct("PersistentSegmentTree")
54            .field("len", &self.len)
55            .field("versions", &self.version_roots.len())
56            .finish()
57    }
58}
59
60impl<M> PersistentSegmentTree<M>
61where
62    M: Monoid,
63{
64    #[must_use]
65    pub fn new(len: usize) -> Self {
66        Self {
67            len,
68            version_roots: vec![None],
69            allocator: MemoryPool::new(),
70        }
71    }
72
73    pub fn base(&self) -> PersistentSegmentTreeVersion {
74        PersistentSegmentTreeVersion::base()
75    }
76
77    pub fn len(&self) -> usize {
78        self.len
79    }
80
81    pub fn is_empty(&self) -> bool {
82        self.len == 0
83    }
84
85    fn version_root(&self, version: PersistentSegmentTreeVersion) -> NodePtr<M::T> {
86        *self
87            .version_roots
88            .get(version.index())
89            .expect("invalid version")
90    }
91
92    fn push_version_root(&mut self, root: NodePtr<M::T>) -> PersistentSegmentTreeVersion {
93        let version_id = self.version_roots.len();
94        self.version_roots.push(root);
95        PersistentSegmentTreeVersion::new(version_id)
96    }
97
98    fn allocate_node(&mut self, children: [NodePtr<M::T>; 2], value: M::T) -> NonNull<Node<M::T>> {
99        self.allocator.allocate(Node::new(children, value))
100    }
101
102    fn build_dfs(&mut self, start: usize, end: usize, values: &[M::T]) -> NodePtr<M::T> {
103        if end - start == 1 {
104            return self.leaf_node(values[start].clone());
105        }
106        let mid = (start + end) / 2;
107        let left = self.build_dfs(start, mid, values);
108        let right = self.build_dfs(mid, end, values);
109        self.merge_nodes(left, right)
110    }
111
112    fn leaf_node(&mut self, value: M::T) -> NodePtr<M::T> {
113        Some(self.allocate_node([None, None], value))
114    }
115
116    fn merge_nodes(&mut self, left: NodePtr<M::T>, right: NodePtr<M::T>) -> NodePtr<M::T> {
117        if left.is_none() && right.is_none() {
118            None
119        } else {
120            let value = M::operate(&Self::subtree_value(left), &Self::subtree_value(right));
121            Some(self.allocate_node([left, right], value))
122        }
123    }
124
125    fn subtree_value(node: NodePtr<M::T>) -> M::T {
126        node.map(|node| unsafe { node.as_ref().value.clone() })
127            .unwrap_or_else(M::unit)
128    }
129
130    fn children(node: NodePtr<M::T>) -> [NodePtr<M::T>; 2] {
131        node.map(|node| unsafe { node.as_ref().children })
132            .unwrap_or([None, None])
133    }
134
135    fn point_get_dfs(node: NodePtr<M::T>, start: usize, end: usize, index: usize) -> M::T {
136        let Some(node) = node else {
137            return M::unit();
138        };
139        let node = unsafe { node.as_ref() };
140        if end - start == 1 {
141            node.value.clone()
142        } else {
143            let mid = (start + end) / 2;
144            if index < mid {
145                Self::point_get_dfs(node.children[0], start, mid, index)
146            } else {
147                Self::point_get_dfs(node.children[1], mid, end, index)
148            }
149        }
150    }
151
152    fn fold_dfs(node: NodePtr<M::T>, start: usize, end: usize, range: &Range<usize>) -> M::T {
153        if range.end <= start || end <= range.start {
154            return M::unit();
155        }
156        let Some(node) = node else {
157            return M::unit();
158        };
159        let node = unsafe { node.as_ref() };
160        if range.start <= start && end <= range.end {
161            node.value.clone()
162        } else {
163            let mid = (start + end) / 2;
164            let left = Self::fold_dfs(node.children[0], start, mid, range);
165            let right = Self::fold_dfs(node.children[1], mid, end, range);
166            M::operate(&left, &right)
167        }
168    }
169
170    fn set_dfs(
171        &mut self,
172        node: NodePtr<M::T>,
173        start: usize,
174        end: usize,
175        index: usize,
176        value: &M::T,
177    ) -> NodePtr<M::T> {
178        if end - start == 1 {
179            return self.leaf_node(value.clone());
180        }
181        let mid = (start + end) / 2;
182        let mut children = Self::children(node);
183        if index < mid {
184            children[0] = self.set_dfs(children[0], start, mid, index, value);
185        } else {
186            children[1] = self.set_dfs(children[1], mid, end, index, value);
187        }
188        self.merge_nodes(children[0], children[1])
189    }
190
191    fn update_dfs(
192        &mut self,
193        node: NodePtr<M::T>,
194        start: usize,
195        end: usize,
196        index: usize,
197        value: &M::T,
198    ) -> NodePtr<M::T> {
199        if end - start == 1 {
200            return self.leaf_node(M::operate(&Self::subtree_value(node), value));
201        }
202        let mid = (start + end) / 2;
203        let mut children = Self::children(node);
204        if index < mid {
205            children[0] = self.update_dfs(children[0], start, mid, index, value);
206        } else {
207            children[1] = self.update_dfs(children[1], mid, end, index, value);
208        }
209        self.merge_nodes(children[0], children[1])
210    }
211
212    pub fn from_vec(&mut self, v: Vec<M::T>) -> PersistentSegmentTreeVersion {
213        assert_eq!(v.len(), self.len);
214        let root = if self.len == 0 {
215            None
216        } else {
217            self.build_dfs(0, self.len, &v)
218        };
219        self.push_version_root(root)
220    }
221
222    pub fn set(
223        &mut self,
224        version: PersistentSegmentTreeVersion,
225        index: usize,
226        value: M::T,
227    ) -> PersistentSegmentTreeVersion {
228        assert!(index < self.len);
229        let root = self.set_dfs(self.version_root(version), 0, self.len, index, &value);
230        self.push_version_root(root)
231    }
232
233    pub fn update(
234        &mut self,
235        version: PersistentSegmentTreeVersion,
236        index: usize,
237        value: M::T,
238    ) -> PersistentSegmentTreeVersion {
239        assert!(index < self.len);
240        let root = self.update_dfs(self.version_root(version), 0, self.len, index, &value);
241        self.push_version_root(root)
242    }
243
244    #[must_use]
245    pub fn get(&self, version: PersistentSegmentTreeVersion, index: usize) -> M::T {
246        assert!(index < self.len);
247        Self::point_get_dfs(self.version_root(version), 0, self.len, index)
248    }
249
250    #[must_use]
251    pub fn fold<R>(&self, version: PersistentSegmentTreeVersion, range: R) -> M::T
252    where
253        R: RangeBounds<usize>,
254    {
255        let range = range.to_range_bounded(0, self.len).expect("invalid range");
256        if range.is_empty() {
257            M::unit()
258        } else {
259            Self::fold_dfs(self.version_root(version), 0, self.len, &range)
260        }
261    }
262
263    #[must_use]
264    pub fn fold_all(&self, version: PersistentSegmentTreeVersion) -> M::T {
265        Self::subtree_value(self.version_root(version))
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::{
273        algebra::ConcatenateOperation,
274        tools::{WithEmptySegment as Wes, Xorshift},
275    };
276
277    const N: usize = 12;
278    const Q: usize = 2_000;
279    const SIGMA: u8 = 6;
280
281    fn rand_word(rng: &mut Xorshift) -> Vec<u8> {
282        let len = rng.random(0..4usize);
283        (0..len).map(|_| rng.random(0..SIGMA)).collect()
284    }
285
286    #[test]
287    fn test_persistent_segment_tree_random_non_commutative() {
288        let mut rng = Xorshift::default();
289        let mut segtree: PersistentSegmentTree<ConcatenateOperation<u8>> =
290            PersistentSegmentTree::new(N);
291        let initial: Vec<_> = (0..N).map(|_| rand_word(&mut rng)).collect();
292        let mut versions = vec![segtree.base(), segtree.from_vec(initial.clone())];
293        let mut states = vec![vec![Vec::new(); N], initial];
294
295        for _ in 0..Q {
296            let base_version = rng.random(0..versions.len());
297            let index = rng.random(0..N);
298            let mut state = states[base_version].clone();
299
300            if rng.gen_bool(0.5) {
301                let value = rand_word(&mut rng);
302                state[index] = value.clone();
303                versions.push(segtree.set(versions[base_version], index, value));
304            } else {
305                let value = rand_word(&mut rng);
306                state[index].extend_from_slice(&value);
307                versions.push(segtree.update(versions[base_version], index, value));
308            }
309            states.push(state);
310
311            let version = rng.random(0..versions.len());
312            let index = rng.random(0..N);
313            let (start, end) = rng.random(Wes(N));
314            let expected: Vec<_> = states[version][start..end]
315                .iter()
316                .flat_map(|word| word.iter().copied())
317                .collect();
318            let expected_all: Vec<_> = states[version]
319                .iter()
320                .flat_map(|word| word.iter().copied())
321                .collect();
322
323            assert_eq!(
324                segtree.get(versions[version], index),
325                states[version][index]
326            );
327            assert_eq!(segtree.fold(versions[version], start..end), expected);
328            assert_eq!(segtree.fold_all(versions[version]), expected_all);
329        }
330    }
331}