Skip to main content

competitive/tree/
euler_tour.rs

1use super::{RangeMinimumQuery, UndirectedSparseGraph};
2use std::{marker::PhantomData, mem::swap, ops::Range};
3
4pub trait EulerTourKind {
5    const USE_LAST: bool = false;
6    const USE_VISIT: bool = false;
7
8    fn size(n: usize) -> usize {
9        if Self::USE_VISIT {
10            2 * n - 1
11        } else if Self::USE_LAST {
12            2 * n
13        } else {
14            n
15        }
16    }
17}
18
19mod marker {
20    use super::EulerTourKind;
21
22    #[derive(Debug, Clone)]
23    pub enum First {}
24    #[derive(Debug, Clone)]
25    pub enum FirstLast {}
26    #[derive(Debug, Clone)]
27    pub enum Visit {}
28
29    impl EulerTourKind for First {}
30    impl EulerTourKind for FirstLast {
31        const USE_LAST: bool = true;
32    }
33    impl EulerTourKind for Visit {
34        const USE_VISIT: bool = true;
35    }
36}
37
38#[derive(Debug)]
39pub struct EulerTourBuilder<'a, K>
40where
41    K: EulerTourKind,
42{
43    tree: &'a UndirectedSparseGraph,
44    root: usize,
45    vidx: Vec<[usize; 2]>,
46    eidx: Vec<[usize; 2]>,
47    pos: usize,
48    _marker: PhantomData<fn() -> K>,
49}
50
51#[derive(Debug, Clone)]
52pub struct EulerTour<K>
53where
54    K: EulerTourKind,
55{
56    pub root: usize,
57    pub vidx: Vec<[usize; 2]>,
58    pub eidx: Vec<[usize; 2]>,
59    pub size: usize,
60    _marker: PhantomData<fn() -> K>,
61}
62
63impl<'a, K> EulerTourBuilder<'a, K>
64where
65    K: EulerTourKind,
66{
67    pub fn new(tree: &'a UndirectedSparseGraph, root: usize) -> Self {
68        let n = tree.vertices_size();
69        Self {
70            tree,
71            root,
72            vidx: vec![[0usize; 2]; n],
73            eidx: vec![[0usize; 2]; n - 1],
74            pos: 0,
75            _marker: PhantomData,
76        }
77    }
78
79    pub fn build_with_trace(mut self, mut trace: impl FnMut(usize)) -> EulerTour<K> {
80        self.dfs(self.root, !0, &mut trace);
81        EulerTour {
82            root: self.root,
83            vidx: self.vidx,
84            eidx: self.eidx,
85            size: self.pos,
86            _marker: PhantomData,
87        }
88    }
89
90    pub fn build(self) -> EulerTour<K> {
91        self.build_with_trace(|_u| {})
92    }
93
94    fn dfs(&mut self, u: usize, parent: usize, trace: &mut impl FnMut(usize)) {
95        self.vidx[u][0] = self.pos;
96        trace(u);
97        self.pos += 1;
98        for a in self.tree.adjacencies(u) {
99            if a.to != parent {
100                self.eidx[a.id][0] = self.pos;
101                self.dfs(a.to, u, trace);
102                self.eidx[a.id][1] = self.pos;
103                if K::USE_VISIT {
104                    trace(u);
105                    self.pos += 1;
106                }
107            }
108        }
109        self.vidx[u][1] = self.pos;
110        if K::USE_LAST {
111            trace(u);
112            self.pos += 1;
113        }
114    }
115}
116
117impl EulerTourBuilder<'_, marker::First> {
118    pub fn build_with_rearrange<T>(self, s: &[T]) -> (EulerTour<marker::First>, Vec<T>)
119    where
120        T: Clone,
121    {
122        assert_eq!(s.len(), self.tree.vertices_size());
123        let mut trace = Vec::with_capacity(marker::First::size(s.len()));
124        let tour = self.build_with_trace(|u| {
125            trace.push(s[u].clone());
126        });
127        (tour, trace)
128    }
129}
130
131impl EulerTourBuilder<'_, marker::FirstLast> {
132    pub fn build_with_rearrange<T>(
133        self,
134        s: &[T],
135        mut inverse: impl FnMut(T) -> T,
136    ) -> (EulerTour<marker::FirstLast>, Vec<T>)
137    where
138        T: Clone,
139    {
140        assert_eq!(s.len(), self.tree.vertices_size());
141        let mut visited = vec![false; s.len()];
142        let mut trace = Vec::with_capacity(marker::FirstLast::size(s.len()));
143        let tour = self.build_with_trace(|u| {
144            if !visited[u] {
145                trace.push(s[u].clone());
146                visited[u] = true;
147            } else {
148                trace.push(inverse(s[u].clone()));
149            }
150        });
151        (tour, trace)
152    }
153}
154
155impl EulerTourBuilder<'_, marker::Visit> {
156    pub fn build_with_rearrange<T>(self, s: &[T]) -> (EulerTour<marker::Visit>, Vec<T>)
157    where
158        T: Clone,
159    {
160        assert_eq!(s.len(), self.tree.vertices_size());
161        let mut trace = Vec::with_capacity(marker::Visit::size(s.len()));
162        let tour = self.build_with_trace(|u| {
163            trace.push(s[u].clone());
164        });
165        (tour, trace)
166    }
167}
168
169impl UndirectedSparseGraph {
170    pub fn subtree_euler_tour_builder<'a>(
171        &'a self,
172        root: usize,
173    ) -> EulerTourBuilder<'a, marker::First> {
174        EulerTourBuilder::new(self, root)
175    }
176
177    pub fn path_euler_tour_builder<'a>(
178        &'a self,
179        root: usize,
180    ) -> EulerTourBuilder<'a, marker::FirstLast> {
181        EulerTourBuilder::new(self, root)
182    }
183
184    pub fn full_euler_tour_builder<'a>(
185        &'a self,
186        root: usize,
187    ) -> EulerTourBuilder<'a, marker::Visit> {
188        EulerTourBuilder::new(self, root)
189    }
190
191    pub fn lca(&self, root: usize) -> LowestCommonAncestor {
192        let depth = self.tree_depth(root);
193        let mut trace = Vec::with_capacity(2 * self.vertices_size() - 1);
194        let mut depth_trace = Vec::with_capacity(2 * self.vertices_size() - 1);
195        let euler_tour = self.full_euler_tour_builder(root).build_with_trace(|u| {
196            trace.push(u);
197            depth_trace.push(depth[u]);
198        });
199        let rmq = RangeMinimumQuery::new(depth_trace);
200        LowestCommonAncestor {
201            euler_tour,
202            trace,
203            rmq,
204        }
205    }
206}
207
208impl EulerTour<marker::First> {
209    pub fn get<T>(&self, u: usize, mut f: impl FnMut(usize) -> T) -> T {
210        let [l, _] = self.vidx[u];
211        f(l)
212    }
213
214    pub fn update<T>(&self, u: usize, x: T, mut f: impl FnMut(usize, T)) {
215        let [l, _] = self.vidx[u];
216        f(l, x);
217    }
218
219    pub fn fold<T>(&self, u: usize, mut f: impl FnMut(Range<usize>) -> T) -> T {
220        let [l, r] = self.vidx[u];
221        f(l..r)
222    }
223
224    pub fn range_update<T>(&self, u: usize, x: T, mut f: impl FnMut(Range<usize>, T)) {
225        let [l, r] = self.vidx[u];
226        f(l..r, x);
227    }
228}
229
230impl EulerTour<marker::FirstLast> {
231    pub fn get<T>(&self, u: usize, mut f: impl FnMut(usize) -> T) -> T {
232        let [l, _] = self.vidx[u];
233        f(l)
234    }
235
236    pub fn update<T>(&self, u: usize, x: T, invx: T, mut f: impl FnMut(usize, T)) {
237        let [l, r] = self.vidx[u];
238        f(l, x);
239        f(r, invx);
240    }
241
242    // f: accumulate
243    pub fn fold<T>(&self, u: usize, mut f: impl FnMut(usize) -> T) -> T {
244        f(self.vidx[u][0])
245    }
246}
247
248#[derive(Debug)]
249pub struct LowestCommonAncestor {
250    euler_tour: EulerTour<marker::Visit>,
251    trace: Vec<usize>,
252    rmq: RangeMinimumQuery<u64>,
253}
254
255impl LowestCommonAncestor {
256    pub fn lca(&self, u: usize, v: usize) -> usize {
257        let mut l = self.euler_tour.vidx[u][0];
258        let mut r = self.euler_tour.vidx[v][0];
259        if l > r {
260            swap(&mut l, &mut r);
261        }
262        let idx = self.rmq.argmin(l, r + 1);
263        self.trace[idx]
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::{
271        algebra::{AdditiveOperation, RangeSumRangeAdd},
272        crecurse,
273        data_structure::{LazySegmentTree, SegmentTree},
274        tools::Xorshift,
275        tree::MixedTree,
276    };
277
278    #[test]
279    fn test_builder() {
280        let mut rng = Xorshift::default();
281        for _ in 0..200 {
282            let n = rng.random(1..=200);
283            let tree = rng.random(MixedTree(n));
284            let root = rng.random(0..n);
285            let et1 = tree.subtree_euler_tour_builder(root).build();
286            let et2 = tree.path_euler_tour_builder(root).build();
287            let et3 = tree.full_euler_tour_builder(root).build();
288            assert_eq!(et1.size, marker::First::size(n));
289            assert_eq!(et2.size, marker::FirstLast::size(n));
290            assert_eq!(et3.size, marker::Visit::size(n));
291            for u in 0..n {
292                assert!(et1.vidx[u][0] < et1.vidx[u][1]);
293                assert!(et1.vidx[u][1] <= marker::First::size(n));
294                assert!(et2.vidx[u][0] < et2.vidx[u][1]);
295                assert!(et2.vidx[u][1] < marker::FirstLast::size(n));
296                assert!(et3.vidx[u][0] < et3.vidx[u][1]);
297                assert!(et3.vidx[u][1] <= marker::Visit::size(n));
298            }
299        }
300    }
301
302    #[test]
303    fn test_subtree_euler_tour() {
304        const A: i64 = 1_000_000;
305        let mut rng = Xorshift::default();
306        for _ in 0..200 {
307            let n = rng.random(1..=200);
308            let tree = rng.random(MixedTree(n));
309            let root = rng.random(0..n);
310            let mut a: Vec<_> = rng.random_iter(0..A).take(n).collect();
311            let (et, arr) = tree
312                .subtree_euler_tour_builder(root)
313                .build_with_rearrange(&a);
314            let mut seg = LazySegmentTree::<RangeSumRangeAdd<i64>>::from_keys(arr.into_iter());
315            for _ in 0..200 {
316                match rng.random(0..4) {
317                    0 => {
318                        let u = rng.random(0..n);
319                        let result = et.get(u, |idx| seg.get(idx)).0;
320                        let expected = a[u];
321                        assert_eq!(result, expected);
322                    }
323                    1 => {
324                        let u = rng.random(0..n);
325                        let x = rng.random(0..A);
326                        et.update(u, x, |i, x| seg.update(i..=i, x));
327                        a[u] += x;
328                    }
329                    2 => {
330                        let u = rng.random(0..n);
331                        let result = et.fold(u, |r| seg.fold(r)).0;
332                        let mut expected = 0;
333                        crecurse!(
334                            unsafe fn dfs(v: usize, p: usize, b: bool) {
335                                let b = b || v == u;
336                                if b {
337                                    expected += a[v];
338                                }
339                                for a in tree.adjacencies(v) {
340                                    if a.to != p {
341                                        dfs!(a.to, v, b);
342                                    }
343                                }
344                            }
345                        )(root, !0, false);
346                        assert_eq!(result, expected);
347                    }
348                    _ => {
349                        let u = rng.random(0..n);
350                        let x = rng.random(0..A);
351                        et.range_update(u, x, |r, x| seg.update(r, x));
352                        crecurse!(
353                            unsafe fn dfs(v: usize, p: usize, b: bool) {
354                                let b = b || v == u;
355                                if b {
356                                    a[v] += x;
357                                }
358                                for a in tree.adjacencies(v) {
359                                    if a.to != p {
360                                        dfs!(a.to, v, b);
361                                    }
362                                }
363                            }
364                        )(root, !0, false);
365                    }
366                }
367            }
368        }
369    }
370
371    #[test]
372    fn test_path_euler_tour() {
373        const A: i64 = 1_000_000;
374        let mut rng = Xorshift::default();
375        for _ in 0..200 {
376            let n = rng.random(1..=200);
377            let tree = rng.random(MixedTree(n));
378            let root = rng.random(0..n);
379            let mut a: Vec<_> = rng.random_iter(0..A).take(n).collect();
380            let (et, arr) = tree
381                .path_euler_tour_builder(root)
382                .build_with_rearrange(&a, |x| -x);
383            let mut seg = SegmentTree::<AdditiveOperation<i64>>::from_vec(arr);
384            for _ in 0..200 {
385                match rng.random(0..3) {
386                    0 => {
387                        let u = rng.random(0..n);
388                        let result = et.get(u, |idx| seg.get(idx));
389                        let expected = a[u];
390                        assert_eq!(result, expected);
391                    }
392                    1 => {
393                        let u = rng.random(0..n);
394                        let x = rng.random(0..A);
395                        let invx = -x;
396                        et.update(u, x, invx, |i, x| seg.update(i, x));
397                        a[u] += x;
398                    }
399                    _ => {
400                        let u = rng.random(0..n);
401                        let result = et.fold(u, |k| seg.fold(0..=k));
402                        let mut expected = 0;
403                        crecurse!(
404                            unsafe fn dfs(v: usize, p: usize) -> bool {
405                                if v == u {
406                                    expected += a[v];
407                                    return true;
408                                }
409                                for adj in tree.adjacencies(v) {
410                                    if adj.to != p && dfs!(adj.to, v) {
411                                        expected += a[v];
412                                        return true;
413                                    }
414                                }
415                                false
416                            }
417                        )(root, !0);
418                        assert_eq!(result, expected);
419                    }
420                }
421            }
422        }
423    }
424
425    #[test]
426    fn test_lca() {
427        let mut rng = Xorshift::default();
428        for _ in 0..200 {
429            let n = rng.random(1..=200);
430            let tree = rng.random(MixedTree(n));
431            let root = rng.random(0..n);
432            let lca = tree.lca(root);
433            for _ in 0..200 {
434                let u = rng.random(0..n);
435                let v = rng.random(0..n);
436                let result = lca.lca(u, v);
437                let expected = crecurse!(
438                    unsafe fn dfs(w: usize, p: usize) -> Result<usize, [bool; 2]> {
439                        let mut found = [false; 2];
440                        if w == u {
441                            found[0] = true;
442                        }
443                        if w == v {
444                            found[1] = true;
445                        }
446                        for adj in tree.adjacencies(w) {
447                            if adj.to != p {
448                                match dfs!(adj.to, w) {
449                                    Ok(lca) => return Ok(lca),
450                                    Err(res) => {
451                                        for i in 0..2 {
452                                            if res[i] {
453                                                found[i] = true;
454                                            }
455                                        }
456                                    }
457                                }
458                            }
459                        }
460                        if found[0] && found[1] {
461                            Ok(w)
462                        } else {
463                            Err(found)
464                        }
465                    }
466                )(root, !0)
467                .unwrap();
468                assert_eq!(result, expected);
469            }
470        }
471    }
472}