competitive/data_structure/
bitset.rs1use std::{
2 cmp::Ordering,
3 ops::{
4 BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr,
5 ShrAssign,
6 },
7};
8
9#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub struct BitSet {
11 size: usize,
12 bits: Vec<u64>,
13}
14
15impl BitSet {
16 pub fn new(size: usize) -> Self {
17 Self {
18 size,
19 bits: vec![0; size.div_ceil(64)],
20 }
21 }
22
23 pub fn len(&self) -> usize {
24 self.size
25 }
26
27 pub fn is_empty(&self) -> bool {
28 self.size == 0
29 }
30
31 pub fn ones(size: usize) -> Self {
32 let mut self_ = Self {
33 size,
34 bits: vec![u64::MAX; size.div_ceil(64)],
35 };
36 self_.trim();
37 self_
38 }
39
40 pub fn get(&self, i: usize) -> bool {
41 self.bits[i >> 6] & (1 << (i & 63)) != 0
42 }
43
44 pub fn set(&mut self, i: usize, b: bool) {
45 if b {
46 self.bits[i >> 6] |= 1 << (i & 63);
47 } else {
48 self.bits[i >> 6] &= !(1 << (i & 63));
49 }
50 }
51
52 pub fn count_ones(&self) -> u64 {
53 self.bits.iter().map(|x| x.count_ones() as u64).sum()
54 }
55
56 pub fn count_zeros(&self) -> u64 {
57 self.size as u64 - self.count_ones()
58 }
59
60 pub fn push(&mut self, b: bool) {
61 let d = self.size & 63;
62 if d == 0 {
63 self.bits.push(b as u64);
64 } else {
65 *self.bits.last_mut().unwrap() |= (b as u64) << d;
66 }
67 self.size += 1;
68 }
69
70 pub fn resize(&mut self, new_size: usize) {
71 match self.size.cmp(&new_size) {
72 Ordering::Less => self.bits.resize(new_size.div_ceil(64), 0),
73 Ordering::Equal => {}
74 Ordering::Greater => self.bits.truncate(new_size.div_ceil(64)),
75 }
76 self.size = new_size;
77 self.trim();
78 }
79
80 fn trim(&mut self) {
81 if self.size & 63 != 0
82 && let Some(x) = self.bits.last_mut()
83 {
84 *x &= 0xffff_ffff_ffff_ffff >> (64 - (self.size & 63));
85 }
86 }
87
88 pub fn shl_bitor_assign(&mut self, rhs: usize) {
89 let n = self.bits.len();
90 let k = rhs >> 6;
91 let d = rhs & 63;
92 if k < n {
93 if d == 0 {
94 for i in (0..n - k).rev() {
95 self.bits[i + k] |= self.bits[i];
96 }
97 } else {
98 for i in (1..n - k).rev() {
99 self.bits[i + k] |= (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
100 }
101 self.bits[k] |= self.bits[0] << d;
102 }
103 self.trim();
104 }
105 }
106
107 pub fn shr_bitor_assign(&mut self, rhs: usize) {
108 let n = self.bits.len();
109 let k = rhs >> 6;
110 let d = rhs & 63;
111 if k < n {
112 if d == 0 {
113 for i in k..n {
114 self.bits[i - k] |= self.bits[i];
115 }
116 } else {
117 for i in k..n - 1 {
118 self.bits[i - k] |= (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
119 }
120 self.bits[n - k - 1] |= self.bits[n - 1] >> d;
121 }
122 }
123 }
124}
125
126impl Extend<bool> for BitSet {
127 fn extend<T: IntoIterator<Item = bool>>(&mut self, iter: T) {
128 let d = self.size & 63;
129 let mut iter = iter.into_iter();
130 let Some(first) = iter.next() else {
131 return;
132 };
133 if d == 0 {
134 self.bits.push(0);
135 }
136 let mut e = self.bits.last_mut().unwrap();
137 *e |= (first as u64) << d;
138 self.size += 1;
139 for b in iter {
140 let d = self.size & 63;
141 if d == 0 {
142 self.bits.push(b as u64);
143 e = self.bits.last_mut().unwrap();
144 } else {
145 *e |= (b as u64) << d;
146 }
147 self.size += 1;
148 }
149 }
150}
151
152impl FromIterator<bool> for BitSet {
153 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
154 let mut set = BitSet::new(0);
155 set.extend(iter);
156 set
157 }
158}
159
160impl ShlAssign<usize> for BitSet {
161 fn shl_assign(&mut self, rhs: usize) {
162 let n = self.bits.len();
163 let k = rhs >> 6;
164 let d = rhs & 63;
165 if k >= n {
166 for x in self.bits.iter_mut() {
167 *x = 0;
168 }
169 } else {
170 if d == 0 {
171 for i in (0..n - k).rev() {
172 self.bits[i + k] = self.bits[i];
173 }
174 } else {
175 for i in (1..n - k).rev() {
176 self.bits[i + k] = (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
177 }
178 self.bits[k] = self.bits[0] << d;
179 }
180 for x in self.bits[..k].iter_mut() {
181 *x = 0;
182 }
183 self.trim();
184 }
185 }
186}
187
188impl Shl<usize> for BitSet {
189 type Output = Self;
190 fn shl(mut self, rhs: usize) -> Self::Output {
191 self <<= rhs;
192 self
193 }
194}
195
196impl ShrAssign<usize> for BitSet {
197 fn shr_assign(&mut self, rhs: usize) {
198 let n = self.bits.len();
199 let k = rhs >> 6;
200 let d = rhs & 63;
201 if k >= n {
202 for x in self.bits.iter_mut() {
203 *x = 0;
204 }
205 } else {
206 if d == 0 {
207 for i in k..n {
208 self.bits[i - k] = self.bits[i];
209 }
210 } else {
211 for i in k..n - 1 {
212 self.bits[i - k] = (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
213 }
214 self.bits[n - k - 1] = self.bits[n - 1] >> d;
215 }
216 for x in self.bits[n - k..].iter_mut() {
217 *x = 0;
218 }
219 }
220 }
221}
222
223impl Shr<usize> for BitSet {
224 type Output = Self;
225 fn shr(mut self, rhs: usize) -> Self::Output {
226 self >>= rhs;
227 self
228 }
229}
230
231impl<'a> BitOrAssign<&'a BitSet> for BitSet {
232 fn bitor_assign(&mut self, rhs: &'a Self) {
233 for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
234 *l |= *r;
235 }
236 self.trim();
237 }
238}
239
240impl<'a> BitOr<&'a BitSet> for BitSet {
241 type Output = Self;
242 fn bitor(mut self, rhs: &'a Self) -> Self::Output {
243 self |= rhs;
244 self
245 }
246}
247
248impl<'b> BitOr<&'b BitSet> for &BitSet {
249 type Output = BitSet;
250 fn bitor(self, rhs: &'b BitSet) -> Self::Output {
251 let mut res = self.clone();
252 res |= rhs;
253 res
254 }
255}
256
257impl<'a> BitAndAssign<&'a BitSet> for BitSet {
258 fn bitand_assign(&mut self, rhs: &'a Self) {
259 for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
260 *l &= *r;
261 }
262 }
263}
264
265impl<'a> BitAnd<&'a BitSet> for BitSet {
266 type Output = Self;
267 fn bitand(mut self, rhs: &'a Self) -> Self::Output {
268 self &= rhs;
269 self
270 }
271}
272
273impl<'b> BitAnd<&'b BitSet> for &BitSet {
274 type Output = BitSet;
275 fn bitand(self, rhs: &'b BitSet) -> Self::Output {
276 let mut res = self.clone();
277 res &= rhs;
278 res
279 }
280}
281
282impl<'a> BitXorAssign<&'a BitSet> for BitSet {
283 fn bitxor_assign(&mut self, rhs: &'a Self) {
284 for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
285 *l ^= *r;
286 }
287 self.trim();
288 }
289}
290
291impl<'a> BitXor<&'a BitSet> for BitSet {
292 type Output = Self;
293 fn bitxor(mut self, rhs: &'a Self) -> Self::Output {
294 self ^= rhs;
295 self
296 }
297}
298
299impl<'b> BitXor<&'b BitSet> for &BitSet {
300 type Output = BitSet;
301 fn bitxor(self, rhs: &'b BitSet) -> Self::Output {
302 let mut res = self.clone();
303 res ^= rhs;
304 res
305 }
306}
307
308impl Not for BitSet {
309 type Output = Self;
310 fn not(mut self) -> Self::Output {
311 for x in self.bits.iter_mut() {
312 *x = !*x;
313 }
314 self.trim();
315 self
316 }
317}
318
319impl Not for &BitSet {
320 type Output = BitSet;
321 fn not(self) -> Self::Output {
322 !self.clone()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::{rand, tools::Xorshift};
330
331 #[test]
332 fn test_access() {
333 for _ in 0..100 {
334 let mut rng = Xorshift::default();
335 rand!(rng, n: 1..=200);
336 let mut bitset = BitSet::new(n);
337 let mut arr = vec![false; n];
338 for _ in 0..200 {
339 rand!(rng, i: 0..n, b: 0..=1u32);
340 bitset.set(i, b != 0);
341 arr[i] = b != 0;
342 assert_eq!(bitset.get(i), arr[i]);
343 }
344 assert_eq!(
345 bitset.count_ones(),
346 arr.iter().filter(|&&x| x).count() as u64
347 );
348 assert_eq!(
349 bitset.count_zeros(),
350 arr.iter().filter(|&&x| !x).count() as u64
351 );
352 }
353 }
354
355 #[test]
356 fn test_push() {
357 for _ in 0..100 {
358 let mut rng = Xorshift::default();
359 rand!(rng, n: 0..=200, arr: [0..=1u32; n]);
360 let mut bitset = BitSet::new(0);
361 for &x in &arr {
362 bitset.push(x != 0);
363 }
364 assert_eq!(bitset.len(), n);
365 for (i, &x) in arr.iter().enumerate() {
366 assert_eq!(bitset.get(i), x != 0);
367 }
368 }
369 }
370
371 #[test]
372 fn test_resize() {
373 for _ in 0..100 {
374 let mut rng = Xorshift::default();
375 rand!(rng, n: 1..=200, m: 0..=400, arr: [0..=1u32; n]);
376 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
377 bitset.resize(m);
378 assert_eq!(bitset.len(), m);
379 for (i, &x) in arr.iter().enumerate() {
380 if i < n {
381 assert_eq!(bitset.get(i), x != 0);
382 } else {
383 assert!(!bitset.get(i));
384 }
385 }
386 }
387 }
388
389 #[test]
390 fn test_shl_bitor_assign() {
391 for _ in 0..100 {
392 let mut rng = Xorshift::default();
393 rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
394 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
395 bitset.shl_bitor_assign(k);
396 for i in (k..n).rev() {
397 arr[i] |= arr[i - k];
398 }
399 assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
400 }
401 }
402
403 #[test]
404 fn test_shr_bitor_assign() {
405 for _ in 0..100 {
406 let mut rng = Xorshift::default();
407 rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
408 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
409 bitset.shr_bitor_assign(k);
410 for i in k..n {
411 arr[i - k] |= arr[i];
412 }
413 assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
414 }
415 }
416
417 #[test]
418 fn test_shl() {
419 for _ in 0..100 {
420 let mut rng = Xorshift::default();
421 rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
422 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
423 bitset <<= k;
424 let mut arr2 = vec![0; n];
425 for i in (k..n).rev() {
426 arr2[i] = arr[i - k];
427 }
428 assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
429 }
430 }
431
432 #[test]
433 fn test_shr() {
434 for _ in 0..100 {
435 let mut rng = Xorshift::default();
436 rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
437 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
438 bitset >>= k;
439 let mut arr2 = vec![0; n];
440 for (i, &a) in arr.iter().enumerate().skip(k) {
441 arr2[i - k] = a;
442 }
443 assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
444 }
445 }
446
447 #[test]
448 fn test_extend() {
449 for _ in 0..100 {
450 let mut rng = Xorshift::default();
451 rand!(rng, arr: [0..=1u32; 200], n1: 0..=200);
452 let mut bitset: BitSet = arr[..n1].iter().map(|&x| x != 0).collect();
453 bitset.extend(arr[n1..].iter().map(|&x| x != 0));
454 assert_eq!(bitset.len(), 200);
455 for (i, &x) in arr.iter().enumerate() {
456 assert_eq!(bitset.get(i), x != 0);
457 }
458 }
459 }
460}