competitive/algorithm/
solve_01_on_tree.rs1use super::{UnionFindBase, union_find};
2use std::{cmp::Ordering, collections::BTreeSet, ops::AddAssign};
3
4#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
5struct Count01 {
6 cnt0: usize,
7 cnt1: usize,
8}
9
10impl PartialOrd for Count01 {
11 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
12 Some(self.cmp(other))
13 }
14}
15
16impl Ord for Count01 {
17 fn cmp(&self, other: &Self) -> Ordering {
18 (self.cnt0 * other.cnt1).cmp(&(other.cnt0 * self.cnt1))
19 }
20}
21
22impl AddAssign for Count01 {
23 fn add_assign(&mut self, other: Self) {
24 *self = Self {
25 cnt0: self.cnt0 + other.cnt0,
26 cnt1: self.cnt1 + other.cnt1,
27 }
28 }
29}
30
31impl Count01 {
32 pub fn new(cnt0: usize, cnt1: usize) -> Self {
33 Self { cnt0, cnt1 }
34 }
35}
36
37pub fn solve_01_on_tree(
38 n: usize,
39 c01: impl Fn(usize) -> (usize, usize),
40 root: usize,
41 parent: impl Fn(usize) -> usize,
42) -> usize {
43 pub type UF<T, M> =
44 UnionFindBase<(), union_find::PathCompression, union_find::FnMerger<T, M>, (), ()>;
45 let mut cost = 0usize;
46 let c01 = |u| {
47 let c = c01(u);
48 Count01::new(c.0, c.1)
49 };
50 let mut uf = UF::new_with_merger(n, &c01, |x, y| {
51 cost += x.cnt1 * y.cnt0;
52 *x += *y;
53 });
54 let mut heap = BTreeSet::from_iter((0..n).filter(|&u| u != root).map(|u| (c01(u), u)));
55 while let Some((_c, u)) = heap.pop_last() {
56 let p = uf.find_root(parent(u));
57 heap.remove(&(*uf.merge_data(p), p));
58 uf.unite(u, p);
59 if !uf.same(p, root) {
60 heap.insert((*uf.merge_data(p), p));
61 }
62 }
63 cost
64}