competitive/algorithm/
solve_01_on_tree.rs1use super::{UnionFindBase, union_find};
2use std::{cmp::Ordering, collections::BinaryHeap, ops::AddAssign};
3
4#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
5struct Count01 {
6 cnt0: usize,
7 cnt1: usize,
8}
9
10impl PartialOrd for Count01 {
11 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
12 Some(self.cmp(other))
13 }
14}
15
16impl Ord for Count01 {
17 fn cmp(&self, other: &Self) -> Ordering {
18 match (self.cnt1 == 0, other.cnt1 == 0) {
19 (true, true) => self.cnt0.cmp(&other.cnt0),
20 (true, false) => Ordering::Greater,
21 (false, true) => Ordering::Less,
22 (false, false) => (self.cnt0 * other.cnt1).cmp(&(other.cnt0 * self.cnt1)),
23 }
24 }
25}
26
27impl AddAssign for Count01 {
28 fn add_assign(&mut self, other: Self) {
29 *self = Self {
30 cnt0: self.cnt0 + other.cnt0,
31 cnt1: self.cnt1 + other.cnt1,
32 }
33 }
34}
35
36impl Count01 {
37 pub fn new(cnt0: usize, cnt1: usize) -> Self {
38 Self { cnt0, cnt1 }
39 }
40}
41
42pub fn solve_01_on_tree(
43 n: usize,
44 c01: impl Fn(usize) -> (usize, usize),
45 root: usize,
46 parent: impl Fn(usize) -> usize,
47) -> (usize, Vec<usize>) {
48 pub type UF<T, M> =
49 UnionFindBase<(), union_find::PathCompression, union_find::FnMerger<T, M>, (), ()>;
50 let mut cost = 0usize;
51 let c01 = |u| {
52 let c = c01(u);
53 Count01::new(c.0, c.1)
54 };
55 let mut uf = UF::new_with_merger(n, &c01, |x, y| {
56 cost += x.cnt1 * y.cnt0;
57 *x += *y;
58 });
59 let mut label = vec![0; n];
60 let mut heap = BinaryHeap::from_iter((0..n).filter(|&u| u != root).map(|u| (c01(u), u, 0)));
61 let mut next: Vec<_> = (0..n).collect();
62 let mut ord = Vec::with_capacity(n);
63 while let Some((_c, u, l)) = heap.pop() {
64 if label[u] != l {
65 continue;
66 }
67 let p = uf.find_root(parent(u));
68 uf.unite(u, p);
69 if !uf.same(p, root) {
70 label[p] += 1;
71 heap.push((*uf.merge_data(p), p, label[p]));
72 }
73 next.swap(u, p);
74 }
75 let mut u = next[root];
76 ord.push(u);
77 while u != root {
78 u = next[u];
79 ord.push(u);
80 }
81 ord.reverse();
82 (cost, ord)
83}