Skip to main content

competitive/tree/
centroid_decomposition.rs

1use super::{ConvolveSteps, U64Convolve, UndirectedSparseGraph};
2use std::mem::swap;
3
4#[derive(Debug, Clone)]
5struct RootedTree {
6    parents: Vec<usize>,
7    vs: Vec<usize>,
8}
9
10impl RootedTree {
11    fn len(&self) -> usize {
12        self.vs.len()
13    }
14
15    fn split_centroid(&self) -> CentroidSplit {
16        let n = self.len();
17        assert!(n > 2);
18        let parents = &self.parents;
19        let vs = &self.vs;
20        let mut size = vec![1; n];
21        let mut c = usize::MAX;
22        for i in (0..n).rev() {
23            if size[i] >= n.div_ceil(2) {
24                c = i;
25                break;
26            }
27            size[parents[i]] += size[i];
28        }
29        let mut side = vec![u8::MAX; n];
30        let mut order = vec![usize::MAX; n];
31        order[c] = 0;
32        let mut count = 1usize;
33        let mut taken = 0usize;
34        for u in 1..n {
35            if parents[u] == c && taken + size[u] <= (n - 1) / 2 {
36                taken += size[u];
37                side[u] = 0;
38                order[u] = count;
39                count += 1;
40            }
41        }
42        for u in 1..n {
43            if side[parents[u]] == 0 {
44                side[u] = 0;
45                order[u] = count;
46                count += 1;
47            }
48        }
49        let lsize = count - 1;
50        {
51            let mut u = parents[c];
52            while u != usize::MAX {
53                side[u] = 1;
54                order[u] = count;
55                count += 1;
56                u = parents[u];
57            }
58        }
59        for u in 0..n {
60            if u != c && side[u] == u8::MAX {
61                side[u] = 1;
62                order[u] = count;
63                count += 1;
64            }
65        }
66        assert_eq!(count, n);
67        let rsize = n - lsize - 1;
68        let mut whole_parents = vec![usize::MAX; n];
69        let mut whole_vs = vec![usize::MAX; n];
70        let mut left_parents = vec![usize::MAX; lsize + 1];
71        let mut left_vs = vec![usize::MAX; lsize + 1];
72        let mut right_parents = vec![usize::MAX; rsize + 1];
73        let mut right_vs = vec![usize::MAX; rsize + 1];
74        for u in 0..n {
75            let i = order[u];
76            whole_vs[i] = vs[u];
77            if side[u] != 1 {
78                left_vs[i] = vs[u];
79            }
80            if side[u] != 0 {
81                right_vs[if i == 0 { 0 } else { i - lsize }] = vs[u];
82            }
83        }
84        for u in 1..n {
85            let mut x = order[u];
86            let mut y = order[parents[u]];
87            if x > y {
88                swap(&mut x, &mut y);
89            }
90            whole_parents[y] = x;
91            if side[u] != 1 && side[parents[u]] != 1 {
92                left_parents[y] = x;
93            }
94            if side[u] != 0 && side[parents[u]] != 0 {
95                right_parents[if y == 0 { 0 } else { y - lsize }] =
96                    if x == 0 { 0 } else { x - lsize };
97            }
98        }
99        CentroidSplit {
100            whole: RootedTree {
101                parents: whole_parents,
102                vs: whole_vs,
103            },
104            left: RootedTree {
105                parents: left_parents,
106                vs: left_vs,
107            },
108            right: RootedTree {
109                parents: right_parents,
110                vs: right_vs,
111            },
112            old_to_new: order,
113            lsize,
114        }
115    }
116
117    fn centroid_decomposition(self, f: &mut impl FnMut(&[usize], &[usize], usize, usize)) {
118        if self.len() <= 2 {
119            return;
120        }
121        let split = self.split_centroid();
122        f(
123            &split.whole.parents,
124            &split.whole.vs,
125            split.lsize,
126            split.rsize(),
127        );
128        split.left.centroid_decomposition(f);
129        split.right.centroid_decomposition(f);
130    }
131
132    fn append_contour_components(
133        &self,
134        color: &[i8],
135        comp_range: &mut Vec<usize>,
136        vertex_info: &mut [Vec<ContourInfo>],
137    ) {
138        let n = self.len();
139        let mut dist = vec![0usize; n];
140        for i in 1..n {
141            dist[i] = dist[self.parents[i]] + 1;
142        }
143        let mut comp = comp_range.len() - 1;
144        for c1 in [0, 1] {
145            let mut max_a = 0usize;
146            let mut max_b = 0usize;
147            let mut has_a = false;
148            let mut has_b = false;
149            for (v, &c2) in color.iter().enumerate() {
150                if c2 == c1 {
151                    has_a = true;
152                    max_a = max_a.max(dist[v]);
153                } else if c2 > c1 {
154                    has_b = true;
155                    max_b = max_b.max(dist[v]);
156                }
157            }
158            if !has_a || !has_b {
159                continue;
160            }
161            for (v, &c2) in color.iter().enumerate() {
162                if c2 == c1 {
163                    vertex_info[self.vs[v]].push(ContourInfo { comp, dep: dist[v] });
164                }
165            }
166            comp_range.push(comp_range[comp] + max_a + 1);
167            comp += 1;
168            for (v, &c2) in color.iter().enumerate() {
169                if c2 > c1 {
170                    vertex_info[self.vs[v]].push(ContourInfo { comp, dep: dist[v] });
171                }
172            }
173            comp_range.push(comp_range[comp] + max_b + 1);
174            comp += 1;
175        }
176    }
177
178    fn build_contour_query_range(
179        self,
180        real: Vec<bool>,
181        comp_range: &mut Vec<usize>,
182        vertex_info: &mut [Vec<ContourInfo>],
183    ) {
184        let n = self.len();
185        if n <= 1 {
186            return;
187        }
188        if n == 2 {
189            if real[0] && real[1] {
190                self.append_contour_components(&[0, 1], comp_range, vertex_info);
191            }
192            return;
193        }
194        let split = self.split_centroid();
195        let mut whole_real = vec![false; n];
196        for (u, &is_real) in real.iter().enumerate() {
197            if is_real {
198                whole_real[split.old_to_new[u]] = true;
199            }
200        }
201        let mut color = vec![-1i8; n];
202        for (i, &is_real) in whole_real.iter().enumerate().skip(1) {
203            if is_real {
204                color[i] = if i <= split.lsize { 0 } else { 1 };
205            }
206        }
207        if whole_real[0] {
208            color[0] = 2;
209        }
210        split
211            .whole
212            .append_contour_components(&color, comp_range, vertex_info);
213        if whole_real[0] {
214            whole_real[0] = false;
215        }
216        let mut right_real = Vec::with_capacity(split.rsize() + 1);
217        right_real.push(whole_real[0]);
218        right_real.extend_from_slice(&whole_real[split.lsize + 1..]);
219        split.left.build_contour_query_range(
220            whole_real[..=split.lsize].to_vec(),
221            comp_range,
222            vertex_info,
223        );
224        split
225            .right
226            .build_contour_query_range(right_real, comp_range, vertex_info);
227    }
228}
229
230impl From<&UndirectedSparseGraph> for RootedTree {
231    fn from(graph: &UndirectedSparseGraph) -> Self {
232        let n = graph.vertices_size();
233        let mut vs = Vec::with_capacity(n);
234        let mut parent = vec![usize::MAX; n];
235        vs.push(0usize);
236        for i in 0..n {
237            let u = vs[i];
238            for a in graph.adjacencies(u) {
239                if a.to != parent[u] {
240                    vs.push(a.to);
241                    parent[a.to] = u;
242                }
243            }
244        }
245        let mut new_idx = vec![0; n];
246        for (i, &v) in vs.iter().enumerate() {
247            new_idx[v] = i;
248        }
249        let mut parents = vec![usize::MAX; n];
250        for v in 1..n {
251            parents[new_idx[v]] = new_idx[parent[v]];
252        }
253        Self { parents, vs }
254    }
255}
256
257#[derive(Debug)]
258struct CentroidSplit {
259    whole: RootedTree,
260    left: RootedTree,
261    right: RootedTree,
262    old_to_new: Vec<usize>,
263    lsize: usize,
264}
265
266impl CentroidSplit {
267    fn rsize(&self) -> usize {
268        self.whole.len() - self.lsize - 1
269    }
270}
271
272#[derive(Debug, Clone, Copy)]
273struct ContourInfo {
274    comp: usize,
275    dep: usize,
276}
277
278#[derive(Debug, Clone)]
279pub struct ContourQueryRange {
280    comp_range: Vec<usize>,
281    info_indptr: Vec<usize>,
282    infos: Vec<ContourInfo>,
283}
284
285impl ContourQueryRange {
286    pub fn len(&self) -> usize {
287        self.comp_range.last().copied().unwrap_or_default()
288    }
289
290    pub fn is_empty(&self) -> bool {
291        self.len() == 0
292    }
293
294    pub fn for_each_index(&self, v: usize, mut f: impl FnMut(usize)) {
295        for info in &self.infos[self.info_indptr[v]..self.info_indptr[v + 1]] {
296            f(self.comp_range[info.comp] + info.dep);
297        }
298    }
299
300    pub fn for_each_contour_range(
301        &self,
302        v: usize,
303        l: usize,
304        r: usize,
305        mut f: impl FnMut(usize, usize),
306    ) {
307        for info in &self.infos[self.info_indptr[v]..self.info_indptr[v + 1]] {
308            let comp = info.comp ^ 1;
309            let start = self.comp_range[comp];
310            let len = self.comp_range[comp + 1] - start;
311            let lo = l.saturating_sub(info.dep).min(len);
312            let hi = r.saturating_sub(info.dep).min(len);
313            if lo < hi {
314                f(start + lo, start + hi);
315            }
316        }
317    }
318}
319
320impl UndirectedSparseGraph {
321    /// 1/3 centroid decomposition
322    ///
323    /// - f: (parents: &[usize], vs: &[usize], lsize: usize, rsize: usize)
324    /// - 0: root, 1..=lsize: left subtree, lsize+1..=lsize+rsize: right subtree
325    pub fn centroid_decomposition(&self, mut f: impl FnMut(&[usize], &[usize], usize, usize)) {
326        if self.vertices_size() <= 1 {
327            return;
328        }
329        RootedTree::from(self).centroid_decomposition(&mut f);
330    }
331
332    pub fn contour_query_range(&self) -> ContourQueryRange {
333        let n = self.vertices_size();
334        if n <= 1 {
335            return ContourQueryRange {
336                comp_range: vec![0],
337                info_indptr: vec![0; n + 1],
338                infos: vec![],
339            };
340        }
341        let mut comp_range = vec![0usize];
342        let mut vertex_info = vec![vec![]; n];
343        RootedTree::from(self).build_contour_query_range(
344            vec![true; n],
345            &mut comp_range,
346            &mut vertex_info,
347        );
348        let mut info_indptr = vec![0usize; n + 1];
349        for (v, infos) in vertex_info.iter().enumerate() {
350            info_indptr[v + 1] = info_indptr[v] + infos.len();
351        }
352        let mut infos = Vec::with_capacity(info_indptr[n]);
353        for entries in vertex_info {
354            infos.extend(entries);
355        }
356        ContourQueryRange {
357            comp_range,
358            info_indptr,
359            infos,
360        }
361    }
362
363    pub fn distance_frequencies(&self) -> Vec<u64> {
364        let n = self.vertices_size();
365        let mut table = vec![0u64; n];
366        if n == 0 {
367            return table;
368        }
369        table[0] = n as u64;
370        if n == 1 {
371            return table;
372        }
373        table[1] = (n * 2 - 2) as u64;
374        self.centroid_decomposition(|parents, vs, lsize, _rsize| {
375            let n = vs.len();
376            let mut dist = vec![0usize; n];
377            for i in 1..n {
378                dist[i] = dist[parents[i]] + 1;
379            }
380            let d_max = dist.iter().max().cloned().unwrap_or_default();
381            let mut f = vec![0u64; d_max + 1];
382            let mut g = vec![0u64; d_max + 1];
383            for i in 1..=lsize {
384                f[dist[i]] += 1;
385            }
386            for i in lsize + 1..n {
387                g[dist[i]] += 1;
388            }
389            while f.last().is_some_and(|&x| x == 0) {
390                f.pop();
391            }
392            while g.last().is_some_and(|&x| x == 0) {
393                g.pop();
394            }
395            let h = U64Convolve::convolve(f, g);
396            for (i, &x) in h.iter().enumerate() {
397                table[i] += x * 2;
398            }
399        });
400        table
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use crate::{tools::Xorshift, tree::MixedTree};
407
408    #[test]
409    fn test_distance_frequencies() {
410        let mut rng = Xorshift::default();
411        for _ in 0..200 {
412            let g = rng.random(MixedTree(1usize..100));
413            let n = g.vertices_size();
414            let result = g.distance_frequencies();
415            let mut expected = vec![0u64; n];
416            for u in 0..n {
417                let depth = g.tree_depth(u);
418                for v in 0..n {
419                    expected[depth[v] as usize] += 1;
420                }
421            }
422            assert_eq!(result, expected);
423        }
424    }
425
426    #[test]
427    fn test_contour_query_range_counts() {
428        let mut rng = Xorshift::default();
429        for _ in 0..100 {
430            let g = rng.random(MixedTree(1usize..80));
431            let n = g.vertices_size();
432            let cq = g.contour_query_range();
433            let mut data = vec![0; cq.len()];
434            for v in 0..n {
435                cq.for_each_index(v, |i| data[i] += 1);
436            }
437            for _ in 0..200 {
438                let v = rng.random(0..n);
439                let l = rng.random(0..=n);
440                let r = rng.random(l..=n + 1);
441                let dist = g.tree_depth(v);
442                let expected = dist
443                    .iter()
444                    .enumerate()
445                    .filter(|&(u, &d)| u != v && l <= d as usize && (d as usize) < r)
446                    .count();
447                let mut actual = 0usize;
448                cq.for_each_contour_range(v, l, r, |start, end| {
449                    actual += data[start..end].iter().sum::<usize>();
450                });
451                assert_eq!(actual, expected);
452            }
453        }
454    }
455
456    #[test]
457    fn test_contour_query_range_single_vertex() {
458        let mut rng = Xorshift::default();
459        for _ in 0..80 {
460            let g = rng.random(MixedTree(1usize..60));
461            let n = g.vertices_size();
462            let cq = g.contour_query_range();
463            for _ in 0..120 {
464                let u = rng.random(0..n);
465                let v = rng.random(0..n);
466                let l = rng.random(0..=n);
467                let r = rng.random(l..=n + 1);
468                let mut data = vec![0; cq.len()];
469                cq.for_each_index(u, |i| data[i] += 1);
470                let expected = usize::from({
471                    let d = g.tree_depth(v)[u] as usize;
472                    u != v && l <= d && d < r
473                });
474                let mut actual = 0usize;
475                cq.for_each_contour_range(v, l, r, |start, end| {
476                    actual += data[start..end].iter().sum::<usize>();
477                });
478                assert_eq!(actual, expected);
479            }
480        }
481    }
482}