competitive/algorithm/
solve_01_on_tree.rs

1use super::{UnionFindBase, union_find};
2use std::{cmp::Ordering, collections::BTreeSet, ops::AddAssign};
3
4#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
5struct Count01 {
6    cnt0: usize,
7    cnt1: usize,
8}
9
10impl PartialOrd for Count01 {
11    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
12        Some(self.cmp(other))
13    }
14}
15
16impl Ord for Count01 {
17    fn cmp(&self, other: &Self) -> Ordering {
18        (self.cnt0 * other.cnt1).cmp(&(other.cnt0 * self.cnt1))
19    }
20}
21
22impl AddAssign for Count01 {
23    fn add_assign(&mut self, other: Self) {
24        *self = Self {
25            cnt0: self.cnt0 + other.cnt0,
26            cnt1: self.cnt1 + other.cnt1,
27        }
28    }
29}
30
31impl Count01 {
32    pub fn new(cnt0: usize, cnt1: usize) -> Self {
33        Self { cnt0, cnt1 }
34    }
35}
36
37pub fn solve_01_on_tree(
38    n: usize,
39    c01: impl Fn(usize) -> (usize, usize),
40    root: usize,
41    parent: impl Fn(usize) -> usize,
42) -> usize {
43    pub type UF<T, M> =
44        UnionFindBase<(), union_find::PathCompression, union_find::FnMerger<T, M>, (), ()>;
45    let mut cost = 0usize;
46    let c01 = |u| {
47        let c = c01(u);
48        Count01::new(c.0, c.1)
49    };
50    let mut uf = UF::new_with_merger(n, &c01, |x, y| {
51        cost += x.cnt1 * y.cnt0;
52        *x += *y;
53    });
54    let mut heap = BTreeSet::from_iter((0..n).filter(|&u| u != root).map(|u| (c01(u), u)));
55    while let Some((_c, u)) = heap.pop_last() {
56        let p = uf.find_root(parent(u));
57        heap.remove(&(*uf.merge_data(p), p));
58        uf.unite(u, p);
59        if !uf.same(p, root) {
60            heap.insert((*uf.merge_data(p), p));
61        }
62    }
63    cost
64}