competitive/algorithm/
bitdp.rs1use super::{One, Zero};
2use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Not, Shl, Shr, Sub};
3
4pub trait BitDpExt:
5 Sized
6 + Copy
7 + Default
8 + PartialEq
9 + Eq
10 + PartialOrd
11 + Ord
12 + Not<Output = Self>
13 + BitAnd<Output = Self>
14 + BitOr<Output = Self>
15 + BitXor<Output = Self>
16 + Shl<usize, Output = Self>
17 + Shr<usize, Output = Self>
18 + Add<Output = Self>
19 + Sub<Output = Self>
20 + Div<Output = Self>
21 + Zero
22 + One
23{
24 fn contains(self, x: usize) -> bool {
25 self & (Self::one() << x) != Self::zero()
26 }
27 fn insert(self, x: usize) -> Self {
28 self | (Self::one() << x)
29 }
30 fn remove(self, x: usize) -> Self {
31 self & !(Self::one() << x)
32 }
33 fn is_subset(self, elements: Self) -> bool {
34 self & elements == elements
35 }
36 fn is_superset(self, elements: Self) -> bool {
37 elements.is_subset(self)
38 }
39 fn subsets(self) -> Subsets<Self> {
40 Subsets {
41 mask: self,
42 cur: Some(self),
43 }
44 }
45 fn combinations(n: usize, k: usize) -> Combinations<Self> {
46 Combinations {
47 mask: Self::one() << n,
48 cur: Some((Self::one() << k) - Self::one()),
49 }
50 }
51}
52
53impl BitDpExt for u8 {}
54impl BitDpExt for u16 {}
55impl BitDpExt for u32 {}
56impl BitDpExt for u64 {}
57impl BitDpExt for u128 {}
58impl BitDpExt for usize {}
59
60#[derive(Debug, Clone)]
61pub struct Subsets<T> {
62 mask: T,
63 cur: Option<T>,
64}
65
66impl<T> Iterator for Subsets<T>
67where
68 T: BitDpExt,
69{
70 type Item = T;
71 fn next(&mut self) -> Option<Self::Item> {
72 if let Some(cur) = self.cur {
73 self.cur = if cur == T::zero() {
74 None
75 } else {
76 Some((cur - T::one()) & self.mask)
77 };
78 Some(cur)
79 } else {
80 None
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct Combinations<T> {
87 mask: T,
88 cur: Option<T>,
89}
90
91impl<T> Iterator for Combinations<T>
92where
93 T: BitDpExt,
94{
95 type Item = T;
96 fn next(&mut self) -> Option<Self::Item> {
97 if let Some(cur) = self.cur {
98 if cur < self.mask {
99 self.cur = if cur == T::zero() {
100 None
101 } else {
102 let x = cur & (!cur + T::one());
103 let y = cur + x;
104 Some(((cur & !y) / x / (T::one() + T::one())) | y)
105 };
106 Some(cur)
107 } else {
108 None
109 }
110 } else {
111 None
112 }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_contains() {
122 assert!(!0b1010u8.contains(0));
123 assert!(0b1010u8.contains(1));
124 assert!(!0b1010u8.contains(2));
125 assert!(0b1010u8.contains(3));
126 }
127
128 #[test]
129 fn test_insert() {
130 assert_eq!(0b1010u8.insert(0), 0b1011);
131 assert_eq!(0b1010u8.insert(1), 0b1010);
132 assert_eq!(0b1010u8.insert(2), 0b1110);
133 assert_eq!(0b1010u8.insert(3), 0b1010);
134 }
135
136 #[test]
137 fn test_remove() {
138 assert_eq!(0b1010u8.remove(0), 0b1010);
139 assert_eq!(0b1010u8.remove(1), 0b1000);
140 assert_eq!(0b1010u8.remove(2), 0b1010);
141 assert_eq!(0b1010u8.remove(3), 0b0010);
142 }
143
144 #[test]
145 fn test_is_subset() {
146 assert!(0b1010u8.is_subset(0b1010));
147 assert!(0b1010u8.is_subset(0b0000));
148 assert!(!0b1010u8.is_subset(0b0100));
149 assert!(!0b1010u8.is_subset(0b10000));
150 }
151
152 #[test]
153 fn test_is_superset() {
154 assert!(0b1010u8.is_superset(0b1010));
155 assert!(0b1010u8.is_superset(0b1111));
156 assert!(!0b1010u8.is_superset(0b0000));
157 assert!(!0b1010u8.is_superset(0b10000));
158 }
159
160 #[test]
161 fn test_subsets() {
162 for mask in 0usize..1 << 12 {
163 let mut subsets = mask.subsets().collect::<Vec<_>>();
164 let n = subsets.len();
165 assert_eq!(n, 1 << mask.count_ones());
166 assert!(subsets.iter().all(|&s| mask.is_subset(s)));
167 subsets.sort_unstable();
168 subsets.dedup();
169 assert_eq!(n, subsets.len());
170 }
171 }
172
173 #[test]
174 fn test_combinations() {
175 let mut comb = vec![vec![0; 14]; 14];
176 comb[0][0] = 1;
177 for i in 0..=12 {
178 for j in 0..=12 {
179 comb[i + 1][j] += comb[i][j];
180 comb[i][j + 1] += comb[i][j];
181 }
182 }
183
184 for n in 0..=12 {
185 for k in 0..=n {
186 let mut combinations = usize::combinations(n, k).collect::<Vec<_>>();
187 let len = combinations.len();
188 assert_eq!(len, comb[n - k][k]);
189 assert!(combinations.iter().all(|&s| s.count_ones() as usize == k));
190 combinations.sort_unstable();
191 combinations.dedup();
192 assert_eq!(len, combinations.len());
193 }
194 }
195 }
196}