competitive/data_structure/
union_find.rs

1use super::{Group, Monoid};
2use std::{
3    collections::HashMap,
4    fmt::{self, Debug},
5    marker::PhantomData,
6    mem::swap,
7};
8
9pub struct UnionFindBase<U, F, M, P, H>
10where
11    U: UnionStrategy,
12    F: FindStrategy,
13    M: UfMergeSpec,
14    P: Monoid,
15    H: UndoStrategy<UfCell<U, M, P>>,
16{
17    cells: Vec<UfCell<U, M, P>>,
18    merger: M,
19    history: H::History,
20    _marker: PhantomData<fn() -> F>,
21}
22
23impl<U, F, M, P, H> Clone for UnionFindBase<U, F, M, P, H>
24where
25    U: UnionStrategy,
26    F: FindStrategy,
27    M: UfMergeSpec + Clone,
28    P: Monoid,
29    H: UndoStrategy<UfCell<U, M, P>>,
30    U::Info: Clone,
31    M::Data: Clone,
32    H::History: Clone,
33{
34    fn clone(&self) -> Self {
35        Self {
36            cells: self.cells.clone(),
37            merger: self.merger.clone(),
38            history: self.history.clone(),
39            _marker: self._marker,
40        }
41    }
42}
43
44impl<U, F, M, P, H> Debug for UnionFindBase<U, F, M, P, H>
45where
46    U: UnionStrategy,
47    F: FindStrategy,
48    M: UfMergeSpec,
49    P: Monoid,
50    H: UndoStrategy<UfCell<U, M, P>>,
51    U::Info: Debug,
52    M::Data: Debug,
53    P::T: Debug,
54    H::History: Debug,
55{
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("UnionFindBase")
58            .field("cells", &self.cells)
59            .field("history", &self.history)
60            .finish()
61    }
62}
63
64pub enum UfCell<U, M, P>
65where
66    U: UnionStrategy,
67    M: UfMergeSpec,
68    P: Monoid,
69{
70    Root((U::Info, M::Data)),
71    Child((usize, P::T)),
72}
73
74impl<U, M, P> Clone for UfCell<U, M, P>
75where
76    U: UnionStrategy,
77    M: UfMergeSpec,
78    P: Monoid,
79    U::Info: Clone,
80    M::Data: Clone,
81{
82    fn clone(&self) -> Self {
83        match self {
84            Self::Root(data) => Self::Root(data.clone()),
85            Self::Child(data) => Self::Child(data.clone()),
86        }
87    }
88}
89
90impl<U, M, P> Debug for UfCell<U, M, P>
91where
92    U: UnionStrategy,
93    M: UfMergeSpec,
94    P: Monoid,
95    U::Info: Debug,
96    M::Data: Debug,
97    P::T: Debug,
98{
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::Root(data) => f.debug_tuple("Root").field(data).finish(),
102            Self::Child(data) => f.debug_tuple("Child").field(data).finish(),
103        }
104    }
105}
106
107impl<U, M, P> UfCell<U, M, P>
108where
109    U: UnionStrategy,
110    M: UfMergeSpec,
111    P: Monoid,
112{
113    fn root_mut(&mut self) -> Option<&mut (U::Info, M::Data)> {
114        match self {
115            UfCell::Root(root) => Some(root),
116            UfCell::Child(_) => None,
117        }
118    }
119    fn child_mut(&mut self) -> Option<&mut (usize, P::T)> {
120        match self {
121            UfCell::Child(child) => Some(child),
122            UfCell::Root(_) => None,
123        }
124    }
125}
126
127pub trait FindStrategy {
128    const CHENGE_ROOT: bool;
129}
130
131pub enum PathCompression {}
132
133impl FindStrategy for PathCompression {
134    const CHENGE_ROOT: bool = true;
135}
136
137impl FindStrategy for () {
138    const CHENGE_ROOT: bool = false;
139}
140
141pub trait UnionStrategy {
142    type Info: Clone;
143    fn single_info() -> Self::Info;
144    fn check_directoin(parent: &Self::Info, child: &Self::Info) -> bool;
145    fn unite(parent: &Self::Info, child: &Self::Info) -> Self::Info;
146}
147
148pub enum UnionBySize {}
149
150impl UnionStrategy for UnionBySize {
151    type Info = usize;
152
153    fn single_info() -> Self::Info {
154        1
155    }
156
157    fn check_directoin(parent: &Self::Info, child: &Self::Info) -> bool {
158        parent >= child
159    }
160
161    fn unite(parent: &Self::Info, child: &Self::Info) -> Self::Info {
162        parent + child
163    }
164}
165
166pub enum UnionByRank {}
167
168impl UnionStrategy for UnionByRank {
169    type Info = u32;
170
171    fn single_info() -> Self::Info {
172        0
173    }
174
175    fn check_directoin(parent: &Self::Info, child: &Self::Info) -> bool {
176        parent >= child
177    }
178
179    fn unite(parent: &Self::Info, child: &Self::Info) -> Self::Info {
180        parent + (parent == child) as u32
181    }
182}
183
184impl UnionStrategy for () {
185    type Info = ();
186
187    fn single_info() -> Self::Info {}
188
189    fn check_directoin(_parent: &Self::Info, _child: &Self::Info) -> bool {
190        false
191    }
192
193    fn unite(_parent: &Self::Info, _child: &Self::Info) -> Self::Info {}
194}
195
196pub trait UfMergeSpec {
197    type Data;
198    fn merge(&mut self, to: &mut Self::Data, from: &mut Self::Data);
199}
200
201#[derive(Debug, Clone)]
202pub struct FnMerger<T, F> {
203    f: F,
204    _marker: PhantomData<fn() -> T>,
205}
206
207impl<T, F> UfMergeSpec for FnMerger<T, F>
208where
209    F: FnMut(&mut T, &mut T),
210{
211    type Data = T;
212
213    fn merge(&mut self, to: &mut Self::Data, from: &mut Self::Data) {
214        (self.f)(to, from)
215    }
216}
217
218impl UfMergeSpec for () {
219    type Data = ();
220
221    fn merge(&mut self, _to: &mut Self::Data, _from: &mut Self::Data) {}
222}
223
224pub trait UndoStrategy<T> {
225    const UNDOABLE: bool;
226
227    type History: Default;
228
229    fn unite(history: &mut Self::History, x: usize, y: usize, cells: &[T]);
230
231    fn undo_unite(history: &mut Self::History, cells: &mut [T]);
232}
233
234pub enum Undoable {}
235
236impl<T> UndoStrategy<T> for Undoable
237where
238    T: Clone,
239{
240    const UNDOABLE: bool = true;
241
242    type History = Vec<[(usize, T); 2]>;
243
244    fn unite(history: &mut Self::History, x: usize, y: usize, cells: &[T]) {
245        let cx = cells[x].clone();
246        let cy = cells[y].clone();
247        history.push([(x, cx), (y, cy)]);
248    }
249
250    fn undo_unite(history: &mut Self::History, cells: &mut [T]) {
251        if let Some([(x, cx), (y, cy)]) = history.pop() {
252            cells[x] = cx;
253            cells[y] = cy;
254        }
255    }
256}
257
258impl<T> UndoStrategy<T> for () {
259    const UNDOABLE: bool = false;
260
261    type History = ();
262
263    fn unite(_history: &mut Self::History, _x: usize, _y: usize, _cells: &[T]) {}
264
265    fn undo_unite(_history: &mut Self::History, _cells: &mut [T]) {}
266}
267
268impl<U, F, P, H> UnionFindBase<U, F, (), P, H>
269where
270    U: UnionStrategy,
271    F: FindStrategy,
272    P: Monoid,
273    H: UndoStrategy<UfCell<U, (), P>>,
274{
275    pub fn new(n: usize) -> Self {
276        let cells: Vec<_> = (0..n)
277            .map(|_| UfCell::Root((U::single_info(), ())))
278            .collect();
279        Self {
280            cells,
281            merger: (),
282            history: Default::default(),
283            _marker: PhantomData,
284        }
285    }
286    pub fn push(&mut self) {
287        self.cells.push(UfCell::Root((U::single_info(), ())));
288    }
289}
290
291impl<U, F, T, Merge, P, H> UnionFindBase<U, F, FnMerger<T, Merge>, P, H>
292where
293    U: UnionStrategy,
294    F: FindStrategy,
295    Merge: FnMut(&mut T, &mut T),
296    P: Monoid,
297    H: UndoStrategy<UfCell<U, FnMerger<T, Merge>, P>>,
298{
299    pub fn new_with_merger(n: usize, mut init: impl FnMut(usize) -> T, merge: Merge) -> Self {
300        let cells: Vec<_> = (0..n)
301            .map(|i| UfCell::Root((U::single_info(), init(i))))
302            .collect();
303        Self {
304            cells,
305            merger: FnMerger {
306                f: merge,
307                _marker: PhantomData,
308            },
309            history: Default::default(),
310            _marker: PhantomData,
311        }
312    }
313}
314
315impl<F, M, P, H> UnionFindBase<UnionBySize, F, M, P, H>
316where
317    F: FindStrategy,
318    M: UfMergeSpec,
319    P: Monoid,
320    H: UndoStrategy<UfCell<UnionBySize, M, P>>,
321{
322    pub fn size(&mut self, x: usize) -> <UnionBySize as UnionStrategy>::Info {
323        let root = self.find_root(x);
324        self.root_info(root).unwrap()
325    }
326}
327
328impl<U, F, M, P, H> UnionFindBase<U, F, M, P, H>
329where
330    U: UnionStrategy,
331    F: FindStrategy,
332    M: UfMergeSpec,
333    P: Monoid,
334    H: UndoStrategy<UfCell<U, M, P>>,
335{
336    fn root_info(&mut self, x: usize) -> Option<U::Info> {
337        match &self.cells[x] {
338            UfCell::Root((info, _)) => Some(info.clone()),
339            UfCell::Child(_) => None,
340        }
341    }
342
343    fn root_info_mut(&mut self, x: usize) -> Option<&mut U::Info> {
344        match &mut self.cells[x] {
345            UfCell::Root((info, _)) => Some(info),
346            UfCell::Child(_) => None,
347        }
348    }
349
350    pub fn same(&mut self, x: usize, y: usize) -> bool {
351        self.find_root(x) == self.find_root(y)
352    }
353
354    pub fn merge_data(&mut self, x: usize) -> &M::Data {
355        let root = self.find_root(x);
356        match &self.cells[root] {
357            UfCell::Root((_, data)) => data,
358            UfCell::Child(_) => unreachable!(),
359        }
360    }
361
362    pub fn merge_data_mut(&mut self, x: usize) -> &mut M::Data {
363        let root = self.find_root(x);
364        match &mut self.cells[root] {
365            UfCell::Root((_, data)) => data,
366            UfCell::Child(_) => unreachable!(),
367        }
368    }
369
370    pub fn roots(&self) -> impl Iterator<Item = usize> + '_ {
371        (0..self.cells.len()).filter(|&x| matches!(self.cells[x], UfCell::Root(_)))
372    }
373
374    pub fn all_group_members(&mut self) -> HashMap<usize, Vec<usize>> {
375        let mut groups_map = HashMap::new();
376        for x in 0..self.cells.len() {
377            let r = self.find_root(x);
378            groups_map.entry(r).or_insert_with(Vec::new).push(x);
379        }
380        groups_map
381    }
382
383    pub fn find(&mut self, x: usize) -> (usize, P::T) {
384        let (parent_parent, parent_potential) = match &self.cells[x] {
385            UfCell::Child((parent, _)) => self.find(*parent),
386            UfCell::Root(_) => return (x, P::unit()),
387        };
388        let (parent, potential) = self.cells[x].child_mut().unwrap();
389        let potential = if F::CHENGE_ROOT {
390            *parent = parent_parent;
391            *potential = P::operate(&parent_potential, potential);
392            potential.clone()
393        } else {
394            P::operate(&parent_potential, potential)
395        };
396        (parent_parent, potential)
397    }
398
399    pub fn find_root(&mut self, x: usize) -> usize {
400        let (parent, parent_parent) = match &self.cells[x] {
401            UfCell::Child((parent, _)) => (*parent, self.find_root(*parent)),
402            UfCell::Root(_) => return x,
403        };
404        if F::CHENGE_ROOT {
405            let (cx, cp) = {
406                let ptr = self.cells.as_mut_ptr();
407                unsafe { (&mut *ptr.add(x), &*ptr.add(parent)) }
408            };
409            let (parent, potential) = cx.child_mut().unwrap();
410            *parent = parent_parent;
411            if let UfCell::Child((_, ppot)) = &cp {
412                *potential = P::operate(ppot, potential);
413            }
414        }
415        parent_parent
416    }
417
418    pub fn unite_noninv(&mut self, x: usize, y: usize, potential: P::T) -> bool {
419        let (rx, potx) = self.find(x);
420        let ry = self.find_root(y);
421        if rx == ry || y != ry {
422            return false;
423        }
424        H::unite(&mut self.history, rx, ry, &self.cells);
425        {
426            let ptr = self.cells.as_mut_ptr();
427            let (cx, cy) = unsafe { (&mut *ptr.add(rx), &mut *ptr.add(ry)) };
428            self.merger
429                .merge(&mut cx.root_mut().unwrap().1, &mut cy.root_mut().unwrap().1);
430        }
431        *self.root_info_mut(rx).unwrap() =
432            U::unite(&self.root_info(rx).unwrap(), &self.root_info(ry).unwrap());
433        self.cells[ry] = UfCell::Child((rx, P::operate(&potx, &potential)));
434        true
435    }
436}
437
438impl<U, F, M, P, H> UnionFindBase<U, F, M, P, H>
439where
440    U: UnionStrategy,
441    F: FindStrategy,
442    M: UfMergeSpec,
443    P: Group,
444    H: UndoStrategy<UfCell<U, M, P>>,
445{
446    pub fn difference(&mut self, x: usize, y: usize) -> Option<P::T> {
447        let (rx, potx) = self.find(x);
448        let (ry, poty) = self.find(y);
449        if rx == ry {
450            Some(P::operate(&P::inverse(&potx), &poty))
451        } else {
452            None
453        }
454    }
455
456    pub fn unite_with(&mut self, x: usize, y: usize, potential: P::T) -> bool {
457        let (mut rx, potx) = self.find(x);
458        let (mut ry, poty) = self.find(y);
459        if rx == ry {
460            return false;
461        }
462        let mut xinfo = self.root_info(rx).unwrap();
463        let mut yinfo = self.root_info(ry).unwrap();
464        let inverse = !U::check_directoin(&xinfo, &yinfo);
465        let potential = if inverse {
466            P::rinv_operate(&poty, &P::operate(&potx, &potential))
467        } else {
468            P::operate(&potx, &P::rinv_operate(&potential, &poty))
469        };
470        if inverse {
471            swap(&mut rx, &mut ry);
472            swap(&mut xinfo, &mut yinfo);
473        }
474        H::unite(&mut self.history, rx, ry, &self.cells);
475        {
476            let ptr = self.cells.as_mut_ptr();
477            let (cx, cy) = unsafe { (&mut *ptr.add(rx), &mut *ptr.add(ry)) };
478            self.merger
479                .merge(&mut cx.root_mut().unwrap().1, &mut cy.root_mut().unwrap().1);
480        }
481        *self.root_info_mut(rx).unwrap() = U::unite(&xinfo, &yinfo);
482        self.cells[ry] = UfCell::Child((rx, potential));
483        true
484    }
485
486    pub fn unite(&mut self, x: usize, y: usize) -> bool {
487        self.unite_with(x, y, P::unit())
488    }
489}
490
491impl<U, M, P, H> UnionFindBase<U, (), M, P, H>
492where
493    U: UnionStrategy,
494    M: UfMergeSpec,
495    P: Monoid,
496    H: UndoStrategy<UfCell<U, M, P>>,
497{
498    pub fn undo(&mut self) {
499        H::undo_unite(&mut self.history, &mut self.cells);
500    }
501}
502
503pub type UnionFind = UnionFindBase<UnionBySize, PathCompression, (), (), ()>;
504pub type MergingUnionFind<T, M> =
505    UnionFindBase<UnionBySize, PathCompression, FnMerger<T, M>, (), ()>;
506pub type PotentializedUnionFind<P> = UnionFindBase<UnionBySize, PathCompression, (), P, ()>;
507pub type UndoableUnionFind = UnionFindBase<UnionBySize, (), (), (), Undoable>;
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::{
513        algebra::{Invertible, LinearOperation, Magma, Unital},
514        graph::UndirectedSparseGraph,
515        num::mint_basic::MInt998244353 as M,
516        rand,
517        tools::{RandomSpec, Xorshift},
518        tree::MixedTree,
519    };
520    use std::collections::HashSet;
521
522    fn distinct_edges(rng: &mut Xorshift, n: usize, m: usize) -> Vec<(usize, usize)> {
523        let mut edges = vec![];
524        for x in 0..n {
525            for y in 0..n {
526                edges.push((x, y));
527            }
528        }
529        rng.shuffle(&mut edges);
530        edges.truncate(m);
531        edges
532    }
533
534    fn dfs(
535        g: &UndirectedSparseGraph,
536        u: usize,
537        vis: &mut [bool],
538        f: &mut impl FnMut(usize),
539        f2: &mut impl FnMut(usize, usize, usize),
540    ) {
541        vis[u] = true;
542        f(u);
543        for a in g.adjacencies(u) {
544            if !vis[a.to] {
545                f2(u, a.to, a.id);
546                dfs(g, a.to, vis, f, f2);
547            }
548        }
549    }
550
551    struct Mspec;
552    impl RandomSpec<M> for Mspec {
553        fn rand(&self, rng: &mut Xorshift) -> M {
554            M::new_unchecked(rng.random(0..M::get_mod()))
555        }
556    }
557
558    #[test]
559    fn test_union_find() {
560        const N: usize = 20;
561        let mut rng = Xorshift::default();
562        for _ in 0..1000 {
563            rand!(rng, n: 1..=N, m: 1..=n * n);
564            let edges = distinct_edges(&mut rng, n, m);
565
566            macro_rules! test_uf {
567                ($union:ty, $find:ty) => {{
568                    let mut uf = UnionFindBase::<$union, $find, FnMerger<Vec<usize>, _>, (), ()>::new_with_merger(n, |i| vec![i], |x, y| x.append(y));
569                    for &(x, y) in &edges {
570                        uf.unite(x, y);
571                    }
572                    let g = UndirectedSparseGraph::from_edges(n, edges.to_vec());
573                    let mut id = vec![!0; n];
574                    {
575                        let mut vis = vec![false; n];
576                        for x in 0..n {
577                            if vis[x] {
578                                continue;
579                            }
580                            let mut set = HashSet::new();
581                            dfs(
582                                &g,
583                                x,
584                                &mut vis,
585                                &mut |x| {
586                                    set.insert(x);
587                                },
588                                &mut |_, _, _| {},
589                            );
590                            for s in set {
591                                id[s] = x;
592                            }
593                        }
594                    }
595                    for x in 0..n {
596                        for y in 0..n {
597                            assert_eq!(id[x] == id[y], uf.same(x, y));
598                        }
599                        assert_eq!(
600                            (0..n).filter(|&y| id[x] == id[y]).collect::<HashSet<_>>(),
601                            uf.merge_data(x).iter().cloned().collect()
602                        );
603                    }
604                }};
605            }
606            test_uf!(UnionBySize, PathCompression);
607            test_uf!(UnionByRank, PathCompression);
608            test_uf!((), PathCompression);
609            test_uf!(UnionBySize, ());
610            test_uf!(UnionByRank, ());
611            test_uf!((), ());
612        }
613    }
614
615    #[test]
616    fn test_potential_union_find() {
617        const N: usize = 20;
618        let mut rng = Xorshift::default();
619        type G = LinearOperation<M>;
620        for _ in 0..1000 {
621            rand!(rng, n: 1..=N, g: MixedTree(n), p: [(Mspec, Mspec); n - 1], k: 0..n);
622
623            macro_rules! test_uf {
624                ($union:ty, $find:ty) => {{
625                    let mut uf = UnionFindBase::<$union, $find, (), G, ()>::new(n);
626                    for (i, &(u, v)) in g.edges.iter().enumerate().take(k) {
627                        uf.unite_with(u, v, p[i]);
628                    }
629                    for x in 0..n {
630                        let mut vis = vec![false; n];
631                        let mut dp = vec![None; n];
632                        dp[x] = Some(G::unit());
633                        dfs(&g, x, &mut vis, &mut |_| {}, &mut |u, to, id| {
634                            let p = if g.edges[id] == (u, to) {
635                                p[id]
636                            } else {
637                                G::inverse(&p[id])
638                            };
639                            if id < k {
640                                if let Some(d) = dp[u] {
641                                    dp[to] = Some(G::operate(&d, &p));
642                                }
643                            }
644                        });
645                        for (y, d) in dp.into_iter().enumerate() {
646                            assert_eq!(d, uf.difference(x, y));
647                        }
648                    }
649                }};
650            }
651            test_uf!(UnionBySize, PathCompression);
652            test_uf!(UnionByRank, PathCompression);
653            test_uf!((), PathCompression);
654            test_uf!(UnionBySize, ());
655            test_uf!(UnionByRank, ());
656            test_uf!((), ());
657        }
658    }
659
660    #[test]
661    fn test_undoable_union_find() {
662        const N: usize = 10;
663        const M: usize = 200;
664        let mut rng = Xorshift::default();
665        for _ in 0..10 {
666            rand!(rng, n: 1..=N, m: 1..=M, g: MixedTree(m), p: [(0..n, 0..n); m]);
667
668            macro_rules! test_uf {
669                ($union:ty, $find:ty) => {{
670                    let uf = UnionFind::new(n);
671                    let mut uf2 = UnionFindBase::<$union, $find, (), (), Undoable>::new(n);
672                    fn dfs(
673                        n: usize,
674                        g: &UndirectedSparseGraph,
675                        u: usize,
676                        vis: &mut [bool],
677                        mut uf: UnionFindBase<UnionBySize, PathCompression, (), (), ()>,
678                        uf2: &mut UnionFindBase<$union, $find, (), (), Undoable>,
679                        p: &[(usize, usize)],
680                    ) {
681                        vis[u] = true;
682                        for x in 0..n {
683                            for y in 0..n {
684                                assert_eq!(uf.same(x, y), uf2.same(x, y));
685                            }
686                        }
687                        for a in g.adjacencies(u) {
688                            if !vis[a.to] {
689                                let (x, y) = p[a.id];
690                                let mut uf = uf.clone();
691                                uf.unite(x, y);
692                                let merged = uf2.unite(x, y);
693                                dfs(n, g, a.to, vis, uf, uf2, p);
694                                if merged {
695                                    uf2.undo();
696                                }
697                            }
698                        }
699                    }
700                    for u in 0..m {
701                        dfs(n, &g, u, &mut vec![false; m], uf.clone(), &mut uf2, &p);
702                    }
703                }};
704            }
705            test_uf!(UnionBySize, ());
706            test_uf!(UnionByRank, ());
707            test_uf!((), ());
708        }
709    }
710}