competitive/data_structure/
bitset.rs1#![allow(clippy::suspicious_op_assign_impl)]
2
3use std::ops::{
4 BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, ShlAssign, Shr,
5 ShrAssign,
6};
7
8#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
9pub struct BitSet {
10 size: usize,
11 bits: Vec<u64>,
12}
13
14impl BitSet {
15 pub fn new(size: usize) -> Self {
16 Self {
17 size,
18 bits: vec![0; (size + 63) / 64],
19 }
20 }
21
22 pub fn len(&self) -> usize {
23 self.size
24 }
25
26 pub fn is_empty(&self) -> bool {
27 self.size == 0
28 }
29
30 pub fn ones(size: usize) -> Self {
31 let mut self_ = Self {
32 size,
33 bits: vec![u64::MAX; (size + 63) / 64],
34 };
35 self_.trim();
36 self_
37 }
38
39 pub fn get(&self, i: usize) -> bool {
40 self.bits[i >> 6] & (1 << (i & 63)) != 0
41 }
42
43 pub fn set(&mut self, i: usize, b: bool) {
44 if b {
45 self.bits[i >> 6] |= 1 << (i & 63);
46 } else {
47 self.bits[i >> 6] &= !(1 << (i & 63));
48 }
49 }
50
51 pub fn count_ones(&self) -> u64 {
52 self.bits.iter().map(|x| x.count_ones() as u64).sum()
53 }
54
55 pub fn count_zeros(&self) -> u64 {
56 self.size as u64 - self.count_ones()
57 }
58
59 pub fn push(&mut self, b: bool) {
60 let d = self.size & 63;
61 if d == 0 {
62 self.bits.push(b as u64);
63 } else {
64 *self.bits.last_mut().unwrap() |= (b as u64) << d;
65 }
66 self.size += 1;
67 }
68
69 fn trim(&mut self) {
70 if self.size & 63 != 0 {
71 if let Some(x) = self.bits.last_mut() {
72 *x &= 0xffff_ffff_ffff_ffff >> (64 - (self.size & 63));
73 }
74 }
75 }
76
77 pub fn shl_bitor_assign(&mut self, rhs: usize) {
78 let n = self.bits.len();
79 let k = rhs >> 6;
80 let d = rhs & 63;
81 if k < n {
82 if d == 0 {
83 for i in (0..n - k).rev() {
84 self.bits[i + k] |= self.bits[i];
85 }
86 } else {
87 for i in (1..n - k).rev() {
88 self.bits[i + k] |= (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
89 }
90 self.bits[k] |= self.bits[0] << d;
91 }
92 self.trim();
93 }
94 }
95
96 pub fn shr_bitor_assign(&mut self, rhs: usize) {
97 let n = self.bits.len();
98 let k = rhs >> 6;
99 let d = rhs & 63;
100 if k < n {
101 if d == 0 {
102 for i in k..n {
103 self.bits[i - k] |= self.bits[i];
104 }
105 } else {
106 for i in k..n - 1 {
107 self.bits[i - k] |= (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
108 }
109 self.bits[n - k - 1] |= self.bits[n - 1] >> d;
110 }
111 }
112 }
113}
114
115impl Extend<bool> for BitSet {
116 fn extend<T: IntoIterator<Item = bool>>(&mut self, iter: T) {
117 let d = self.size & 63;
118 let mut iter = iter.into_iter();
119 let Some(first) = iter.next() else {
120 return;
121 };
122 if d == 0 {
123 self.bits.push(0);
124 }
125 let mut e = self.bits.last_mut().unwrap();
126 *e |= (first as u64) << d;
127 self.size += 1;
128 for b in iter {
129 let d = self.size & 63;
130 if d == 0 {
131 self.bits.push(b as u64);
132 e = self.bits.last_mut().unwrap();
133 } else {
134 *e |= (b as u64) << d;
135 }
136 self.size += 1;
137 }
138 }
139}
140
141impl FromIterator<bool> for BitSet {
142 fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
143 let mut set = BitSet::new(0);
144 set.extend(iter);
145 set
146 }
147}
148
149impl ShlAssign<usize> for BitSet {
150 fn shl_assign(&mut self, rhs: usize) {
151 let n = self.bits.len();
152 let k = rhs >> 6;
153 let d = rhs & 63;
154 if k >= n {
155 for x in self.bits.iter_mut() {
156 *x = 0;
157 }
158 } else {
159 if d == 0 {
160 for i in (0..n - k).rev() {
161 self.bits[i + k] = self.bits[i];
162 }
163 } else {
164 for i in (1..n - k).rev() {
165 self.bits[i + k] = (self.bits[i] << d) | (self.bits[i - 1] >> (64 - d));
166 }
167 self.bits[k] = self.bits[0] << d;
168 }
169 for x in self.bits[..k].iter_mut() {
170 *x = 0;
171 }
172 self.trim();
173 }
174 }
175}
176
177impl Shl<usize> for BitSet {
178 type Output = Self;
179 fn shl(mut self, rhs: usize) -> Self::Output {
180 self <<= rhs;
181 self
182 }
183}
184
185impl ShrAssign<usize> for BitSet {
186 fn shr_assign(&mut self, rhs: usize) {
187 let n = self.bits.len();
188 let k = rhs >> 6;
189 let d = rhs & 63;
190 if k >= n {
191 for x in self.bits.iter_mut() {
192 *x = 0;
193 }
194 } else {
195 if d == 0 {
196 for i in k..n {
197 self.bits[i - k] = self.bits[i];
198 }
199 } else {
200 for i in k..n - 1 {
201 self.bits[i - k] = (self.bits[i] >> d) | (self.bits[i + 1] << (64 - d));
202 }
203 self.bits[n - k - 1] = self.bits[n - 1] >> d;
204 }
205 for x in self.bits[n - k..].iter_mut() {
206 *x = 0;
207 }
208 }
209 }
210}
211
212impl Shr<usize> for BitSet {
213 type Output = Self;
214 fn shr(mut self, rhs: usize) -> Self::Output {
215 self >>= rhs;
216 self
217 }
218}
219
220impl<'a> BitOrAssign<&'a BitSet> for BitSet {
221 fn bitor_assign(&mut self, rhs: &'a Self) {
222 for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
223 *l |= *r;
224 }
225 self.trim();
226 }
227}
228
229impl<'a> BitOr<&'a BitSet> for BitSet {
230 type Output = Self;
231 fn bitor(mut self, rhs: &'a Self) -> Self::Output {
232 self |= rhs;
233 self
234 }
235}
236
237impl<'b> BitOr<&'b BitSet> for &BitSet {
238 type Output = BitSet;
239 fn bitor(self, rhs: &'b BitSet) -> Self::Output {
240 let mut res = self.clone();
241 res |= rhs;
242 res
243 }
244}
245
246impl<'a> BitAndAssign<&'a BitSet> for BitSet {
247 fn bitand_assign(&mut self, rhs: &'a Self) {
248 for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
249 *l &= *r;
250 }
251 }
252}
253
254impl<'a> BitAnd<&'a BitSet> for BitSet {
255 type Output = Self;
256 fn bitand(mut self, rhs: &'a Self) -> Self::Output {
257 self &= rhs;
258 self
259 }
260}
261
262impl<'b> BitAnd<&'b BitSet> for &BitSet {
263 type Output = BitSet;
264 fn bitand(self, rhs: &'b BitSet) -> Self::Output {
265 let mut res = self.clone();
266 res &= rhs;
267 res
268 }
269}
270
271impl<'a> BitXorAssign<&'a BitSet> for BitSet {
272 fn bitxor_assign(&mut self, rhs: &'a Self) {
273 for (l, r) in self.bits.iter_mut().zip(rhs.bits.iter()) {
274 *l ^= *r;
275 }
276 self.trim();
277 }
278}
279
280impl<'a> BitXor<&'a BitSet> for BitSet {
281 type Output = Self;
282 fn bitxor(mut self, rhs: &'a Self) -> Self::Output {
283 self ^= rhs;
284 self
285 }
286}
287
288impl<'b> BitXor<&'b BitSet> for &BitSet {
289 type Output = BitSet;
290 fn bitxor(self, rhs: &'b BitSet) -> Self::Output {
291 let mut res = self.clone();
292 res ^= rhs;
293 res
294 }
295}
296
297impl Not for BitSet {
298 type Output = Self;
299 fn not(mut self) -> Self::Output {
300 for x in self.bits.iter_mut() {
301 *x = !*x;
302 }
303 self.trim();
304 self
305 }
306}
307
308impl Not for &BitSet {
309 type Output = BitSet;
310 fn not(self) -> Self::Output {
311 !self.clone()
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::{rand, tools::Xorshift};
319
320 #[test]
321 fn test_access() {
322 for _ in 0..100 {
323 let mut rng = Xorshift::new();
324 rand!(rng, n: 1..=200);
325 let mut bitset = BitSet::new(n);
326 let mut arr = vec![false; n];
327 for _ in 0..200 {
328 rand!(rng, i: 0..n, b: 0..=1u32);
329 bitset.set(i, b != 0);
330 arr[i] = b != 0;
331 assert_eq!(bitset.get(i), arr[i]);
332 }
333 assert_eq!(
334 bitset.count_ones(),
335 arr.iter().filter(|&&x| x).count() as u64
336 );
337 assert_eq!(
338 bitset.count_zeros(),
339 arr.iter().filter(|&&x| !x).count() as u64
340 );
341 }
342 }
343
344 #[test]
345 fn test_push() {
346 for _ in 0..100 {
347 let mut rng = Xorshift::new();
348 rand!(rng, n: 0..=200, arr: [0..=1u32; n]);
349 let mut bitset = BitSet::new(0);
350 for &x in &arr {
351 bitset.push(x != 0);
352 }
353 assert_eq!(bitset.len(), n);
354 for (i, &x) in arr.iter().enumerate() {
355 assert_eq!(bitset.get(i), x != 0);
356 }
357 }
358 }
359
360 #[test]
361 fn test_shl_bitor_assign() {
362 for _ in 0..100 {
363 let mut rng = Xorshift::new();
364 rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
365 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
366 bitset.shl_bitor_assign(k);
367 for i in (k..n).rev() {
368 arr[i] |= arr[i - k];
369 }
370 assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
371 }
372 }
373
374 #[test]
375 fn test_shr_bitor_assign() {
376 for _ in 0..100 {
377 let mut rng = Xorshift::new();
378 rand!(rng, n: 1..=200, k: 1..=n, mut arr: [0..=1u32; n]);
379 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
380 bitset.shr_bitor_assign(k);
381 for i in k..n {
382 arr[i - k] |= arr[i];
383 }
384 assert_eq!(bitset, BitSet::from_iter(arr.iter().map(|&x| x != 0)));
385 }
386 }
387
388 #[test]
389 fn test_shl() {
390 for _ in 0..100 {
391 let mut rng = Xorshift::new();
392 rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
393 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
394 bitset <<= k;
395 let mut arr2 = vec![0; n];
396 for i in (k..n).rev() {
397 arr2[i] = arr[i - k];
398 }
399 assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
400 }
401 }
402
403 #[test]
404 fn test_shr() {
405 for _ in 0..100 {
406 let mut rng = Xorshift::new();
407 rand!(rng, n: 1..=200, k: 1..=n, arr: [0..=1u32; n]);
408 let mut bitset: BitSet = arr.iter().map(|&x| x != 0).collect();
409 bitset >>= k;
410 let mut arr2 = vec![0; n];
411 for (i, &a) in arr.iter().enumerate().skip(k) {
412 arr2[i - k] = a;
413 }
414 assert_eq!(bitset, BitSet::from_iter(arr2.iter().map(|&x| x != 0)));
415 }
416 }
417
418 #[test]
419 fn test_extend() {
420 for _ in 0..100 {
421 let mut rng = Xorshift::new();
422 rand!(rng, arr: [0..=1u32; 200], n1: 0..=200);
423 let mut bitset: BitSet = arr[..n1].iter().map(|&x| x != 0).collect();
424 bitset.extend(arr[n1..].iter().map(|&x| x != 0));
425 assert_eq!(bitset.len(), 200);
426 for (i, &x) in arr.iter().enumerate() {
427 assert_eq!(bitset.get(i), x != 0);
428 }
429 }
430 }
431}