competitive/algorithm/
solve_01_on_tree.rs

1use super::{UnionFindBase, union_find};
2use std::{cmp::Ordering, collections::BinaryHeap, ops::AddAssign};
3
4#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
5struct Count01 {
6    cnt0: usize,
7    cnt1: usize,
8}
9
10impl PartialOrd for Count01 {
11    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
12        Some(self.cmp(other))
13    }
14}
15
16impl Ord for Count01 {
17    fn cmp(&self, other: &Self) -> Ordering {
18        match (self.cnt1 == 0, other.cnt1 == 0) {
19            (true, true) => self.cnt0.cmp(&other.cnt0),
20            (true, false) => Ordering::Greater,
21            (false, true) => Ordering::Less,
22            (false, false) => (self.cnt0 * other.cnt1).cmp(&(other.cnt0 * self.cnt1)),
23        }
24    }
25}
26
27impl AddAssign for Count01 {
28    fn add_assign(&mut self, other: Self) {
29        *self = Self {
30            cnt0: self.cnt0 + other.cnt0,
31            cnt1: self.cnt1 + other.cnt1,
32        }
33    }
34}
35
36impl Count01 {
37    pub fn new(cnt0: usize, cnt1: usize) -> Self {
38        Self { cnt0, cnt1 }
39    }
40}
41
42pub fn solve_01_on_tree(
43    n: usize,
44    c01: impl Fn(usize) -> (usize, usize),
45    root: usize,
46    parent: impl Fn(usize) -> usize,
47) -> (usize, Vec<usize>) {
48    pub type UF<T, M> =
49        UnionFindBase<(), union_find::PathCompression, union_find::FnMerger<T, M>, (), ()>;
50    let mut cost = 0usize;
51    let c01 = |u| {
52        let c = c01(u);
53        Count01::new(c.0, c.1)
54    };
55    let mut uf = UF::new_with_merger(n, &c01, |x, y| {
56        cost += x.cnt1 * y.cnt0;
57        *x += *y;
58    });
59    let mut label = vec![0; n];
60    let mut heap = BinaryHeap::from_iter((0..n).filter(|&u| u != root).map(|u| (c01(u), u, 0)));
61    let mut next: Vec<_> = (0..n).collect();
62    let mut ord = Vec::with_capacity(n);
63    while let Some((_c, u, l)) = heap.pop() {
64        if label[u] != l {
65            continue;
66        }
67        let p = uf.find_root(parent(u));
68        uf.unite(u, p);
69        if !uf.same(p, root) {
70            label[p] += 1;
71            heap.push((*uf.merge_data(p), p, label[p]));
72        }
73        next.swap(u, p);
74    }
75    let mut u = next[root];
76    ord.push(u);
77    while u != root {
78        u = next[u];
79        ord.push(u);
80    }
81    ord.reverse();
82    (cost, ord)
83}