competitive/data_structure/
union_find.rs

1use super::{Group, Monoid};
2use std::{
3    collections::HashMap,
4    fmt::{self, Debug},
5    marker::PhantomData,
6    mem::swap,
7};
8
9pub struct UnionFindBase<U, F, M, P, H>
10where
11    U: UnionStrategy,
12    F: FindStrategy,
13    M: UfMergeSpec,
14    P: Monoid,
15    H: UndoStrategy<UfCell<U, M, P>>,
16{
17    cells: Vec<UfCell<U, M, P>>,
18    merger: M,
19    history: H::History,
20    _marker: PhantomData<fn() -> F>,
21}
22
23impl<U, F, M, P, H> Clone for UnionFindBase<U, F, M, P, H>
24where
25    U: UnionStrategy<Info: Clone>,
26    F: FindStrategy,
27    M: UfMergeSpec<Data: Clone> + Clone,
28    P: Monoid,
29    H: UndoStrategy<UfCell<U, M, P>, History: Clone>,
30{
31    fn clone(&self) -> Self {
32        Self {
33            cells: self.cells.clone(),
34            merger: self.merger.clone(),
35            history: self.history.clone(),
36            _marker: self._marker,
37        }
38    }
39}
40
41impl<U, F, M, P, H> Debug for UnionFindBase<U, F, M, P, H>
42where
43    U: UnionStrategy<Info: Debug>,
44    F: FindStrategy,
45    M: UfMergeSpec<Data: Debug>,
46    P: Monoid<T: Debug>,
47    H: UndoStrategy<UfCell<U, M, P>, History: Debug>,
48{
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("UnionFindBase")
51            .field("cells", &self.cells)
52            .field("history", &self.history)
53            .finish()
54    }
55}
56
57pub enum UfCell<U, M, P>
58where
59    U: UnionStrategy,
60    M: UfMergeSpec,
61    P: Monoid,
62{
63    Root((U::Info, M::Data)),
64    Child((usize, P::T)),
65}
66
67impl<U, M, P> Clone for UfCell<U, M, P>
68where
69    U: UnionStrategy<Info: Clone>,
70    M: UfMergeSpec<Data: Clone>,
71    P: Monoid,
72{
73    fn clone(&self) -> Self {
74        match self {
75            Self::Root(data) => Self::Root(data.clone()),
76            Self::Child(data) => Self::Child(data.clone()),
77        }
78    }
79}
80
81impl<U, M, P> Debug for UfCell<U, M, P>
82where
83    U: UnionStrategy<Info: Debug>,
84    M: UfMergeSpec<Data: Debug>,
85    P: Monoid<T: Debug>,
86{
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            Self::Root(data) => f.debug_tuple("Root").field(data).finish(),
90            Self::Child(data) => f.debug_tuple("Child").field(data).finish(),
91        }
92    }
93}
94
95impl<U, M, P> UfCell<U, M, P>
96where
97    U: UnionStrategy,
98    M: UfMergeSpec,
99    P: Monoid,
100{
101    fn root_mut(&mut self) -> Option<&mut (U::Info, M::Data)> {
102        match self {
103            UfCell::Root(root) => Some(root),
104            UfCell::Child(_) => None,
105        }
106    }
107    fn child_mut(&mut self) -> Option<&mut (usize, P::T)> {
108        match self {
109            UfCell::Child(child) => Some(child),
110            UfCell::Root(_) => None,
111        }
112    }
113}
114
115pub trait FindStrategy {
116    const CHENGE_ROOT: bool;
117}
118
119pub enum PathCompression {}
120
121impl FindStrategy for PathCompression {
122    const CHENGE_ROOT: bool = true;
123}
124
125impl FindStrategy for () {
126    const CHENGE_ROOT: bool = false;
127}
128
129pub trait UnionStrategy {
130    type Info: Clone;
131    fn single_info() -> Self::Info;
132    fn check_directoin(parent: &Self::Info, child: &Self::Info) -> bool;
133    fn unite(parent: &Self::Info, child: &Self::Info) -> Self::Info;
134}
135
136pub enum UnionBySize {}
137
138impl UnionStrategy for UnionBySize {
139    type Info = usize;
140
141    fn single_info() -> Self::Info {
142        1
143    }
144
145    fn check_directoin(parent: &Self::Info, child: &Self::Info) -> bool {
146        parent >= child
147    }
148
149    fn unite(parent: &Self::Info, child: &Self::Info) -> Self::Info {
150        parent + child
151    }
152}
153
154pub enum UnionByRank {}
155
156impl UnionStrategy for UnionByRank {
157    type Info = u32;
158
159    fn single_info() -> Self::Info {
160        0
161    }
162
163    fn check_directoin(parent: &Self::Info, child: &Self::Info) -> bool {
164        parent >= child
165    }
166
167    fn unite(parent: &Self::Info, child: &Self::Info) -> Self::Info {
168        parent + (parent == child) as u32
169    }
170}
171
172impl UnionStrategy for () {
173    type Info = ();
174
175    fn single_info() -> Self::Info {}
176
177    fn check_directoin(_parent: &Self::Info, _child: &Self::Info) -> bool {
178        false
179    }
180
181    fn unite(_parent: &Self::Info, _child: &Self::Info) -> Self::Info {}
182}
183
184pub trait UfMergeSpec {
185    type Data;
186    fn merge(&mut self, to: &mut Self::Data, from: &mut Self::Data);
187}
188
189#[derive(Debug, Clone)]
190pub struct FnMerger<T, F> {
191    f: F,
192    _marker: PhantomData<fn() -> T>,
193}
194
195impl<T, F> UfMergeSpec for FnMerger<T, F>
196where
197    F: FnMut(&mut T, &mut T),
198{
199    type Data = T;
200
201    fn merge(&mut self, to: &mut Self::Data, from: &mut Self::Data) {
202        (self.f)(to, from)
203    }
204}
205
206impl UfMergeSpec for () {
207    type Data = ();
208
209    fn merge(&mut self, _to: &mut Self::Data, _from: &mut Self::Data) {}
210}
211
212pub trait UndoStrategy<T> {
213    const UNDOABLE: bool;
214
215    type History: Default;
216
217    fn unite(history: &mut Self::History, x: usize, y: usize, cells: &[T]);
218
219    fn undo_unite(history: &mut Self::History, cells: &mut [T]);
220}
221
222pub enum Undoable {}
223
224impl<T> UndoStrategy<T> for Undoable
225where
226    T: Clone,
227{
228    const UNDOABLE: bool = true;
229
230    type History = Vec<[(usize, T); 2]>;
231
232    fn unite(history: &mut Self::History, x: usize, y: usize, cells: &[T]) {
233        let cx = cells[x].clone();
234        let cy = cells[y].clone();
235        history.push([(x, cx), (y, cy)]);
236    }
237
238    fn undo_unite(history: &mut Self::History, cells: &mut [T]) {
239        if let Some([(x, cx), (y, cy)]) = history.pop() {
240            cells[x] = cx;
241            cells[y] = cy;
242        }
243    }
244}
245
246impl<T> UndoStrategy<T> for () {
247    const UNDOABLE: bool = false;
248
249    type History = ();
250
251    fn unite(_history: &mut Self::History, _x: usize, _y: usize, _cells: &[T]) {}
252
253    fn undo_unite(_history: &mut Self::History, _cells: &mut [T]) {}
254}
255
256impl<U, F, P, H> UnionFindBase<U, F, (), P, H>
257where
258    U: UnionStrategy,
259    F: FindStrategy,
260    P: Monoid,
261    H: UndoStrategy<UfCell<U, (), P>>,
262{
263    pub fn new(n: usize) -> Self {
264        let cells: Vec<_> = (0..n)
265            .map(|_| UfCell::Root((U::single_info(), ())))
266            .collect();
267        Self {
268            cells,
269            merger: (),
270            history: Default::default(),
271            _marker: PhantomData,
272        }
273    }
274    pub fn push(&mut self) {
275        self.cells.push(UfCell::Root((U::single_info(), ())));
276    }
277}
278
279impl<U, F, T, Merge, P, H> UnionFindBase<U, F, FnMerger<T, Merge>, P, H>
280where
281    U: UnionStrategy,
282    F: FindStrategy,
283    Merge: FnMut(&mut T, &mut T),
284    P: Monoid,
285    H: UndoStrategy<UfCell<U, FnMerger<T, Merge>, P>>,
286{
287    pub fn new_with_merger(n: usize, mut init: impl FnMut(usize) -> T, merge: Merge) -> Self {
288        let cells: Vec<_> = (0..n)
289            .map(|i| UfCell::Root((U::single_info(), init(i))))
290            .collect();
291        Self {
292            cells,
293            merger: FnMerger {
294                f: merge,
295                _marker: PhantomData,
296            },
297            history: Default::default(),
298            _marker: PhantomData,
299        }
300    }
301}
302
303impl<F, M, P, H> UnionFindBase<UnionBySize, F, M, P, H>
304where
305    F: FindStrategy,
306    M: UfMergeSpec,
307    P: Monoid,
308    H: UndoStrategy<UfCell<UnionBySize, M, P>>,
309{
310    pub fn size(&mut self, x: usize) -> <UnionBySize as UnionStrategy>::Info {
311        let root = self.find_root(x);
312        self.root_info(root).unwrap()
313    }
314}
315
316impl<U, F, M, P, H> UnionFindBase<U, F, M, P, H>
317where
318    U: UnionStrategy,
319    F: FindStrategy,
320    M: UfMergeSpec,
321    P: Monoid,
322    H: UndoStrategy<UfCell<U, M, P>>,
323{
324    fn root_info(&mut self, x: usize) -> Option<U::Info> {
325        match &self.cells[x] {
326            UfCell::Root((info, _)) => Some(info.clone()),
327            UfCell::Child(_) => None,
328        }
329    }
330
331    fn root_info_mut(&mut self, x: usize) -> Option<&mut U::Info> {
332        match &mut self.cells[x] {
333            UfCell::Root((info, _)) => Some(info),
334            UfCell::Child(_) => None,
335        }
336    }
337
338    pub fn same(&mut self, x: usize, y: usize) -> bool {
339        self.find_root(x) == self.find_root(y)
340    }
341
342    pub fn merge_data(&mut self, x: usize) -> &M::Data {
343        let root = self.find_root(x);
344        match &self.cells[root] {
345            UfCell::Root((_, data)) => data,
346            UfCell::Child(_) => unreachable!(),
347        }
348    }
349
350    pub fn merge_data_mut(&mut self, x: usize) -> &mut M::Data {
351        let root = self.find_root(x);
352        match &mut self.cells[root] {
353            UfCell::Root((_, data)) => data,
354            UfCell::Child(_) => unreachable!(),
355        }
356    }
357
358    pub fn roots(&self) -> impl Iterator<Item = usize> + '_ {
359        (0..self.cells.len()).filter(|&x| matches!(self.cells[x], UfCell::Root(_)))
360    }
361
362    pub fn all_group_members(&mut self) -> HashMap<usize, Vec<usize>> {
363        let mut groups_map = HashMap::new();
364        for x in 0..self.cells.len() {
365            let r = self.find_root(x);
366            groups_map.entry(r).or_insert_with(Vec::new).push(x);
367        }
368        groups_map
369    }
370
371    pub fn find(&mut self, x: usize) -> (usize, P::T) {
372        let (parent_parent, parent_potential) = match &self.cells[x] {
373            UfCell::Child((parent, _)) => self.find(*parent),
374            UfCell::Root(_) => return (x, P::unit()),
375        };
376        let (parent, potential) = self.cells[x].child_mut().unwrap();
377        let potential = if F::CHENGE_ROOT {
378            *parent = parent_parent;
379            *potential = P::operate(&parent_potential, potential);
380            potential.clone()
381        } else {
382            P::operate(&parent_potential, potential)
383        };
384        (parent_parent, potential)
385    }
386
387    pub fn find_root(&mut self, x: usize) -> usize {
388        let (parent, parent_parent) = match &self.cells[x] {
389            UfCell::Child((parent, _)) => (*parent, self.find_root(*parent)),
390            UfCell::Root(_) => return x,
391        };
392        if F::CHENGE_ROOT {
393            let (cx, cp) = {
394                let ptr = self.cells.as_mut_ptr();
395                unsafe { (&mut *ptr.add(x), &*ptr.add(parent)) }
396            };
397            let (parent, potential) = cx.child_mut().unwrap();
398            *parent = parent_parent;
399            if let UfCell::Child((_, ppot)) = &cp {
400                *potential = P::operate(ppot, potential);
401            }
402        }
403        parent_parent
404    }
405
406    pub fn unite_noninv(&mut self, x: usize, y: usize, potential: P::T) -> bool {
407        let (rx, potx) = self.find(x);
408        let ry = self.find_root(y);
409        if rx == ry || y != ry {
410            return false;
411        }
412        H::unite(&mut self.history, rx, ry, &self.cells);
413        {
414            let ptr = self.cells.as_mut_ptr();
415            let (cx, cy) = unsafe { (&mut *ptr.add(rx), &mut *ptr.add(ry)) };
416            self.merger
417                .merge(&mut cx.root_mut().unwrap().1, &mut cy.root_mut().unwrap().1);
418        }
419        *self.root_info_mut(rx).unwrap() =
420            U::unite(&self.root_info(rx).unwrap(), &self.root_info(ry).unwrap());
421        self.cells[ry] = UfCell::Child((rx, P::operate(&potx, &potential)));
422        true
423    }
424}
425
426impl<U, F, M, P, H> UnionFindBase<U, F, M, P, H>
427where
428    U: UnionStrategy,
429    F: FindStrategy,
430    M: UfMergeSpec,
431    P: Group,
432    H: UndoStrategy<UfCell<U, M, P>>,
433{
434    pub fn difference(&mut self, x: usize, y: usize) -> Option<P::T> {
435        let (rx, potx) = self.find(x);
436        let (ry, poty) = self.find(y);
437        if rx == ry {
438            Some(P::operate(&P::inverse(&potx), &poty))
439        } else {
440            None
441        }
442    }
443
444    pub fn unite_with(&mut self, x: usize, y: usize, potential: P::T) -> bool {
445        let (mut rx, potx) = self.find(x);
446        let (mut ry, poty) = self.find(y);
447        if rx == ry {
448            return false;
449        }
450        let mut xinfo = self.root_info(rx).unwrap();
451        let mut yinfo = self.root_info(ry).unwrap();
452        let inverse = !U::check_directoin(&xinfo, &yinfo);
453        let potential = if inverse {
454            P::rinv_operate(&poty, &P::operate(&potx, &potential))
455        } else {
456            P::operate(&potx, &P::rinv_operate(&potential, &poty))
457        };
458        if inverse {
459            swap(&mut rx, &mut ry);
460            swap(&mut xinfo, &mut yinfo);
461        }
462        H::unite(&mut self.history, rx, ry, &self.cells);
463        {
464            let ptr = self.cells.as_mut_ptr();
465            let (cx, cy) = unsafe { (&mut *ptr.add(rx), &mut *ptr.add(ry)) };
466            self.merger
467                .merge(&mut cx.root_mut().unwrap().1, &mut cy.root_mut().unwrap().1);
468        }
469        *self.root_info_mut(rx).unwrap() = U::unite(&xinfo, &yinfo);
470        self.cells[ry] = UfCell::Child((rx, potential));
471        true
472    }
473
474    pub fn unite(&mut self, x: usize, y: usize) -> bool {
475        self.unite_with(x, y, P::unit())
476    }
477}
478
479impl<U, M, P, H> UnionFindBase<U, (), M, P, H>
480where
481    U: UnionStrategy,
482    M: UfMergeSpec,
483    P: Monoid,
484    H: UndoStrategy<UfCell<U, M, P>>,
485{
486    pub fn undo(&mut self) {
487        H::undo_unite(&mut self.history, &mut self.cells);
488    }
489}
490
491pub type UnionFind = UnionFindBase<UnionBySize, PathCompression, (), (), ()>;
492pub type MergingUnionFind<T, M> =
493    UnionFindBase<UnionBySize, PathCompression, FnMerger<T, M>, (), ()>;
494pub type PotentializedUnionFind<P> = UnionFindBase<UnionBySize, PathCompression, (), P, ()>;
495pub type UndoableUnionFind = UnionFindBase<UnionBySize, (), (), (), Undoable>;
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use crate::{
501        algebra::{Invertible, LinearOperation, Magma, Unital},
502        graph::UndirectedSparseGraph,
503        num::mint_basic::MInt998244353 as M,
504        rand,
505        tools::Xorshift,
506        tree::MixedTree,
507    };
508    use std::collections::HashSet;
509
510    fn distinct_edges(rng: &mut Xorshift, n: usize, m: usize) -> Vec<(usize, usize)> {
511        let mut edges = vec![];
512        for x in 0..n {
513            for y in 0..n {
514                edges.push((x, y));
515            }
516        }
517        rng.shuffle(&mut edges);
518        edges.truncate(m);
519        edges
520    }
521
522    fn dfs(
523        g: &UndirectedSparseGraph,
524        u: usize,
525        vis: &mut [bool],
526        f: &mut impl FnMut(usize),
527        f2: &mut impl FnMut(usize, usize, usize),
528    ) {
529        vis[u] = true;
530        f(u);
531        for a in g.adjacencies(u) {
532            if !vis[a.to] {
533                f2(u, a.to, a.id);
534                dfs(g, a.to, vis, f, f2);
535            }
536        }
537    }
538
539    #[test]
540    fn test_union_find() {
541        const N: usize = 20;
542        let mut rng = Xorshift::default();
543        for _ in 0..1000 {
544            rand!(rng, n: 1..=N, m: 1..=n * n);
545            let edges = distinct_edges(&mut rng, n, m);
546
547            macro_rules! test_uf {
548                ($union:ty, $find:ty) => {{
549                    let mut uf = UnionFindBase::<$union, $find, FnMerger<Vec<usize>, _>, (), ()>::new_with_merger(n, |i| vec![i], |x, y| x.append(y));
550                    for &(x, y) in &edges {
551                        uf.unite(x, y);
552                    }
553                    let g = UndirectedSparseGraph::from_edges(n, edges.to_vec());
554                    let mut id = vec![!0; n];
555                    {
556                        let mut vis = vec![false; n];
557                        for x in 0..n {
558                            if vis[x] {
559                                continue;
560                            }
561                            let mut set = HashSet::new();
562                            dfs(
563                                &g,
564                                x,
565                                &mut vis,
566                                &mut |x| {
567                                    set.insert(x);
568                                },
569                                &mut |_, _, _| {},
570                            );
571                            for s in set {
572                                id[s] = x;
573                            }
574                        }
575                    }
576                    for x in 0..n {
577                        for y in 0..n {
578                            assert_eq!(id[x] == id[y], uf.same(x, y));
579                        }
580                        assert_eq!(
581                            (0..n).filter(|&y| id[x] == id[y]).collect::<HashSet<_>>(),
582                            uf.merge_data(x).iter().cloned().collect()
583                        );
584                    }
585                }};
586            }
587            test_uf!(UnionBySize, PathCompression);
588            test_uf!(UnionByRank, PathCompression);
589            test_uf!((), PathCompression);
590            test_uf!(UnionBySize, ());
591            test_uf!(UnionByRank, ());
592            test_uf!((), ());
593        }
594    }
595
596    #[test]
597    fn test_potential_union_find() {
598        const N: usize = 20;
599        let mut rng = Xorshift::default();
600        type G = LinearOperation<M>;
601        for _ in 0..1000 {
602            rand!(rng, n: 1..=N, g: MixedTree(n), p: [(.., ..); n - 1], k: 0..n);
603
604            macro_rules! test_uf {
605                ($union:ty, $find:ty) => {{
606                    let mut uf = UnionFindBase::<$union, $find, (), G, ()>::new(n);
607                    for (i, &(u, v)) in g.edges.iter().enumerate().take(k) {
608                        uf.unite_with(u, v, p[i]);
609                    }
610                    for x in 0..n {
611                        let mut vis = vec![false; n];
612                        let mut dp = vec![None; n];
613                        dp[x] = Some(G::unit());
614                        dfs(&g, x, &mut vis, &mut |_| {}, &mut |u, to, id| {
615                            let p = if g.edges[id] == (u, to) {
616                                p[id]
617                            } else {
618                                G::inverse(&p[id])
619                            };
620                            if id < k {
621                                if let Some(d) = dp[u] {
622                                    dp[to] = Some(G::operate(&d, &p));
623                                }
624                            }
625                        });
626                        for (y, d) in dp.into_iter().enumerate() {
627                            assert_eq!(d, uf.difference(x, y));
628                        }
629                    }
630                }};
631            }
632            test_uf!(UnionBySize, PathCompression);
633            test_uf!(UnionByRank, PathCompression);
634            test_uf!((), PathCompression);
635            test_uf!(UnionBySize, ());
636            test_uf!(UnionByRank, ());
637            test_uf!((), ());
638        }
639    }
640
641    #[test]
642    fn test_undoable_union_find() {
643        const N: usize = 10;
644        const M: usize = 200;
645        let mut rng = Xorshift::default();
646        for _ in 0..10 {
647            rand!(rng, n: 1..=N, m: 1..=M, g: MixedTree(m), p: [(0..n, 0..n); m]);
648
649            macro_rules! test_uf {
650                ($union:ty, $find:ty) => {{
651                    let uf = UnionFind::new(n);
652                    let mut uf2 = UnionFindBase::<$union, $find, (), (), Undoable>::new(n);
653                    fn dfs(
654                        n: usize,
655                        g: &UndirectedSparseGraph,
656                        u: usize,
657                        vis: &mut [bool],
658                        mut uf: UnionFindBase<UnionBySize, PathCompression, (), (), ()>,
659                        uf2: &mut UnionFindBase<$union, $find, (), (), Undoable>,
660                        p: &[(usize, usize)],
661                    ) {
662                        vis[u] = true;
663                        for x in 0..n {
664                            for y in 0..n {
665                                assert_eq!(uf.same(x, y), uf2.same(x, y));
666                            }
667                        }
668                        for a in g.adjacencies(u) {
669                            if !vis[a.to] {
670                                let (x, y) = p[a.id];
671                                let mut uf = uf.clone();
672                                uf.unite(x, y);
673                                let merged = uf2.unite(x, y);
674                                dfs(n, g, a.to, vis, uf, uf2, p);
675                                if merged {
676                                    uf2.undo();
677                                }
678                            }
679                        }
680                    }
681                    for u in 0..m {
682                        dfs(n, &g, u, &mut vec![false; m], uf.clone(), &mut uf2, &p);
683                    }
684                }};
685            }
686            test_uf!(UnionBySize, ());
687            test_uf!(UnionByRank, ());
688            test_uf!((), ());
689        }
690    }
691}