1use super::{Monoid, UndirectedSparseGraph};
2
3pub struct HeavyLightDecomposition {
4 pub par: Vec<usize>,
5 size: Vec<usize>,
6 head: Vec<usize>,
7 pub vidx: Vec<usize>,
8}
9
10impl HeavyLightDecomposition {
11 pub fn new(root: usize, graph: &mut UndirectedSparseGraph) -> Self {
12 let mut self_ = Self {
13 par: vec![0; graph.vertices_size()],
14 size: vec![0; graph.vertices_size()],
15 head: vec![0; graph.vertices_size()],
16 vidx: vec![0; graph.vertices_size()],
17 };
18 self_.build(root, graph);
19 self_
20 }
21
22 fn dfs_size(&mut self, u: usize, p: usize, graph: &mut UndirectedSparseGraph) {
23 self.par[u] = p;
24 self.size[u] = 1;
25 let base = graph.start[u];
26 if graph.adjacencies(u).len() > 1 && graph.adjacencies(u).next().unwrap().to == p {
27 graph.elist.swap(base, base + 1);
28 }
29 for i in base..graph.start[u + 1] {
30 let a = graph.elist[i];
31 if a.to != p {
32 self.dfs_size(a.to, u, graph);
33 self.size[u] += self.size[a.to];
34 if self.size[graph.elist[base].to] < self.size[a.to] {
35 graph.elist.swap(base, i);
36 }
37 }
38 }
39 }
40
41 fn dfs_hld(&mut self, u: usize, p: usize, t: &mut usize, graph: &UndirectedSparseGraph) {
42 self.vidx[u] = *t;
43 *t += 1;
44 let mut adjacencies = graph.adjacencies(u).filter(|a| a.to != p);
45 if let Some(a) = adjacencies.next() {
46 self.head[a.to] = self.head[u];
47 self.dfs_hld(a.to, u, t, graph);
48 }
49 for a in adjacencies {
50 self.head[a.to] = a.to;
51 self.dfs_hld(a.to, u, t, graph);
52 }
53 }
54
55 fn build(&mut self, root: usize, graph: &mut UndirectedSparseGraph) {
56 self.head[root] = root;
57 self.dfs_size(root, graph.vertices_size(), graph);
58 let mut t = 0;
59 self.dfs_hld(root, graph.vertices_size(), &mut t, graph);
60 }
61
62 pub fn lca(&self, mut u: usize, mut v: usize) -> usize {
63 loop {
64 if self.vidx[u] > self.vidx[v] {
65 std::mem::swap(&mut u, &mut v);
66 }
67 if self.head[u] == self.head[v] {
68 return u;
69 }
70 v = self.par[self.head[v]];
71 }
72 }
73
74 pub fn update<F: FnMut(usize, usize)>(
75 &self,
76 mut u: usize,
77 mut v: usize,
78 is_edge: bool,
79 mut f: F,
80 ) {
81 loop {
82 if self.vidx[u] > self.vidx[v] {
83 std::mem::swap(&mut u, &mut v);
84 }
85 if self.head[u] == self.head[v] {
86 break;
87 }
88 f(self.vidx[self.head[v]], self.vidx[v] + 1);
89 v = self.par[self.head[v]];
90 }
91 f(self.vidx[u] + is_edge as usize, self.vidx[v] + 1);
92 }
93
94 pub fn query<M: Monoid, F: FnMut(usize, usize) -> M::T>(
95 &self,
96 mut u: usize,
97 mut v: usize,
98 is_edge: bool,
99 mut f: F,
100 ) -> M::T {
101 let (mut l, mut r) = (M::unit(), M::unit());
102 loop {
103 if self.vidx[u] > self.vidx[v] {
104 std::mem::swap(&mut u, &mut v);
105 std::mem::swap(&mut l, &mut r);
106 }
107 if self.head[u] == self.head[v] {
108 break;
109 }
110 l = M::operate(&f(self.vidx[self.head[v]], self.vidx[v] + 1), &l);
111 v = self.par[self.head[v]];
112 }
113 M::operate(
114 &M::operate(&f(self.vidx[u] + is_edge as usize, self.vidx[v] + 1), &l),
115 &r,
116 )
117 }
118
119 pub fn query_noncom<
120 M: Monoid,
121 F1: FnMut(usize, usize) -> M::T,
122 F2: FnMut(usize, usize) -> M::T,
123 >(
124 &self,
125 mut u: usize,
126 mut v: usize,
127 is_edge: bool,
128 mut f1: F1,
129 mut f2: F2,
130 ) -> M::T {
131 let (mut l, mut r) = (M::unit(), M::unit());
132 while self.head[u] != self.head[v] {
133 if self.vidx[u] > self.vidx[v] {
134 l = M::operate(&l, &f2(self.vidx[self.head[u]], self.vidx[u] + 1));
135 u = self.par[self.head[u]];
136 } else {
137 r = M::operate(&f1(self.vidx[self.head[v]], self.vidx[v] + 1), &r);
138 v = self.par[self.head[v]];
139 }
140 }
141 M::operate(
142 &M::operate(
143 &l,
144 &if self.vidx[u] > self.vidx[v] {
145 f2(self.vidx[v] + is_edge as usize, self.vidx[u] + 1)
146 } else {
147 f1(self.vidx[u] + is_edge as usize, self.vidx[v] + 1)
148 },
149 ),
150 &r,
151 )
152 }
153}