competitive/algorithm/
bitdp.rs

1use super::{One, Zero};
2use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Not, Shl, Shr, Sub};
3
4pub trait BitDpExt:
5    Sized
6    + Copy
7    + Default
8    + PartialEq
9    + Eq
10    + PartialOrd
11    + Ord
12    + Not<Output = Self>
13    + BitAnd<Output = Self>
14    + BitOr<Output = Self>
15    + BitXor<Output = Self>
16    + Shl<usize, Output = Self>
17    + Shr<usize, Output = Self>
18    + Add<Output = Self>
19    + Sub<Output = Self>
20    + Div<Output = Self>
21    + Zero
22    + One
23{
24    fn contains(self, x: usize) -> bool {
25        self & (Self::one() << x) != Self::zero()
26    }
27    fn insert(self, x: usize) -> Self {
28        self | (Self::one() << x)
29    }
30    fn remove(self, x: usize) -> Self {
31        self & !(Self::one() << x)
32    }
33    fn is_subset(self, elements: Self) -> bool {
34        self & elements == elements
35    }
36    fn is_superset(self, elements: Self) -> bool {
37        elements.is_subset(self)
38    }
39    fn subsets(self) -> Subsets<Self> {
40        Subsets {
41            mask: self,
42            cur: Some(self),
43        }
44    }
45    fn combinations(n: usize, k: usize) -> Combinations<Self> {
46        Combinations {
47            mask: Self::one() << n,
48            cur: Some((Self::one() << k) - Self::one()),
49        }
50    }
51}
52
53impl BitDpExt for u8 {}
54impl BitDpExt for u16 {}
55impl BitDpExt for u32 {}
56impl BitDpExt for u64 {}
57impl BitDpExt for u128 {}
58impl BitDpExt for usize {}
59
60#[derive(Debug, Clone)]
61pub struct Subsets<T> {
62    mask: T,
63    cur: Option<T>,
64}
65
66impl<T> Iterator for Subsets<T>
67where
68    T: BitDpExt,
69{
70    type Item = T;
71    fn next(&mut self) -> Option<Self::Item> {
72        if let Some(cur) = self.cur {
73            self.cur = if cur == T::zero() {
74                None
75            } else {
76                Some((cur - T::one()) & self.mask)
77            };
78            Some(cur)
79        } else {
80            None
81        }
82    }
83}
84
85#[derive(Debug, Clone)]
86pub struct Combinations<T> {
87    mask: T,
88    cur: Option<T>,
89}
90
91impl<T> Iterator for Combinations<T>
92where
93    T: BitDpExt,
94{
95    type Item = T;
96    fn next(&mut self) -> Option<Self::Item> {
97        if let Some(cur) = self.cur {
98            if cur < self.mask {
99                self.cur = if cur == T::zero() {
100                    None
101                } else {
102                    let x = cur & (!cur + T::one());
103                    let y = cur + x;
104                    Some(((cur & !y) / x / (T::one() + T::one())) | y)
105                };
106                Some(cur)
107            } else {
108                None
109            }
110        } else {
111            None
112        }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_contains() {
122        assert!(!0b1010u8.contains(0));
123        assert!(0b1010u8.contains(1));
124        assert!(!0b1010u8.contains(2));
125        assert!(0b1010u8.contains(3));
126    }
127
128    #[test]
129    fn test_insert() {
130        assert_eq!(0b1010u8.insert(0), 0b1011);
131        assert_eq!(0b1010u8.insert(1), 0b1010);
132        assert_eq!(0b1010u8.insert(2), 0b1110);
133        assert_eq!(0b1010u8.insert(3), 0b1010);
134    }
135
136    #[test]
137    fn test_remove() {
138        assert_eq!(0b1010u8.remove(0), 0b1010);
139        assert_eq!(0b1010u8.remove(1), 0b1000);
140        assert_eq!(0b1010u8.remove(2), 0b1010);
141        assert_eq!(0b1010u8.remove(3), 0b0010);
142    }
143
144    #[test]
145    fn test_is_subset() {
146        assert!(0b1010u8.is_subset(0b1010));
147        assert!(0b1010u8.is_subset(0b0000));
148        assert!(!0b1010u8.is_subset(0b0100));
149        assert!(!0b1010u8.is_subset(0b10000));
150    }
151
152    #[test]
153    fn test_is_superset() {
154        assert!(0b1010u8.is_superset(0b1010));
155        assert!(0b1010u8.is_superset(0b1111));
156        assert!(!0b1010u8.is_superset(0b0000));
157        assert!(!0b1010u8.is_superset(0b10000));
158    }
159
160    #[test]
161    fn test_subsets() {
162        for mask in 0usize..1 << 12 {
163            let mut subsets = mask.subsets().collect::<Vec<_>>();
164            let n = subsets.len();
165            assert_eq!(n, 1 << mask.count_ones());
166            assert!(subsets.iter().all(|&s| mask.is_subset(s)));
167            subsets.sort_unstable();
168            subsets.dedup();
169            assert_eq!(n, subsets.len());
170        }
171    }
172
173    #[test]
174    fn test_combinations() {
175        let mut comb = vec![vec![0; 14]; 14];
176        comb[0][0] = 1;
177        for i in 0..=12 {
178            for j in 0..=12 {
179                comb[i + 1][j] += comb[i][j];
180                comb[i][j + 1] += comb[i][j];
181            }
182        }
183
184        for n in 0..=12 {
185            for k in 0..=n {
186                let mut combinations = usize::combinations(n, k).collect::<Vec<_>>();
187                let len = combinations.len();
188                assert_eq!(len, comb[n - k][k]);
189                assert!(combinations.iter().all(|&s| s.count_ones() as usize == k));
190                combinations.sort_unstable();
191                combinations.dedup();
192                assert_eq!(len, combinations.len());
193            }
194        }
195    }
196}