competitive/algorithm/
combinations.rs

1use std::{collections::BTreeSet, mem::swap};
2
3pub trait SliceCombinationsExt<T> {
4    fn for_each_product<F>(&self, r: usize, f: F)
5    where
6        F: FnMut(&[T]);
7    fn for_each_permutations<F>(&self, r: usize, f: F)
8    where
9        F: FnMut(&[T]);
10    fn for_each_combinations<F>(&self, r: usize, f: F)
11    where
12        F: FnMut(&[T]);
13    fn for_each_combinations_with_replacement<F>(&self, r: usize, f: F)
14    where
15        F: FnMut(&[T]);
16    fn next_permutation(&mut self) -> bool
17    where
18        T: Ord;
19    fn prev_permutation(&mut self) -> bool
20    where
21        T: Ord;
22    fn next_combination(&mut self, r: usize) -> bool
23    where
24        T: Ord;
25    fn prev_combination(&mut self, r: usize) -> bool
26    where
27        T: Ord;
28
29    fn apply_permutation(&mut self, permutation: &[usize]);
30}
31
32impl<T> SliceCombinationsExt<T> for [T]
33where
34    T: Clone,
35{
36    /// choose `r` elements from `n` independently
37    ///
38    /// # Example
39    ///
40    /// ```
41    /// # use competitive::algorithm::SliceCombinationsExt;
42    /// let n = vec![1, 2, 3, 4];
43    /// let mut p = Vec::new();
44    /// let mut q = Vec::new();
45    /// n.for_each_product(2, |v| p.push(v.to_vec()));
46    /// for x in n.iter().cloned() {
47    ///     for y in n.iter().cloned() {
48    ///         q.push(vec![x, y]);
49    ///     }
50    /// }
51    /// assert_eq!(p, q);
52    /// ```
53    fn for_each_product<F>(&self, r: usize, mut f: F)
54    where
55        F: FnMut(&[T]),
56    {
57        fn product_inner<T, F>(n: &[T], mut r: usize, buf: &mut Vec<T>, f: &mut F)
58        where
59            T: Clone,
60            F: FnMut(&[T]),
61        {
62            if r == 0 {
63                f(buf.as_slice());
64            } else {
65                r -= 1;
66                for a in n.iter().cloned() {
67                    buf.push(a);
68                    product_inner(n, r, buf, f);
69                    buf.pop();
70                }
71            }
72        }
73
74        let mut v = Vec::with_capacity(r);
75        product_inner(self, r, &mut v, &mut f);
76    }
77
78    /// choose `r` elements from `n` independently
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// # use competitive::algorithm::SliceCombinationsExt;
84    /// let n = vec![1, 2, 3, 4];
85    /// let mut p = Vec::new();
86    /// let mut q = Vec::new();
87    /// n.for_each_product(2, |v| p.push(v.to_vec()));
88    /// for x in n.iter().cloned() {
89    ///     for y in n.iter().cloned() {
90    ///         q.push(vec![x, y]);
91    ///     }
92    /// }
93    /// assert_eq!(p, q);
94    /// ```
95    fn for_each_permutations<F>(&self, r: usize, mut f: F)
96    where
97        F: FnMut(&[T]),
98    {
99        fn permutations_inner<T, F>(
100            n: &[T],
101            mut r: usize,
102            rem: &mut BTreeSet<usize>,
103            buf: &mut Vec<T>,
104            f: &mut F,
105        ) where
106            T: Clone,
107            F: FnMut(&[T]),
108        {
109            if r == 0 {
110                f(buf.as_slice());
111            } else {
112                r -= 1;
113                for i in rem.iter().cloned().collect::<Vec<_>>() {
114                    buf.push(n[i].clone());
115                    rem.remove(&i);
116                    permutations_inner(n, r, rem, buf, f);
117                    rem.insert(i);
118                    buf.pop();
119                }
120            }
121        }
122
123        if r <= self.len() {
124            let mut v = Vec::with_capacity(r);
125            let mut rem: BTreeSet<usize> = (0..self.len()).collect();
126            permutations_inner(self, r, &mut rem, &mut v, &mut f);
127        }
128    }
129
130    /// choose distinct `r` elements from `n` in any order
131    ///
132    /// # Example
133    ///
134    /// ```
135    /// # use competitive::algorithm::SliceCombinationsExt;
136    /// let n = vec![1, 2, 3, 4];
137    /// let mut p = Vec::new();
138    /// let mut q = Vec::new();
139    /// n.for_each_permutations(2, |v| p.push(v.to_vec()));
140    /// for (i, x) in n.iter().cloned().enumerate() {
141    ///     for (j, y) in n.iter().cloned().enumerate() {
142    ///         if i != j {
143    ///             q.push(vec![x, y]);
144    ///         }
145    ///     }
146    /// }
147    /// assert_eq!(p, q);
148    /// ```
149    fn for_each_combinations<F>(&self, r: usize, mut f: F)
150    where
151        F: FnMut(&[T]),
152    {
153        fn combinations_inner<T, F>(
154            n: &[T],
155            mut r: usize,
156            start: usize,
157            buf: &mut Vec<T>,
158            f: &mut F,
159        ) where
160            T: Clone,
161            F: FnMut(&[T]),
162        {
163            if r == 0 {
164                f(buf.as_slice());
165            } else {
166                r -= 1;
167                for i in start..n.len() - r {
168                    buf.push(n[i].clone());
169                    combinations_inner(n, r, i + 1, buf, f);
170                    buf.pop();
171                }
172            }
173        }
174
175        if r <= self.len() {
176            let mut v = Vec::with_capacity(r);
177            combinations_inner(self, r, 0, &mut v, &mut f);
178        }
179    }
180
181    /// choose `r` elements from `n` in sorted order
182    ///
183    /// # Example
184    ///
185    /// ```
186    /// # use competitive::algorithm::SliceCombinationsExt;
187    /// let n = vec![1, 2, 3, 4];
188    /// let mut p = Vec::new();
189    /// let mut q = Vec::new();
190    /// n.for_each_combinations_with_replacement(2, |v| p.push(v.to_vec()));
191    /// for (i, x) in n.iter().cloned().enumerate() {
192    ///     for y in n[i..].iter().cloned() {
193    ///         q.push(vec![x, y]);
194    ///     }
195    /// }
196    /// assert_eq!(p, q);
197    /// ```
198    fn for_each_combinations_with_replacement<F>(&self, r: usize, mut f: F)
199    where
200        F: FnMut(&[T]),
201    {
202        fn combinations_with_replacement_inner<T, F>(
203            n: &[T],
204            mut r: usize,
205            start: usize,
206            buf: &mut Vec<T>,
207            f: &mut F,
208        ) where
209            T: Clone,
210            F: FnMut(&[T]),
211        {
212            if r == 0 {
213                f(buf.as_slice());
214            } else {
215                r -= 1;
216                for i in start..n.len() {
217                    buf.push(n[i].clone());
218                    combinations_with_replacement_inner(n, r, i, buf, f);
219                    buf.pop();
220                }
221            }
222        }
223
224        let mut v = Vec::with_capacity(r);
225        combinations_with_replacement_inner(self, r, 0, &mut v, &mut f);
226    }
227
228    /// Permute the elements into next permutation in lexicographical order.
229    /// Return whether such a next permutation exists.
230    fn next_permutation(&mut self) -> bool
231    where
232        T: Ord,
233    {
234        if self.len() < 2 {
235            return false;
236        }
237        let mut target = self.len() - 2;
238        while target > 0 && self[target] > self[target + 1] {
239            target -= 1;
240        }
241        if target == 0 && self[target] > self[target + 1] {
242            return false;
243        }
244        let mut next = self.len() - 1;
245        while next > target && self[next] < self[target] {
246            next -= 1;
247        }
248        self.swap(next, target);
249        self[target + 1..].reverse();
250        true
251    }
252
253    /// Permute the elements into previous permutation in lexicographical order.
254    /// Return whether such a previous permutation exists.
255    fn prev_permutation(&mut self) -> bool
256    where
257        T: Ord,
258    {
259        if self.len() < 2 {
260            return false;
261        }
262        let mut target = self.len() - 2;
263        while target > 0 && self[target] < self[target + 1] {
264            target -= 1;
265        }
266        if target == 0 && self[target] < self[target + 1] {
267            return false;
268        }
269        self[target + 1..].reverse();
270        let mut next = self.len() - 1;
271        while next > target && self[next - 1] < self[target] {
272            next -= 1;
273        }
274        self.swap(target, next);
275        true
276    }
277
278    /// Permute the elements into next combination choosing r elements in lexicographical order.
279    /// Return whether such a next combination exists.
280    fn next_combination(&mut self, r: usize) -> bool
281    where
282        T: Ord,
283    {
284        assert!(r <= self.len());
285        let (a, b) = self.split_at_mut(r);
286        next_combination_inner(a, b)
287    }
288
289    /// Permute the elements into previous combination choosing r elements in lexicographical order.
290    /// Return whether such a previous combination exists.
291    fn prev_combination(&mut self, r: usize) -> bool
292    where
293        T: Ord,
294    {
295        assert!(r <= self.len());
296        let (a, b) = self.split_at_mut(r);
297        next_combination_inner(b, a)
298    }
299
300    /// Apply a permutation to the elements.
301    /// self[i] <- self[p[i]] for each i
302    fn apply_permutation(&mut self, p: &[usize]) {
303        assert_eq!(self.len(), p.len());
304        let mut visited = vec![false; self.len()];
305        for mut current in 0..self.len() {
306            if visited[current] {
307                continue;
308            }
309            loop {
310                visited[current] = true;
311                let next = p[current];
312                if visited[next] {
313                    break;
314                }
315                self.swap(current, next);
316                current = next;
317            }
318        }
319    }
320}
321
322fn rotate_distinct<'a, T>(mut a: &'a mut [T], mut b: &'a mut [T]) {
323    while !a.is_empty() && !b.is_empty() {
324        if a.len() >= b.len() {
325            let (l, r) = a.split_at_mut(b.len());
326            l.swap_with_slice(b);
327            a = r;
328        } else {
329            let (l, r) = b.split_at_mut(a.len());
330            l.swap_with_slice(a);
331            a = l;
332            b = r;
333        }
334    }
335}
336
337fn next_combination_inner<T>(a: &mut [T], b: &mut [T]) -> bool
338where
339    T: Ord,
340{
341    if a.is_empty() || b.is_empty() {
342        return false;
343    }
344    let mut target = a.len() - 1;
345    let last_elem = b.last().unwrap();
346    while target > 0 && &a[target] >= last_elem {
347        target -= 1;
348    }
349    if target == 0 && &a[target] >= last_elem {
350        rotate_distinct(a, b);
351        return false;
352    }
353    let mut next = 0;
354    while a[target] >= b[next] {
355        next += 1;
356    }
357    swap(&mut a[target], &mut b[next]);
358    rotate_distinct(&mut a[target + 1..], &mut b[next + 1..]);
359    true
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::tools::Xorshift;
366
367    #[test]
368    fn test_for_each_product() {
369        for n in 1..=6 {
370            let values: Vec<i32> = (0..n).collect();
371            for r in 0..=6 {
372                let mut result = vec![];
373                values
374                    .as_slice()
375                    .for_each_product(r, |cur| result.push(cur.to_vec()));
376                let mut expected = vec![];
377                let mut current = vec![0; r];
378                'outer: loop {
379                    expected.push(current.clone());
380                    for i in (0..r).rev() {
381                        if current[i] + 1 < n {
382                            current[i] += 1;
383                            for c in &mut current[i + 1..] {
384                                *c = 0;
385                            }
386                            continue 'outer;
387                        }
388                    }
389                    break;
390                }
391                assert_eq!(result, expected);
392            }
393        }
394    }
395
396    #[test]
397    fn test_for_each_permutations_small_cases() {
398        for n in 1..=6 {
399            let values: Vec<i32> = (0..n).collect();
400            for r in 0..=6 {
401                let mut result = vec![];
402                values
403                    .as_slice()
404                    .for_each_permutations(r, |cur| result.push(cur.to_vec()));
405                let mut expected = vec![];
406                let mut current = vec![0; r];
407                'outer: loop {
408                    let ok = {
409                        let mut current = current.clone();
410                        current.sort_unstable();
411                        current.dedup();
412                        current.len() == r
413                    };
414                    if ok {
415                        expected.push(current.clone());
416                    }
417                    for i in (0..r).rev() {
418                        if current[i] + 1 < n {
419                            current[i] += 1;
420                            for c in &mut current[i + 1..] {
421                                *c = 0;
422                            }
423                            continue 'outer;
424                        }
425                    }
426                    break;
427                }
428                assert_eq!(result, expected);
429            }
430        }
431    }
432
433    #[test]
434    fn test_for_each_combinations_small_cases() {
435        for n in 1..=6 {
436            let values: Vec<i32> = (0..n).collect();
437            for r in 0..=6 {
438                let mut result = vec![];
439                values
440                    .as_slice()
441                    .for_each_combinations(r, |cur| result.push(cur.to_vec()));
442                let mut expected = vec![];
443                let mut current = vec![0; r];
444                'outer: loop {
445                    let ok = {
446                        let mut current = current.clone();
447                        current.dedup();
448                        current.len() == r && current.is_sorted()
449                    };
450                    if ok {
451                        expected.push(current.clone());
452                    }
453                    for i in (0..r).rev() {
454                        if current[i] + 1 < n {
455                            current[i] += 1;
456                            for c in &mut current[i + 1..] {
457                                *c = 0;
458                            }
459                            continue 'outer;
460                        }
461                    }
462                    break;
463                }
464                assert_eq!(result, expected);
465            }
466        }
467    }
468
469    #[test]
470    fn test_for_each_combinations_with_replacement_small_cases() {
471        for n in 1..=6 {
472            let values: Vec<i32> = (0..n).collect();
473            for r in 0..=6 {
474                let mut result = vec![];
475                values
476                    .as_slice()
477                    .for_each_combinations_with_replacement(r, |cur| result.push(cur.to_vec()));
478                let mut expected = vec![];
479                let mut current = vec![0; r];
480                'outer: loop {
481                    let ok = {
482                        let current = current.clone();
483                        current.is_sorted()
484                    };
485                    if ok {
486                        expected.push(current.clone());
487                    }
488                    for i in (0..r).rev() {
489                        if current[i] + 1 < n {
490                            current[i] += 1;
491                            for c in &mut current[i + 1..] {
492                                *c = 0;
493                            }
494                            continue 'outer;
495                        }
496                    }
497                    break;
498                }
499                assert_eq!(result, expected);
500            }
501        }
502    }
503
504    #[test]
505    fn test_next_prev_permutation() {
506        for n in 1..=7 {
507            let mut p: Vec<_> = (0..n).collect();
508            let mut a = vec![];
509            p.for_each_permutations(n, |p| a.push(p.to_vec()));
510            let mut b = vec![];
511            loop {
512                b.push(p.to_vec());
513                if !p.next_permutation() {
514                    break;
515                }
516                assert!(p.prev_permutation());
517                assert_eq!(b.last().as_ref().unwrap().as_slice(), &p);
518                assert!(p.next_permutation());
519            }
520            assert_eq!(a, b);
521        }
522    }
523
524    #[test]
525    fn test_next_prev_combination() {
526        for n in 1..=7 {
527            for r in 0..=n {
528                let mut p: Vec<_> = (0..n).collect();
529                let mut a = vec![];
530                p.for_each_combinations(r, |p| a.push(p.to_vec()));
531                let mut b = vec![];
532                loop {
533                    b.push(p[..r].to_vec());
534                    if !p.next_combination(r) {
535                        break;
536                    }
537                    assert!(p.prev_combination(r));
538                    assert_eq!(b.last().as_ref().unwrap().as_slice(), &p[..r]);
539                    assert!(p.next_combination(r));
540                }
541                assert_eq!(a, b);
542            }
543        }
544    }
545
546    #[test]
547    fn test_apply_permutation() {
548        let mut rng = Xorshift::default();
549        for _ in 0..100 {
550            let n = rng.random(1..100);
551            let a: Vec<_> = rng.random_iter(0..1_000).take(n).collect();
552            let mut p: Vec<usize> = (0..n).collect();
553            rng.shuffle(&mut p);
554            let expected: Vec<_> = p.iter().map(|&i| a[i]).collect();
555            let mut result = a.to_vec();
556            result.apply_permutation(&p);
557            assert_eq!(expected, result);
558        }
559    }
560}