competitive/tree/
heavy_light_decomposition.rs

1use super::{Monoid, UndirectedSparseGraph};
2
3pub struct HeavyLightDecomposition {
4    pub par: Vec<usize>,
5    size: Vec<usize>,
6    head: Vec<usize>,
7    pub vidx: Vec<usize>,
8}
9
10impl HeavyLightDecomposition {
11    pub fn new(root: usize, graph: &mut UndirectedSparseGraph) -> Self {
12        let mut self_ = Self {
13            par: vec![0; graph.vertices_size()],
14            size: vec![0; graph.vertices_size()],
15            head: vec![0; graph.vertices_size()],
16            vidx: vec![0; graph.vertices_size()],
17        };
18        self_.build(root, graph);
19        self_
20    }
21
22    fn dfs_size(&mut self, u: usize, p: usize, graph: &mut UndirectedSparseGraph) {
23        self.par[u] = p;
24        self.size[u] = 1;
25        let base = graph.start[u];
26        if graph.adjacencies(u).len() > 1 && graph.adjacencies(u).next().unwrap().to == p {
27            graph.elist.swap(base, base + 1);
28        }
29        for i in base..graph.start[u + 1] {
30            let a = graph.elist[i];
31            if a.to != p {
32                self.dfs_size(a.to, u, graph);
33                self.size[u] += self.size[a.to];
34                if self.size[graph.elist[base].to] < self.size[a.to] {
35                    graph.elist.swap(base, i);
36                }
37            }
38        }
39    }
40
41    fn dfs_hld(&mut self, u: usize, p: usize, t: &mut usize, graph: &UndirectedSparseGraph) {
42        self.vidx[u] = *t;
43        *t += 1;
44        let mut adjacencies = graph.adjacencies(u).filter(|a| a.to != p);
45        if let Some(a) = adjacencies.next() {
46            self.head[a.to] = self.head[u];
47            self.dfs_hld(a.to, u, t, graph);
48        }
49        for a in adjacencies {
50            self.head[a.to] = a.to;
51            self.dfs_hld(a.to, u, t, graph);
52        }
53    }
54
55    fn build(&mut self, root: usize, graph: &mut UndirectedSparseGraph) {
56        self.head[root] = root;
57        self.dfs_size(root, graph.vertices_size(), graph);
58        let mut t = 0;
59        self.dfs_hld(root, graph.vertices_size(), &mut t, graph);
60    }
61
62    pub fn lca(&self, mut u: usize, mut v: usize) -> usize {
63        loop {
64            if self.vidx[u] > self.vidx[v] {
65                std::mem::swap(&mut u, &mut v);
66            }
67            if self.head[u] == self.head[v] {
68                return u;
69            }
70            v = self.par[self.head[v]];
71        }
72    }
73
74    pub fn update<F: FnMut(usize, usize)>(
75        &self,
76        mut u: usize,
77        mut v: usize,
78        is_edge: bool,
79        mut f: F,
80    ) {
81        loop {
82            if self.vidx[u] > self.vidx[v] {
83                std::mem::swap(&mut u, &mut v);
84            }
85            if self.head[u] == self.head[v] {
86                break;
87            }
88            f(self.vidx[self.head[v]], self.vidx[v] + 1);
89            v = self.par[self.head[v]];
90        }
91        f(self.vidx[u] + is_edge as usize, self.vidx[v] + 1);
92    }
93
94    pub fn query<M: Monoid, F: FnMut(usize, usize) -> M::T>(
95        &self,
96        mut u: usize,
97        mut v: usize,
98        is_edge: bool,
99        mut f: F,
100    ) -> M::T {
101        let (mut l, mut r) = (M::unit(), M::unit());
102        loop {
103            if self.vidx[u] > self.vidx[v] {
104                std::mem::swap(&mut u, &mut v);
105                std::mem::swap(&mut l, &mut r);
106            }
107            if self.head[u] == self.head[v] {
108                break;
109            }
110            l = M::operate(&f(self.vidx[self.head[v]], self.vidx[v] + 1), &l);
111            v = self.par[self.head[v]];
112        }
113        M::operate(
114            &M::operate(&f(self.vidx[u] + is_edge as usize, self.vidx[v] + 1), &l),
115            &r,
116        )
117    }
118
119    pub fn query_noncom<
120        M: Monoid,
121        F1: FnMut(usize, usize) -> M::T,
122        F2: FnMut(usize, usize) -> M::T,
123    >(
124        &self,
125        mut u: usize,
126        mut v: usize,
127        is_edge: bool,
128        mut f1: F1,
129        mut f2: F2,
130    ) -> M::T {
131        let (mut l, mut r) = (M::unit(), M::unit());
132        while self.head[u] != self.head[v] {
133            if self.vidx[u] > self.vidx[v] {
134                l = M::operate(&l, &f2(self.vidx[self.head[u]], self.vidx[u] + 1));
135                u = self.par[self.head[u]];
136            } else {
137                r = M::operate(&f1(self.vidx[self.head[v]], self.vidx[v] + 1), &r);
138                v = self.par[self.head[v]];
139            }
140        }
141        M::operate(
142            &M::operate(
143                &l,
144                &if self.vidx[u] > self.vidx[v] {
145                    f2(self.vidx[v] + is_edge as usize, self.vidx[u] + 1)
146                } else {
147                    f1(self.vidx[u] + is_edge as usize, self.vidx[v] + 1)
148                },
149            ),
150            &r,
151        )
152    }
153}